机器学习(线性回归模型预测)

机器学习(线性回归模型预测),第1张

机器学习(线性回归模型预测)
#线性回归模型优缺点:
    #优点:快速;没有调节参数;可轻易解释;容易理解
    #缺点:相比较于其他复杂模型,其准确率不高,因为它假设特征和响应之间存在确定的线性关系,这种假设对于非线性的关系,不能得到合适的解决方法
#实例
import pandas as pd
data = pd.read_csv('Advertising.csv')
#使用pandas构建X(特征向量)和 y(标签列)
#创建特征列表
feature_cols = ['TV','radio','newspaper']
X  =  data[feature_cols]
# X = data[['TV','radio','newspaper']]
# print(X.head())
# print(type(X))          #
# print(X.shape)         #(200, 3)

y = data['sales']
# y =  data.sales
# print(y.head())
#构建训练集和测试集
from sklearn.cross_validation import train_test_split   #引用交叉验证
#75%用于训练 ,25%用于测试
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=1)
# print(X_train.shape)
# print(y_train.shape)
# print(X_test.shape)
# print(y_test.shape)
# print(type(y_test))       
#sklearn的线性回归
from sklearn.linear_model import LinearRegression
linreg  = LinearRegression()
model = linreg.fit(X_train,y_train)
print(model)    #LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
print(linreg.intercept_)
print(linreg.coef_)
#输出变量的回归系数
#将特征名称与系数对应
res = zip(feature_cols,linreg.coef_)
print(list(res))    #[('TV', 0.04656456787415029), ('radio', 0.17915812245088839), ('newspaper', 0.003450464711180378)]
#即线性回归结果为:y=0.0465*TV+0.1791*radio+0.0034*newspaper+2.877
#预测
y_pred = linreg.predict(X_test)
print(y_pred)
print(type(y_pred))

#评价测度(分类问题的评价是准确率,回归问题则是连续数值的评价测度)
    #平均绝对误差(Mean Absolute Error)-MAE
    #均方误差(Mean Squared Error)-MSE
    #均方根误差(Root Mean Squared Error)-RMSE
#计算sales预测的RMSE
print(type(y_pred),type(y_test))            # 
print(len(y_pred),len(y_test))                #50 50
print(y_pred.shape,y_test.shape)         #(50,) (50,)
from sklearn import metrics
import numpy as np
sum_mean = 0
for i in range(len(y_pred)):
    sum_mean+=(y_pred[i]-y_test.values[i])**2
sum_err = np.sqrt(sum_mean/50)
print("RMSE:",sum_err)                  #RMSE: 1.404651423032895
#绘制ROC曲线
import matplotlib.pyplot as plt
plt.figure()
plt.plot(range(len(y_pred)),y_pred,'b',label='predict')
plt.plot(range(len(y_pred)),y_test,'r',label='test')
plt.legend(loc='upper right')
plt.xlabel('the nmber of sales')
plt.ylabel('value of sales')
plt.show()

曲线结果:

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5491489.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-12
下一篇 2022-12-12

发表评论

登录后才能评论

评论列表(0条)

保存