Skip to content

Commit

Permalink
updated docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed May 12, 2024
1 parent f72df49 commit 29dc796
Show file tree
Hide file tree
Showing 21 changed files with 184 additions and 449 deletions.
52 changes: 28 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ N/B: Codes are implemented pedagogically at the expense of repetition.
Each model is purposefully contained in a file without inter-file dependencies.

## Overview
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks and abstracts distributed training, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
- An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc.
Expand Down Expand Up @@ -57,27 +57,30 @@ We provide various example usages of the nanodl API.

```py
import jax
import nanodl
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer, Tokenizer
from nanodl import GPT4, GPTDataParallelTrainer

# Preparing your dataset
batch_size = 8
max_length = 50
vocab_size = 1000

# Create random data
data = nanodl.uniform(shape=(batch, max_length))
data = nanodl.uniform(
shape=(batch_size, max_length),
minval=0, maxval=vocab_size-1
).astype(jnp.int32)

# Shift to create next-token prediction dataset
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]

# Create dataset and dataloader
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, drop_last=False
)

# model parameters
hyperparams = {
Expand All @@ -96,29 +99,32 @@ hyperparams = {
# Inferred GPT4 model
model = GPT4(**hyperparams)

trainer = GPTDataParallelTrainer(model,
dummy_inputs.shape,
'params.pkl')
trainer = GPTDataParallelTrainer(
model, dummy_inputs.shape, 'params.pkl'
)

trainer.train(train_loader=dataloader,
num_epochs=100,
val_loader=dataloader) # use actual val data
trainer.train(
train_loader=dataloader, num_epochs=100, val_loader=dataloader
) # use actual val data

# Generating from a start token
start_tokens = jnp.array([[123, 456]])

# Remember to load the trained parameters
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': nanodl.time_rng_key()},
method=model.generate)

outputs = model.apply(
{'params': params},
start_tokens,
rngs={'dropout': nanodl.time_rng_key()},
method=model.generate
)
```

Vision example

```py
import jax
import nanodl
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer
Expand All @@ -143,7 +149,7 @@ trainer = DiffusionDataParallelTrainer(diffusion_model,
weights_filename='params.pkl',
learning_rate=1e-4)

trainer.train(dataloader, 10, dataloader) # use actual val data
trainer.train(dataloader, 10)

# Generate some samples: Each model is a Flax.linen module
# Use as you normally would
Expand Down Expand Up @@ -205,17 +211,15 @@ params = trainer.load_params('params.pkl')

# for more than one sample, often use model.generate_batch
transcripts = model.apply({'params': params},
dummy_inputs[:1],
rngs=rngs,
dummy_inputs[:1],
method=model.generate)
```

Reward Model example for RLHF

```py
import jax
import nanodl
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

Expand Down Expand Up @@ -266,7 +270,7 @@ rewards = reward_model.apply({'params': params},
PCA example

```py
import jax
import nanodl
from nanodl import PCA

# Use actual data
Expand Down
36 changes: 0 additions & 36 deletions nanodl/__src/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ class PositionalEncoding(nn.Module):
Attributes:
num_embeddings (int): The maximum number of positions for which to generate positional encodings.
features (int): The dimensionality of the embeddings/positional encodings.
Methods:
setup(): Initializes the positional encoding matrix based on the provided attributes.
__call__(x: jnp.ndarray): Adds positional encodings to the input embeddings.
"""

num_embeddings: int
Expand Down Expand Up @@ -58,10 +54,6 @@ class TokenAndPositionEmbedding(nn.Module):
vocab_size (int): Size of the vocabulary.
embed_dim (int): Dimension of the embeddings.
learned_position (bool): Flag to use learned positional embeddings instead of fixed positional encodings.
Methods:
setup(): Initializes token and positional embeddings.
__call__(x: jnp.ndarray): Applies token embeddings and adds positional information to the input sequence.
"""

max_len: int
Expand Down Expand Up @@ -100,11 +92,6 @@ class SelfMultiHeadAttention(nn.Module):
Attributes:
hidden_dim (int): Dimensionality of the input and output features.
num_heads (int): Number of attention heads.
Methods:
setup(): Initializes projection matrices for queries, keys, values, and the output projection.
__call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism.
attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors.
"""

hidden_dim: int
Expand Down Expand Up @@ -172,10 +159,6 @@ class PositionWiseFFN(nn.Module):
Attributes:
num_hiddens (int): The number of hidden units in the first linear layer.
num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size).
Methods:
setup(): Initializes the two linear layers.
__call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor.
"""

num_hiddens: int
Expand All @@ -201,9 +184,6 @@ class AddNorm(nn.Module):
Attributes:
dropout (float): Dropout rate for the residual connection.
Methods:
__call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization.
"""

dropout: int
Expand All @@ -226,10 +206,6 @@ class EncoderBlock(nn.Module):
num_heads (int): Number of attention heads.
feedforward_dim (int): Dimension of the feed-forward network.
dropout (float): Dropout rate.
Methods:
setup(): Initializes the attention, feed-forward network, and normalization layers.
__call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block.
"""

hidden_dim: int
Expand Down Expand Up @@ -271,10 +247,6 @@ class TextEncoder(nn.Module):
vocab_size (int): Size of the vocabulary.
embed_dim (int): Dimension of the embeddings.
learned_position (bool): Flag to use learned positional embeddings instead of fixed positional encodings.
Methods:
setup(): Initializes the embedding layer and the encoder blocks.
__call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the transformer encoder.
"""

num_layers: int
Expand Down Expand Up @@ -319,9 +291,6 @@ class PatchEmbedding(nn.Module):
patch_size (tuple): Size (height, width) of the patches to extract from input images.
embed_dim (int): Dimension of the embeddings for the patches.
Methods:
__call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding.
extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images.
"""

patch_size: Tuple[int, int]
Expand Down Expand Up @@ -371,9 +340,6 @@ class ImageEncoder(nn.Module):
feedforward_dim (int): Dimension of the feed-forward network in the transformer encoder.
dropout (float): Dropout rate for regularization.
Methods:
setup(): Initializes the patch embedding and encoder blocks.
__call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input images through the vision transformer encoder.
"""

patch_size: Tuple[int, int]
Expand Down Expand Up @@ -429,8 +395,6 @@ class CLIP(nn.Module):
- num_layers_images (int): Number of transformer layers for image encoding.
Methods:
- setup(): Initializes the model components and parameters.
- __call__(texts, images, training): Computes embeddings for text and images.
- get_attention_maps(texts, images): Computes attention maps for text and images.
- encode_text(texts): Encodes text data using the text encoder.
- encode_image(images): Encodes image data using the image encoder.
Expand Down
16 changes: 0 additions & 16 deletions nanodl/__src/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class SinusoidalEmbedding(nn.Module):
embedding_min_frequency (float): The minimum frequency used in the sinusoidal embedding.
embedding_max_frequency (float): The maximum frequency used in the sinusoidal embedding.
Methods:
setup(): Initializes the layer by computing the angular speeds for the sinusoidal functions based on the specified frequency range.
__call__(x: jnp.ndarray): Generates the sinusoidal embeddings for the input positions.
"""

embedding_dims: int
Expand Down Expand Up @@ -53,8 +50,6 @@ class UNetResidualBlock(nn.Module):
Attributes:
width (int): The number of output channels for the convolutional layers within the block.
Methods:
__call__(x: jnp.ndarray): Processes the input tensor through the residual block and returns the result.
"""

width: int
Expand Down Expand Up @@ -92,9 +87,6 @@ class UNetDownBlock(nn.Module):
width (int): The number of output channels for the convolutional layers within the block.
block_depth (int): The number of residual blocks to include in the down-sampling block.
Methods:
setup(): Initializes the sequence of residual blocks.
__call__(x: jnp.ndarray): Processes the input tensor through the down-sampling block and returns the result.
"""

width: int
Expand Down Expand Up @@ -122,9 +114,6 @@ class UNetUpBlock(nn.Module):
width (int): The number of output channels for the convolutional layers within the block.
block_depth (int): The number of residual blocks to include in the up-sampling block.
Methods:
setup(): Initializes the sequence of residual blocks.
__call__(x: jnp.ndarray, skip: jnp.ndarray): Processes the input tensor and a skip connection from the encoding pathway through the up-sampling block and returns the result.
"""

width: int
Expand Down Expand Up @@ -159,9 +148,6 @@ class UNet(nn.Module):
embed_min_freq (float): The minimum frequency for the sinusoidal embeddings.
embed_max_freq (float): The maximum frequency for the sinusoidal embeddings.
Methods:
setup(): Initializes the U-Net architecture including the sinusoidal embedding layer, down-sampling blocks, residual blocks, and up-sampling blocks.
__call__(noisy_images: jnp.ndarray, noise_variances: jnp.ndarray): Processes noisy images and their associated noise variances through the U-Net and returns the denoised images.
"""

image_size: Tuple[int, int]
Expand Down Expand Up @@ -237,10 +223,8 @@ class DiffusionModel(nn.Module):
embed_max_freq (float): The maximum frequency for the sinusoidal embeddings.
Methods:
setup(): Initializes the diffusion model including the U-Net architecture.
diffusion_schedule(diffusion_times: jnp.ndarray): Computes the noise and signal rates for given diffusion times.
denoise(noisy_images: jnp.ndarray, noise_rates: jnp.ndarray, signal_rates: jnp.ndarray): Denoises images given their noise and signal rates.
__call__(images: jnp.ndarray): Applies the diffusion process to a batch of images.
reverse_diffusion(initial_noise: jnp.ndarray, diffusion_steps: int): Reverses the diffusion process to generate images from noise.
generate(num_images: int, diffusion_steps: int): Generates images by reversing the diffusion process from random noise.
Expand Down
19 changes: 0 additions & 19 deletions nanodl/__src/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@ class RotaryPositionalEncoding:
Attributes:
dim_model (int): The dimensionality of the model embeddings.
Methods:
_update_cos_sin_tables(x, seq_dimension): Updates cosine and sine tables based on the sequence length.
rotate_half(x): Rotates the last half of the dimensions of x by swapping them and changing signs to simulate a 90-degree rotation.
apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings.
__call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms.
"""

def __init__(self, dim_model: int):
Expand Down Expand Up @@ -79,11 +74,6 @@ class GroupedRotaryMultiHeadAttention(nn.Module):
num_heads (int): Number of attention heads.
num_groups (int): Number of groups to split the heads into for applying rotary positional embeddings separately.
Methods:
setup(): Initializes the projections for query, key, value, and output, along with the rotary positional encoder.
__call__(inputs, context, mask): Processes the input and context tensors through the grouped rotary multi-head attention mechanism.
process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding and attention.
attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors.
"""

hidden_dim: int # Output dimension
Expand Down Expand Up @@ -217,10 +207,6 @@ class GemmaDecoderBlock(nn.Module):
dropout (float): Dropout rate for regularization.
num_groups (int): Number of groups for the grouped rotary positional embeddings.
Methods:
setup(): Initializes the components of the Gemma decoder block.
causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism.
__call__(x, training): Processes the input tensor through the Gemma decoder block.
"""

hidden_dim: int
Expand Down Expand Up @@ -290,9 +276,6 @@ class GemmaDecoder(nn.Module):
vocab_size (float): Size of the vocabulary.
embed_dim (float): Dimensionality of the token embeddings.
Methods:
setup(): Initializes the components of the LLaMA2 decoder.
__call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 decoder.
"""

num_layers: int
Expand Down Expand Up @@ -356,8 +339,6 @@ class Gemma(nn.Module):
end_token (int): Token that indicates the end of a generated sequence.
Methods:
setup(): Initializes the LLaMA2 model including the decoder component.
__call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 model.
generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively.
generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively.
Expand Down
Loading

0 comments on commit 29dc796

Please sign in to comment.