當前位置:網站首頁>【MATLAB】機器學習: 線性回歸實驗(梯度下降+閉式解)

【MATLAB】機器學習: 線性回歸實驗(梯度下降+閉式解)

2022-01-28 00:35:43 Orange_Jet

實驗內容

1.根據梯度下降法完成一元線性回歸實驗。
2.根據閉式解完成一元線性回歸實驗。
3.比較兩種解下的實驗結果。

實驗代碼

clear;clc;
%% 數據導入;劃分訓練集和測試集
% 數據導入
data=load("ex1data1.txt");
X=data(:,1);
Y=data(:,2);
% 劃分訓練集和測試集
ind=[];Train_ind=[];Test_ind=[];    % 隨機數索引;訓練集索引;測試集索引
ind=randperm(length(X));
Train_ind=ind(1:round(length(ind)*(2/3)));  % 訓練集樣本索引
Test_ind=ind(round(length(ind)*(2/3))+1:end);   % 測試集樣本索引
Xtrain=X(Train_ind,:);     % X中的訓練樣本
Xtest=X(Test_ind,:);       % X中的測試樣本
Ytrain=Y(Train_ind,:);     % Y中的訓練樣本
Ytest=Y(Test_ind,:);       % Y中的測試樣本
%% 梯度下降求解一元線性回歸
m=length(Xtrain);
X_train1=[ones(m,1),Xtrain];
theta=zeros(2,1);
iterations=1500;    % 設置迭代次數為1500
alpha=0.01;         % 設置梯度下降步長為0.01
[theta_1, J_history] = gradientDescent(X_train1, Ytrain, theta, alpha, iterations);
%% 閉式解求解一元線性回歸
[w,b] = closed_formSolution(Xtrain,Ytrain);
theta_2=[b;w];
%% 繪制兩種方法的圖像
scatter(Xtrain,Ytrain); % 繪制真實數據的散點圖
hold on;
x=0:1:25;
y1=theta_1(2,1).*x+theta_1(1,1);    % 梯度下降求解的y1函數
y2=theta_2(2,1).*x+theta_2(1,1);    % 閉式解求解的y2函數
plot(x,y1,'-xb',x,y2,':.r');    % 繪制y1和y2的圖像
xlabel('x');
ylabel('y');
legend('真實數據點:(x,y)','梯度下降求解:y=1.1816x-3.5597','閉式解求解:y=1.2114x-3.8586','best');
%% 使用均方誤差衡量兩種方法的優劣
Ytest_predict1=theta_1(2,1).*Xtest+theta_1(1,1);    % 使用梯度下降求解得到的測試集的預測值
Ytest_predict2=theta_2(2,1).*Xtest+theta_2(1,1);    % 使用閉式解求解得到的測試集的預測值
mse=zeros(2,1);
mse(1,1)=sum((Ytest_predict1-Ytest).^2,1);  % 梯度下降的均方誤差
mse(2,1)=sum((Ytest_predict2-Ytest).^2,1);  % 閉式解的均方誤差



%% **********************函數一(梯度下降)*********************
%定義一個實現梯度下降的函數
function [theta, J_history] = gradientDescent(X, y, theta, alpha, num_iters)
m = length(y);%取數據的長度
J_history = zeros(num_iters, 1);%定義J_history為num_iters行1列的向量
%其中,num_iters是在調用該梯度下降函數時的參數
for iter = 1:num_iters
    S = (1 / m) * (X' * (X * theta - y));%相當於求導
    theta = theta - alpha .* S;%theta的更新
    J_history(iter) = computeCost(X, y, theta);%在這裏調用computeCost(X, y, theta)
 %相當於在慢慢的减小代價函數
end
end

function J = computeCost(X, y, theta)
% 函數功能:求代價函數
m = length(y); % y的數據量
J = 0;
%h(x)
h = X*theta;
loss = (h-y).^2;
J = 1/(2*m)*(X*theta-y)'*(X*theta-y);
end

%% **********************函數二(閉式解)*********************
function [w,b] = closed_formSolution(X,Y)
% 函數功能:閉式解求出一元線性回歸函數Y=ωX+b中的ω和b
% 輸入參數:X錶示自變量;Y錶示因變量
% 函數返回值:w錶示自變量X的系數;b錶示常量

m=length(X);    % 取X的長度,

X_mean=mean(X,1);
tmp=repmat(X_mean,m,1);
w_top=sum(Y.*(X-tmp),1);
w_bottom=sum(X.*X,1)-(sum(X,1))^2/m;
w=w_top/w_bottom;

b=sum(Y-w*X,1)/m;
end

實驗結果

實驗心得

通過本次“線性回歸實驗”,加深了我對梯度下降和閉式解兩種方法的理解,並且能够熟練使用。在回歸任務的度量中,可以使用均方誤差。繪制散點圖時,可以使用scatter函數。

版權聲明
本文為[Orange_Jet]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/01/202201280035434077.html

隨機推薦