도찐개찐

[머신러닝] 09. 의사결정 나무 본문

PYTHON/데이터분석

[머신러닝] 09. 의사결정 나무

도개진 2023. 1. 3. 10:11

의사결정트리

  • 나무 모양의 그래프를 사용해서 최적의 결정을 돕는 분석기법
  • 기회비용에 대한 고려, 기대 이익 계산, 위험 관리등 효율적인 결정이 필요한 많은 분야에 사용되고 있음
  • 의사결정트리는 회귀,분류에 사용되지만 주로 분류에 많이 이용되고 있음
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

iris 데이터셋을 이용한 분석

from sklearn.datasets import load_iris
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, train_size = 0.7,
                stratify=iris.target, random_state=2211171705)
dtclf = DecisionTreeClassifier(criterion='entropy') # 기본값은 gini 계수 > entropy 로 변환
dtclf.fit(X_train, y_train)
dtclf.score(X_train, y_train)
1.0
pred = dtclf.predict(X_test)
accuracy_score(y_test, pred)
0.8888888888888888

의사결정나무 시각화

  • 나무 모양의 그래프를 사용해서 최적의 결정을 돕는 분석기법
  • 기회비용에 대한 고려, 기대 이익 계산, 위험 관리등 효율적인 결정이 필요한 많은 분야에 사용되고 있음
  • 의사결정트리는 회귀,분류에 사용되지만 주로 분류에 많이 이용되고 있음
  • 출력하려면 graphviz라는 라이브러리 필요!
    • graphviz.org => stable 2.38 windows => graphviz-2.38.zip
      • c:/Java 아래에 압축 해제
      • 폴더명은 graphviz-2.38
      • bin 폴더를 PATH 환경변수로 등록
    • 리눅스는 yum install -y graphviz
    • 맥은 brew install graphviz
  • 파이썬 pydotplus 패키지도 필요함
#!conda install -y pydotplus
import pydotplus
from sklearn import tree
dot_data = tree.export_graphviz(dtclf, out_file=None,
                               feature_names=iris.feature_names,
                                class_names=iris.target_names
                               )
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('../img/iris.png')
True
import matplotlib.image as pltimg

img = pltimg.imread('../img/iris.png')

plt.figure(figsize=(12,8))
plt.imshow(img)
plt.axis('off')
plt.show()

변수별 중요도 확인

dtclf.feature_importances_
array([0.01655372, 0.0149634 , 0.07455295, 0.89392993])
titanic = pd.read_csv('../data/titanic2.csv')
titanic
  Unnamed: 0 pclass survived name sex age sibsp parch fare embarked title gender Embarked Title
0 0 1 1 Allen, Miss. Elisabeth Walton female 29.0000 0 0 211.3375 S Miss 0 2 10
1 1 1 1 Allison, Master. Hudson Trevor male 0.9167 1 2 151.5500 S Master 1 2 9
2 2 1 0 Allison, Miss. Helen Loraine female 2.0000 1 2 151.5500 S Miss 0 2 10
3 3 1 0 Allison, Mr. Hudson Joshua Creighton male 30.0000 1 2 151.5500 S Mr 1 2 13
4 4 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female 25.0000 1 2 151.5500 S Mrs 0 2 14
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1301 1304 3 0 Zabour, Miss. Hileni female 14.5000 1 0 14.4542 C Miss 0 0 10
1302 1305 3 0 Zabour, Miss. Thamine female 28.0000 1 0 14.4542 C Miss 0 0 10
1303 1306 3 0 Zakarian, Mr. Mapriededer male 26.5000 0 0 7.2250 C Mr 1 0 13
1304 1307 3 0 Zakarian, Mr. Ortin male 27.0000 0 0 7.2250 C Mr 1 0 13
1305 1308 3 0 Zimmerman, Mr. Leo male 29.0000 0 0 7.8750 S Mr 1 2 13

1306 rows × 14 columns

data = titanic.iloc[:, [1,5,6,7,8, 11,12,13]]
target = titanic.survived
X_train, X_test, y_train, y_test = train_test_split(data, target, train_size = 0.7,
                stratify=target, random_state=2211171705)
dtclf = DecisionTreeClassifier(criterion='entropy', max_depth=5) # 기본값은 gini 계수 > entropy 로 변환
dtclf.fit(X_train, y_train)
dtclf.score(X_train, y_train)
0.8413566739606126
pred = dtclf.predict(X_test)
accuracy_score(y_test, pred)
0.8010204081632653
dot_data = tree.export_graphviz(dtclf, out_file=None,
                               feature_names=data.columns
                               #  class_names=iris.target_names
                               )
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('../img/titanic.png')
True

변수별 중요도 확인

dtclf.feature_importances_
array([0.17367191, 0.11992278, 0.02903506, 0.        , 0.12023138,
       0.47256337, 0.02291981, 0.06165569])
img = pltimg.imread('../img/titanic.png')

plt.figure(figsize=(12,8))
plt.imshow(img)
plt.axis('off')
plt.show()

728x90
Comments