Implementation of Evidential Deep Learning to Quantify Classification Uncertainty and Deep Evidential Regression, a deterministic method for quantifying uncertainty for neural network models.
pip install edl_pytorch
See examples/mnist.py and examples/cubic.py for examples of classification/regression respectively, producing the figures above.
For classification, use the Dirichlet
layer as the final layer in the model
and evidential_classification
loss:
import torch
from torch import nn
from edl_pytorch import Dirichlet, evidential_classification
model = nn.Sequential(
nn.Linear(2, 16), # two input dim
nn.ReLU(),
Dirichlet(16, 2), # two output classes
)
x = torch.randn(1, 2) # (batch, dim)
y = torch.randint(0, 2, (1, 1))
pred_dirchlet = model(x) # (1, 2)
loss = evidential_classification(
pred_dirchlet, # predicted Dirichlet parameters
y, # target labels
lamb=0.001, # regularization coefficient
)
For regression, use the NormalInvGamma
layer as the final layer in the model
and evidential_regression
loss:
import torch
from torch import nn
from edl_pytorch import NormalInvGamma, evidential_regression
model = nn.Sequential(
nn.Linear(1, 16), # one input dim
nn.ReLU(),
NormalInvGamma(16, 1), # one target variable
)
x = torch.randn(1, 1) # (batch, dim)
y = torch.randn(1, 1)
pred_nig = model(x) # (mu, v, alpha, beta)
loss = evidential_regression(
pred_nig, # predicted Normal Inverse Gamma parameters
y, # target labels
lamb=0.001, # regularization coefficient
)
- https://muratsensoy.github.io/uncertainty.html, original code for Evidential Deep Learning to Quantify Classification Uncertainty in Tensorflow
- aamini/evidential-deep-learning, original code for Deep Evidential Regression in Tensorflow/Keras
- add examples
- allow specification of evidence function, currently
F.softplus
@article{sensoy2018evidential,
title={Evidential deep learning to quantify classification uncertainty},
author={Sensoy, Murat and Kaplan, Lance and Kandemir, Melih},
journal={Advances in neural information processing systems},
volume={31},
year={2018}
}
@article{amini2020deep,
title={Deep evidential regression},
author={Amini, Alexander and Schwarting, Wilko and Soleimany, Ava and Rus, Daniela},
journal={Advances in Neural Information Processing Systems},
volume={33},
pages={14927--14937},
year={2020}
}