diff --git a/docs/source/training.rst b/docs/source/training.rst index 339d8be6..24ec8f06 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -150,3 +150,21 @@ potential difficulty in the task. .. autoclass:: matsciml.models.base.NodeDenoisingTask :members: + + +Loss functions +============== + +Each task will have a default loss function that is ensured to work for that particular task. +Some tasks and datasets may need more flexibility in how training signals are computed, and +the currently implemented interface for tasks allows a dictionary mapping of task key and +loss function to be passed as a hyperparameter. In addition to the typical PyTorch losses +like ``L1Loss`` and ``MSELoss``, we also implement some loss functions based on ideas borrowed +from other repositories, such as MACE. + +.. autoclass:: matsciml.models.losses.AtomWeightedMSE + :members: + + +.. autoclass:: matsciml.models.losses.BatchQuantileLoss + :members: diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 02205310..207da6d2 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -1,4 +1,6 @@ from __future__ import annotations +from functools import partial +from typing import Callable, Literal import torch from torch import nn @@ -63,3 +65,125 @@ def forward( # ensures that atoms_per_graph is type cast correctly squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0 return squared_error.mean() + + +class BatchQuantileLoss(nn.Module): + def __init__( + self, + quantile_weights: dict[float, float], + loss_func: Callable | Literal["mse", "mae", "rmse", "huber"], + use_norm: bool = True, + huber_delta: float | None = None, + ) -> None: + """ + Implements a batch-based or dynamic quantile loss function. + + This loss function uses user-defined quantiles and associated + weights for training: the high-level idea is to allow flexibility + in optimizing model performance against certain outliers, and + ensuring that the model generalizes well. + + The function will either use the target values, or the norm of + the target values (more meaningful for vector quantities like + forces) to compute quantile values based on bins requested. A weight + tensor is then generated (with the same shape as the targets) to + weight predicted vs. actual margins, as computed with ``loss_func``. + The mean of the weighted loss is then returned. + + In the case of ``loss_func='huber'``, the ``huber_delta`` argument specifies + the margin used to switch between MAE and MSE losses. This value + is applied globally (i.e. regardless of the quantile). + + Parameters + ---------- + quantile_weights : dict[float, float] + Dictionary mapping of quantile and the weighting to ascribe + to that bin. Values smaller than the first bin, and larger + than the last bin take on these respective values, while + quantile in between bin ranges include the lower quantile + and up to (not including) the next bin. + loss_func : Callable | Literal['mse', 'mae', 'rmse', 'huber'] + Actual metric function. If a string literal is given, then + one of the built-in PyTorch functional losses are used + based on either MSE, RMSE, or Huber loss. If a ``Callable`` + is passed, the output **must** be of the same dimension + as the targets, i.e. the behavior of ``keepdim`` or no + reduction, as the weights are applied afterwards. + use_norm : bool, default True + Whether to use the norm of targets, instead of an elementwise + wise application. This makes sense for vector quantities that + are coupled, e.g. force vectors. If ``False``, this will still + work with scalar and vector quantities-alike, but requires an + intuition for one over the other. + huber_delta : float, optional + If ``loss_func`` is set to 'huber', this value is used as the + ``delta`` argument in ``torch.nn.functional.huber_loss``, which + corresponds to the margin between MAE/L1 and MSE functions. + + Raises + ------ + NotImplementedError: + Currently RMSE is not implemented, and will trigger this + exception. + """ + super().__init__() + for key, value in quantile_weights.items(): + assert isinstance( + key, float + ), "Expected quantile keys to be floats between [0,1]." + assert isinstance( + value, float + ), "Expected quantile dict values to be floats." + assert ( + 0.0 <= key <= 1.0 + ), f"Quantile value {key} invalid; must be between [0,1]." + quantiles = torch.Tensor(list(quantile_weights.keys())) + self.register_buffer("quantiles", quantiles) + weights = torch.Tensor(list(quantile_weights.values())) + self.register_buffer("weights", weights) + self.use_norm = use_norm + # each loss is wrapped as a partial to provide static arguments, primarily + # as we want to not apply the reduction immediately + if isinstance(loss_func, str): + if loss_func == "mse": + loss_func = partial(torch.nn.functional.mse_loss, reduction="none") + elif loss_func == "mae": + loss_func = partial(torch.nn.functional.l1_loss, reduction="none") + elif loss_func == "rmse": + raise NotImplementedError("RMSE function has not yet been implemented.") + elif loss_func == "huber": + assert ( + huber_delta + ), "Huber loss specified but no margin provided to ``huber_delta``." + loss_func = partial( + torch.nn.functional.huber_loss, delta=huber_delta, reduction="none" + ) + self.loss_func = loss_func + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if self.use_norm: + if target.ndim == 1: + temp_target = target.unsqueeze(-1) + else: + temp_target = target + target_quantity = temp_target.norm(dim=-1, keepdim=True) + else: + target_quantity = target + target_quantiles = torch.quantile(target_quantity, q=self.quantiles) + target_weights = torch.empty_like(target_quantity) + # define the first quantile bracket + target_weights[target_quantity < target_quantiles[0]] = self.weights[0] + # now do quantiles in between + for index in range(len(self.weights) - 1): + curr_quantile = self.quantiles[index] + next_quantile = self.quantiles[index + 1] + curr_weight = self.weights[index] + mask = (target_quantity >= curr_quantile) & ( + target_quantity < next_quantile + ) + target_weights[mask] = curr_weight + # the last bin + target_weights[target_quantity >= target_quantiles[-1]] = self.weights[-1] + unweighted_loss = self.loss_func(input, target) + weighted_loss = unweighted_loss * target_weights + return weighted_loss.mean() diff --git a/matsciml/models/tests/test_losses.py b/matsciml/models/tests/test_losses.py index b39373e9..db407e2a 100644 --- a/matsciml/models/tests/test_losses.py +++ b/matsciml/models/tests/test_losses.py @@ -2,7 +2,15 @@ import pytest import torch +from lightning import pytorch as pl +from matsciml.lightning import MatSciMLDataModule +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) +from matsciml.models.base import ForceRegressionTask +from matsciml.models.pyg import EGNN from matsciml.models import losses @@ -29,3 +37,48 @@ def test_weighted_mse(atom_weighted_mse, shape): target = torch.rand_like(pred) ptr = torch.randint(1, 100, (shape[0],)) atom_weighted_mse(pred, target, ptr) + + +@pytest.mark.parametrize("shape", [(10,), (50, 3)]) +@pytest.mark.parametrize( + "quantiles", + [ + {0.25: 0.5, 0.5: 1.0, 0.75: 3.21}, + {0.02: 0.2, 0.32: 0.67, 0.5: 1.0, 0.83: 2.0, 0.95: 10.0}, + ], +) +@pytest.mark.parametrize("use_norm", [True, False]) +@pytest.mark.parametrize("loss_func", ["mse", "huber"]) +def test_quantile_loss(shape, quantiles, use_norm, loss_func): + # ensure we test against back prop as well + x, y = torch.rand(2, *shape, requires_grad=True) + l_func = losses.BatchQuantileLoss(quantiles, loss_func, use_norm, huber_delta=0.01) + loss = l_func(x, y) + loss.mean().backward() + + +def test_quantile_loss_egnn(): + task = ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs={"hidden_dim": 64, "output_dim": 64}, + output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64}, + loss_func={ + "energy": torch.nn.MSELoss, + "force": losses.BatchQuantileLoss( + {0.1: 0.5, 0.25: 0.9, 0.5: 1.5, 0.85: 2.0, 0.95: 1.0}, + loss_func="huber", + huber_delta=0.01, + ), + }, + ) + dm = MatSciMLDataModule.from_devset( + "LiPSDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(6.0), + PointCloudToGraphTransform("pyg"), + ] + }, + ) + trainer = pl.Trainer(fast_dev_run=10) + trainer.fit(task, datamodule=dm)