A python library to build Model Trees with Linear Models at the leaves.

Overview

linear-tree

A python library to build Model Trees with Linear Models at the leaves.

Overview

Linear Model Trees combine the learning ability of Decision Tree with the predictive and explicative power of Linear Models. Like in tree-based algorithms, the data are split according to simple decision rules. The goodness of slits is evaluated in gain terms fitting Linear Models in the nodes. This implies that the models in the leaves are linear instead of constant approximations like in classical Decision Trees.

linear-tree is developed to be fully integrable with scikit-learn. LinearTreeRegressor and LinearTreeClassifier are provided as scikit-learn BaseEstimator. They are wrappers that build a decision tree on the data fitting a linear estimator from sklearn.linear_model. All the models available in sklearn.linear_model can be used as linear estimators.

Installation

pip install linear-tree

The module depends on NumPy, SciPy and Scikit-Learn (>=0.23.0). Python 3.6 or above is supported.

Media

Usage

Regression
from sklearn.linear_model import LinearRegression
from lineartree import LinearTreeRegressor
from sklearn.datasets import make_regression
X, y = make_regression(n_samples=100, n_features=4,
                       n_informative=2, n_targets=1,
                       random_state=0, shuffle=False)
regr = LinearTreeRegressor(base_estimator=LinearRegression())
regr.fit(X, y)
Classification
from sklearn.linear_model import RidgeClassifier
from lineartree import LinearTreeClassifier
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=100, n_features=4,
                           n_informative=2, n_redundant=0,
                           random_state=0, shuffle=False)
clf = LinearTreeClassifier(base_estimator=RidgeClassifier())
clf.fit(X, y)

More examples in the notebooks folder.

Check the API Reference to see the parameter configurations and the available methods.

Examples

Show the model tree structure:

plot tree

Linear Tree Regressor at work:

linear tree regressor

Linear Tree Classifier at work:

linear tree classifier

Extract and examine coefficients at the leaves:

leaf coefficients

Comments
  • finding breakpoint

    finding breakpoint

    Hello,

    thank you for your nice tool. I am using the function LinearTreeRegressor to draw a continuous piecewise linear. It works well, I am wondering, is it possible to show the location (the coordinates) of the breakpoints?

    thank you

    opened by ZhengLiu1119 5
  • Allow the hyperparameter

    Allow the hyperparameter "max_depth = 0".

    Thanks for the good library.

    When using LinearTreeRegressor, I think that max_depth is often optimized by cross-validation.

    This library allows max_depth in the range 1-20. However, depending on the dataset, simple linear regression may be suitable. Even in such a dataset, max_depth is forced to be 1 or more, so Simple Linear Regression cannot be applied properly with LinearTreeRegressor.

    • Of course, it is appropriate to use sklearn.linear_model.LinearRegression for such datasets.

    My suggestion is to change to a program that uses base_estimator to perform regression when "max_depth = 0". With this change, LinearTreeRegressor can flexibly respond to both segmented regression and simple regression by changing hyperparameters.

    opened by jckkvs 4
  • Error when running with multiple jobs: unexpected keyword argument 'target_offload'

    Error when running with multiple jobs: unexpected keyword argument 'target_offload'

    I have been using your library for quite a while and am super happy with it. So first, thanks a lot!

    Lately, I used my framework (which also uses your library) on modern many core server with many jobs. Worked fine. Now I have updated everything via pip and with 8 jobs on my MacBook, I got the following error.

    This error does not occur when using only a single job (I pass the number of jobs to n_jobs).

    I cannot nail the down the actual problem, but since it occurred right after the upgrade, I assume this might be the reason?

    Am I doing something wrong here?

    """
    Traceback (most recent call last):
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 436, in _process_worker
        r = call_item()
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py", line 288, in __call__
        return self.fn(*self.args, **self.kwargs)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 595, in __call__
        return self.func(*args, **kwargs)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 263, in __call__
        for func, args, kwargs in self.items]
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 263, in <listcomp>
        for func, args, kwargs in self.items]
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 56, in __call__
        with config_context(**self.config):
      File "/Users/martin/opt/anaconda3/lib/python3.7/contextlib.py", line 239, in helper
        return _GeneratorContextManager(func, args, kwds)
      File "/Users/martin/opt/anaconda3/lib/python3.7/contextlib.py", line 82, in __init__
        self.gen = func(*args, **kwds)
    TypeError: config_context() got an unexpected keyword argument 'target_offload'
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "compression_selection_pipeline.py", line 41, in <module>
        model_pipeline.learn_runtime_models(calibration_result_dir)
      File "/Users/martin/Programming/compression_selection_v3/hyrise_calibration/model_pipeline.py", line 670, in learn_runtime_models
        non_splitting_models("table_scan", table_scans)
      File "/Users/martin/Programming/compression_selection_v3/hyrise_calibration/model_pipeline.py", line 590, in non_splitting_models
        fitted_model = model_dict["model"].fit(X_train, y_train)
      File "/Users/martin/Programming/compression_selection_v3/hyrise_calibration/model_pipeline.py", line 209, in fit
        return self.regression.fit(X, y)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/lineartree.py", line 187, in fit
        self._fit(X, y, sample_weight)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 576, in _fit
        self._grow(X, y, sample_weight)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 387, in _grow
        loss=loss)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/lineartree/_classes.py", line 285, in _split
        for feat in split_feat)
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 1056, in __call__
        self.retrieve()
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 935, in retrieve
        self._output.extend(job.get(timeout=self.timeout))
      File "/Users/martin/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 542, in wrap_future_result
        return future.result(timeout=timeout)
      File "/Users/martin/opt/anaconda3/lib/python3.7/concurrent/futures/_base.py", line 435, in result
        return self.__get_result()
      File "/Users/martin/opt/anaconda3/lib/python3.7/concurrent/futures/_base.py", line 384, in __get_result
        raise self._exception
    TypeError: config_context() got an unexpected keyword argument 'target_offload'
    

    PS: I have already left a star. :D

    opened by Bouncner 3
  • Option to specify features to use for splitting and for leaf models

    Option to specify features to use for splitting and for leaf models

    Added two additional parameters:

    • split_features: Indices of features that can be used for splitting. Default all.
    • linear_features: Indices of features that are used by the linear models in the leaves. Default all except for categorical features

    This implements a feature requested in https://github.com/cerlymarco/linear-tree/issues/2

    Potential performance improvement: Currently the code still computes bins for all features and not only for those used for splitting.

    opened by JonasRauch 3
  • Rationale for rounding during _parallel_binning_fit and _grow

    Rationale for rounding during _parallel_binning_fit and _grow

    I noticed that the implementations of _parallel_binning_fit and _grow internally round loss values to 5 decimal places. This makes the regression results dependent on the scale of the labels, as data with a lower natural loss value will result in many different splits of the data having the same loss when rounded to 5 decimal places. Is there a reason why this is the case?

    This behavior can be observed by fitting a LinearTreeRegressor using the default loss function and multiplying the scale of the labels by a small number (like 1e-9). This will result in the regressor no longer learning any splits.

    opened by session-id 2
  • ValueError: Invalid parameter linearforestregression for estimator Pipeline

    ValueError: Invalid parameter linearforestregression for estimator Pipeline

    Great work! I'm new to ML and stuck with this. I'm trying to combine pipeline and GridSearch to search for best possible hyperparameters for a model.

    image

    I got the following error:

    image

    Kindly help : )

    opened by NousMei 2
  • Performance and possibility to split only on subset of features

    Performance and possibility to split only on subset of features

    Hey, I have been playing around a lot with your linear trees. Like them very much. Thanks!

    Nevertheless, I am somewhat disappointed by the runtime performance. Compared to XGBoost Regressors (I know it's not a fair comparison) or linear regressions (also not fair), the linear tree is reeeeeaally slow. 50k observations, 80 features: 2s for linear regression, 27s for XGBoost, and 300s for the linear tree. Have you seen similar runtimes or might I be using it wrong?

    Another aspects that's interesting to me is the question whether is possibe to limit the features which are used for splits. I haven't found it in the code. Any change to see it in the future?

    opened by Bouncner 2
  • export to graphviz  -AttributeError: 'LinearTreeRegressor' object has no attribute 'n_features_'

    export to graphviz -AttributeError: 'LinearTreeRegressor' object has no attribute 'n_features_'

    Hi

    thanks for writing this great package!

    I was trying to display the decision tree with graphviz I get this error

    AttributeError: 'LinearTreeRegressor' object has no attribute 'n_features_'

    from lineartree import LinearTreeRegressor from sklearn.linear_model import LinearRegression

    reg = LinearTreeRegressor(base_estimator=LinearRegression()) reg.fit(train[x_cols], train["y"])

    from graphviz import Source from sklearn import tree

    graph = Source( tree.export_graphviz(reg, out_file=None,feature_names=train.columns))

    opened by ricmarchao 2
  • numpy deprecation warning

    numpy deprecation warning

    /lineartree/_classes.py:338: DeprecationWarning:

    the interpolation= argument to quantile was renamed to method=, which has additional options. Users of the modes 'nearest', 'lower', 'higher', or 'midpoint' are encouraged to review the method they. (Deprecated NumPy 1.22)

    Seems like a quick update here would get this warning to stop showing up, right? I can always ignore it, but figured I would mention it in case it is actually an error on my side.

    Also, sorry, I don't actually what the best open source etiquette is. If I'm supposed to create a pull request with a proposed fix instead of just mentioning it then feel free to correct me.

    opened by paul-brenner 1
  • How to gridsearch tree and regression parameters?

    How to gridsearch tree and regression parameters?

    Hi, I am wondering how to perform a GridsearchCV to find best parameters for the tree and regression model? For now I am able to tune the tree component of my model:

    `

     param_grid={
        'n_estimators': [50, 100, 500, 700],
        'max_depth': [10, 20, 30, 50],
        'min_samples_split' : [2, 4, 8, 16, 32],
        'max_features' : ['sqrt', 'log2', None]
    }
    cv = RepeatedKFold(n_repeats=3,
                       n_splits=3,
                       random_state=1)
    
    model = GridSearchCV(
        LinearForestRegressor(ElasticNet(random_state = 0), random_state=42),
        param_grid=param_grid,
        n_jobs=-1,
        cv=cv,
        scoring='neg_root_mean_squared_error'
        )
    

    `

    opened by zuzannakarwowska 1
  • Potential bug in LinearForestClassifier 'predict_proba'

    Potential bug in LinearForestClassifier 'predict_proba'

    Hello! Thank you for useful package!

    I think I might have found a potential bug in LinearForestClassifier.

    I expected 'predict_proba' to use 'self.decision_function', similarly to 'predict' - to include predictions from both estimators (base + forest). Is that a potential bug or am I in wrong here?

    https://github.com/cerlymarco/linear-tree/blob/8d5beca8d492cb8c57e6618e3fb770860f28b550/lineartree/lineartree.py#L1560

    opened by PiotrKaszuba 1
Releases(0.3.5)
Owner
Marco Cerliani
Statistician Hacker & Data Scientist
Marco Cerliani
Marine debris detection with commercial satellite imagery and deep learning.

Marine debris detection with commercial satellite imagery and deep learning. Floating marine debris is a global pollution problem which threatens mari

Inter Agency Implementation and Advanced Concepts 56 Dec 16, 2022
Fully Automatic Page Turning on Real Scores

Fully Automatic Page Turning on Real Scores This repository contains the corresponding code for our extended abstract Henkel F., Schwaiger S. and Widm

Florian Henkel 7 Jan 02, 2022
Code for ICE-BeeM paper - NeurIPS 2020

ICE-BeeM: Identifiable Conditional Energy-Based Deep Models Based on Nonlinear ICA This repository contains code to run and reproduce the experiments

Ilyes Khemakhem 65 Dec 22, 2022
Tensorflow implementation of "Learning Deconvolution Network for Semantic Segmentation"

Tensorflow implementation of Learning Deconvolution Network for Semantic Segmentation. Install Instructions Works with tensorflow 1.11.0 and uses the

Fabian Bormann 224 Apr 15, 2022
Collects many various multi-modal transformer architectures, including image transformer, video transformer, image-language transformer, video-language transformer and related datasets

The repository collects many various multi-modal transformer architectures, including image transformer, video transformer, image-language transformer, video-language transformer and related datasets

Jun Chen 139 Dec 21, 2022
Context Axial Reverse Attention Network for Small Medical Objects Segmentation

CaraNet: Context Axial Reverse Attention Network for Small Medical Objects Segmentation This repository contains the implementation of a novel attenti

401 Dec 23, 2022
Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems

Learning an Adaptive Meta Model-Generator for Incrementally Updating Recommender Systems This is our experimental code for RecSys 2021 paper "Learning

11 Jul 28, 2022
Datasets and source code for our paper Webly Supervised Fine-Grained Recognition: Benchmark Datasets and An Approach

Introduction Datasets and source code for our paper Webly Supervised Fine-Grained Recognition: Benchmark Datasets and An Approach Datasets: WebFG-496

21 Sep 30, 2022
Google Recaptcha solver.

byerecaptcha - Google Recaptcha solver. Model and some codes takes from embium's repository -Installation- pip install byerecaptcha -How to use- from

Vladislav Zenkevich 21 Dec 19, 2022
Event sourced bank - A wide-and-shallow example using the Python event sourcing library

Event Sourced Bank A "wide but shallow" example of using the Python event sourci

3 Mar 09, 2022
Position detection system of mobile robot in the warehouse enviroment

Autonomous-Forklift-System About | GUI | Tests | Starting | License | Author | 🎯 About An application that run the autonomous forklift paletization a

Kamil Goś 1 Nov 24, 2021
Pytorch implementation of "A simple neural network module for relational reasoning" (Relational Networks)

Pytorch implementation of Relational Networks - A simple neural network module for relational reasoning Implemented & tested on Sort-of-CLEVR task. So

Kim Heecheol 800 Dec 05, 2022
[ECCV'20] Convolutional Occupancy Networks

Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page | Blog Post This repository contains the implementation o

622 Dec 30, 2022
MediaPipeで姿勢推定を行い、Tokyo2020オリンピック風のピクトグラムを表示するデモ

Tokyo2020-Pictogram-using-MediaPipe MediaPipeで姿勢推定を行い、Tokyo2020オリンピック風のピクトグラムを表示するデモです。 Tokyo2020Pictgram02.mp4 Requirement mediapipe 0.8.6 or later O

KazuhitoTakahashi 295 Dec 26, 2022
NLMpy - A Python package to create neutral landscape models

NLMpy is a Python package for the creation of neutral landscape models that are widely used by landscape ecologists to model ecological patterns

Manaaki Whenua – Landcare Research 1 Oct 08, 2022
Augmented Traffic Control: A tool to simulate network conditions

Augmented Traffic Control Full documentation for the project is available at http://facebook.github.io/augmented-traffic-control/. Overview Augmented

Meta Archive 4.3k Jan 08, 2023
Survival analysis (SA) is a well-known statistical technique for the study of temporal events.

DAGSurv Survival analysis (SA) is a well-known statistical technique for the study of temporal events. In SA, time-to-an-event data is modeled using a

Rahul Kukreja 1 Sep 05, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
ICON: Implicit Clothed humans Obtained from Normals (CVPR 2022)

ICON: Implicit Clothed humans Obtained from Normals Yuliang Xiu · Jinlong Yang · Dimitrios Tzionas · Michael J. Black CVPR 2022 News 🚩 [2022/04/26] H

Yuliang Xiu 1.1k Jan 04, 2023
MiraiML: asynchronous, autonomous and continuous Machine Learning in Python

MiraiML Mirai: future in japanese. MiraiML is an asynchronous engine for continuous & autonomous machine learning, built for real-time usage. Usage In

Arthur Paulino 25 Jul 27, 2022