From 5055aba5c6dc204b2888a520fe9a11f6aa892f44 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 10:25:46 -0700 Subject: [PATCH 01/10] feat: implemented dynamic quantile loss --- matsciml/models/losses.py | 68 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 02205310..7a803e9b 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -1,4 +1,6 @@ from __future__ import annotations +from functools import partial +from typing import Callable, Literal import torch from torch import nn @@ -63,3 +65,69 @@ def forward( # ensures that atoms_per_graph is type cast correctly squared_error = ((input - target) / atoms_per_graph.to(input.dtype)) ** 2.0 return squared_error.mean() + + +class BatchQuantileLoss(nn.Module): + def __init__( + self, + quantile_weights: dict[float, float], + loss_func: Callable | Literal["mse", "rmse", "huber"], + use_norm: bool = True, + huber_delta: float | None = None, + ) -> None: + super().__init__() + for key, value in quantile_weights.items(): + assert isinstance( + key, float + ), "Expected quantile keys to be floats between [0,1]." + assert isinstance( + value, float + ), "Expected quantile dict values to be floats." + assert ( + 0.0 <= key <= 1.0 + ), f"Quantile value {key} invalid; must be between [0,1]." + quantiles = torch.Tensor(list(quantile_weights.keys())) + self.register_buffer("quantiles", quantiles) + weights = torch.Tensor(list(quantile_weights.values())) + self.register_buffer("weights", weights) + self.use_norm = use_norm + # each loss is wrapped as a partial to provide static arguments, primarily + # as we want to not apply the reduction immediately + if isinstance(loss_func, str): + if loss_func == "mse": + loss_func = partial(torch.nn.functional.mse_loss, reduction="none") + elif loss_func == "rmse": + raise NotImplementedError("RMSE function has not yet been implemented.") + elif loss_func == "huber": + assert ( + huber_delta + ), "Huber loss specified but no margin provided to ``huber_delta``." + loss_func = partial( + torch.nn.functional.huber_loss, delta=huber_delta, reduction="none" + ) + self.loss_func = loss_func + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + assert target.ndim >= 2, "BatchQuantileLoss assumes vector quantites." + if self.use_norm: + target_quantity = target.norm(dim=-1) + else: + target_quantity = target + target_quantiles = torch.quantile(target_quantity, q=self.quantiles) + target_weights = torch.empty_like(target_quantity) + # define the first quantile bracket + target_weights[target_quantity < target_quantiles[0]] = self.weights[0] + # now do quantiles in between + for index in range(1, len(self.weights) - 1): + curr_quantile = self.quantiles[index] + next_quantile = self.quantiles[index + 1] + curr_weight = self.weights[index] + mask = (target_quantity >= curr_quantile) & ( + target_quantity < next_quantile + ) + target_weights[mask] = curr_weight + # the last bin + target_weights[target_quantity > target_quantiles[-1]] = self.weights[-1] + unweighted_loss = self.loss_func(input, target) + weighted_loss = unweighted_loss * target_weights + return weighted_loss.mean() From 1a784b0f6f52f39bf1839f4fc6f5e78b5928f066 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 10:38:29 -0700 Subject: [PATCH 02/10] fix & test: removed dim assertion and added passing parametrized unit test --- matsciml/models/losses.py | 3 +-- matsciml/models/tests/test_losses.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 7a803e9b..020d789c 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -108,9 +108,8 @@ def __init__( self.loss_func = loss_func def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - assert target.ndim >= 2, "BatchQuantileLoss assumes vector quantites." if self.use_norm: - target_quantity = target.norm(dim=-1) + target_quantity = target.norm(dim=-1, keepdim=True) else: target_quantity = target target_quantiles = torch.quantile(target_quantity, q=self.quantiles) diff --git a/matsciml/models/tests/test_losses.py b/matsciml/models/tests/test_losses.py index b39373e9..1e94be94 100644 --- a/matsciml/models/tests/test_losses.py +++ b/matsciml/models/tests/test_losses.py @@ -29,3 +29,21 @@ def test_weighted_mse(atom_weighted_mse, shape): target = torch.rand_like(pred) ptr = torch.randint(1, 100, (shape[0],)) atom_weighted_mse(pred, target, ptr) + + +@pytest.mark.parametrize("shape", [(10,), (50, 3)]) +@pytest.mark.parametrize( + "quantiles", + [ + {0.25: 0.5, 0.5: 1.0, 0.75: 3.21}, + {0.02: 0.2, 0.32: 0.67, 0.5: 1.0, 0.83: 2.0, 0.95: 10.0}, + ], +) +@pytest.mark.parametrize("use_norm", [True, False]) +@pytest.mark.parametrize("loss_func", ["mse", "huber"]) +def test_quantile_loss(shape, quantiles, use_norm, loss_func): + # ensure we test against back prop as well + x, y = torch.rand(2, *shape, requires_grad=True) + l_func = losses.BatchQuantileLoss(quantiles, loss_func, use_norm, huber_delta=0.01) + loss = l_func(x, y) + loss.mean().backward() From 1610c76f939504f91418f3d9f173d28e991bf3d1 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 10:49:17 -0700 Subject: [PATCH 03/10] docs: added docstring for batch quantile function --- matsciml/models/losses.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 020d789c..897502a5 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -75,6 +75,53 @@ def __init__( use_norm: bool = True, huber_delta: float | None = None, ) -> None: + """ + Implements a batch-based or dynamic quantile loss function. + + This loss function uses user-defined quantiles and associated + weights for training: the high-level idea is to allow flexibility + in optimizing model performance against certain outliers, and + ensuring that the model generalizes well. + + The function will either use the target values, or the norm of + the target values (more meaningful for vector quantities like + forces) to compute quantile values based on bins requested. A weight + tensor is then generated (with the same shape as the targets) to + weight predicted vs. actual margins, as computed with ``loss_func``. + The mean of the weighted loss is then returned. + + Parameters + ---------- + quantile_weights : dict[float, float] + Dictionary mapping of quantile and the weighting to ascribe + to that bin. Values smaller than the first bin, and larger + than the last bin take on these respective values, while + quantile in between bin ranges include the lower quantile + and up to (not including) the next bin. + loss_func : Callable | Literal['mse', 'rmse', 'huber'] + Actual metric function. If a string literal is given, then + one of the built-in PyTorch functional losses are used + based on either MSE, RMSE, or Huber loss. If a ``Callable`` + is passed, the output **must** be of the same dimension + as the targets, i.e. the behavior of ``keepdim`` or no + reduction, as the weights are applied afterwards. + use_norm : bool, default True + Whether to use the norm of targets, instead of an elementwise + wise application. This makes sense for vector quantities that + are coupled, e.g. force vectors. If ``False``, this will still + work with scalar and vector quantities-alike, but requires an + intuition for one over the other. + huber_delta : float, optional + If ``loss_func`` is set to 'huber', this value is used as the + ``delta`` argument in ``torch.nn.functional.huber_loss``, which + corresponds to the margin between MAE/L1 and MSE functions. + + Raises + ------ + NotImplementedError: + Currently RMSE is not implemented, and will trigger this + exception. + """ super().__init__() for key, value in quantile_weights.items(): assert isinstance( From 51d0be908cdbb818b925ad6241bed3d2a503fee8 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 10:50:23 -0700 Subject: [PATCH 04/10] refactor: allowing mae to be specified in batch quantile loss --- matsciml/models/losses.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 897502a5..b1e2d873 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -71,7 +71,7 @@ class BatchQuantileLoss(nn.Module): def __init__( self, quantile_weights: dict[float, float], - loss_func: Callable | Literal["mse", "rmse", "huber"], + loss_func: Callable | Literal["mse", "mae", "rmse", "huber"], use_norm: bool = True, huber_delta: float | None = None, ) -> None: @@ -98,7 +98,7 @@ def __init__( than the last bin take on these respective values, while quantile in between bin ranges include the lower quantile and up to (not including) the next bin. - loss_func : Callable | Literal['mse', 'rmse', 'huber'] + loss_func : Callable | Literal['mse', 'mae', 'rmse', 'huber'] Actual metric function. If a string literal is given, then one of the built-in PyTorch functional losses are used based on either MSE, RMSE, or Huber loss. If a ``Callable`` @@ -143,6 +143,8 @@ def __init__( if isinstance(loss_func, str): if loss_func == "mse": loss_func = partial(torch.nn.functional.mse_loss, reduction="none") + elif loss_func == "mae": + loss_func = partial(torch.nn.functional.l1_loss, reduction="none") elif loss_func == "rmse": raise NotImplementedError("RMSE function has not yet been implemented.") elif loss_func == "huber": From a44e5d1b90c044af6c9a5f04469c1719bb84de9b Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 11:10:27 -0700 Subject: [PATCH 05/10] fix: accidentally skipped first bin --- matsciml/models/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index b1e2d873..de0b1134 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -166,7 +166,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # define the first quantile bracket target_weights[target_quantity < target_quantiles[0]] = self.weights[0] # now do quantiles in between - for index in range(1, len(self.weights) - 1): + for index in range(len(self.weights) - 1): curr_quantile = self.quantiles[index] next_quantile = self.quantiles[index + 1] curr_weight = self.weights[index] From a8d590e0ddee5e98740be9cb912beeeafe352536 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 11:10:47 -0700 Subject: [PATCH 06/10] fix: will pad scalars to make norm computation sensible --- matsciml/models/losses.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index de0b1134..446992b7 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -158,7 +158,11 @@ def __init__( def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.use_norm: - target_quantity = target.norm(dim=-1, keepdim=True) + if target.ndim == 1: + temp_target = target.unsqueeze(-1) + else: + temp_target = target + target_quantity = temp_target.norm(dim=-1, keepdim=True) else: target_quantity = target target_quantiles = torch.quantile(target_quantity, q=self.quantiles) From a1c0b20328a68ed363b92489bd261906194d6018 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 11:11:08 -0700 Subject: [PATCH 07/10] fix: including last quantile range value for weights --- matsciml/models/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 446992b7..0d2bbd92 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -179,7 +179,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ) target_weights[mask] = curr_weight # the last bin - target_weights[target_quantity > target_quantiles[-1]] = self.weights[-1] + target_weights[target_quantity >= target_quantiles[-1]] = self.weights[-1] unweighted_loss = self.loss_func(input, target) weighted_loss = unweighted_loss * target_weights return weighted_loss.mean() From f78d2092013b9caf1efe23b622237157d887d8c8 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 12:12:08 -0700 Subject: [PATCH 08/10] test: added end-to-end unit test with EGNN for quantile loss Signed-off-by: Lee, Kin Long Kelvin --- matsciml/models/tests/test_losses.py | 35 ++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/matsciml/models/tests/test_losses.py b/matsciml/models/tests/test_losses.py index 1e94be94..db407e2a 100644 --- a/matsciml/models/tests/test_losses.py +++ b/matsciml/models/tests/test_losses.py @@ -2,7 +2,15 @@ import pytest import torch +from lightning import pytorch as pl +from matsciml.lightning import MatSciMLDataModule +from matsciml.datasets.transforms import ( + PeriodicPropertiesTransform, + PointCloudToGraphTransform, +) +from matsciml.models.base import ForceRegressionTask +from matsciml.models.pyg import EGNN from matsciml.models import losses @@ -47,3 +55,30 @@ def test_quantile_loss(shape, quantiles, use_norm, loss_func): l_func = losses.BatchQuantileLoss(quantiles, loss_func, use_norm, huber_delta=0.01) loss = l_func(x, y) loss.mean().backward() + + +def test_quantile_loss_egnn(): + task = ForceRegressionTask( + encoder_class=EGNN, + encoder_kwargs={"hidden_dim": 64, "output_dim": 64}, + output_kwargs={"lazy": False, "input_dim": 64, "hidden_dim": 64}, + loss_func={ + "energy": torch.nn.MSELoss, + "force": losses.BatchQuantileLoss( + {0.1: 0.5, 0.25: 0.9, 0.5: 1.5, 0.85: 2.0, 0.95: 1.0}, + loss_func="huber", + huber_delta=0.01, + ), + }, + ) + dm = MatSciMLDataModule.from_devset( + "LiPSDataset", + dset_kwargs={ + "transforms": [ + PeriodicPropertiesTransform(6.0), + PointCloudToGraphTransform("pyg"), + ] + }, + ) + trainer = pl.Trainer(fast_dev_run=10) + trainer.fit(task, datamodule=dm) From b9d68e5f02beea9c4d2044edc32b8e324e3dc368 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 12:18:36 -0700 Subject: [PATCH 09/10] docs: updated quantile huber_delta docstring --- matsciml/models/losses.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/matsciml/models/losses.py b/matsciml/models/losses.py index 0d2bbd92..207da6d2 100644 --- a/matsciml/models/losses.py +++ b/matsciml/models/losses.py @@ -90,6 +90,10 @@ def __init__( weight predicted vs. actual margins, as computed with ``loss_func``. The mean of the weighted loss is then returned. + In the case of ``loss_func='huber'``, the ``huber_delta`` argument specifies + the margin used to switch between MAE and MSE losses. This value + is applied globally (i.e. regardless of the quantile). + Parameters ---------- quantile_weights : dict[float, float] From 91953afd163d9f241ef1cfa99247c9c3d8c0ab6e Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 7 Oct 2024 12:34:47 -0700 Subject: [PATCH 10/10] docs: added loss function docs in generated documentation --- docs/source/training.rst | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/source/training.rst b/docs/source/training.rst index 339d8be6..24ec8f06 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -150,3 +150,21 @@ potential difficulty in the task. .. autoclass:: matsciml.models.base.NodeDenoisingTask :members: + + +Loss functions +============== + +Each task will have a default loss function that is ensured to work for that particular task. +Some tasks and datasets may need more flexibility in how training signals are computed, and +the currently implemented interface for tasks allows a dictionary mapping of task key and +loss function to be passed as a hyperparameter. In addition to the typical PyTorch losses +like ``L1Loss`` and ``MSELoss``, we also implement some loss functions based on ideas borrowed +from other repositories, such as MACE. + +.. autoclass:: matsciml.models.losses.AtomWeightedMSE + :members: + + +.. autoclass:: matsciml.models.losses.BatchQuantileLoss + :members: