Skip to content

Commit

Permalink
Merge pull request #44 from krasserm/wip-hugging-face-api
Browse files Browse the repository at this point in the history
Hugging Face interface for inference
  • Loading branch information
krasserm authored Apr 23, 2023
2 parents 737a766 + 3572972 commit 452a6ea
Show file tree
Hide file tree
Showing 79 changed files with 17,267 additions and 10,201 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ repos:
- id: check-toml
- id: check-docstring-first
- id: check-case-conflict
- id: check-added-large-files
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
Expand Down
179 changes: 164 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Perceiver, Perceiver IO and Perceiver AR

This repository is a PyTorch and PyTorch Lightning implementation of
This repository is a PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR, with PyTorch Lightning
interfaces for model training and Hugging Face 🤗 interfaces for inference.

<table>
<tr>
Expand Down Expand Up @@ -29,13 +30,32 @@ This repository is a PyTorch and PyTorch Lightning implementation of
</tr>
</table>

All model classes are written in plain PyTorch and can be wrapped into [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/)
modules for training at scale. The command line interface is implemented with the [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/cli/lightning_cli.html).
[Pretrained weights](docs/pretrained-models.md) can be imported for [official models](docs/pretrained-models.md#official-models)
from the 🤗 Hub, [training checkpoints](docs/pretrained-models.md#training-checkpoints) from [training examples](docs/training-examples.md)
are available for download too. Datasets used in the training examples are 🤗 [datasets](https://huggingface.co/docs/datasets)
wrapped into PyTorch Lightning [data modules](perceiver/data). For NLP tasks, this library supports all 🤗
[fast tokenizers](https://huggingface.co/docs/transformers/fast_tokenizers) and the 🤗 Perceiver UTF-8 bytes tokenizer.
## Overview

Core of the `perceiver-io` library are *backend models*, lightweight PyTorch implementations of Perceiver,
Perceiver IO and Perceiver AR. They can be wrapped into [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/)
modules for training (*Lightning interface*) and 🤗 modules for inference (*Hugging Face interface*). See
[library design](docs/library-design.md) for details.

<p align="center">
<img src="docs/images/library-design-small.jpg" alt="library-design"/>
</p>

The command line interface for training is implemented with [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/cli/lightning_cli.html).
Training datasets are 🤗 [datasets](https://huggingface.co/docs/datasets) wrapped into PyTorch Lightning data modules.
For NLP tasks, `perceiver-io` supports all 🤗 [fast tokenizers](https://huggingface.co/docs/transformers/fast_tokenizers)
and the 🤗 Perceiver UTF-8 bytes tokenizer.

## Documentation

- [Installation](#installation)
- [Getting started](#getting-started)
- [Library design](docs/library-design.md)
- [Pretrained models](docs/pretrained-models.md)
- [Training examples](docs/training-examples.md)
- [Inference examples](examples/inference.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/perceiver-io/blob/main/examples/inference.ipynb)
- [Model construction](docs/model-construction.md)
- [Building blocks](docs/building-blocks.md)

## Installation

Expand Down Expand Up @@ -78,14 +98,143 @@ docker pull ghcr.io/krasserm/perceiver-io:latest

See [Docker image](docs/docker-image.md) for details.

## Documentation
## Getting started

- [Getting started](docs/getting-started.md)
- [Model construction](docs/model-construction.md)
- [Pretrained models](docs/pretrained-models.md)
- [Training examples](docs/training-examples.md)
- [Inference examples](examples/inference.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/krasserm/perceiver-io/blob/0.8.2/examples/inference.ipynb)
- [Building blocks](docs/building-blocks.md)
### Inference

Compute the optical flow between consecutive frames of an input video and write the rendered results to an output
video:

```python
from urllib.request import urlretrieve
from transformers import pipeline

from perceiver.data.vision import video_utils
from perceiver.model.vision import optical_flow # register auto-classes and pipeline

urlretrieve(
url="https://martin-krasser.com/perceiver/flow/sintel_clip_cave_dragon_fight.mp4",
filename="sintel_clip_cave_dragon_fight.mp4",
)

# Create optical flow pipeline
optical_flow_pipeline = pipeline("optical-flow", model="krasserm/perceiver-io-optical-flow", device="cuda:0")

# load consecutive video frame pairs
frame_pairs = video_utils.read_video_frame_pairs("sintel_clip_cave_dragon_fight.mp4")

# create and render optical flow for all frame pairs
optical_flows = optical_flow_pipeline(frame_pairs, render=True, device="cuda:0")

# create video with rendered optical flows
video_utils.write_video("sintel_clip_cave_dragon_fight_output.mp4", optical_flows, fps=24)
```

Here is a side-by-side comparison of the input and output video:

<p align="center">
<img src="docs/images/optical-flow.gif" alt="optical-flow-sbs">
</p>

See [inference examples](https://colab.research.google.com/github/krasserm/perceiver-io/blob/main/examples/inference.ipynb)
for more examples.

### Training

Train a small Perceiver IO image classifier (907K parameters) on MNIST from the command line. The classifier
cross-attends to individual pixels of input images with [repeated cross-attention](docs/building-blocks.md).
See [image classification](docs/training-examples.md#image-classification) training example for more details.

```shell
python -m perceiver.scripts.vision.image_classifier fit \
--model.num_latents=32 \
--model.num_latent_channels=128 \
--model.encoder.num_frequency_bands=32 \
--model.encoder.num_cross_attention_layers=2 \
--model.encoder.num_self_attention_blocks=3 \
--model.encoder.num_self_attention_layers_per_block=3 \
--model.encoder.first_self_attention_block_shared=false \
--model.encoder.dropout=0.1 \
--model.encoder.init_scale=0.1 \
--model.decoder.num_output_query_channels=128 \
--model.decoder.dropout=0.1 \
--model.decoder.init_scale=0.1 \
--data=MNISTDataModule \
--data.batch_size=64 \
--optimizer=AdamW \
--optimizer.lr=1e-3 \
--lr_scheduler.warmup_steps=500 \
--trainer.accelerator=gpu \
--trainer.devices=1 \
--trainer.max_epochs=30 \
--trainer.logger=TensorBoardLogger \
--trainer.logger.save_dir=logs \
--trainer.logger.name=logs
```

[Model construction](docs/model-construction.md) describes how to implement model-specific command line interfaces
with the Lightning CLI. Training checkpoints are written to the `logs/img_clf/version_0/checkpoints` directory. Assuming
a checkpoint with filename `epoch=025-val_loss=0.065.ckpt` exists, it can be converted to a `perceiver-io` 🤗 model with

```python
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint

convert_mnist_classifier_checkpoint(
save_dir="example/mnist-classifier",
ckpt_url="logs/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
)
```

so that it can be used in a 🤗 image classification pipeline

```python
from datasets import load_dataset
from transformers import pipeline

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

classifier = pipeline("image-classification", model="example/mnist-classifier")
predictions = [int(pred[0]["label"]) for pred in classifier(images)]

print(f"Labels: {labels}")
print(f"Predictions: {predictions}")
```
```
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
```

or loaded directly:

```python
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor

model = AutoModelForImageClassification.from_pretrained("example/mnist-classifier")
processor = AutoImageProcessor.from_pretrained("example/mnist-classifier")

inputs = processor(images, return_tensors="pt")

with torch.no_grad():
# use perceiver-io Hugging Face model
output_1 = model(**inputs).logits

with torch.no_grad():
# or use perceiver-io backend model directly
output_2 = model.backend_model(inputs.pixel_values)

print(f"Predictions: {output_1.argmax(dim=-1).numpy().tolist()}")
print(f"Predictions: {output_2.argmax(dim=-1).numpy().tolist()}")
```
```
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
```

See [training examples](docs/training-examples.md) for more examples.

## Articles

Expand Down
3 changes: 2 additions & 1 deletion docs/building-blocks.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Building blocks

The following subsections map Perceiver, Perceiver IO and Perceiver AR concepts to the [core modules](../perceiver/model/core/modules.py)
of this library. Core modules are the building blocks of concrete PyTorch models (see also [model construction](model-construction.md)).
of this library. Core modules are the building blocks of concrete backend models (see also [library-design](library-design.md)
and [model construction](model-construction.md)).

## Perceiver IO

Expand Down
180 changes: 0 additions & 180 deletions docs/getting-started.md

This file was deleted.

Binary file added docs/images/library-design-small.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/library-design.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/optical-flow.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 452a6ea

Please sign in to comment.