Skip to content

Commit

Permalink
Merge pull request #303 from laserkelvin/quantile-loss
Browse files Browse the repository at this point in the history
Quantile-based loss mechanism
  • Loading branch information
laserkelvin authored Oct 7, 2024
2 parents 25969cc + 91953af commit b0e7cd9
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 0 deletions.
18 changes: 18 additions & 0 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,21 @@ potential difficulty in the task.

.. autoclass:: matsciml.models.base.NodeDenoisingTask
:members:


Loss functions
==============

Each task will have a default loss function that is ensured to work for that particular task.
Some tasks and datasets may need more flexibility in how training signals are computed, and
the currently implemented interface for tasks allows a dictionary mapping of task key and
loss function to be passed as a hyperparameter. In addition to the typical PyTorch losses
like ``L1Loss`` and ``MSELoss``, we also implement some loss functions based on ideas borrowed
from other repositories, such as MACE.

.. autoclass:: matsciml.models.losses.AtomWeightedMSE
:members:


.. autoclass:: matsciml.models.losses.BatchQuantileLoss
:members:
124 changes: 124 additions & 0 deletions matsciml/models/losses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from __future__ import annotations
from functools import partial
from typing import Callable, Literal

import torch
from torch import nn
Expand Down Expand Up @@ -63,3 +65,125 @@ def forward(
# ensures that atoms_per_graph is type cast correctly
squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0
return squared_error.mean()


class BatchQuantileLoss(nn.Module):
def __init__(
self,
quantile_weights: dict[float, float],
loss_func: Callable | Literal["mse", "mae", "rmse", "huber"],
use_norm: bool = True,
huber_delta: float | None = None,
) -> None:
"""
Implements a batch-based or dynamic quantile loss function.
This loss function uses user-defined quantiles and associated
weights for training: the high-level idea is to allow flexibility
in optimizing model performance against certain outliers, and
ensuring that the model generalizes well.
The function will either use the target values, or the norm of
the target values (more meaningful for vector quantities like
forces) to compute quantile values based on bins requested. A weight
tensor is then generated (with the same shape as the targets) to
weight predicted vs. actual margins, as computed with ``loss_func``.
The mean of the weighted loss is then returned.
In the case of ``loss_func='huber'``, the ``huber_delta`` argument specifies
the margin used to switch between MAE and MSE losses. This value
is applied globally (i.e. regardless of the quantile).
Parameters
----------
quantile_weights : dict[float, float]
Dictionary mapping of quantile and the weighting to ascribe
to that bin. Values smaller than the first bin, and larger
than the last bin take on these respective values, while
quantile in between bin ranges include the lower quantile
and up to (not including) the next bin.
loss_func : Callable | Literal['mse', 'mae', 'rmse', 'huber']
Actual metric function. If a string literal is given, then
one of the built-in PyTorch functional losses are used
based on either MSE, RMSE, or Huber loss. If a ``Callable``
is passed, the output **must** be of the same dimension
as the targets, i.e. the behavior of ``keepdim`` or no
reduction, as the weights are applied afterwards.
use_norm : bool, default True
Whether to use the norm of targets, instead of an elementwise
wise application. This makes sense for vector quantities that
are coupled, e.g. force vectors. If ``False``, this will still
work with scalar and vector quantities-alike, but requires an
intuition for one over the other.
huber_delta : float, optional
If ``loss_func`` is set to 'huber', this value is used as the
``delta`` argument in ``torch.nn.functional.huber_loss``, which
corresponds to the margin between MAE/L1 and MSE functions.
Raises
------
NotImplementedError:
Currently RMSE is not implemented, and will trigger this
exception.
"""
super().__init__()
for key, value in quantile_weights.items():
assert isinstance(
key, float
), "Expected quantile keys to be floats between [0,1]."
assert isinstance(
value, float
), "Expected quantile dict values to be floats."
assert (
0.0 <= key <= 1.0
), f"Quantile value {key} invalid; must be between [0,1]."
quantiles = torch.Tensor(list(quantile_weights.keys()))
self.register_buffer("quantiles", quantiles)
weights = torch.Tensor(list(quantile_weights.values()))
self.register_buffer("weights", weights)
self.use_norm = use_norm
# each loss is wrapped as a partial to provide static arguments, primarily
# as we want to not apply the reduction immediately
if isinstance(loss_func, str):
if loss_func == "mse":
loss_func = partial(torch.nn.functional.mse_loss, reduction="none")
elif loss_func == "mae":
loss_func = partial(torch.nn.functional.l1_loss, reduction="none")
elif loss_func == "rmse":
raise NotImplementedError("RMSE function has not yet been implemented.")
elif loss_func == "huber":
assert (
huber_delta
), "Huber loss specified but no margin provided to ``huber_delta``."
loss_func = partial(
torch.nn.functional.huber_loss, delta=huber_delta, reduction="none"
)
self.loss_func = loss_func

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if self.use_norm:
if target.ndim == 1:
temp_target = target.unsqueeze(-1)
else:
temp_target = target
target_quantity = temp_target.norm(dim=-1, keepdim=True)
else:
target_quantity = target
target_quantiles = torch.quantile(target_quantity, q=self.quantiles)
target_weights = torch.empty_like(target_quantity)
# define the first quantile bracket
target_weights[target_quantity < target_quantiles[0]] = self.weights[0]
# now do quantiles in between
for index in range(len(self.weights) - 1):
curr_quantile = self.quantiles[index]
next_quantile = self.quantiles[index + 1]
curr_weight = self.weights[index]
mask = (target_quantity >= curr_quantile) & (
target_quantity < next_quantile
)
target_weights[mask] = curr_weight
# the last bin
target_weights[target_quantity >= target_quantiles[-1]] = self.weights[-1]
unweighted_loss = self.loss_func(input, target)
weighted_loss = unweighted_loss * target_weights
return weighted_loss.mean()
53 changes: 53 additions & 0 deletions matsciml/models/tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import pytest
import torch
from lightning import pytorch as pl

from matsciml.lightning import MatSciMLDataModule
from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
PointCloudToGraphTransform,
)
from matsciml.models.base import ForceRegressionTask
from matsciml.models.pyg import EGNN
from matsciml.models import losses


Expand All @@ -29,3 +37,48 @@ def test_weighted_mse(atom_weighted_mse, shape):
target = torch.rand_like(pred)
ptr = torch.randint(1, 100, (shape[0],))
atom_weighted_mse(pred, target, ptr)


@pytest.mark.parametrize("shape", [(10,), (50, 3)])
@pytest.mark.parametrize(
"quantiles",
[
{0.25: 0.5, 0.5: 1.0, 0.75: 3.21},
{0.02: 0.2, 0.32: 0.67, 0.5: 1.0, 0.83: 2.0, 0.95: 10.0},
],
)
@pytest.mark.parametrize("use_norm", [True, False])
@pytest.mark.parametrize("loss_func", ["mse", "huber"])
def test_quantile_loss(shape, quantiles, use_norm, loss_func):
# ensure we test against back prop as well
x, y = torch.rand(2, *shape, requires_grad=True)
l_func = losses.BatchQuantileLoss(quantiles, loss_func, use_norm, huber_delta=0.01)
loss = l_func(x, y)
loss.mean().backward()


def test_quantile_loss_egnn():
task = ForceRegressionTask(
encoder_class=EGNN,
encoder_kwargs={"hidden_dim": 64, "output_dim": 64},
output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64},
loss_func={
"energy": torch.nn.MSELoss,
"force": losses.BatchQuantileLoss(
{0.1: 0.5, 0.25: 0.9, 0.5: 1.5, 0.85: 2.0, 0.95: 1.0},
loss_func="huber",
huber_delta=0.01,
),
},
)
dm = MatSciMLDataModule.from_devset(
"LiPSDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(6.0),
PointCloudToGraphTransform("pyg"),
]
},
)
trainer = pl.Trainer(fast_dev_run=10)
trainer.fit(task, datamodule=dm)

0 comments on commit b0e7cd9

Please sign in to comment.