Skip to content

Latest commit

 

History

History
93 lines (74 loc) · 3.38 KB

README.md

File metadata and controls

93 lines (74 loc) · 3.38 KB

Node-GAMLSS: Differentiable Additive Model for Interpretable Distributional Deep Learning

Node-GAMLSS integrates the Node-GAM framework and the NAMLSS/GAMLSS framework to train Generalized Additive Models for Location, Scale, and Shape (GAMLSS) using multi-layer differentiable trees. This model minimizes the negative log-likelihood for a variety of distributions, with built-in support for handling distributional parameter constraints. The implementation adapts the Node-GAM approach to accommodate distributional regression, making it suitable for deep learning applications that require interpretability. In short, it trains a GAMLSS model by multi-layer differentiable trees and minimizes the negative log-likelihood of a given distribution.

Relevant Papers:

The codebase is inspired by and adapted from the Node-GAM Github repository.

Installation

Install Node-GAMLSS directly from the repository using pip:

pip install git+https://github.com/AnFreTh/NodeGAMLSS.git

NodeGAM Training

Sklearn interface

To simply minimize the negative log-likelihood of a normal distribution on a given dataset just use:

from nodegamlss.model import NodeGAMLSS

model = NodeGAMLSS(
    in_features=X.shape[1],
    family="normal",
    device="cpu", #or "cuda"
    verbose=False,
    max_steps=5000,
    lr=0.001,
    report_frequency=100,
    num_trees=75,
)


record = model.fit(X, y)

The "family" parameter defines the distribution that is used.

See nodegamlss/distributions for details regarding the distributions and their parameters. So far, the following distributions are implemented. Feel free to raise an Issue when crucial distributions are missing:

  1. Normal
  2. Poisson
  3. Gamma
  4. Inverse Gamma
  5. Dirichlet
  6. Beta
  7. StudentT
  8. Negative Binomial
  9. Categorical

After fitting, the individual feature and parameter effects are easily visualizable. There are multiple ways to visualize the model. First, we can simply leverage the Node-GAM visualizations and create a Node-GAM-style plot for each parameter.

fig, axes, df = model.visualize_nodegam(X)

This will create as many plots as you have distributional parameters and then create for each feature and feature interaction a subplot.

Additionally, you can create dash plots for each distributional parameter. E.g. for a normal distribution, you can visualize the mean predictions via:

model.plot_single_feature_effects(X, parameter="mean")

And the variance via:

model.plot_single_feature_effects(X, parameter="variance")

The interaction plots are similarly created via:

model.plot_interaction_effects(X, port=8051, parameter="mean")

Furthermore, it is similar to Node-GAM easy and user-friendly to visualize the loss and the evaluation metric simply via:

plt.figure(figsize=[18, 6])
plt.subplot(1, 2, 1)
plt.plot(record['train_losses'])
plt.title('Training Loss')
plt.xlabel('Steps')
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(record['val_metrics'])
plt.title('Validation Metrics')
plt.xlabel('Steps in every 100 step')
plt.grid()
plt.show()