决策树分类与回归模型的实现和可视化

Overview

DecisionTree

决策树分类与回归模型,以及可视化

ID3

ID3决策树是最朴素的决策树分类器:

  • 无剪枝
  • 只支持离散属性
  • 采用信息增益准则

data.py中,我们记录了一个小的西瓜数据集,用于离散属性的二分类任务。我们可以像下面这样训练一个ID3决策树分类器:

from ID3 import ID3Classifier
from data import load_watermelon2
import numpy as np

X, y = load_watermelon2(return_X_y=True) # 函数参数仿照sklearn.datasets
model = ID3Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树是正确的。

C4.5

C4.5决策树分类器对ID3进行了改进:

  • 用信息增益率的启发式方法来选择划分特征;
  • 能够处理离散型和连续型的属性类型,即将连续型的属性进行离散化处理;
  • 剪枝;
  • 能够处理具有缺失属性值的训练数据;

我们实现了前两点,以及第三点中的预剪枝功能(超参数)

data.py中还有一个连续离散特征混合的西瓜数据集,我们用它来测试C4.5决策树的效果:

from C4_5 import C4_5Classifier
from data import load_watermelon3
import numpy as np

X, y = load_watermelon3(return_X_y=True) # 函数参数仿照sklearn.datasets
model = C4_5Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树正确.

CART

分类

CART(Classification and Regression Tree)是C4.5决策树的扩展,支持分类和回归。CART分类树算法使用基尼系数选择特征,此外对于离散特征,CART决策树在每个节点二分划分,缓解了过拟合。

这里我们用sklearn中的鸢尾花数据集测试:

from CART import CARTClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTClassifier()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(accuracy_score(test_y, pred))

准确率95.55%。

回归

CARTRegressor类实现了决策树回归,以sklearn的波士顿数据集为例:

from CART import CARTRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X, y = load_boston(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTRegressor()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(mean_squared_error(test_y, pred))

输出26.352171052631576,sklearn决策树回归的Baseline是22.46,性能近似,说明我们的实现正确。

决策树绘制

分类树

利用python3的graphviz第三方库和Graphviz(需要安装),我们可以将决策树可视化:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)
model = CARTClassifier()
model.fit(X, y)
tree_plot(model)

运行,文件夹中生成tree.png

iris_tree

如果提供了特征的名词和标签的名称,决策树会更明显:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

iris = load_iris()
model = CARTClassifier()
model.fit(iris.data, iris.target)
tree_plot(model,
          filename="tree2",
          feature_names=iris.feature_names,
          target_names=iris.target_names)

iris_tree2

绘制西瓜数据集2对应的ID3决策树:

from plot import tree_plot
from ID3 import ID3Classifier
from data import load_watermelon2

watermelon = load_watermelon2()
model = ID3Classifier()
model.fit(watermelon.data, watermelon.target)
tree_plot(
    model,
    filename="tree",
    font="SimHei",
    feature_names=watermelon.feature_names,
    target_names=watermelon.target_names,
)

这里要自定义字体,否则无法显示中文:

watermelon

回归树

用同样的方法,我们可以进行回归树的绘制:

from plot import tree_plot
from ID3 import ID3Classifier
from sklearn.datasets import load_boston

boston = load_boston()
model = ID3Classifier(max_depth=5)
model.fit(boston.data, boston.target)
tree_plot(
    model,
    feature_names=boston.feature_names,
)

由于生成的回归树很大,我们限制最大深度再绘制:

regression

调参

CART和C4.5都是有超参数的,我们让它们作为sklearn.base.BaseEstimator的派生类,借助sklearn的GridSearchCV,就可以实现调参:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split, GridSearchCV

wine = load_wine()
train_X, test_X, train_y, test_y = train_test_split(
    wine.data,
    wine.target,
    train_size=0.7,
)
model = CARTClassifier()
grid_param = {
    'max_depth': [2, 4, 6, 8, 10],
    'min_samples_leaf': [1, 3, 5, 7],
}

search = GridSearchCV(model, grid_param, n_jobs=4, verbose=5)
search.fit(train_X, train_y)
best_model = search.best_estimator_
print(search.best_params_, search.best_estimator_.score(test_X, test_y))
tree_plot(
    best_model,
    feature_names=wine.feature_names,
    target_names=wine.target_names,
)

输出最优参数和最优模型在测试集上的表现:

{'max_depth': 4, 'min_samples_leaf': 3} 0.8518518518518519

绘制对应的决策树:

wine

剪枝

在ID3和CART回归中加入了REP剪枝,C4.5则支持了PEP剪枝。

对IRIS数据集训练后的决策树进行PEP剪枝:

iris = load_iris()
model = C4_5Classifier()
X, y = iris.data, iris.target
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model.fit(train_X, train_y)
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/pre_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names)
model.pep_pruning()
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/post_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names,
)

剪枝前后的准确率分别为97.78%,100%,即泛化性能的提升:

prepre

Owner
Welt Xing
Undergraduate in AI school, Nanjing University. Main interest(for now): Machine learning and deep learning.
Welt Xing
flexible time-series processing & feature extraction

A corona statistics and information telegram bot.

PreDiCT.IDLab 206 Dec 28, 2022
A Python implementation of GRAIL, a generic framework to learn compact time series representations.

GRAIL A Python implementation of GRAIL, a generic framework to learn compact time series representations. Requirements Python 3.6+ numpy scipy tslearn

3 Nov 24, 2021
Implemented four supervised learning Machine Learning algorithms

Implemented four supervised learning Machine Learning algorithms from an algorithmic family called Classification and Regression Trees (CARTs), details see README_Report.

Teng (Elijah) Xue 0 Jan 31, 2022
Provide an input CSV and a target field to predict, generate a model + code to run it.

automl-gs Give an input CSV file and a target field you want to predict to automl-gs, and get a trained high-performing machine learning or deep learn

Max Woolf 1.8k Jan 04, 2023
Retrieve annotated intron sequences and classify them as minor (U12-type) or major (U2-type)

(intron I nterrogator and C lassifier) intronIC is a program that can be used to classify intron sequences as minor (U12-type) or major (U2-type), usi

Graham Larue 4 Jul 26, 2022
Tool for producing high quality forecasts for time series data that has multiple seasonality with linear or non-linear growth.

Prophet: Automatic Forecasting Procedure Prophet is a procedure for forecasting time series data based on an additive model where non-linear trends ar

Facebook 15.4k Jan 07, 2023
A Python package for time series classification

pyts: a Python package for time series classification pyts is a Python package for time series classification. It aims to make time series classificat

Johann Faouzi 1.4k Jan 01, 2023
Evidently helps analyze machine learning models during validation or production monitoring

Evidently helps analyze machine learning models during validation or production monitoring. The tool generates interactive visual reports and JSON profiles from pandas DataFrame or csv files. Current

Evidently AI 3.1k Jan 07, 2023
A Streamlit demo to interactively visualize Uber pickups in New York City

Streamlit Demo: Uber Pickups in New York City A Streamlit demo written in pure Python to interactively visualize Uber pickups in New York City. View t

Streamlit 230 Dec 28, 2022
Napari sklearn decomposition

napari-sklearn-decomposition A simple plugin to use with napari This napari plug

1 Sep 01, 2022
Combines MLflow with a database (PostgreSQL) and a reverse proxy (NGINX) into a multi-container Docker application

Combines MLflow with a database (PostgreSQL) and a reverse proxy (NGINX) into a multi-container Docker application (with docker-compose).

Philip May 2 Dec 03, 2021
Implementation of different ML Algorithms from scratch, written in Python 3.x

Implementation of different ML Algorithms from scratch, written in Python 3.x

Gautam J 393 Nov 29, 2022
Machine Learning Algorithms ( Desion Tree, XG Boost, Random Forest )

implementation of machine learning Algorithms such as decision tree and random forest and xgboost on darasets then compare results for each and implement ant colony and genetic algorithms on tsp map,

Mohamadreza Rezaei 1 Jan 19, 2022
2D fluid simulation implementation of Jos Stam paper on real-time fuild dynamics, including some suggested extensions.

Fluid Simulation Usage Download this repo and store it in your computer. Open a terminal and go to the root directory of this folder. Make sure you ha

Mariana Ávalos Arce 5 Dec 02, 2022
Educational python for Neural Networks, written in pure Python/NumPy.

Educational python for Neural Networks, written in pure Python/NumPy.

127 Oct 27, 2022
Turns your machine learning code into microservices with web API, interactive GUI, and more.

Turns your machine learning code into microservices with web API, interactive GUI, and more.

Machine Learning Tooling 2.8k Jan 02, 2023
TensorFlow Decision Forests (TF-DF) is a collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models.

TensorFlow Decision Forests (TF-DF) is a collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models. The library is a collection of Keras models

538 Jan 01, 2023
A library of sklearn compatible categorical variable encoders

Categorical Encoding Methods A set of scikit-learn-style transformers for encoding categorical variables into numeric by means of different techniques

2.1k Jan 07, 2023
BASTA: The BAyesian STellar Algorithm

BASTA: BAyesian STellar Algorithm Current stable version: v1.0 Important note: BASTA is developed for Python 3.8, but Python 3.7 should work as well.

BASTA team 16 Nov 15, 2022