import csv
from sklearn.utils import Bunch
# 读取西瓜数据集
def readWatermelonDataSet():
FeatureNames = []
FeatureList = []
LabelList = []
ifile = open("E:\My Word\study\RL0314\data.csv", "r")
reader = csv.reader(ifile)
cnt = 0
for row in reader:
if cnt == 0: # 读取属性名称
headers = row
FeatureNames = headers[1:len(headers) - 1]
# print(FeatureNames)
else: # 读取数据和标签
headers = row
FeatureList.append(headers[1:len(headers) - 1])
LabelList.append(headers[len(headers) - 1])
cnt = cnt + 1
print(FeatureNames)
print(FeatureList)
print(LabelList)
return Bunch(
data=FeatureList,
target=LabelList,
feature_names=FeatureNames,
)
注意:如果想要直接使用sklearn后续算法,数据集里应该为数值型的数据,但凡加入西瓜数据集其他栏后续都会报错,需要做好数据预处理。
这里使用的数据集是这样的:
完整决策树生成代码:
import csv
from sklearn.utils import Bunch
from sklearn import tree
from sklearn.model_selection import train_test_split
import pandas as pd
import graphviz
import os
# 读取西瓜数据集
def readWatermelonDataSet():
FeatureNames = []
FeatureList = []
LabelList = []
ifile = open("E:\My Word\study\RL0314\data.csv", "r")
reader = csv.reader(ifile)
cnt = 0
for row in reader:
if cnt == 0: # 读取属性名称
headers = row
FeatureNames = headers[1:len(headers) - 1]
# print(FeatureNames)
else: # 读取数据和标签
headers = row
FeatureList.append(headers[1:len(headers) - 1])
LabelList.append(headers[len(headers) - 1])
cnt = cnt + 1
print(FeatureNames)
print(FeatureList)
print(LabelList)
return Bunch(
data=FeatureList,
target=LabelList,
feature_names=FeatureNames,
)
def main():
watermelon = readWatermelonDataSet() # 西瓜数据
pd.concat([pd.DataFrame(watermelon.data), pd.DataFrame(watermelon.target)], axis=1)
Xtrain, Xtest, Ytarin, Ytest = train_test_split(watermelon.data, watermelon.target, test_size=0.3) # 测试集30%训练集70%
"""建立模型"""
clf = tree.DecisionTreeClassifier(criterion="entropy") # 实例化,分类树
clf = clf.fit(Xtrain, Ytarin)
score = clf.score(Xtest, Ytest)
score
dot_data = tree.export_graphviz(clf
, feature_names=watermelon.feature_names
, class_names=["好瓜", "坏瓜"]
, filled=True
, rounded=True
, special_characters=True
, fontname="Microsoft YaHei")
graph = graphviz.Source(dot_data)
os.environ["PATH"] += os.pathsep + 'D:/DiyProgram/graphviz/bin/'
graph.render("watermelon1", view=True)
if __name__ == "__main__":
main()
运行结果:
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)