数据集链接链接: link.
百度网盘:链接: https://pan.baidu.com/s/1_pQ-3iG4dr0hrvU_5hYUtg
提取码:9520
我们要用的编译软件是jupytor,选择它的原因是可以将程序进行分块运行。
对于新手而言,很容易写错代码,如果在pycharm上运行的话,每次运行错误,都需要重新运行,如果数据集很大,就会很浪费时间,jupytor可以交互运行、分块运行,对于用户调试代码很友好。
怎么运行jupytor,看下图,当你装anconda3的时候,会默认装一个jupytor,点击运行就可以。
之后会跳出一个网页(如下图),这个目录的路径是C盘下user(或者用户)下的一个目录路径,有些文件它显示的是英文,不过这样不影响 *** 作,我们点击新建,选择Python3
进来之后,就是下面的界面,我们就可以运行代码了
运行某个块的快捷键 shift+enter
jupytor的用法讲完,咱们就用它运行我们的代码,下面的代码我是按分块保存的,所以您也可以像我这样分块运行
import pandas as pd
from sklearn.feature_extraction import DictVectorizer #对字典进行特征提取,将文本特征转化成one-hot编码,返回一个one-hot编码,当sparse=False不产生稀疏矩阵
from sklearn.ensemble import RandomForestClassifier #导入随机森林
from sklearn.tree import DecisionTreeClassifier #导入决策树包
from sklearn.model_selection import GridSearchCV #导入网格搜索(用于参数调优)
(1)导入需要的数据
train_data = pd.read_csv('./data/train.csv')
test_data = pd.read_csv('./data/test.csv')
y_test = pd.read_csv('./data/gender_submission.csv')
(2)数据清洗
# 使用平均年龄来填充年龄中的 nan 值
train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)
test_data['Age'].fillna(test_data['Age'].mean(),inplace=True)
# 使用票价的均值填充票价中的 nan 值
train_data['Fare'].fillna(train_data['Fare'].mean(), inplace=True)
test_data['Fare'].fillna(test_data['Fare'].mean(),inplace=True)
print(train_data['Embarked'].value_counts())
# 使用登录最多的港口来填充登录港口的 nan 值
train_data['Embarked'].fillna('S', inplace=True)
test_data['Embarked'].fillna('S',inplace=True)
S 644
C 168
Q 77
Name: Embarked, dtype: int64
(3)特征选择
features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']
x_train = train_data[features]
y_train = train_data['Survived']
x_test = test_data[features]
y_test = y_test['Survived'] #导入的数据有两列,取'Survived'列为测试标签
dvec=DictVectorizer(sparse=False)
x_train = dvec.fit_transform(x_train.to_dict(orient='record'))
x_test = dvec.transform(x_test.to_dict(orient='record'))
E:\Program Files\Anaconda3\lib\site-packages\pandas\core\frame.py:1549: FutureWarning: Using short name for 'orient' is deprecated. Only the options: ('dict', list, 'series', 'split', 'records', 'index') will be used in a future version. Use one of the above to silence this warning.
warnings.warn(
y_test
0 0
1 1
2 0
3 0
4 1
..
413 0
414 1
415 0
416 0
417 0
Name: Survived, Length: 418, dtype: int64
(4)构造随机森林
estimator = RandomForestClassifier()
(5)加入网格搜索和交叉验证
param_dict = {"n_estimators":[120,200,300,500,800,1200],
"max_depth":[5,8,15,25,30]}
estimator = GridSearchCV(estimator,param_grid=param_dict,cv=3) #这里的cv是用10折的数据用来交叉验证
x_train,y_train = x_train, y_train
estimator.fit(x_train,y_train)
GridSearchCV(cv=3, estimator=RandomForestClassifier(),
param_grid={'max_depth': [5, 8, 15, 25, 30],
'n_estimators': [120, 200, 300, 500, 800, 1200]})
(6)评估模型
#方法1:直接对比真实值和预估值
y_predict = estimator.predict(x_test)
print('y_predict:\n', y_predict)
print('直接对比真实值和预测值:\n',y_test == y_predict)
#方法2:计算准确率
y_predict:
[0 0 0 0 1 0 1 0 1 0 0 0 1 0 1 1 0 0 0 1 1 1 1 0 1 0 1 0 0 0 0 0 1 0 1 0 0
0 0 1 0 0 0 1 1 0 0 0 1 1 0 0 1 1 0 0 0 0 0 1 0 0 0 1 1 1 1 0 0 1 1 0 0 0
1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0
0 1 1 1 0 0 1 0 1 1 0 1 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0
0 0 1 0 0 1 0 0 1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 0 0 0 1 1 1 1 1 0 0 1 0 1
0 1 0 0 0 0 0 1 0 1 0 1 1 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 1 1 0 1 0 1 0 1 0
1 0 1 0 0 0 0 0 0 1 0 0 1 0 1 0 1 1 1 1 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 1
0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 1 0 1 0 0 0 1 0 0
1 0 0 0 0 0 0 0 0 0 1 0 1 0 1 0 1 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 0 0 1 0
0 1 0 0 1 1 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 1 0 1 0 0 0 0
0 0 1 0 1 0 0 1 0 0 1]
直接对比真实值和预测值:
0 True
1 False
2 True
3 True
4 True
...
413 True
414 True
415 True
416 True
417 False
Name: Survived, Length: 418, dtype: bool
score = estimator.score(x_test,y_test)
print("准确率为:\n",score)
# 结果分析:
print('最佳参数:\n',estimator.best_params_)
# bestscore: 在交叉验证中验证的最好结果_
print('最佳结果:\n',estimator.best_score_)
# bestestimator:最好的参数模型
print('最佳预估器:\n',estimator.best_estimator_)
# cvresults: 每次交叉验证后的验证集准确率结果和训练集准确率结果
print('交叉验证结果:\n', estimator.cv_results_)
准确率为:
0.861244019138756
最佳参数:
{'max_depth': 8, 'n_estimators': 800}
最佳结果:
0.8260381593714928
最佳预估器:
RandomForestClassifier(max_depth=8, n_estimators=800)
交叉验证结果:
{'mean_fit_time': array([0.14993318, 0.24965994, 0.38132485, 0.61536312, 1.04751142,
1.57679296, 0.15622576, 0.25996598, 0.38662124, 0.64927038,
1.03956922, 1.54551975, 0.18119264, 0.32046803, 0.48171075,
0.72539409, 1.18317087, 1.83011953, 0.1668752 , 0.27692644,
0.41985854, 0.71610451, 1.1459322 , 1.72572255, 0.16987888,
0.30083632, 0.41321762, 0.75131655, 1.13660574, 1.64195848]), 'std_fit_time': array([0.00287001, 0.00737044, 0.00308405, 0.00643914, 0.07266875,
0.10127349, 0.0020616 , 0.00248251, 0.00046787, 0.00295435,
0.00633619, 0.00846855, 0.01510928, 0.01061129, 0.04667695,
0.03663693, 0.01839278, 0.02784425, 0.00329037, 0.0065818 ,
0.00488676, 0.02092371, 0.01521066, 0.07668334, 0.00418458,
0.03867806, 0.00377439, 0.02325702, 0.00683441, 0.01049551]), 'mean_score_time': array([0.0126303 , 0.02026908, 0.03123816, 0.06316447, 0.08343434,
0.11901553, 0.01332045, 0.02326671, 0.03324294, 0.05385939,
0.0877467 , 0.1286722 , 0.01628272, 0.02326957, 0.03890379,
0.06084267, 0.10037716, 0.13929447, 0.0143044 , 0.02393691,
0.0355804 , 0.06049252, 0.09140317, 0.14261834, 0.01329788,
0.02727048, 0.03690338, 0.06350621, 0.10572728, 0.13629254]), 'std_score_time': array([0.00045032, 0.00093432, 0.00093255, 0.01600419, 0.00410913,
0.00248874, 0.00045671, 0.00124138, 0.00124407, 0.00081108,
0.0008147 , 0.00162775, 0.00093646, 0.00170441, 0.0024408 ,
0.00571806, 0.00692931, 0.00384891, 0.00046342, 0.00141029,
0.00045754, 0.00048385, 0.00047451, 0.0029412 , 0.00047064,
0.00576717, 0.00080651, 0.00400502, 0.01742332, 0.00125819]), 'param_max_depth': masked_array(data=[5, 5, 5, 5, 5, 5, 8, 8, 8, 8, 8, 8, 15, 15, 15, 15, 15,
15, 25, 25, 25, 25, 25, 25, 30, 30, 30, 30, 30, 30],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False],
fill_value='?',
dtype=object), 'param_n_estimators': masked_array(data=[120, 200, 300, 500, 800, 1200, 120, 200, 300, 500, 800,
1200, 120, 200, 300, 500, 800, 1200, 120, 200, 300,
500, 800, 1200, 120, 200, 300, 500, 800, 1200],
mask=[False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False,
False, False, False, False, False, False],
fill_value='?',
dtype=object), 'params': [{'max_depth': 5, 'n_estimators': 120}, {'max_depth': 5, 'n_estimators': 200}, {'max_depth': 5, 'n_estimators': 300}, {'max_depth': 5, 'n_estimators': 500}, {'max_depth': 5, 'n_estimators': 800}, {'max_depth': 5, 'n_estimators': 1200}, {'max_depth': 8, 'n_estimators': 120}, {'max_depth': 8, 'n_estimators': 200}, {'max_depth': 8, 'n_estimators': 300}, {'max_depth': 8, 'n_estimators': 500}, {'max_depth': 8, 'n_estimators': 800}, {'max_depth': 8, 'n_estimators': 1200}, {'max_depth': 15, 'n_estimators': 120}, {'max_depth': 15, 'n_estimators': 200}, {'max_depth': 15, 'n_estimators': 300}, {'max_depth': 15, 'n_estimators': 500}, {'max_depth': 15, 'n_estimators': 800}, {'max_depth': 15, 'n_estimators': 1200}, {'max_depth': 25, 'n_estimators': 120}, {'max_depth': 25, 'n_estimators': 200}, {'max_depth': 25, 'n_estimators': 300}, {'max_depth': 25, 'n_estimators': 500}, {'max_depth': 25, 'n_estimators': 800}, {'max_depth': 25, 'n_estimators': 1200}, {'max_depth': 30, 'n_estimators': 120}, {'max_depth': 30, 'n_estimators': 200}, {'max_depth': 30, 'n_estimators': 300}, {'max_depth': 30, 'n_estimators': 500}, {'max_depth': 30, 'n_estimators': 800}, {'max_depth': 30, 'n_estimators': 1200}], 'split0_test_score': array([0.8047138 , 0.8047138 , 0.81481481, 0.81144781, 0.81144781,
0.81818182, 0.80808081, 0.79461279, 0.79124579, 0.7979798 ,
0.8047138 , 0.8013468 , 0.77777778, 0.77777778, 0.77104377,
0.77441077, 0.77104377, 0.77104377, 0.78114478, 0.76767677,
0.77104377, 0.77777778, 0.77777778, 0.76767677, 0.78114478,
0.77104377, 0.76767677, 0.77104377, 0.77104377, 0.77104377]), 'split1_test_score': array([0.83501684, 0.83501684, 0.83501684, 0.82828283, 0.83501684,
0.82491582, 0.84848485, 0.84511785, 0.83838384, 0.83501684,
0.85521886, 0.83838384, 0.81144781, 0.81144781, 0.8047138 ,
0.81818182, 0.82491582, 0.81818182, 0.82154882, 0.82154882,
0.82491582, 0.82491582, 0.82154882, 0.82154882, 0.82491582,
0.82491582, 0.82154882, 0.82154882, 0.82154882, 0.82491582]), 'split2_test_score': array([0.81144781, 0.8047138 , 0.81481481, 0.80808081, 0.80808081,
0.81144781, 0.81481481, 0.81481481, 0.82491582, 0.82154882,
0.81818182, 0.82491582, 0.8047138 , 0.81144781, 0.80808081,
0.81481481, 0.81481481, 0.81144781, 0.7979798 , 0.78451178,
0.79124579, 0.79461279, 0.7979798 , 0.8047138 , 0.8047138 ,
0.8013468 , 0.79461279, 0.79124579, 0.79461279, 0.8013468 ]), 'mean_test_score': array([0.81705948, 0.81481481, 0.82154882, 0.81593715, 0.81818182,
0.81818182, 0.82379349, 0.81818182, 0.81818182, 0.81818182,
0.82603816, 0.82154882, 0.7979798 , 0.80022447, 0.79461279,
0.80246914, 0.80359147, 0.80022447, 0.80022447, 0.79124579,
0.79573513, 0.79910213, 0.79910213, 0.7979798 , 0.80359147,
0.79910213, 0.79461279, 0.79461279, 0.79573513, 0.79910213]), 'std_test_score': array([0.01299196, 0.01428499, 0.00952332, 0.00883727, 0.01198325,
0.00549829, 0.01767454, 0.0207556 , 0.01982438, 0.0153066 ,
0.02135387, 0.0153066 , 0.01454712, 0.01587221, 0.01672241,
0.01988782, 0.02338122, 0.0208162 , 0.01657107, 0.02250274,
0.02222109, 0.01950409, 0.01788707, 0.02250274, 0.01788707,
0.02205037, 0.02199317, 0.0207556 , 0.02063387, 0.02205037]), 'rank_test_score': array([10, 12, 3, 11, 5, 5, 2, 5, 5, 5, 1, 3, 23, 16, 27, 15, 13,
16, 16, 30, 25, 19, 19, 23, 13, 19, 27, 27, 25, 19])}
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)