Skip to content

Commit

Permalink
docs: added docstrings to parity inference task
Browse files Browse the repository at this point in the history
Signed-off-by: Kin Long Kelvin Lee <[email protected]>
  • Loading branch information
laserkelvin committed Sep 25, 2024
1 parent e3c38e8 commit eabfef8
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions matsciml/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,33 @@ def predict_step(

@registry.register_task("ParityInferenceTask")
class ParityInferenceTask(BaseInferenceTask):
def __init__(self, pretrained_model: BaseTaskModule, *args, **kwargs):
def __init__(self, pretrained_model: BaseTaskModule):
"""
Use a pretrained model to produce pair-plot data, i.e. predicted vs.
ground truth.
Example usage
-------------
The intended usage is to load a pretrained model, define a data module
that points to some data to perform predictions with, then call Lightning
Trainer's ``predict`` method.
>>> task = ParityInferenceTask.from_pretrained_checkpoint(...)
>>> dm = MatSciMLDataModule("DatasetName", pred_path=...)
>>> trainer = pl.Trainer()
>>> trainer.predict(task, datamodule=dm)
Parameters
----------
pretrained_model : BaseTaskModule
An instance of a subclass of ``BaseTaskModule``, e.g. a
``ForceRegressionTask`` object.
Raises
------
NotImplementedError
Currently, multitask modules are not yet supported.
"""
if isinstance(pretrained_model, MultiTaskLitModule):
raise NotImplementedError(
"ParityInferenceTask currently only supports single task modules."
Expand All @@ -218,11 +244,28 @@ def __init__(self, pretrained_model: BaseTaskModule, *args, **kwargs):
self.accumulators = {}

def forward(self, batch: BatchDict) -> dict[str, float | torch.Tensor]:
# initially try using the predict method
"""
Forward call for the inference task. This wraps the underlying
``matsciml`` task module's ``predict`` function to ensure that
normalization is 'reversed', i.e. predictions are reported in
the original unit space.
Parameters
----------
batch : BatchDict
Batch of samples to process.
Returns
-------
dict[str, float | torch.Tensor]
Prediction output, which should correspond to a key/tensor
mapping of output head/task name, and the associated outputs.
"""
preds = self.model.predict(batch)
return preds

def on_predict_start(self) -> None:
"""Verify that logging is enabled, as it is needed."""
if not self.trainer.log_dir:
raise RuntimeError(
"ParityInferenceTask requires logging to be enabled; no `log_dir` detected in Trainer."
Expand All @@ -244,6 +287,7 @@ def predict_step(
acc.predictions = predictions[key].detach()

def on_predict_epoch_end(self) -> None:
"""At the end of the dataset, write results to ``<log_dir>/inference_data.json``."""
log_dir = Path(self.trainer.log_dir)
output_file = log_dir.joinpath("inference_data.json")
with open(output_file, "w+") as write_file:
Expand Down

0 comments on commit eabfef8

Please sign in to comment.