Skip to content

eguidotti/torchabc

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TorchABC

torchabc is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.

The core of the package is the TorchABC class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.

This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained file. It is ideal for research, prototyping, and teaching.

Structure

The TorchABC class structures a project into the following main steps:

diagram

  1. Dataloaders - load raw data samples.
  2. Preprocess – transform raw samples.
  3. Collate - batch preprocessed samples.
  4. Network - compute model outputs.
  5. Loss - compute error against targets.
  6. Optimizer - update model parameters.
  7. Postprocess - transform outputs into predictions.

Each step corresponds to an abstract method in TorchABC. To use TorchABC, create a concrete subclass and implement these methods.

Quick start

Install the package.

pip install torchabc

Generate a template using the command line interface.

torchabc --create template.py --min

Fill out the template by implementing the methods below. The documentation of each method is available here.

import torch
from torchabc import TorchABC
from functools import cached_property


class MyModel(TorchABC):
    
    @cached_property
    def dataloaders(self):
        raise NotImplementedError
    
    @staticmethod
    def preprocess(sample, hparams, flag=''):
        return sample

    @staticmethod
    def collate(samples):
        return torch.utils.data.default_collate(samples)

    @cached_property
    def network(self):
        raise NotImplementedError
    
    @staticmethod
    def loss(outputs, targets, hparams):
        raise NotImplementedError

    @cached_property
    def optimizer(self):
        raise NotImplementedError
    
    @staticmethod
    def postprocess(outputs, hparams):
        return outputs

Usage

Once a subclass of TorchABC is implemented, it can be used for training, evaluation, checkpointing, and inference.

Initialization

model = MyModel()

Initialize the model.

Training

model.train(epochs=5, on="train", val="val")

Train the model for 5 epochs using the train and val dataloaders.

Evaluation

metrics = model.eval(on="test")

Evaluate on the test dataloader and return metrics.

Checkpoints

model.save("checkpoint.pth")
model.load("checkpoint.pth")

Save and restore the model state.

Inference

preds = model(samples)

Run predictions on raw input samples.

API Reference

The TorchABC class defines a standard workflow for PyTorch projects. Some methods are abstract (must be implemented in subclasses), others are optional (can be overridden but have defaults), and a few are concrete (should not be overridden).


Abstract Methods

Method Description
dataloaders Must return dict[str, torch.utils.data.DataLoader]. Example keys: "train", "val", "test".
preprocess(sample, hparams, flag='') Transform a raw dataset sample.
Parameters:
- sample (Any): raw sample.
- hparams (dict): hyperparameters.
- flag (str, optional): mode flag.
Returns: Tensor or iterable of tensors.
collate(samples) Collate a batch of preprocessed samples.
Parameters:
- samples (Iterable[Tensor])
Returns: Tensor or iterable of tensors.
network Must return a torch.nn.Module. Inputs and outputs must use (batch_size, ...) format.
optimizer Must return a torch.optim.Optimizer for self.network.parameters().
loss(outputs, targets, hparams) Compute loss for a batch.
Parameters:
- outputs (Tensor or iterable)
- targets (Tensor or iterable)
- hparams (dict)
Returns: dict[str, Any] containing key "loss".
postprocess(outputs, hparams) Convert network outputs into predictions.
Parameters:
- outputs (Tensor or iterable)
- hparams (dict)
Returns: predictions (Any).

Default Methods

Method Description
scheduler Learning rate scheduler. May return None, torch.optim.lr_scheduler.LRScheduler, or ReduceLROnPlateau. Default is None.
backward(batch, gas) Backpropagation step.
Parameters:
- batch (dict[str, Any]): must contain key "loss".
- gas (int): gradient accumulation steps.
metrics(batches, hparams) Compute evaluation metrics.
Parameters:
- batches (deque[dict[str, Any]]): batch results.
- hparams (dict)
Returns: dict[str, Any]. Default computes average loss.
checkpoint(epoch, metrics, out) Checkpoint step. Saves model if loss improves.
Parameters:
- epoch (int): epoch number.
- metrics (dict[str, float]): validation metrics.
- out (str or None): output path to save checkpoints.
Returns: bool indicating early stopping.
move(data) Move data to current device. Supports Tensor, list, tuple, dict.
detach(data) Detach data from computation graph. Supports Tensor, list, tuple, dict.

Concrete Methods

Method Description
TorchABC(device=None, logger=print, hparams=None, **kwargs) Initialize the model.
Parameters:
- device (str or torch.device, optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.
- logger (Callable[[dict], None], optional): logging function. Defaults to print.
- hparams (dict, optional): dictionary of hyperparameters.
- kwargs: additional attributes stored in the instance.
train(epochs, gas=1, mas=None, on='train', val='val', out=None) Train the model.
Parameters:
- epochs (int): number of training epochs.
- gas (int, optional): gradient accumulation steps. Defaults to 1.
- mas (int, optional): metrics accumulation steps. Defaults to gas.
- on (str, optional): training dataloader name. Default "train".
- val (str, optional): validation dataloader name. Default "val". If None, validation is skipped.
- out (str, optional): output path to save checkpoints.
eval(on) Evaluate the model.
Parameters:
- on (str): dataloader name.
Returns: dict[str, float] of evaluation metrics.
__call__(samples) Run inference on raw samples.
Parameters:
- samples (Iterable[Any]): raw samples.
Returns: postprocessed predictions.
save(path) Save a checkpoint.
Parameters:
- path (str): file path.
load(path) Load a checkpoint.
Parameters:
- path (str): file path.

Examples

Get started with simple self-contained examples:

Run the examples

Install the dependencies

poetry install --with examples

Run the examples by replacing <name> with one of the filenames in the examples folder

poetry run python examples/<name>.py

Contribute

Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC class itself.

About

A simple abstract class for training and inference in PyTorch

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages