diff --git a/README.md b/README.md index d50b196..2aee2e2 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ N/B: Codes are implemented pedagogically at the expense of repetition. Each model is purposefully contained in a file without inter-file dependencies. ## Overview -Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: +Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks and abstracts distributed training, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. - An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc. @@ -57,9 +57,10 @@ We provide various example usages of the nanodl API. ```py import jax +import nanodl import jax.numpy as jnp from nanodl import ArrayDataset, DataLoader -from nanodl import GPT4, GPTDataParallelTrainer, Tokenizer +from nanodl import GPT4, GPTDataParallelTrainer # Preparing your dataset batch_size = 8 @@ -67,17 +68,19 @@ max_length = 50 vocab_size = 1000 # Create random data -data = nanodl.uniform(shape=(batch, max_length)) +data = nanodl.uniform( + shape=(batch_size, max_length), + minval=0, maxval=vocab_size-1 + ).astype(jnp.int32) # Shift to create next-token prediction dataset dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, drop_last=False + ) # model parameters hyperparams = { @@ -96,29 +99,32 @@ hyperparams = { # Inferred GPT4 model model = GPT4(**hyperparams) -trainer = GPTDataParallelTrainer(model, - dummy_inputs.shape, - 'params.pkl') +trainer = GPTDataParallelTrainer( + model, dummy_inputs.shape, 'params.pkl' + ) -trainer.train(train_loader=dataloader, - num_epochs=100, - val_loader=dataloader) # use actual val data +trainer.train( + train_loader=dataloader, num_epochs=100, val_loader=dataloader + ) # use actual val data # Generating from a start token start_tokens = jnp.array([[123, 456]]) # Remember to load the trained parameters params = trainer.load_params('params.pkl') -outputs = model.apply({'params': params}, - start_tokens, - rngs={'dropout': nanodl.time_rng_key()}, - method=model.generate) + +outputs = model.apply( + {'params': params}, + start_tokens, + rngs={'dropout': nanodl.time_rng_key()}, + method=model.generate + ) ``` Vision example ```py -import jax +import nanodl import jax.numpy as jnp from nanodl import ArrayDataset, DataLoader from nanodl import DiffusionModel, DiffusionDataParallelTrainer @@ -143,7 +149,7 @@ trainer = DiffusionDataParallelTrainer(diffusion_model, weights_filename='params.pkl', learning_rate=1e-4) -trainer.train(dataloader, 10, dataloader) # use actual val data +trainer.train(dataloader, 10) # Generate some samples: Each model is a Flax.linen module # Use as you normally would @@ -205,17 +211,15 @@ params = trainer.load_params('params.pkl') # for more than one sample, often use model.generate_batch transcripts = model.apply({'params': params}, - dummy_inputs[:1], - rngs=rngs, + dummy_inputs[:1], method=model.generate) ``` Reward Model example for RLHF ```py -import jax +import nanodl import jax.numpy as jnp -from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import Mistral, RewardModel, RewardDataParallelTrainer @@ -266,7 +270,7 @@ rewards = reward_model.apply({'params': params}, PCA example ```py -import jax +import nanodl from nanodl import PCA # Use actual data diff --git a/nanodl/__src/models/clip.py b/nanodl/__src/models/clip.py index 31601d6..45b2361 100644 --- a/nanodl/__src/models/clip.py +++ b/nanodl/__src/models/clip.py @@ -18,10 +18,6 @@ class PositionalEncoding(nn.Module): Attributes: num_embeddings (int): The maximum number of positions for which to generate positional encodings. features (int): The dimensionality of the embeddings/positional encodings. - - Methods: - setup(): Initializes the positional encoding matrix based on the provided attributes. - __call__(x: jnp.ndarray): Adds positional encodings to the input embeddings. """ num_embeddings: int @@ -58,10 +54,6 @@ class TokenAndPositionEmbedding(nn.Module): vocab_size (int): Size of the vocabulary. embed_dim (int): Dimension of the embeddings. learned_position (bool): Flag to use learned positional embeddings instead of fixed positional encodings. - - Methods: - setup(): Initializes token and positional embeddings. - __call__(x: jnp.ndarray): Applies token embeddings and adds positional information to the input sequence. """ max_len: int @@ -100,11 +92,6 @@ class SelfMultiHeadAttention(nn.Module): Attributes: hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - - Methods: - setup(): Initializes projection matrices for queries, keys, values, and the output projection. - __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. - attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int @@ -172,10 +159,6 @@ class PositionWiseFFN(nn.Module): Attributes: num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -201,9 +184,6 @@ class AddNorm(nn.Module): Attributes: dropout (float): Dropout rate for the residual connection. - - Methods: - __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ dropout: int @@ -226,10 +206,6 @@ class EncoderBlock(nn.Module): num_heads (int): Number of attention heads. feedforward_dim (int): Dimension of the feed-forward network. dropout (float): Dropout rate. - - Methods: - setup(): Initializes the attention, feed-forward network, and normalization layers. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block. """ hidden_dim: int @@ -271,10 +247,6 @@ class TextEncoder(nn.Module): vocab_size (int): Size of the vocabulary. embed_dim (int): Dimension of the embeddings. learned_position (bool): Flag to use learned positional embeddings instead of fixed positional encodings. - - Methods: - setup(): Initializes the embedding layer and the encoder blocks. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the transformer encoder. """ num_layers: int @@ -319,9 +291,6 @@ class PatchEmbedding(nn.Module): patch_size (tuple): Size (height, width) of the patches to extract from input images. embed_dim (int): Dimension of the embeddings for the patches. - Methods: - __call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding. - extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images. """ patch_size: Tuple[int, int] @@ -371,9 +340,6 @@ class ImageEncoder(nn.Module): feedforward_dim (int): Dimension of the feed-forward network in the transformer encoder. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the patch embedding and encoder blocks. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input images through the vision transformer encoder. """ patch_size: Tuple[int, int] @@ -429,8 +395,6 @@ class CLIP(nn.Module): - num_layers_images (int): Number of transformer layers for image encoding. Methods: - - setup(): Initializes the model components and parameters. - - __call__(texts, images, training): Computes embeddings for text and images. - get_attention_maps(texts, images): Computes attention maps for text and images. - encode_text(texts): Encodes text data using the text encoder. - encode_image(images): Encodes image data using the image encoder. diff --git a/nanodl/__src/models/diffusion.py b/nanodl/__src/models/diffusion.py index 6019f39..4b5a612 100644 --- a/nanodl/__src/models/diffusion.py +++ b/nanodl/__src/models/diffusion.py @@ -20,9 +20,6 @@ class SinusoidalEmbedding(nn.Module): embedding_min_frequency (float): The minimum frequency used in the sinusoidal embedding. embedding_max_frequency (float): The maximum frequency used in the sinusoidal embedding. - Methods: - setup(): Initializes the layer by computing the angular speeds for the sinusoidal functions based on the specified frequency range. - __call__(x: jnp.ndarray): Generates the sinusoidal embeddings for the input positions. """ embedding_dims: int @@ -53,8 +50,6 @@ class UNetResidualBlock(nn.Module): Attributes: width (int): The number of output channels for the convolutional layers within the block. - Methods: - __call__(x: jnp.ndarray): Processes the input tensor through the residual block and returns the result. """ width: int @@ -92,9 +87,6 @@ class UNetDownBlock(nn.Module): width (int): The number of output channels for the convolutional layers within the block. block_depth (int): The number of residual blocks to include in the down-sampling block. - Methods: - setup(): Initializes the sequence of residual blocks. - __call__(x: jnp.ndarray): Processes the input tensor through the down-sampling block and returns the result. """ width: int @@ -122,9 +114,6 @@ class UNetUpBlock(nn.Module): width (int): The number of output channels for the convolutional layers within the block. block_depth (int): The number of residual blocks to include in the up-sampling block. - Methods: - setup(): Initializes the sequence of residual blocks. - __call__(x: jnp.ndarray, skip: jnp.ndarray): Processes the input tensor and a skip connection from the encoding pathway through the up-sampling block and returns the result. """ width: int @@ -159,9 +148,6 @@ class UNet(nn.Module): embed_min_freq (float): The minimum frequency for the sinusoidal embeddings. embed_max_freq (float): The maximum frequency for the sinusoidal embeddings. - Methods: - setup(): Initializes the U-Net architecture including the sinusoidal embedding layer, down-sampling blocks, residual blocks, and up-sampling blocks. - __call__(noisy_images: jnp.ndarray, noise_variances: jnp.ndarray): Processes noisy images and their associated noise variances through the U-Net and returns the denoised images. """ image_size: Tuple[int, int] @@ -237,10 +223,8 @@ class DiffusionModel(nn.Module): embed_max_freq (float): The maximum frequency for the sinusoidal embeddings. Methods: - setup(): Initializes the diffusion model including the U-Net architecture. diffusion_schedule(diffusion_times: jnp.ndarray): Computes the noise and signal rates for given diffusion times. denoise(noisy_images: jnp.ndarray, noise_rates: jnp.ndarray, signal_rates: jnp.ndarray): Denoises images given their noise and signal rates. - __call__(images: jnp.ndarray): Applies the diffusion process to a batch of images. reverse_diffusion(initial_noise: jnp.ndarray, diffusion_steps: int): Reverses the diffusion process to generate images from noise. generate(num_images: int, diffusion_steps: int): Generates images by reversing the diffusion process from random noise. diff --git a/nanodl/__src/models/gemma.py b/nanodl/__src/models/gemma.py index b73d1c0..4eb4684 100644 --- a/nanodl/__src/models/gemma.py +++ b/nanodl/__src/models/gemma.py @@ -18,11 +18,6 @@ class RotaryPositionalEncoding: Attributes: dim_model (int): The dimensionality of the model embeddings. - Methods: - _update_cos_sin_tables(x, seq_dimension): Updates cosine and sine tables based on the sequence length. - rotate_half(x): Rotates the last half of the dimensions of x by swapping them and changing signs to simulate a 90-degree rotation. - apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. - __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. """ def __init__(self, dim_model: int): @@ -79,11 +74,6 @@ class GroupedRotaryMultiHeadAttention(nn.Module): num_heads (int): Number of attention heads. num_groups (int): Number of groups to split the heads into for applying rotary positional embeddings separately. - Methods: - setup(): Initializes the projections for query, key, value, and output, along with the rotary positional encoder. - __call__(inputs, context, mask): Processes the input and context tensors through the grouped rotary multi-head attention mechanism. - process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding and attention. - attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int # Output dimension @@ -217,10 +207,6 @@ class GemmaDecoderBlock(nn.Module): dropout (float): Dropout rate for regularization. num_groups (int): Number of groups for the grouped rotary positional embeddings. - Methods: - setup(): Initializes the components of the Gemma decoder block. - causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. - __call__(x, training): Processes the input tensor through the Gemma decoder block. """ hidden_dim: int @@ -290,9 +276,6 @@ class GemmaDecoder(nn.Module): vocab_size (float): Size of the vocabulary. embed_dim (float): Dimensionality of the token embeddings. - Methods: - setup(): Initializes the components of the LLaMA2 decoder. - __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 decoder. """ num_layers: int @@ -356,8 +339,6 @@ class Gemma(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the LLaMA2 model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. diff --git a/nanodl/__src/models/gpt.py b/nanodl/__src/models/gpt.py index 7f76ea4..96c11ca 100644 --- a/nanodl/__src/models/gpt.py +++ b/nanodl/__src/models/gpt.py @@ -19,10 +19,6 @@ class SelfMultiHeadAttention(nn.Module): hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes projection matrices for queries, keys, values, and the output projection. - __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. - attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int @@ -92,9 +88,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -148,10 +141,6 @@ class GPT3Block(nn.Module): feedforward_dim (int): Dimension of the feedforward layer. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the GPT-3 block, including attention mechanisms, feedforward network, and normalization layers. - causal_mask(batch_size, destination_dim, source_dim): Creates a causal mask to ensure that predictions for a position can depend only on known outputs at earlier positions. - __call__(x, mask=None, training=False): Defines the computation performed at every call of the GPT-3 block. """ hidden_dim: int @@ -228,9 +217,6 @@ class GPT3Decoder(nn.Module): vocab_size (int): The size of the vocabulary. embed_dim (int): The dimensionality of the token embeddings. - Methods: - setup(): Initializes the components of the GPT-3 decoder including the embedding layer, GPT-3 blocks, and the output layer. - __call__(x, mask, training, drop_last_layer): Processes the input tensor through the GPT-3 decoder, generating predictions for the next token in the sequence. """ num_layers: int @@ -296,8 +282,6 @@ class GPT3(nn.Module): end_token (int): The token that indicates the end of a generated sequence. Methods: - setup(): Initializes the GPT-3 model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the GPT-3 model, generating predictions for the next token in the sequence. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively, starting from an optional initial sequence. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. @@ -518,19 +502,6 @@ class SparseMixtureOfExperts(nn.Module): num_experts (int): Number of experts. top_k (int): Number of top experts to use for each input instance. - Methods: - setup(): Initializes the experts, the gating mechanism, and the final dense layer. - - __call__(X: jnp.ndarray) -> jnp.ndarray: - Performs a forward pass through the Mixture of Experts layer. - - Args: - X (jnp.ndarray): Input tensor of shape (batch_size, seq_length, input_dim). - - Returns: - jnp.ndarray: Output tensor after processing through the MoE layer. The output - tensor has the same batch and sequence length dimensions as the input tensor, - but the last dimension is equal to num_outputs. """ num_hiddens: int @@ -585,10 +556,6 @@ class GPT4Block(nn.Module): num_experts (int): Number of experts in the sparse mixture of experts layer. top_k (int): Number of experts to be activated for each input in the sparse mixture of experts layer. - Methods: - setup(): Initializes the components of the GPT-4 block. - causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. - __call__(x, mask, training): Processes the input tensor through the GPT-4 block. """ hidden_dim: int @@ -672,9 +639,6 @@ class GPT4Decoder(nn.Module): num_experts (int): Number of experts in the sparse mixture of experts layer in each GPT-4 block. top_k (int): Number of experts to be activated for each input in the sparse mixture of experts layer in each GPT-4 block. - Methods: - setup(): Initializes the components of the GPT-4 decoder. - __call__(x, mask, training, drop_last_layer): Processes the input tensor through the GPT-4 decoder. """ num_layers: int @@ -749,8 +713,6 @@ class GPT4(nn.Module): top_k (int): Number of experts to be activated for each input in the sparse mixture of experts layer. Methods: - setup(): Initializes the GPT-4 model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the GPT-4 model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. diff --git a/nanodl/__src/models/ijepa.py b/nanodl/__src/models/ijepa.py index 345c218..422ec5c 100644 --- a/nanodl/__src/models/ijepa.py +++ b/nanodl/__src/models/ijepa.py @@ -417,6 +417,25 @@ def __call__( class IJEPADataSampler: + """ + Implements a data sampler for the IJEPA model. + + The data sampler is used to sample data for the IJEPA model. + It samples the scale of the target block using a uniform random distribution and scales it within the target scale range. + Also samples the scale of the context using a uniform random distribution and scales it within the context scale range. + + Attributes: + image_size (int): The size of the image. + patch_size (int): The size of the patches into which the image is divided. + M (int): The number of patches. + context_scale_range (tuple): The range of scales for the context. + target_scale_range (tuple): The range of scales for the target. + target_aspect_ratio_range (tuple): The range of aspect ratios for the target. + h (int): The height of the image divided by the patch size. + w (int): The width of the image divided by the patch size. + to_scale (function): A function to scale a value within a specified range. + random_key (int): A seed for generating random numbers. + """ to_scale: Any = lambda self, x, a, b: (b - a) * x + a random_key: int = 0 random_key = jax.random.PRNGKey(random_key) @@ -520,6 +539,26 @@ def __call__(self) -> Tuple[jnp.ndarray, jnp.ndarray]: class IJEPADataParallelTrainer: + """ + Implements a parallel trainer for the IJEPA model. + + The IJEPADataParallelTrainer is used to train the IJEPA model in parallel. + + Attributes: + model (Any): The model to be trained. + input_shape (Tuple[int, ...]): The shape of the input data. + weights_filename (str): The filename of the weights of the model. + data_sampler (IJEPADataSampler): The data sampler used to sample data for training. + learning_rate (float): The learning rate for training. Default is 1e-4. + params_path (str, optional): The path to the parameters of the model. Default is None. + params (Any): The parameters of the model. Initialized as None. + num_parameters (int): The number of parameters in the model. Initialized as None. + best_val_loss (float): The best validation loss achieved during training. Initialized as infinity. + num_devices (int): The number of devices used for training. + train_step (function): The function used to perform a training step. + evaluation_step (function): The function used to perform an evaluation step. + state (Any): The state of the model during training. + """ def __init__( self, model: Any, diff --git a/nanodl/__src/models/lamda.py b/nanodl/__src/models/lamda.py index 18e59ef..df42377 100644 --- a/nanodl/__src/models/lamda.py +++ b/nanodl/__src/models/lamda.py @@ -19,10 +19,6 @@ class RelativeMultiHeadAttention(nn.Module): hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes the projections for query, key, value, and output. - __call__(inputs, context, mask, clip): Processes the input and context tensors through the relative multi-head attention mechanism. - attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors, incorporating relative position information. """ hidden_dim: int @@ -129,9 +125,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -159,8 +152,6 @@ class AddNorm(nn.Module): Attributes: dropout (float): Dropout rate for the residual connection. - Methods: - __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ dropout: int @@ -208,11 +199,7 @@ class LaMDABlock(nn.Module): feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the LaMDA block. - causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. - __call__(x, mask, training): Processes the input tensor through the LaMDA block. - """ + M""" hidden_dim: int num_heads: int @@ -280,9 +267,6 @@ class LaMDADecoder(nn.Module): vocab_size (float): Size of the vocabulary. embed_dim (float): Dimensionality of the token embeddings. - Methods: - setup(): Initializes the components of the LaMDA decoder. - __call__(x, mask, training, drop_last_layer): Processes the input tensor through the LaMDA decoder. """ num_layers: int @@ -348,8 +332,6 @@ class LaMDA(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the LaMDA model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the LaMDA model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. diff --git a/nanodl/__src/models/llama.py b/nanodl/__src/models/llama.py index e24e4fb..dd97909 100644 --- a/nanodl/__src/models/llama.py +++ b/nanodl/__src/models/llama.py @@ -18,11 +18,6 @@ class RotaryPositionalEncoding: Attributes: dim_model (int): The dimensionality of the model embeddings. - Methods: - _update_cos_sin_tables(x, seq_dimension): Updates cosine and sine tables based on the sequence length. - rotate_half(x): Rotates the last half of the dimensions of x by swapping them and changing signs to simulate a 90-degree rotation. - apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. - __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. """ def __init__(self, dim_model: int): @@ -79,11 +74,6 @@ class GroupedRotaryMultiHeadAttention(nn.Module): num_heads (int): Number of attention heads. num_groups (int): Number of groups to split the heads into for applying rotary positional embeddings separately. - Methods: - setup(): Initializes the projections for query, key, value, and output, along with the rotary positional encoder. - __call__(inputs, context, mask): Processes the input and context tensors through the grouped rotary multi-head attention mechanism. - process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding and attention. - attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int # Output dimension @@ -185,8 +175,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ @@ -219,10 +207,6 @@ class Llama3DecoderBlock(nn.Module): dropout (float): Dropout rate for regularization. num_groups (int): Number of groups for the grouped rotary positional embeddings. - Methods: - setup(): Initializes the components of the Llama3 decoder block. - causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. - __call__(x, training): Processes the input tensor through the Llama3 decoder block. """ hidden_dim: int @@ -304,9 +288,6 @@ class Llama3Decoder(nn.Module): vocab_size (float): Size of the vocabulary. embed_dim (float): Dimensionality of the token embeddings. - Methods: - setup(): Initializes the components of the Llama3 decoder. - __call__(x, training, drop_last_layer): Processes the input tensor through the Llama3 decoder. """ num_layers: int @@ -374,8 +355,6 @@ class Llama3(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the Llama3 model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the Llama3 model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. diff --git a/nanodl/__src/models/mistral.py b/nanodl/__src/models/mistral.py index 352d9df..d823f0a 100644 --- a/nanodl/__src/models/mistral.py +++ b/nanodl/__src/models/mistral.py @@ -18,11 +18,6 @@ class RotaryPositionalEncoding: Attributes: dim_model (int): The dimensionality of the model embeddings. - Methods: - _update_cos_sin_tables(x, seq_dimension): Updates cosine and sine tables based on the sequence length. - rotate_half(x): Rotates the last half of the dimensions of x by swapping them and changing signs to simulate a 90-degree rotation. - apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. - __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. """ def __init__(self, dim_model: int): @@ -81,13 +76,6 @@ class GroupedRotaryShiftedWindowMultiHeadAttention(nn.Module): window_size (int): Size of each window for processing local context. shift_size (int): Number of positions to shift the window at each layer to capture global context. - Methods: - setup(): Initializes the projections for query, key, value, and output, along with the rotary positional encoder. - __call__(inputs, context, mask): Processes the input and context tensors through the grouped rotary and shifted window multi-head attention mechanism. - process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding, shifted window partitioning, and attention. - window_partition(x): Partitions the input tensor into windows of a specified size. - attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors within each window. - causal_mask(shape): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism within windows. """ hidden_dim: int # Output dimension @@ -235,9 +223,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ hidden_dim: int @@ -271,9 +256,6 @@ class MistralDecoderBlock(nn.Module): window_size (int): Size of each window for processing local context. shift_size (int): Number of positions to shift the window at each layer to capture global context. - Methods: - setup(): Initializes the components of the Mistral decoder block. - __call__(x, training): Processes the input tensor through the Mistral decoder block. """ hidden_dim: int @@ -347,9 +329,6 @@ class MistralDecoder(nn.Module): window_size (int): Window size used in grouped rotary shifted window multi-head attention. shift_size (int): Shift size used in grouped rotary shifted window multi-head attention. - Methods: - setup(): Initializes the components of the Mistral decoder. - __call__(x, training, drop_last_layer): Processes the input tensor through the Mistral decoder. """ num_layers: int @@ -423,8 +402,6 @@ class Mistral(nn.Module): shift_size (int): Shift size used in grouped rotary shifted window multi-head attention. Methods: - setup(): Initializes the Mistral model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the Mistral model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. @@ -666,19 +643,6 @@ class SparseMixtureOfExperts(nn.Module): num_experts (int): Number of experts. top_k (int): Number of top experts to use for each input instance. - Methods: - setup(): Initializes the experts, the gating mechanism, and the final dense layer. - - __call__(X: jnp.ndarray) -> jnp.ndarray: - Performs a forward pass through the Mixture of Experts layer. - - Args: - X (jnp.ndarray): Input tensor of shape (batch_size, seq_length, input_dim). - - Returns: - jnp.ndarray: Output tensor after processing through the MoE layer. The output - tensor has the same batch and sequence length dimensions as the input tensor, - but the last dimension is equal to num_outputs. """ num_hiddens: int @@ -730,9 +694,6 @@ class MixtralDecoderBlock(nn.Module): window_size (int): Size of each window for processing local context. shift_size (int): Number of positions to shift the window at each layer to capture global context. - Methods: - setup(): Initializes the components of the Mixtral decoder block. - __call__(x, training): Processes the input tensor through the Mixtral decoder block. """ hidden_dim: int @@ -808,9 +769,6 @@ class MixtralDecoder(nn.Module): window_size (int): Window size used in grouped rotary shifted window multi-head attention. shift_size (int): Shift size used in grouped rotary shifted window multi-head attention. - Methods: - setup(): Initializes the components of the Mixtral decoder. - __call__(x, training, drop_last_layer): Processes the input tensor through the Mixtral decoder. """ num_layers: int @@ -884,8 +842,6 @@ class Mixtral(nn.Module): shift_size (int): Shift size used in grouped rotary shifted window multi-head attention. Methods: - setup(): Initializes the Mixtral model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the Mixtral model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. diff --git a/nanodl/__src/models/mixer.py b/nanodl/__src/models/mixer.py index 53544cc..7bf276a 100644 --- a/nanodl/__src/models/mixer.py +++ b/nanodl/__src/models/mixer.py @@ -19,9 +19,6 @@ class PatchEmbedding(nn.Module): patch_size (tuple): Size (height, width) of the patches to extract from input images. embed_dim (int): Dimension of the embeddings for the patches. - Methods: - __call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding. - extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images. """ patch_size: Tuple[int, int] @@ -63,8 +60,6 @@ class MixerBlock(nn.Module): The Mixer block applies a two-step mixing process: the first step mixes per-location features across the channel dimension, and the second step mixes per-channel features across spatial locations. It aims to capture both channel-wise and spatial interactions within the input. - Methods: - __call__(x): Processes the input tensor through the Mixer block. """ @nn.compact @@ -94,9 +89,6 @@ class MixerEncoder(nn.Module): feedforward_dim (int): Dimensionality of the feedforward network within the MixerBlock. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the MixerEncoder. - __call__(x, training): Processes the input tensor through the encoder. """ patch_size: Tuple[int, int] @@ -137,10 +129,6 @@ class Mixer(nn.Module): dropout (float): Dropout rate for regularization. n_outputs (int): Number of output classes. - Methods: - setup(): Initializes the components of the Mixer model. - __call__(x, training): Processes the input tensor through the model and produces class logits. - MLP Mixers are a recent architectural innovation in the field of deep learning, introduced to address the limitations of traditional Convolutional Neural Networks (CNNs) and Transformers. The motivation behind MLP Mixers arises from the need to handle diverse data types and leverage multi-modal information efficiently. Unlike transformers that rely on self-attention mechanisms, MLP Mixers employ a simple yet powerful approach using Multi-Layer Perceptrons (MLPs) to process data. This architecture is designed to work with sequences, images, or even a combination of both, diff --git a/nanodl/__src/models/reward.py b/nanodl/__src/models/reward.py index 17550c9..04406fa 100644 --- a/nanodl/__src/models/reward.py +++ b/nanodl/__src/models/reward.py @@ -16,59 +16,64 @@ class RewardModel(nn.Module): It uses the last hidden state of a transformer-based model to generate a scalar reward prediction, guiding the agent's behavior by evaluating the desirability or utility of its generated outputs. + Args: + model (nn.Module): The neural network model to be used. + dim (int): The dimension of the input data. + dropout (float): The dropout rate for the model, a value between 0 and 1. + Example: - ```python - from nanodl import ArrayDataset, DataLoader - from nanodl import Gemma, RewardModel, RewardDataParallelTrainer - - # Generate dummy data - batch_size = 8 - max_length = 10 - - # Replace with actual tokenised data - dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) - dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) - - # Create dataset and dataloader - dataset = ArrayDataset(dummy_chosen, dummy_rejected) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) - - # model parameters - hyperparams = { - 'num_layers': 1, - 'hidden_dim': 256, - 'num_heads': 2, - 'feedforward_dim': 256, - 'dropout': 0.1, - 'vocab_size': 1000, - 'embed_dim': 256, - 'max_length': max_length, - 'start_token': 0, - 'end_token': 50, - 'num_groups': 2, - } - - # Initialize reward model from Gemma - model = Gemma(**hyperparams) - reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) - - # Train the reward model - trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') - trainer.train(dataloader, 5, dataloader) - params = trainer.load_params('reward_model_weights.pkl') - - # Call as you would a regular Flax model - rngs = jax.random.PRNGKey(0) - rngs, dropout_rng = jax.random.split(rngs) - rewards = reward_model.apply({'params': params}, - dummy_chosen, - rngs={'dropout': dropout_rng}) - - print(rewards.shape) - ``` + ```python + from nanodl import ArrayDataset, DataLoader + from nanodl import Gemma, RewardModel, RewardDataParallelTrainer + + # Generate dummy data + batch_size = 8 + max_length = 10 + + # Replace with actual tokenised data + dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) + dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) + + # Create dataset and dataloader + dataset = ArrayDataset(dummy_chosen, dummy_rejected) + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + # model parameters + hyperparams = { + 'num_layers': 1, + 'hidden_dim': 256, + 'num_heads': 2, + 'feedforward_dim': 256, + 'dropout': 0.1, + 'vocab_size': 1000, + 'embed_dim': 256, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + 'num_groups': 2, + } + + # Initialize reward model from Gemma + model = Gemma(**hyperparams) + reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) + + # Train the reward model + trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') + trainer.train(dataloader, 5, dataloader) + params = trainer.load_params('reward_model_weights.pkl') + + # Call as you would a regular Flax model + rngs = jax.random.PRNGKey(0) + rngs, dropout_rng = jax.random.split(rngs) + rewards = reward_model.apply({'params': params}, + dummy_chosen, + rngs={'dropout': dropout_rng}) + + print(rewards.shape) + ``` """ model: nn.Module diff --git a/nanodl/__src/models/t5.py b/nanodl/__src/models/t5.py index e65d832..31697c0 100644 --- a/nanodl/__src/models/t5.py +++ b/nanodl/__src/models/t5.py @@ -19,10 +19,6 @@ class RelativeMultiHeadAttention(nn.Module): hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes the projections for query, key, value, and output. - __call__(inputs, context, mask, clip): Processes the input and context tensors through the relative multi-head attention mechanism. - attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors, incorporating relative position information. """ hidden_dim: int # Output dimension @@ -129,9 +125,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -159,8 +152,6 @@ class AddNorm(nn.Module): Attributes: dropout (float): Dropout rate for the residual connection. - Methods: - __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ dropout: int @@ -185,9 +176,6 @@ class T5EncoderBlock(nn.Module): feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the T5 encoder block. - __call__(x, mask, training): Processes the input tensor through the encoder block. """ hidden_dim: int @@ -229,9 +217,6 @@ class T5Encoder(nn.Module): vocab_size (float): Size of the vocabulary. embed_dim (float): Dimensionality of the token embeddings. - Methods: - setup(): Initializes the components of the T5 encoder. - __call__(x, mask, training): Processes the input tensor through the encoder. """ num_layers: int @@ -278,9 +263,6 @@ class T5DecoderBlock(nn.Module): feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the T5 decoder block. - __call__(x, context, training): Processes the input tensor through the decoder block, incorporating context from the encoder. """ hidden_dim: int @@ -349,9 +331,6 @@ class T5Decoder(nn.Module): vocab_size (float): Size of the vocabulary. embed_dim (float): Dimensionality of the token embeddings. - Methods: - setup(): Initializes the components of the T5 decoder. - __call__(x, context, training): Processes the input tensor through the decoder, incorporating context from the encoder. """ num_layers: int @@ -413,8 +392,6 @@ class T5(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the T5 model including both the encoder and decoder components. - __call__(x, y, training): Processes the input tensor through the T5 model, generating predictions. generate(x, temperature, deterministic): Generates output sequences from input sequences. generate_batch(x, temperature, deterministic): Generates output sequences for a batch of input sequences. diff --git a/nanodl/__src/models/transformer.py b/nanodl/__src/models/transformer.py index f0baf15..4afceff 100644 --- a/nanodl/__src/models/transformer.py +++ b/nanodl/__src/models/transformer.py @@ -19,9 +19,6 @@ class PositionalEncoding(nn.Module): num_embeddings (int): The maximum number of positions for which to generate positional encodings. features (int): The dimensionality of the embeddings/positional encodings. - Methods: - setup(): Initializes the positional encoding matrix based on the provided attributes. - __call__(x: jnp.ndarray): Adds positional encodings to the input embeddings. """ num_embeddings: int @@ -59,9 +56,6 @@ class TokenAndPositionEmbedding(nn.Module): embed_dim (int): Dimension of the embeddings. learned_position (bool): Flag to use learned positional embeddings instead of fixed positional encodings. - Methods: - setup(): Initializes token and positional embeddings. - __call__(x: jnp.ndarray): Applies token embeddings and adds positional information to the input sequence. """ max_len: int @@ -101,10 +95,6 @@ class MultiHeadAttention(nn.Module): hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes projection matrices for queries, keys, values, and the output projection. - __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. - attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int @@ -186,9 +176,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -216,8 +203,6 @@ class AddNorm(nn.Module): Attributes: dropout (float): Dropout rate for the residual connection. - Methods: - __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ dropout: int @@ -241,9 +226,6 @@ class TransformerEncoderBlock(nn.Module): feedforward_dim (int): Dimension of the feed-forward network. dropout (float): Dropout rate. - Methods: - setup(): Initializes the attention, feed-forward network, and normalization layers. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block. """ hidden_dim: int @@ -286,9 +268,6 @@ class TransformerEncoder(nn.Module): embed_dim (int): Dimension of the embeddings. learned_position (bool): Flag to use learned positional embeddings instead of fixed positional encodings. - Methods: - setup(): Initializes the embedding layer and the encoder blocks. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the transformer encoder. """ num_layers: int @@ -336,9 +315,6 @@ class TransformerDecoderBlock(nn.Module): feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the Transformer decoder block. - __call__(x, context, training): Processes the input tensor through the decoder block. """ hidden_dim: int @@ -409,9 +385,6 @@ class TransformerDecoder(nn.Module): embed_dim (float): Dimensionality of the token embeddings. learned_position (bool): Indicates if positional embeddings are learned or static. - Methods: - setup(): Initializes the components of the Transformer decoder. - __call__(x, context, training): Processes the input tensor through the decoder. """ num_layers: int @@ -475,8 +448,6 @@ class Transformer(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the Transformer model including both the encoder and decoder components. - __call__(x, y, training): Processes the input tensor through the Transformer model, generating predictions. generate(x, temperature, deterministic): Generates output sequences from input sequences. generate_batch(x, temperature, deterministic): Generates output sequences for a batch of input sequences. @@ -568,7 +539,7 @@ class Transformer(nn.Module): rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) - ``` + ``` """ num_layers: int diff --git a/nanodl/__src/models/vit.py b/nanodl/__src/models/vit.py index 4c5606e..913d937 100644 --- a/nanodl/__src/models/vit.py +++ b/nanodl/__src/models/vit.py @@ -19,9 +19,6 @@ class PatchEmbedding(nn.Module): patch_size (tuple): Size (height, width) of the patches to extract from input images. embed_dim (int): Dimension of the embeddings for the patches. - Methods: - __call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding. - extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images. """ patch_size: Tuple[int, int] @@ -67,10 +64,6 @@ class SelfMultiHeadAttention(nn.Module): hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes projection matrices for queries, keys, values, and the output projection. - __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. - attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int # Output dimension @@ -140,9 +133,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -169,23 +159,12 @@ class AddNorm(nn.Module): Attributes: dropout (float): Dropout rate for the residual connection. - Methods: - __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ dropout: int @nn.compact def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: - """ - Apply AddNorm to input tensors. - Args: - X (jnp.ndarray): Input tensor X. - Y (jnp.ndarray): Input tensor Y. - training (bool): Training mode. - Returns: - jnp.ndarray: Output tensor after applying AddNorm. - """ return nn.LayerNorm()( nn.Dropout(self.dropout)(Y, deterministic=not training) + X ) @@ -203,9 +182,6 @@ class ViTBlock(nn.Module): feedforward_dim (int): Dimension of the feed-forward network. dropout (float): Dropout rate. - Methods: - setup(): Initializes the attention, feed-forward network, and normalization layers. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block. """ hidden_dim: int @@ -246,9 +222,6 @@ class ViTEncoder(nn.Module): feedforward_dim (int): Dimension of the feed-forward network in the transformer encoder. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the patch embedding and encoder blocks. - __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input images through the vision transformer encoder. """ patch_size: Tuple[int, int] @@ -294,10 +267,6 @@ class ViT(nn.Module): feedforward_dim (int): Dimensionality of the feedforward network within each Transformer encoder layer. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the ViTEncoder. - __call__(x, mask, training): Processes the input tensor through the encoder, returning encoded features and attention maps. - Vision Transformers, or ViTs, have emerged as a groundbreaking architectural paradigm in computer vision and deep learning. The motivation behind Vision Transformers lies in the desire to extend the success of transformers, originally designed for natural language processing, to visual data. These models aim to replace diff --git a/nanodl/__src/models/whisper.py b/nanodl/__src/models/whisper.py index 437cd51..6b5a826 100644 --- a/nanodl/__src/models/whisper.py +++ b/nanodl/__src/models/whisper.py @@ -15,9 +15,6 @@ class SpeechEmbedding(nn.Module): This layer applies two convolutional operations followed by GELU activations to the input audio signals. The first convolution maintains the sequence length, while the second halves it. Additionally, it adds sinusoidal embeddings to capture positional information within the audio sequence. - Methods: - __call__(x): Processes the input audio tensor through the convolutional layers and adds sinusoidal embeddings. - sinusoidal_embedding(x, max_position): Generates sinusoidal embeddings based on the sequence length and hidden dimension of the input. """ @nn.compact @@ -51,9 +48,6 @@ class PositionalEncoding(nn.Module): num_embeddings (int): The maximum number of positions for which to generate positional encodings. features (int): The dimensionality of the embeddings/positional encodings. - Methods: - setup(): Initializes the positional encoding matrix based on the provided attributes. - __call__(x: jnp.ndarray): Adds positional encodings to the input embeddings. """ num_embeddings: int @@ -82,6 +76,7 @@ def __call__(self, x): class TokenAndPositionEmbedding(nn.Module): """ Token and Position Embedding. + Args: max_len (int): Maximum sequence length. vocab_size (int): Vocabulary size. @@ -125,10 +120,6 @@ class MultiHeadAttention(nn.Module): hidden_dim (int): Dimensionality of the input and output features. num_heads (int): Number of attention heads. - Methods: - setup(): Initializes projection matrices for queries, keys, values, and the output projection. - __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. - attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ hidden_dim: int # Output dimension @@ -211,9 +202,6 @@ class PositionWiseFFN(nn.Module): num_hiddens (int): The number of hidden units in the first linear layer. num_outputs (int): The number of output units in the second linear layer (usually the same as the model's hidden size). - Methods: - setup(): Initializes the two linear layers. - __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ num_hiddens: int @@ -261,9 +249,6 @@ class WhisperSpeechEncoderBlock(nn.Module): feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the WhisperSpeechEncoderBlock. - __call__(x, mask, training): Processes the input tensor through the encoder block. """ hidden_dim: int @@ -303,9 +288,6 @@ class WhisperSpeechEncoder(nn.Module): feedforward_dim (int): Dimensionality of the feedforward network within each encoder block. dropout (float): Dropout rate used for regularization. - Methods: - setup(): Initializes the components of the WhisperSpeechEncoder. - __call__(x, mask, training): Processes the input audio tensor through the encoder, returning encoded features and attention maps. """ num_layers: int @@ -348,9 +330,6 @@ class WhisperTextDecoderBlock(nn.Module): feedforward_dim (int): Dimensionality of the inner layer of the feed-forward network. dropout (float): Dropout rate for regularization. - Methods: - setup(): Initializes the components of the Transformer decoder block. - __call__(x, context, training): Processes the input tensor through the decoder block. """ hidden_dim: int @@ -421,9 +400,6 @@ class WhisperTextDecoder(nn.Module): embed_dim (float): Dimensionality of the token embeddings. learned_position (bool): Indicates if positional embeddings are learned or static. - Methods: - setup(): Initializes the components of the Transformer decoder. - __call__(x, context, training): Processes the input tensor through the decoder. """ num_layers: int @@ -487,8 +463,6 @@ class Whisper(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the Whisper model including both the encoder and decoder components. - __call__(x, y, training): Processes the input audio tensor through the Whisper model, generating textual predictions. generate(x, temperature, deterministic): Generates textual output from input audio sequences. Whisper uses an encoder-decoder Transformer (Vaswani et al., 2017) as this, All audio is re-sampled to 16,000 Hz, and an 80-channel logmagnitude Mel spectrogram representation is computed on diff --git a/nanodl/__src/utils/data.py b/nanodl/__src/utils/data.py index c004d9c..6e4746b 100644 --- a/nanodl/__src/utils/data.py +++ b/nanodl/__src/utils/data.py @@ -110,7 +110,7 @@ def __init__( self.shuffle = shuffle self.drop_last = drop_last - self.keys = PRNGSequence(seed=Config.default().global_seed) + self.keys = __PRNGSequence(seed=Config.default().global_seed) self.data_len = len(dataset) # Length of the dataset self.indices = jnp.arange(self.data_len) # available indices in the dataset self.pose = 0 # record the current position in the dataset @@ -162,7 +162,7 @@ def default(cls): return cls(rng_reserve_size=1, global_seed=42) -class PRNGSequence(Iterator[jax.random.PRNGKey]): +class __PRNGSequence(Iterator[jax.random.PRNGKey]): """ An Iterator of Jax PRNGKey (minimal version of `haiku.PRNGSequence`). diff --git a/nanodl/__src/utils/ml.py b/nanodl/__src/utils/ml.py index d0c41e3..b34b545 100644 --- a/nanodl/__src/utils/ml.py +++ b/nanodl/__src/utils/ml.py @@ -19,12 +19,12 @@ def batch_cosine_similarities( jnp.ndarray: Array of cosine similarity scores of shape (N,). Example usage: - ``` + ``` >>> source = jnp.array([1, 0, 0]) >>> candidates = jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) >>> similarities = batch_cosine_similarities(source, candidates) >>> print(similarities) - ``` + ``` """ dot_products = jnp.einsum("ij,j->i", candidates, source) norm_source = jnp.sqrt(jnp.einsum("i,i->", source, source)) @@ -45,12 +45,12 @@ def batch_pearsonr(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: jnp.ndarray: Array of Pearson correlation coefficients of shape (N,). Example usage: - ``` + ``` >>> x = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> y = jnp.array([[1, 5, 7], [2, 6, 8]]) >>> correlations = batch_pearsonr(x, y) >>> print(correlations) - ``` + ``` """ x = jnp.asarray(x).T y = jnp.asarray(y).T @@ -76,11 +76,11 @@ def classification_scores(labels: jnp.ndarray, preds: jnp.ndarray) -> jnp.ndarra jnp.ndarray: Array containing accuracy, precision, recall, and F1-score. Example usage: - ``` + ``` >>> labels = jnp.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]) >>> preds = jnp.array([1, 1, 1, 0, 1, 0, 1, 0, 0, 0]) >>> print(classification_scores(labels, preds)) - ``` + ``` """ true_positives = jnp.sum(jnp.logical_and(preds == 1, labels == 1)) true_negatives = jnp.sum(jnp.logical_and(preds == 0, labels == 0)) @@ -100,14 +100,14 @@ def mean_reciprocal_rank(predictions: jnp.ndarray) -> float: Calculate the Mean Reciprocal Rank (MRR) for a list of ranked predictions using JAX. Example usage: - ``` + ``` predictions = jnp.array([ [0, 1, 2], # "correct" prediction at index 0 [1, 0, 2], # "correct" prediction at index 1 [2, 1, 0] # "correct" prediction at index 2 ]) mrr_score = mean_reciprocal_rank(predictions) - ``` + ``` Args: predictions (jnp.ndarray): 2D array where each row contains ranked predictions @@ -135,12 +135,12 @@ def jaccard(sequence1: List, sequence2: List) -> float: float: Jaccard similarity score. Example usage: - ```py + ```py >>> sequence1 = [1, 2, 3] >>> sequence2 = [2, 3, 4] >>> similarity = jaccard(sequence1, sequence2) >>> print(similarity) - ``` + ``` """ numerator = len(set(sequence1).intersection(sequence2)) denominator = len(set(sequence1).union(sequence2)) @@ -160,12 +160,12 @@ def hamming(sequence1: jnp.ndarray, sequence2: jnp.ndarray) -> int: int: Hamming similarity score. Example usage: - ```py + ```py >>> sequence1 = jnp.array([1, 2, 3, 4]) >>> sequence2 = jnp.array([1, 2, 4, 4]) >>> similarity = hamming_jax(sequence1, sequence2) >>> print(similarity) - ``` + ``` """ return jnp.sum(sequence1 == sequence2) @@ -186,14 +186,14 @@ def zero_pad_sequences(arr: jnp.array, max_length: int) -> jnp.array: jax.numpy.ndarray: The zero-padded array. Example usage: - ```py + ```py >>> arr = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> max_length = 5 >>> padded_arr = zero_pad_sequences(arr, max_length) >>> print(padded_arr) [[1 2 3 0 0] [4 5 6 0 0]] - ``` + ``` """ current_length = arr.shape[1] num_zeros = max_length - current_length @@ -213,10 +213,10 @@ def entropy(probabilities: jnp.ndarray) -> float: Calculate the entropy of a probability distribution using JAX. Example usage: - ``` + ``` probabilities = jnp.array([0.25, 0.75]) entropy_value = entropy(probabilities) - ``` + ``` Args: probabilities (jnp.ndarray): Array of probability values. @@ -235,10 +235,10 @@ def gini_impurity(probabilities: jnp.ndarray) -> float: Calculate the Gini impurity of a probability distribution using JAX. Example usage: - ``` + ``` probabilities = jnp.array([0.25, 0.75]) gini_value = gini_impurity(probabilities) - ``` + ``` Args: probabilities (jnp.ndarray): Array of probability values. @@ -256,11 +256,11 @@ def kl_divergence(p: jnp.ndarray, q: jnp.ndarray) -> float: Calculate the Kullback-Leibler (KL) divergence between two probability distributions using JAX. Example usage: - ``` + ``` p = jnp.array([0.25, 0.75]) q = jnp.array([0.5, 0.5]) kl_value = kl_divergence(p, q) - ``` + ``` Args: p (jnp.ndarray): Array of probability values for distribution p. @@ -279,11 +279,11 @@ def count_parameters(params: Any) -> int: Count the total number of parameters in a model's parameter dictionary using JAX. Example usage: - ``` + ``` model = MyModel() params = model.init(jax.random.PRNGKey(0), jnp.ones(input_shape)) total_params = count_parameters(params) - ``` + ``` Args: params (Any): Model's parameter dictionary. diff --git a/nanodl/__src/utils/nlp.py b/nanodl/__src/utils/nlp.py index f763615..bfd521b 100644 --- a/nanodl/__src/utils/nlp.py +++ b/nanodl/__src/utils/nlp.py @@ -21,12 +21,12 @@ def rouge( dict: Dictionary containing precision, recall, and F1-score for each n-gram size. Example usage: - ``` + ``` >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] >>> references = ["the cat is on the mat", "the cat sits on the mat"] >>> rouge_scores = rouge(hypotheses, references, [1, 2]) >>> print(rouge_scores) - ``` + ``` """ def ngrams(sequence: List[str], n: int) -> List[str]: @@ -96,12 +96,12 @@ def bleu(hypotheses: List[str], references: List[str], max_ngram: int = 4) -> fl float: BLEU score. Example usage: - ``` + ``` >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] >>> references = ["the cat is on the mat", "the cat sits on the mat"] >>> bleu_score = bleu(hypotheses, references) >>> print(bleu_score) - ``` + ``` """ def ngrams(sequence: List[str], n: int) -> List[str]: @@ -156,12 +156,12 @@ def meteor(hypothesis: str, reference: str) -> float: float: METEOR score. Example usage: - ``` + ``` >>> hypothesis = "the cat is on the mat" >>> reference = "the cat sits on the mat" >>> meteor_score = meteor(hypothesis, reference) >>> print(meteor_score) - ``` + ``` """ def tokenize(sentence): @@ -219,12 +219,12 @@ def cider_score(hypothesis: str, reference: str) -> float: float: CIDEr score. Example usage: - ``` + ``` >>> hypothesis = "the cat is on the mat" >>> reference = "the cat sits on the mat" >>> score = cider_score(hypothesis, reference) >>> print(score) - ``` + ``` """ def tokenize(sentence): @@ -280,11 +280,11 @@ def perplexity(log_probs: List[float]) -> float: float: Perplexity score. Example usage: - ``` + ``` >>> log_probs = [-2.3, -1.7, -0.4] # Example log probabilities >>> perplexity_score = perplexity(log_probs) >>> print(perplexity_score) - ``` + ``` """ log_likelihood = 0.0 word_count = 0 @@ -313,12 +313,12 @@ def word_error_rate(hypotheses: List[int], references: List[int]) -> float: float: Word Error Rate score. Example usage: - ``` + ``` >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] >>> references = ["the cat is on the mat", "the cat sits on the mat"] >>> wer_score = word_error_rate(hypotheses, references) >>> print(wer_score) - ``` + ``` """ def edit_distance(str1, str2): diff --git a/nanodl/__src/utils/random.py b/nanodl/__src/utils/random.py index fd9fd5f..86d0b49 100644 --- a/nanodl/__src/utils/random.py +++ b/nanodl/__src/utils/random.py @@ -18,18 +18,18 @@ def time_rng_key(seed=None) -> jnp.ndarray: def uniform( shape: Tuple[int, ...], - dtype: Any = jnp.float32, - minval: float = 0.0, - maxval: float = 1.0, + minval: Any = 0.0, + maxval: Any = 1.0, seed=None, + dtype: Any = jnp.float32, ) -> jnp.ndarray: """Generate a tensor of uniform random values. Args: shape (Tuple[int, ...]): The shape of the output tensor. dtype (Any, optional): The data type of the output tensor. Defaults to jnp.float32. - minval (float, optional): The lower bound of the uniform distribution. Defaults to 0.0. - maxval (float, optional): The upper bound of the uniform distribution. Defaults to 1.0. + minval (Any, optional): The lower bound of the uniform distribution. Defaults to 0.0. + maxval (Any, optional): The upper bound of the uniform distribution. Defaults to 1.0. Returns: jnp.ndarray: A tensor of uniform random values. diff --git a/nanodl/__src/utils/vision.py b/nanodl/__src/utils/vision.py index 9dc3383..c8c7b7d 100644 --- a/nanodl/__src/utils/vision.py +++ b/nanodl/__src/utils/vision.py @@ -17,11 +17,11 @@ def normalize_images(images: jnp.ndarray) -> jnp.ndarray: jnp.ndarray: Normalized images of the same shape as the input. Example usage: - ``` + ``` >>> images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]]) # One image of shape (1, 2, 2, 1) >>> normalized_images = normalize_images(images) >>> print(normalized_images) - ``` + ``` """ mean = images.mean(axis=(1, 2, 3), keepdims=True) std = images.std(axis=(1, 2, 3), keepdims=True) @@ -45,12 +45,12 @@ def random_crop(images: jnp.ndarray, crop_size: int) -> jnp.ndarray: jax.numpy.ndarray: The cropped images, with shape (batch_size, crop_size, crop_size, channels). Example usage: - ``` + ``` >>> images = jnp.ones((10, 100, 100, 3)) # Batch of 10 images of size 100x100 with 3 channels >>> crop_size = 64 >>> cropped_images = random_crop(images, crop_size) >>> print(cropped_images.shape) - ``` + ``` """ key = jax.random.PRNGKey(int(time.time())) _, height, width, _ = images.shape @@ -75,11 +75,11 @@ def gaussian_blur(image: jnp.ndarray, kernel_size: int, sigma: float) -> jnp.nda jnp.ndarray: Blurred image of the same shape as the input. Example usage: - ``` + ``` >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels >>> blurred_image = gaussian_blur(image, kernel_size=3, sigma=1.0) >>> print(blurred_image.shape) - ``` + ``` """ assert kernel_size % 2 == 1, "Kernel size must be odd." ax = jnp.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0) @@ -110,11 +110,11 @@ def sobel_edge_detection(image: jnp.ndarray) -> jnp.ndarray: jnp.ndarray: Image representing the edges, of the same shape as the input. Example usage: - ``` + ``` >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels >>> edges = sobel_edge_detection(image) >>> print(edges.shape) - ``` + ``` """ sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) @@ -148,11 +148,11 @@ def adjust_brightness(image: jnp.ndarray, factor: float) -> jnp.ndarray: jnp.ndarray: Brightness-adjusted image of the same shape as the input. Example usage: - ``` + ``` >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels >>> adjusted_image = adjust_brightness(image, factor=1.5) >>> print(adjusted_image.shape) - ``` + ``` """ return jnp.clip(image * factor, 0, 1) @@ -171,11 +171,11 @@ def adjust_contrast(image: jnp.ndarray, factor: float) -> jnp.ndarray: jnp.ndarray: Contrast-adjusted image of the same shape as the input. Example usage: - ``` + ``` >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels >>> adjusted_image = adjust_contrast(image, factor=1.5) >>> print(adjusted_image.shape) - ``` + ``` """ mean = jnp.mean(image, axis=(0, 1), keepdims=True) return jnp.clip((image - mean) * factor + mean, 0, 1) @@ -195,12 +195,12 @@ def flip_image(image: jnp.ndarray, horizontal: jnp.ndarray) -> jnp.ndarray: jnp.ndarray: Flipped image of the same shape as the input. Example usage: - ``` + ``` >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels >>> flipped_image_horizontally = flip_image(image, jnp.array([True])) >>> flipped_image_vertically = flip_image(image, jnp.array([False])) >>> print(flipped_image_horizontally.shape, flipped_image_vertically.shape) - ``` + ``` """ return jnp.where(horizontal, image[:, ::-1, :], image[::-1, :, :]) @@ -223,12 +223,12 @@ def random_flip_image( jnp.ndarray: Randomly flipped image of the same shape as the input. Example usage: - ``` + ``` >>> key = jax.random.PRNGKey(0) >>> image = jnp.ones((5, 5, 3)) # Example image with 3 channels >>> flipped_image = random_flip_image(image, key, jnp.array([True])) >>> print(flipped_image.shape) - ``` + ``` """ flip = jax.random.uniform(key) > 0.5 flip_horizontal = jnp.where(horizontal, image[:, ::-1, :], image) diff --git a/reward_model_weights.pkl b/reward_model_weights.pkl new file mode 100644 index 0000000..b695080 Binary files /dev/null and b/reward_model_weights.pkl differ