From f72df49a33a563a91604c18c363240b1d6279b45 Mon Sep 17 00:00:00 2001 From: Henry Ndubuaku Date: Sun, 12 May 2024 19:37:27 +0100 Subject: [PATCH] Pushed v1.2.4Dev1 codes --- README.md | 74 +- nanodl/__init__.py | 247 +++--- .../__src/{layers => classical}/__init__.py | 0 .../__src/{sklearn_gpu => classical}/bayes.py | 23 +- .../{sklearn_gpu => classical}/clustering.py | 112 ++- .../dimensionality_reduction.py | 32 +- nanodl/__src/classical/dsp.py | 129 +++ .../{sklearn_gpu => classical}/regression.py | 30 +- nanodl/__src/experimental/bitlinear.py | 72 ++ nanodl/__src/{models => experimental}/gat.py | 113 +-- nanodl/__src/experimental/kan.py | 227 ++++++ .../mamba.py} | 440 ++++++----- nanodl/__src/{models => experimental}/rlhf.py | 262 ++++--- .../{utils => experimental}/tokenizer.py | 55 +- nanodl/__src/layers/general.py | 34 - nanodl/__src/{layers => models}/attention.py | 466 ++++++----- nanodl/__src/models/clip.py | 487 +++++++----- nanodl/__src/models/diffusion.py | 356 +++++---- nanodl/__src/models/gemma.py | 417 +++++----- nanodl/__src/models/gpt.py | 640 ++++++++------- nanodl/__src/models/ijepa.py | 523 ++++++++----- nanodl/__src/models/lamda.py | 475 ++++++----- nanodl/__src/models/llama.py | 478 ++++++----- nanodl/__src/models/mistral.py | 739 ++++++++++-------- nanodl/__src/models/mixer.py | 243 +++--- nanodl/__src/models/reward.py | 197 ++--- nanodl/__src/models/t5.py | 555 +++++++------ nanodl/__src/models/transformer.py | 733 +++++++++-------- nanodl/__src/models/vit.py | 354 +++++---- nanodl/__src/models/whisper.py | 569 ++++++++------ nanodl/__src/sklearn_gpu/__init__.py | 0 nanodl/__src/sklearn_gpu/dsp.py | 126 --- nanodl/__src/utils/data.py | 41 +- nanodl/__src/utils/ml.py | 87 ++- nanodl/__src/utils/nlp.py | 185 +++-- nanodl/__src/utils/random.py | 210 +++-- nanodl/__src/utils/vision.py | 43 +- setup.py | 56 +- tests/test_models.py | 324 +++----- tests/test_random.py | 31 +- tests/test_sklearn_gpu.py | 22 +- tests/test_utils.py | 106 ++- 42 files changed, 5803 insertions(+), 4510 deletions(-) rename nanodl/__src/{layers => classical}/__init__.py (100%) rename nanodl/__src/{sklearn_gpu => classical}/bayes.py (85%) rename nanodl/__src/{sklearn_gpu => classical}/clustering.py (61%) rename nanodl/__src/{sklearn_gpu => classical}/dimensionality_reduction.py (78%) create mode 100644 nanodl/__src/classical/dsp.py rename nanodl/__src/{sklearn_gpu => classical}/regression.py (94%) create mode 100644 nanodl/__src/experimental/bitlinear.py rename nanodl/__src/{models => experimental}/gat.py (69%) create mode 100644 nanodl/__src/experimental/kan.py rename nanodl/__src/{models/mamba_experimental.py => experimental/mamba.py} (63%) rename nanodl/__src/{models => experimental}/rlhf.py (64%) rename nanodl/__src/{utils => experimental}/tokenizer.py (83%) delete mode 100644 nanodl/__src/layers/general.py rename nanodl/__src/{layers => models}/attention.py (53%) delete mode 100644 nanodl/__src/sklearn_gpu/__init__.py delete mode 100644 nanodl/__src/sklearn_gpu/dsp.py diff --git a/README.md b/README.md index 660bc34..d50b196 100644 --- a/README.md +++ b/README.md @@ -15,18 +15,18 @@ Each model is purposefully contained in a file without inter-file dependencies. Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: - A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. -- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. -- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. +- An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc. +- Data-parallel distributed trainers models on multiple GPUs or TPUs, without the need for manual training loops. - Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective. -- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. -- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. -- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. +- Layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. +- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc. - True random number generators in Jax which do not need the verbose code. - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc. - Each model is contained in a single file with no external dependencies, so the source code can also be easily used. - True random number generators in Jax which do not need the verbose code (examples shown in next sections). -There are experimental features (like MAMBA architecture and RLHF) in the repo which are not available via the package, pending tests. +There are experimental and/or unfinished features (like MAMBA, KAN, BitNet, GAT and RLHF) +in the repo which are not yet available via the package, but can be copied from this repo. Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [Discord](https://discord.gg/3u9vumJEmz), or just let us know what you're working on! @@ -58,7 +58,6 @@ We provide various example usages of the nanodl API. ```py import jax import jax.numpy as jnp -from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import GPT4, GPTDataParallelTrainer, Tokenizer @@ -67,29 +66,8 @@ batch_size = 8 max_length = 50 vocab_size = 1000 -text_paths = ['/path/sample1.txt', - '/path/sample2.txt', - '/path/sample3.txt'] - -tokenizer = Tokenizer(training_data=text_paths, - vocab_size=vocab_size, - model_type='bpe', - max_sentence_length=max_length) - -data = [] -for path in text_paths: - with open(path, 'r') as file: - text = file.read() - # To-Do: preprocess however you wish - encoded = list(map(tokenizer.encode, text)) - data.extend(encoded) - -# Pad sequences with 0 -max_length = max(len(seq) for seq in data) -padded = [seq + [0] * (max_length - len(seq)) for seq in data] - -# Jax does not support strings yet, encode before converting to array -data = jnp.array(padded) +# Create random data +data = nanodl.uniform(shape=(batch, max_length)) # Shift to create next-token prediction dataset dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] @@ -124,7 +102,7 @@ trainer = GPTDataParallelTrainer(model, trainer.train(train_loader=dataloader, num_epochs=100, - val_loader=dataloader) #To Do: replace with actual val data + val_loader=dataloader) # use actual val data # Generating from a start token start_tokens = jnp.array([[123, 456]]) @@ -133,11 +111,8 @@ start_tokens = jnp.array([[123, 456]]) params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': time_rng_key()}, + rngs={'dropout': nanodl.time_rng_key()}, method=model.generate) - -# Jax does not support strings yet, convert to list before decoding -outputs = tokenizer.decode(outputs.tolist()) ``` Vision example @@ -145,7 +120,6 @@ Vision example ```py import jax import jax.numpy as jnp -from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import DiffusionModel, DiffusionDataParallelTrainer @@ -154,7 +128,7 @@ block_depth = 2 batch_size = 8 widths = [32, 64, 128] input_shape = (101, image_size, image_size, 3) -images = jax.random.normal(time_rng_key(), input_shape) +images = nanodl.normal(shape=input_shape) # Use your own images dataset = ArrayDataset(images) @@ -169,7 +143,7 @@ trainer = DiffusionDataParallelTrainer(diffusion_model, weights_filename='params.pkl', learning_rate=1e-4) -trainer.train(dataloader, 10, dataloader) +trainer.train(dataloader, 10, dataloader) # use actual val data # Generate some samples: Each model is a Flax.linen module # Use as you normally would @@ -185,7 +159,6 @@ Audio example ```py import jax import jax.numpy as jnp -from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader from nanodl import Whisper, WhisperDataParallelTrainer @@ -287,7 +260,7 @@ params = trainer.load_params('reward_model_weights.pkl') # Call as you would a regular Flax model rewards = reward_model.apply({'params': params}, dummy_chosen, - rngs={'dropout': time_rng_key()}) + rngs={'dropout': nanodl.time_rng_key()}) ``` PCA example @@ -297,7 +270,7 @@ import jax from nanodl import PCA # Use actual data -data = jax.random.normal(jax.random.key(0), (1000, 10)) +data = nanodl.normal(shape=(1000, 10)) # Initialise and train PCA model pca = PCA(n_components=2) @@ -313,24 +286,7 @@ original_data = pca.inverse_transform(transformed_data) X_sampled = pca.sample(n_samples=1000, key=None) ``` -NanoDL provides random module which abstracts away Jax's intricacies. -It generates truly random variables by using the current timestamp as seed. - -```py -import jax - -# Jax example -key = jax.random.PRNGKey(0) -jax_array = jax.random.uniform(key, shape=(3, 3)) - -# NanoDL example -jax_array = nanodl.uniform(shape=(3, 3)) - -# For reproducability, use seed -jax_array = nanodl.uniform(shape=(3, 3), seed=0) -``` - -This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged! +This is still in dev, works great but roughness is expected, and contributions are therefore highly encouraged! - Make your changes without changing the design patterns. - Write tests for your changes if necessary. diff --git a/nanodl/__init__.py b/nanodl/__init__.py index 2042f5b..6361b2a 100644 --- a/nanodl/__init__.py +++ b/nanodl/__init__.py @@ -1,158 +1,118 @@ -__version__ = "1.2.3.dev1" - -from nanodl.__src.sklearn_gpu.bayes import NaiveBayesClassifier -from nanodl.__src.sklearn_gpu.dimensionality_reduction import PCA -from nanodl.__src.sklearn_gpu.clustering import KMeans, GaussianMixtureModel -from nanodl.__src.utils.tokenizer import Tokenizer -from nanodl.__src.utils.random import * - -from nanodl.__src.sklearn_gpu.regression import ( - LinearRegression, - LogisticRegression, - GaussianProcess -) - -from nanodl.__src.models.gat import ( - GAT, - GraphAttentionLayer -) - -from nanodl.__src.models.t5 import ( - T5, - T5DataParallelTrainer, - T5Encoder, - T5Decoder, - T5EncoderBlock, - T5DecoderBlock +__version__ = "1.2.4.dev1" + +from nanodl.__src.classical.bayes import NaiveBayesClassifier +from nanodl.__src.classical.clustering import GaussianMixtureModel, KMeans +from nanodl.__src.classical.dimensionality_reduction import PCA +from nanodl.__src.classical.regression import ( + GaussianProcess, + LinearRegression, + LogisticRegression, ) - -from nanodl.__src.models.vit import ( - ViT, - ViTDataParallelTrainer, - ViTBlock, - ViTEncoder, - PatchEmbedding +from nanodl.__src.experimental.gat import GAT, GraphAttentionLayer +from nanodl.__src.models.attention import ( + GatedMultiHeadAttention, + HierarchicalMultiHeadAttention, + LocalMultiHeadAttention, + MultiQueryAttention, + RotaryMultiHeadAttention, ) - from nanodl.__src.models.clip import ( CLIP, CLIPDataParallelTrainer, ImageEncoder, + SelfMultiHeadAttention, TextEncoder, - SelfMultiHeadAttention ) - -from nanodl.__src.models.lamda import ( - LaMDA, - LaMDADataParallelTrainer, - LaMDABlock, - LaMDADecoder, - RelativeMultiHeadAttention -) - -from nanodl.__src.models.mixer import ( - Mixer, - MixerDataParallelTrainer, - MixerBlock, - MixerEncoder +from nanodl.__src.models.diffusion import ( + DiffusionDataParallelTrainer, + DiffusionModel, + UNet, + UNetDownBlock, + UNetResidualBlock, + UNetUpBlock, ) - -from nanodl.__src.models.llama import ( - LlaMA2, - LlaMADataParallelTrainer, - RotaryPositionalEncoding, - LlaMA2Decoder, - LlaMA2DecoderBlock, - GroupedRotaryMultiHeadAttention +from nanodl.__src.models.gemma import ( + Gemma, + GemmaDataParallelTrainer, + GemmaDecoder, + GemmaDecoderBlock, ) - from nanodl.__src.models.gpt import ( GPT3, GPT4, - GPTDataParallelTrainer, GPT3Block, - GPT4Block, GPT3Decoder, + GPT4Block, GPT4Decoder, - PositionWiseFFN + GPTDataParallelTrainer, + PositionWiseFFN, +) +from nanodl.__src.models.ijepa import IJEPA, IJEPADataParallelTrainer, IJEPADataSampler +from nanodl.__src.models.lamda import ( + LaMDA, + LaMDABlock, + LaMDADataParallelTrainer, + LaMDADecoder, + RelativeMultiHeadAttention, +) +from nanodl.__src.models.llama import ( + GroupedRotaryMultiHeadAttention, + Llama3, + Llama3Decoder, + Llama3DecoderBlock, + LlamaDataParallelTrainer, + RotaryPositionalEncoding, ) - from nanodl.__src.models.mistral import ( + GroupedRotaryShiftedWindowMultiHeadAttention, Mistral, MistralDataParallelTrainer, MistralDecoder, MistralDecoderBlock, - GroupedRotaryShiftedWindowMultiHeadAttention -) - -from nanodl.__src.models.mistral import ( Mixtral, MixtralDecoder, MixtralDecoderBlock, - GroupedRotaryShiftedWindowMultiHeadAttention ) - -from nanodl.__src.models.whisper import ( - Whisper, - WhisperDataParallelTrainer, - WhisperSpeechEncoder, - WhisperSpeechEncoderBlock +from nanodl.__src.models.mixer import ( + Mixer, + MixerBlock, + MixerDataParallelTrainer, + MixerEncoder, ) - -from nanodl.__src.models.diffusion import ( - DiffusionModel, - DiffusionDataParallelTrainer, - UNet, - UNetDownBlock, - UNetUpBlock, - UNetResidualBlock +from nanodl.__src.models.reward import RewardDataParallelTrainer, RewardModel +from nanodl.__src.models.t5 import ( + T5, + T5DataParallelTrainer, + T5Decoder, + T5DecoderBlock, + T5Encoder, + T5EncoderBlock, ) - - from nanodl.__src.models.transformer import ( - Transformer, - TransformerDataParallelTrainer, - TransformerEncoder, - TransformerDecoderBlock, + AddNorm, + MultiHeadAttention, PositionalEncoding, PositionWiseFFN, TokenAndPositionEmbedding, - MultiHeadAttention, - AddNorm -) - -from nanodl.__src.models.gemma import ( - Gemma, - GemmaDataParallelTrainer, - GemmaDecoder, - GemmaDecoderBlock -) - -from nanodl.__src.models.reward import ( - RewardModel, - RewardDataParallelTrainer -) - -from nanodl.__src.models.ijepa import ( - IJEPA, - IJEPADataParallelTrainer, - IJEPADataSampler + Transformer, + TransformerDataParallelTrainer, + TransformerDecoderBlock, + TransformerEncoder, ) - -from nanodl.__src.layers.attention import ( - MultiQueryAttention, - LocalMultiHeadAttention, - HierarchicalMultiHeadAttention, - GatedMultiHeadAttention, - RotaryMultiHeadAttention +from nanodl.__src.models.vit import ( + PatchEmbedding, + ViT, + ViTBlock, + ViTDataParallelTrainer, + ViTEncoder, ) - -from nanodl.__src.utils.data import ( - Dataset, - ArrayDataset, - DataLoader +from nanodl.__src.models.whisper import ( + Whisper, + WhisperDataParallelTrainer, + WhisperSpeechEncoder, + WhisperSpeechEncoderBlock, ) - +from nanodl.__src.utils.data import ArrayDataset, DataLoader, Dataset from nanodl.__src.utils.ml import ( batch_cosine_similarities, batch_pearsonr, @@ -164,19 +124,18 @@ jaccard, kl_divergence, mean_reciprocal_rank, - zero_pad_sequences + zero_pad_sequences, ) - -from nanodl.__src.utils.nlp import( +from nanodl.__src.utils.nlp import ( bleu, cider_score, meteor, perplexity, rouge, - word_error_rate + word_error_rate, ) - -from nanodl.__src.utils.vision import( +from nanodl.__src.utils.random import * +from nanodl.__src.utils.vision import ( adjust_brightness, adjust_contrast, flip_image, @@ -187,7 +146,6 @@ sobel_edge_detection, ) - __all__ = [ # Sklearn GPU "NaiveBayesClassifier", @@ -197,7 +155,6 @@ "LinearRegression", "LogisticRegression", "GaussianProcess", - # Models "IJEPA", "IJEPADataParallelTrainer", @@ -206,7 +163,7 @@ "GemmaDataParallelTrainer", "GemmaDecoder", "GemmaDecoderBlock", - "GAT", + "GAT", "GraphAttentionLayer", "T5", "T5DataParallelTrainer", @@ -233,11 +190,11 @@ "MixerDataParallelTrainer", "MixerBlock", "MixerEncoder", - "LlaMA2", - "LlaMADataParallelTrainer", + "Llama3", + "LlamaDataParallelTrainer", "RotaryPositionalEncoding", - "LlaMA2Decoder", - "LlaMA2DecoderBlock", + "Llama3Decoder", + "Llama3DecoderBlock", "GroupedRotaryMultiHeadAttention", "GPT3", "GPT4", @@ -276,12 +233,10 @@ "TokenAndPositionEmbedding", "MultiHeadAttention", "AddNorm", - # Utilities - "Dataset", - "ArrayDataset", + "Dataset", + "ArrayDataset", "DataLoader", - "Tokenizer", "batch_cosine_similarities", "batch_pearsonr", "classification_scores", @@ -312,7 +267,6 @@ "HierarchicalMultiHeadAttention", "GatedMultiHeadAttention", "RotaryMultiHeadAttention", - # Random "time_rng_key", "uniform", @@ -336,31 +290,37 @@ import importlib import sys + def check_library_installed(lib_name): try: return importlib.import_module(lib_name) except ImportError: raise ImportError(f"{lib_name} is not installed or improperly installed.") + def test_flax(flax): model = flax.linen.Dense(features=10) + def test_jax(jax): arr = jax.numpy.array([1, 2, 3]) result = jax.numpy.sum(arr) + def test_optax(optax): optimizer = optax.sgd(learning_rate=0.1) + def test_einops(einops): - arr = einops.rearrange([1, 2, 3], 'a b c -> b a c') + arr = einops.rearrange([1, 2, 3], "a b c -> b a c") + def main(): try: - flax = check_library_installed('flax') - jax = check_library_installed('jax') - optax = check_library_installed('optax') - einops = check_library_installed('einops') + flax = check_library_installed("flax") + jax = check_library_installed("jax") + optax = check_library_installed("optax") + einops = check_library_installed("einops") test_flax(flax) test_jax(jax) @@ -373,5 +333,6 @@ def main(): print(f"An error occurred while verifying Jax/Flax/Optax installation: {e}") sys.exit(1) + if __name__ == "__main__": main() diff --git a/nanodl/__src/layers/__init__.py b/nanodl/__src/classical/__init__.py similarity index 100% rename from nanodl/__src/layers/__init__.py rename to nanodl/__src/classical/__init__.py diff --git a/nanodl/__src/sklearn_gpu/bayes.py b/nanodl/__src/classical/bayes.py similarity index 85% rename from nanodl/__src/sklearn_gpu/bayes.py rename to nanodl/__src/classical/bayes.py index 9536c70..ecd5fcc 100644 --- a/nanodl/__src/sklearn_gpu/bayes.py +++ b/nanodl/__src/classical/bayes.py @@ -1,8 +1,12 @@ +from typing import Tuple + import jax import jax.numpy as jnp -from typing import Tuple -def fit_naive_bayes(X: jnp.ndarray, y: jnp.ndarray, num_classes: int) -> Tuple[jnp.ndarray, jnp.ndarray]: + +def fit_naive_bayes( + X: jnp.ndarray, y: jnp.ndarray, num_classes: int +) -> Tuple[jnp.ndarray, jnp.ndarray]: class_priors = jnp.zeros(num_classes) feature_probs = jnp.zeros((num_classes, X.shape[1])) @@ -15,8 +19,11 @@ def fit_naive_bayes(X: jnp.ndarray, y: jnp.ndarray, num_classes: int) -> Tuple[j return class_priors, feature_probs + @jax.jit -def predict_naive_bayes(X: jnp.ndarray, class_priors: jnp.ndarray, feature_probs: jnp.ndarray) -> jnp.ndarray: +def predict_naive_bayes( + X: jnp.ndarray, class_priors: jnp.ndarray, feature_probs: jnp.ndarray +) -> jnp.ndarray: # Calculate log probabilities for features log_feature_probs = jnp.log(feature_probs) log_feature_probs_neg = jnp.log(1 - feature_probs) @@ -26,14 +33,18 @@ def predict_naive_bayes(X: jnp.ndarray, class_priors: jnp.ndarray, feature_probs expanded_log_feature_probs_neg = log_feature_probs_neg[:, None, :] # Calculate log probabilities for each sample and class - log_probs = jnp.sum(expanded_log_feature_probs * X + expanded_log_feature_probs_neg * (1 - X), axis=2) + log_probs = jnp.sum( + expanded_log_feature_probs * X + expanded_log_feature_probs_neg * (1 - X), + axis=2, + ) log_probs += jnp.log(class_priors)[:, None] return jnp.argmax(log_probs, axis=0) -@jax.jit + def accuracy(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> float: return jnp.mean(y_true == y_pred) + class NaiveBayesClassifier: """ Naive Bayes classifier using JAX. @@ -64,4 +75,4 @@ def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> None: self.class_priors, self.feature_probs = fit_naive_bayes(X, y, self.num_classes) def predict(self, X: jnp.ndarray) -> jnp.ndarray: - return predict_naive_bayes(X, self.class_priors, self.feature_probs) \ No newline at end of file + return predict_naive_bayes(X, self.class_priors, self.feature_probs) diff --git a/nanodl/__src/sklearn_gpu/clustering.py b/nanodl/__src/classical/clustering.py similarity index 61% rename from nanodl/__src/sklearn_gpu/clustering.py rename to nanodl/__src/classical/clustering.py index 774703a..e9e2599 100644 --- a/nanodl/__src/sklearn_gpu/clustering.py +++ b/nanodl/__src/classical/clustering.py @@ -1,7 +1,6 @@ import jax import jax.numpy as jnp -from jax import random, ops -from typing import Optional + class KMeans: """ @@ -25,41 +24,37 @@ class KMeans: ``` """ - def __init__(self, - k: int, - num_iters: int = 100, - random_seed: int = 0) -> None: + def __init__(self, k: int, num_iters: int = 100, random_seed: int = 0) -> None: self.k = k self.num_iters = num_iters self.random_seed = random_seed self.centroids = None self.clusters = None - def initialize_centroids(self, - X: jnp.ndarray) -> jnp.ndarray: - + def initialize_centroids(self, X: jnp.ndarray) -> jnp.ndarray: + indices = jnp.arange(X.shape[0]) - selected = jax.random.choice(jax.random.PRNGKey(self.random_seed), - indices, - shape=(self.k,), - replace=False) + selected = jax.random.choice( + jax.random.PRNGKey(self.random_seed), + indices, + shape=(self.k,), + replace=False, + ) return X[selected] - def assign_clusters(self, - X: jnp.ndarray, - centroids: jnp.ndarray) -> jnp.ndarray: - - distances = jnp.sqrt(((X[:, jnp.newaxis, :] - centroids[jnp.newaxis, :, :]) ** 2).sum(axis=2)) + def assign_clusters(self, X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray: + + distances = jnp.sqrt( + ((X[:, jnp.newaxis, :] - centroids[jnp.newaxis, :, :]) ** 2).sum(axis=2) + ) return jnp.argmin(distances, axis=1) - def update_centroids(self, X: jnp.ndarray, - clusters: jnp.ndarray) -> jnp.ndarray: - + def update_centroids(self, X: jnp.ndarray, clusters: jnp.ndarray) -> jnp.ndarray: + return jnp.array([X[clusters == i].mean(axis=0) for i in range(self.k)]) - def fit(self, - X: jnp.ndarray) -> None: - + def fit(self, X: jnp.ndarray) -> None: + self.centroids = self.initialize_centroids(X) for _ in range(self.num_iters): self.clusters = self.assign_clusters(X, self.centroids) @@ -68,18 +63,17 @@ def fit(self, break self.centroids = new_centroids - def predict(self, - X: jnp.ndarray) -> jnp.ndarray: + def predict(self, X: jnp.ndarray) -> jnp.ndarray: if self.centroids is None: raise ValueError("Model not yet trained. Call 'fit' with training data.") return self.assign_clusters(X, self.centroids) - + class GaussianMixtureModel: """ Gaussian Mixture Model implemented in JAX. - This class represents a Gaussian Mixture Model (GMM) for clustering and density estimation. + This class represents a Gaussian Mixture Model (GMM) for clustering and density estimation. It uses the Expectation-Maximization (EM) algorithm for fitting the model to data. Attributes: @@ -105,11 +99,9 @@ class GaussianMixtureModel: ``` """ - def __init__(self, - n_components: int, - tol: float = 1e-3, - max_iter: int = 100, - seed: int = 0) -> None: + def __init__( + self, n_components: int, tol: float = 1e-3, max_iter: int = 100, seed: int = 0 + ) -> None: self.n_components = n_components self.tol = tol self.max_iter = max_iter @@ -118,18 +110,16 @@ def __init__(self, self.weights = None self.seed = seed - def fit(self, - X: jnp.ndarray) -> None: + def fit(self, X: jnp.ndarray) -> None: _, n_features = X.shape rng = jax.random.PRNGKey(self.seed) - # Step 1: Initialization self.means = jax.random.normal(rng, (self.n_components, n_features)) self.covariances = jnp.array([jnp.eye(n_features)] * self.n_components) self.weights = jnp.ones(self.n_components) / self.n_components log_likelihood = 0 - for iteration in range(self.max_iter): + for _ in range(self.max_iter): responsibilities = self._e_step(X) self._m_step(X, responsibilities) @@ -138,41 +128,49 @@ def fit(self, break log_likelihood = new_log_likelihood - def _e_step(self, - X: jnp.ndarray) -> jnp.ndarray: + def _e_step(self, X: jnp.ndarray) -> jnp.ndarray: responsibilities = jnp.zeros((X.shape[0], self.n_components)) for k in range(self.n_components): - responsibilities = responsibilities.at[:, k].set(self.weights[k] * self._multivariate_gaussian(X, self.means[k], self.covariances[k])) + responsibilities = responsibilities.at[:, k].set( + self.weights[k] + * self._multivariate_gaussian(X, self.means[k], self.covariances[k]) + ) responsibilities /= responsibilities.sum(axis=1, keepdims=True) return responsibilities - def _m_step(self, - X: jnp.ndarray, - responsibilities: jnp.ndarray) -> None: + def _m_step(self, X: jnp.ndarray, responsibilities: jnp.ndarray) -> None: n_samples = X.shape[0] for k in range(self.n_components): Nk = responsibilities[:, k].sum() - self.means = self.means.at[k].set((1 / Nk) * jnp.dot(responsibilities[:, k], X)) + self.means = self.means.at[k].set( + (1 / Nk) * jnp.dot(responsibilities[:, k], X) + ) diff = X - self.means[k] - self.covariances = self.covariances.at[k].set((1 / Nk) * jnp.dot(responsibilities[:, k] * diff.T, diff)) + self.covariances = self.covariances.at[k].set( + (1 / Nk) * jnp.dot(responsibilities[:, k] * diff.T, diff) + ) self.weights = self.weights.at[k].set(Nk / n_samples) - def _multivariate_gaussian(self, - X: jnp.ndarray, - mean: jnp.ndarray, - cov: jnp.ndarray) -> jnp.ndarray: + def _multivariate_gaussian( + self, X: jnp.ndarray, mean: jnp.ndarray, cov: jnp.ndarray + ) -> jnp.ndarray: n = X.shape[1] diff = X - mean - return jnp.exp(-0.5 * jnp.sum(jnp.dot(diff, jnp.linalg.inv(cov)) * diff, axis=1)) / (jnp.sqrt((2 * jnp.pi) ** n * jnp.linalg.det(cov))) + return jnp.exp( + -0.5 * jnp.sum(jnp.dot(diff, jnp.linalg.inv(cov)) * diff, axis=1) + ) / (jnp.sqrt((2 * jnp.pi) ** n * jnp.linalg.det(cov))) - def _compute_log_likelihood(self, - X: jnp.ndarray) -> float: + def _compute_log_likelihood(self, X: jnp.ndarray) -> float: log_likelihood = 0 for k in range(self.n_components): - log_likelihood += jnp.sum(jnp.log(self.weights[k] * self._multivariate_gaussian(X, self.means[k], self.covariances[k]))) + log_likelihood += jnp.sum( + jnp.log( + self.weights[k] + * self._multivariate_gaussian(X, self.means[k], self.covariances[k]) + ) + ) return log_likelihood - - def predict(self, - X: jnp.ndarray) -> jnp.ndarray: + + def predict(self, X: jnp.ndarray) -> jnp.ndarray: responsibilities = self._e_step(X) - return jnp.argmax(responsibilities, axis=1) \ No newline at end of file + return jnp.argmax(responsibilities, axis=1) diff --git a/nanodl/__src/sklearn_gpu/dimensionality_reduction.py b/nanodl/__src/classical/dimensionality_reduction.py similarity index 78% rename from nanodl/__src/sklearn_gpu/dimensionality_reduction.py rename to nanodl/__src/classical/dimensionality_reduction.py index 33adece..7a5704d 100644 --- a/nanodl/__src/sklearn_gpu/dimensionality_reduction.py +++ b/nanodl/__src/classical/dimensionality_reduction.py @@ -1,6 +1,8 @@ +from typing import Optional + import jax import jax.numpy as jnp -from typing import Optional + class PCA: """ @@ -35,40 +37,36 @@ class PCA: print(X_sampled.shape, original_data.shape, transformed_data.shape) """ - def __init__(self, - n_components: int): - + def __init__(self, n_components: int): + self.n_components = n_components self.components = None self.mean = None - def fit(self, - X: jnp.ndarray) -> None: - + def fit(self, X: jnp.ndarray) -> None: + self.mean = jnp.mean(X, axis=0) X_centered = X - self.mean cov_matrix = jnp.cov(X_centered, rowvar=False) eigvals, eigvecs = jnp.linalg.eigh(cov_matrix) sorted_indices = jnp.argsort(eigvals)[::-1] sorted_eigvecs = eigvecs[:, sorted_indices] - self.components = sorted_eigvecs[:, :self.n_components] + self.components = sorted_eigvecs[:, : self.n_components] - def transform(self, - X: jnp.ndarray) -> jnp.ndarray: + def transform(self, X: jnp.ndarray) -> jnp.ndarray: X_centered = X - self.mean return jnp.dot(X_centered, self.components) - def inverse_transform(self, - X_transformed: jnp.ndarray) -> jnp.ndarray: + def inverse_transform(self, X_transformed: jnp.ndarray) -> jnp.ndarray: return jnp.dot(X_transformed, self.components.T) + self.mean - def sample(self, - n_samples:int=1, - key: Optional[jnp.ndarray] = None) -> jnp.ndarray: - + def sample( + self, n_samples: int = 1, key: Optional[jnp.ndarray] = None + ) -> jnp.ndarray: + if key is None: key = jax.random.PRNGKey(0) z = jax.random.normal(key, (n_samples, self.n_components)) X_sampled = self.inverse_transform(z) - return X_sampled \ No newline at end of file + return X_sampled diff --git a/nanodl/__src/classical/dsp.py b/nanodl/__src/classical/dsp.py new file mode 100644 index 0000000..c27674b --- /dev/null +++ b/nanodl/__src/classical/dsp.py @@ -0,0 +1,129 @@ +from typing import Tuple + +import jax.numpy as jnp +from jax import random + + +def fastica( + X: jnp.ndarray, n_components: jnp.ndarray, max_iter: int = 1000, tol: float = 1e-4 +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Perform Independent Component Analysis (ICA) on the input data using the FastICA algorithm. + + Parameters: + X : jax.numpy.ndarray + The input data matrix, where each row represents a data point, and each column represents a different signal. + The input data should be a 2D jax.numpy array with shape (n_samples, n_features). + n_components : int + The number of independent components to extract. This should be less than or equal to the number of features in the input data. + max_iter : int, optional + The maximum number of iterations for the optimization process. The default value is 1000 iterations. + tol : float, optional + The tolerance for convergence. The optimization process stops when the maximum absolute change in the diagonal elements of the + unmixing matrix from one iteration to the next is less than this tolerance. The default value is 1e-4. + + Returns: + S : jax.numpy.ndarray + The separated independent components. This is a 2D jax.numpy array with shape (n_components, n_samples), where each row represents + a different independent component, and each column represents a data point. + W : jax.numpy.ndarray + The unmixing matrix. This is a 2D jax.numpy array with shape (n_components, n_features), representing the estimated inverse of the + mixing matrix. It is used to transform the input data back into the independent components. + whitening_matrix : jax.numpy.ndarray + The whitening matrix used to whiten the input data. This is a 2D jax.numpy array with shape (n_features, n_features), used to decorrelate + the input data and make its covariance matrix the identity matrix. + + Description: + The FastICA algorithm aims to separate the mixed input signals into statistically independent components. The function first whitens the input + data to decorrelate it and normalize its variance. Then, it initializes a random unmixing matrix and uses an optimization process to find + the optimal unmixing matrix that maximizes the independence of the source signals. + + The optimization process involves iteratively updating the unmixing matrix based on the non-linear function (`tanh` in this case) applied + to the transformed data (`WX`). The process stops when the unmixing matrix converges according to the specified tolerance (`tol`) or when the + maximum number of iterations (`max_iter`) is reached. + + Once the optimal unmixing matrix is found, the function applies it to the whitened data to obtain the separated independent components. + + Example usage: + # Set random seed + jax.random.PRNGKey(42) + + # Generate synthetic source signals + n_samples = 2000 + time = jnp.linspace(0, 8, n_samples) + s1 = jnp.sin(2 * time) + s2 = jnp.sign(jnp.sin(3 * time)) + + # Combine the sources with a mixing matrix + A = jnp.array([[1, 1], [0.5, 2]]) + X = jnp.dot(A, jnp.array([s1, s2])) + + # Perform ICA + n_components = 2 + S, W, whitening_matrix = fastica(X.T, n_components) + + # Plot the results + plt.figure(figsize=(12, 8)) + + plt.subplot(3, 1, 1) + plt.title('Original Source Signals') + plt.plot(time, s1, label='Source 1 (Sine Wave)') + plt.plot(time, s2, label='Source 2 (Square Wave)') + plt.legend() + + plt.subplot(3, 1, 2) + plt.title('Mixed Signals') + plt.plot(time, X[0], label='Mixed Signal 1') + plt.plot(time, X[1], label='Mixed Signal 2') + plt.legend() + + plt.subplot(3, 1, 3) + plt.title('Separated Signals (Using ICA)') + plt.plot(time, S[0], label='Separated Signal 1') + plt.plot(time, S[1], label='Separated Signal 2') + plt.legend() + s + plt.tight_layout() + plt.show() + """ + # Calculate the covariance matrix and perform eigenvalue decomposition + cov_matrix = jnp.cov(X, rowvar=False) + eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix) + + # Sort the eigenvalues and eigenvectors + idx = jnp.argsort(eigenvalues)[::-1] + eigenvalues = eigenvalues[idx] + eigenvectors = eigenvectors[:, idx] + + # Create the whitening matrix + D = jnp.diag(1.0 / jnp.sqrt(eigenvalues)) + whitening_matrix = jnp.dot(eigenvectors, D) + X_whitened = jnp.dot(X, whitening_matrix) + + # Initialize unmixing matrix with random values + rng = random.PRNGKey(0) # Set a seed for reproducibility + W = random.normal(rng, (n_components, n_components)) + + # Perform FastICA algorithm + for _ in range(max_iter): + WX = jnp.dot(X_whitened, W.T) + g = jnp.tanh(WX) + g_prime = 1 - g**2 + W_new = (jnp.dot(X_whitened.T, g) / X.shape[0]) - jnp.diag( + g_prime.mean(axis=0) + ).dot(W) + + # Orthogonalize the unmixing matrix + W_new, _ = jnp.linalg.qr(W_new) + + # Check for convergence + if jnp.max(jnp.abs(jnp.abs(jnp.diag(jnp.dot(W_new, W.T))) - 1)) < tol: + W = W_new + break + + W = W_new + + # Calculate the separated independent components + S = jnp.dot(W, X_whitened.T) + + return S, W, whitening_matrix diff --git a/nanodl/__src/sklearn_gpu/regression.py b/nanodl/__src/classical/regression.py similarity index 94% rename from nanodl/__src/sklearn_gpu/regression.py rename to nanodl/__src/classical/regression.py index 3ec737d..803d48e 100644 --- a/nanodl/__src/sklearn_gpu/regression.py +++ b/nanodl/__src/classical/regression.py @@ -1,6 +1,8 @@ +from typing import Callable, Tuple + import jax import jax.numpy as jnp -from typing import Callable, Tuple + class LinearRegression: """ @@ -33,6 +35,7 @@ class LinearRegression: print("Learned Bias:", learned_bias) ``` """ + def __init__(self, input_dim, output_dim): self.input_dim = input_dim self.output_dim = output_dim @@ -63,7 +66,6 @@ def fit(self, x_data, y_data, learning_rate=0.1, num_epochs=100): def get_params(self): return self.params - class LogisticRegression: @@ -97,6 +99,7 @@ class LogisticRegression: print("Predictions:", predictions) ``` """ + def __init__(self, input_dim): self.input_dim = input_dim self.key = jax.random.PRNGKey(0) @@ -129,7 +132,7 @@ def fit(self, x_data, y_data, learning_rate=0.1, num_epochs=100): def predict(self, x_data): return self.logistic_regression(self.params, x_data) - + class GaussianProcess: """ @@ -165,30 +168,29 @@ def rbf_kernel(x1, x2, length_scale=1.0): mean, covariance = gp.predict(X_new) """ - def __init__(self, - kernel: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], - noise: float = 1e-3): - + def __init__( + self, + kernel: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], + noise: float = 1e-3, + ): + self.kernel = kernel self.noise = noise self.X = None self.y = None self.K = None - def fit(self, - X: jnp.ndarray, - y: jnp.ndarray) -> None: - + def fit(self, X: jnp.ndarray, y: jnp.ndarray) -> None: + self.X = X self.y = y self.K = self.kernel(self.X, self.X) + jnp.eye(len(X)) * self.noise - def predict(self, - X_new: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def predict(self, X_new: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: K_inv = jnp.linalg.inv(self.K) K_s = self.kernel(self.X, X_new) K_ss = self.kernel(X_new, X_new) mu_s = jnp.dot(K_s.T, jnp.dot(K_inv, self.y)) cov_s = K_ss - jnp.dot(K_s.T, jnp.dot(K_inv, K_s)) - return mu_s, cov_s \ No newline at end of file + return mu_s, cov_s diff --git a/nanodl/__src/experimental/bitlinear.py b/nanodl/__src/experimental/bitlinear.py new file mode 100644 index 0000000..b0f2796 --- /dev/null +++ b/nanodl/__src/experimental/bitlinear.py @@ -0,0 +1,72 @@ +import jax +import jax.numpy as jnp +from flax import linen as nn + + +class BitLinear(nn.Module): + """ + Implements a linear transformation layer with quantization for both activations and weights, + optimized for low-bit inference. The layer is designed to operate in two modes: training and inference. + During training, the activations and weights are quantized using separate quantization functions, + aiming to simulate low-bit operations and reduce the quantization error. For inference, a more + aggressive quantization scheme is applied to both activations and weights, potentially different + from the training quantization, to maximize performance and efficiency on low-bit hardware. + + Attributes: + output_features (int): The number of output features. + kernel_init (callable): A function to initialize the weights. Default is LeCun normal initializer. + """ + + output_features: int + kernel_init: callable = nn.initializers.lecun_normal() + + @nn.compact + def __call__(self, x, training=False): + w = self.param("kernel", self.kernel_init, (x.shape[-1], self.output_features)) + + if not training: + x_quant, x_scale = self.fused_activation_norm_quant(x) + + # HELP: How run externally on params at once for efficiency + # Quantising weigts all over again each call is repeated work + # This can be done on params dict using jax tree utils. + # Albeit the weight scale for quantisation needs to be utilised at inference + # Its easy to bypassed on its own by passing the weight scale during a call + # This will be a module in various transformer models in my project (NanoDL) + # Is there a way to achieve this without complication my existing codebase? + w, w_scale = self.inference_weight_quant(w) + + return self.inference_lowbit_matmul(x_quant, w) / w_scale / x_scale + + x_norm = self.rmsnorm(x) + x_quant = x_norm + jax.lax.stop_gradient(self.activation_quant(x_norm) - x_norm) + w_quant = w + jax.lax.stop_gradient(self.weight_quant(w) - w) + return jnp.dot(x_quant, w_quant) + + def rmsnorm(self, x): + return x / jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + 1e-5) + + def activation_quant(self, x): + scale = 127.0 / jnp.max(jnp.abs(x), axis=-1, keepdims=True).clip(min=1e-5) + y = jnp.round(x * scale).clip(-128, 127) / scale + return y + + def weight_quant(self, w): + scale = 1.0 / jnp.mean(jnp.abs(w)).clip(min=1e-5) + u = jnp.round(w * scale).clip(-1, 1) / scale + return u + + def fused_activation_norm_quant(self, x): + x_norm = self.rmsnorm(x) + scale = 127.0 / jnp.max(jnp.abs(x_norm), axis=-1, keepdims=True).clip(min=1e-5) + x_quant = jnp.round(x_norm * scale).clip(-128, 127) / scale + return x_quant, scale + + def inference_weight_quant(self, w): + scale = jnp.abs(w).mean().clip(min=1e-5) + u = jnp.sign(w - w.mean()) * scale + return u, scale + + # Help: how to implement lowbit matmul kernel for efficiency that can be integrated into Flax model + def inference_lowbit_matmul(self, x, w): + return jnp.dot(x, w) diff --git a/nanodl/__src/models/gat.py b/nanodl/__src/experimental/gat.py similarity index 69% rename from nanodl/__src/models/gat.py rename to nanodl/__src/experimental/gat.py index 7200e40..789dad3 100644 --- a/nanodl/__src/models/gat.py +++ b/nanodl/__src/experimental/gat.py @@ -1,8 +1,7 @@ -import jax, flax, optax, time +import jax import jax.numpy as jnp from flax import linen as nn -from flax.training import train_state -from typing import Any, Tuple, Optional, Iterable + class GraphAttentionLayer(nn.Module): """ @@ -29,6 +28,7 @@ class GraphAttentionLayer(nn.Module): Returns: jnp.ndarray: The output node features after the attention mechanism. If `concat` is True, applies a non-linearity (LeakyReLU); otherwise, returns the linear combination of features directly. Shape is (N, out_features). """ + in_features: int out_features: int dropout_rate: float @@ -36,34 +36,35 @@ class GraphAttentionLayer(nn.Module): concat: bool = True @nn.compact - def __call__(self, - x: jnp.ndarray, - adj: jnp.ndarray, - training: bool) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, adj: jnp.ndarray, training: bool) -> jnp.ndarray: + + W = self.param( + "W", + jax.nn.initializers.glorot_uniform(), + (self.in_features, self.out_features), + ) - W = self.param('W', jax.nn.initializers.glorot_uniform(), - (self.in_features, self.out_features)) - - a = self.param('a', jax.nn.initializers.glorot_uniform(), - (2 * self.out_features, 1)) + a = self.param( + "a", jax.nn.initializers.glorot_uniform(), (2 * self.out_features, 1) + ) h = jnp.dot(x, W) - h = nn.Dropout(rate=self.dropout_rate, - deterministic=not training)(h) + h = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(h) N = h.shape[0] - a_input = jnp.concatenate([h[:, None, :].repeat(N, axis=1), - h[None, :, :].repeat(N, axis=0)], axis=2) - - e = nn.leaky_relu(jnp.dot(a_input, a).squeeze(-1), - negative_slope=self.alpha) + a_input = jnp.concatenate( + [h[:, None, :].repeat(N, axis=1), h[None, :, :].repeat(N, axis=0)], axis=2 + ) + + e = nn.leaky_relu(jnp.dot(a_input, a).squeeze(-1), negative_slope=self.alpha) zero_vec = -9e15 * jnp.ones_like(e) attention = jnp.where(adj > 0, e, zero_vec) attention = nn.softmax(attention, axis=1) - attention = nn.Dropout(rate=self.dropout_rate, - deterministic=not training)(attention) + attention = nn.Dropout(rate=self.dropout_rate, deterministic=not training)( + attention + ) h_prime = jnp.matmul(attention, h) @@ -75,13 +76,13 @@ def __call__(self, class GAT(nn.Module): """ - Graph Attention Networks (GATs) are a type of neural network designed for graph-structured data. - The key feature of GATs is the use of attention mechanisms to weigh the importance of nodes' neighbors. - This allows GATs to focus on the most relevant parts of the graph structure when learning node representations. - In GATs, each node aggregates information from its neighbors, but not all neighbors contribute equally. - The attention mechanism computes weights that determine the importance of each neighbor's features to the target node. + Graph Attention Networks (GATs) are a type of neural network designed for graph-structured data. + The key feature of GATs is the use of attention mechanisms to weigh the importance of nodes' neighbors. + This allows GATs to focus on the most relevant parts of the graph structure when learning node representations. + In GATs, each node aggregates information from its neighbors, but not all neighbors contribute equally. + The attention mechanism computes weights that determine the importance of each neighbor's features to the target node. These weights are learned during training and are based on the features of the nodes involved. - GATs can handle graphs with varying sizes and connectivity patterns, making them suitable for a wide range of applications, + GATs can handle graphs with varying sizes and connectivity patterns, making them suitable for a wide range of applications, including social network analysis, recommendation systems, and molecular structure analysis. Example usage: @@ -105,11 +106,11 @@ class GAT(nn.Module): adj = jax.random.bernoulli(key, 0.3, (num_nodes, num_nodes)) # Random adjacency matrix # Initialize the GAT model - model = GAT(nfeat=num_features, - nhid=8, - nclass=nclass, - dropout_rate=0.5, - alpha=0.2, + model = GAT(nfeat=num_features, + nhid=8, + nclass=nclass, + dropout_rate=0.5, + alpha=0.2, nheads=3) # Initialize the model parameters @@ -138,6 +139,7 @@ class GAT(nn.Module): Returns: jnp.ndarray: The output node features after passing through the GAT model. Shape is (N, nclass), representing the class scores for each node. """ + nfeat: int nhid: int nclass: int @@ -146,24 +148,31 @@ class GAT(nn.Module): nheads: int @nn.compact - def __call__(self, - x: jnp.ndarray, - adj: jnp.ndarray, - training: bool = False) -> jnp.ndarray: - - heads = [GraphAttentionLayer(self.nfeat, - self.nhid, - dropout_rate=self.dropout_rate, - alpha=self.alpha, concat=True) for _ in range(self.nheads)] - + def __call__( + self, x: jnp.ndarray, adj: jnp.ndarray, training: bool = False + ) -> jnp.ndarray: + + heads = [ + GraphAttentionLayer( + self.nfeat, + self.nhid, + dropout_rate=self.dropout_rate, + alpha=self.alpha, + concat=True, + ) + for _ in range(self.nheads) + ] + x = jnp.concatenate([head(x, adj, training) for head in heads], axis=1) - - x = nn.Dropout(rate=self.dropout_rate, - deterministic=not training)(x) - - out_att = GraphAttentionLayer(self.nhid * self.nheads, - self.nclass, - dropout_rate=self.dropout_rate, - alpha=self.alpha, concat=False) - - return out_att(x, adj, training) \ No newline at end of file + + x = nn.Dropout(rate=self.dropout_rate, deterministic=not training)(x) + + out_att = GraphAttentionLayer( + self.nhid * self.nheads, + self.nclass, + dropout_rate=self.dropout_rate, + alpha=self.alpha, + concat=False, + ) + + return out_att(x, adj, training) diff --git a/nanodl/__src/experimental/kan.py b/nanodl/__src/experimental/kan.py new file mode 100644 index 0000000..7ff74d9 --- /dev/null +++ b/nanodl/__src/experimental/kan.py @@ -0,0 +1,227 @@ +import jax +import jax.numpy as jnp +from flax import linen as nn +from jax.scipy.special import logsumexp +from jax import random + + +class KANLinear(nn.Module): + """ + KANLinear is a class that represents a linear layer in a Kernelized Attention Network (KAN). + It uses B-splines to model the attention mechanism, which allows for more flexibility than traditional attention mechanisms. + + Attributes: + in_features (int): The number of input features. + out_features (int): The number of output features. + grid_size (int): The size of the grid used for the B-splines. Default is 5. + spline_order (int): The order of the B-splines. Default is 3. + scale_noise (float): The scale of the noise added to the B-splines. Default is 0.1. + scale_base (float): The scale of the base weights. Default is 1.0. + scale_spline (float): The scale of the spline weights. Default is 1.0. + enable_standalone_scale_spline (bool): Whether to enable standalone scaling of the spline weights. Default is True. + base_activation (callable): The activation function to use for the base weights. Default is nn.silu. + grid_eps (float): The epsilon value used for the grid. Default is 0.02. + grid_range (list): The range of the grid. Default is [-1, 1]. + """ + + in_features: int + out_features: int + grid_size: int = 5 + spline_order: int = 3 + scale_noise: float = 0.1 + scale_base: float = 1.0 + scale_spline: float = 1.0 + enable_standalone_scale_spline: bool = True + base_activation: callable = nn.silu + grid_eps: float = 0.02 + grid_range: list = [-1, 1] + + def setup(self): + h = (self.grid_range[1] - self.grid_range[0]) / self.grid_size + grid = jnp.tile( + jnp.arange(-self.spline_order, self.grid_size + self.spline_order + 1) * h + + self.grid_range[0], + (self.in_features, 1), + ) + self.grid = self.param("grid", grid.shape, nn.initializers.zeros) + + self.base_weight = self.param( + "base_weight", + (self.out_features, self.in_features), + nn.initializers.kaiming_uniform(), + ) + self.spline_weight = self.param( + "spline_weight", + (self.out_features, self.in_features, self.grid_size + self.spline_order), + nn.initializers.zeros, + ) + if self.enable_standalone_scale_spline: + self.spline_scaler = self.param( + "spline_scaler", + (self.out_features, self.in_features), + nn.initializers.kaiming_uniform(), + ) + + self.reset_parameters() + + def reset_parameters(self): + self.base_weight = ( + nn.initializers.kaiming_uniform()( + self.base_weight.shape, self.base_weight.dtype + ) + * self.scale_base + ) + noise = ( + ( + random.uniform( + jax.random.PRNGKey(0), + (self.grid_size + 1, self.in_features, self.out_features), + ) + - 1 / 2 + ) + * self.scale_noise + / self.grid_size + ) + self.spline_weight = self.curve2coeff( + self.grid.T[self.spline_order : -self.spline_order], noise + ) * (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) + if self.enable_standalone_scale_spline: + self.spline_scaler = ( + nn.initializers.kaiming_uniform()( + self.spline_scaler.shape, self.spline_scaler.dtype + ) + * self.scale_spline + ) + + def b_splines(self, x): + grid = self.grid + x = x[..., None] + bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).astype(x.dtype) + for k in range(1, self.spline_order + 1): + bases = (x - grid[:, : -(k + 1)]) / ( + grid[:, k:-1] - grid[:, : -(k + 1)] + ) * bases[..., :-1] + (grid[:, k + 1 :] - x) / ( + grid[:, k + 1 :] - grid[:, 1:(-k)] + ) * bases[ + ..., 1: + ] + return bases + + def curve2coeff(self, x, y): + A = self.b_splines(x).transpose((1, 0, 2)) + B = y.transpose((1, 0, 2)) + solution = jnp.linalg.lstsq(A, B)[0] + result = solution.transpose((2, 0, 1)) + return result + + @property + def scaled_spline_weight(self): + return self.spline_weight * ( + self.spline_scaler[..., None] + if self.enable_standalone_scale_spline + else 1.0 + ) + + def __call__(self, x): + base_output = jnp.dot(self.base_activation(x), self.base_weight.T) + spline_output = jnp.dot( + self.b_splines(x).reshape(x.shape[0], -1), + self.scaled_spline_weight.reshape(self.out_features, -1).T, + ) + return base_output + spline_output + + def update_grid(self, x, margin=0.01): + batch = x.shape[0] + + splines = self.b_splines(x).transpose((1, 0, 2)) + orig_coeff = self.scaled_spline_weight.transpose((1, 2, 0)) + unreduced_spline_output = jnp.matmul(splines, orig_coeff).transpose((1, 0, 2)) + + x_sorted = jnp.sort(x, axis=0) + grid_adaptive = x_sorted[ + jnp.linspace(0, batch - 1, self.grid_size + 1, dtype=jnp.int64) + ] + + uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size + grid_uniform = ( + jnp.arange(self.grid_size + 1, dtype=jnp.float32)[..., None] * uniform_step + + x_sorted[0] + - margin + ) + + grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive + grid = jnp.concatenate( + [ + grid[:1] + - uniform_step * jnp.arange(self.spline_order, 0, -1)[..., None], + grid, + grid[-1:] + + uniform_step * jnp.arange(1, self.spline_order + 1)[..., None], + ], + axis=0, + ) + + self.grid = grid.T + self.spline_weight = self.curve2coeff(x, unreduced_spline_output) + + def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): + l1_fake = jnp.mean(jnp.abs(self.spline_weight), axis=-1) + regularization_loss_activation = jnp.sum(l1_fake) + p = l1_fake / regularization_loss_activation + regularization_loss_entropy = -jnp.sum(p * jnp.log(p)) + return ( + regularize_activation * regularization_loss_activation + + regularize_entropy * regularization_loss_entropy + ) + + +class KAN(nn.Module): + """ + KAN is a class that represents a Kernelized Attention Network (KAN). + It is a type of neural network that uses a kernelized attention mechanism, which allows for more flexibility than traditional attention mechanisms. + + Attributes: + layers_hidden (list): A list of integers representing the number of hidden units in each layer. + """ + + layers_hidden: list + grid_size: int = 5 + spline_order: int = 3 + scale_noise: float = 0.1 + scale_base: float = 1.0 + scale_spline: float = 1.0 + base_activation: callable = nn.silu + grid_eps: float = 0.02 + grid_range: list = [-1, 1] + + def setup(self): + self.layers = [ + KANLinear( + in_features, + out_features, + grid_size=self.grid_size, + spline_order=self.spline_order, + scale_noise=self.scale_noise, + scale_base=self.scale_base, + scale_spline=self.scale_spline, + base_activation=self.base_activation, + grid_eps=self.grid_eps, + grid_range=self.grid_range, + ) + for in_features, out_features in zip( + self.layers_hidden, self.layers_hidden[1:] + ) + ] + + def __call__(self, x, update_grid=False): + for layer in self.layers: + if update_grid: + layer.update_grid(x) + x = layer(x) + return x + + def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): + return sum( + layer.regularization_loss(regularize_activation, regularize_entropy) + for layer in self.layers + ) diff --git a/nanodl/__src/models/mamba_experimental.py b/nanodl/__src/experimental/mamba.py similarity index 63% rename from nanodl/__src/models/mamba_experimental.py rename to nanodl/__src/experimental/mamba.py index fdcf320..e162738 100644 --- a/nanodl/__src/models/mamba_experimental.py +++ b/nanodl/__src/experimental/mamba.py @@ -1,15 +1,15 @@ -import jax -import flax -import time import math -import optax -import jax.numpy as jnp +import time +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from einops import einsum from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable -########## EXPERIMENMTAL ############ class MambaBlock(nn.Module): """ @@ -17,7 +17,7 @@ class MambaBlock(nn.Module): convolution, and dense layers to process input sequences. This block is designed for sequence modeling tasks and includes specialized components like selective scan for dynamic computation. - + Attributes: d_inner (int): Dimensionality of the inner dense layer. d_conv (int): Size of the convolution kernel. @@ -28,34 +28,37 @@ class MambaBlock(nn.Module): bias (bool): Flag indicating whether to use bias in dense layers. conv_bias (bool): Flag indicating whether to use bias in the convolution layer. """ + d_inner: int d_conv: int dt_rank: int d_state: int d_model: int seq_len: int - bias: bool - conv_bias: bool + bias: bool + conv_bias: bool def setup(self): self.norm = nn.RMSNorm(self.d_model) self.in_proj = nn.Dense(features=self.d_inner * 2, use_bias=self.bias) - self.conv1d = nn.Conv(features=self.seq_len, - kernel_size=(self.d_conv,), - strides=(1,), - padding='SAME', - use_bias=self.conv_bias, - feature_group_count=self.d_inner) - + self.conv1d = nn.Conv( + features=self.seq_len, + kernel_size=(self.d_conv,), + strides=(1,), + padding="SAME", + use_bias=self.conv_bias, + feature_group_count=self.d_inner, + ) + self.x_proj = nn.Dense(features=self.dt_rank + self.d_state * 2, use_bias=False) self.dt_proj = nn.Dense(features=self.d_inner, use_bias=True) self.out_proj = nn.Dense(features=self.d_model, use_bias=self.bias) - + # Parameter initialization A = jnp.tile(jnp.arange(1, self.d_state + 1), (self.d_inner, 1)) - self.A_log = self.variable('params', 'A_log', lambda: jnp.log(A)) - self.D = self.variable('params', 'D', lambda: jnp.ones((self.d_inner,))) + self.A_log = self.variable("params", "A_log", lambda: jnp.log(A)) + self.D = self.variable("params", "D", lambda: jnp.ones((self.d_inner,))) def __call__(self, inputs: jnp.ndarray): u = self.norm(inputs) @@ -64,58 +67,58 @@ def __call__(self, inputs: jnp.ndarray): x_and_res = self.in_proj(u) x, res = jnp.split(x_and_res, 2, axis=-1) x = jnp.transpose(x, (0, 2, 1)) - x = self.conv1d(x)[:, :, :u.shape[1]] + x = self.conv1d(x)[:, :, : u.shape[1]] x = jnp.transpose(x, (0, 2, 1)) x = nn.silu(x) x_dbl = self.x_proj(u) - delta, B, C = jnp.split(x_dbl, indices_or_sections=[self.dt_rank, - self.dt_rank + self.d_state], - axis=-1) + delta, B, C = jnp.split( + x_dbl, + indices_or_sections=[self.dt_rank, self.dt_rank + self.d_state], + axis=-1, + ) delta = nn.softplus(self.dt_proj(delta)) y = self.selective_scan(x, delta, A, B, C, D) y = y * nn.silu(res) return self.out_proj(y) + inputs - def selective_scan(self, - u: jnp.ndarray, - delta: jnp.ndarray, - A: jnp.ndarray, - B: jnp.ndarray, - C: jnp.ndarray, - D: jnp.ndarray) -> jnp.ndarray: - + def selective_scan( + self, + u: jnp.ndarray, + delta: jnp.ndarray, + A: jnp.ndarray, + B: jnp.ndarray, + C: jnp.ndarray, + D: jnp.ndarray, + ) -> jnp.ndarray: + b, l, d_in = u.shape n = A.shape[1] - - deltaA = jnp.exp(einsum( - delta, A, - 'b l d_in, d_in n -> b l d_in n')) - - deltaB_u = einsum( - delta, B, u, - 'b l d_in, b l n, b l d_in -> b l d_in n') + + deltaA = jnp.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n")) + + deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n") x = jnp.zeros((b, d_in, n)) - ys = [] - + ys = [] + for i in range(l): x = deltaA[:, i] * x + deltaB_u[:, i] - y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') + y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in") ys.append(y) - + return jnp.stack(ys, axis=1) + u * D - + class Mamba(nn.Module): """ - MAMBA is an advanced ML model renowned for its exceptional linear-time processing efficiency, - which notably enhances its inference speed to outperform traditional Transformer models by up to five times in throughput. - Unlike conventional models that struggle with long sequence lengths, MAMBA demonstrates a linear scalability with sequence length, - maintaining or even improving its performance with sequences that extend up to a million elements. - This attribute makes MAMBA a highly versatile and efficient backbone for a variety of sequence modeling tasks across different domains, - including but not limited to language processing, audio analysis, and genomic studies. - + MAMBA is an advanced ML model renowned for its exceptional linear-time processing efficiency, + which notably enhances its inference speed to outperform traditional Transformer models by up to five times in throughput. + Unlike conventional models that struggle with long sequence lengths, MAMBA demonstrates a linear scalability with sequence length, + maintaining or even improving its performance with sequences that extend up to a million elements. + This attribute makes MAMBA a highly versatile and efficient backbone for a variety of sequence modeling tasks across different domains, + including but not limited to language processing, audio analysis, and genomic studies. + Attributes: vocab_size (int): The size of the vocabulary. n_layer (int): The number of MambaBlock layers. @@ -152,9 +155,9 @@ class Mamba(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -168,12 +171,12 @@ class Mamba(nn.Module): 'vocab_size': 100, 'expand': 2, 'n_layer': 2, - 'd_conv': 3, - 'dt_rank': 16, - 'd_state': 8, + 'd_conv': 3, + 'dt_rank': 16, + 'd_state': 8, 'd_model': 64, 'dropout': 0.2, - 'bias':True, + 'bias':True, 'conv_bias': True, 'max_length': max_length, 'start_token': 0, @@ -184,20 +187,20 @@ class Mamba(nn.Module): model = Mamba(**hyperparams) rngs = jax.random.PRNGKey(0) rngs, dropout_rng = jax.random.split(rngs) - params = model.init({'params': rngs, 'dropout': dropout_rng}, + params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = MambaDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -205,15 +208,16 @@ class Mamba(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + vocab_size: int n_layer: int d_conv: int @@ -224,46 +228,48 @@ class Mamba(nn.Module): max_length: int start_token: int end_token: int - dropout: float + dropout: float bias: bool = True conv_bias: bool = True - dt_rank: int = 'auto' + dt_rank: int = "auto" def setup(self): self.d_inner = int(self.expand * self.d_model) - - if self.dt_rank == 'auto': + + if self.dt_rank == "auto": self.dt_rank = math.ceil(self.d_model / 16) self.embedding = nn.Embed(self.vocab_size, self.d_model) - self.layers = [MambaBlock(d_inner=self.d_inner, - d_conv=self.d_conv, - dt_rank=self.dt_rank, - d_state=self.d_state, - d_model=self.d_model, - seq_len=self.max_length, - bias=self.bias, - conv_bias=self.conv_bias) for _ in range(self.n_layer)] - + self.layers = [ + MambaBlock( + d_inner=self.d_inner, + d_conv=self.d_conv, + dt_rank=self.dt_rank, + d_state=self.d_state, + d_model=self.d_model, + seq_len=self.max_length, + bias=self.bias, + conv_bias=self.conv_bias, + ) + for _ in range(self.n_layer) + ] + self.norm_f = nn.RMSNorm(self.d_model) self.dropout1 = nn.Dropout(self.dropout) self.lm_head = nn.Dense(features=self.vocab_size, use_bias=False) # Note: Flax doesn't support parameter sharing like PyTorch's weight tying directly. # You might need to implement a custom method for weight tying or handle it outside the model definition. - def __call__(self, - input_ids: jnp.ndarray, - training: bool = False) -> jnp.ndarray: - + def __call__(self, input_ids: jnp.ndarray, training: bool = False) -> jnp.ndarray: + x = self.embedding(input_ids) for layer in self.layers: x = self.dropout1(layer(x), deterministic=not training) - + x = self.norm_f(x) logits = self.lm_head(x) return logits - def zero_pad(self, arr, max_length): current_length = arr.shape[1] @@ -276,13 +282,14 @@ def zero_pad(self, arr, max_length): padded_array = arr return padded_array - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" @@ -291,8 +298,10 @@ def generate(self, # Autoregressive decoding loop print(self.zero_pad(decoder_input, self.max_length).shape) - for _ in range(self.max_length-1): - decoder_output = self.__call__(self.zero_pad(decoder_input, self.max_length), training=False)[0] + for _ in range(self.max_length - 1): + decoder_output = self.__call__( + self.zero_pad(decoder_input, self.max_length), training=False + )[0] print(decoder_output.shape) last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature @@ -301,29 +310,43 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) - - if next_token.item() == self.end_token or len(output_sequence) == self.max_length: + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) + + if ( + next_token.item() == self.end_token + or len(output_sequence) == self.max_length + ): break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) - for i in range(self.max_length-1): - decoder_output = self.__call__(self.zero_pad(decoder_input, self.max_length), training=False)[0] + for i in range(self.max_length - 1): + decoder_output = self.__call__( + self.zero_pad(decoder_input, self.max_length), training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -332,12 +355,19 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) - - if jnp.all(next_token == self.end_token) or len(output_sequences) == self.max_length: + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) + + if ( + jnp.all(next_token == self.end_token) + or len(output_sequences) == self.max_length + ): break return output_sequences @@ -346,7 +376,7 @@ def generate_batch(self, class MambaDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -365,13 +395,16 @@ class MambaDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: - + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: + self.model = model self.params = None self.params_path = params_path @@ -379,51 +412,61 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(MambaDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(MambaDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + MambaDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + MambaDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -432,35 +475,36 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -471,32 +515,34 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) return self.params - + import jax import jax.numpy as jnp + from nanodl import ArrayDataset, DataLoader -#from nanodl import Mamba, MambaDataParallelTrainer + +# from nanodl import Mamba, MambaDataParallelTrainer # Generate dummy data batch_size = 8 max_length = 128 # Replace with actual tokenised data -data = jnp.ones((101, max_length+1), dtype=jnp.int16) +data = jnp.ones((101, max_length + 1), dtype=jnp.int16) # Shift to create next-token prediction dataset dummy_inputs = data[:, :-1] @@ -504,10 +550,7 @@ def load_params(self, filename: str): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) -dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) +dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) # How to loop through dataloader for batch in dataloader: @@ -517,38 +560,37 @@ def load_params(self, filename: str): # model parameters hyperparams = { - 'vocab_size': 100, - 'expand': 2, - 'n_layer': 2, - 'd_conv': 3, - 'dt_rank': 16, - 'd_state': 8, - 'd_model': 64, - 'dropout': 0.2, - 'bias':True, - 'conv_bias': True, - 'max_length': max_length, - 'start_token': 0, - 'end_token': 50, + "vocab_size": 100, + "expand": 2, + "n_layer": 2, + "d_conv": 3, + "dt_rank": 16, + "d_state": 8, + "d_model": 64, + "dropout": 0.2, + "bias": True, + "conv_bias": True, + "max_length": max_length, + "start_token": 0, + "end_token": 50, } # Initialize model model = Mamba(**hyperparams) rngs = jax.random.PRNGKey(0) rngs, dropout_rng = jax.random.split(rngs) -params = model.init({'params': rngs, 'dropout': dropout_rng}, - dummy_inputs)['params'] +params = model.init({"params": rngs, "dropout": dropout_rng}, dummy_inputs)["params"] # Call as you would a Jax/Flax model -outputs = model.apply({'params': params}, - dummy_inputs, - rngs={'dropout': dropout_rng}) +outputs = model.apply({"params": params}, dummy_inputs, rngs={"dropout": dropout_rng}) print(outputs.shape) start_tokens = jnp.array([[123, 456]]) -outputs = model.apply({'params': params}, - start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, - method=model.generate) -print(outputs) \ No newline at end of file +outputs = model.apply( + {"params": params}, + start_tokens, + rngs={"dropout": jax.random.PRNGKey(2)}, + method=model.generate, +) +print(outputs) diff --git a/nanodl/__src/models/rlhf.py b/nanodl/__src/experimental/rlhf.py similarity index 64% rename from nanodl/__src/models/rlhf.py rename to nanodl/__src/experimental/rlhf.py index 0586691..5cb3819 100644 --- a/nanodl/__src/models/rlhf.py +++ b/nanodl/__src/experimental/rlhf.py @@ -1,14 +1,15 @@ -import jax -import flax -import time import copy -import optax -import jax.numpy as jnp +import time +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable -# Still in active development + class RLHF(nn.Module): policy_network: Any reference: bool = False @@ -18,57 +19,58 @@ def setup(self) -> None: self.dense2 = nn.Dense(256) self.dense3 = nn.Dense(1) - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray]: - + def __call__( + self, x: jnp.ndarray, training: bool = False + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + logits = self.policy_network(x, training=training) log_probs = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) probs = jnp.exp(log_probs) - rng = jax.random.PRNGKey(int(time.time())) + rng = jax.random.PRNGKey(int(time.time())) action = jax.random.categorical(rng, log_probs, axis=-1) entropy = -jnp.sum(probs * log_probs, axis=-1) action_log_probs = jnp.take_along_axis(log_probs, action[:, None], axis=-1) value = self.get_value(x) if not self.reference else None return action, action_log_probs, entropy, value - + def get_value(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: hidden = self.policy_network(x, training=training, drop_last_layer=True) hidden = nn.relu(self.dense1(hidden)) hidden = nn.relu(self.dense2(hidden)) value = nn.tanh(self.dense3(hidden)) return value - + def generate(self, x: jnp.ndarray) -> jnp.ndarray: return self.policy_network.generate(x) - + def generate_batch(self, x: jnp.ndarray) -> jnp.ndarray: return self.policy_network.generate_batch(x) class PPODataParallelTrainer: - def __init__(self, - rlhf_main: Any, - rlhf_ref: Any, - reward_model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - gamma: float = 0.99, - beta: float = 0.2, - lam: float = 0.95, - ent_coef: float = 0.01, - vf_coef: float = 0.5, - learning_rate: float = 1e-4, - params_path: Optional[str] = None, - sft_params_path: Optional[str] = None, - reward_params_path: Optional[str] = None, - ) -> None: - + def __init__( + self, + rlhf_main: Any, + rlhf_ref: Any, + reward_model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + gamma: float = 0.99, + beta: float = 0.2, + lam: float = 0.95, + ent_coef: float = 0.01, + vf_coef: float = 0.5, + learning_rate: float = 1e-4, + params_path: Optional[str] = None, + sft_params_path: Optional[str] = None, + reward_params_path: Optional[str] = None, + ) -> None: + self.rlhf_main = rlhf_main self.reward_model = reward_model self.rlhf_ref = rlhf_ref - self.gamma = gamma + self.gamma = gamma self.lam = lam self.beta = beta self.epsilon = 1.0e-8 @@ -80,67 +82,77 @@ def __init__(self, self.params_path = params_path self.sft_params = self.load_params(sft_params_path) - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - reward_params = self.reward_model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + reward_params = self.reward_model.init( + rngs, jnp.ones(input_shape, dtype=jnp.int32) + )["params"] self.reward_params = self.load_params(reward_params_path, params=reward_params) self.num_parameters = None self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(PPODataParallelTrainer.train_step, axis_name='devices') + self.train_step = jax.pmap( + PPODataParallelTrainer.train_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - - - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.rlhf_main.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))['params'] - params['policy_network']['decoder'] = self.sft_params['decoder'] + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.rlhf_main.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))[ + "params" + ] + params["policy_network"]["decoder"] = self.sft_params["decoder"] self.ref_params = copy.deepcopy(params) if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.rlhf_main.apply, - params=params, - tx=optax.adam(learning_rate)) - + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.rlhf_main.apply, params=params, tx=optax.adam(learning_rate) + ) + return jax.device_put_replicated(state, jax.local_devices()) - - def compute_agent_objective(self, model_logits, sft_logits, reward_score, gamma, beta): - ratio = nn.log_softmax(model_logits, axis=-1) - nn.log_softmax(sft_logits, axis=-1) + def compute_agent_objective( + self, model_logits, sft_logits, reward_score, gamma, beta + ): + ratio = nn.log_softmax(model_logits, axis=-1) - nn.log_softmax( + sft_logits, axis=-1 + ) left = jnp.mean(reward_score - beta * ratio.mean(axis=-1)) right = gamma * nn.log_softmax(model_logits, axis=-1).mean(axis=-1) return left + right - + def advantage_and_return(self, rewards, values): rewards = jnp.expand_dims(rewards, axis=0) values = jnp.expand_dims(values, axis=0) - + gen_len = rewards.shape[1] lastgaelam = 0 advantages_reversed = [] - + for t in reversed(range(gen_len)): nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0 delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] lastgaelam = delta + self.gamma * self.lam * lastgaelam advantages_reversed.append(lastgaelam) - + # Reversing and stacking to create the correct shape for advantages advantages = jnp.vstack(advantages_reversed[::-1]).T returns = advantages + values advantages = jnp.squeeze(advantages, axis=0) returns = jnp.squeeze(returns, axis=0) return advantages, returns - + def calculate_loss(self, logprobs, values, entropies, ref_logprobs, rewards): ratio = jnp.exp(logprobs - ref_logprobs) clipped_ratio = jnp.clip(ratio, 1 - self.epsilon, 1 + self.epsilon) @@ -151,47 +163,60 @@ def calculate_loss(self, logprobs, values, entropies, ref_logprobs, rewards): pg_loss = jnp.minimum(pg_loss_1, pg_loss_2).mean() loss = pg_loss - self.ent_coef * entropies.mean() + self.vf_coef * value_loss return loss - + def get_ref_log_probs(self, inputs: jnp.ndarray) -> jnp.ndarray: - return self.rlhf_ref.apply({'params': self.ref_params}, - inputs, training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - + return self.rlhf_ref.apply( + {"params": self.ref_params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + def get_rewards(self, inputs: jnp.ndarray) -> jnp.ndarray: - responses = self.rlhf_main.apply({'params': self.params}, - inputs, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}, - method=self.rlhf_main.generate_batch) - return self.reward_model.apply({'params': self.reward_params}, - responses, - training=False, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - - def train_step(self, - state: Any, - inputs: jnp.ndarray, - ref_log_probs: jnp.ndarray, - rewards: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + responses = self.rlhf_main.apply( + {"params": self.params}, + inputs, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + method=self.rlhf_main.generate_batch, + ) + return self.reward_model.apply( + {"params": self.reward_params}, + responses, + training=False, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + + def train_step( + self, + state: Any, + inputs: jnp.ndarray, + ref_log_probs: jnp.ndarray, + rewards: jnp.ndarray, + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - _, action_log_probs, entropy, value = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - - - - return self.calculate_loss(action_log_probs, value, entropy, ref_log_probs, rewards) - + _, action_log_probs, entropy, value = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + + return self.calculate_loss( + action_log_probs, value, entropy, ref_log_probs, rewards + ) + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -202,46 +227,51 @@ def train(self, batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) - ref_log_probs = ref_log_probs.reshape((self.num_devices, batch_size_per_device, -1)) + ref_log_probs = ref_log_probs.reshape( + (self.num_devices, batch_size_per_device, -1) + ) rewards = rewards.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - ref_log_probs=ref_log_probs, - rewards=rewards) + self.state, loss = self.train_step( + state=self.state, + inputs=inputs, + ref_log_probs=ref_log_probs, + rewards=rewards, + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + def merge_params(self, untrained_params, trained_params): updated_untrained_params = jax.tree_map( - lambda untrained, trained: trained if untrained.shape == trained.shape else untrained, - untrained_params, - trained_params) + lambda untrained, trained: ( + trained if untrained.shape == trained.shape else untrained + ), + untrained_params, + trained_params, + ) return updated_untrained_params def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str, params=None): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: params = self.params if params is None else params self.params = flax.serialization.from_bytes(params, f.read()) return self.params - - # from nanodl import ArrayDataset, DataLoader @@ -304,12 +334,12 @@ def load_params(self, filename: str, params=None): # rlhf_ref = RLHF(model, reference=True) # dataset = ArrayDataset(dummy_chosen) # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) -# trainer = PPODataParallelTrainer(rlhf_model, -# rlhf_ref, -# reward_model, -# dummy_inputs.shape, +# trainer = PPODataParallelTrainer(rlhf_model, +# rlhf_ref, +# reward_model, +# dummy_inputs.shape, # rlhf_params_path, # sft_params_path=model_params_path, # reward_params_path=reward_params_path) -# trainer.train(dataloader, 2) \ No newline at end of file +# trainer.train(dataloader, 2) diff --git a/nanodl/__src/utils/tokenizer.py b/nanodl/__src/experimental/tokenizer.py similarity index 83% rename from nanodl/__src/utils/tokenizer.py rename to nanodl/__src/experimental/tokenizer.py index 65ecf7d..0f08e82 100644 --- a/nanodl/__src/utils/tokenizer.py +++ b/nanodl/__src/experimental/tokenizer.py @@ -1,26 +1,28 @@ import os from typing import List, Optional + from sentencepiece import SentencePieceProcessor, SentencePieceTrainer + class Tokenizer: """ A tokenizer class that utilizes SentencePiece to encode and decode text. - + This class can be initialized with either an existing SentencePiece model or a dataset to train a new model. It provides methods to encode a string to a list of token ids and decode a list of token ids back to a string. - + Attributes: sp_model (SentencePieceProcessor): The SentencePiece processor. n_words (int): Number of words in the vocabulary. bos_id (int): Token id for the beginning of a sentence. eos_id (int): Token id for the end of a sentence. pad_id (int): Token id for padding. - + Example usage: - + Training a new model and encoding/decoding a string: - + ```python # Initialize tokenizer with training data and train a new model. text_paths = ['/Users/mac1/Desktop/nanodl/nanodl/__src/utils/sample.txt'] @@ -29,44 +31,47 @@ class Tokenizer: vocab_size=100, model_type='bpe', max_sentence_length=50) - + # Encode a sentence. encoded_sentence = tokenizer.encode('Hello, world!') print(f'Encoded: {encoded_sentence}') - + # Decode the encoded sentence. decoded_sentence = tokenizer.decode(encoded_sentence) print(f'Decoded: {decoded_sentence}') ``` - + Loading an existing model and encoding/decoding a string: - + ```python # Initialize tokenizer with a pre-trained model. tokenizer = Tokenizer(model_path='path/to/model.model') - + # Encode a sentence. encoded_sentence = tokenizer.encode('Hello, world!') print(f'Encoded: {encoded_sentence}') - + # Decode the encoded sentence. decoded_sentence = tokenizer.decode(encoded_sentence) print(f'Decoded: {decoded_sentence}') ``` """ - def __init__(self, - training_data: List[str] = None, - vocab_size: int = None, - model_type: str = "bpe", - max_sentence_length: int = 512, - model_path: Optional[str] = None): - + + def __init__( + self, + training_data: List[str] = None, + vocab_size: int = None, + model_type: str = "bpe", + max_sentence_length: int = 512, + model_path: Optional[str] = None, + ): + if model_path and os.path.isfile(model_path): # Load an existing model self.sp_model = SentencePieceProcessor(model_file=model_path) elif training_data and all(os.path.isfile(f) for f in training_data): # Train a new model using a list of data files - input_files = ','.join(training_data) + input_files = ",".join(training_data) model_prefix = "trained_model" SentencePieceTrainer.train( input=input_files, @@ -78,7 +83,9 @@ def __init__(self, self.sp_model = SentencePieceProcessor(model_file=f"{model_prefix}.model") else: - raise ValueError("Must provide either a model_path or a non-empty training_data list") + raise ValueError( + "Must provide either a model_path or a non-empty training_data list" + ) # Initialize token IDs self.n_words: int = self.sp_model.vocab_size() @@ -88,10 +95,7 @@ def __init__(self, assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - def encode(self, - s: str, - bos: bool = True, - eos: bool = False) -> List[int]: + def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: """Converts a string into a list of tokens.""" assert isinstance(s, str) t = self.sp_model.encode(s) @@ -101,7 +105,6 @@ def encode(self, t = t + [self.eos_id] return t - def decode(self, - t: List[int]) -> str: + def decode(self, t: List[int]) -> str: """Converts a list of tokens back into a string.""" return self.sp_model.decode(t) diff --git a/nanodl/__src/layers/general.py b/nanodl/__src/layers/general.py deleted file mode 100644 index 94411be..0000000 --- a/nanodl/__src/layers/general.py +++ /dev/null @@ -1,34 +0,0 @@ -import jax -import time -import jax.numpy as jnp -from jax import random - -def dropout(x: jnp.ndarray, - rate: float, - training: bool = False) -> jnp.ndarray: - """Apply dropout to input tensor. - - Args: - x (jnp.ndarray): Input tensor. - rate (float): Dropout rate, must be between 0 and 1. - training (bool, optional): Whether to apply dropout. - If False, returns input tensor unchanged. Defaults to False. - - Raises: - ValueError: If dropout rate is not in [0, 1). - - Returns: - jnp.ndarray: Tensor after applying dropout. - """ - if not training: - return x - - if not 0 <= rate < 1: - raise ValueError("Dropout rate must be in the range [0, 1).") - - if rate == 0: - return x - - keep_prob = 1 - rate - mask = jax.random.bernoulli(random.PRNGKey(int(time.time())), keep_prob, x.shape) - return jax.lax.select(mask, x / keep_prob, jnp.zeros_like(x)) \ No newline at end of file diff --git a/nanodl/__src/layers/attention.py b/nanodl/__src/models/attention.py similarity index 53% rename from nanodl/__src/layers/attention.py rename to nanodl/__src/models/attention.py index 156d1e3..2e7c0db 100644 --- a/nanodl/__src/layers/attention.py +++ b/nanodl/__src/models/attention.py @@ -1,6 +1,7 @@ +import flax.linen as nn import jax import jax.numpy as jnp -import flax.linen as nn + class MultiQueryAttention(nn.Module): """Multi-Query Attention module. @@ -18,44 +19,50 @@ class MultiQueryAttention(nn.Module): hidden_dim (int): The output dimension of the attention module. num_heads (int): The number of parallel attention heads. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # To ensure dimensions are compatible assert self.hidden_dim % self.num_heads <= 0 - self.query_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) key = jnp.repeat(key, self.num_heads, axis=-1) value = jnp.repeat(value, self.num_heads, axis=-1) - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = value.shape[1] context_length = key.shape[1] @@ -63,27 +70,38 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - -class RotaryPositionalEncoding(): +class RotaryPositionalEncoding: def __init__(self, dim_model: int): super().__init__() self.dim_model = dim_model - inv_freq = 1.0 / (10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model)) + inv_freq = 1.0 / ( + 10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model) + ) self.inv_freq = inv_freq self._seq_len_cached = None @@ -113,12 +131,14 @@ def apply_rotary_pos_emb(self, x, cos, sin): return (x * cos) + (self.rotate_half(x) * sin) def __call__(self, q, k): - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + k, seq_dimension=-2 + ) return ( self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)[0], self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)[0], ) - + class RotaryMultiHeadAttention(nn.Module): """Rotary Multi-Head Attention module. @@ -138,42 +158,48 @@ class RotaryMultiHeadAttention(nn.Module): hidden_dim (int): The output dimension of the attention module. num_heads (int): The number of parallel attention heads. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.rope = RotaryPositionalEncoding(self.hidden_dim*self.num_heads) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.rope = RotaryPositionalEncoding(self.hidden_dim * self.num_heads) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - query, key = self.rope(query, key) # Encode query and key with RoPE - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + query, key = self.rope(query, key) # Encode query and key with RoPE + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = value.shape[1] context_length = key.shape[1] @@ -181,17 +207,27 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights @@ -215,58 +251,71 @@ class GatedMultiHeadAttention(nn.Module): hidden_dim (int): The output dimension of the attention module. num_heads (int): The number of parallel attention heads. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) + self.query_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) self.gate = nn.Dense(features=1) - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - context_vectors, attention = self.attention_function(query,key,value,mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - - def attention_function(self, query, key, value,mask=None): + + def attention_function(self, query, key, value, mask=None): input_length = value.shape[1] context_length = key.shape[1] head_dim = query.shape[-1] // self.num_heads dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) probabilities = jax.nn.sigmoid(self.gate(value_heads)) - booleans = jax.random.bernoulli(jax.random.PRNGKey(0), probabilities) + booleans = jax.random.bernoulli(jax.random.PRNGKey(0), probabilities) gate = jnp.where(booleans, 1.0, 0.0) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) attention_scores * gate if mask is not None: @@ -274,9 +323,11 @@ def attention_function(self, query, key, value,mask=None): attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class HierarchicalMultiHeadAttention(nn.Module): """Hierarchical Multi-Head Attention module. @@ -302,51 +353,62 @@ class HierarchicalMultiHeadAttention(nn.Module): hidden_dim (int): The output dimension of the attention module. num_heads (int): The number of parallel attention heads. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # Because the Query is determined from a context, project separately - self.word_query_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.word_key_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.word_value_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.word_output = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.sentence_query_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.sentence_key_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.sentence_value_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.sentence_output = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - word_inputs: jnp.ndarray, - word_context: jnp.ndarray, - sentence_inputs: jnp.ndarray, - sentence_context: jnp.ndarray, - word_mask: jnp.ndarray = None, - sentence_mask: jnp.ndarray = None) -> tuple: + self.word_query_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.word_key_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.word_value_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.word_output = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.sentence_query_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.sentence_key_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.sentence_value_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.sentence_output = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, + word_inputs: jnp.ndarray, + word_context: jnp.ndarray, + sentence_inputs: jnp.ndarray, + sentence_context: jnp.ndarray, + word_mask: jnp.ndarray = None, + sentence_mask: jnp.ndarray = None, + ) -> tuple: """Computes the hierarchical multi-head attention. Args: @@ -368,22 +430,20 @@ def __call__(self, word_queries = self.word_query_projection(word_inputs) word_keys = self.word_key_projection(word_context) word_values = self.word_value_projection(word_context) - word_attention, word_context_vectors = self.attention_function(word_queries, - word_keys, - word_values, - mask=word_mask) - + word_attention, word_context_vectors = self.attention_function( + word_queries, word_keys, word_values, mask=word_mask + ) + sentence_queries = self.sentence_query_projection(sentence_inputs) sentence_keys = self.sentence_key_projection(sentence_context) sentence_values = self.sentence_value_projection(sentence_context) - sentence_attention, sentence_context_vectors = self.attention_function(sentence_queries, - sentence_keys, - sentence_values, - mask=sentence_mask) + sentence_attention, sentence_context_vectors = self.attention_function( + sentence_queries, sentence_keys, sentence_values, mask=sentence_mask + ) word_outputs = self.word_output(word_context_vectors) sentence_outputs = self.sentence_output(sentence_context_vectors) return word_outputs, sentence_outputs, word_attention, sentence_attention - + def attention_function(self, query, key, value, mask=None): input_length = value.shape[1] context_length = key.shape[1] @@ -391,21 +451,30 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - class LocalMultiHeadAttention(nn.Module): """Local Multi-Head Attention module. @@ -425,32 +494,35 @@ class LocalMultiHeadAttention(nn.Module): window_size (int, optional): The size of the local attention window. Default is 3. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads - window_size : int = 3 + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads + window_size: int = 3 def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim*self.num_headsm, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim*self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim * self.num_headsm, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim * self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__(self, inputs: jnp.ndarray, context: jnp.ndarray) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) @@ -458,10 +530,12 @@ def __call__(self, local_mask = self.create_local_attention_mask(query.shape[1], key.shape[1]) - context_vectors, attention = self.attention_function(query,key,value,mask=local_mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=local_mask + ) outputs = self.output(context_vectors) return outputs, attention - + def create_local_attention_mask(self, input_length, context_length): # Create a matrix with shape (input_length, context_length) mask = jnp.ones((input_length, context_length)) @@ -473,7 +547,7 @@ def create_local_attention_mask(self, input_length, context_length): mask = mask.at[i, :start].set(0) mask = mask.at[i, end:].set(0) return mask - + def attention_function(self, query, key, value, mask=None): input_length = value.shape[1] context_length = key.shape[1] @@ -481,15 +555,25 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) - return attended_values, attention_weights \ No newline at end of file + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) + return attended_values, attention_weights diff --git a/nanodl/__src/models/clip.py b/nanodl/__src/models/clip.py index a2da98a..31601d6 100644 --- a/nanodl/__src/models/clip.py +++ b/nanodl/__src/models/clip.py @@ -1,12 +1,12 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp -import flax.linen as nn +from typing import Any, Iterable, Optional, Tuple +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Any, Iterable, Optional, Tuple, Dict class PositionalEncoding(nn.Module): @@ -23,19 +23,27 @@ class PositionalEncoding(nn.Module): setup(): Initializes the positional encoding matrix based on the provided attributes. __call__(x: jnp.ndarray): Adds positional encodings to the input embeddings. """ + num_embeddings: int features: int def setup(self): positional_encoding = jnp.zeros((self.features, self.num_embeddings)) position = jnp.arange(0, self.features, dtype=jnp.float32)[:, None] - div_term = jnp.exp(jnp.arange(0, self.num_embeddings, 2) * (-jnp.log(10000.0) / self.num_embeddings)) - positional_encoding = positional_encoding.at[:, 0::2].set(jnp.sin(position * div_term)) - positional_encoding = positional_encoding.at[:, 1::2].set(jnp.cos(position * div_term)) + div_term = jnp.exp( + jnp.arange(0, self.num_embeddings, 2) + * (-jnp.log(10000.0) / self.num_embeddings) + ) + positional_encoding = positional_encoding.at[:, 0::2].set( + jnp.sin(position * div_term) + ) + positional_encoding = positional_encoding.at[:, 1::2].set( + jnp.cos(position * div_term) + ) self.positional_encoding = positional_encoding.T def __call__(self, x): - x = x + self.positional_encoding[:x.shape[1]] + x = x + self.positional_encoding[: x.shape[1]] return x @@ -55,18 +63,25 @@ class TokenAndPositionEmbedding(nn.Module): setup(): Initializes token and positional embeddings. __call__(x: jnp.ndarray): Applies token embeddings and adds positional information to the input sequence. """ - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool - + + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool + def setup(self): - self.token_embeddings = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim) + self.token_embeddings = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) if self.learned_position: - self.position_embeddings = nn.Embed(num_embeddings=self.max_len, features=self.embed_dim) + self.position_embeddings = nn.Embed( + num_embeddings=self.max_len, features=self.embed_dim + ) else: - self.position_embeddings = PositionalEncoding(num_embeddings=self.max_len, features=self.embed_dim) + self.position_embeddings = PositionalEncoding( + num_embeddings=self.max_len, features=self.embed_dim + ) def __call__(self, x): x = self.token_embeddings(x) @@ -74,7 +89,6 @@ def __call__(self, x): return x + self.position_embeddings(jnp.arange(x.shape[1])) else: return x + self.position_embeddings(x) - class SelfMultiHeadAttention(nn.Module): @@ -92,29 +106,32 @@ class SelfMultiHeadAttention(nn.Module): __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int - num_heads : int + + hidden_dim: int + num_heads: int def setup(self): # Stack all weight matrices together for efficiency - self.projection = nn.Dense(3*self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + self.projection = nn.Dense( + 3 * self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__(self, inputs: jnp.ndarray, mask: jnp.ndarray = None) -> tuple: projections = self.projection(inputs) query, key, value = jnp.array_split(projections, 3, axis=-1) - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -122,19 +139,28 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - class PositionWiseFFN(nn.Module): @@ -151,12 +177,17 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(nn.gelu(self.dense1(X))) @@ -174,15 +205,14 @@ class AddNorm(nn.Module): Methods: __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ + dropout: int @nn.compact - def __call__(self, - X: jnp.ndarray, - Y: jnp.ndarray, - training=False) -> jnp.ndarray: + def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: return nn.LayerNorm()( - nn.Dropout(self.dropout)(Y, deterministic=not training) + X) + nn.Dropout(self.dropout)(Y, deterministic=not training) + X + ) class EncoderBlock(nn.Module): @@ -201,22 +231,23 @@ class EncoderBlock(nn.Module): setup(): Initializes the attention, feed-forward network, and normalization layers. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads) + self.attention = SelfMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.ff = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: attended_x, attention = self.attention(x, mask=mask) x = self.add_norm1(x, attended_x, training) ff_output = self.ff(x) @@ -245,32 +276,31 @@ class TextEncoder(nn.Module): setup(): Initializes the embedding layer and the encoder blocks. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the transformer encoder. """ + num_layers: int hidden_dim: int num_heads: int feedforward_dim: int dropout: float - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool = True - + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool = True def setup(self): - self.embedding = TokenAndPositionEmbedding(self.max_len, - self.vocab_size, - self.embed_dim, - self.learned_position) - self.layers = [EncoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) - for _ in range(self.num_layers)] - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: + self.embedding = TokenAndPositionEmbedding( + self.max_len, self.vocab_size, self.embed_dim, self.learned_position + ) + self.layers = [ + EncoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: attention_maps = [] x = self.embedding(x) for layer in self.layers: @@ -293,18 +323,21 @@ class PatchEmbedding(nn.Module): __call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding. extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images. """ + patch_size: Tuple[int, int] - embed_dim: int + embed_dim: int @nn.compact def __call__(self, x): x = nn.Dense(self.embed_dim)(self.extract_patches(x)) - return x + nn.Embed(num_embeddings=x.shape[1], features=x.shape[2])(jnp.arange(x.shape[1])) + return x + nn.Embed(num_embeddings=x.shape[1], features=x.shape[2])( + jnp.arange(x.shape[1]) + ) def extract_patches(self, images: jnp.ndarray) -> jnp.ndarray: if len(images.shape) != 4: raise ValueError("Input images should have shape (batch_size, H, W, C)") - + batch_size, h, w, c = images.shape ph, pw = self.patch_size @@ -316,11 +349,13 @@ def extract_patches(self, images: jnp.ndarray) -> jnp.ndarray: num_patches_w = w // pw # Reshape the images into patches and flatten each patch - patches = jnp.reshape(images, (batch_size, num_patches_h, ph, num_patches_w, pw, c)) + patches = jnp.reshape( + images, (batch_size, num_patches_h, ph, num_patches_w, pw, c) + ) patches = jnp.transpose(patches, (0, 1, 3, 2, 4, 5)) patches = jnp.reshape(patches, (batch_size, -1, ph * pw * c)) return patches - + class ImageEncoder(nn.Module): """ @@ -340,6 +375,7 @@ class ImageEncoder(nn.Module): setup(): Initializes the patch embedding and encoder blocks. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input images through the vision transformer encoder. """ + patch_size: Tuple[int, int] num_layers: int hidden_dim: int @@ -348,37 +384,36 @@ class ImageEncoder(nn.Module): dropout: float def setup(self): - self.embedding = PatchEmbedding(self.patch_size, - self.feedforward_dim) - - self.layers = [EncoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) - for _ in range(self.num_layers)] - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: + self.embedding = PatchEmbedding(self.patch_size, self.feedforward_dim) + + self.layers = [ + EncoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: attention_maps = [] x = self.embedding(x) for layer in self.layers: x, attention = layer(x, mask=mask, training=training) attention_maps.append(attention) return x, jnp.array(attention_maps) - + class CLIP(nn.Module): """ - CLIP (Contrastive Language-Image Pretraining) is designed to understand and connect vision and language. - Its motivation arises from the need to bridge the gap between textual and visual information processing in AI. - CLIP's architecture is based on a vision-language transformer, - which is pretrained on a large corpus of text and images from the internet, - allowing it to learn associations between text and visuals. - Unlike traditional models that are pretrained on single-modal data, CLIP can perform a wide range of tasks, - including image classification, zero-shot object recognition, and even generating textual descriptions for images. - CLIP's versatility and performance stem from its ability to encode and compare text and image representations directly, + CLIP (Contrastive Language-Image Pretraining) is designed to understand and connect vision and language. + Its motivation arises from the need to bridge the gap between textual and visual information processing in AI. + CLIP's architecture is based on a vision-language transformer, + which is pretrained on a large corpus of text and images from the internet, + allowing it to learn associations between text and visuals. + Unlike traditional models that are pretrained on single-modal data, CLIP can perform a wide range of tasks, + including image classification, zero-shot object recognition, and even generating textual descriptions for images. + CLIP's versatility and performance stem from its ability to encode and compare text and image representations directly, enabling it to generalize well across various vision and language tasks while minimizing the need for task-specific fine-tuning. Args: @@ -401,8 +436,8 @@ class CLIP(nn.Module): - encode_image(images): Encodes image data using the image encoder. - embed_text(texts): Embeds text data into the shared embedding space. - embed_image(images): Embeds image data into the shared embedding space. - - Note: + + Note: Text input shape: (batch_size, max_length, embed_dim) Image input shape: (batch_size, height, width, channels) Image shape after patch embedding: (batch_size, sequence_length, embed_dim) @@ -417,18 +452,18 @@ class CLIP(nn.Module): # Dummy data parameters batch_size = 8 - max_length = 50 - vocab_size = 1000 - embed_dim = 256 - patch_size = (16, 16) + max_length = 50 + vocab_size = 1000 + embed_dim = 256 + patch_size = (16, 16) # Generate dummy text and image data dummy_texts = jnp.ones((batch_size, max_length), dtype=jnp.int32) dummy_images = jnp.ones((batch_size, 224, 224, 3)) dataset = ArrayDataset(dummy_texts, dummy_images) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # CLIP model parameters @@ -453,24 +488,25 @@ class CLIP(nn.Module): loss = clip_model.apply({'params': params}, dummy_texts, dummy_images) # Training on your data - trainer = CLIPDataParallelTrainer(clip_model, - dummy_texts.shape, + trainer = CLIPDataParallelTrainer(clip_model, + dummy_texts.shape, dummy_images.shape, 'params.pkl') trainer.train(dataloader, 2) # Sample encodings - image_encodings = clip_model.apply({'params': params}, + image_encodings = clip_model.apply({'params': params}, images = dummy_images, - method=clip_model.encode_image) + method=clip_model.encode_image) print(image_encodings.shape) # Sample embeddings - image_embeddings = clip_model.apply({'params': params}, + image_embeddings = clip_model.apply({'params': params}, images = dummy_images, - method=clip_model.embed_image) + method=clip_model.embed_image) print(image_embeddings.shape) ``` """ + dropout: float num_heads: int feedforward_dim: int @@ -479,9 +515,9 @@ class CLIP(nn.Module): image_patch_size: int hidden_dim_image: int num_layers_images: int - max_len : int - vocab_size : int - embed_dim : int + max_len: int + vocab_size: int + embed_dim: int def setup(self): """ @@ -503,87 +539,72 @@ def setup(self): hidden_dim=self.hidden_dim_image, num_heads=self.num_heads, feedforward_dim=self.feedforward_dim, - dropout=self.dropout + dropout=self.dropout, ) self.text_pooler = nn.Dense(self.embed_dim) self.image_pooler = nn.Dense(self.embed_dim) - self.temperature = self.param('temperature', nn.initializers.zeros, ()) + self.temperature = self.param("temperature", nn.initializers.zeros, ()) + + def __call__( + self, texts: jnp.ndarray, images: jnp.ndarray, training: bool = False + ) -> Tuple[jnp.ndarray, jnp.ndarray, float]: - def __call__(self, - texts: jnp.ndarray, - images: jnp.ndarray, - training: bool = False) -> Tuple[jnp.ndarray, jnp.ndarray, float]: - text_latents, _ = self.text_encoder(texts, training=training) image_latents, _ = self.image_encoder(images, training=training) text_embedding = self.text_pooler(jnp.mean(text_latents, axis=1)) image_embedding = self.image_pooler(jnp.mean(image_latents, axis=1)) return self.clip_loss(text_embedding, image_embedding) - - def clip_loss(self, - text_embeddings: jnp.ndarray, - image_embeddings: jnp.ndarray) -> float: - + + def clip_loss( + self, text_embeddings: jnp.ndarray, image_embeddings: jnp.ndarray + ) -> float: + def l2_normalise(x): return x / jnp.linalg.norm(x, axis=-1, keepdims=True) def cross_entropy(preds, targets): return (-targets * jax.nn.log_softmax(preds)).sum(axis=1).mean() - + text_embeddings = l2_normalise(text_embeddings) image_embeddings = l2_normalise(image_embeddings) - similarity_matrix = image_embeddings @ text_embeddings.T / (self.temperature + 0.00001) + similarity_matrix = ( + image_embeddings @ text_embeddings.T / (self.temperature + 0.00001) + ) labels = jnp.arange(similarity_matrix.shape[0]) image_loss = cross_entropy(similarity_matrix, labels) text_loss = cross_entropy(similarity_matrix.T, labels) return (image_loss + text_loss) / 2 + def get_attention_maps( + self, texts: jnp.ndarray, images: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: - def get_attention_maps(self, - texts: jnp.ndarray, - images: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - _, text_attention = self.text_encoder(texts, training=False) _, image_attention = self.image_encoder(images, training=False) return text_attention, image_attention - def encode_text(self, - texts: jnp.ndarray) -> jnp.ndarray: - + def encode_text(self, texts: jnp.ndarray) -> jnp.ndarray: + return self.text_encoder(texts)[0] - def encode_image(self, - images: jnp.ndarray) -> jnp.ndarray: - + def encode_image(self, images: jnp.ndarray) -> jnp.ndarray: + return self.image_encoder(images)[0] - def embed_text(self, - texts: jnp.ndarray) -> jnp.ndarray: - - return self.text_pooler( - jnp.mean( - self.text_encoder(texts)[0], - axis=1 - ) - ) + def embed_text(self, texts: jnp.ndarray) -> jnp.ndarray: - def embed_image(self, - images: jnp.ndarray) -> jnp.ndarray: - - return self.image_pooler( - jnp.mean( - self.image_encoder(images)[0], - axis=1 - ) - ) + return self.text_pooler(jnp.mean(self.text_encoder(texts)[0], axis=1)) + def embed_image(self, images: jnp.ndarray) -> jnp.ndarray: + + return self.image_pooler(jnp.mean(self.image_encoder(images)[0], axis=1)) class CLIPDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -603,13 +624,16 @@ class CLIPDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - text_input_shape: Tuple[int, ...], - image_input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + text_input_shape: Tuple[int, ...], + image_input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -617,82 +641,111 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(CLIPDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(CLIPDataParallelTrainer.evaluation_step, axis_name='devices') - self.state = self.create_train_state(learning_rate, text_input_shape, image_input_shape) - print(f'Number of accelerators: {self.num_devices}') - - - def create_train_state(self, learning_rate: float, - text_input_shape: Tuple[int, ...], - image_input_shape: Tuple[int, ...]) -> Any: + self.train_step = jax.pmap( + CLIPDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + CLIPDataParallelTrainer.evaluation_step, axis_name="devices" + ) + self.state = self.create_train_state( + learning_rate, text_input_shape, image_input_shape + ) + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, + learning_rate: float, + text_input_shape: Tuple[int, ...], + image_input_shape: Tuple[int, ...], + ) -> Any: rng = jax.random.PRNGKey(0) - params = self.model.init(rng, jnp.ones(text_input_shape, dtype=jnp.int32), jnp.ones(image_input_shape))['params'] + params = self.model.init( + rng, + jnp.ones(text_input_shape, dtype=jnp.int32), + jnp.ones(image_input_shape), + )["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - texts: jnp.ndarray, - images: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - grad_fn = jax.value_and_grad(lambda params: state.apply_fn({'params': params}, - texts, - images, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))})) + def train_step( + state: Any, texts: jnp.ndarray, images: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + grad_fn = jax.value_and_grad( + lambda params: state.apply_fn( + {"params": params}, + texts, + images, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + ) loss, grads = grad_fn(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 for texts, images in train_loader: batch_size = texts.shape[0] batch_size_per_device = batch_size // self.num_devices - texts = texts.reshape((self.num_devices, batch_size_per_device, texts.shape[1])) - images = images.reshape((self.num_devices, batch_size_per_device, images.shape[1], images.shape[2], images.shape[3])) + texts = texts.reshape( + (self.num_devices, batch_size_per_device, texts.shape[1]) + ) + images = images.reshape( + ( + self.num_devices, + batch_size_per_device, + images.shape[1], + images.shape[2], + images.shape[3], + ) + ) self.state, loss = self.train_step(self.state, texts, images) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - return - + return + return + @staticmethod - def evaluation_step(state: Any, - texts: jnp.ndarray, - images: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - forward_fn = lambda params: state.apply_fn({'params': params}, texts, images) + def evaluation_step( + state: Any, texts: jnp.ndarray, images: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + forward_fn = lambda params: state.apply_fn({"params": params}, texts, images) return forward_fn(state.params) - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for texts, images in test_loader: @@ -703,15 +756,15 @@ def evaluate(self, loss = self.evaluation_step(self.state, texts, images) total_loss += jnp.mean(loss) count += 1 - + return total_loss / count def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/diffusion.py b/nanodl/__src/models/diffusion.py index 63b3abc..6019f39 100644 --- a/nanodl/__src/models/diffusion.py +++ b/nanodl/__src/models/diffusion.py @@ -1,12 +1,12 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp -import flax.linen as nn +from typing import Any, Iterable, List, Optional, Tuple +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Any, Iterable, Optional, Tuple, Dict, List class SinusoidalEmbedding(nn.Module): @@ -24,6 +24,7 @@ class SinusoidalEmbedding(nn.Module): setup(): Initializes the layer by computing the angular speeds for the sinusoidal functions based on the specified frequency range. __call__(x: jnp.ndarray): Generates the sinusoidal embeddings for the input positions. """ + embedding_dims: int embedding_min_frequency: float embedding_max_frequency: float @@ -36,9 +37,12 @@ def setup(self): self.angular_speeds = 2.0 * jnp.pi * frequencies def __call__(self, x): - embeddings = jnp.concatenate([jnp.sin(self.angular_speeds * x), jnp.cos(self.angular_speeds * x)], axis=-1) + embeddings = jnp.concatenate( + [jnp.sin(self.angular_speeds * x), jnp.cos(self.angular_speeds * x)], + axis=-1, + ) return embeddings - + class UNetResidualBlock(nn.Module): """ @@ -52,17 +56,17 @@ class UNetResidualBlock(nn.Module): Methods: __call__(x: jnp.ndarray): Processes the input tensor through the residual block and returns the result. """ + width: int @nn.compact - def __call__(self, - x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: input_width = x.shape[-1] # Define layers convolution_1 = nn.Conv(self.width, kernel_size=(1, 1)) - convolution_2 = nn.Conv(self.width, kernel_size=(3, 3), padding='SAME') - convolution_3 = nn.Conv(self.width, kernel_size=(3, 3), padding='SAME') + convolution_2 = nn.Conv(self.width, kernel_size=(3, 3), padding="SAME") + convolution_3 = nn.Conv(self.width, kernel_size=(3, 3), padding="SAME") norm = nn.GroupNorm(num_groups=2, epsilon=1e-5, use_bias=False, use_scale=False) # Residual connection @@ -76,7 +80,7 @@ def __call__(self, x = convolution_3(x) return x + residual - + class UNetDownBlock(nn.Module): """ @@ -92,14 +96,16 @@ class UNetDownBlock(nn.Module): setup(): Initializes the sequence of residual blocks. __call__(x: jnp.ndarray): Processes the input tensor through the down-sampling block and returns the result. """ + width: int block_depth: int def setup(self): - self.residual_blocks = [UNetResidualBlock(self.width) for _ in range(self.block_depth)] + self.residual_blocks = [ + UNetResidualBlock(self.width) for _ in range(self.block_depth) + ] - def __call__(self, - x: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: for block in self.residual_blocks: x = block(x) x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) @@ -120,18 +126,19 @@ class UNetUpBlock(nn.Module): setup(): Initializes the sequence of residual blocks. __call__(x: jnp.ndarray, skip: jnp.ndarray): Processes the input tensor and a skip connection from the encoding pathway through the up-sampling block and returns the result. """ + width: int block_depth: int def setup(self): - self.residual_blocks = [UNetResidualBlock(self.width) for _ in range(self.block_depth)] + self.residual_blocks = [ + UNetResidualBlock(self.width) for _ in range(self.block_depth) + ] - def __call__(self, - x: jnp.ndarray, - skip: jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, skip: jnp.ndarray) -> jnp.ndarray: B, H, W, C = x.shape upsampled_shape = (B, H * 2, W * 2, C) - x = jax.image.resize(x, shape=upsampled_shape, method='bilinear') + x = jax.image.resize(x, shape=upsampled_shape, method="bilinear") x = jnp.concatenate([x, skip], axis=-1) for block in self.residual_blocks: x = block(x) @@ -156,6 +163,7 @@ class UNet(nn.Module): setup(): Initializes the U-Net architecture including the sinusoidal embedding layer, down-sampling blocks, residual blocks, and up-sampling blocks. __call__(noisy_images: jnp.ndarray, noise_variances: jnp.ndarray): Processes noisy images and their associated noise variances through the U-Net and returns the denoised images. """ + image_size: Tuple[int, int] widths: List[int] block_depth: int @@ -164,20 +172,35 @@ class UNet(nn.Module): embed_max_freq: float def setup(self): - self.sinusoidal_embedding = SinusoidalEmbedding(self.embed_dims, self.embed_min_freq, self.embed_max_freq) - self.down_blocks = [UNetDownBlock(width, self.block_depth) for width in self.widths[:-1]] - self.residual_blocks = [UNetResidualBlock(self.widths[-1]) for _ in range(self.block_depth)] - self.up_blocks = [UNetUpBlock(width, self.block_depth) for width in reversed(self.widths[:-1])] + self.sinusoidal_embedding = SinusoidalEmbedding( + self.embed_dims, self.embed_min_freq, self.embed_max_freq + ) + self.down_blocks = [ + UNetDownBlock(width, self.block_depth) for width in self.widths[:-1] + ] + self.residual_blocks = [ + UNetResidualBlock(self.widths[-1]) for _ in range(self.block_depth) + ] + self.up_blocks = [ + UNetUpBlock(width, self.block_depth) for width in reversed(self.widths[:-1]) + ] self.convolution_1 = nn.Conv(self.widths[0], kernel_size=(1, 1)) - self.convolution_2 = nn.Conv(3, kernel_size=(1, 1), kernel_init=nn.initializers.zeros) + self.convolution_2 = nn.Conv( + 3, kernel_size=(1, 1), kernel_init=nn.initializers.zeros + ) + + def __call__( + self, noisy_images: jnp.ndarray, noise_variances: jnp.ndarray + ) -> jnp.ndarray: - def __call__(self, - noisy_images: jnp.ndarray, - noise_variances: jnp.ndarray) -> jnp.ndarray: - e = self.sinusoidal_embedding(noise_variances) - upsampled_shape = (noisy_images.shape[0], self.image_size[0], self.image_size[1], self.embed_dims) - e = jax.image.resize(e, upsampled_shape, method='nearest') + upsampled_shape = ( + noisy_images.shape[0], + self.image_size[0], + self.image_size[1], + self.embed_dims, + ) + e = jax.image.resize(e, upsampled_shape, method="nearest") x = self.convolution_1(noisy_images) x = jnp.concatenate([x, e], axis=-1) @@ -195,8 +218,8 @@ def __call__(self, outputs = self.convolution_2(x) return outputs - - + + class DiffusionModel(nn.Module): """ Implements a diffusion model for image generation using JAX. @@ -237,11 +260,11 @@ class DiffusionModel(nn.Module): images = jax.random.normal(key, input_shape) # Use your own images - dataset = ArrayDataset(images) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) + dataset = ArrayDataset(images) + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) # Create diffusion model diffusion_model = DiffusionModel(image_size, widths, block_depth) @@ -251,22 +274,23 @@ class DiffusionModel(nn.Module): # Training on your data # Note: saved params are often different from training weights, use the saved params for generation - trainer = DiffusionDataParallelTrainer(diffusion_model, - input_shape=images.shape, - weights_filename='params.pkl', + trainer = DiffusionDataParallelTrainer(diffusion_model, + input_shape=images.shape, + weights_filename='params.pkl', learning_rate=1e-4) trainer.train(dataloader, 10, dataloader) print(trainer.evaluate(dataloader)) # Generate some samples params = trainer.load_params('params.pkl') - generated_images = diffusion_model.apply({'params': params}, - num_images=5, - diffusion_steps=5, + generated_images = diffusion_model.apply({'params': params}, + num_images=5, + diffusion_steps=5, method=diffusion_model.generate) print(generated_images.shape) ``` """ + image_size: int widths: List[int] block_depth: int @@ -277,15 +301,18 @@ class DiffusionModel(nn.Module): embed_max_freq: float = 1000.0 def setup(self): - self.unet = UNet(image_size=(self.image_size, self.image_size), - widths=self.widths, - block_depth=self.block_depth, - embed_dims=self.embed_dims, - embed_min_freq=self.embed_min_freq, - embed_max_freq=self.embed_max_freq) - - def diffusion_schedule(self, - diffusion_times: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + self.unet = UNet( + image_size=(self.image_size, self.image_size), + widths=self.widths, + block_depth=self.block_depth, + embed_dims=self.embed_dims, + embed_min_freq=self.embed_min_freq, + embed_max_freq=self.embed_max_freq, + ) + + def diffusion_schedule( + self, diffusion_times: jnp.ndarray + ) -> Tuple[jnp.ndarray, jnp.ndarray]: start_angle = jnp.arccos(self.max_signal_rate) end_angle = jnp.arccos(self.min_signal_rate) diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle) @@ -293,29 +320,34 @@ def diffusion_schedule(self, noise_rates = jnp.sin(diffusion_angles) return noise_rates, signal_rates - def denoise(self, - noisy_images: jnp.ndarray, - noise_rates: jnp.ndarray, - signal_rates: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: - pred_noises = self.unet(noisy_images, noise_rates ** 2) + def denoise( + self, + noisy_images: jnp.ndarray, + noise_rates: jnp.ndarray, + signal_rates: jnp.ndarray, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + pred_noises = self.unet(noisy_images, noise_rates**2) pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates return pred_noises, pred_images - def __call__(self, - images: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: + def __call__(self, images: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: key = jax.random.PRNGKey(int(time.time())) - noises = jax.random.normal(key, shape=(images.shape[0], self.image_size, self.image_size, 3)) + noises = jax.random.normal( + key, shape=(images.shape[0], self.image_size, self.image_size, 3) + ) batch_size = images.shape[0] - diffusion_times = jax.random.uniform(key, shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0) + diffusion_times = jax.random.uniform( + key, shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0 + ) noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) noisy_images = signal_rates * images + noise_rates * noises pred_noises, pred_images = self.denoise(noisy_images, noise_rates, signal_rates) return pred_noises, pred_images - def reverse_diffusion(self, - initial_noise: jnp.ndarray, - diffusion_steps: int) -> jnp.ndarray: - + def reverse_diffusion( + self, initial_noise: jnp.ndarray, diffusion_steps: int + ) -> jnp.ndarray: + num_images = initial_noise.shape[0] step_size = 1.0 / diffusion_steps next_noisy_images = initial_noise @@ -323,30 +355,33 @@ def reverse_diffusion(self, for step in range(diffusion_steps): diffusion_times = jnp.ones((num_images, 1, 1, 1)) - step * step_size noise_rates, signal_rates = self.diffusion_schedule(diffusion_times) - pred_noises, pred_images = self.denoise(next_noisy_images, noise_rates, signal_rates) + pred_noises, pred_images = self.denoise( + next_noisy_images, noise_rates, signal_rates + ) next_diffusion_times = diffusion_times - step_size - next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times) - next_noisy_images = (next_signal_rates * pred_images + next_noise_rates * pred_noises) + next_noise_rates, next_signal_rates = self.diffusion_schedule( + next_diffusion_times + ) + next_noisy_images = ( + next_signal_rates * pred_images + next_noise_rates * pred_noises + ) return pred_images - - def generate(self, - num_images: int = 1, - diffusion_steps: int = 20) -> jnp.ndarray: - + + def generate(self, num_images: int = 1, diffusion_steps: int = 20) -> jnp.ndarray: + key = jax.random.PRNGKey(int(time.time())) - noises = jax.random.normal(key, shape=(num_images, - self.image_size, - self.image_size, - 3)) - + noises = jax.random.normal( + key, shape=(num_images, self.image_size, self.image_size, 3) + ) + return self.reverse_diffusion(noises, diffusion_steps) class DiffusionDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -365,12 +400,15 @@ class DiffusionDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-4, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-4, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -378,50 +416,60 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(DiffusionDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(DiffusionDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + DiffusionDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + DiffusionDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, jnp.ones(input_shape))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - images: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step(state: Any, images: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): key = jax.random.PRNGKey(int(time.time())) noises = jax.random.normal(key, shape=images.shape) - pred_noises, pred_images = state.apply_fn({'params': params}, - images, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return jnp.mean(jnp.square(pred_noises - noises)) + jnp.mean(jnp.square(pred_images - images)) - + pred_noises, pred_images = state.apply_fn( + {"params": params}, + images, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return jnp.mean(jnp.square(pred_noises - noises)) + jnp.mean( + jnp.square(pred_images - images) + ) + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -429,72 +477,82 @@ def train(self, images = images[0] if len(images) == 1 else images batch_size = images.shape[0] batch_size_per_device = batch_size // self.num_devices - images = images.reshape((self.num_devices, - batch_size_per_device, - images.shape[1], - images.shape[2], - images.shape[3])) - self.state, loss = self.train_step(state=self.state, - images=images) + images = images.reshape( + ( + self.num_devices, + batch_size_per_device, + images.shape[1], + images.shape[2], + images.shape[3], + ) + ) + self.state, loss = self.train_step(state=self.state, images=images) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - images: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def evaluation_step(state: Any, images: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + key = jax.random.PRNGKey(int(time.time())) noises = jax.random.normal(key, shape=images.shape) - pred_noises, pred_images = state.apply_fn({'params': state.params}, - images, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return jnp.mean(jnp.square(pred_noises - noises)) + jnp.mean(jnp.square(pred_images - images)) - - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + pred_noises, pred_images = state.apply_fn( + {"params": state.params}, + images, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return jnp.mean(jnp.square(pred_noises - noises)) + jnp.mean( + jnp.square(pred_images - images) + ) + + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for images in test_loader: images = images[0] if len(images) == 1 else images batch_size = images.shape[0] batch_size_per_device = batch_size // self.num_devices - images = images.reshape((self.num_devices, - batch_size_per_device, - images.shape[1], - images.shape[2], - images.shape[3])) + images = images.reshape( + ( + self.num_devices, + batch_size_per_device, + images.shape[1], + images.shape[2], + images.shape[3], + ) + ) loss = self.evaluation_step(self.state, images) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss - + def get_ema_weights(self, params, ema=0.999): def func(x): return x * ema + (1 - ema) * x + return jax.tree_util.tree_map(func, params) def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) self.params = self.get_ema_weights(self.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/gemma.py b/nanodl/__src/models/gemma.py index 62bef81..b73d1c0 100644 --- a/nanodl/__src/models/gemma.py +++ b/nanodl/__src/models/gemma.py @@ -1,13 +1,15 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable -class RotaryPositionalEncoding(): + +class RotaryPositionalEncoding: """ Implements rotary positional encoding (RoPE) for transformers, enhancing their ability to capture sequence order. @@ -22,10 +24,13 @@ class RotaryPositionalEncoding(): apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. """ + def __init__(self, dim_model: int): super().__init__() self.dim_model = dim_model - inv_freq = 1.0 / (10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model)) + inv_freq = 1.0 / ( + 10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model) + ) self.inv_freq = inv_freq self._seq_len_cached = None self._cos_cached = None @@ -54,12 +59,14 @@ def apply_rotary_pos_emb(self, x, cos, sin): return (x * cos) + (self.rotate_half(x) * sin) def __call__(self, q, k): - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + k, seq_dimension=-2 + ) return ( self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)[0], self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)[0], ) - + class GroupedRotaryMultiHeadAttention(nn.Module): """ @@ -78,53 +85,62 @@ class GroupedRotaryMultiHeadAttention(nn.Module): process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding and attention. attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads - num_groups : int # Number of groups to split the heads into + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads + num_groups: int # Number of groups to split the heads into def setup(self): - self.query_projection = nn.Dense(self.hidden_dim // self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros, - ) - self.key_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) + self.query_projection = nn.Dense( + self.hidden_dim // self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) self.rope = RotaryPositionalEncoding(self.hidden_dim // self.num_groups) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - + # Break query into groups and transpose to (num_groups, batch_size, seq_len, dims) # This will allow vmapping over the groups for parallelization - grouped_query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_groups, -1)) + grouped_query = jnp.reshape( + query, (query.shape[0], query.shape[1], self.num_groups, -1) + ) grouped_query = jnp.repeat(grouped_query, self.num_heads, axis=-1) grouped_query = jnp.transpose(grouped_query, (2, 0, 1, 3)) # Repeat the key and values key = jnp.repeat(key, self.num_heads, axis=-1) value = jnp.repeat(value, self.num_heads, axis=-1) - vectorized_process_group = jax.vmap(self.process_group, in_axes=(0, None, None, None)) + vectorized_process_group = jax.vmap( + self.process_group, in_axes=(0, None, None, None) + ) results = vectorized_process_group(grouped_query, key, value, mask) # Merge the groups back together context_vectors = jnp.concatenate(results[0], axis=-1) return self.output(context_vectors), results[1] - + def process_group(self, query, key, value, mask): query, key = self.rope(query, key) return self.attention_function(query, key, value, mask=mask) @@ -135,17 +151,27 @@ def attention_function(self, query, key, value, mask=None): head_dim = query.shape[-1] // self.num_heads dim_key = key.shape[-1] - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights @@ -161,6 +187,7 @@ class GemmaMLP(nn.Module): to before the gating mechanism and the subsequent projection back to the original dimensionality. """ + hidden_size: int intermediate_size: int @@ -195,6 +222,7 @@ class GemmaDecoderBlock(nn.Module): causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. __call__(x, training): Processes the input tensor through the Gemma decoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int @@ -202,34 +230,35 @@ class GemmaDecoderBlock(nn.Module): num_groups: int def setup(self): - self.attention = GroupedRotaryMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups) + self.attention = GroupedRotaryMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + ) self.feed_forward = GemmaMLP(self.feedforward_dim, self.hidden_dim) self.norm1 = nn.RMSNorm() self.norm2 = nn.RMSNorm() self.dropout1 = nn.Dropout(self.dropout) self.dropout2 = nn.Dropout(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__(self, x: jnp.ndarray, training: bool = False) -> tuple: - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], x.shape[1]) x = self.norm1(x) @@ -244,7 +273,7 @@ def __call__(self, return x, jnp.array(attention) - + class GemmaDecoder(nn.Module): """ Implements the decoder component of the LLaMA2 model. @@ -265,6 +294,7 @@ class GemmaDecoder(nn.Module): setup(): Initializes the components of the LLaMA2 decoder. __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 decoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -275,23 +305,27 @@ class GemmaDecoder(nn.Module): embed_dim: float def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [GemmaDecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.num_groups) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + GemmaDecoderBlock( + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.num_groups, + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) for layer in self.layers: @@ -302,7 +336,6 @@ def __call__(self, x = self.outputs(x) return x, jnp.array(attention_maps) - class Gemma(nn.Module): @@ -328,8 +361,8 @@ class Gemma(nn.Module): generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. - LlaMA is built upon the transformer architecture, incorporating enhancements inspired by recent advancements in the field of large language models. - These improvements are drawn from various sources, such as GPT-3, PaLM, and GPT-Neo. Notable modifications include the adoption of pre-normalization for enhanced training stability, + LlaMA is built upon the transformer architecture, incorporating enhancements inspired by recent advancements in the field of large language models. + These improvements are drawn from various sources, such as GPT-3, PaLM, and GPT-Neo. Notable modifications include the adoption of pre-normalization for enhanced training stability, employing the RMSNorm normalization function. Additionally, the ReLU non-linearity is replaced with the SwiGLU activation function, which is a variant of the GLU activation function. Absolute positional embeddings are replaced with rotary positional embeddings (RoPE), implemented at each layer of the network. For specific hyper-parameter details, refer to Table 2 in the document. @@ -353,9 +386,9 @@ class Gemma(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -386,15 +419,15 @@ class Gemma(nn.Module): params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = GemmaDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -402,15 +435,16 @@ class Gemma(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int num_heads: int num_groups: int @@ -424,32 +458,31 @@ class Gemma(nn.Module): end_token: int def setup(self): - - self.decoder = GemmaDecoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.num_groups, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> jnp.ndarray: - - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + + self.decoder = GemmaDecoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.num_groups, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" @@ -466,25 +499,34 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): @@ -497,10 +539,14 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break @@ -508,11 +554,10 @@ def generate_batch(self, return output_sequences - class GemmaDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -531,12 +576,15 @@ class GemmaDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -544,51 +592,61 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(GemmaDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(GemmaDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + GemmaDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + GemmaDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -597,35 +655,36 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -636,16 +695,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/gpt.py b/nanodl/__src/models/gpt.py index 5cd8b08..7f76ea4 100644 --- a/nanodl/__src/models/gpt.py +++ b/nanodl/__src/models/gpt.py @@ -1,11 +1,12 @@ -import jax import time +from typing import Any, Iterable, Optional, Tuple + import flax -import optax -import jax.numpy as jnp import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import List, Tuple, Any, Optional, Iterable class SelfMultiHeadAttention(nn.Module): @@ -23,30 +24,33 @@ class SelfMultiHeadAttention(nn.Module): __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int - num_heads : int + + hidden_dim: int + num_heads: int def setup(self): # Stack all weight matrices together for efficiency - self.projection = nn.Dense(3*self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + self.projection = nn.Dense( + 3 * self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__(self, inputs: jnp.ndarray, mask: jnp.ndarray = None) -> tuple: projections = self.projection(inputs) query, key, value = jnp.array_split(projections, 3, axis=-1) - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -54,19 +58,29 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) - - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) + + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -82,13 +96,18 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) self.activation = GEGLU(self.num_hiddens) - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(self.activation(self.dense1(X))) @@ -102,18 +121,20 @@ class GEGLU(nn.Module): Args: output_dim (int): Output dimension of the GLU layer. """ + output_dim: int def setup(self): - self.dense = nn.Dense(self.output_dim * 2, - kernel_init=nn.initializers.xavier_uniform()) + self.dense = nn.Dense( + self.output_dim * 2, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, inputs): x = self.dense(inputs) x, gate = x[..., : self.output_dim], x[..., self.output_dim :] tanh_res = jnp.tanh(gate * 0.7978845608 * (1 + 0.044715 * (gate**2))) return x * 0.5 * gate * (1 + tanh_res) - + class GPT3Block(nn.Module): """ @@ -132,14 +153,19 @@ class GPT3Block(nn.Module): causal_mask(batch_size, destination_dim, source_dim): Creates a causal mask to ensure that predictions for a position can depend only on known outputs at earlier positions. __call__(x, mask=None, training=False): Defines the computation performed at every call of the GPT-3 block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention1 = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.attention2 = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) + self.attention1 = SelfMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.attention2 = SelfMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.norm1 = nn.LayerNorm() self.norm2 = nn.LayerNorm() @@ -148,25 +174,25 @@ def setup(self): self.dropout2 = nn.Dropout(self.dropout) self.dropout3 = nn.Dropout(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], x.shape[1]) x = self.norm1(x) @@ -185,7 +211,7 @@ def __call__(self, x += attended_x return x, jnp.array(attention1), jnp.array(attention2) - + class GPT3Decoder(nn.Module): """ @@ -206,6 +232,7 @@ class GPT3Decoder(nn.Module): setup(): Initializes the components of the GPT-3 decoder including the embedding layer, GPT-3 blocks, and the output layer. __call__(x, mask, training, drop_last_layer): Processes the input tensor through the GPT-3 decoder, generating predictions for the next token in the sequence. """ + num_layers: int hidden_dim: int num_heads: int @@ -215,23 +242,27 @@ class GPT3Decoder(nn.Module): embed_dim: int def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [GPT3Block(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + GPT3Block( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + + def __call__( + self, + x: jnp.ndarray, + mask: jnp.ndarray = None, + training: bool = False, + drop_last_layer: bool = False, + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -242,9 +273,9 @@ def __call__(self, if not drop_last_layer: x = self.outputs(x) - + return x, jnp.array(attention_maps), jnp.array(cross_attention_maps) - + class GPT3(nn.Module): """ @@ -270,16 +301,16 @@ class GPT3(nn.Module): generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively, starting from an optional initial sequence. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. - The motivation behind GPT is to create a highly effective language model that can understand and generate human-like text. + The motivation behind GPT is to create a highly effective language model that can understand and generate human-like text. Its architecture is a decoder-only transformer trained on next-token prediction and generates autoregressively duting training. - It's pre-trained on a massive amount of text data, which allows it to learn the patterns and nuances of language. - GPT's strength lies in its ability to generalize this knowledge to perform a wide range of natural language processing tasks without the need for extensive task-specific training, + It's pre-trained on a massive amount of text data, which allows it to learn the patterns and nuances of language. + GPT's strength lies in its ability to generalize this knowledge to perform a wide range of natural language processing tasks without the need for extensive task-specific training, making it a powerful tool for various applications in language understanding and generation. GPT3 uses prelayer normalisation opposed to classic transformers Note: - This implementation excludes the modified initialization which accounts for the accumulation on the residual path with model depth. - Such an intialisation involves scaling the weights of residual layers at initialization by a factor of 1/√N where N is the number of residual layers. + This implementation excludes the modified initialization which accounts for the accumulation on the residual path with model depth. + Such an intialisation involves scaling the weights of residual layers at initialization by a factor of 1/√N where N is the number of residual layers. Rather we use 'Xavier' initialization (https://proceedings.mlr.press/v9/glorot10a.html) for the weights and 'zeros' for the biases. @@ -303,9 +334,9 @@ class GPT3(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -335,15 +366,15 @@ class GPT3(nn.Module): params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -351,15 +382,16 @@ class GPT3(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int hidden_dim: int num_heads: int @@ -372,33 +404,32 @@ class GPT3(nn.Module): end_token: int def setup(self): - self.decoder = GPT3Decoder(self.num_layers, - self.embed_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim) - - - def __call__(self, - x: jnp.ndarray, - training: bool = True, - drop_last_layer: bool = False) -> jnp.ndarray: - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + self.decoder = GPT3Decoder( + self.num_layers, + self.embed_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = True, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" - + decoder_input = x if x is not None else jnp.array([[self.start_token]]) output_sequence = [] @@ -411,25 +442,34 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): @@ -442,16 +482,20 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break return output_sequences - + class SparseMixtureOfExperts(nn.Module): """ @@ -488,25 +532,27 @@ class SparseMixtureOfExperts(nn.Module): tensor has the same batch and sequence length dimensions as the input tensor, but the last dimension is equal to num_outputs. """ + num_hiddens: int num_outputs: int num_experts: int - top_k: int # Number of top experts to use each pass + top_k: int # Number of top experts to use each pass def setup(self): - self.experts = [PositionWiseFFN(self.num_hiddens, - self.num_outputs) for _ in range(self.num_experts) - ] - self.gate = nn.Dense(self.num_experts, - kernel_init=nn.initializers.xavier_uniform() - ) - self.dense_final = nn.Dense(self.num_outputs, - kernel_init=nn.initializers.xavier_uniform() - ) + self.experts = [ + PositionWiseFFN(self.num_hiddens, self.num_outputs) + for _ in range(self.num_experts) + ] + self.gate = nn.Dense( + self.num_experts, kernel_init=nn.initializers.xavier_uniform() + ) + self.dense_final = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: gating_weights = nn.softmax(self.gate(X), axis=-1) - top_k_indices = jnp.argsort(gating_weights, axis=-1)[..., -self.top_k:] + top_k_indices = jnp.argsort(gating_weights, axis=-1)[..., -self.top_k :] expert_outputs = jnp.stack([expert(X) for expert in self.experts], axis=2) # Select only the top K expert outputs @@ -516,10 +562,14 @@ def __call__(self, X: jnp.ndarray) -> jnp.ndarray: top_k_expert_outputs = expert_outputs[batch_indices, seq_indices, top_k_indices] # Compute the gating weights for the selected top K experts - top_k_gating_weights = jnp.take_along_axis(gating_weights, top_k_indices, axis=-1) - mixed_expert_output = jnp.sum(top_k_gating_weights[..., None] * top_k_expert_outputs, axis=2) + top_k_gating_weights = jnp.take_along_axis( + gating_weights, top_k_indices, axis=-1 + ) + mixed_expert_output = jnp.sum( + top_k_gating_weights[..., None] * top_k_expert_outputs, axis=2 + ) return self.dense_final(mixed_expert_output) - + class GPT4Block(nn.Module): """ @@ -540,6 +590,7 @@ class GPT4Block(nn.Module): causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. __call__(x, mask, training): Processes the input tensor through the GPT-4 block. """ + hidden_dim: int num_heads: int feedforward_dim: int @@ -548,12 +599,15 @@ class GPT4Block(nn.Module): top_k: int def setup(self): - self.attention1 = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.attention2 = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.feed_forward = SparseMixtureOfExperts(self.feedforward_dim, - self.hidden_dim, - self.num_experts, - self.top_k) + self.attention1 = SelfMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.attention2 = SelfMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.feed_forward = SparseMixtureOfExperts( + self.feedforward_dim, self.hidden_dim, self.num_experts, self.top_k + ) self.norm1 = nn.LayerNorm() self.norm2 = nn.LayerNorm() self.norm3 = nn.LayerNorm() @@ -561,26 +615,26 @@ def setup(self): self.dropout2 = nn.Dropout(self.dropout) self.dropout3 = nn.Dropout(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], x.shape[1]) x = self.norm1(x) @@ -599,7 +653,7 @@ def __call__(self, x += attended_x return x, jnp.array(attention1), jnp.array(attention2) - + class GPT4Decoder(nn.Module): """ @@ -622,6 +676,7 @@ class GPT4Decoder(nn.Module): setup(): Initializes the components of the GPT-4 decoder. __call__(x, mask, training, drop_last_layer): Processes the input tensor through the GPT-4 decoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -633,25 +688,32 @@ class GPT4Decoder(nn.Module): top_k: int def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [GPT4Block(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.num_experts, - self.top_k) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + GPT4Block( + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.num_experts, + self.top_k, + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + + def __call__( + self, + x: jnp.ndarray, + mask: jnp.ndarray = None, + training: bool = False, + drop_last_layer: bool = False, + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -662,9 +724,9 @@ def __call__(self, if not drop_last_layer: x = self.outputs(x) - + return x, jnp.array(attention_maps), jnp.array(cross_attention_maps) - + class GPT4(nn.Module): """ @@ -712,9 +774,9 @@ class GPT4(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -744,15 +806,15 @@ class GPT4(nn.Module): params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -760,15 +822,16 @@ class GPT4(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int hidden_dim: int num_heads: int @@ -783,35 +846,34 @@ class GPT4(nn.Module): top_k: int = 2 def setup(self): - self.decoder = GPT4Decoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim, - self.num_experts, - self.top_k) - - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> jnp.ndarray: - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + self.decoder = GPT4Decoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + self.num_experts, + self.top_k, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" - + decoder_input = x if x is not None else jnp.array([[self.start_token]]) output_sequence = [] @@ -825,25 +887,34 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): @@ -856,21 +927,25 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break return output_sequences - + class GPTDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -889,12 +964,15 @@ class GPTDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -902,50 +980,61 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(GPTDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(GPTDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + GPTDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + GPTDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))['params'] + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -954,35 +1043,36 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -993,16 +1083,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/ijepa.py b/nanodl/__src/models/ijepa.py index 6b8b5ed..345c218 100644 --- a/nanodl/__src/models/ijepa.py +++ b/nanodl/__src/models/ijepa.py @@ -1,12 +1,13 @@ -import jax -import flax import time +from typing import Any, Iterable, List, Optional, Tuple + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp import optax from einops import rearrange -import jax.numpy as jnp -import flax.linen as nn from flax.training import train_state -from typing import List, Tuple, Any, Optional, Dict, Iterable class PatchEmbedding(nn.Module): @@ -21,25 +22,28 @@ class PatchEmbedding(nn.Module): embed_dim (int): Dimension of the embeddings for the patches. """ - image_size:int - patch_size:int - embed_dim:int - num_channels:int + + image_size: int + patch_size: int + embed_dim: int + num_channels: int def setup(self): self.num_patches = (self.image_size**2) // (self.patch_size**2) - + # Use sliding window from conv layer implementation to avoid "splitting" the image. self.proj = nn.Conv( - features=self.embed_dim, + features=self.embed_dim, kernel_size=(self.patch_size, self.patch_size), - strides=self.patch_size, + strides=self.patch_size, padding="VALID", ) - def __call__(self, x:jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = self.proj(x) - x = jnp.reshape(x, (x.shape[0], -1, self.embed_dim)) # (batch_size, num_patches, embed_dim) + x = jnp.reshape( + x, (x.shape[0], -1, self.embed_dim) + ) # (batch_size, num_patches, embed_dim) return x @@ -54,16 +58,16 @@ class PositionalEmbedding(nn.Module): num_patches (int): Number of patches in an image which is dependent on the patch size. """ - embed_dim:int - num_patches:int + + embed_dim: int + num_patches: int def setup(self): self.embedding = nn.Embed( - num_embeddings=self.num_patches, - features=self.embed_dim + num_embeddings=self.num_patches, features=self.embed_dim ) - def __call__(self, x:jnp.ndarray) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: positions = jnp.arange(x.shape[1])[jnp.newaxis, :].repeat(x.shape[0], axis=0) embed = self.embedding(positions) x = x + embed @@ -81,17 +85,23 @@ class MultiHeadedAttention(nn.Module): num_heads (int): Number of attention heads. """ - embed_dim:int - num_heads:int + + embed_dim: int + num_heads: int def setup(self): - self.attn_proj = nn.Dense(3 * self.embed_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - self.out_proj = nn.Dense(self.embed_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - def __call__(self, x:jnp.ndarray) -> jnp.ndarray: + self.attn_proj = nn.Dense( + 3 * self.embed_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.out_proj = nn.Dense( + self.embed_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: qkv = self.attn_proj(x) query, key, value = jnp.array_split(qkv, 3, axis=-1) query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_heads, -1)) @@ -100,13 +110,15 @@ def __call__(self, x:jnp.ndarray) -> jnp.ndarray: query = jnp.permute_dims(query, (0, 2, 1, 3)) key = jnp.permute_dims(key, (0, 2, 1, 3)) value = jnp.permute_dims(value, (0, 2, 1, 3)) - attn_weights = jnp.matmul(query, key.transpose(0, 1, 3, 2)) / (self.embed_dim **.5) + attn_weights = jnp.matmul(query, key.transpose(0, 1, 3, 2)) / ( + self.embed_dim**0.5 + ) attn_weights = nn.softmax(attn_weights, -1) attn = jnp.matmul(attn_weights, value) - attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) + attn = jnp.reshape(attn, (query.shape[0], -1, self.embed_dim)) attn = self.out_proj(attn) return attn, attn_weights - + class TransformerEncoderBlock(nn.Module): """ @@ -121,20 +133,23 @@ class TransformerEncoderBlock(nn.Module): dropout_p (float): Dropout rate. """ - embed_dim:int - num_heads:int - feed_forward_dim:int - dropout_p:float + + embed_dim: int + num_heads: int + feed_forward_dim: int + dropout_p: float def setup(self): self.norm1 = nn.LayerNorm() self.norm2 = nn.LayerNorm() - self.ff = nn.Sequential([ - nn.Dense(self.feed_forward_dim), - lambda x: nn.gelu(x), - nn.Dense(self.embed_dim) - ]) + self.ff = nn.Sequential( + [ + nn.Dense(self.feed_forward_dim), + lambda x: nn.gelu(x), + nn.Dense(self.embed_dim), + ] + ) self.attn = MultiHeadedAttention( embed_dim=self.embed_dim, @@ -143,7 +158,7 @@ def setup(self): self.dropout = nn.Dropout(self.dropout_p) - def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray: x_, attn_weights = self.attn(self.norm1(x)) x = x + x_ x = self.dropout(x, deterministic=not training) @@ -164,13 +179,14 @@ class TransformerEncoder(nn.Module): embed_dim (int): Dimensionality of inputs and outputs. num_layers (int): Number of encoder blocks. feed_forward_dim (int): Dimension of the feed-forward network. - + """ - dropout:float - num_heads:int - embed_dim:int - num_layers:int - feed_forward_dim:int + + dropout: float + num_heads: int + embed_dim: int + num_layers: int + feed_forward_dim: int def setup(self): self.layers = [ @@ -178,24 +194,24 @@ def setup(self): embed_dim=self.embed_dim, num_heads=self.num_heads, feed_forward_dim=self.feed_forward_dim, - dropout_p=self.dropout - ) for _ in range(self.num_layers) + dropout_p=self.dropout, + ) + for _ in range(self.num_layers) ] - - def __call__(self, x:jnp.ndarray, training:bool) -> jnp.ndarray: + def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray: attn_maps = [] for layer in self.layers: x, attn_weights = layer(x, training=training) attn_maps.append(attn_weights) return x, jnp.array(attn_maps) - + class IJEPA(nn.Module): """ Implements the IJEPA architecture for non-generative self-supervised learning. Ref: "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" by Mahmoud Assran et al. - + This module consists of three ViTs / Transformer Encoders; A context and target encoder and an embedding predictor. The embedding predictor is trained to predict the outputs of the target encoder given the outputs of the context encoder. @@ -218,8 +234,8 @@ class IJEPA(nn.Module): # Dummy data parameters batch_size = 8 - embed_dim = 256 - patch_size = 16 + embed_dim = 256 + patch_size = 16 image_size = 256 M=4 @@ -235,9 +251,9 @@ class IJEPA(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) data_sampler = IJEPADataSampler( @@ -265,10 +281,10 @@ class IJEPA(nn.Module): params = model.init(rngs, dummy_inputs, dummy_context_masks, dummy_target_masks)['params'] outputs, _ = model.apply( - {'params': params}, - dummy_inputs, - dummy_context_mask, - dummy_target_mask, + {'params': params}, + dummy_inputs, + dummy_context_mask, + dummy_target_mask, rngs=rngs ) @@ -281,56 +297,56 @@ class IJEPA(nn.Module): """ image_size: int - num_channels:int - patch_size:int - embed_dim:int - num_heads:int - num_layers:int - dropout_p:float - predictor_num_heads:int - predictor_bottleneck:int - predictor_num_layers:int - share_patch_embedding:bool = True + num_channels: int + patch_size: int + embed_dim: int + num_heads: int + num_layers: int + dropout_p: float + predictor_num_heads: int + predictor_bottleneck: int + predictor_num_layers: int + share_patch_embedding: bool = True def setup(self): self.num_patches = (self.image_size**2) // (self.patch_size**2) - self.feed_forward_dim = self.embed_dim*4 - self.predictor_feed_forward_dim = self.predictor_bottleneck*4 + self.feed_forward_dim = self.embed_dim * 4 + self.predictor_feed_forward_dim = self.predictor_bottleneck * 4 - create_patch_embedding = lambda:PatchEmbedding( + create_patch_embedding = lambda: PatchEmbedding( image_size=self.image_size, patch_size=self.patch_size, embed_dim=self.embed_dim, num_channels=self.num_channels, - ) + ) - if self.share_patch_embedding: # We could have the context and target decoder share the patch emebddings + if ( + self.share_patch_embedding + ): # We could have the context and target decoder share the patch emebddings patch_embedding = create_patch_embedding() self.patch_embedding = { "context": patch_embedding, - "target": patch_embedding + "target": patch_embedding, } - else: # Or have them learn different patch embeddings + else: # Or have them learn different patch embeddings self.patch_embedding = { "context": create_patch_embedding(), - "target": create_patch_embedding() + "target": create_patch_embedding(), } # because the positional embedding is constant, doesn't need to be shared. self.positional_embedding = PositionalEmbedding( - embed_dim=self.embed_dim, - num_patches=self.num_patches + embed_dim=self.embed_dim, num_patches=self.num_patches ) - self.context_encoder = TransformerEncoder( dropout=self.dropout_p, num_heads=self.num_heads, embed_dim=self.embed_dim, num_layers=self.num_layers, - feed_forward_dim=self.feed_forward_dim + feed_forward_dim=self.feed_forward_dim, ) self.target_encoder = TransformerEncoder( @@ -338,8 +354,7 @@ def setup(self): num_heads=self.num_heads, embed_dim=self.embed_dim, num_layers=self.num_layers, - feed_forward_dim=self.feed_forward_dim - + feed_forward_dim=self.feed_forward_dim, ) self.embedding_predictor = TransformerEncoder( @@ -347,13 +362,22 @@ def setup(self): num_heads=self.predictor_num_heads, embed_dim=self.predictor_bottleneck, num_layers=self.predictor_num_layers, - feed_forward_dim=self.predictor_feed_forward_dim + feed_forward_dim=self.predictor_feed_forward_dim, ) self.to_predictor_embed = nn.Dense(self.predictor_bottleneck) self.to_encoder_embed = nn.Dense(self.embed_dim) - def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndarray, training:bool=False) -> Tuple[List[Tuple[jnp.ndarray, jnp.ndarray]], List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]]: + def __call__( + self, + x: jnp.ndarray, + context_mask: jnp.ndarray, + target_mask: jnp.ndarray, + training: bool = False, + ) -> Tuple[ + List[Tuple[jnp.ndarray, jnp.ndarray]], + List[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]], + ]: x_context = self.patch_embedding["context"](x) x_context = self.positional_embedding(x_context) x_target = self.patch_embedding["target"](x) @@ -363,38 +387,49 @@ def __call__(self, x:jnp.ndarray, context_mask:jnp.ndarray, target_mask:jnp.ndar attn_weights = [] for m in range(context_mask.shape[1]): - context, context_attn_weights = self.context_encoder(x_context, training=training) - context = context * jnp.expand_dims(context_mask[:, m], -1) # (N, num_patches, E) - target, target_attn_weights = self.target_encoder(x_target, training=training) - target = target * jnp.expand_dims(target_mask[:, m], -1) # (N, num_patches, E) + context, context_attn_weights = self.context_encoder( + x_context, training=training + ) + context = context * jnp.expand_dims( + context_mask[:, m], -1 + ) # (N, num_patches, E) + target, target_attn_weights = self.target_encoder( + x_target, training=training + ) + target = target * jnp.expand_dims( + target_mask[:, m], -1 + ) # (N, num_patches, E) predicted_embeddings, embed_attn_weights = self.embedding_predictor( - self.to_predictor_embed(context), - training=training + self.to_predictor_embed(context), training=training ) predicted_embeddings = self.to_encoder_embed(predicted_embeddings) - predicted_embeddings = predicted_embeddings * jnp.expand_dims(target_mask[:, m], -1) + predicted_embeddings = predicted_embeddings * jnp.expand_dims( + target_mask[:, m], -1 + ) outputs.append((predicted_embeddings, target)) - attn_weights.append((context_attn_weights, target_attn_weights, embed_attn_weights)) + attn_weights.append( + (context_attn_weights, target_attn_weights, embed_attn_weights) + ) return (outputs, attn_weights) class IJEPADataSampler: - to_scale:Any = lambda self, x, a, b: (b-a) * x + a - random_key:int = 0 + to_scale: Any = lambda self, x, a, b: (b - a) * x + a + random_key: int = 0 random_key = jax.random.PRNGKey(random_key) def __init__( self, - image_size:int = 256, - patch_size:int = 16, - M:int = 4, - context_scale_range:tuple = (.85, 1), - target_scale_range:tuple = (.15, .2), - target_aspect_ratio_range:tuple = (.75, 1.5), - ): + image_size: int = 256, + patch_size: int = 16, + M: int = 4, + context_scale_range: tuple = (0.85, 1), + target_scale_range: tuple = (0.15, 0.2), + target_aspect_ratio_range: tuple = (0.75, 1.5), + ): self.image_size = image_size self.patch_size = patch_size @@ -402,7 +437,7 @@ def __init__( self.context_scale_range = context_scale_range self.target_scale_range = target_scale_range self.target_aspect_ratio_range = target_aspect_ratio_range - + self.h = image_size // patch_size self.w = image_size // patch_size @@ -410,25 +445,25 @@ def sample_target_block_scale(self) -> Tuple[int, int]: scale = self.to_scale( jax.random.uniform(self.random_key), self.target_scale_range[0], - self.target_scale_range[1] + self.target_scale_range[1], ) context_scale = self.to_scale( jax.random.uniform(self.random_key), self.context_scale_range[0], - self.context_scale_range[1] + self.context_scale_range[1], ) aspect_ratio = self.to_scale( jax.random.uniform(self.random_key), self.target_aspect_ratio_range[0], - self.target_aspect_ratio_range[1] + self.target_aspect_ratio_range[1], ) target_mask_scale = int(self.h * self.w * scale * context_scale) - target_h = int((target_mask_scale * aspect_ratio)**.5) - target_w = int((target_mask_scale / aspect_ratio)**.5) + target_h = int((target_mask_scale * aspect_ratio) ** 0.5) + target_w = int((target_mask_scale / aspect_ratio) ** 0.5) if target_h >= self.h: target_h -= target_h - self.h - 1 @@ -436,31 +471,46 @@ def sample_target_block_scale(self) -> Tuple[int, int]: target_w -= target_w - self.w - 1 return target_h, target_w - - def sample_context_target_blocks(self, h:int, w:int) -> Tuple[jnp.ndarray, jnp.ndarray]: - context_mask = jnp.ones((self.M, self.image_size, self.image_size)) + + def sample_context_target_blocks( + self, h: int, w: int + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + context_mask = jnp.ones((self.M, self.image_size, self.image_size)) target_mask = jnp.zeros((self.M, self.image_size, self.image_size)) for m in range(self.M): top = jax.random.randint(self.random_key, (), 0, self.h - h) left = jax.random.randint(self.random_key, (), 0, self.w - w) - context_mask = context_mask.at[m, - top*self.patch_size: (top+h)*self.patch_size, - left*self.patch_size: (left+w)*self.patch_size].set(0) - - target_mask = target_mask.at[m, - top*self.patch_size: (top+h)*self.patch_size, - left*self.patch_size: (left+w)*self.patch_size].set(1) - - context_mask = rearrange(context_mask, "m (p1 h) (p2 w) -> m (h w) (p1 p2)", p1=self.patch_size, p2=self.patch_size) - target_mask = rearrange(target_mask, "m (p1 h) (p2 w) -> m (h w) (p1 p2)", p1=self.patch_size, p2=self.patch_size) + context_mask = context_mask.at[ + m, + top * self.patch_size : (top + h) * self.patch_size, + left * self.patch_size : (left + w) * self.patch_size, + ].set(0) + + target_mask = target_mask.at[ + m, + top * self.patch_size : (top + h) * self.patch_size, + left * self.patch_size : (left + w) * self.patch_size, + ].set(1) + + context_mask = rearrange( + context_mask, + "m (p1 h) (p2 w) -> m (h w) (p1 p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + target_mask = rearrange( + target_mask, + "m (p1 h) (p2 w) -> m (h w) (p1 p2)", + p1=self.patch_size, + p2=self.patch_size, + ) context_mask = jnp.any(context_mask == 1, axis=-1) target_mask = jnp.any(target_mask == 0, axis=-1) - + return context_mask, target_mask - def __call__(self) -> Tuple[jnp.ndarray, jnp.ndarray]: h, w = self.sample_target_block_scale() @@ -471,14 +521,15 @@ def __call__(self) -> Tuple[jnp.ndarray, jnp.ndarray]: class IJEPADataParallelTrainer: def __init__( - self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename:str, - data_sampler: IJEPADataSampler, - learning_rate:float = 1e-4, - params_path: Optional[str] = None) -> None: - + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + data_sampler: IJEPADataSampler, + learning_rate: float = 1e-4, + params_path: Optional[str] = None, + ) -> None: + self.model = model self.params = None self.params_path = params_path @@ -487,125 +538,165 @@ def __init__( self.weights_filename = weights_filename self.data_sampler = data_sampler self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(IJEPADataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(IJEPADataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + IJEPADataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + IJEPADataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') + print(f"Number of accelerators: {self.num_devices}") + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} context_mask, target_mask = self.data_sampler() context_mask = jnp.repeat(context_mask[jnp.newaxis], input_shape[0], axis=0) target_mask = jnp.repeat(target_mask[jnp.newaxis], input_shape[0], axis=0) - params = self.model.init(rngs, jnp.ones(input_shape), context_mask, target_mask)['params'] + params = self.model.init( + rngs, jnp.ones(input_shape), context_mask, target_mask + )["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) @staticmethod - def train_step(state: Any, - images: jnp.ndarray, - context_mask: jnp.ndarray, - target_mask: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, + images: jnp.ndarray, + context_mask: jnp.ndarray, + target_mask: jnp.ndarray, + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): outputs, _ = state.apply_fn( - {'params': params}, + {"params": params}, images, context_mask=context_mask, target_mask=target_mask, training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))} + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, ) - losses = jnp.array([ - jnp.mean(jnp.square(outputs[i][0] - outputs[i][1])) for i in range(len(outputs)) - ]) + losses = jnp.array( + [ + jnp.mean(jnp.square(outputs[i][0] - outputs[i][1])) + for i in range(len(outputs)) + ] + ) return jnp.mean(losses) - + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 for images in train_loader: images = images[0] if len(images) == 1 else images - + batch_size = images.shape[0] batch_size_per_device = batch_size // self.num_devices - images = images.reshape((self.num_devices, batch_size_per_device, images.shape[1], images.shape[2], images.shape[3])) + images = images.reshape( + ( + self.num_devices, + batch_size_per_device, + images.shape[1], + images.shape[2], + images.shape[3], + ) + ) context_mask, target_mask = self.data_sampler() context_mask = jnp.repeat(context_mask[jnp.newaxis], batch_size, axis=0) target_mask = jnp.repeat(target_mask[jnp.newaxis], batch_size, axis=0) - context_mask = context_mask.reshape((self.num_devices, batch_size_per_device, context_mask.shape[1], context_mask.shape[2])) - target_mask = target_mask.reshape((self.num_devices, batch_size_per_device, target_mask.shape[1], target_mask.shape[2])) + context_mask = context_mask.reshape( + ( + self.num_devices, + batch_size_per_device, + context_mask.shape[1], + context_mask.shape[2], + ) + ) + target_mask = target_mask.reshape( + ( + self.num_devices, + batch_size_per_device, + target_mask.shape[1], + target_mask.shape[2], + ) + ) - self.state, loss = self.train_step(state=self.state, - images=images, - context_mask=context_mask, - target_mask=target_mask + self.state, loss = self.train_step( + state=self.state, + images=images, + context_mask=context_mask, + target_mask=target_mask, ) - + total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - images: jnp.ndarray, - context_mask: jnp.ndarray, - target_mask: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: + def evaluation_step( + state: Any, + images: jnp.ndarray, + context_mask: jnp.ndarray, + target_mask: jnp.ndarray, + ) -> Tuple[Any, jnp.ndarray]: outputs, _ = state.apply_fn( - {'params': state.params}, + {"params": state.params}, images, context_mask=context_mask, target_mask=target_mask, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))} + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, ) - losses = jnp.array([ - jnp.mean(jnp.square(outputs[i][0] - outputs[i][1])) for i in range(len(outputs)) - ]) + losses = jnp.array( + [ + jnp.mean(jnp.square(outputs[i][0] - outputs[i][1])) + for i in range(len(outputs)) + ] + ) return jnp.mean(losses) + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - total_loss = 0.0 count = 0 for images in test_loader: @@ -613,46 +704,58 @@ def evaluate(self, batch_size = images.shape[0] batch_size_per_device = batch_size // self.num_devices - images = images.reshape((self.num_devices, batch_size_per_device, images.shape[1], images.shape[2], images.shape[3])) + images = images.reshape( + ( + self.num_devices, + batch_size_per_device, + images.shape[1], + images.shape[2], + images.shape[3], + ) + ) context_mask, target_mask = self.data_sampler() context_mask = jnp.repeat(context_mask[jnp.newaxis], batch_size, axis=0) target_mask = jnp.repeat(target_mask[jnp.newaxis], batch_size, axis=0) - context_mask = context_mask.reshape(( - self.num_devices, - batch_size_per_device, - context_mask.shape[1], - context_mask.shape[2] - )) + context_mask = context_mask.reshape( + ( + self.num_devices, + batch_size_per_device, + context_mask.shape[1], + context_mask.shape[2], + ) + ) - target_mask = target_mask.reshape(( - self.num_devices, - batch_size_per_device, - target_mask.shape[1], - target_mask.shape[2] - )) + target_mask = target_mask.reshape( + ( + self.num_devices, + batch_size_per_device, + target_mask.shape[1], + target_mask.shape[2], + ) + ) loss = self.evaluation_step( - state=self.state, - images=images, + state=self.state, + images=images, context_mask=context_mask, - target_mask=target_mask + target_mask=target_mask, ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/lamda.py b/nanodl/__src/models/lamda.py index aebb566..18e59ef 100644 --- a/nanodl/__src/models/lamda.py +++ b/nanodl/__src/models/lamda.py @@ -1,11 +1,12 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import List, Tuple, Any, Optional, Dict, Iterable class RelativeMultiHeadAttention(nn.Module): @@ -23,53 +24,70 @@ class RelativeMultiHeadAttention(nn.Module): __call__(inputs, context, mask, clip): Processes the input and context tensors through the relative multi-head attention mechanism. attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors, incorporating relative position information. """ - hidden_dim : int - num_heads : int + + hidden_dim: int + num_heads: int def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None, - clip: int = 3) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, + inputs: jnp.ndarray, + context: jnp.ndarray, + mask: jnp.ndarray = None, + clip: int = 3, + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - query_relative_positions = jnp.expand_dims(jnp.arange(query.shape[2]), axis=0) + query_relative_positions = jnp.expand_dims(jnp.arange(query.shape[2]), axis=0) query_relative_positions -= jnp.expand_dims(jnp.arange(query.shape[1]), axis=1) - query_relative_positions = jnp.where(query_relative_positions < clip, query_relative_positions, clip) - query_relative_positions = jnp.where(query_relative_positions > -clip, query_relative_positions, -clip) + query_relative_positions = jnp.where( + query_relative_positions < clip, query_relative_positions, clip + ) + query_relative_positions = jnp.where( + query_relative_positions > -clip, query_relative_positions, -clip + ) query += query_relative_positions - value_relative_positions = jnp.expand_dims(jnp.arange(value.shape[2]), axis=0) + value_relative_positions = jnp.expand_dims(jnp.arange(value.shape[2]), axis=0) value_relative_positions -= jnp.expand_dims(jnp.arange(value.shape[1]), axis=1) - value_relative_positions = jnp.where(value_relative_positions < clip, value_relative_positions, clip) - value_relative_positions = jnp.where(value_relative_positions > -clip, value_relative_positions, -clip) + value_relative_positions = jnp.where( + value_relative_positions < clip, value_relative_positions, clip + ) + value_relative_positions = jnp.where( + value_relative_positions > -clip, value_relative_positions, -clip + ) value += value_relative_positions - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -77,19 +95,29 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) - - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) + + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -105,17 +133,22 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) self.activation = GEGLU(self.num_hiddens) - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(self.activation(self.dense1(X))) - + class AddNorm(nn.Module): """ @@ -129,17 +162,16 @@ class AddNorm(nn.Module): Methods: __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ + dropout: int @nn.compact - def __call__(self, - X: jnp.ndarray, - Y: jnp.ndarray, - training=False) -> jnp.ndarray: - + def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: + return nn.LayerNorm()( - nn.Dropout(self.dropout)(Y, deterministic=not training) + X) - + nn.Dropout(self.dropout)(Y, deterministic=not training) + X + ) + class GEGLU(nn.Module): """ @@ -149,18 +181,20 @@ class GEGLU(nn.Module): Args: output_dim (int): Output dimension of the GLU layer. """ + output_dim: int def setup(self): - self.dense = nn.Dense(self.output_dim * 2, - kernel_init=nn.initializers.xavier_uniform()) + self.dense = nn.Dense( + self.output_dim * 2, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, inputs): x = self.dense(inputs) x, gate = x[..., : self.output_dim], x[..., self.output_dim :] tanh_res = jnp.tanh(gate * 0.7978845608 * (1 + 0.044715 * (gate**2))) return x * 0.5 * gate * (1 + tanh_res) - + class LaMDABlock(nn.Module): """ @@ -179,39 +213,44 @@ class LaMDABlock(nn.Module): causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. __call__(x, mask, training): Processes the input tensor through the LaMDA block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention1 = RelativeMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.attention2 = RelativeMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) + self.attention1 = RelativeMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.attention2 = RelativeMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) self.add_norm3 = AddNorm(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], x.shape[1]) attended_x, attention1 = self.attention1(x, x, mask=mask) @@ -224,7 +263,7 @@ def __call__(self, x = self.add_norm3(x, linear_output, training) return x, jnp.array(attention1), jnp.array(attention2) - + class LaMDADecoder(nn.Module): """ @@ -245,6 +284,7 @@ class LaMDADecoder(nn.Module): setup(): Initializes the components of the LaMDA decoder. __call__(x, mask, training, drop_last_layer): Processes the input tensor through the LaMDA decoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -254,23 +294,27 @@ class LaMDADecoder(nn.Module): embed_dim: float def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [LaMDABlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + LaMDABlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + + def __call__( + self, + x: jnp.ndarray, + mask: jnp.ndarray = None, + training: bool = False, + drop_last_layer: bool = False, + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -278,12 +322,12 @@ def __call__(self, x, attention, cross_attention = layer(x, mask=mask, training=training) attention_maps.append(attention) cross_attention_maps.append(cross_attention) - + if not drop_last_layer: x = self.outputs(x) - + return x, jnp.array(attention_maps), jnp.array(cross_attention_maps) - + class LaMDA(nn.Module): """ @@ -309,15 +353,15 @@ class LaMDA(nn.Module): generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. - LaMBDA, which stands for "Language Model for Dialogue Applications," is a deep learning model developed by Google. - Its primary motivation lies in addressing the limitations of existing conversational AI models, such as GPT-3, - by explicitly targeting dialogue applications. LaMBDA's architecture is designed to excel in multi-turn conversations, - offering improvements in several key aspects. It incorporates features like context windowing, which enables it to remember and track information over longer dialogues, - and provides better control over generating detailed responses. LaMBDA also introduces a more controllable prompt engineering mechanism, - allowing users to instruct the model more precisely for various dialogue tasks. Overall, LaMBDA represents a significant step forward in the development of conversational AI models, + LaMBDA, which stands for "Language Model for Dialogue Applications," is a deep learning model developed by Google. + Its primary motivation lies in addressing the limitations of existing conversational AI models, such as GPT-3, + by explicitly targeting dialogue applications. LaMBDA's architecture is designed to excel in multi-turn conversations, + offering improvements in several key aspects. It incorporates features like context windowing, which enables it to remember and track information over longer dialogues, + and provides better control over generating detailed responses. LaMBDA also introduces a more controllable prompt engineering mechanism, + allowing users to instruct the model more precisely for various dialogue tasks. Overall, LaMBDA represents a significant step forward in the development of conversational AI models, offering enhanced performance and usability in real-world dialogue applications. - Note: + Note: This is the architecture for LaMDA itself for now, the system is a lot more complex. At inference, LaMDA makes use of a single model to perform multiple tasks. it generates potential responses, which are then filtered for safety, grounded on an external knowledge source, and re-ranked to find the highest-quality response. @@ -341,9 +385,9 @@ class LaMDA(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -373,15 +417,15 @@ class LaMDA(nn.Module): params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = LaMDADataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -389,15 +433,16 @@ class LaMDA(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int num_heads: int hidden_dim: int @@ -410,32 +455,32 @@ class LaMDA(nn.Module): end_token: int def setup(self): - self.decoder = LaMDADecoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> jnp.ndarray: - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + self.decoder = LaMDADecoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" - + decoder_input = x if x is not None else jnp.array([[self.start_token]]) output_sequence = [] @@ -449,25 +494,34 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): @@ -480,21 +534,25 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break return output_sequences - + class LaMDADataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -513,12 +571,15 @@ class LaMDADataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -526,50 +587,61 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(LaMDADataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(LaMDADataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + LaMDADataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + LaMDADataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -578,35 +650,36 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -617,16 +690,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/llama.py b/nanodl/__src/models/llama.py index 57d8323..e24e4fb 100644 --- a/nanodl/__src/models/llama.py +++ b/nanodl/__src/models/llama.py @@ -1,14 +1,15 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable -class RotaryPositionalEncoding(): +class RotaryPositionalEncoding: """ Implements rotary positional encoding (RoPE) for transformers, enhancing their ability to capture sequence order. @@ -23,10 +24,13 @@ class RotaryPositionalEncoding(): apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. """ + def __init__(self, dim_model: int): super().__init__() self.dim_model = dim_model - inv_freq = 1.0 / (10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model)) + inv_freq = 1.0 / ( + 10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model) + ) self.inv_freq = inv_freq self._seq_len_cached = None self._cos_cached = None @@ -55,12 +59,14 @@ def apply_rotary_pos_emb(self, x, cos, sin): return (x * cos) + (self.rotate_half(x) * sin) def __call__(self, q, k): - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + k, seq_dimension=-2 + ) return ( self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)[0], self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)[0], ) - + class GroupedRotaryMultiHeadAttention(nn.Module): """ @@ -79,53 +85,62 @@ class GroupedRotaryMultiHeadAttention(nn.Module): process_group(query, key, value, mask): Processes a single group of heads through rotary positional encoding and attention. attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads - num_groups : int # Number of groups to split the heads into + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads + num_groups: int # Number of groups to split the heads into def setup(self): - self.query_projection = nn.Dense(self.hidden_dim // self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros, - ) - self.key_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) + self.query_projection = nn.Dense( + self.hidden_dim // self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) self.rope = RotaryPositionalEncoding(self.hidden_dim // self.num_groups) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - + # Break query into groups and transpose to (num_groups, batch_size, seq_len, dims) # This will allow vmapping over the groups for parallelization - grouped_query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_groups, -1)) + grouped_query = jnp.reshape( + query, (query.shape[0], query.shape[1], self.num_groups, -1) + ) grouped_query = jnp.repeat(grouped_query, self.num_heads, axis=-1) grouped_query = jnp.transpose(grouped_query, (2, 0, 1, 3)) # Repeat the key and values key = jnp.repeat(key, self.num_heads, axis=-1) value = jnp.repeat(value, self.num_heads, axis=-1) - vectorized_process_group = jax.vmap(self.process_group, in_axes=(0, None, None, None)) + vectorized_process_group = jax.vmap( + self.process_group, in_axes=(0, None, None, None) + ) results = vectorized_process_group(grouped_query, key, value, mask) # Merge the groups back together context_vectors = jnp.concatenate(results[0], axis=-1) return self.output(context_vectors), results[1] - + def process_group(self, query, key, value, mask): query, key = self.rope(query, key) return self.attention_function(query, key, value, mask=mask) @@ -136,19 +151,29 @@ def attention_function(self, query, key, value, mask=None): head_dim = query.shape[-1] // self.num_heads dim_key = key.shape[-1] - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -164,21 +189,26 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + hidden_dim: int dim: int def setup(self): - self.dense1 = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.hidden_dim, kernel_init=nn.initializers.xavier_uniform() + ) self.dense2 = nn.Dense(self.dim, kernel_init=nn.initializers.xavier_uniform()) - self.dense3 = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform()) + self.dense3 = nn.Dense( + self.hidden_dim, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(nn.silu(self.dense1(X) * self.dense3(X))) - - -class LlaMA2DecoderBlock(nn.Module): + + +class Llama3DecoderBlock(nn.Module): """ - Implements a decoder block for the LLaMA2 model, incorporating grouped rotary positional embeddings. + Implements a decoder block for the Llama3 model, incorporating grouped rotary positional embeddings. This block is designed to enhance the model's ability to understand and generate text by using grouped rotary positional embeddings for more nuanced positional encoding, alongside traditional transformer mechanisms like self-attention and feed-forward layers. @@ -190,10 +220,11 @@ class LlaMA2DecoderBlock(nn.Module): num_groups (int): Number of groups for the grouped rotary positional embeddings. Methods: - setup(): Initializes the components of the LLaMA2 decoder block. + setup(): Initializes the components of the Llama3 decoder block. causal_mask(batch_size, destination_dim, source_dim): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism. - __call__(x, training): Processes the input tensor through the LLaMA2 decoder block. + __call__(x, training): Processes the input tensor through the Llama3 decoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int @@ -201,12 +232,16 @@ class LlaMA2DecoderBlock(nn.Module): num_groups: int def setup(self): - self.attention1 = GroupedRotaryMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups) - self.attention2 = GroupedRotaryMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups) + self.attention1 = GroupedRotaryMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + ) + self.attention2 = GroupedRotaryMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + ) self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.norm1 = nn.RMSNorm() self.norm2 = nn.RMSNorm() @@ -215,25 +250,24 @@ def setup(self): self.dropout2 = nn.Dropout(self.dropout) self.dropout3 = nn.Dropout(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__(self, x: jnp.ndarray, training: bool = False) -> tuple: - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], x.shape[1]) x = self.norm1(x) @@ -253,15 +287,15 @@ def __call__(self, return x, jnp.array(attention1), jnp.array(attention2) - -class LlaMA2Decoder(nn.Module): + +class Llama3Decoder(nn.Module): """ - Implements the decoder component of the LLaMA2 model. + Implements the decoder component of the Llama3 model. - The decoder is composed of multiple LLaMA2DecoderBlocks, processing sequences of tokens to generate text. It includes an embedding layer to convert tokens into vectors and an output layer to predict the next token in the sequence. + The decoder is composed of multiple Llama3DecoderBlocks, processing sequences of tokens to generate text. It includes an embedding layer to convert tokens into vectors and an output layer to predict the next token in the sequence. Attributes: - num_layers (int): Number of LLaMA2DecoderBlocks in the decoder. + num_layers (int): Number of Llama3DecoderBlocks in the decoder. hidden_dim (int): Dimensionality of the input and output features for the blocks. num_heads (int): Number of attention heads in each block. num_groups (int): Number of groups for the grouped rotary positional embeddings in each block. @@ -271,9 +305,10 @@ class LlaMA2Decoder(nn.Module): embed_dim (float): Dimensionality of the token embeddings. Methods: - setup(): Initializes the components of the LLaMA2 decoder. - __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 decoder. + setup(): Initializes the components of the Llama3 decoder. + __call__(x, training, drop_last_layer): Processes the input tensor through the Llama3 decoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -284,23 +319,27 @@ class LlaMA2Decoder(nn.Module): embed_dim: float def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [LlaMA2DecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.num_groups) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + Llama3DecoderBlock( + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.num_groups, + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -313,17 +352,16 @@ def __call__(self, x = self.outputs(x) return x, jnp.array(attention_maps), jnp.array(cross_attention_maps) - -class LlaMA2(nn.Module): +class Llama3(nn.Module): """ - Implements the LLaMA2 model for text generation, featuring grouped rotary positional embeddings. + Implements the Llama3 model for text generation, featuring grouped rotary positional embeddings. - LLaMA2 enhances the transformer architecture by incorporating grouped rotary positional embeddings within its decoder blocks, aiming to improve the model's understanding of positional context and its ability to generate coherent and contextually relevant text. + Llama3 enhances the transformer architecture by incorporating grouped rotary positional embeddings within its decoder blocks, aiming to improve the model's understanding of positional context and its ability to generate coherent and contextually relevant text. Attributes: - num_layers (int): Number of layers (blocks) in the LLaMA2 model. + num_layers (int): Number of layers (blocks) in the Llama3 model. num_heads (int): Number of attention heads in each block. num_groups (int): Number of groups for the grouped rotary positional embeddings in each block. hidden_dim (int): Dimensionality of the input and output features for the blocks. @@ -336,13 +374,13 @@ class LlaMA2(nn.Module): end_token (int): Token that indicates the end of a generated sequence. Methods: - setup(): Initializes the LLaMA2 model including the decoder component. - __call__(x, training, drop_last_layer): Processes the input tensor through the LLaMA2 model. + setup(): Initializes the Llama3 model including the decoder component. + __call__(x, training, drop_last_layer): Processes the input tensor through the Llama3 model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. - LlaMA is built upon the transformer architecture, incorporating enhancements inspired by recent advancements in the field of large language models. - These improvements are drawn from various sources, such as GPT-3, PaLM, and GPT-Neo. Notable modifications include the adoption of pre-normalization for enhanced training stability, + Llama is built upon the transformer architecture, incorporating enhancements inspired by recent advancements in the field of large language models. + These improvements are drawn from various sources, such as GPT-3, PaLM, and GPT-Neo. Notable modifications include the adoption of pre-normalization for enhanced training stability, employing the RMSNorm normalization function. Additionally, the ReLU non-linearity is replaced with the SwiGLU activation function, which is a variant of the GLU activation function. Absolute positional embeddings are replaced with rotary positional embeddings (RoPE), implemented at each layer of the network. For specific hyper-parameter details, refer to Table 2 in the document. @@ -351,7 +389,7 @@ class LlaMA2(nn.Module): import jax import jax.numpy as jnp from nanodl import ArrayDataset, DataLoader - from nanodl import LlaMA2, LlaMADataParallelTrainer + from nanodl import Llama3, LlamaDataParallelTrainer # Generate dummy data batch_size = 8 @@ -366,9 +404,9 @@ class LlaMA2(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -393,21 +431,21 @@ class LlaMA2(nn.Module): } # Initialize model - model = LlaMA2(**hyperparams) + model = Llama3(**hyperparams) rngs = jax.random.PRNGKey(0) rngs, dropout_rng = jax.random.split(rngs) params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data - trainer = LlaMADataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer = LlamaDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -415,15 +453,16 @@ class LlaMA2(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int num_heads: int num_groups: int @@ -437,32 +476,31 @@ class LlaMA2(nn.Module): end_token: int def setup(self): - - self.decoder = LlaMA2Decoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.num_groups, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> jnp.ndarray: - - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + + self.decoder = Llama3Decoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.num_groups, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" @@ -479,25 +517,34 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): @@ -510,10 +557,14 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break @@ -521,11 +572,10 @@ def generate_batch(self, return output_sequences - -class LlaMADataParallelTrainer: +class LlamaDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -544,12 +594,15 @@ class LlaMADataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -557,51 +610,61 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(LlaMADataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(LlaMADataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + LlamaDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + LlamaDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -610,35 +673,36 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -649,16 +713,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/mistral.py b/nanodl/__src/models/mistral.py index b04bb46..352d9df 100644 --- a/nanodl/__src/models/mistral.py +++ b/nanodl/__src/models/mistral.py @@ -1,14 +1,15 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable -class RotaryPositionalEncoding(): +class RotaryPositionalEncoding: """ Implements rotary positional encoding (RoPE) for transformers, enhancing their ability to capture sequence order. @@ -23,10 +24,13 @@ class RotaryPositionalEncoding(): apply_rotary_pos_emb(x, cos, sin): Applies the rotary positional encoding to the input embeddings. __call__(q, k): Applies rotary positional encoding to query and key tensors in attention mechanisms. """ + def __init__(self, dim_model: int): super().__init__() self.dim_model = dim_model - inv_freq = 1.0 / (10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model)) + inv_freq = 1.0 / ( + 10000 ** (jnp.arange(0, dim_model, 2, dtype=jnp.float32) / dim_model) + ) self.inv_freq = inv_freq self._seq_len_cached = None self._cos_cached = None @@ -55,12 +59,14 @@ def apply_rotary_pos_emb(self, x, cos, sin): return (x * cos) + (self.rotate_half(x) * sin) def __call__(self, q, k): - self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2) + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + k, seq_dimension=-2 + ) return ( self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)[0], self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)[0], ) - + class GroupedRotaryShiftedWindowMultiHeadAttention(nn.Module): """ @@ -83,64 +89,72 @@ class GroupedRotaryShiftedWindowMultiHeadAttention(nn.Module): attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors within each window. causal_mask(shape): Generates a causal mask to ensure autoregressive properties in the self-attention mechanism within windows. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads - num_groups : int # Number of groups to split the heads into + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads + num_groups: int # Number of groups to split the heads into window_size: int shift_size: int def setup(self): - self.query_projection = nn.Dense(self.hidden_dim // self.num_heads, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros, - ) - self.key_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim // (self.num_heads * self.num_groups), - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) + self.query_projection = nn.Dense( + self.hidden_dim // self.num_heads, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim // (self.num_heads * self.num_groups), + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) self.rope = RotaryPositionalEncoding(self.hidden_dim // self.num_groups) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray) -> tuple: + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - + # Break query into groups and transpose to (num_groups, batch_size, seq_len, dims) # This will allow vmapping over the groups for parallelization - grouped_query = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_groups, -1)) + grouped_query = jnp.reshape( + query, (query.shape[0], query.shape[1], self.num_groups, -1) + ) grouped_query = jnp.repeat(grouped_query, self.num_heads, axis=-1) grouped_query = jnp.transpose(grouped_query, (2, 0, 1, 3)) # Repeat the key and values key = jnp.repeat(key, self.num_heads, axis=-1) value = jnp.repeat(value, self.num_heads, axis=-1) - vectorized_process_group = jax.vmap(self.process_group, in_axes=(0, None, None, None)) + vectorized_process_group = jax.vmap( + self.process_group, in_axes=(0, None, None, None) + ) results = vectorized_process_group(grouped_query, key, value, mask) # Merge the groups back together context_vectors = jnp.concatenate(results[0], axis=-1) return self.output(context_vectors), results[1] - + def process_group(self, query, key, value, mask): query, key = self.rope(query, key) query_windows = self.window_partition(query) key_windows = self.window_partition(key) value_windows = self.window_partition(value) - attention_windows, attention_maps = self.attention_function(query_windows, - key_windows, - value_windows, - mask) + attention_windows, attention_maps = self.attention_function( + query_windows, key_windows, value_windows, mask + ) attention_windows = jnp.roll(attention_windows, -self.shift_size, axis=1) merged = attention_windows.transpose((1, 0, 2, 3)) @@ -148,9 +162,15 @@ def process_group(self, query, key, value, mask): def window_partition(self, x): B, N, C = x.shape - assert N % self.window_size == 0, "Sequence length must be a multiple of the window size" - windows = jnp.reshape(x, (B, -1, self.window_size, C)) # (batch_size, num_windows, window_size, dim) - windows = windows.transpose((1, 0, 2, 3)) # Transpose to (num_windows, batch_size, window_size, dim) + assert ( + N % self.window_size == 0 + ), "Sequence length must be a multiple of the window size" + windows = jnp.reshape( + x, (B, -1, self.window_size, C) + ) # (batch_size, num_windows, window_size, dim) + windows = windows.transpose( + (1, 0, 2, 3) + ) # Transpose to (num_windows, batch_size, window_size, dim) return windows def attention_function(self, query, key, value, mask): @@ -160,11 +180,21 @@ def attention_function(self, query, key, value, mask): dim_key = key.shape[-1] # Split keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], query.shape[1], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], key.shape[1], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], value.shape[1], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, + (query.shape[0], query.shape[1], self.num_heads, input_length, head_dim), + ) + key_heads = jnp.reshape( + key, (key.shape[0], key.shape[1], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, + (value.shape[0], value.shape[1], self.num_heads, context_length, head_dim), + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 2, 4, 3)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 2, 4, 3) + ) / jnp.sqrt(dim_key) if mask is not None: mask = self.causal_mask(attention_scores.shape) @@ -173,23 +203,27 @@ def attention_function(self, query, key, value, mask): attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) attended_values = attended_values.transpose(0, 1, 3, 2, 4) - attended_values = jnp.reshape(attended_values, (query.shape[0], query.shape[1], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, + (query.shape[0], query.shape[1], input_length, query.shape[-1]), + ) return attended_values, attention_weights - - def causal_mask(self, - shape: Tuple[int, ...]) -> jnp.ndarray: - + + def causal_mask(self, shape: Tuple[int, ...]) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions source_dim, destination_dim = shape[-2], shape[-2] idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, None, :, :] - return jnp.broadcast_to(mask, (shape[0], shape[1], shape[2], destination_dim, source_dim)) - + return jnp.broadcast_to( + mask, (shape[0], shape[1], shape[2], destination_dim, source_dim) + ) + class PositionWiseFFN(nn.Module): """ @@ -205,18 +239,23 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + hidden_dim: int dim: int def setup(self): - self.dense1 = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.hidden_dim, kernel_init=nn.initializers.xavier_uniform() + ) self.dense2 = nn.Dense(self.dim, kernel_init=nn.initializers.xavier_uniform()) - self.dense3 = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform()) + self.dense3 = nn.Dense( + self.hidden_dim, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(nn.silu(self.dense1(X) * self.dense3(X))) - - + + class MistralDecoderBlock(nn.Module): """ Implements a decoder block for the Mistral model, incorporating grouped rotary shifted window multi-head attention. @@ -236,6 +275,7 @@ class MistralDecoderBlock(nn.Module): setup(): Initializes the components of the Mistral decoder block. __call__(x, training): Processes the input tensor through the Mistral decoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int @@ -245,18 +285,22 @@ class MistralDecoderBlock(nn.Module): shift_size: int def setup(self): - self.attention1 = GroupedRotaryShiftedWindowMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups, - window_size=self.window_size, - shift_size=self.shift_size) - - self.attention2 = GroupedRotaryShiftedWindowMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups, - window_size=self.window_size, - shift_size=self.shift_size) - + self.attention1 = GroupedRotaryShiftedWindowMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + window_size=self.window_size, + shift_size=self.shift_size, + ) + + self.attention2 = GroupedRotaryShiftedWindowMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + window_size=self.window_size, + shift_size=self.shift_size, + ) + self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.norm1 = nn.RMSNorm() self.norm2 = nn.RMSNorm() @@ -265,10 +309,8 @@ def setup(self): self.dropout2 = nn.Dropout(self.dropout) self.dropout3 = nn.Dropout(self.dropout) - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> tuple: - + def __call__(self, x: jnp.ndarray, training: bool = False) -> tuple: + x = self.norm1(x) attended_x, attention1 = self.attention1(x, x, mask=True) x = self.dropout1(x, deterministic=not training) @@ -309,6 +351,7 @@ class MistralDecoder(nn.Module): setup(): Initializes the components of the Mistral decoder. __call__(x, training, drop_last_layer): Processes the input tensor through the Mistral decoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -321,25 +364,29 @@ class MistralDecoder(nn.Module): shift_size: int def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [MistralDecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.num_groups, - self.window_size, - self.shift_size) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + MistralDecoderBlock( + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.num_groups, + self.window_size, + self.shift_size, + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -352,7 +399,6 @@ def __call__(self, x = self.outputs(x) return x, jnp.array(attention_maps), jnp.array(cross_attention_maps) - class Mistral(nn.Module): @@ -381,13 +427,13 @@ class Mistral(nn.Module): __call__(x, training, drop_last_layer): Processes the input tensor through the Mistral model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. - - Mistral 7B is a large language model (LLM) designed for enhanced efficiency and performance. It utilizes Grouped-Query Attention (GQA) to achieve quicker inference times. - It incorporates Sliding Window Attention (SWA), enabling it to efficiently process sequences of any length while minimizing the cost of inference. + + Mistral 7B is a large language model (LLM) designed for enhanced efficiency and performance. It utilizes Grouped-Query Attention (GQA) to achieve quicker inference times. + It incorporates Sliding Window Attention (SWA), enabling it to efficiently process sequences of any length while minimizing the cost of inference. Additionally, the ReLU non-linearity is replaced with the SwiGLU activation function, which is a variant of the GLU activation function. Absolute positional embeddings are replaced with rotary positional embeddings (RoPE), implemented at each layer of the network. For specific hyper-parameter details, refer to Table 2 in the document. - Mixtral is an architectural upgrade within Mistral. Leverages "Sparse Mixture-of-Experts" (MoE). Each layer has 8 expert groups, + Mixtral is an architectural upgrade within Mistral. Leverages "Sparse Mixture-of-Experts" (MoE). Each layer has 8 expert groups, but a "router network" selects only 2 relevant ones per token, reducing active calculations and boosting efficiency. Example usage: @@ -410,9 +456,9 @@ class Mistral(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -445,15 +491,15 @@ class Mistral(nn.Module): params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = MistralDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -461,15 +507,16 @@ class Mistral(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int num_heads: int num_groups: int @@ -485,29 +532,28 @@ class Mistral(nn.Module): shift_size: int def setup(self): - self.decoder = MistralDecoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.num_groups, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim, - self.window_size, - self.shift_size) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> jnp.ndarray: - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - + self.decoder = MistralDecoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.num_groups, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + self.window_size, + self.shift_size, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + def zero_pad(self, arr, max_length): - current_length = arr.shape[1] - num_zeros = max_length - current_length + current_length = arr.shape[1] + num_zeros = max_length - current_length if num_zeros > 0: zeros = jnp.zeros((arr.shape[0], num_zeros), dtype=arr.dtype) @@ -516,13 +562,14 @@ def zero_pad(self, arr, max_length): padded_array = arr return padded_array - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" @@ -531,7 +578,9 @@ def generate(self, # Autoregressive decoding loop for _ in range(self.max_length - 1): - decoder_output = self.decoder(self.zero_pad(decoder_input, self.max_length), training=False)[0] + decoder_output = self.decoder( + self.zero_pad(decoder_input, self.max_length), training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -539,29 +588,40 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) - for i in range(self.max_length-1): - decoder_output = self.decoder(self.zero_pad(decoder_input, self.max_length), training=False)[0] + for i in range(self.max_length - 1): + decoder_output = self.decoder( + self.zero_pad(decoder_input, self.max_length), training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -570,10 +630,14 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break @@ -616,34 +680,40 @@ class SparseMixtureOfExperts(nn.Module): tensor has the same batch and sequence length dimensions as the input tensor, but the last dimension is equal to num_outputs. """ + num_hiddens: int num_outputs: int num_experts: int = 8 top_k: int = 2 # Number of top experts to use def setup(self): - self.experts = [PositionWiseFFN(self.num_hiddens, - self.num_outputs) for _ in range(self.num_experts) - ] - self.gate = nn.Dense(self.num_experts, - kernel_init=nn.initializers.xavier_uniform() - ) - self.dense_final = nn.Dense(self.num_outputs, - kernel_init=nn.initializers.xavier_uniform() - ) + self.experts = [ + PositionWiseFFN(self.num_hiddens, self.num_outputs) + for _ in range(self.num_experts) + ] + self.gate = nn.Dense( + self.num_experts, kernel_init=nn.initializers.xavier_uniform() + ) + self.dense_final = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: gating_weights = nn.softmax(self.gate(X), axis=-1) - top_k_indices = jnp.argsort(gating_weights, axis=-1)[..., -self.top_k:] + top_k_indices = jnp.argsort(gating_weights, axis=-1)[..., -self.top_k :] expert_outputs = jnp.stack([expert(X) for expert in self.experts], axis=2) batch_size, seq_length, _ = X.shape batch_indices = jnp.arange(batch_size)[:, None, None] seq_indices = jnp.arange(seq_length)[None, :, None] top_k_expert_outputs = expert_outputs[batch_indices, seq_indices, top_k_indices] - top_k_gating_weights = jnp.take_along_axis(gating_weights, top_k_indices, axis=-1) - mixed_expert_output = jnp.sum(top_k_gating_weights[..., None] * top_k_expert_outputs, axis=2) + top_k_gating_weights = jnp.take_along_axis( + gating_weights, top_k_indices, axis=-1 + ) + mixed_expert_output = jnp.sum( + top_k_gating_weights[..., None] * top_k_expert_outputs, axis=2 + ) return self.dense_final(mixed_expert_output) - + class MixtralDecoderBlock(nn.Module): """ @@ -664,6 +734,7 @@ class MixtralDecoderBlock(nn.Module): setup(): Initializes the components of the Mixtral decoder block. __call__(x, training): Processes the input tensor through the Mixtral decoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int @@ -673,19 +744,25 @@ class MixtralDecoderBlock(nn.Module): shift_size: int def setup(self): - self.attention1 = GroupedRotaryShiftedWindowMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups, - window_size=self.window_size, - shift_size=self.shift_size) - - self.attention2 = GroupedRotaryShiftedWindowMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads, - num_groups=self.num_groups, - window_size=self.window_size, - shift_size=self.shift_size) - - self.feed_forward = SparseMixtureOfExperts(self.feedforward_dim, self.hidden_dim) + self.attention1 = GroupedRotaryShiftedWindowMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + window_size=self.window_size, + shift_size=self.shift_size, + ) + + self.attention2 = GroupedRotaryShiftedWindowMultiHeadAttention( + hidden_dim=self.hidden_dim, + num_heads=self.num_heads, + num_groups=self.num_groups, + window_size=self.window_size, + shift_size=self.shift_size, + ) + + self.feed_forward = SparseMixtureOfExperts( + self.feedforward_dim, self.hidden_dim + ) self.norm1 = nn.RMSNorm() self.norm2 = nn.RMSNorm() self.norm3 = nn.RMSNorm() @@ -693,10 +770,8 @@ def setup(self): self.dropout2 = nn.Dropout(self.dropout) self.dropout3 = nn.Dropout(self.dropout) - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> tuple: - + def __call__(self, x: jnp.ndarray, training: bool = False) -> tuple: + x = self.norm1(x) attended_x, attention1 = self.attention1(x, x, mask=True) x = self.dropout1(x, deterministic=not training) @@ -714,7 +789,7 @@ def __call__(self, return x, jnp.array(attention1), jnp.array(attention2) - + class MixtralDecoder(nn.Module): """ Implements the decoder component of the Mixtral model. @@ -737,6 +812,7 @@ class MixtralDecoder(nn.Module): setup(): Initializes the components of the Mixtral decoder. __call__(x, training, drop_last_layer): Processes the input tensor through the Mixtral decoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -749,25 +825,29 @@ class MixtralDecoder(nn.Module): shift_size: int def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [MixtralDecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.num_groups, - self.window_size, - self.shift_size) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + MixtralDecoderBlock( + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.num_groups, + self.window_size, + self.shift_size, + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -780,7 +860,6 @@ def __call__(self, x = self.outputs(x) return x, jnp.array(attention_maps), jnp.array(cross_attention_maps) - class Mixtral(nn.Module): @@ -809,7 +888,7 @@ class Mixtral(nn.Module): __call__(x, training, drop_last_layer): Processes the input tensor through the Mixtral model. generate(x, temperature, deterministic): Generates a sequence of tokens autoregressively. generate_batch(x, temperature, deterministic): Generates sequences of tokens for a batch of initial sequences autoregressively. - + Example usage: ``` import jax @@ -830,9 +909,9 @@ class Mixtral(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -865,15 +944,15 @@ class Mixtral(nn.Module): params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data trainer = MistralDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -881,15 +960,16 @@ class Mixtral(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int num_heads: int num_groups: int @@ -905,26 +985,25 @@ class Mixtral(nn.Module): shift_size: int def setup(self): - self.decoder = MixtralDecoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.num_groups, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim, - self.window_size, - self.shift_size) - - def __call__(self, - x: jnp.ndarray, - training: bool = False, - drop_last_layer: bool = False) -> jnp.ndarray: - - return self.decoder(x=x, - training=training, - drop_last_layer=drop_last_layer)[0] - + self.decoder = MixtralDecoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.num_groups, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + self.window_size, + self.shift_size, + ) + + def __call__( + self, x: jnp.ndarray, training: bool = False, drop_last_layer: bool = False + ) -> jnp.ndarray: + + return self.decoder(x=x, training=training, drop_last_layer=drop_last_layer)[0] + def zero_pad(self, arr, max_length): current_length = arr.shape[1] num_zeros = max_length - current_length @@ -936,13 +1015,14 @@ def zero_pad(self, arr, max_length): padded_array = arr return padded_array - - def generate(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + def generate( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> Tuple[jnp.ndarray]: + if x is not None: assert x.shape[0] == 1, "Batch size must be 1, else use generate_batch()" @@ -950,8 +1030,10 @@ def generate(self, output_sequence = [] # Autoregressive decoding loop - for _ in range(self.max_length-1): - decoder_output = self.decoder(self.zero_pad(decoder_input, self.max_length), training=False)[0] + for _ in range(self.max_length - 1): + decoder_output = self.decoder( + self.zero_pad(decoder_input, self.max_length), training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -959,29 +1041,43 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) - - if next_token.item() == self.end_token or len(output_sequence) == self.max_length: + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) + + if ( + next_token.item() == self.end_token + or len(output_sequence) == self.max_length + ): break return jnp.array(output_sequence) - - def generate_batch(self, - x: Optional[jnp.ndarray] = None, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, + x: Optional[jnp.ndarray] = None, + temperature: float = 1.0, + deterministic: bool = False, + ) -> jnp.ndarray: + batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) - for i in range(self.max_length-1): - decoder_output = self.decoder(self.zero_pad(decoder_input, self.max_length), training=False)[0] + for i in range(self.max_length - 1): + decoder_output = self.decoder( + self.zero_pad(decoder_input, self.max_length), training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -990,21 +1086,28 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) - - if jnp.all(next_token == self.end_token) or len(output_sequences) == self.max_length: + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) + + if ( + jnp.all(next_token == self.end_token) + or len(output_sequences) == self.max_length + ): break return output_sequences - - + + class MistralDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -1023,12 +1126,15 @@ class MistralDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -1036,51 +1142,61 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(MistralDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(MistralDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + MistralDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + MistralDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32))['params'] + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -1089,35 +1205,36 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -1128,16 +1245,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/mixer.py b/nanodl/__src/models/mixer.py index e86914a..53544cc 100644 --- a/nanodl/__src/models/mixer.py +++ b/nanodl/__src/models/mixer.py @@ -1,11 +1,13 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import List, Tuple, Any, Optional, Dict, Iterable + class PatchEmbedding(nn.Module): """ @@ -21,18 +23,21 @@ class PatchEmbedding(nn.Module): __call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding. extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images. """ + patch_size: Tuple[int, int] - embed_dim: int + embed_dim: int @nn.compact def __call__(self, x): x = nn.Dense(self.embed_dim)(self.extract_patches(x)) - return x + nn.Embed(num_embeddings=x.shape[1], features=x.shape[2])(jnp.arange(x.shape[1])) + return x + nn.Embed(num_embeddings=x.shape[1], features=x.shape[2])( + jnp.arange(x.shape[1]) + ) def extract_patches(self, images: jnp.ndarray) -> jnp.ndarray: if len(images.shape) != 4: raise ValueError("Input images should have shape (batch_size, H, W, C)") - + batch_size, h, w, c = images.shape ph, pw = self.patch_size @@ -44,7 +49,9 @@ def extract_patches(self, images: jnp.ndarray) -> jnp.ndarray: num_patches_w = w // pw # Reshape the images into patches and flatten each patch - patches = jnp.reshape(images, (batch_size, num_patches_h, ph, num_patches_w, pw, c)) + patches = jnp.reshape( + images, (batch_size, num_patches_h, ph, num_patches_w, pw, c) + ) patches = jnp.transpose(patches, (0, 1, 3, 2, 4, 5)) patches = jnp.reshape(patches, (batch_size, -1, ph * pw * c)) return patches @@ -59,13 +66,14 @@ class MixerBlock(nn.Module): Methods: __call__(x): Processes the input tensor through the Mixer block. """ + @nn.compact def __call__(self, x): # Create a skip connection skip = x.copy() x = nn.LayerNorm()(x) x = jnp.transpose(x, axes=(0, 2, 1)) - x = nn.gelu(nn.Dense(x.shape[-1])(x)) + x = nn.gelu(nn.Dense(x.shape[-1])(x)) x = jnp.transpose(x, axes=(0, 2, 1)) + skip skip = x.copy() x = nn.LayerNorm()(x) @@ -90,6 +98,7 @@ class MixerEncoder(nn.Module): setup(): Initializes the components of the MixerEncoder. __call__(x, training): Processes the input tensor through the encoder. """ + patch_size: Tuple[int, int] num_layers: int hidden_dim: int @@ -98,18 +107,14 @@ class MixerEncoder(nn.Module): dropout: float def setup(self): - self.embedding = PatchEmbedding(self.patch_size, - self.feedforward_dim) - - self.layers = [MixerBlock() - for _ in range(self.num_layers)] - + self.embedding = PatchEmbedding(self.patch_size, self.feedforward_dim) + + self.layers = [MixerBlock() for _ in range(self.num_layers)] + self.dropout_layer = nn.Dropout(self.dropout) - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> tuple: - + def __call__(self, x: jnp.ndarray, training: bool = False) -> tuple: + x = self.embedding(x) for layer in self.layers: x = layer(x) @@ -135,11 +140,11 @@ class Mixer(nn.Module): Methods: setup(): Initializes the components of the Mixer model. __call__(x, training): Processes the input tensor through the model and produces class logits. - - MLP Mixers are a recent architectural innovation in the field of deep learning, introduced to address the limitations of traditional Convolutional Neural Networks (CNNs) and Transformers. - The motivation behind MLP Mixers arises from the need to handle diverse data types and leverage multi-modal information efficiently. Unlike transformers that rely on self-attention mechanisms, - MLP Mixers employ a simple yet powerful approach using Multi-Layer Perceptrons (MLPs) to process data. This architecture is designed to work with sequences, images, or even a combination of both, - making it versatile for a wide range of tasks. MLP Mixers have demonstrated strong performance in various applications, including image classification, natural language understanding, and cross-modal learning, + + MLP Mixers are a recent architectural innovation in the field of deep learning, introduced to address the limitations of traditional Convolutional Neural Networks (CNNs) and Transformers. + The motivation behind MLP Mixers arises from the need to handle diverse data types and leverage multi-modal information efficiently. Unlike transformers that rely on self-attention mechanisms, + MLP Mixers employ a simple yet powerful approach using Multi-Layer Perceptrons (MLPs) to process data. This architecture is designed to work with sequences, images, or even a combination of both, + making it versatile for a wide range of tasks. MLP Mixers have demonstrated strong performance in various applications, including image classification, natural language understanding, and cross-modal learning, showcasing their potential in handling different modalities and promoting model efficiency and scalability in deep learning. Example usage: @@ -151,26 +156,26 @@ class Mixer(nn.Module): # Dummy data parameters batch_size = 8 - max_length = 50 - n_outputs = 5 - embed_dim = 256 - patch_size = (16, 16) + max_length = 50 + n_outputs = 5 + embed_dim = 256 + patch_size = (16, 16) # Generate data dummy_inputs = jnp.ones((batch_size, 224, 224, 3)) key = jax.random.PRNGKey(10) - dummy_labels = jax.random.randint(key, - shape=(batch_size,), - minval=0, + dummy_labels = jax.random.randint(key, + shape=(batch_size,), + minval=0, maxval=n_outputs-1) # Create dataset and dataloader - dataset = ArrayDataset(dummy_inputs, + dataset = ArrayDataset(dummy_inputs, dummy_labels) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # model parameters @@ -196,6 +201,7 @@ class Mixer(nn.Module): trainer.train(dataloader, 10, dataloader) ``` """ + patch_size: Tuple[int, int] num_layers: int hidden_dim: int @@ -211,23 +217,21 @@ def setup(self): hidden_dim=self.hidden_dim, num_heads=self.num_heads, feedforward_dim=self.feedforward_dim, - dropout=self.dropout + dropout=self.dropout, ) self.dropout_layer = nn.Dropout(self.dropout) self.output = nn.Dense(self.n_outputs) - def __call__(self, - x: jnp.ndarray, - training: bool = False) -> tuple: + def __call__(self, x: jnp.ndarray, training: bool = False) -> tuple: x = self.encoder(x=x, training=training) x = self.dropout_layer(x, deterministic=not training) - return self.output(x[:,0,:]), x + return self.output(x[:, 0, :]), x class MixerDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -246,12 +250,15 @@ class MixerDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -259,107 +266,137 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(MixerDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(MixerDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + MixerDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + MixerDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, jnp.ones(input_shape))['params'] + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))})[0] - return -jnp.mean(jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets]) - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + )[0] + return -jnp.mean( + jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets] + ) + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 for inputs, targets in train_loader: batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices - inputs = inputs.reshape((self.num_devices, batch_size_per_device, inputs.shape[1], inputs.shape[2], inputs.shape[3])) + inputs = inputs.reshape( + ( + self.num_devices, + batch_size_per_device, + inputs.shape[1], + inputs.shape[2], + inputs.shape[3], + ) + ) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)})[0] - return -jnp.mean(jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets]) - - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + )[0] + return -jnp.mean( + jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets] + ) + + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices - inputs = inputs.reshape((self.num_devices, batch_size_per_device, inputs.shape[1], inputs.shape[2], inputs.shape[3])) + inputs = inputs.reshape( + ( + self.num_devices, + batch_size_per_device, + inputs.shape[1], + inputs.shape[2], + inputs.shape[3], + ) + ) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/reward.py b/nanodl/__src/models/reward.py index 32401ef..17550c9 100644 --- a/nanodl/__src/models/reward.py +++ b/nanodl/__src/models/reward.py @@ -1,18 +1,19 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable class RewardModel(nn.Module): """ The RewardModel estimates the reward or value of a given input sequence, - typically used in reinforcement learning frameworks for natural language processing tasks. - It uses the last hidden state of a transformer-based model to generate a scalar reward prediction, + typically used in reinforcement learning frameworks for natural language processing tasks. + It uses the last hidden state of a transformer-based model to generate a scalar reward prediction, guiding the agent's behavior by evaluating the desirability or utility of its generated outputs. Example: @@ -30,9 +31,9 @@ class RewardModel(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_chosen, dummy_rejected) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # model parameters @@ -62,32 +63,31 @@ class RewardModel(nn.Module): # Call as you would a regular Flax model rngs = jax.random.PRNGKey(0) rngs, dropout_rng = jax.random.split(rngs) - rewards = reward_model.apply({'params': params}, - dummy_chosen, + rewards = reward_model.apply({'params': params}, + dummy_chosen, rngs={'dropout': dropout_rng}) print(rewards.shape) ``` """ + model: nn.Module dim: int dropout: float @nn.compact - def __call__(self, - x: jnp.ndarray, - training: bool = False): - + def __call__(self, x: jnp.ndarray, training: bool = False): + x = self.model(x, training=training, drop_last_layer=True) x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) x = nn.Dense(1)(x) return nn.sigmoid(x)[:, -1, 0] - + class RewardDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -107,14 +107,17 @@ class RewardDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None, - model_params_path: Optional[str] = None) -> None: - + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + model_params_path: Optional[str] = None, + ) -> None: + self.model = model self.params = None self.params_path = params_path @@ -123,19 +126,21 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(RewardDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(RewardDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + RewardDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + RewardDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32))['params'] + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape, dtype=jnp.int32))["params"] if self.params_path is not None: params = self.load_params(self.params_path) @@ -144,40 +149,48 @@ def create_train_state(self, model_params = self.load_params(self.model_params_path) params = self.merge_params(model_params, params) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - chosen: jnp.ndarray, - rejected: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, chosen: jnp.ndarray, rejected: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - chosen_rewards = state.apply_fn({'params': params}, - chosen, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - - rejected_rewards = state.apply_fn({'params': params}, - rejected, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - + chosen_rewards = state.apply_fn( + {"params": params}, + chosen, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + + rejected_rewards = state.apply_fn( + {"params": params}, + rejected, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean() - + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -185,36 +198,41 @@ def train(self, batch_size = chosen.shape[0] batch_size_per_device = batch_size // self.num_devices chosen = chosen.reshape((self.num_devices, batch_size_per_device, -1)) - rejected = rejected.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - chosen=chosen, - rejected=rejected) + rejected = rejected.reshape( + (self.num_devices, batch_size_per_device, -1) + ) + self.state, loss = self.train_step( + state=self.state, chosen=chosen, rejected=rejected + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - chosen: jnp.ndarray, - rejected: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - chosen_rewards = state.apply_fn({'params': state.params}, chosen, rngs={'dropout': jax.random.PRNGKey(2)}) - rejected_rewards = state.apply_fn({'params': state.params}, rejected, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, chosen: jnp.ndarray, rejected: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + chosen_rewards = state.apply_fn( + {"params": state.params}, chosen, rngs={"dropout": jax.random.PRNGKey(2)} + ) + rejected_rewards = state.apply_fn( + {"params": state.params}, rejected, rngs={"dropout": jax.random.PRNGKey(2)} + ) return -jnp.log(jax.nn.sigmoid(chosen_rewards - rejected_rewards)).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for chosen, rejected in test_loader: @@ -225,23 +243,26 @@ def evaluate(self, loss = self.evaluation_step(self.state, chosen, rejected) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss - + def merge_params(untrained_params, trained_params): updated_untrained_params = jax.tree_map( - lambda untrained, trained: trained if untrained.shape == trained.shape else untrained, - untrained_params, - trained_params) + lambda untrained, trained: ( + trained if untrained.shape == trained.shape else untrained + ), + untrained_params, + trained_params, + ) return updated_untrained_params def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/t5.py b/nanodl/__src/models/t5.py index 4a779a9..e65d832 100644 --- a/nanodl/__src/models/t5.py +++ b/nanodl/__src/models/t5.py @@ -1,11 +1,12 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import List, Tuple, Any, Optional, Dict, Iterable class RelativeMultiHeadAttention(nn.Module): @@ -23,53 +24,70 @@ class RelativeMultiHeadAttention(nn.Module): __call__(inputs, context, mask, clip): Processes the input and context tensors through the relative multi-head attention mechanism. attention_function(query, key, value, mask): Computes the attention scores and applies them to the value vectors, incorporating relative position information. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None, - clip: int = 3) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, + inputs: jnp.ndarray, + context: jnp.ndarray, + mask: jnp.ndarray = None, + clip: int = 3, + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - query_relative_positions = jnp.expand_dims(jnp.arange(query.shape[2]), axis=0) + query_relative_positions = jnp.expand_dims(jnp.arange(query.shape[2]), axis=0) query_relative_positions -= jnp.expand_dims(jnp.arange(query.shape[1]), axis=1) - query_relative_positions = jnp.where(query_relative_positions < clip, query_relative_positions, clip) - query_relative_positions = jnp.where(query_relative_positions > -clip, query_relative_positions, -clip) + query_relative_positions = jnp.where( + query_relative_positions < clip, query_relative_positions, clip + ) + query_relative_positions = jnp.where( + query_relative_positions > -clip, query_relative_positions, -clip + ) query += query_relative_positions - value_relative_positions = jnp.expand_dims(jnp.arange(value.shape[2]), axis=0) + value_relative_positions = jnp.expand_dims(jnp.arange(value.shape[2]), axis=0) value_relative_positions -= jnp.expand_dims(jnp.arange(value.shape[1]), axis=1) - value_relative_positions = jnp.where(value_relative_positions < clip, value_relative_positions, clip) - value_relative_positions = jnp.where(value_relative_positions > -clip, value_relative_positions, -clip) + value_relative_positions = jnp.where( + value_relative_positions < clip, value_relative_positions, clip + ) + value_relative_positions = jnp.where( + value_relative_positions > -clip, value_relative_positions, -clip + ) value += value_relative_positions - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -77,19 +95,29 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) - - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) + + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -105,17 +133,22 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) self.activation = nn.gelu - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(self.activation(self.dense1(X))) - + class AddNorm(nn.Module): """ @@ -129,17 +162,16 @@ class AddNorm(nn.Module): Methods: __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ + dropout: int @nn.compact - def __call__(self, - X: jnp.ndarray, - Y: jnp.ndarray, - training=False) -> jnp.ndarray: - + def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: + return nn.LayerNorm()( - nn.Dropout(self.dropout)(Y, deterministic=not training) + X) - + nn.Dropout(self.dropout)(Y, deterministic=not training) + X + ) + class T5EncoderBlock(nn.Module): """ @@ -157,30 +189,31 @@ class T5EncoderBlock(nn.Module): setup(): Initializes the components of the T5 encoder block. __call__(x, mask, training): Processes the input tensor through the encoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention = RelativeMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads) + self.attention = RelativeMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.linear = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + attended_x, attention = self.attention(x, x, mask=mask) x = self.add_norm1(x, attended_x, training) linear_output = self.linear(x) x = self.add_norm2(x, linear_output, training) return x, attention - - + + class T5Encoder(nn.Module): """ Implements the encoder component of the T5 model. @@ -200,6 +233,7 @@ class T5Encoder(nn.Module): setup(): Initializes the components of the T5 encoder. __call__(x, mask, training): Processes the input tensor through the encoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -208,29 +242,29 @@ class T5Encoder(nn.Module): vocab_size: float embed_dim: float - def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [T5EncoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) - for _ in range(self.num_layers)] - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + T5EncoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) for layer in self.layers: x, attention = layer(x, mask=mask, training=training) attention_maps.append(attention) return x, jnp.array(attention_maps) - + class T5DecoderBlock(nn.Module): """ @@ -248,39 +282,44 @@ class T5DecoderBlock(nn.Module): setup(): Initializes the components of the T5 decoder block. __call__(x, context, training): Processes the input tensor through the decoder block, incorporating context from the encoder. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention1 = RelativeMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.attention2 = RelativeMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) + self.attention1 = RelativeMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.attention2 = RelativeMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) self.add_norm3 = AddNorm(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__( + self, x: jnp.ndarray, context: jnp.ndarray, training: bool = False + ) -> tuple: - def __call__(self, - x: jnp.ndarray, - context: jnp.ndarray, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], context.shape[1]) attended_x, attention1 = self.attention1(x, x) @@ -291,9 +330,9 @@ def __call__(self, linear_output = self.feed_forward(x) x = self.add_norm3(x, linear_output, training) - + return x, jnp.array(attention1), jnp.array(attention2) - + class T5Decoder(nn.Module): """ @@ -314,6 +353,7 @@ class T5Decoder(nn.Module): setup(): Initializes the components of the T5 decoder. __call__(x, context, training): Processes the input tensor through the decoder, incorporating context from the encoder. """ + num_layers: int hidden_dim: int num_heads: int @@ -322,24 +362,24 @@ class T5Decoder(nn.Module): vocab_size: float embed_dim: float - def setup(self): - self.embedding = nn.Embed(num_embeddings=self.vocab_size, - features=self.embed_dim) - - self.layers = [T5DecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) for _ in range(self.num_layers)] - + self.embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) + + self.layers = [ + T5DecoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - context: jnp.ndarray, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, context: jnp.ndarray, training: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -347,9 +387,13 @@ def __call__(self, x, attention, cross_attention = layer(x, context, training=training) attention_maps.append(attention) cross_attention_maps.append(cross_attention) - return self.outputs(x), jnp.array(attention_maps), jnp.array(cross_attention_maps) - - + return ( + self.outputs(x), + jnp.array(attention_maps), + jnp.array(cross_attention_maps), + ) + + class T5(nn.Module): """ Implements the T5 model for text-to-text tasks, such as translation, summarization, and question answering. @@ -373,11 +417,11 @@ class T5(nn.Module): __call__(x, y, training): Processes the input tensor through the T5 model, generating predictions. generate(x, temperature, deterministic): Generates output sequences from input sequences. generate_batch(x, temperature, deterministic): Generates output sequences for a batch of input sequences. - - T5, which stands for Text-to-Text Transfer Transformer, is an influential deep learning architecture introduced by Google Research. - Its motivation stems from the idea of unifying various natural language processing tasks into a single framework to achieve greater model simplicity and efficiency. - T5 reimagines tasks as text-to-text problems, where both inputs and outputs are represented as text. - This consistent formulation allows T5 to perform an astonishingly wide range of tasks, + + T5, which stands for Text-to-Text Transfer Transformer, is an influential deep learning architecture introduced by Google Research. + Its motivation stems from the idea of unifying various natural language processing tasks into a single framework to achieve greater model simplicity and efficiency. + T5 reimagines tasks as text-to-text problems, where both inputs and outputs are represented as text. + This consistent formulation allows T5 to perform an astonishingly wide range of tasks, from translation and summarization to question-answering and document classification, by adjusting the input and output formats accordingly. The architecture is roughly equivalent to the original Transformer proposed by Vaswani et al. (2017) with the exception of removing the Layer Norm bias, placing the layer @@ -403,9 +447,9 @@ class T5(nn.Module): # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -432,25 +476,25 @@ class T5(nn.Module): model = T5(**hyperparams) rngs = jax.random.PRNGKey(0) rngs, dropout_rng = jax.random.split(rngs) - params = model.init({'params': rngs, 'dropout': dropout_rng}, + params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs, dummy_targets)['params'] # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, + outputs = model.apply({'params': params}, + dummy_inputs, dummy_targets, rngs={'dropout': dropout_rng}) print(outputs.shape) # Training on data - trainer = T5DataParallelTrainer(model, - dummy_inputs.shape, + trainer = T5DataParallelTrainer(model, + dummy_inputs.shape, dummy_targets.shape, 'params.pkl') - trainer.train(train_loader=dataloader, - num_epochs=2, + trainer.train(train_loader=dataloader, + num_epochs=2, val_loader=dataloader) print(trainer.evaluate(dataloader)) @@ -458,15 +502,16 @@ class T5(nn.Module): # Generating from a start token start_tokens = jnp.array([[123, 456]]) - # Remember to load the trained parameters + # Remember to load the trained parameters params = trainer.load_params('params.pkl') outputs = model.apply({'params': params}, start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, + rngs={'dropout': jax.random.PRNGKey(2)}, method=model.generate) print(outputs) ``` """ + num_layers: int num_heads: int hidden_dim: int @@ -479,45 +524,46 @@ class T5(nn.Module): end_token: int def setup(self): - self.encoder = T5Encoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim) - - self.decoder = T5Decoder(self.num_layers, - self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout, - self.vocab_size, - self.embed_dim) - - def __call__(self, - x: jnp.ndarray, - y: jnp.ndarray, - training: bool = False) -> jnp.ndarray: - + self.encoder = T5Encoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + ) + + self.decoder = T5Decoder( + self.num_layers, + self.hidden_dim, + self.num_heads, + self.feedforward_dim, + self.dropout, + self.vocab_size, + self.embed_dim, + ) + + def __call__( + self, x: jnp.ndarray, y: jnp.ndarray, training: bool = False + ) -> jnp.ndarray: + z = self.encoder(x=x, training=training)[0] return self.decoder(x=y, context=z, training=training)[0] - - def generate(self, - x: jnp.ndarray, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + def generate( + self, x: jnp.ndarray, temperature: float = 1.0, deterministic: bool = False + ) -> Tuple[jnp.ndarray]: + encoded_sequence = self.encoder(x=x, training=False)[0] decoder_input = x if x is not None else jnp.array([[self.start_token]]) output_sequence = [] # Autoregressive decoding loop for _ in range(self.max_length): - decoder_output = self.decoder(x=decoder_input, - context=encoded_sequence, - training=False)[0] + decoder_output = self.decoder( + x=decoder_input, context=encoded_sequence, training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -525,35 +571,41 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) print(decoder_input.shape, jnp.array([[next_token]]).shape) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: jnp.ndarray, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, x: jnp.ndarray, temperature: float = 1.0, deterministic: bool = False + ) -> jnp.ndarray: + # Encode the input sequence encoded_sequence = self.encoder(x=x, training=False)[0] batch_size = x.shape[0] if x is not None else 1 - decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token) + decoder_input = ( + x if x is not None else jnp.full((batch_size, 1), self.start_token) + ) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): - decoder_output = self.decoder(x=decoder_input, - context=encoded_sequence, - training=False)[0] + decoder_output = self.decoder( + x=decoder_input, context=encoded_sequence, training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -562,10 +614,14 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break @@ -573,11 +629,10 @@ def generate_batch(self, return output_sequences - class T5DataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -597,13 +652,16 @@ class T5DataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - target_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -611,54 +669,69 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(T5DataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(T5DataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + T5DataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + T5DataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape, target_shape) - print(f'Number of accelerators: {self.num_devices}') - - - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...], - target_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(target_shape, dtype=jnp.int32))['params'] + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, + learning_rate: float, + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init( + rngs, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(target_shape, dtype=jnp.int32), + )["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - targets, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + targets, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -667,35 +740,39 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, targets, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, + inputs, + targets, + rngs={"dropout": jax.random.PRNGKey(2)}, + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -706,16 +783,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/transformer.py b/nanodl/__src/models/transformer.py index 4171a6f..f0baf15 100644 --- a/nanodl/__src/models/transformer.py +++ b/nanodl/__src/models/transformer.py @@ -1,11 +1,12 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import List, Tuple, Any, Optional, Dict, Iterable class PositionalEncoding(nn.Module): @@ -22,19 +23,27 @@ class PositionalEncoding(nn.Module): setup(): Initializes the positional encoding matrix based on the provided attributes. __call__(x: jnp.ndarray): Adds positional encodings to the input embeddings. """ + num_embeddings: int features: int def setup(self): positional_encoding = jnp.zeros((self.features, self.num_embeddings)) position = jnp.arange(0, self.features, dtype=jnp.float32)[:, None] - div_term = jnp.exp(jnp.arange(0, self.num_embeddings, 2) * (-jnp.log(10000.0) / self.num_embeddings)) - positional_encoding = positional_encoding.at[:, 0::2].set(jnp.sin(position * div_term)) - positional_encoding = positional_encoding.at[:, 1::2].set(jnp.cos(position * div_term)) + div_term = jnp.exp( + jnp.arange(0, self.num_embeddings, 2) + * (-jnp.log(10000.0) / self.num_embeddings) + ) + positional_encoding = positional_encoding.at[:, 0::2].set( + jnp.sin(position * div_term) + ) + positional_encoding = positional_encoding.at[:, 1::2].set( + jnp.cos(position * div_term) + ) self.positional_encoding = positional_encoding.T def __call__(self, x): - x = x + self.positional_encoding[:x.shape[1]] + x = x + self.positional_encoding[: x.shape[1]] return x @@ -54,18 +63,25 @@ class TokenAndPositionEmbedding(nn.Module): setup(): Initializes token and positional embeddings. __call__(x: jnp.ndarray): Applies token embeddings and adds positional information to the input sequence. """ - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool - + + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool + def setup(self): - self.token_embeddings = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim) + self.token_embeddings = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) if self.learned_position: - self.position_embeddings = nn.Embed(num_embeddings=self.max_len, features=self.embed_dim) + self.position_embeddings = nn.Embed( + num_embeddings=self.max_len, features=self.embed_dim + ) else: - self.position_embeddings = PositionalEncoding(num_embeddings=self.max_len, features=self.embed_dim) + self.position_embeddings = PositionalEncoding( + num_embeddings=self.max_len, features=self.embed_dim + ) def __call__(self, x): x = self.token_embeddings(x) @@ -73,7 +89,7 @@ def __call__(self, x): return x + self.position_embeddings(jnp.arange(x.shape[1])) else: return x + self.position_embeddings(x) - + class MultiHeadAttention(nn.Module): """ @@ -90,39 +106,45 @@ class MultiHeadAttention(nn.Module): __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int - num_heads : int + + hidden_dim: int + num_heads: int def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -130,19 +152,29 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -158,17 +190,22 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) self.activation = nn.gelu - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(self.activation(self.dense1(X))) - + class AddNorm(nn.Module): """ @@ -182,16 +219,15 @@ class AddNorm(nn.Module): Methods: __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ + dropout: int @nn.compact - def __call__(self, - X: jnp.ndarray, - Y: jnp.ndarray, - training=False) -> jnp.ndarray: + def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: return nn.LayerNorm()( - nn.Dropout(self.dropout)(Y, deterministic=not training) + X) - + nn.Dropout(self.dropout)(Y, deterministic=not training) + X + ) + class TransformerEncoderBlock(nn.Module): """ @@ -209,29 +245,30 @@ class TransformerEncoderBlock(nn.Module): setup(): Initializes the attention, feed-forward network, and normalization layers. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention = MultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads) + self.attention = MultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.linear = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: attended_x, attention = self.attention(x, x, mask=mask) x = self.add_norm1(x, attended_x, training) linear_output = self.linear(x) x = self.add_norm2(x, linear_output, training) return x, attention - - + + class TransformerEncoder(nn.Module): """ Implements a transformer encoder for text. @@ -253,40 +290,39 @@ class TransformerEncoder(nn.Module): setup(): Initializes the embedding layer and the encoder blocks. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the transformer encoder. """ + num_layers: int hidden_dim: int num_heads: int feedforward_dim: int dropout: float - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool = True - + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool = True def setup(self): - self.embedding = TokenAndPositionEmbedding(self.max_len, - self.vocab_size, - self.embed_dim, - self.learned_position) - - self.layers = [TransformerEncoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) - for _ in range(self.num_layers)] - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: + self.embedding = TokenAndPositionEmbedding( + self.max_len, self.vocab_size, self.embed_dim, self.learned_position + ) + + self.layers = [ + TransformerEncoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: attention_maps = [] x = self.embedding(x) for layer in self.layers: x, attention = layer(x, mask=mask, training=training) attention_maps.append(attention) return x, jnp.array(attention_maps) - + class TransformerDecoderBlock(nn.Module): """ @@ -304,39 +340,44 @@ class TransformerDecoderBlock(nn.Module): setup(): Initializes the components of the Transformer decoder block. __call__(x, context, training): Processes the input tensor through the decoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention1 = MultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.attention2 = MultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) + self.attention1 = MultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.attention2 = MultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) self.add_norm3 = AddNorm(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__( + self, x: jnp.ndarray, context: jnp.ndarray, training: bool = False + ) -> tuple: - def __call__(self, - x: jnp.ndarray, - context: jnp.ndarray, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], context.shape[1]) attended_x, attention1 = self.attention1(x, x) @@ -347,9 +388,9 @@ def __call__(self, linear_output = self.feed_forward(x) x = self.add_norm3(x, linear_output, training) - + return x, jnp.array(attention1), jnp.array(attention2) - + class TransformerDecoder(nn.Module): """ @@ -372,36 +413,35 @@ class TransformerDecoder(nn.Module): setup(): Initializes the components of the Transformer decoder. __call__(x, context, training): Processes the input tensor through the decoder. """ + num_layers: int hidden_dim: int num_heads: int feedforward_dim: int dropout: float - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool = True - + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool = True def setup(self): - self.embedding = TokenAndPositionEmbedding(self.max_len, - self.vocab_size, - self.embed_dim, - self.learned_position) - - self.layers = [TransformerDecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) for _ in range(self.num_layers)] - + self.embedding = TokenAndPositionEmbedding( + self.max_len, self.vocab_size, self.embed_dim, self.learned_position + ) + + self.layers = [ + TransformerDecoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - context: jnp.ndarray, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, context: jnp.ndarray, training: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -409,123 +449,128 @@ def __call__(self, x, attention, cross_attention = layer(x, context, training=training) attention_maps.append(attention) cross_attention_maps.append(cross_attention) - return self.outputs(x), jnp.array(attention_maps), jnp.array(cross_attention_maps) - - -class Transformer(nn.Module): - """ - Implements the Transformer model for sequence-to-sequence tasks, such as translation and text generation. + return ( + self.outputs(x), + jnp.array(attention_maps), + jnp.array(cross_attention_maps), + ) - The Transformer model utilizes an encoder-decoder architecture. The encoder captures contextual information from the input sequence, and the decoder generates the output sequence based on this context. - Attributes: - num_layers (int): Number of layers in both the encoder and decoder. - num_heads (int): Number of attention heads in each layer. - hidden_dim (int): Dimensionality of the input and output features for the layers. - feedforward_dim (int): Dimensionality of the inner layer of the feed-forward networks in the layers. - dropout (float): Dropout rate used for regularization. - vocab_size (float): Size of the vocabulary. - embed_dim (float): Dimensionality of the token embeddings. - max_length (int): Maximum length of the generated sequences. - start_token (int): Token used to start the generation process. - end_token (int): Token that indicates the end of a generated sequence. +class Transformer(nn.Module): + """ + Implements the Transformer model for sequence-to-sequence tasks, such as translation and text generation. + + The Transformer model utilizes an encoder-decoder architecture. The encoder captures contextual information from the input sequence, and the decoder generates the output sequence based on this context. + + Attributes: + num_layers (int): Number of layers in both the encoder and decoder. + num_heads (int): Number of attention heads in each layer. + hidden_dim (int): Dimensionality of the input and output features for the layers. + feedforward_dim (int): Dimensionality of the inner layer of the feed-forward networks in the layers. + dropout (float): Dropout rate used for regularization. + vocab_size (float): Size of the vocabulary. + embed_dim (float): Dimensionality of the token embeddings. + max_length (int): Maximum length of the generated sequences. + start_token (int): Token used to start the generation process. + end_token (int): Token that indicates the end of a generated sequence. + + Methods: + setup(): Initializes the Transformer model including both the encoder and decoder components. + __call__(x, y, training): Processes the input tensor through the Transformer model, generating predictions. + generate(x, temperature, deterministic): Generates output sequences from input sequences. + generate_batch(x, temperature, deterministic): Generates output sequences for a batch of input sequences. + + Transformers are a groundbreaking class of deep learning models originally introduced in the paper "Attention Is All You Need" by Vaswani et al. + Their motivation stems from addressing limitations in previous sequence-to-sequence models and enabling more efficient and parallelizable training. + The key innovation of transformers is the self-attention mechanism, which allows the model to weigh the importance of different parts of the input sequence during processing. + This architecture has had a profound impact on natural language processing and has been adapted for a wide range of tasks, including machine translation, text generation, image captioning, and more. + Transformers have become the foundation for various state-of-the-art models, including BERT, GPT, and Transformer, which have achieved remarkable results across multiple domains, showcasing the power of attention-based architectures in deep learning. + + Example usage: + ``` + import jax + import jax.numpy as jnp + from nanodl import ArrayDataset, DataLoader + from nanodl import Transformer, TransformerDataParallelTrainer + + # Generate dummy data + batch_size = 8 + max_length = 10 + + # Replace with actual tokenised data + data = jnp.ones((101, max_length+1), dtype=jnp.int32) + + # Shift to create next-token prediction dataset + dummy_inputs = data[:, :-1] + dummy_targets = data[:, 1:] + + # Create dataset and dataloader + dataset = ArrayDataset(dummy_inputs, dummy_targets) + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) + + # How to loop through dataloader + for batch in dataloader: + x, y = batch + print(x.shape, y.shape) + break - Methods: - setup(): Initializes the Transformer model including both the encoder and decoder components. - __call__(x, y, training): Processes the input tensor through the Transformer model, generating predictions. - generate(x, temperature, deterministic): Generates output sequences from input sequences. - generate_batch(x, temperature, deterministic): Generates output sequences for a batch of input sequences. - - Transformers are a groundbreaking class of deep learning models originally introduced in the paper "Attention Is All You Need" by Vaswani et al. - Their motivation stems from addressing limitations in previous sequence-to-sequence models and enabling more efficient and parallelizable training. - The key innovation of transformers is the self-attention mechanism, which allows the model to weigh the importance of different parts of the input sequence during processing. - This architecture has had a profound impact on natural language processing and has been adapted for a wide range of tasks, including machine translation, text generation, image captioning, and more. - Transformers have become the foundation for various state-of-the-art models, including BERT, GPT, and Transformer, which have achieved remarkable results across multiple domains, showcasing the power of attention-based architectures in deep learning. - - Example usage: - ``` - import jax - import jax.numpy as jnp - from nanodl import ArrayDataset, DataLoader - from nanodl import Transformer, TransformerDataParallelTrainer - - # Generate dummy data - batch_size = 8 - max_length = 10 - - # Replace with actual tokenised data - data = jnp.ones((101, max_length+1), dtype=jnp.int32) - - # Shift to create next-token prediction dataset - dummy_inputs = data[:, :-1] - dummy_targets = data[:, 1:] - - # Create dataset and dataloader - dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, - drop_last=False) - - # How to loop through dataloader - for batch in dataloader: - x, y = batch - print(x.shape, y.shape) - break - - # model parameters - hyperparams = { - 'num_layers': 1, - 'hidden_dim': 256, - 'num_heads': 2, - 'feedforward_dim': 256, - 'dropout': 0.1, - 'vocab_size': 1000, - 'embed_dim': 256, - 'max_length': max_length, - 'start_token': 0, - 'end_token': 50, - } - - # Initialize model - model = Transformer(**hyperparams) - rngs = jax.random.PRNGKey(0) - rngs, dropout_rng = jax.random.split(rngs) - params = model.init({'params': rngs, 'dropout': dropout_rng}, - dummy_inputs, - dummy_targets)['params'] - - # Call as you would a Jax/Flax model - outputs = model.apply({'params': params}, - dummy_inputs, - dummy_targets, - rngs={'dropout': dropout_rng}) - print(outputs.shape) - - # Training on data - trainer = TransformerDataParallelTrainer(model, - dummy_inputs.shape, - dummy_targets.shape, - 'params.pkl') - - trainer.train(train_loader=dataloader, - num_epochs=2, - val_loader=dataloader) - - print(trainer.evaluate(dataloader)) - - # Generating from a start token - start_tokens = jnp.array([[123, 456]]) - - # Remember to load the trained parameters - params = trainer.load_params('params.pkl') - outputs = model.apply({'params': params}, - start_tokens, - rngs={'dropout': jax.random.PRNGKey(2)}, - method=model.generate) - print(outputs) -``` + # model parameters + hyperparams = { + 'num_layers': 1, + 'hidden_dim': 256, + 'num_heads': 2, + 'feedforward_dim': 256, + 'dropout': 0.1, + 'vocab_size': 1000, + 'embed_dim': 256, + 'max_length': max_length, + 'start_token': 0, + 'end_token': 50, + } + + # Initialize model + model = Transformer(**hyperparams) + rngs = jax.random.PRNGKey(0) + rngs, dropout_rng = jax.random.split(rngs) + params = model.init({'params': rngs, 'dropout': dropout_rng}, + dummy_inputs, + dummy_targets)['params'] + + # Call as you would a Jax/Flax model + outputs = model.apply({'params': params}, + dummy_inputs, + dummy_targets, + rngs={'dropout': dropout_rng}) + print(outputs.shape) + + # Training on data + trainer = TransformerDataParallelTrainer(model, + dummy_inputs.shape, + dummy_targets.shape, + 'params.pkl') + + trainer.train(train_loader=dataloader, + num_epochs=2, + val_loader=dataloader) + + print(trainer.evaluate(dataloader)) + + # Generating from a start token + start_tokens = jnp.array([[123, 456]]) + + # Remember to load the trained parameters + params = trainer.load_params('params.pkl') + outputs = model.apply({'params': params}, + start_tokens, + rngs={'dropout': jax.random.PRNGKey(2)}, + method=model.generate) + print(outputs) + ``` """ + num_layers: int num_heads: int hidden_dim: int @@ -548,7 +593,7 @@ def setup(self): vocab_size=self.vocab_size, embed_dim=self.embed_dim, ) - + self.decoder = TransformerDecoder( hidden_dim=self.hidden_dim, num_heads=self.num_heads, @@ -559,30 +604,27 @@ def setup(self): vocab_size=self.vocab_size, embed_dim=self.embed_dim, ) - - def __call__(self, - x: jnp.ndarray, - y: jnp.ndarray, - training: bool = False) -> jnp.ndarray: - + + def __call__( + self, x: jnp.ndarray, y: jnp.ndarray, training: bool = False + ) -> jnp.ndarray: + z = self.encoder(x=x, training=training)[0] return self.decoder(x=y, context=z, training=training)[0] - - def generate(self, - x: jnp.ndarray, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + def generate( + self, x: jnp.ndarray, temperature: float = 1.0, deterministic: bool = False + ) -> Tuple[jnp.ndarray]: + encoded_sequence = self.encoder(x=x, training=False)[0] decoder_input = jnp.array([[self.start_token]]) output_sequence = [] # Autoregressive decoding loop for _ in range(self.max_length): - decoder_output = self.decoder(x=decoder_input, - context=encoded_sequence, - training=False)[0] + decoder_output = self.decoder( + x=decoder_input, context=encoded_sequence, training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -590,32 +632,36 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: jnp.ndarray, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, x: jnp.ndarray, temperature: float = 1.0, deterministic: bool = False + ) -> jnp.ndarray: + encoded_sequence = self.encoder(x=x, training=False)[0] batch_size = x.shape[0] if x is not None else 1 decoder_input = jnp.full((batch_size, 1), self.start_token) output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): - decoder_output = self.decoder(x=decoder_input, - context=encoded_sequence, - training=False)[0] + decoder_output = self.decoder( + x=decoder_input, context=encoded_sequence, training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -624,10 +670,14 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break @@ -635,11 +685,10 @@ def generate_batch(self, return output_sequences - class TransformerDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -659,13 +708,16 @@ class TransformerDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - target_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -673,54 +725,69 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(TransformerDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(TransformerDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + TransformerDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + TransformerDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape, target_shape) - print(f'Number of accelerators: {self.num_devices}') - - - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...], - target_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(target_shape, dtype=jnp.int32))['params'] + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, + learning_rate: float, + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init( + rngs, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(target_shape, dtype=jnp.int32), + )["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - targets, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + targets, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 @@ -729,35 +796,39 @@ def train(self, batch_size_per_device = batch_size // self.num_devices inputs = inputs.reshape((self.num_devices, batch_size_per_device, -1)) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, targets, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, + inputs, + targets, + rngs={"dropout": jax.random.PRNGKey(2)}, + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: @@ -768,16 +839,16 @@ def evaluate(self, loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/vit.py b/nanodl/__src/models/vit.py index fc4bcc5..4c5606e 100644 --- a/nanodl/__src/models/vit.py +++ b/nanodl/__src/models/vit.py @@ -1,11 +1,13 @@ -import jax import time +from typing import Any, Iterable, Optional, Tuple + import flax -import optax -import jax.numpy as jnp import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import Tuple, Any, Optional, Iterable + class PatchEmbedding(nn.Module): """ @@ -21,18 +23,21 @@ class PatchEmbedding(nn.Module): __call__(x: jnp.ndarray): Extracts patches from the input images and applies patch embedding. extract_patches(images: jnp.ndarray): Extracts and flattens patches from input images. """ + patch_size: Tuple[int, int] - embed_dim: int + embed_dim: int @nn.compact def __call__(self, x): x = nn.Dense(self.embed_dim)(self.extract_patches(x)) - return x + nn.Embed(num_embeddings=x.shape[1], features=x.shape[2])(jnp.arange(x.shape[1])) + return x + nn.Embed(num_embeddings=x.shape[1], features=x.shape[2])( + jnp.arange(x.shape[1]) + ) def extract_patches(self, images: jnp.ndarray) -> jnp.ndarray: if len(images.shape) != 4: raise ValueError("Input images should have shape (batch_size, H, W, C)") - + batch_size, h, w, c = images.shape ph, pw = self.patch_size @@ -44,11 +49,13 @@ def extract_patches(self, images: jnp.ndarray) -> jnp.ndarray: num_patches_w = w // pw # Reshape the images into patches and flatten each patch - patches = jnp.reshape(images, (batch_size, num_patches_h, ph, num_patches_w, pw, c)) + patches = jnp.reshape( + images, (batch_size, num_patches_h, ph, num_patches_w, pw, c) + ) patches = jnp.transpose(patches, (0, 1, 3, 2, 4, 5)) patches = jnp.reshape(patches, (batch_size, -1, ph * pw * c)) return patches - + class SelfMultiHeadAttention(nn.Module): """ @@ -65,30 +72,33 @@ class SelfMultiHeadAttention(nn.Module): __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # Stack all weight matrices together for efficiency - self.projection = nn.Dense(3*self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - + self.projection = nn.Dense( + 3 * self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) - def __call__(self, - inputs: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + def __call__(self, inputs: jnp.ndarray, mask: jnp.ndarray = None) -> tuple: projections = self.projection(inputs) query, key, value = jnp.array_split(projections, 3, axis=-1) - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -96,19 +106,29 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -124,12 +144,17 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(nn.gelu(self.dense1(X))) @@ -147,13 +172,11 @@ class AddNorm(nn.Module): Methods: __call__(X: jnp.ndarray, Y: jnp.ndarray, training=False): Applies dropout to the output of a sublayer (Y), adds it to the original input (X), and applies layer normalization. """ + dropout: int @nn.compact - def __call__(self, - X: jnp.ndarray, - Y: jnp.ndarray, - training=False) -> jnp.ndarray: + def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: """ Apply AddNorm to input tensors. Args: @@ -164,8 +187,9 @@ def __call__(self, jnp.ndarray: Output tensor after applying AddNorm. """ return nn.LayerNorm()( - nn.Dropout(self.dropout)(Y, deterministic=not training) + X) - + nn.Dropout(self.dropout)(Y, deterministic=not training) + X + ) + class ViTBlock(nn.Module): """ @@ -183,29 +207,30 @@ class ViTBlock(nn.Module): setup(): Initializes the attention, feed-forward network, and normalization layers. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input through the encoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads) + self.attention = SelfMultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.ff = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + attended_x, attention = self.attention(x, mask=mask) x = self.add_norm1(x, attended_x, training) ff_output = self.ff(x) x = self.add_norm2(x, ff_output, training) return x, attention - + class ViTEncoder(nn.Module): """ @@ -225,6 +250,7 @@ class ViTEncoder(nn.Module): setup(): Initializes the patch embedding and encoder blocks. __call__(x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False): Processes the input images through the vision transformer encoder. """ + patch_size: Tuple[int, int] num_layers: int hidden_dim: int @@ -233,20 +259,19 @@ class ViTEncoder(nn.Module): dropout: float def setup(self): - self.embedding = PatchEmbedding(self.patch_size, - self.feedforward_dim) - - self.layers = [ViTBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) - for _ in range(self.num_layers)] - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + self.embedding = PatchEmbedding(self.patch_size, self.feedforward_dim) + + self.layers = [ + ViTBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) for layer in self.layers: @@ -272,14 +297,14 @@ class ViT(nn.Module): Methods: setup(): Initializes the components of the ViTEncoder. __call__(x, mask, training): Processes the input tensor through the encoder, returning encoded features and attention maps. - - Vision Transformers, or ViTs, have emerged as a groundbreaking architectural paradigm in computer vision and deep learning. - The motivation behind Vision Transformers lies in the desire to extend the success of transformers, - originally designed for natural language processing, to visual data. These models aim to replace - or complement traditional Convolutional Neural Networks (CNNs) in image-related tasks. ViTs employ a self-attention mechanism - to capture global dependencies among pixels or patches of an image, which helps them understand context and relationships between different regions effectively. - By utilizing pretraining on large-scale image datasets, ViTs have achieved remarkable performance in image classification, object detection, image generation, and various other computer vision tasks. - Their modular design, scalability, and ability to handle both local and global information have made Vision Transformers a significant advancement in the field, + + Vision Transformers, or ViTs, have emerged as a groundbreaking architectural paradigm in computer vision and deep learning. + The motivation behind Vision Transformers lies in the desire to extend the success of transformers, + originally designed for natural language processing, to visual data. These models aim to replace + or complement traditional Convolutional Neural Networks (CNNs) in image-related tasks. ViTs employ a self-attention mechanism + to capture global dependencies among pixels or patches of an image, which helps them understand context and relationships between different regions effectively. + By utilizing pretraining on large-scale image datasets, ViTs have achieved remarkable performance in image classification, object detection, image generation, and various other computer vision tasks. + Their modular design, scalability, and ability to handle both local and global information have made Vision Transformers a significant advancement in the field, offering promising avenues for future research and applications in computer vision. Example usage: @@ -291,26 +316,26 @@ class ViT(nn.Module): # Dummy data parameters batch_size = 8 - max_length = 50 - n_outputs = 5 - embed_dim = 256 - patch_size = (16, 16) + max_length = 50 + n_outputs = 5 + embed_dim = 256 + patch_size = (16, 16) # Generate data dummy_inputs = jnp.ones((batch_size, 224, 224, 3)) key = jax.random.PRNGKey(10) - dummy_labels = jax.random.randint(key, - shape=(batch_size,), - minval=0, + dummy_labels = jax.random.randint(key, + shape=(batch_size,), + minval=0, maxval=n_outputs-1) # Create dataset and dataloader - dataset = ArrayDataset(dummy_inputs, + dataset = ArrayDataset(dummy_inputs, dummy_labels) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # model parameters @@ -336,6 +361,7 @@ class ViT(nn.Module): trainer.train(dataloader, 10, dataloader) ``` """ + patch_size: Tuple[int, int] num_layers: int hidden_dim: int @@ -351,25 +377,24 @@ def setup(self): hidden_dim=self.hidden_dim, num_heads=self.num_heads, feedforward_dim=self.feedforward_dim, - dropout=self.dropout + dropout=self.dropout, ) self.dropout_layer = nn.Dropout(self.dropout) self.output = nn.Dense(self.n_outputs) - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + x, attention_maps = self.encoder(x=x, mask=mask, training=training) x = self.dropout_layer(x, deterministic=not training) - return self.output(x[:,0,:]), x, attention_maps + return self.output(x[:, 0, :]), x, attention_maps class ViTDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -388,12 +413,15 @@ class ViTDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -401,107 +429,137 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(ViTDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(ViTDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + ViTDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + ViTDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape) - print(f'Number of accelerators: {self.num_devices}') - + print(f"Number of accelerators: {self.num_devices}") - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, jnp.ones(input_shape))['params'] + def create_train_state( + self, learning_rate: float, input_shape: Tuple[int, ...] + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init(rngs, jnp.ones(input_shape))["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))})[0] - return -jnp.mean(jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets]) - + logits = state.apply_fn( + {"params": params}, + inputs, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + )[0] + return -jnp.mean( + jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets] + ) + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 for inputs, targets in train_loader: batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices - inputs = inputs.reshape((self.num_devices, batch_size_per_device, inputs.shape[1], inputs.shape[2], inputs.shape[3])) + inputs = inputs.reshape( + ( + self.num_devices, + batch_size_per_device, + inputs.shape[1], + inputs.shape[2], + inputs.shape[3], + ) + ) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, rngs={'dropout': jax.random.PRNGKey(2)})[0] - return -jnp.mean(jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets]) - - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, inputs, rngs={"dropout": jax.random.PRNGKey(2)} + )[0] + return -jnp.mean( + jax.vmap(jax.nn.log_softmax)(logits)[jnp.arange(targets.size), targets] + ) + + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices - inputs = inputs.reshape((self.num_devices, batch_size_per_device, inputs.shape[1], inputs.shape[2], inputs.shape[3])) + inputs = inputs.reshape( + ( + self.num_devices, + batch_size_per_device, + inputs.shape[1], + inputs.shape[2], + inputs.shape[3], + ) + ) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/models/whisper.py b/nanodl/__src/models/whisper.py index 20706b5..437cd51 100644 --- a/nanodl/__src/models/whisper.py +++ b/nanodl/__src/models/whisper.py @@ -1,11 +1,12 @@ -import jax -import flax import time -import optax -import jax.numpy as jnp +from typing import Any, Iterable, Optional, Tuple + +import flax import flax.linen as nn +import jax +import jax.numpy as jnp +import optax from flax.training import train_state -from typing import List, Tuple, Any, Optional, Dict, Iterable class SpeechEmbedding(nn.Module): @@ -18,23 +19,27 @@ class SpeechEmbedding(nn.Module): __call__(x): Processes the input audio tensor through the convolutional layers and adds sinusoidal embeddings. sinusoidal_embedding(x, max_position): Generates sinusoidal embeddings based on the sequence length and hidden dimension of the input. """ + @nn.compact - def __call__(self, - x: jnp.ndarray) -> jnp.ndarray: - x = nn.gelu(nn.Conv(features=x.shape[-1], kernel_size=(3,), padding='SAME')(x)) - x = nn.gelu(nn.Conv(features=x.shape[-1], kernel_size=(3,), strides=(2,), padding='SAME')(x)) + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = nn.gelu(nn.Conv(features=x.shape[-1], kernel_size=(3,), padding="SAME")(x)) + x = nn.gelu( + nn.Conv( + features=x.shape[-1], kernel_size=(3,), strides=(2,), padding="SAME" + )(x) + ) return jnp.concatenate((x, self.sinusoidal_embedding(x)), axis=-2) - - def sinusoidal_embedding(self, - x: jnp.ndarray, - max_position: int = 10000) -> jnp.ndarray: + + def sinusoidal_embedding( + self, x: jnp.ndarray, max_position: int = 10000 + ) -> jnp.ndarray: batch_size, seq_len, hidden_dim = x.shape positions = jnp.arange(seq_len)[:, None] angles = (jnp.arange(hidden_dim) / hidden_dim)[None, :] encodings = jnp.sin(positions / jnp.power(max_position, angles))[None, :, :] encodings = jnp.repeat(encodings, batch_size, axis=0) return x + encodings - + class PositionalEncoding(nn.Module): """ @@ -50,19 +55,27 @@ class PositionalEncoding(nn.Module): setup(): Initializes the positional encoding matrix based on the provided attributes. __call__(x: jnp.ndarray): Adds positional encodings to the input embeddings. """ + num_embeddings: int features: int def setup(self): positional_encoding = jnp.zeros((self.features, self.num_embeddings)) position = jnp.arange(0, self.features, dtype=jnp.float32)[:, None] - div_term = jnp.exp(jnp.arange(0, self.num_embeddings, 2) * (-jnp.log(10000.0) / self.num_embeddings)) - positional_encoding = positional_encoding.at[:, 0::2].set(jnp.sin(position * div_term)) - positional_encoding = positional_encoding.at[:, 1::2].set(jnp.cos(position * div_term)) + div_term = jnp.exp( + jnp.arange(0, self.num_embeddings, 2) + * (-jnp.log(10000.0) / self.num_embeddings) + ) + positional_encoding = positional_encoding.at[:, 0::2].set( + jnp.sin(position * div_term) + ) + positional_encoding = positional_encoding.at[:, 1::2].set( + jnp.cos(position * div_term) + ) self.positional_encoding = positional_encoding.T def __call__(self, x): - x = x + self.positional_encoding[:x.shape[1]] + x = x + self.positional_encoding[: x.shape[1]] return x @@ -74,18 +87,25 @@ class TokenAndPositionEmbedding(nn.Module): vocab_size (int): Vocabulary size. embed_dim (int): Embedding dimension. """ - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool - + + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool + def setup(self): - self.token_embeddings = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim) + self.token_embeddings = nn.Embed( + num_embeddings=self.vocab_size, features=self.embed_dim + ) if self.learned_position: - self.position_embeddings = nn.Embed(num_embeddings=self.max_len, features=self.embed_dim) + self.position_embeddings = nn.Embed( + num_embeddings=self.max_len, features=self.embed_dim + ) else: - self.position_embeddings = PositionalEncoding(num_embeddings=self.max_len, features=self.embed_dim) + self.position_embeddings = PositionalEncoding( + num_embeddings=self.max_len, features=self.embed_dim + ) def __call__(self, x): x = self.token_embeddings(x) @@ -93,7 +113,7 @@ def __call__(self, x): return x + self.position_embeddings(jnp.arange(x.shape[1])) else: return x + self.position_embeddings(x) - + class MultiHeadAttention(nn.Module): """ @@ -110,40 +130,46 @@ class MultiHeadAttention(nn.Module): __call__(inputs: jnp.ndarray, mask: jnp.ndarray = None): Processes the input tensor through the multi-head self-attention mechanism. attention_function(query, key, value, mask=None): Computes the attention scores and applies them to the value vectors. """ - hidden_dim : int # Output dimension - num_heads : int # Number of parallel heads + + hidden_dim: int # Output dimension + num_heads: int # Number of parallel heads def setup(self): # Because the Query is determined from a context, project separately - self.query_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.key_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.value_projection = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros - ) - self.output = nn.Dense(self.hidden_dim, - kernel_init=nn.initializers.xavier_uniform(), - bias_init=nn.initializers.zeros) - - - def __call__(self, - inputs: jnp.ndarray, - context: jnp.ndarray, - mask: jnp.ndarray = None) -> tuple: + self.query_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.key_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.value_projection = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + self.output = nn.Dense( + self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + bias_init=nn.initializers.zeros, + ) + + def __call__( + self, inputs: jnp.ndarray, context: jnp.ndarray, mask: jnp.ndarray = None + ) -> tuple: query = self.query_projection(inputs) key = self.key_projection(context) value = self.value_projection(context) - context_vectors, attention = self.attention_function(query,key, value, mask=mask) + context_vectors, attention = self.attention_function( + query, key, value, mask=mask + ) outputs = self.output(context_vectors) return outputs, attention - + def attention_function(self, query, key, value, mask=None): input_length = query.shape[1] context_length = key.shape[1] @@ -151,19 +177,29 @@ def attention_function(self, query, key, value, mask=None): dim_key = key.shape[-1] # Split queries, keys, and values into heads - query_heads = jnp.reshape(query, (query.shape[0], self.num_heads, input_length, head_dim)) - key_heads = jnp.reshape(key, (key.shape[0], self.num_heads, context_length, head_dim)) - value_heads = jnp.reshape(value, (value.shape[0], self.num_heads, context_length, head_dim)) + query_heads = jnp.reshape( + query, (query.shape[0], self.num_heads, input_length, head_dim) + ) + key_heads = jnp.reshape( + key, (key.shape[0], self.num_heads, context_length, head_dim) + ) + value_heads = jnp.reshape( + value, (value.shape[0], self.num_heads, context_length, head_dim) + ) - attention_scores = jnp.matmul(query_heads, key_heads.transpose(0, 1, 3, 2)) / jnp.sqrt(dim_key) + attention_scores = jnp.matmul( + query_heads, key_heads.transpose(0, 1, 3, 2) + ) / jnp.sqrt(dim_key) if mask is not None: attention_scores = attention_scores * mask attention_weights = jax.nn.softmax(attention_scores, axis=-1) attended_values = jnp.matmul(attention_weights, value_heads) - attended_values = jnp.reshape(attended_values, (query.shape[0], input_length, query.shape[-1])) + attended_values = jnp.reshape( + attended_values, (query.shape[0], input_length, query.shape[-1]) + ) return attended_values, attention_weights - + class PositionWiseFFN(nn.Module): """ @@ -179,17 +215,22 @@ class PositionWiseFFN(nn.Module): setup(): Initializes the two linear layers. __call__(X: jnp.ndarray): Applies the position-wise feed-forward network to the input tensor. """ + num_hiddens: int num_outputs: int def setup(self): - self.dense1 = nn.Dense(self.num_hiddens, kernel_init=nn.initializers.xavier_uniform()) + self.dense1 = nn.Dense( + self.num_hiddens, kernel_init=nn.initializers.xavier_uniform() + ) self.activation = nn.gelu - self.dense2 = nn.Dense(self.num_outputs, kernel_init=nn.initializers.xavier_uniform()) + self.dense2 = nn.Dense( + self.num_outputs, kernel_init=nn.initializers.xavier_uniform() + ) def __call__(self, X: jnp.ndarray) -> jnp.ndarray: return self.dense2(self.activation(self.dense1(X))) - + class AddNorm(nn.Module): """ @@ -198,16 +239,15 @@ class AddNorm(nn.Module): Args: dropout (float): Dropout rate for the residual connection. """ + dropout: int @nn.compact - def __call__(self, - X: jnp.ndarray, - Y: jnp.ndarray, - training=False) -> jnp.ndarray: + def __call__(self, X: jnp.ndarray, Y: jnp.ndarray, training=False) -> jnp.ndarray: return nn.LayerNorm()( - nn.Dropout(self.dropout)(Y, deterministic=not training) + X) - + nn.Dropout(self.dropout)(Y, deterministic=not training) + X + ) + class WhisperSpeechEncoderBlock(nn.Module): """ @@ -225,30 +265,31 @@ class WhisperSpeechEncoderBlock(nn.Module): setup(): Initializes the components of the WhisperSpeechEncoderBlock. __call__(x, mask, training): Processes the input tensor through the encoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention = MultiHeadAttention(hidden_dim=self.hidden_dim, - num_heads=self.num_heads) + self.attention = MultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.linear = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + attended_x, attention = self.attention(x, x, mask=mask) x = self.add_norm1(x, attended_x, training) linear_output = self.linear(x) x = self.add_norm2(x, linear_output, training) return x, attention - - + + class WhisperSpeechEncoder(nn.Module): """ Implements the encoder component of the Whisper Speech model. @@ -266,6 +307,7 @@ class WhisperSpeechEncoder(nn.Module): setup(): Initializes the components of the WhisperSpeechEncoder. __call__(x, mask, training): Processes the input audio tensor through the encoder, returning encoded features and attention maps. """ + num_layers: int hidden_dim: int num_heads: int @@ -274,25 +316,25 @@ class WhisperSpeechEncoder(nn.Module): def setup(self): self.embedding = SpeechEmbedding() - - self.layers = [WhisperSpeechEncoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) - for _ in range(self.num_layers)] - - def __call__(self, - x: jnp.ndarray, - mask: jnp.ndarray = None, - training: bool = False) -> tuple: - + + self.layers = [ + WhisperSpeechEncoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + + def __call__( + self, x: jnp.ndarray, mask: jnp.ndarray = None, training: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) for layer in self.layers: x, attention = layer(x, mask=mask, training=training) attention_maps.append(attention) return x, jnp.array(attention_maps) - + class WhisperTextDecoderBlock(nn.Module): """ @@ -310,39 +352,44 @@ class WhisperTextDecoderBlock(nn.Module): setup(): Initializes the components of the Transformer decoder block. __call__(x, context, training): Processes the input tensor through the decoder block. """ + hidden_dim: int num_heads: int feedforward_dim: int dropout: float def setup(self): - self.attention1 = MultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) - self.attention2 = MultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads) + self.attention1 = MultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) + self.attention2 = MultiHeadAttention( + hidden_dim=self.hidden_dim, num_heads=self.num_heads + ) self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim) self.add_norm1 = AddNorm(self.dropout) self.add_norm2 = AddNorm(self.dropout) self.add_norm3 = AddNorm(self.dropout) - def causal_mask(self, - batch_size: int, - destination_dim: int, - source_dim: int) -> jnp.ndarray: - + def causal_mask( + self, batch_size: int, destination_dim: int, source_dim: int + ) -> jnp.ndarray: + # Create index tensors for the source and destination dimensions idx_source = jnp.arange(destination_dim)[:, None] idx_destination = jnp.arange(source_dim) mask = idx_source >= idx_destination - source_dim + destination_dim - mask = mask.astype(jnp.int32) + mask = mask.astype(jnp.int32) # Expand dimensions to match the required output shape mask = mask[None, None, :, :] - return jnp.broadcast_to(mask, (batch_size, self.num_heads, destination_dim, source_dim)) + return jnp.broadcast_to( + mask, (batch_size, self.num_heads, destination_dim, source_dim) + ) + + def __call__( + self, x: jnp.ndarray, context: jnp.ndarray, training: bool = False + ) -> tuple: - def __call__(self, - x: jnp.ndarray, - context: jnp.ndarray, - training: bool = False) -> tuple: - mask = self.causal_mask(x.shape[0], x.shape[1], context.shape[1]) attended_x, attention1 = self.attention1(x, x) @@ -353,9 +400,9 @@ def __call__(self, linear_output = self.feed_forward(x) x = self.add_norm3(x, linear_output, training) - + return x, jnp.array(attention1), jnp.array(attention2) - + class WhisperTextDecoder(nn.Module): """ @@ -384,31 +431,29 @@ class WhisperTextDecoder(nn.Module): num_heads: int feedforward_dim: int dropout: float - max_len : int - vocab_size : int - embed_dim : int - learned_position : bool = True - + max_len: int + vocab_size: int + embed_dim: int + learned_position: bool = True def setup(self): - self.embedding = TokenAndPositionEmbedding(self.max_len, - self.vocab_size, - self.embed_dim, - self.learned_position) - - self.layers = [WhisperTextDecoderBlock(self.hidden_dim, - self.num_heads, - self.feedforward_dim, - self.dropout) for _ in range(self.num_layers)] - + self.embedding = TokenAndPositionEmbedding( + self.max_len, self.vocab_size, self.embed_dim, self.learned_position + ) + + self.layers = [ + WhisperTextDecoderBlock( + self.hidden_dim, self.num_heads, self.feedforward_dim, self.dropout + ) + for _ in range(self.num_layers) + ] + self.outputs = nn.Dense(self.vocab_size) - - def __call__(self, - x: jnp.ndarray, - context: jnp.ndarray, - training: bool = False) -> tuple: - + def __call__( + self, x: jnp.ndarray, context: jnp.ndarray, training: bool = False + ) -> tuple: + attention_maps = [] x = self.embedding(x) cross_attention_maps = [] @@ -416,9 +461,13 @@ def __call__(self, x, attention, cross_attention = layer(x, context, training=training) attention_maps.append(attention) cross_attention_maps.append(cross_attention) - return self.outputs(x), jnp.array(attention_maps), jnp.array(cross_attention_maps) - - + return ( + self.outputs(x), + jnp.array(attention_maps), + jnp.array(cross_attention_maps), + ) + + class Whisper(nn.Module): """ Implements the Whisper model for speech-to-text tasks, such as speech recognition and transcription. @@ -441,7 +490,7 @@ class Whisper(nn.Module): setup(): Initializes the Whisper model including both the encoder and decoder components. __call__(x, y, training): Processes the input audio tensor through the Whisper model, generating textual predictions. generate(x, temperature, deterministic): Generates textual output from input audio sequences. - + Whisper uses an encoder-decoder Transformer (Vaswani et al., 2017) as this, All audio is re-sampled to 16,000 Hz, and an 80-channel logmagnitude Mel spectrogram representation is computed on 25-millisecond windows with a stride of 10 milliseconds. For feature normalization, we globally scale the input to be between -1 and 1 with approximately zero mean across the pre-training dataset. The encoder processes this input representation with a small stem consisting of two convolution layers with a filter width of 3 and the GELU activation @@ -460,19 +509,19 @@ class Whisper(nn.Module): # Dummy data parameters batch_size = 8 max_length = 50 - embed_dim = 256 - vocab_size = 1000 + embed_dim = 256 + vocab_size = 1000 # Generate data: replace with actual tokenised/quantised data dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32) dummy_inputs = jnp.ones((101, max_length, embed_dim)) - dataset = ArrayDataset(dummy_inputs, + dataset = ArrayDataset(dummy_inputs, dummy_targets) - dataloader = DataLoader(dataset, - batch_size=batch_size, - shuffle=True, + dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, drop_last=False) # How to loop through dataloader @@ -503,9 +552,9 @@ class Whisper(nn.Module): print(outputs.shape) # Training on your data - trainer = WhisperDataParallelTrainer(model, - dummy_inputs.shape, - dummy_targets.shape, + trainer = WhisperDataParallelTrainer(model, + dummy_inputs.shape, + dummy_targets.shape, 'params.pkl') trainer.train(dataloader, 2, dataloader) @@ -513,14 +562,15 @@ class Whisper(nn.Module): params = trainer.load_params('params.pkl') # for more than one sample, use model.generate_batch - transcripts = model.apply({'params': params}, - dummy_inputs[:1], - rngs=rngs, + transcripts = model.apply({'params': params}, + dummy_inputs[:1], + rngs=rngs, method=model.generate) print(transcripts) ``` """ + num_layers: int num_heads: int hidden_dim: int @@ -538,9 +588,9 @@ def setup(self): num_heads=self.num_heads, num_layers=self.num_layers, feedforward_dim=self.feedforward_dim, - dropout=self.dropout + dropout=self.dropout, ) - + self.decoder = WhisperTextDecoder( hidden_dim=self.hidden_dim, num_heads=self.num_heads, @@ -551,20 +601,18 @@ def setup(self): vocab_size=self.vocab_size, embed_dim=self.embed_dim, ) - - def __call__(self, - x: jnp.ndarray, - y: jnp.ndarray, - training: bool = False) -> jnp.ndarray: - + + def __call__( + self, x: jnp.ndarray, y: jnp.ndarray, training: bool = False + ) -> jnp.ndarray: + z = self.encoder(x=x, training=training)[0] return self.decoder(x=y, context=z, training=training)[0] - - def generate(self, - x: jnp.ndarray, - temperature: float = 1.0, - deterministic: bool = False) -> Tuple[jnp.ndarray]: - + + def generate( + self, x: jnp.ndarray, temperature: float = 1.0, deterministic: bool = False + ) -> Tuple[jnp.ndarray]: + # Encode the input sequence encoded_sequence = self.encoder(x=x, training=False)[0] @@ -573,9 +621,9 @@ def generate(self, # Autoregressive decoding loop for _ in range(self.max_length): - decoder_output = self.decoder(x=decoder_input, - context=encoded_sequence, - training=False)[0] + decoder_output = self.decoder( + x=decoder_input, context=encoded_sequence, training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -583,23 +631,27 @@ def generate(self, if deterministic: next_token = jnp.argmax(next_token_probabilities, axis=-1) else: - next_token = jax.random.categorical(jax.random.PRNGKey(int(time.time())), next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + jax.random.PRNGKey(int(time.time())), + next_token_probabilities, + axis=-1, + ) next_token = next_token[0] output_sequence.append(next_token.item()) - decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, jnp.array([[next_token]])], axis=1 + ) if next_token.item() == self.end_token: break return jnp.array(output_sequence) - - def generate_batch(self, - x: jnp.ndarray, - temperature: float = 1.0, - deterministic: bool = False) -> jnp.ndarray: - + def generate_batch( + self, x: jnp.ndarray, temperature: float = 1.0, deterministic: bool = False + ) -> jnp.ndarray: + # Encode the input sequence encoded_sequence = self.encoder(x=x, training=False)[0] @@ -608,9 +660,9 @@ def generate_batch(self, output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32) for i in range(self.max_length): - decoder_output = self.decoder(decoder_input, - context=encoded_sequence, - training=False)[0] + decoder_output = self.decoder( + decoder_input, context=encoded_sequence, training=False + )[0] last_token_logits = decoder_output[:, -1, :] scaled_logits = last_token_logits / temperature next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1) @@ -619,10 +671,14 @@ def generate_batch(self, next_token = jnp.argmax(next_token_probabilities, axis=-1) else: key = jax.random.PRNGKey(int(time.time())) - next_token = jax.random.categorical(key, next_token_probabilities, axis=-1) + next_token = jax.random.categorical( + key, next_token_probabilities, axis=-1 + ) output_sequences = output_sequences.at[:, i].set(next_token) - decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1) + decoder_input = jnp.concatenate( + [decoder_input, next_token[:, None]], axis=1 + ) if jnp.all(next_token == self.end_token): break @@ -630,11 +686,10 @@ def generate_batch(self, return output_sequences - class WhisperDataParallelTrainer: """ Trainer class using data parallelism with JAX. - This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). + This trainer leverages JAX's `pmap` for parallel training across multiple devices (GPUs/TPUs). It handles the model training loop, including gradient computation, parameter updates, and evaluation. Attributes: @@ -654,13 +709,16 @@ class WhisperDataParallelTrainer: save_params(): Saves the model parameters to a file. load_params(filename): Loads model parameters from a file. """ - def __init__(self, - model: Any, - input_shape: Tuple[int, ...], - target_shape: Tuple[int, ...], - weights_filename: str, - learning_rate: float = 1e-5, - params_path: Optional[str] = None) -> None: + + def __init__( + self, + model: Any, + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + weights_filename: str, + learning_rate: float = 1e-5, + params_path: Optional[str] = None, + ) -> None: self.model = model self.params = None self.params_path = params_path @@ -668,111 +726,144 @@ def __init__(self, self.best_val_loss = float("inf") self.weights_filename = weights_filename self.num_devices = jax.local_device_count() - self.train_step = jax.pmap(WhisperDataParallelTrainer.train_step, axis_name='devices') - self.evaluation_step = jax.pmap(WhisperDataParallelTrainer.evaluation_step, axis_name='devices') + self.train_step = jax.pmap( + WhisperDataParallelTrainer.train_step, axis_name="devices" + ) + self.evaluation_step = jax.pmap( + WhisperDataParallelTrainer.evaluation_step, axis_name="devices" + ) self.state = self.create_train_state(learning_rate, input_shape, target_shape) - print(f'Number of accelerators: {self.num_devices}') - - - def create_train_state(self, - learning_rate: float, - input_shape: Tuple[int, ...], - target_shape: Tuple[int, ...]) -> Any: - - rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} - params = self.model.init(rngs, - jnp.ones(input_shape, dtype=jnp.int32), - jnp.ones(target_shape, dtype=jnp.int32))['params'] + print(f"Number of accelerators: {self.num_devices}") + + def create_train_state( + self, + learning_rate: float, + input_shape: Tuple[int, ...], + target_shape: Tuple[int, ...], + ) -> Any: + + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = self.model.init( + rngs, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(target_shape, dtype=jnp.int32), + )["params"] if self.params_path is not None: params = self.load_params(self.params_path) - self.num_parameters = sum(param.size for param in jax.tree_util.tree_leaves(params)) - print(f'Number of parameters: {self.num_parameters}') - state = train_state.TrainState.create(apply_fn=self.model.apply, - params=params, - tx=optax.adam(learning_rate)) + self.num_parameters = sum( + param.size for param in jax.tree_util.tree_leaves(params) + ) + print(f"Number of parameters: {self.num_parameters}") + state = train_state.TrainState.create( + apply_fn=self.model.apply, params=params, tx=optax.adam(learning_rate) + ) return jax.device_put_replicated(state, jax.local_devices()) - + @staticmethod - def train_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - + def train_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + def loss_fn(params): - logits = state.apply_fn({'params': params}, - inputs, - targets, - training=True, - rngs={'dropout': jax.random.PRNGKey(int(time.time()))}) - return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - + logits = state.apply_fn( + {"params": params}, + inputs, + targets, + training=True, + rngs={"dropout": jax.random.PRNGKey(int(time.time()))}, + ) + return optax.softmax_cross_entropy_with_integer_labels( + logits, targets + ).mean() + loss, grads = jax.value_and_grad(loss_fn)(state.params) state = state.apply_gradients(grads=grads) return state, loss - def train(self, - train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], - num_epochs: int, - val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None) -> None: - + def train( + self, + train_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]], + num_epochs: int, + val_loader: Optional[Iterable[Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> None: + for epoch in range(num_epochs): total_loss = 0.0 count = 0 for inputs, targets in train_loader: batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices - inputs = inputs.reshape((self.num_devices, batch_size_per_device, inputs.shape[1], inputs.shape[2])) + inputs = inputs.reshape( + ( + self.num_devices, + batch_size_per_device, + inputs.shape[1], + inputs.shape[2], + ) + ) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) - self.state, loss = self.train_step(state=self.state, - inputs=inputs, - targets=targets) + self.state, loss = self.train_step( + state=self.state, inputs=inputs, targets=targets + ) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count - print(f'Epoch {epoch+1}, Train Loss: {mean_loss}') + print(f"Epoch {epoch+1}, Train Loss: {mean_loss}") if val_loader is not None: val_loss = self.evaluate(val_loader) - print(f'Epoch {epoch+1}, Val Loss: {val_loss}') + print(f"Epoch {epoch+1}, Val Loss: {val_loss}") if val_loss < self.best_val_loss: self.best_val_loss = val_loss print("New best validation score achieved, saving model...") self.save_params() - return - + return + @staticmethod - def evaluation_step(state: Any, - inputs: jnp.ndarray, - targets: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: - - logits = state.apply_fn({'params': state.params}, inputs, targets, rngs={'dropout': jax.random.PRNGKey(2)}) + def evaluation_step( + state: Any, inputs: jnp.ndarray, targets: jnp.ndarray + ) -> Tuple[Any, jnp.ndarray]: + + logits = state.apply_fn( + {"params": state.params}, + inputs, + targets, + rngs={"dropout": jax.random.PRNGKey(2)}, + ) return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - def evaluate(self, - test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: - + def evaluate(self, test_loader: Iterable[Tuple[jnp.ndarray, jnp.ndarray]]) -> None: + total_loss = 0.0 count = 0 for inputs, targets in test_loader: batch_size = inputs.shape[0] batch_size_per_device = batch_size // self.num_devices - inputs = inputs.reshape((self.num_devices, batch_size_per_device, inputs.shape[1], inputs.shape[2])) + inputs = inputs.reshape( + ( + self.num_devices, + batch_size_per_device, + inputs.shape[1], + inputs.shape[2], + ) + ) targets = targets.reshape((self.num_devices, batch_size_per_device, -1)) loss = self.evaluation_step(self.state, inputs, targets) total_loss += jnp.mean(loss) count += 1 - + mean_loss = total_loss / count return mean_loss def save_params(self) -> None: self.params = flax.jax_utils.unreplicate(self.state.params) - with open(self.weights_filename, 'wb') as f: + with open(self.weights_filename, "wb") as f: f.write(flax.serialization.to_bytes(self.params)) def load_params(self, filename: str): - with open(filename, 'rb') as f: + with open(filename, "rb") as f: self.params = flax.serialization.from_bytes(self.params, f.read()) - return self.params \ No newline at end of file + return self.params diff --git a/nanodl/__src/sklearn_gpu/__init__.py b/nanodl/__src/sklearn_gpu/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/nanodl/__src/sklearn_gpu/dsp.py b/nanodl/__src/sklearn_gpu/dsp.py deleted file mode 100644 index 592e81a..0000000 --- a/nanodl/__src/sklearn_gpu/dsp.py +++ /dev/null @@ -1,126 +0,0 @@ -import jax.numpy as jnp -from jax import random -from typing import Tuple - -def fastica(X: jnp.ndarray, - n_components: jnp.ndarray, - max_iter: int = 1000, - tol: float = 1e-4) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: - """ - Perform Independent Component Analysis (ICA) on the input data using the FastICA algorithm. - - Parameters: - X : jax.numpy.ndarray - The input data matrix, where each row represents a data point, and each column represents a different signal. - The input data should be a 2D jax.numpy array with shape (n_samples, n_features). - n_components : int - The number of independent components to extract. This should be less than or equal to the number of features in the input data. - max_iter : int, optional - The maximum number of iterations for the optimization process. The default value is 1000 iterations. - tol : float, optional - The tolerance for convergence. The optimization process stops when the maximum absolute change in the diagonal elements of the - unmixing matrix from one iteration to the next is less than this tolerance. The default value is 1e-4. - - Returns: - S : jax.numpy.ndarray - The separated independent components. This is a 2D jax.numpy array with shape (n_components, n_samples), where each row represents - a different independent component, and each column represents a data point. - W : jax.numpy.ndarray - The unmixing matrix. This is a 2D jax.numpy array with shape (n_components, n_features), representing the estimated inverse of the - mixing matrix. It is used to transform the input data back into the independent components. - whitening_matrix : jax.numpy.ndarray - The whitening matrix used to whiten the input data. This is a 2D jax.numpy array with shape (n_features, n_features), used to decorrelate - the input data and make its covariance matrix the identity matrix. - - Description: - The FastICA algorithm aims to separate the mixed input signals into statistically independent components. The function first whitens the input - data to decorrelate it and normalize its variance. Then, it initializes a random unmixing matrix and uses an optimization process to find - the optimal unmixing matrix that maximizes the independence of the source signals. - - The optimization process involves iteratively updating the unmixing matrix based on the non-linear function (`tanh` in this case) applied - to the transformed data (`WX`). The process stops when the unmixing matrix converges according to the specified tolerance (`tol`) or when the - maximum number of iterations (`max_iter`) is reached. - - Once the optimal unmixing matrix is found, the function applies it to the whitened data to obtain the separated independent components. - - Example usage: - # Set random seed - jax.random.PRNGKey(42) - - # Generate synthetic source signals - n_samples = 2000 - time = jnp.linspace(0, 8, n_samples) - s1 = jnp.sin(2 * time) - s2 = jnp.sign(jnp.sin(3 * time)) - - # Combine the sources with a mixing matrix - A = jnp.array([[1, 1], [0.5, 2]]) - X = jnp.dot(A, jnp.array([s1, s2])) - - # Perform ICA - n_components = 2 - S, W, whitening_matrix = fastica(X.T, n_components) - - # Plot the results - plt.figure(figsize=(12, 8)) - - plt.subplot(3, 1, 1) - plt.title('Original Source Signals') - plt.plot(time, s1, label='Source 1 (Sine Wave)') - plt.plot(time, s2, label='Source 2 (Square Wave)') - plt.legend() - - plt.subplot(3, 1, 2) - plt.title('Mixed Signals') - plt.plot(time, X[0], label='Mixed Signal 1') - plt.plot(time, X[1], label='Mixed Signal 2') - plt.legend() - - plt.subplot(3, 1, 3) - plt.title('Separated Signals (Using ICA)') - plt.plot(time, S[0], label='Separated Signal 1') - plt.plot(time, S[1], label='Separated Signal 2') - plt.legend() - - plt.tight_layout() - plt.show() - """ - # Calculate the covariance matrix and perform eigenvalue decomposition - cov_matrix = jnp.cov(X, rowvar=False) - eigenvalues, eigenvectors = jnp.linalg.eigh(cov_matrix) - - # Sort the eigenvalues and eigenvectors - idx = jnp.argsort(eigenvalues)[::-1] - eigenvalues = eigenvalues[idx] - eigenvectors = eigenvectors[:, idx] - - # Create the whitening matrix - D = jnp.diag(1.0 / jnp.sqrt(eigenvalues)) - whitening_matrix = jnp.dot(eigenvectors, D) - X_whitened = jnp.dot(X, whitening_matrix) - - # Initialize unmixing matrix with random values - rng = random.PRNGKey(0) # Set a seed for reproducibility - W = random.normal(rng, (n_components, n_components)) - - # Perform FastICA algorithm - for _ in range(max_iter): - WX = jnp.dot(X_whitened, W.T) - g = jnp.tanh(WX) - g_prime = 1 - g ** 2 - W_new = (jnp.dot(X_whitened.T, g) / X.shape[0]) - jnp.diag(g_prime.mean(axis=0)).dot(W) - - # Orthogonalize the unmixing matrix - W_new, _ = jnp.linalg.qr(W_new) - - # Check for convergence - if jnp.max(jnp.abs(jnp.abs(jnp.diag(jnp.dot(W_new, W.T))) - 1)) < tol: - W = W_new - break - - W = W_new - - # Calculate the separated independent components - S = jnp.dot(W, X_whitened.T) - - return S, W, whitening_matrix \ No newline at end of file diff --git a/nanodl/__src/utils/data.py b/nanodl/__src/utils/data.py index ccdca2c..c004d9c 100644 --- a/nanodl/__src/utils/data.py +++ b/nanodl/__src/utils/data.py @@ -1,14 +1,16 @@ -import jax import collections -import jax.numpy as jnp -from typing import Iterator from dataclasses import dataclass +from typing import Iterator + +import jax +import jax.numpy as jnp # This script modifies the JAX DataLoader from the following repository: # JAX DataLoader by Birkhoff G. (https://birkhoffg.github.io/jax-dataloader/) # Accessed on [Date you accessed the repository, e.g., February 4, 2024] # This DataLoader implementation is used for efficient data loading in JAX-based machine learning projects. + class Dataset: """ A PyTorch-like Dataset class for JAX. @@ -59,8 +61,9 @@ class ArrayDataset(Dataset): """ def __init__(self, *arrays: jnp.array): - assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \ - "All arrays must have the same first dimension." + assert all( + arrays[0].shape[0] == arr.shape[0] for arr in arrays + ), "All arrays must have the same first dimension." self.arrays = arrays def __len__(self): @@ -94,7 +97,14 @@ class DataLoader: ``` """ - def __init__(self, dataset: Dataset, batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, **kwargs): + def __init__( + self, + dataset: Dataset, + batch_size: int = 1, + shuffle: bool = False, + drop_last: bool = False, + **kwargs + ): self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle @@ -102,14 +112,14 @@ def __init__(self, dataset: Dataset, batch_size: int = 1, shuffle: bool = False, self.keys = PRNGSequence(seed=Config.default().global_seed) self.data_len = len(dataset) # Length of the dataset - self.indices = jnp.arange(self.data_len) # available indices in the dataset + self.indices = jnp.arange(self.data_len) # available indices in the dataset self.pose = 0 # record the current position in the dataset self._shuffle() def _shuffle(self): if self.shuffle: self.indices = jax.random.permutation(next(self.keys), self.indices) - + def _stop_iteration(self): self.pose = 0 self._shuffle() @@ -119,17 +129,19 @@ def __len__(self): if self.drop_last: batches = len(self.dataset) // self.batch_size # get the floor of division else: - batches = -(len(self.dataset) // -self.batch_size) # get the ceil of division + batches = -( + len(self.dataset) // -self.batch_size + ) # get the ceil of division return batches def __next__(self): if self.pose + self.batch_size <= self.data_len: - batch_indices = self.indices[self.pose: self.pose + self.batch_size] + batch_indices = self.indices[self.pose : self.pose + self.batch_size] batch_data = self.dataset[batch_indices] self.pose += self.batch_size return batch_data elif self.pose < self.data_len and not self.drop_last: - batch_indices = self.indices[self.pose:] + batch_indices = self.indices[self.pose :] batch_data = self.dataset[batch_indices] self.pose += self.batch_size return batch_data @@ -138,7 +150,7 @@ def __next__(self): def __iter__(self): return self - + @dataclass class Config: @@ -149,6 +161,7 @@ class Config: def default(cls): return cls(rng_reserve_size=1, global_seed=42) + class PRNGSequence(Iterator[jax.random.PRNGKey]): """ An Iterator of Jax PRNGKey (minimal version of `haiku.PRNGSequence`). @@ -175,8 +188,8 @@ def reserve(self, num): new_keys = tuple(jax.random.split(self._key, num + 1)) self._key = new_keys[0] self._subkeys.extend(new_keys[1:]) - + def __next__(self): if not self._subkeys: self.reserve(Config.default().rng_reserve_size) - return self._subkeys.popleft() \ No newline at end of file + return self._subkeys.popleft() diff --git a/nanodl/__src/utils/ml.py b/nanodl/__src/utils/ml.py index 88e454e..d0c41e3 100644 --- a/nanodl/__src/utils/ml.py +++ b/nanodl/__src/utils/ml.py @@ -1,11 +1,13 @@ +from typing import Any, List + import jax import jax.numpy as jnp -from typing import List, Any @jax.jit -def batch_cosine_similarities(source: jnp.ndarray, - candidates: jnp.ndarray) -> jnp.ndarray: +def batch_cosine_similarities( + source: jnp.ndarray, candidates: jnp.ndarray +) -> jnp.ndarray: """ Calculate cosine similarities between a source vector and a batch of candidate vectors. @@ -25,13 +27,13 @@ def batch_cosine_similarities(source: jnp.ndarray, ``` """ dot_products = jnp.einsum("ij,j->i", candidates, source) - norm_source = jnp.sqrt(jnp.einsum('i,i->', source, source)) - norm_candidates = jnp.sqrt(jnp.einsum('ij,ij->i', candidates, candidates)) + norm_source = jnp.sqrt(jnp.einsum("i,i->", source, source)) + norm_candidates = jnp.sqrt(jnp.einsum("ij,ij->i", candidates, candidates)) return dot_products / (norm_source * norm_candidates) + @jax.jit -def batch_pearsonr(x: jnp.ndarray, - y: jnp.ndarray) -> jnp.ndarray: +def batch_pearsonr(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """ Calculate batch Pearson correlation coefficient between two sets of vectors. @@ -55,20 +57,21 @@ def batch_pearsonr(x: jnp.ndarray, x = x - jnp.expand_dims(x.mean(axis=1), axis=-1) y = y - jnp.expand_dims(y.mean(axis=1), axis=-1) numerator = jnp.sum(x * y, axis=-1) - sum_of_squares_x = jnp.einsum('ij,ij -> i', x, x) - sum_of_squares_y = jnp.einsum('ij,ij -> i', y, y) + sum_of_squares_x = jnp.einsum("ij,ij -> i", x, x) + sum_of_squares_y = jnp.einsum("ij,ij -> i", y, y) denominator = jnp.sqrt(sum_of_squares_x * sum_of_squares_y) return numerator / denominator + @jax.jit def classification_scores(labels: jnp.ndarray, preds: jnp.ndarray) -> jnp.ndarray: """ Calculate classification evaluation scores using JAX. - + Args: labels (jnp.ndarray): Array of true labels. preds (jnp.ndarray): Array of predicted labels. - + Returns: jnp.ndarray: Array containing accuracy, precision, recall, and F1-score. @@ -90,11 +93,12 @@ def classification_scores(labels: jnp.ndarray, preds: jnp.ndarray) -> jnp.ndarra f1 = 2 * (precision * recall) / (precision + recall) return jnp.array([accuracy, precision, recall, f1]) + @jax.jit def mean_reciprocal_rank(predictions: jnp.ndarray) -> float: """ Calculate the Mean Reciprocal Rank (MRR) for a list of ranked predictions using JAX. - + Example usage: ``` predictions = jnp.array([ @@ -104,11 +108,11 @@ def mean_reciprocal_rank(predictions: jnp.ndarray) -> float: ]) mrr_score = mean_reciprocal_rank(predictions) ``` - + Args: predictions (jnp.ndarray): 2D array where each row contains ranked predictions and the "correct" prediction is indicated by a specific index. - + Returns: float: Mean Reciprocal Rank (MRR) score. """ @@ -118,8 +122,8 @@ def mean_reciprocal_rank(predictions: jnp.ndarray) -> float: mean_mrr = jnp.mean(reciprocal_ranks) return mean_mrr -def jaccard(sequence1: List, - sequence2: List) -> float: + +def jaccard(sequence1: List, sequence2: List) -> float: """ Calculate Jaccard similarity between two sequences. @@ -142,9 +146,9 @@ def jaccard(sequence1: List, denominator = len(set(sequence1).union(sequence2)) return numerator / denominator + @jax.jit -def hamming(sequence1: jnp.ndarray, - sequence2: jnp.ndarray) -> int: +def hamming(sequence1: jnp.ndarray, sequence2: jnp.ndarray) -> int: """ Calculate Hamming similarity between two sequences using JAX. @@ -165,13 +169,13 @@ def hamming(sequence1: jnp.ndarray, """ return jnp.sum(sequence1 == sequence2) -def zero_pad_sequences(arr: jnp.array, - max_length: int) -> jnp.array: + +def zero_pad_sequences(arr: jnp.array, max_length: int) -> jnp.array: """ Zero-pad the given array to the specified maximum length along axis=1. - This function pads the input array with zeros along the second dimension (axis=1) - until it reaches the specified maximum length. If the array is already longer + This function pads the input array with zeros along the second dimension (axis=1) + until it reaches the specified maximum length. If the array is already longer than the maximum length, it is returned as is. Args: @@ -191,8 +195,8 @@ def zero_pad_sequences(arr: jnp.array, [4 5 6 0 0]] ``` """ - current_length = arr.shape[1] - num_zeros = max_length - current_length + current_length = arr.shape[1] + num_zeros = max_length - current_length if num_zeros > 0: zeros = jnp.zeros((arr.shape[0], num_zeros), dtype=arr.dtype) @@ -202,20 +206,21 @@ def zero_pad_sequences(arr: jnp.array, return padded_array + @jax.jit def entropy(probabilities: jnp.ndarray) -> float: """ Calculate the entropy of a probability distribution using JAX. - + Example usage: ``` probabilities = jnp.array([0.25, 0.75]) entropy_value = entropy(probabilities) ``` - + Args: probabilities (jnp.ndarray): Array of probability values. - + Returns: float: Entropy value. """ @@ -223,65 +228,67 @@ def entropy(probabilities: jnp.ndarray) -> float: entropy_value = -jnp.sum(probabilities * log_probs) return entropy_value + @jax.jit def gini_impurity(probabilities: jnp.ndarray) -> float: """ Calculate the Gini impurity of a probability distribution using JAX. - + Example usage: ``` probabilities = jnp.array([0.25, 0.75]) gini_value = gini_impurity(probabilities) ``` - + Args: probabilities (jnp.ndarray): Array of probability values. - + Returns: float: Gini impurity value. """ - gini_value = 1 - jnp.sum(probabilities ** 2) + gini_value = 1 - jnp.sum(probabilities**2) return gini_value + @jax.jit -def kl_divergence(p: jnp.ndarray, - q: jnp.ndarray) -> float: +def kl_divergence(p: jnp.ndarray, q: jnp.ndarray) -> float: """ Calculate the Kullback-Leibler (KL) divergence between two probability distributions using JAX. - + Example usage: ``` p = jnp.array([0.25, 0.75]) q = jnp.array([0.5, 0.5]) kl_value = kl_divergence(p, q) ``` - + Args: p (jnp.ndarray): Array of probability values for distribution p. q (jnp.ndarray): Array of probability values for distribution q. - + Returns: float: KL divergence value. """ kl_value = jnp.sum(p * jnp.log2(p / q)) return kl_value + @jax.jit def count_parameters(params: Any) -> int: """ Count the total number of parameters in a model's parameter dictionary using JAX. - + Example usage: ``` model = MyModel() params = model.init(jax.random.PRNGKey(0), jnp.ones(input_shape)) total_params = count_parameters(params) ``` - + Args: params (Any): Model's parameter dictionary. - + Returns: int: Total number of parameters. """ - return sum(x.size for x in jax.tree_leaves(params)) \ No newline at end of file + return sum(x.size for x in jax.tree_leaves(params)) diff --git a/nanodl/__src/utils/nlp.py b/nanodl/__src/utils/nlp.py index 94d7810..f763615 100644 --- a/nanodl/__src/utils/nlp.py +++ b/nanodl/__src/utils/nlp.py @@ -1,14 +1,13 @@ import re -import json -import numpy as np -import jax.numpy as jnp from collections import Counter -from typing import Dict, Any, Union, List +from typing import List + +import numpy as np -def rouge(hypotheses: List[str], - references: List[str], - ngram_sizes: List[int] = [1, 2]) -> dict: +def rouge( + hypotheses: List[str], references: List[str], ngram_sizes: List[int] = [1, 2] +) -> dict: """ Calculate the ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric. ROUGE-F1 = (Precision + Recall) / (2⋅Precision⋅Recall) @@ -17,10 +16,10 @@ def rouge(hypotheses: List[str], hypotheses (List[str]): List of hypothesis sentences. references (List[str]): List of reference sentences. ngram_sizes (List[int], optional): List of n-gram sizes. Default is [1, 2]. - + Returns: dict: Dictionary containing precision, recall, and F1-score for each n-gram size. - + Example usage: ``` >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] @@ -29,21 +28,30 @@ def rouge(hypotheses: List[str], >>> print(rouge_scores) ``` """ + def ngrams(sequence: List[str], n: int) -> List[str]: - return [tuple(sequence[i:i+n]) for i in range(len(sequence) - n + 1)] - + return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)] + def precision_recall_f1(hypothesis_tokens, reference_tokens, n): hypothesis_ngrams = set(ngrams(hypothesis_tokens, n)) reference_ngrams = set(ngrams(reference_tokens, n)) - + common_ngrams = hypothesis_ngrams.intersection(reference_ngrams) - - precision = len(common_ngrams) / len(hypothesis_ngrams) if len(hypothesis_ngrams) > 0 else 0.0 - recall = len(common_ngrams) / len(reference_ngrams) if len(reference_ngrams) > 0 else 0.0 - + + precision = ( + len(common_ngrams) / len(hypothesis_ngrams) + if len(hypothesis_ngrams) > 0 + else 0.0 + ) + recall = ( + len(common_ngrams) / len(reference_ngrams) + if len(reference_ngrams) > 0 + else 0.0 + ) + f1 = 2 * (precision * recall) / (precision + recall + 1e-12) return precision, recall, f1 - + rouge_scores = {} for n in ngram_sizes: total_precision = 0.0 @@ -52,41 +60,41 @@ def precision_recall_f1(hypothesis_tokens, reference_tokens, n): for hypothesis, reference in zip(hypotheses, references): hypothesis_tokens = hypothesis.split() reference_tokens = reference.split() - - precision, recall, f1 = precision_recall_f1(hypothesis_tokens, reference_tokens, n) + + precision, recall, f1 = precision_recall_f1( + hypothesis_tokens, reference_tokens, n + ) total_precision += precision total_recall += recall total_f1 += f1 - + average_precision = total_precision / len(hypotheses) average_recall = total_recall / len(hypotheses) average_f1 = total_f1 / len(hypotheses) - - rouge_scores[f'ROUGE-{n}'] = { - 'precision': average_precision, - 'recall': average_recall, - 'f1': average_f1 + + rouge_scores[f"ROUGE-{n}"] = { + "precision": average_precision, + "recall": average_recall, + "f1": average_f1, } - + return rouge_scores -def bleu(hypotheses: List[str], - references: List[str], - max_ngram: int = 4) -> float: +def bleu(hypotheses: List[str], references: List[str], max_ngram: int = 4) -> float: """ Calculate the BLEU (Bilingual Evaluation Understudy) metric. BLEU = (BP) * (exp(sum(wn * log(pn)))) where BP = brevity penalty, wn = weight for n-gram precision, and pn = n-gram precision - + Args: hypotheses (List[str]): List of hypothesis sentences. references (List[str]): List of reference sentences. max_ngram (int, optional): Maximum n-gram size to consider. Default is 4. - + Returns: float: BLEU score. - + Example usage: ``` >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] @@ -95,56 +103,58 @@ def bleu(hypotheses: List[str], >>> print(bleu_score) ``` """ + def ngrams(sequence: List[str], n: int) -> List[str]: - return [tuple(sequence[i:i+n]) for i in range(len(sequence) - n + 1)] - + return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)] + def modified_precision(hypothesis_tokens, reference_tokens, n): hypothesis_ngrams = ngrams(hypothesis_tokens, n) reference_ngrams = ngrams(reference_tokens, n) - + hypothesis_ngram_counts = Counter(hypothesis_ngrams) reference_ngram_counts = Counter(reference_ngrams) - + common_ngrams = hypothesis_ngram_counts & reference_ngram_counts common_count = sum(common_ngrams.values()) - + if len(hypothesis_ngrams) == 0: return 0.0 else: precision = common_count / len(hypothesis_ngrams) return precision - + brevity_penalty = np.exp(min(0, 1 - len(hypotheses) / len(references))) bleu_scores = [] - + for n in range(1, max_ngram + 1): ngram_precisions = [] for hypothesis, reference in zip(hypotheses, references): hypothesis_tokens = hypothesis.split() reference_tokens = reference.split() - + precision = modified_precision(hypothesis_tokens, reference_tokens, n) ngram_precisions.append(precision) - + geometric_mean = np.exp(np.mean(np.log(np.clip(ngram_precisions, 1e-10, None)))) bleu_scores.append(geometric_mean) - - final_bleu = brevity_penalty * np.exp(np.mean(np.log(np.clip(bleu_scores, 1e-10, None)))) + + final_bleu = brevity_penalty * np.exp( + np.mean(np.log(np.clip(bleu_scores, 1e-10, None))) + ) return final_bleu -def meteor(hypothesis: str, - reference: str) -> float: +def meteor(hypothesis: str, reference: str) -> float: """ Calculates the METEOR score between a reference and hypothesis sentence. - + Args: reference (str): The reference sentence. hypothesis (str): The hypothesis sentence. - + Returns: float: METEOR score. - + Example usage: ``` >>> hypothesis = "the cat is on the mat" @@ -153,27 +163,31 @@ def meteor(hypothesis: str, >>> print(meteor_score) ``` """ - + def tokenize(sentence): - return re.findall(r'\w+', sentence.lower()) - + return re.findall(r"\w+", sentence.lower()) + def stemming(token): return token.lower() - + def exact_matching(reference_tokens, hypothesis_tokens): return sum(1 for token in hypothesis_tokens if token in reference_tokens) - + def stemmed_matching(reference_tokens, hypothesis_tokens): stemmed_reference = [stemming(token) for token in reference_tokens] stemmed_hypothesis = [stemming(token) for token in hypothesis_tokens] return sum(1 for token in stemmed_hypothesis if token in stemmed_reference) - + def precision_recall_f1(match_count, hypothesis_length, reference_length): precision = match_count / hypothesis_length if hypothesis_length > 0 else 0 recall = match_count / reference_length if reference_length > 0 else 0 - f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 + f1 = ( + 2 * precision * recall / (precision + recall) + if precision + recall > 0 + else 0 + ) return precision, recall, f1 - + reference_tokens = tokenize(reference) hypothesis_tokens = tokenize(hypothesis) @@ -193,18 +207,17 @@ def precision_recall_f1(match_count, hypothesis_length, reference_length): return meteor_score -def cider_score(hypothesis: str, - reference: str) -> float: +def cider_score(hypothesis: str, reference: str) -> float: """ Calculates the CIDEr score between a reference and hypothesis sentence. - + Args: reference (str): The reference sentence. hypothesis (str): The hypothesis sentence. - + Returns: float: CIDEr score. - + Example usage: ``` >>> hypothesis = "the cat is on the mat" @@ -213,11 +226,12 @@ def cider_score(hypothesis: str, >>> print(score) ``` """ + def tokenize(sentence): - return re.findall(r'\w+', sentence.lower()) - + return re.findall(r"\w+", sentence.lower()) + def ngrams(tokens, n): - return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)] + return [tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)] reference_tokens = tokenize(reference) hypothesis_tokens = tokenize(hypothesis) @@ -226,28 +240,30 @@ def ngrams(tokens, n): weights = [1.0] * max_n # Weights for different n-gram sizes cider_scores = [] - for n in range(1, max_n+1): + for n in range(1, max_n + 1): ref_ngrams = ngrams(reference_tokens, n) hyp_ngrams = ngrams(hypothesis_tokens, n) - + ref_ngram_freq = Counter(ref_ngrams) hyp_ngram_freq = Counter(hyp_ngrams) - + common_ngrams = set(ref_ngrams) & set(hyp_ngrams) - + if len(common_ngrams) == 0: cider_scores.append(0) continue - - precision = sum(min(ref_ngram_freq[ngram], hyp_ngram_freq[ngram]) for ngram in common_ngrams) / len(hyp_ngrams) + + precision = sum( + min(ref_ngram_freq[ngram], hyp_ngram_freq[ngram]) for ngram in common_ngrams + ) / len(hyp_ngrams) ref_ngram_freq_sum = sum(ref_ngram_freq[ngram] for ngram in common_ngrams) hyp_ngram_freq_sum = sum(hyp_ngram_freq[ngram] for ngram in common_ngrams) recall = ref_ngram_freq_sum / len(ref_ngrams) - + cider_scores.append((precision * recall) / (precision + recall) * 2) - + avg_cider_score = np.average(cider_scores, weights=weights) - + return avg_cider_score @@ -255,14 +271,14 @@ def perplexity(log_probs: List[float]) -> float: """ Calculate the perplexity of a sequence using a list of log probabilities. Perplexity = 2^(-average log likelihood) - where average log likelihood = total log likelihood / total word count + where average log likelihood = total log likelihood / total word count Args: log_probs (List[float]): List of log probabilities for each predicted word. - + Returns: float: Perplexity score. - + Example usage: ``` >>> log_probs = [-2.3, -1.7, -0.4] # Example log probabilities @@ -274,7 +290,9 @@ def perplexity(log_probs: List[float]) -> float: word_count = 0 for i in range(len(log_probs) - 1): - predicted_log_prob = log_probs[i] # Replace this with your language model's log probability + predicted_log_prob = log_probs[ + i + ] # Replace this with your language model's log probability log_likelihood += predicted_log_prob word_count += 1 @@ -283,7 +301,7 @@ def perplexity(log_probs: List[float]) -> float: return perplexity_score -def word_error_rate(hypotheses: List[int], references: List[int]) -> float: +def word_error_rate(hypotheses: List[int], references: List[int]) -> float: """ Calculate the Word Error Rate (WER) metric. @@ -293,7 +311,7 @@ def word_error_rate(hypotheses: List[int], references: List[int]) -> float: Returns: float: Word Error Rate score. - + Example usage: ``` >>> hypotheses = ["the cat is on the mat", "there is a cat on the mat"] @@ -302,6 +320,7 @@ def word_error_rate(hypotheses: List[int], references: List[int]) -> float: >>> print(wer_score) ``` """ + def edit_distance(str1, str2): len_str1 = len(str1) len_str2 = len(str2) @@ -318,9 +337,9 @@ def edit_distance(str1, str2): for j in range(1, len_str2 + 1): cost = 0 if str1[i - 1] == str2[j - 1] else 1 dp[i][j] = min( - dp[i - 1][j] + 1, # Deletion - dp[i][j - 1] + 1, # Insertion - dp[i - 1][j - 1] + cost # Substitution or no operation + dp[i - 1][j] + 1, # Deletion + dp[i][j - 1] + 1, # Insertion + dp[i - 1][j - 1] + cost, # Substitution or no operation ) return dp[len_str1][len_str2] @@ -333,4 +352,4 @@ def edit_distance(str1, str2): total_reference_length += len(ref.split()) wer_score = total_edit_distance / total_reference_length - return wer_score \ No newline at end of file + return wer_score diff --git a/nanodl/__src/utils/random.py b/nanodl/__src/utils/random.py index 8d836b2..fd9fd5f 100644 --- a/nanodl/__src/utils/random.py +++ b/nanodl/__src/utils/random.py @@ -1,8 +1,10 @@ -import jax import time +from typing import Any, Tuple, Union + +import jax import jax.numpy as jnp from jax import random -from typing import Any, Union, Tuple + def time_rng_key(seed=None) -> jnp.ndarray: """Generate a JAX random key based on the current UNIX timestamp. @@ -13,11 +15,14 @@ def time_rng_key(seed=None) -> jnp.ndarray: key = int(time.time()) if seed is None else seed return random.PRNGKey(key) -def uniform(shape: Tuple[int, ...], - dtype: Any = jnp.float32, - minval: float = 0.0, - maxval: float = 1.0, - seed=None) -> jnp.ndarray: + +def uniform( + shape: Tuple[int, ...], + dtype: Any = jnp.float32, + minval: float = 0.0, + maxval: float = 1.0, + seed=None, +) -> jnp.ndarray: """Generate a tensor of uniform random values. Args: @@ -29,15 +34,12 @@ def uniform(shape: Tuple[int, ...], Returns: jnp.ndarray: A tensor of uniform random values. """ - return random.uniform(time_rng_key(seed), - shape, - dtype=dtype, - minval=minval, - maxval=maxval) - -def normal(shape: Tuple[int, ...], - dtype: Any = jnp.float32, - seed=None) -> jnp.ndarray: + return random.uniform( + time_rng_key(seed), shape, dtype=dtype, minval=minval, maxval=maxval + ) + + +def normal(shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None) -> jnp.ndarray: """Generate a tensor of normal random values. Args: @@ -47,12 +49,10 @@ def normal(shape: Tuple[int, ...], Returns: jnp.ndarray: A tensor of normal random values. """ - return random.normal(time_rng_key(seed), - shape, dtype=dtype) + return random.normal(time_rng_key(seed), shape, dtype=dtype) -def bernoulli(p: float, - shape: Tuple[int, ...] = (), - seed=None) -> jnp.ndarray: + +def bernoulli(p: float, shape: Tuple[int, ...] = (), seed=None) -> jnp.ndarray: """Generate random boolean values with a given probability. Args: @@ -64,10 +64,10 @@ def bernoulli(p: float, """ return random.bernoulli(time_rng_key(seed), p, shape) -def categorical(logits: jnp.ndarray, - axis: int = -1, - shape: Tuple[int, ...] = (), - seed=None) -> jnp.ndarray: + +def categorical( + logits: jnp.ndarray, axis: int = -1, shape: Tuple[int, ...] = (), seed=None +) -> jnp.ndarray: """Draw samples from a categorical distribution. Args: @@ -78,16 +78,12 @@ def categorical(logits: jnp.ndarray, Returns: jnp.ndarray: The sampled indices with the specified shape. """ - return random.categorical(time_rng_key(seed), - logits, - axis=axis, - shape=shape) - -def randint(shape: Tuple[int, ...], - minval: int, - maxval: int, - dtype: str = 'int32', - seed=None) -> jnp.ndarray: + return random.categorical(time_rng_key(seed), logits, axis=axis, shape=shape) + + +def randint( + shape: Tuple[int, ...], minval: int, maxval: int, dtype: str = "int32", seed=None +) -> jnp.ndarray: """Generate random integers between minval (inclusive) and maxval (exclusive). Args: @@ -99,15 +95,10 @@ def randint(shape: Tuple[int, ...], Returns: jnp.ndarray: A tensor of random integers. """ - return random.randint(time_rng_key(seed), - shape, - minval, - maxval, - dtype=dtype) - -def permutation(x: Union[int, jnp.ndarray], - axis: int = 0, - seed=None) -> jnp.ndarray: + return random.randint(time_rng_key(seed), shape, minval, maxval, dtype=dtype) + + +def permutation(x: Union[int, jnp.ndarray], axis: int = 0, seed=None) -> jnp.ndarray: """Randomly permute a sequence, or return a permuted range. Args: @@ -123,9 +114,8 @@ def permutation(x: Union[int, jnp.ndarray], else: return random.permutation(time_rng_key(seed), x, axis=axis) -def gumbel(shape: Tuple[int, ...], - dtype: Any = jnp.float32, - seed=None) -> jnp.ndarray: + +def gumbel(shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None) -> jnp.ndarray: """Draw samples from a Gumbel distribution. Args: @@ -137,12 +127,15 @@ def gumbel(shape: Tuple[int, ...], """ return random.gumbel(time_rng_key(seed), shape, dtype=dtype) -def choice(a: Union[int, jnp.ndarray], - shape: Tuple[int, ...] = (), - replace: bool = True, - p: Union[None, jnp.ndarray] = None, - axis: int = 0, - seed=None) -> jnp.ndarray: + +def choice( + a: Union[int, jnp.ndarray], + shape: Tuple[int, ...] = (), + replace: bool = True, + p: Union[None, jnp.ndarray] = None, + axis: int = 0, + seed=None, +) -> jnp.ndarray: """Randomly choose elements from a given 1-D array. Args: @@ -157,16 +150,12 @@ def choice(a: Union[int, jnp.ndarray], """ if isinstance(a, int): a = jnp.arange(a) - return random.choice(time_rng_key(seed), - a, - shape=shape, - replace=replace, - p=p, - axis=axis) - -def bits(shape: Tuple[int, ...], - dtype: Any = jnp.uint32, - seed=None) -> jnp.ndarray: + return random.choice( + time_rng_key(seed), a, shape=shape, replace=replace, p=p, axis=axis + ) + + +def bits(shape: Tuple[int, ...], dtype: Any = jnp.uint32, seed=None) -> jnp.ndarray: """Generate random bits. Args: @@ -178,9 +167,10 @@ def bits(shape: Tuple[int, ...], """ return random.bits(time_rng_key(seed), shape, dtype=dtype) -def exponential(shape: Tuple[int, ...], - dtype: Any = jnp.float32, - seed=None) -> jnp.ndarray: + +def exponential( + shape: Tuple[int, ...], dtype: Any = jnp.float32, seed=None +) -> jnp.ndarray: """Draw samples from an exponential distribution. Args: @@ -192,11 +182,10 @@ def exponential(shape: Tuple[int, ...], """ return random.exponential(time_rng_key(seed), shape, dtype=dtype) -def triangular(left: float, - right: float, - mode: float, - shape: Tuple[int, ...] = (), - seed=None) -> jnp.ndarray: + +def triangular( + left: float, right: float, mode: float, shape: Tuple[int, ...] = (), seed=None +) -> jnp.ndarray: """Draw samples from a triangular distribution. Args: @@ -210,11 +199,14 @@ def triangular(left: float, """ return random.triangular(time_rng_key(seed), left, right, mode, shape) -def truncated_normal(lower: float, - upper: float, - shape: Tuple[int, ...] = (), - dtype: Any = jnp.float32, - seed=None) -> jnp.ndarray: + +def truncated_normal( + lower: float, + upper: float, + shape: Tuple[int, ...] = (), + dtype: Any = jnp.float32, + seed=None, +) -> jnp.ndarray: """Draw samples from a truncated normal distribution. Args: @@ -226,16 +218,12 @@ def truncated_normal(lower: float, Returns: jnp.ndarray: A tensor of samples from a truncated normal distribution. """ - return random.truncated_normal(time_rng_key(seed), - lower, - upper, - shape, - dtype) - -def poisson(lam: float, - shape: Tuple[int, ...] = (), - dtype: Any = jnp.int32, - seed=None) -> jnp.ndarray: + return random.truncated_normal(time_rng_key(seed), lower, upper, shape, dtype) + + +def poisson( + lam: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.int32, seed=None +) -> jnp.ndarray: """Draw samples from a Poisson distribution. Args: @@ -246,15 +234,12 @@ def poisson(lam: float, Returns: jnp.ndarray: A tensor of samples from a Poisson distribution. """ - return random.poisson(time_rng_key(seed), - lam, - shape=shape, - dtype=dtype) - -def geometric(p: float, - shape: Tuple[int, ...] = (), - dtype: Any = jnp.int32, - seed=None) -> jnp.ndarray: + return random.poisson(time_rng_key(seed), lam, shape=shape, dtype=dtype) + + +def geometric( + p: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.int32, seed=None +) -> jnp.ndarray: """Draw samples from a geometric distribution. Args: @@ -265,15 +250,12 @@ def geometric(p: float, Returns: jnp.ndarray: A tensor of samples from a geometric distribution. """ - return random.geometric(time_rng_key(seed), - p, - shape=shape, - dtype=dtype) - -def gamma(a: float, - shape: Tuple[int, ...] = (), - dtype: Any = jnp.float32, - seed=None) -> jnp.ndarray: + return random.geometric(time_rng_key(seed), p, shape=shape, dtype=dtype) + + +def gamma( + a: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.float32, seed=None +) -> jnp.ndarray: """Draw samples from a gamma distribution. Args: @@ -284,15 +266,12 @@ def gamma(a: float, Returns: jnp.ndarray: A tensor of samples from a gamma distribution. """ - return random.gamma(time_rng_key(seed), - a, - shape=shape, - dtype=dtype) - -def chisquare(df: float, - shape: Tuple[int, ...] = (), - dtype: Any = jnp.float32, - seed=None) -> jnp.ndarray: + return random.gamma(time_rng_key(seed), a, shape=shape, dtype=dtype) + + +def chisquare( + df: float, shape: Tuple[int, ...] = (), dtype: Any = jnp.float32, seed=None +) -> jnp.ndarray: """Draw samples from a chi-square distribution. Args: @@ -303,7 +282,4 @@ def chisquare(df: float, Returns: jnp.ndarray: A tensor of samples from a chi-square distribution. """ - return random.chisquare(time_rng_key(seed), - df, - shape=shape, - dtype=dtype) \ No newline at end of file + return random.chisquare(time_rng_key(seed), df, shape=shape, dtype=dtype) diff --git a/nanodl/__src/utils/vision.py b/nanodl/__src/utils/vision.py index 908c31f..9dc3383 100644 --- a/nanodl/__src/utils/vision.py +++ b/nanodl/__src/utils/vision.py @@ -1,6 +1,8 @@ +import time + import jax import jax.numpy as jnp -import time + @jax.jit def normalize_images(images: jnp.ndarray) -> jnp.ndarray: @@ -26,8 +28,7 @@ def normalize_images(images: jnp.ndarray) -> jnp.ndarray: return (images - mean) / (std + 1e-5) -def random_crop(images: jnp.ndarray, - crop_size: int) -> jnp.ndarray: +def random_crop(images: jnp.ndarray, crop_size: int) -> jnp.ndarray: """ Randomly crop a batch of images to a specified size using JAX. @@ -61,9 +62,7 @@ def random_crop(images: jnp.ndarray, return crops -def gaussian_blur(image: jnp.ndarray, - kernel_size: int, - sigma: float) -> jnp.ndarray: +def gaussian_blur(image: jnp.ndarray, kernel_size: int, sigma: float) -> jnp.ndarray: """ Apply Gaussian blur to a multi-channel image. @@ -89,7 +88,13 @@ def gaussian_blur(image: jnp.ndarray, kernel = kernel / jnp.sum(kernel) # Apply convolution to each channel - blurred_image = jnp.stack([jax.scipy.signal.convolve2d(image[:, :, i], kernel, mode='same') for i in range(image.shape[2])], axis=-1) + blurred_image = jnp.stack( + [ + jax.scipy.signal.convolve2d(image[:, :, i], kernel, mode="same") + for i in range(image.shape[2]) + ], + axis=-1, + ) return blurred_image @@ -115,18 +120,22 @@ def sobel_edge_detection(image: jnp.ndarray) -> jnp.ndarray: sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) def apply_sobel(channel): - gx = jax.scipy.signal.convolve2d(channel, sobel_x, mode='same') - gy = jax.scipy.signal.convolve2d(channel, sobel_y, mode='same') + gx = jax.scipy.signal.convolve2d(channel, sobel_x, mode="same") + gy = jax.scipy.signal.convolve2d(channel, sobel_y, mode="same") return jnp.sqrt(gx**2 + gy**2) # Apply Sobel filter to each channel and sum the results - edges = jnp.sum(jnp.stack([apply_sobel(image[:, :, i]) for i in range(image.shape[2])], axis=-1), axis=-1) + edges = jnp.sum( + jnp.stack( + [apply_sobel(image[:, :, i]) for i in range(image.shape[2])], axis=-1 + ), + axis=-1, + ) return edges @jax.jit -def adjust_brightness(image: jnp.ndarray, - factor: float) -> jnp.ndarray: +def adjust_brightness(image: jnp.ndarray, factor: float) -> jnp.ndarray: """ Adjust the brightness of an image. @@ -179,7 +188,7 @@ def flip_image(image: jnp.ndarray, horizontal: jnp.ndarray) -> jnp.ndarray: Args: image (jnp.ndarray): Input image of shape (H, W, C). - horizontal (jnp.ndarray): If True (jax.numpy.array with a single True value), flip horizontally; + horizontal (jnp.ndarray): If True (jax.numpy.array with a single True value), flip horizontally; otherwise, flip vertically. Returns: @@ -197,9 +206,9 @@ def flip_image(image: jnp.ndarray, horizontal: jnp.ndarray) -> jnp.ndarray: @jax.jit -def random_flip_image(image: jnp.ndarray, - key: jax.random.PRNGKey, - horizontal: jnp.ndarray) -> jnp.ndarray: +def random_flip_image( + image: jnp.ndarray, key: jax.random.PRNGKey, horizontal: jnp.ndarray +) -> jnp.ndarray: """ Randomly flip an image horizontally or vertically using JAX. @@ -224,4 +233,4 @@ def random_flip_image(image: jnp.ndarray, flip = jax.random.uniform(key) > 0.5 flip_horizontal = jnp.where(horizontal, image[:, ::-1, :], image) flip_vertical = jnp.where(horizontal, image, image[::-1, :, :]) - return jnp.where(flip, flip_horizontal, flip_vertical) \ No newline at end of file + return jnp.where(flip, flip_horizontal, flip_vertical) diff --git a/setup.py b/setup.py index ccfedb8..7d5bc93 100644 --- a/setup.py +++ b/setup.py @@ -1,36 +1,36 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( - name='nanodl', - version='1.2.3.dev1', - author='Henry Ndubuaku', - author_email='ndubuakuhenry@gmail.com', - description='A Jax-based library for designing and training transformer models from scratch.', - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - url='https://github.com/hmunachi/nanodl', + name="nanodl", + version="1.2.4.dev1", + author="Henry Ndubuaku", + author_email="ndubuakuhenry@gmail.com", + description="A Jax-based library for designing and training transformer models from scratch.", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/hmunachi/nanodl", packages=find_packages(), install_requires=[ - 'flax', - 'jax', - 'jaxlib', - 'optax', - 'einops', - 'sentencepiece', + "flax", + "jax", + "jaxlib", + "optax", + "einops", + "sentencepiece", ], classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', - 'Intended Audience :: Education', - 'Topic :: Software Development :: Build Tools', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Education", + "Topic :: Software Development :: Build Tools", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], - keywords='transformers jax machine learning deep learning pytorch tensorflow', - python_requires='>=3.7', + keywords="transformers jax machine learning deep learning pytorch tensorflow", + python_requires=">=3.7", ) diff --git a/tests/test_models.py b/tests/test_models.py index f078302..f5c8cbe 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,9 +1,11 @@ +import unittest + import jax import jax.numpy as jnp -import unittest from nanodl import * + class TestTextBasedModels(unittest.TestCase): def setUp(self): self.batch_size = 8 @@ -12,24 +14,23 @@ def setUp(self): self.embed_dim = 256 self.data = jnp.arange( - self.batch_size * self.max_length, - dtype=jnp.int32 - ).reshape((self.batch_size, self.max_length)) - + self.batch_size * self.max_length, dtype=jnp.int32 + ).reshape((self.batch_size, self.max_length)) + self.dummy_inputs = self.data[:, :-1] self.dummy_targets = self.data[:, 1:] self.hyperparams = { - 'num_layers': 1, - 'hidden_dim': self.embed_dim, - 'num_heads': 2, - 'feedforward_dim': self.embed_dim, - 'dropout': 0.1, - 'vocab_size': self.vocab_size, - 'embed_dim': self.embed_dim, - 'max_length': self.max_length, - 'start_token': 0, - 'end_token': 50, + "num_layers": 1, + "hidden_dim": self.embed_dim, + "num_heads": 2, + "feedforward_dim": self.embed_dim, + "dropout": 0.1, + "vocab_size": self.vocab_size, + "embed_dim": self.embed_dim, + "max_length": self.max_length, + "start_token": 0, + "end_token": 50, } def test_t5_model(self): @@ -53,84 +54,57 @@ def test_gpt3_model(self): self._test_decoder_only_model(model) def test_mistral_model(self): - model = Mistral(**self.hyperparams, - num_groups=2, - window_size=5, - shift_size=2) + model = Mistral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2) self._test_decoder_only_model(model) def test_mixtral_model(self): - model = Mixtral(**self.hyperparams, - num_groups=2, - window_size=5, - shift_size=2) + model = Mixtral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2) self._test_decoder_only_model(model) def test_llama_model(self): - model = LlaMA2(**self.hyperparams, - num_groups=2) + model = Llama3(**self.hyperparams, num_groups=2) self._test_decoder_only_model(model) def test_gemma_model(self): - model = Gemma(**self.hyperparams, - num_groups=2) + model = Gemma(**self.hyperparams, num_groups=2) self._test_decoder_only_model(model) def _test_encoder_decoder_model(self, model): - rngs = { - 'params': jax.random.key(0), - 'dropout': jax.random.key(1) - - } - params = model.init( - rngs, - self.dummy_inputs, - self.dummy_targets - )['params'] - + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = model.init(rngs, self.dummy_inputs, self.dummy_targets)["params"] + outputs = model.apply( - {'params': params}, - self.dummy_inputs, - self.dummy_targets, - rngs=rngs) - + {"params": params}, self.dummy_inputs, self.dummy_targets, rngs=rngs + ) + self.assertEqual( - outputs.shape, - (self.batch_size, self.max_length - 1, self.vocab_size) - ) - + outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size) + ) + def _test_decoder_only_model(self, model): - rngs = { - 'params': jax.random.key(0), - 'dropout': jax.random.key(1) - - } - params = model.init( - rngs, - self.dummy_inputs - )['params'] - - outputs = model.apply( - {'params': params}, - self.dummy_inputs, - rngs=rngs) - + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + params = model.init(rngs, self.dummy_inputs)["params"] + + outputs = model.apply({"params": params}, self.dummy_inputs, rngs=rngs) + self.assertEqual( - outputs.shape, - (self.batch_size, self.max_length - 1, self.vocab_size) - ) - + outputs.shape, (self.batch_size, self.max_length - 1, self.vocab_size) + ) + def test_reward_model(self): - model = RewardModel(Mixtral(**self.hyperparams, - num_groups=2, - window_size=5, - shift_size=2), dim=self.hyperparams['hidden_dim'], dropout=0.1) + model = RewardModel( + Mixtral(**self.hyperparams, num_groups=2, window_size=5, shift_size=2), + dim=self.hyperparams["hidden_dim"], + dropout=0.1, + ) rngs = jax.random.PRNGKey(0) rngs, dropout_rng = jax.random.split(rngs) - params = model.init({'params': rngs, 'dropout': dropout_rng}, self.dummy_inputs)['params'] - rewards = model.apply({'params': params}, - self.dummy_inputs, - rngs={'dropout': dropout_rng}) + params = model.init( + {"params": rngs, "dropout": dropout_rng}, self.dummy_inputs + )["params"] + rewards = model.apply( + {"params": params}, self.dummy_inputs, rngs={"dropout": dropout_rng} + ) assert rewards.shape == (self.batch_size,) @@ -144,11 +118,8 @@ def setUp(self): key = jax.random.PRNGKey(10) self.dummy_labels = jax.random.randint( - key, - shape=(self.batch_size,), - minval=0, - maxval=self.n_outputs-1 - ) + key, shape=(self.batch_size,), minval=0, maxval=self.n_outputs - 1 + ) self.hyperparams = { "dropout": 0.1, @@ -157,7 +128,7 @@ def setUp(self): "patch_size": self.patch_size, "hidden_dim": self.embed_dim, "num_layers": 4, - "n_outputs": self.n_outputs + "n_outputs": self.n_outputs, } def test_vit_model(self): @@ -169,26 +140,13 @@ def test_mixer_model(self): self._test_model(model) def _test_model(self, model): - rngs = { - 'params': jax.random.key(0), - 'dropout': jax.random.key(1) - } - - params = model.init( - rngs, - self.dummy_inputs - )['params'] - - outputs = model.apply( - {'params': params}, - self.dummy_inputs, - rngs=rngs - )[0] - - self.assertEqual( - outputs.shape, - (self.batch_size, self.n_outputs) - ) + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + + params = model.init(rngs, self.dummy_inputs)["params"] + + outputs = model.apply({"params": params}, self.dummy_inputs, rngs=rngs)[0] + + self.assertEqual(outputs.shape, (self.batch_size, self.n_outputs)) class TestCLIPModel(unittest.TestCase): @@ -211,25 +169,17 @@ def setUp(self): "num_layers_images": 4, "max_len": self.max_length, "vocab_size": self.vocab_size, - "embed_dim": self.embed_dim + "embed_dim": self.embed_dim, } self.model = CLIP(**self.clip_params) def test_clip_model_initialization_and_processing(self): rng = jax.random.PRNGKey(0) - params = self.model.init( - rng, - self.dummy_texts, - self.dummy_images - )['params'] - - loss = self.model.apply( - {'params': params}, - self.dummy_texts, - self.dummy_images - ) - + params = self.model.init(rng, self.dummy_texts, self.dummy_images)["params"] + + loss = self.model.apply({"params": params}, self.dummy_texts, self.dummy_images) + self.assertIsNotNone(loss) @@ -241,50 +191,38 @@ def setUp(self): self.vocab_size = 1000 self.dummy_targets = jnp.arange( - self.batch_size * self.max_length, - dtype=jnp.int32 - ).reshape((self.batch_size, self.max_length)) - + self.batch_size * self.max_length, dtype=jnp.int32 + ).reshape((self.batch_size, self.max_length)) + self.dummy_inputs = jnp.ones((self.batch_size, self.max_length, self.embed_dim)) self.hyperparams = { - 'num_layers': 1, - 'hidden_dim': self.embed_dim, - 'num_heads': 2, - 'feedforward_dim': self.embed_dim, - 'dropout': 0.1, - 'vocab_size': self.vocab_size, - 'embed_dim': self.embed_dim, - 'max_length': self.max_length, - 'start_token': 0, - 'end_token': 50, + "num_layers": 1, + "hidden_dim": self.embed_dim, + "num_heads": 2, + "feedforward_dim": self.embed_dim, + "dropout": 0.1, + "vocab_size": self.vocab_size, + "embed_dim": self.embed_dim, + "max_length": self.max_length, + "start_token": 0, + "end_token": 50, } self.model = Whisper(**self.hyperparams) def test_whisper_model_initialization_and_processing(self): - rngs = { - 'params': jax.random.key(0), - 'dropout': jax.random.key(1) - } - - params = self.model.init( - rngs, - self.dummy_inputs, - self.dummy_targets - )['params'] - + rngs = {"params": jax.random.key(0), "dropout": jax.random.key(1)} + + params = self.model.init(rngs, self.dummy_inputs, self.dummy_targets)["params"] + outputs = self.model.apply( - {'params': params}, - self.dummy_inputs, - self.dummy_targets, - rngs=rngs - ) - + {"params": params}, self.dummy_inputs, self.dummy_targets, rngs=rngs + ) + self.assertEqual( - outputs.shape, - (self.batch_size, self.max_length, self.vocab_size) - ) + outputs.shape, (self.batch_size, self.max_length, self.vocab_size) + ) class TestDiffusionModel(unittest.TestCase): @@ -294,12 +232,8 @@ def setUp(self): self.block_depth = 2 self.input_shape = (3, self.image_size, self.image_size, 3) self.images = jax.random.normal(jax.random.PRNGKey(0), self.input_shape) - - self.model = DiffusionModel( - self.image_size, - self.widths, - self.block_depth - ) + + self.model = DiffusionModel(self.image_size, self.widths, self.block_depth) def test_diffusion_model_initialization_and_processing(self): params = self.model.init(jax.random.PRNGKey(0), self.images) @@ -315,39 +249,27 @@ def setUp(self): self.nclass = 3 self.x = jax.random.normal( - jax.random.PRNGKey(0), - (self.num_nodes, self.num_features) - ) - + jax.random.PRNGKey(0), (self.num_nodes, self.num_features) + ) + self.adj = jax.random.bernoulli( - jax.random.PRNGKey(0),0.3, - (self.num_nodes, self.num_nodes) - ) - + jax.random.PRNGKey(0), 0.3, (self.num_nodes, self.num_nodes) + ) + self.model = GAT( - nfeat=self.num_features, - nhid=8, - nclass=self.nclass, - dropout_rate=0.5, - alpha=0.2, - nheads=3 - ) + nfeat=self.num_features, + nhid=8, + nclass=self.nclass, + dropout_rate=0.5, + alpha=0.2, + nheads=3, + ) def test_gat_model_initialization_and_processing(self): - params = self.model.init( - jax.random.key(0), - self.x, - self.adj, - training=False - ) - - output = self.model.apply( - params, - self.x, - self.adj, - training=False - ) - + params = self.model.init(jax.random.key(0), self.x, self.adj, training=False) + + output = self.model.apply(params, self.x, self.adj, training=False) + self.assertEqual(output.shape, (self.num_nodes, self.nclass)) @@ -363,14 +285,13 @@ def setUp(self): self.num_layers = 2 self.predictor_num_layers = 1 self.dropout_p = 0 - self.num_patches = (self.image_size ** 2) / (self.patch_size ** 2) - + self.num_patches = (self.image_size**2) / (self.patch_size**2) self.x = jax.random.normal( - jax.random.PRNGKey(0), - (1, self.image_size, self.image_size, self.num_channels) - ) - + jax.random.PRNGKey(0), + (1, self.image_size, self.image_size, self.num_channels), + ) + self.model = IJEPA( image_size=self.image_size, num_channels=self.num_channels, @@ -385,11 +306,9 @@ def setUp(self): ) self.data_sampler = IJEPADataSampler( - image_size=self.image_size, - M=4, - patch_size=self.patch_size + image_size=self.image_size, M=4, patch_size=self.patch_size ) - + def test_ijepa_data_sampling(self): context_mask, target_mask = self.data_sampler() self.assertEqual(context_mask.shape, (4, self.num_patches)) @@ -399,24 +318,25 @@ def test_ijepa_model_initialization_and_processing(self): context_mask, target_mask = self.data_sampler() params = self.model.init( - jax.random.key(0), - self.x, + jax.random.key(0), + self.x, context_mask[jnp.newaxis], target_mask[jnp.newaxis], - training=False + training=False, ) - - outputs , _ = self.model.apply( - params, + + outputs, _ = self.model.apply( + params, self.x, context_mask[jnp.newaxis], - target_mask[jnp.newaxis], - training=False + target_mask[jnp.newaxis], + training=False, ) self.assertEqual(len(outputs), 4) self.assertEqual(outputs[0][0].shape, (1, self.num_patches, self.embed_dim)) self.assertEqual(outputs[0][0].shape, outputs[0][1].shape) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_random.py b/tests/test_random.py index d0940df..ebcbebd 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -1,18 +1,36 @@ import unittest + import jax.numpy as jnp + from nanodl import ( - time_rng_key, uniform, normal, bernoulli, categorical, randint, - permutation, gumbel, choice, bits, exponential, - triangular, truncated_normal, poisson, geometric, gamma, - chisquare + bernoulli, + bits, + categorical, + chisquare, + choice, + exponential, + gamma, + geometric, + gumbel, + normal, + permutation, + poisson, + randint, + time_rng_key, + triangular, + truncated_normal, + uniform, ) + class TestRandomFunctions(unittest.TestCase): def test_time_rng_key(self): key1 = time_rng_key(seed=42) key2 = time_rng_key(seed=42) - self.assertTrue(jnp.array_equal(key1, key2), "Keys should be equal for the same seed") + self.assertTrue( + jnp.array_equal(key1, key2), "Keys should be equal for the same seed" + ) def test_uniform(self): result = uniform((2, 3)) @@ -94,5 +112,6 @@ def test_chisquare(self): self.assertEqual(result.shape, (2, 2)) self.assertEqual(result.dtype, jnp.float32) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_sklearn_gpu.py b/tests/test_sklearn_gpu.py index 58966b4..a1b7b4b 100644 --- a/tests/test_sklearn_gpu.py +++ b/tests/test_sklearn_gpu.py @@ -1,7 +1,8 @@ +import unittest + import jax import jax.numpy as jnp -import unittest from nanodl import * @@ -18,7 +19,9 @@ def test_naive_bayes_classifier(self): classifier.fit(self.X, self.y) predictions = classifier.predict(self.X) self.assertEqual(predictions.shape, (self.num_samples,)) - self.assertTrue(jnp.all(predictions >= 0) and jnp.all(predictions < self.num_classes)) + self.assertTrue( + jnp.all(predictions >= 0) and jnp.all(predictions < self.num_classes) + ) class TestKClustering(unittest.TestCase): @@ -27,9 +30,8 @@ def setUp(self): self.num_samples = 300 self.num_features = 2 self.X = jax.random.normal( - jax.random.PRNGKey(0), - (self.num_samples, self.num_features) - ) + jax.random.PRNGKey(0), (self.num_samples, self.num_features) + ) def test_kmeans_fit_predict(self): kmeans = KMeans(k=self.k) @@ -97,10 +99,14 @@ def test_gaussian_process(self): def rbf_kernel(x1, x2, length_scale=1.0): diff = x1[:, None] - x2 return jnp.exp(-0.5 * jnp.sum(diff**2, axis=-1) / length_scale**2) + num_samples = 100 input_dim = 1 X_train = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) - y_train = jnp.sin(X_train) + jax.random.normal(jax.random.PRNGKey(0), (num_samples, 1)) * 0.1 + y_train = ( + jnp.sin(X_train) + + jax.random.normal(jax.random.PRNGKey(0), (num_samples, 1)) * 0.1 + ) gp = GaussianProcess(kernel=rbf_kernel, noise=1e-3) gp.fit(X_train, y_train) X_new = jax.random.normal(jax.random.PRNGKey(0), (num_samples, input_dim)) @@ -109,5 +115,5 @@ def rbf_kernel(x1, x2, length_scale=1.0): self.assertEqual(covariance.shape, (num_samples, num_samples)) -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index f4312ad..12d525b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,8 @@ +import unittest + import jax import jax.numpy as jnp -import unittest from nanodl import * @@ -10,8 +11,10 @@ def test_dataset_length(self): class DummyDataset(Dataset): def __init__(self, data): self.data = data + def __len__(self): return len(self.data) + def __getitem__(self, index): return self.data[index] @@ -22,8 +25,10 @@ def test_dataset_getitem(self): class DummyDataset(Dataset): def __init__(self, data): self.data = data + def __len__(self): return len(self.data) + def __getitem__(self, index): return self.data[index] @@ -31,38 +36,28 @@ def __getitem__(self, index): item = dataset[5] self.assertEqual(item, 5) + class TestArrayDataset(unittest.TestCase): def test_array_dataset_length(self): - dataset = ArrayDataset( - jnp.array([1, 2, 3]), - jnp.array([4, 5, 6]) - ) + dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) self.assertEqual(len(dataset), 3) def test_array_dataset_getitem(self): - dataset = ArrayDataset( - jnp.array([1, 2, 3]), - jnp.array([4, 5, 6]) - ) + dataset = ArrayDataset(jnp.array([1, 2, 3]), jnp.array([4, 5, 6])) item = dataset[1] self.assertEqual(item, (2, 5)) + class TestDataLoader(unittest.TestCase): def test_data_loader_length(self): - dataset = ArrayDataset( - jnp.ones((1001, 256, 256)), - jnp.ones((1001, 256, 256)) - ) + dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256))) dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=False) self.assertEqual(len(dataloader), 101) def test_data_loader_iteration(self): - dataset = ArrayDataset( - jnp.ones((1001, 256, 256)), - jnp.ones((1001, 256, 256)) - ) + dataset = ArrayDataset(jnp.ones((1001, 256, 256)), jnp.ones((1001, 256, 256))) dataloader = DataLoader(dataset, batch_size=10, shuffle=True, drop_last=True) - for a,b in dataloader: + for a, b in dataloader: self.assertEqual(a.shape, (10, 256, 256)) self.assertEqual(b.shape, (10, 256, 256)) @@ -79,7 +74,7 @@ def test_batch_pearsonr(self): x = jnp.array([[1, 2, 3], [4, 5, 6]]) y = jnp.array([[6, 5, 4], [2, 6, 8]]) correlations = batch_pearsonr(x, y) - expected_results = jnp.array([-1.0, 1.0, 1.0]) + expected_results = jnp.array([-1.0, 1.0, 1.0]) self.assertTrue(jnp.allclose(correlations, expected_results)) def test_classification_scores(self): @@ -90,11 +85,13 @@ def test_classification_scores(self): self.assertTrue(jnp.allclose(scores, expected_results)) def test_mean_reciprocal_rank(self): - predictions = jnp.array([ - [0, 1, 2], # "correct" prediction at index 0 - [1, 0, 2], # "correct" prediction at index 1 - [2, 1, 0] # "correct" prediction at index 2 - ]) + predictions = jnp.array( + [ + [0, 1, 2], # "correct" prediction at index 0 + [1, 0, 2], # "correct" prediction at index 1 + [2, 1, 0], # "correct" prediction at index 2 + ] + ) mrr_score = mean_reciprocal_rank(predictions) self.assertAlmostEqual(mrr_score, 0.61111116) @@ -158,18 +155,28 @@ def setUp(self): def test_rouge(self): rouge_scores = rouge(self.hypotheses, self.references, [1, 2]) - expected_scores = {'ROUGE-1': {'precision': 0.7857142857142857, - 'recall': 0.9, - 'f1': 0.8333333333328402}, - 'ROUGE-2': {'precision': 0.6666666666666666, - 'recall': 0.7, - 'f1': 0.6818181818176838}} + expected_scores = { + "ROUGE-1": { + "precision": 0.7857142857142857, + "recall": 0.9, + "f1": 0.8333333333328402, + }, + "ROUGE-2": { + "precision": 0.6666666666666666, + "recall": 0.7, + "f1": 0.6818181818176838, + }, + } + def assert_nested_dicts_equal(dict1, dict2): for key in dict1.keys(): if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): assert_nested_dicts_equal(dict1[key], dict2[key]) elif dict1[key] != dict2[key]: - raise AssertionError(f"Values for key '{key}' are not equal: {dict1[key]} != {dict2[key]}") + raise AssertionError( + f"Values for key '{key}' are not equal: {dict1[key]} != {dict2[key]}" + ) + assert_nested_dicts_equal(rouge_scores, expected_scores) def test_bleu(self): @@ -200,39 +207,39 @@ def test_word_error_rate(self): class TestVisionFunctions(unittest.TestCase): def test_normalize_images(self): - images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]]) + images = jnp.array([[[[0.0, 0.5], [1.0, 0.25]]]]) normalized_images = normalize_images(images) self.assertAlmostEqual(normalized_images.mean(), 0.0, places=3) self.assertAlmostEqual(normalized_images.std(), 1.0, places=3) def test_random_crop(self): - images = jnp.ones((10, 100, 100, 3)) + images = jnp.ones((10, 100, 100, 3)) crop_size = 64 cropped_images = random_crop(images, crop_size) self.assertEqual(cropped_images.shape, (10, crop_size, crop_size, 3)) def test_gaussian_blur(self): - image = jnp.ones((5, 5, 3)) + image = jnp.ones((5, 5, 3)) blurred_image = gaussian_blur(image, kernel_size=3, sigma=1.0) self.assertEqual(blurred_image.shape, (5, 5, 3)) def test_sobel_edge_detection(self): - image = jnp.ones((5, 5, 3)) + image = jnp.ones((5, 5, 3)) edges = sobel_edge_detection(image) self.assertEqual(edges.shape, (5, 5)) def test_adjust_brightness(self): - image = jnp.ones((5, 5, 3)) + image = jnp.ones((5, 5, 3)) adjusted_image = adjust_brightness(image, factor=1.5) self.assertEqual(adjusted_image.shape, (5, 5, 3)) def test_adjust_contrast(self): - image = jnp.ones((5, 5, 3)) + image = jnp.ones((5, 5, 3)) adjusted_image = adjust_contrast(image, factor=1.5) self.assertEqual(adjusted_image.shape, (5, 5, 3)) def test_flip_image(self): - image = jnp.ones((5, 5, 3)) + image = jnp.ones((5, 5, 3)) flipped_image_horizontally = flip_image(image, jnp.array([True])) flipped_image_vertically = flip_image(image, jnp.array([False])) self.assertEqual(flipped_image_horizontally.shape, (5, 5, 3)) @@ -240,27 +247,10 @@ def test_flip_image(self): def test_random_flip_image(self): key = jax.random.PRNGKey(0) - image = jnp.ones((5, 5, 3)) + image = jnp.ones((5, 5, 3)) flipped_image = random_flip_image(image, key, jnp.array([True])) self.assertEqual(flipped_image.shape, (5, 5, 3)) -class TestTokenizerEncodingDecoding(unittest.TestCase): - def setUp(self): - """Set up the tokenizer with specific training data.""" - text_paths = ['tests/files/sample.txt'] - self.tokenizer = Tokenizer(training_data=text_paths, - vocab_size=100, - model_type='bpe', - max_sentence_length=50) - - def test_encode_decode(self): - """Test that encoding followed by decoding returns the original sentence.""" - test_sentence = "Hello, test" - encoded_sentence = self.tokenizer.encode(test_sentence) - decoded_sentence = self.tokenizer.decode(encoded_sentence) - self.assertEqual(test_sentence, decoded_sentence) - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main()