Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parity inference task #294

Merged
merged 10 commits into from
Sep 26, 2024
62 changes: 62 additions & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Inference
=========

"Inference" can be a bit of an overloaded term, and this page is broken down into different possible
downstream use cases for trained models.

Parity plots and model evaluations
----------------------------------

The simplest/most straightforward thing to check the performance of a model is to look beyond reduced metrics; i.e. anything that
has been averaged over batches, epochs, etc. Parity plots help verify linear relationships between predictions and ground truths
by simply iterating over the evaluation subset of data, averaging.

The ``ParityInferenceTask`` helps perform this task by using the PyTorch Lightning ``predict`` pipelines. With a pre-trained
``matsciml`` task checkpoint, you simply need to run the following:

.. codeblock:: python

import pytorch_lightning as pl

from matsciml.models.inference import ParityInferenceTask
from matsciml.lightning import MatSciMLDataModule

# configure data module the way that you need to
dm = MatSciMLDataModule(
dataset="NameofDataset",
pred_split="/path/to/lmdb/split",
batch_size=64 # this is just to amoritize model calls
)
task = ParityInferenceTask.from_pretrained_checkpoint("/path/to/checkpoint")

trainer = pl.Trainer() # optionally, configure logger/limit_predict_batches
trainer.predict(task, datamodule=dm)


The default ``Trainer`` settings will create a ``lightning_logs`` directory, followed by an experiment
number. Within it, once your inference run completes, there will be a ``inference_data.json`` that you
can then load in. The data is sorted by the name of the target (e.g. ``energy``, ``bandgap``), under
these keys, ``predictions`` and ``targets``. Note that ``pred_split`` does not necessarily have to be
a completely different hold out: you can pass your training LMDB path if you wish to double check the
performance of your model after training, or you can use it with unseen samples.

.. note::

For developers, this is handled by the ``matsciml.models.inference.ParityData`` class. This is
mainly to standardize the output and provide a means to serialize the data as JSON.



.. autoclass:: matsciml.models.inference.ParityInferenceTask
:members:



Performing molecular dynamics simulations
-----------------------------------------

Currently, the main method of interfacing with dynamical simulations is through the ``ase`` package.
Documentation for this is ongoing, but examples can be found under ``examples/interfaces``.

.. autoclass:: matsciml.interfaces.ase.MatSciMLCalculator
:members:
22 changes: 19 additions & 3 deletions matsciml/lightning/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,15 @@ def __init__(
num_workers: int = 0,
val_split: str | Path | float | None = 0.0,
test_split: str | Path | float | None = 0.0,
pred_split: str | Path | None = None,
seed: int | None = None,
dset_kwargs: dict[str, Any] | None = None,
persistent_workers: bool | None = None,
):
super().__init__()
# make sure we have something to work with
assert any(
[i for i in [dataset, train_path, val_split, test_split]],
[i for i in [dataset, train_path, val_split, test_split, pred_split]],
), "No splits provided to datamodule."
# if floats are passed to splits, make sure dataset is provided for inference
if any([isinstance(i, float) for i in [val_split, test_split]]):
Expand All @@ -122,7 +123,7 @@ def __init__(
assert any(
[
isinstance(p, (str, Path))
for p in [train_path, val_split, test_split]
for p in [train_path, val_split, test_split, pred_split]
],
), "Dataset type passed, but no paths to construct with."
self.dataset = dataset
Expand Down Expand Up @@ -248,6 +249,17 @@ def setup(self, stage: str | None = None) -> None:
if isinstance(split_path, (str, Path)):
dset = self._make_dataset(split_path, self.dataset)
splits[key] = dset
# specialty case for 'inference' or prediction runs
if isinstance(self.hparams.pred_split, (str, Path)):
pred_split_path = self.hparams.pred_split
if isinstance(pred_split_path, str):
pred_split_path = Path(pred_split_path)
if not pred_split_path.exists():
raise FileNotFoundError(
f"Prediction split provided, but not found: {pred_split_path}"
)
dset = self._make_dataset(pred_split_path, self.dataset)
splits["pred"] = dset
# the last case assumes only the dataset is passed, we will treat it as train
if len(splits) == 0:
splits["train"] = self.dataset
Expand All @@ -268,8 +280,12 @@ def predict_dataloader(self):
"""
Predict behavior just assumes the whole dataset is used for inference.
"""
if "pred" in self.splits:
target = self.splits["pred"]
else:
target = self.dataset
return DataLoader(
self.dataset,
target,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
collate_fn=self.dataset.collate_fn,
Expand Down
184 changes: 180 additions & 4 deletions matsciml/models/inference.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,114 @@
from __future__ import annotations

import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Union
from typing import Any
from logging import getLogger

import pytorch_lightning as pl
import torch
from torch import nn

from matsciml.common.registry import registry
from matsciml.common.types import BatchDict, DataDict
from matsciml.models.base import BaseTaskModule, MultiTaskLitModule


class ParityData:
def __init__(self, name: str) -> None:
"""
Class to help accumulate inference results.

This class should be created per target, and uses property
setters to accumulate target and prediction tensors,
and at the final step, aggregate them all into a single
tensor and with the `to_json` method, produce serializable
data.

Parameters
----------
name : str
Name of the target property being tracked.
"""
super().__init__()
self.name = name
self.logger = getLogger(f"matsciml.inference.{name}-parity")

@property
def ndim(self) -> int:
if not hasattr(self, "_targets"):
raise RuntimeError("No data set to accumulator yet.")
sample = self._targets[0]
if isinstance(sample, torch.Tensor):
return sample.ndim
else:
return 0

@property
def targets(self) -> torch.Tensor:
return torch.vstack(self._targets)

@targets.setter
def targets(self, values: torch.Tensor) -> None:
if not hasattr(self, "_targets"):
self._targets = []
if isinstance(values, torch.Tensor):
# remove errenous "empty" dimensions
values.squeeze_()
self._targets.append(values)

@property
def predictions(self) -> torch.Tensor:
return torch.vstack(self._targets)

@predictions.setter
def predictions(self, values: torch.Tensor) -> None:
if not hasattr(self, "_predictions"):
self._predictions = []
if isinstance(values, torch.Tensor):
values.squeeze_()
self._predictions.append(values)

def to_json(self) -> dict[str, list]:
return_dict = {}
targets = self.targets.cpu()
predictions = self.predictions.cpu()
# do some preliminary checks to the data
if targets.ndim != predictions.ndim:
self.logger.warning(
"Target/prediction dimensionality mismatch\n"
f" Target: {targets.ndim}, predictions: {predictions.ndim}"
)
if targets.shape != predictions.shape:
self.logger.warning(
"Target/prediction shape mismatch\n"
f" Target: {targets.shape}, predictions: {predictions.shape}."
)
return_dict["predictions"] = predictions.tolist()
return_dict["targets"] = targets.tolist()
return_dict["name"] = self.name
return return_dict


class BaseInferenceTask(ABC, pl.LightningModule):
def __init__(self, pretrained_model: nn.Module, *args, **kwargs):
super().__init__()
self.model = pretrained_model

def training_step(self, *args, **kwargs) -> None:
"""Overrides Lightning method to prevent task being used for training."""
raise NotImplementedError(
f"{self.__class__.__name__} is not intended for training."
)

@abstractmethod
def predict_step(
self,
batch: BatchDict,
batch_idx: int,
dataloader_idx: int = 0,
) -> Any:
...
) -> Any: ...

@classmethod
def from_pretrained_checkpoint(
Expand Down Expand Up @@ -58,7 +142,7 @@ def from_pretrained_checkpoint(
task_ckpt_path = Path(task_ckpt_path)
assert (
task_ckpt_path.exists()
), f"Encoder checkpoint filepath specified but does not exist."
), "Encoder checkpoint filepath specified but does not exist."
ckpt = torch.load(task_ckpt_path)
select_kwargs = {}
for key in ["encoder_class", "encoder_kwargs"]:
Expand Down Expand Up @@ -117,3 +201,95 @@ def predict_step(
for key in ["targets", "symmetry"]:
return_dict[key] = batch.get(key)
return return_dict


@registry.register_task("ParityInferenceTask")
class ParityInferenceTask(BaseInferenceTask):
def __init__(self, pretrained_model: BaseTaskModule):
"""
Use a pretrained model to produce pair-plot data, i.e. predicted vs.
ground truth.

Example usage
-------------
The intended usage is to load a pretrained model, define a data module
that points to some data to perform predictions with, then call Lightning
Trainer's ``predict`` method.

>>> task = ParityInferenceTask.from_pretrained_checkpoint(...)
>>> dm = MatSciMLDataModule("DatasetName", pred_path=...)
>>> trainer = pl.Trainer()
>>> trainer.predict(task, datamodule=dm)

Parameters
----------
pretrained_model : BaseTaskModule
An instance of a subclass of ``BaseTaskModule``, e.g. a
``ForceRegressionTask`` object.

Raises
------
NotImplementedError
Currently, multitask modules are not yet supported.
"""
if isinstance(pretrained_model, MultiTaskLitModule):
raise NotImplementedError(
"ParityInferenceTask currently only supports single task modules."
)
assert hasattr(pretrained_model, "predict") and callable(
pretrained_model.predict
), "Model passed does not have a `predict` method; is it a `matsciml` task?"
super().__init__(pretrained_model)
self.common_keys = set()
self.accumulators = {}

def forward(self, batch: BatchDict) -> dict[str, float | torch.Tensor]:
"""
Forward call for the inference task. This wraps the underlying
``matsciml`` task module's ``predict`` function to ensure that
normalization is 'reversed', i.e. predictions are reported in
the original unit space.

Parameters
----------
batch : BatchDict
Batch of samples to process.

Returns
-------
dict[str, float | torch.Tensor]
Prediction output, which should correspond to a key/tensor
mapping of output head/task name, and the associated outputs.
"""
preds = self.model.predict(batch)
return preds

def on_predict_start(self) -> None:
"""Verify that logging is enabled, as it is needed."""
if not self.trainer.log_dir:
raise RuntimeError(
"ParityInferenceTask requires logging to be enabled; no `log_dir` detected in Trainer."
)

def predict_step(
self, batch: BatchDict, batch_idx: int, dataloader_idx: int = 0
) -> None:
predictions = self(batch)
pred_keys = set(list(predictions.keys()))
batch_keys = set(list(batch["targets"].keys()))
self.common_keys = pred_keys.intersection(batch_keys)
# loop over keys that are mutually available in predictions and data
for key in self.common_keys:
if key not in self.accumulators:
self.accumulators[key] = ParityData(key)
acc = self.accumulators[key]
acc.targets = batch["targets"][key].detach()
acc.predictions = predictions[key].detach()

def on_predict_epoch_end(self) -> None:
"""At the end of the dataset, write results to ``<log_dir>/inference_data.json``."""
log_dir = Path(self.trainer.log_dir)
output_file = log_dir.joinpath("inference_data.json")
with open(output_file, "w+") as write_file:
data = {key: acc.to_json() for key, acc in self.accumulators.items()}
json.dump(data, write_file, indent=2)
Loading
Loading