决策树 (Decision Tree) 代码——Scikit Learn 实例演示

/ Machine Learning / 没有评论 / 966浏览

数据准备

准备一份训练数据,如下

    Day   Outlook Temperature Humidity    Wind Play Golf
0    D1     Sunny         Hot     High    Weak        No
1    D2     Sunny         Hot     High  Strong        No
2    D3  Overcast         Hot     High    Weak       Yes
3    D4      Rain        Mild     High    Weak       Yes
4    D5      Rain        Cool   Normal    Weak       Yes
5    D6      Rain        Cool   Normal  Strong        No
6    D7  Overcast        Cool   Normal  Strong       Yes
7    D8     Sunny        Mild     High    Weak        No
8    D9     Sunny        Cool   Normal    Weak       Yes
9   D10      Rain        Mild   Normal    Weak       Yes
10  D11     Sunny        Mild   Normal  Strong       Yes
11  D12  Overcast        Mild     High  Strong       Yes
12  D13  Overcast         Hot   Normal    Weak       Yes
13  D14      Rain        Mild     High  Strong        No

将其保存为 csv 文件

读取数据

from sklearn.feature_extraction import DictVectorizer
import csv
from sklearn import preprocessing
from sklearn import tree
from sklearn.externals.six import StringIO
data = open("data1.csv","r")
reader = csv.reader(data)
headers = next(reader)
print(headers)
['Day', 'Outlook', 'Temperature', 'Humidity', 'Wind', 'Play Golf']

构造 Feature 字典

featureList = []
labelList = []
for row in reader:
    labelList.append(row[len(row) - 1])
    rowDict = {}
    for i in range(1,len(row) - 1):
        rowDict[headers[i]] = row[i]
    featureList.append(rowDict)
print(featureList)
[{'Outlook': 'Sunny', 'Temperature': 'Hot', 'Humidity': 'High', 'Wind': 'Strong'}, {'Outlook': 'Overcast', 'Temperature': 'Hot', 'Humidity': 'High', 'Wind': 'Weak'}, {'Outlook': 'Rain', 'Temperature': 'Mild', 'Humidity': 'High', 'Wind': 'Weak'}, {'Outlook': 'Rain', 'Temperature': 'Cool', 'Humidity': 'Normal', 'Wind': 'Weak'}, {'Outlook': 'Rain', 'Temperature': 'Cool', 'Humidity': 'Normal', 'Wind': 'Strong'}, {'Outlook': 'Overcast', 'Temperature': 'Cool', 'Humidity': 'Normal', 'Wind': 'Strong'}, {'Outlook': 'Sunny', 'Temperature': 'Mild', 'Humidity': 'High', 'Wind': 'Weak'}, {'Outlook': 'Sunny', 'Temperature': 'Cool', 'Humidity': 'Normal', 'Wind': 'Weak'}, {'Outlook': 'Rain', 'Temperature': 'Mild', 'Humidity': 'Normal', 'Wind': 'Weak'}, {'Outlook': 'Sunny', 'Temperature': 'Mild', 'Humidity': 'Normal', 'Wind': 'Strong'}, {'Outlook': 'Overcast', 'Temperature': 'Mild', 'Humidity': 'High', 'Wind': 'Strong'}, {'Outlook': 'Overcast', 'Temperature': 'Hot', 'Humidity': 'Normal', 'Wind': 'Weak'}, {'Outlook': 'Rain', 'Temperature': 'Mild', 'Humidity': 'High', 'Wind': 'Strong'}]

Feature 标准化

vec = DictVectorizer()
X = vec.fit_transform(featureList).toarray()
print(vec.get_feature_names())
print(str(X))
['Humidity=High', 'Humidity=Normal', 'Outlook=Overcast', 'Outlook=Rain', 'Outlook=Sunny', 'Temperature=Cool', 'Temperature=Hot', 'Temperature=Mild', 'Wind=Strong', 'Wind=Weak']
[[ 1.  0.  0.  0.  1.  0.  1.  0.  0.  1.]
    [ 1.  0.  0.  0.  1.  0.  1.  0.  1.  0.]
    [ 1.  0.  1.  0.  0.  0.  1.  0.  0.  1.]
    [ 1.  0.  0.  1.  0.  0.  0.  1.  0.  1.]
    [ 0.  1.  0.  1.  0.  1.  0.  0.  0.  1.]
    [ 0.  1.  0.  1.  0.  1.  0.  0.  1.  0.]
    [ 0.  1.  1.  0.  0.  1.  0.  0.  1.  0.]
    [ 1.  0.  0.  0.  1.  0.  0.  1.  0.  1.]
    [ 0.  1.  0.  0.  1.  1.  0.  0.  0.  1.]
    [ 0.  1.  0.  1.  0.  0.  0.  1.  0.  1.]
    [ 0.  1.  0.  0.  1.  0.  0.  1.  1.  0.]
    [ 1.  0.  1.  0.  0.  0.  0.  1.  1.  0.]
    [ 0.  1.  1.  0.  0.  0.  1.  0.  0.  1.]
    [ 1.  0.  0.  1.  0.  0.  0.  1.  1.  0.]]

Label 标准化

print(str(labelList))
lb = preprocessing.LabelBinarizer()
Y = lb.fit_transform(labelList)
print(Y)
['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No']
[[0]
    [0]
    [1]
    [1]
    [1]
    [0]
    [1]
    [0]
    [1]
    [1]
    [1]
    [1]
    [1]
    [0]]

训练模型

clf = tree.DecisionTreeClassifier(criterion="entropy")
clf = clf.fit(X,Y)
print(str(clf))
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

模型写入文件

with open("tree.dot","w") as f:
    f = tree.export_graphviz(clf, feature_names=vec.get_feature_names(), out_file = f)

此时会生成模型文件 tree.dot

查看模型

使用graphviz工具将 dot 文件转化为可视化的 pdf 文件

dot -Tpdf tree.dot -o tree.pdf

结果如下

创建新的数据

将原始数据第一行稍作修改,用刚才生成的模型做预测

oneRowX = X[0,:]
print(oneRowX)
newRowX = oneRowX
newRowX[0] = 0
newRowX[1] = 1
print(newRowX)
[ 1.  0.  0.  0.  1.  0.  1.  0.  0.  1.]
[ 0.  1.  0.  0.  1.  0.  1.  0.  0.  1.]

模型预测

predictedY = clf.predict(newRowX.reshape(1,-1))
print(predictedY)
[1]