sklearn中决策树的应用(python)
摘要:本文使用决策树进行分类,回归。
00 安装scikit-learn库
pip install scikit-learn
01 获取sklearn中鸢尾花数据
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
iris=datasets.load_iris()
dex1=np.random.choice(150,size=120,replace=False)
dex2=[]
for i in range(150):
if i not in dex1:
dex2.append(i)
train_x=iris.data[dex1,:]
train_y=iris.target[dex1]
test_x=iris.data[dex2,:]
test_y=iris.target[dex2]
02 分类树
classi=DecisionTreeClassifier()
classi.fit(train_x,train_y)
classi.score(test_x,test_y)
Out[112]: 0.9333333333333333
classi.predict(test_x)
classi=DecisionTreeClassifier(criterion='gini')
classi.fit(train_x,train_y)
classi.score(test_x,test_y)
Out[135]: 0.9
classi=DecisionTreeClassifier(criterion='entropy')
classi.fit(train_x,train_y)
classi.score(test_x,test_y)
Out[136]: 0.9333333333333333
classi=DecisionTreeClassifier(splitter='best')
classi.fit(train_x,train_y)
classi.score(test_x,test_y)
Out[139]: 0.9666666666666667
classi=DecisionTreeClassifier(splitter='random')
classi.fit(train_x,train_y)
classi.score(test_x,test_y)
Out[142]: 0.8666666666666667
scor=[]
for i in range(1,20):
classi=DecisionTreeClassifier(max_depth=i)
classi.fit(train_x,train_y)
sco=classi.score(test_x,test_y)
scor.append(sco)
plt.plot(range(1,20),scor)
plt.grid(axis='both')
03 构造数据
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
x=np.pi*2*np.random.rand(100)
y=np.sin(x)
y[::5]+=(np.random.rand(20)-0.5)*2
dex1=np.random.choice(100,75,replace=False)
dex2=[]
for i in range(100):
if i not in dex1:
dex2.append(i)
train_x=x[dex1].reshape(-1,1)
train_y=y[dex1].reshape(-1,1)
test_x=x[dex2].reshape(-1,1)
test_y=y[dex2].reshape(-1,1)
04 回归树
regre=DecisionTreeRegressor()
regre.fit(train_x,train_y)
regre.score(test_x,test_y)
Out[145]: 0.6619563827409216
regre.predict(test_x)
Out[146]:
array([ 0.98288727, -0.42291139, -0.11918768, 1.21105977, -0.1792897 ,
0.88046883, 0.77566524, -0.89663976, -0.1565603 , -0.57729087,
-0.89663976, -1.09538944, -0.97786811, 0.51036697, -1.09538944,
0.20600983, 0.14411034, 0.88046883, -0.60787944, 0.20600983,
-0.1792897 , 0.77566524, 0.41635144, 0.36389462, 0.88046883])
x1=np.arange(0,6,0.1).reshape(-1,1)
y1=regre.predict(x1).reshape(-1,1)
x1=np.arange(0,6,0.1).reshape(-1,1)
y1=regre.predict(x1).reshape(-1,1)
plt.plot(x1,y1,linewidth=0.5,marker='*',ms=10,mfc='g',mec='g',
alpha=0.4,label='predicted data')
plt.scatter(x,y,s=12,color='r',label='original data')
plt.legend(loc='upper right')
regre=DecisionTreeRegressor(splitter='best')
regre.fit(train_x,train_y)
regre.score(test_x,test_y)
Out[150]: 0.6619563827409216
regre=DecisionTreeRegressor(splitter='random')
regre.fit(train_x,train_y)
regre.score(test_x,test_y)
Out[151]: 0.6543112875182975
scor=[]
for i in range(1,20):
regre=DecisionTreeRegressor(max_depth=i)
regre.fit(train_x,train_y)
sco=regre.score(test_x,test_y)
scor.append(sco)
plt.plot(range(1,20),scor)
plt.grid(axis='both')
05 总结
01 决策树的关键概念:熵,信息增益,信息增益比;
02 决策树的模型复杂程度(剪枝,max_depth,log2(samples))