Skip to content

A PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

License

Notifications You must be signed in to change notification settings

Lightning-Sandbox/perceiver-io

Β 
Β 

Repository files navigation

Perceiver, Perceiver IO and Perceiver AR

This repository is a PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR, with PyTorch Lightning interfaces for model training and Hugging Face πŸ€— interfaces for inference.

Perceiver: General Perception with Iterative Attention (paper, video) Perceiver
Perceiver IO: A General Architecture for Structured Inputs & Outputs (paper, blog post) Perceiver IO
General-purpose, long-context autoregressive modeling with Perceiver AR (paper, blog post) Perceiver AR

Overview

Core of the perceiver-io library are backend models, lightweight PyTorch implementations of Perceiver, Perceiver IO and Perceiver AR. They can be wrapped into PyTorch Lightning modules for training (Lightning interface) and πŸ€— modules for inference (Hugging Face interface). See library design for details.

library-design

The command line interface for training is implemented with Lightning CLI. Training datasets are πŸ€— datasets wrapped into PyTorch Lightning data modules. For NLP tasks, perceiver-io supports all πŸ€— fast tokenizers and the πŸ€— Perceiver UTF-8 bytes tokenizer.

Documentation

Installation

Via pip

pip install perceiver-io[text,vision,audio]

From sources

Installation from sources requires a Miniconda and a Poetry (1.2.0 or higher) installation.

Create and activate the perceiver-io conda environment:

conda env create -f environment.yml
conda activate perceiver-io

Install main and test dependencies, including all extras:

# Without dependencies required for examples
poetry install --all-extras

If you want to run the examples locally, additionally use --with examples:

poetry install --all-extras --with examples

Docker image

docker pull ghcr.io/krasserm/perceiver-io:latest

See Docker image for details.

Getting started

Inference

Optical flow

Compute the optical flow between consecutive frames of an input video and write the rendered results to an output video:

from urllib.request import urlretrieve
from transformers import pipeline

from perceiver.data.vision import video_utils
from perceiver.model.vision import optical_flow  # register auto-classes and pipeline

urlretrieve(
    url="https://martin-krasser.com/perceiver/flow/sintel_clip_cave_dragon_fight.mp4",
    filename="sintel_clip_cave_dragon_fight.mp4",
)

# Create optical flow pipeline
optical_flow_pipeline = pipeline("optical-flow", model="krasserm/perceiver-io-optical-flow", device="cuda:0")

# load consecutive video frame pairs
frame_pairs = video_utils.read_video_frame_pairs("sintel_clip_cave_dragon_fight.mp4")

# create and render optical flow for all frame pairs
optical_flows = optical_flow_pipeline(frame_pairs, render=True, device="cuda:0")

# create video with rendered optical flows
video_utils.write_video("sintel_clip_cave_dragon_fight_output.mp4", optical_flows, fps=24)

Here is a side-by-side comparison of the input and output video:

optical-flow-sbs

Symbolic audio generation

Create audio sequences by generating symbolic (MIDI) audio data and converting the generated audio symbols into WAV output using fluidsynth (Note: fluidsynth must be installed in order for the following example to work):

from transformers import pipeline
from pretty_midi import PrettyMIDI
from perceiver.model.audio import symbolic  # auto-class registration

repo_id = "krasserm/perceiver-ar-sam-giant-midi"

prompt = PrettyMIDI("prompt.mid")
audio_generator = pipeline("symbolic-audio-generation", model=repo_id)

output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)

with open("generated_audio.wav", "wb") as f:
    f.write(output["generated_audio_wav"])

Examples of generated audio sequences are available on the πŸ€— hub.

See inference examples for more examples.

Training

Train a small Perceiver IO image classifier (907K parameters) on MNIST from the command line. The classifier cross-attends to individual pixels of input images with repeated cross-attention. See image classification training example for more details.

python -m perceiver.scripts.vision.image_classifier fit \
  --model.num_latents=32 \
  --model.num_latent_channels=128 \
  --model.encoder.num_frequency_bands=32 \
  --model.encoder.num_cross_attention_layers=2 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.num_self_attention_layers_per_block=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.encoder.init_scale=0.1 \
  --model.decoder.num_output_query_channels=128 \
  --model.decoder.dropout=0.1 \
  --model.decoder.init_scale=0.1 \
  --data=MNISTDataModule \
  --data.batch_size=64 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --lr_scheduler.warmup_steps=500 \
  --trainer.accelerator=gpu \
  --trainer.devices=1 \
  --trainer.max_epochs=30 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=logs

Model construction describes how to implement model-specific command line interfaces with the Lightning CLI. Training checkpoints are written to the logs/img_clf/version_0/checkpoints directory. Assuming a checkpoint with filename epoch=025-val_loss=0.065.ckpt exists, it can be converted to a perceiver-io πŸ€— model with

from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint

convert_mnist_classifier_checkpoint(
    save_dir="example/mnist-classifier",
    ckpt_url="logs/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
)

so that it can be used in a πŸ€— image classification pipeline

from datasets import load_dataset
from transformers import pipeline

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

classifier = pipeline("image-classification", model="example/mnist-classifier")
predictions = [pred[0]["label"] for pred in classifier(images)]

print(f"Labels:      {labels}")
print(f"Predictions: {predictions}")
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

or loaded directly:

import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor

model = AutoModelForImageClassification.from_pretrained("example/mnist-classifier")
processor = AutoImageProcessor.from_pretrained("example/mnist-classifier")

inputs = processor(images, return_tensors="pt")

with torch.no_grad():
    # use perceiver-io Hugging Face model
    output_1 = model(**inputs).logits

with torch.no_grad():
    # or use perceiver-io backend model directly  
    output_2 = model.backend_model(inputs.pixel_values)

print(f"Predictions: {output_1.argmax(dim=-1).numpy().tolist()}")
print(f"Predictions: {output_2.argmax(dim=-1).numpy().tolist()}")
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

See training examples for more examples.

Articles

Articles referencing this repository:

Other implementations

About

A PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.8%
  • Dockerfile 0.2%