torchabc
is a lightweight package that provides an Abstract Base Class (ABC) to structure PyTorch projects and keep code well organized.
The core of the package is the TorchABC
class. This class defines the abstract training and inference workflows and must be subclassed to implement a concrete logic.
This package has no extra dependencies beyond PyTorch and it consists of a simple self-contained file. It is ideal for research, prototyping, and teaching.
The TorchABC
class structures a project into the following main steps:
- Dataloaders - load raw data samples.
- Preprocess – transform raw samples.
- Collate - batch preprocessed samples.
- Network - compute model outputs.
- Loss - compute error against targets.
- Optimizer - update model parameters.
- Postprocess - transform outputs into predictions.
Each step corresponds to an abstract method in TorchABC
. To use TorchABC
, create a concrete subclass and implement these methods.
Install the package.
pip install torchabc
Generate a template using the command line interface.
torchabc --create template.py --min
Fill out the template by implementing the methods below. The documentation of each method is available here.
import torch
from torchabc import TorchABC
from functools import cached_property
class MyModel(TorchABC):
@cached_property
def dataloaders(self):
raise NotImplementedError
@staticmethod
def preprocess(sample, hparams, flag=''):
return sample
@staticmethod
def collate(samples):
return torch.utils.data.default_collate(samples)
@cached_property
def network(self):
raise NotImplementedError
@staticmethod
def loss(outputs, targets, hparams):
raise NotImplementedError
@cached_property
def optimizer(self):
raise NotImplementedError
@staticmethod
def postprocess(outputs, hparams):
return outputs
Once a subclass of TorchABC
is implemented, it can be used for training, evaluation, checkpointing, and inference.
model = MyModel()
Initialize the model.
model.train(epochs=5, on="train", val="val")
Train the model for 5 epochs using the train
and val
dataloaders.
metrics = model.eval(on="test")
Evaluate on the test
dataloader and return metrics.
model.save("checkpoint.pth")
model.load("checkpoint.pth")
Save and restore the model state.
preds = model(samples)
Run predictions on raw input samples.
The TorchABC
class defines a standard workflow for PyTorch projects. Some methods are abstract (must be implemented in subclasses), others are optional (can be overridden but have defaults), and a few are concrete (should not be overridden).
Method | Description |
---|---|
dataloaders |
Must return dict[str, torch.utils.data.DataLoader] . Example keys: "train" , "val" , "test" . |
preprocess(sample, hparams, flag='') |
Transform a raw dataset sample. Parameters: - sample (Any ): raw sample.- hparams (dict ): hyperparameters.- flag (str , optional): mode flag.Returns: Tensor or iterable of tensors. |
collate(samples) |
Collate a batch of preprocessed samples. Parameters: - samples (Iterable[Tensor] )Returns: Tensor or iterable of tensors. |
network |
Must return a torch.nn.Module . Inputs and outputs must use (batch_size, ...) format. |
optimizer |
Must return a torch.optim.Optimizer for self.network.parameters() . |
loss(outputs, targets, hparams) |
Compute loss for a batch. Parameters: - outputs (Tensor or iterable)- targets (Tensor or iterable)- hparams (dict )Returns: dict[str, Any] containing key "loss" . |
postprocess(outputs, hparams) |
Convert network outputs into predictions. Parameters: - outputs (Tensor or iterable)- hparams (dict )Returns: predictions ( Any ). |
Method | Description |
---|---|
scheduler |
Learning rate scheduler. May return None , torch.optim.lr_scheduler.LRScheduler , or ReduceLROnPlateau . Default is None . |
backward(batch, gas) |
Backpropagation step. Parameters: - batch (dict[str, Any] ): must contain key "loss" .- gas (int ): gradient accumulation steps. |
metrics(batches, hparams) |
Compute evaluation metrics. Parameters: - batches (deque[dict[str, Any]] ): batch results.- hparams (dict )Returns: dict[str, Any] . Default computes average loss. |
checkpoint(epoch, metrics, out) |
Checkpoint step. Saves model if loss improves. Parameters: - epoch (int ): epoch number.- metrics (dict[str, float] ): validation metrics.- out (str or None ): output path to save checkpoints.Returns: bool indicating early stopping. |
move(data) |
Move data to current device. Supports Tensor , list, tuple, dict. |
detach(data) |
Detach data from computation graph. Supports Tensor , list, tuple, dict. |
Method | Description |
---|---|
TorchABC(device=None, logger=print, hparams=None, **kwargs) |
Initialize the model. Parameters: - device (str or torch.device , optional): computation device. Defaults to CUDA if available, otherwise MPS or CPU.- logger (Callable[[dict], None] , optional): logging function. Defaults to print .- hparams (dict , optional): dictionary of hyperparameters.- kwargs : additional attributes stored in the instance. |
train(epochs, gas=1, mas=None, on='train', val='val', out=None) |
Train the model. Parameters: - epochs (int ): number of training epochs.- gas (int , optional): gradient accumulation steps. Defaults to 1.- mas (int , optional): metrics accumulation steps. Defaults to gas .- on (str , optional): training dataloader name. Default "train" .- val (str , optional): validation dataloader name. Default "val" . If None , validation is skipped.- out (str , optional): output path to save checkpoints. |
eval(on) |
Evaluate the model. Parameters: - on (str ): dataloader name.Returns: dict[str, float] of evaluation metrics. |
__call__(samples) |
Run inference on raw samples. Parameters: - samples (Iterable[Any] ): raw samples.Returns: postprocessed predictions. |
save(path) |
Save a checkpoint. Parameters: - path (str ): file path. |
load(path) |
Load a checkpoint. Parameters: - path (str ): file path. |
Get started with simple self-contained examples:
Install the dependencies
poetry install --with examples
Run the examples by replacing <name>
with one of the filenames in the examples folder
poetry run python examples/<name>.py
Contributions are welcome! Submit pull requests with new examples or improvements to the core TorchABC
class itself.