Skip to content

Commit

Permalink
added more models
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Mar 4, 2024
1 parent ea7607b commit 0802890
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 13 deletions.
78 changes: 68 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Developing and training transformer-based models is typically resource-intensive

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers includding RLHFPPO and RLHFDPO so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
Expand All @@ -35,7 +35,7 @@ Feedback on any of our discussion, issue and pull request threads are welcomed!
## What's New in version 1.2.0.dev1

- Google's Gemma architecture.
- Data parallel distributed RLHFPPO and RLHFDPO.
- Data parallel distributed Reward and RLHF trainers.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

There are experimental features (like MAMBA architecture) in the repo which is not available via the package,
Expand All @@ -62,7 +62,7 @@ pip install nanodl

## What does nanodl look like?

We provide various examples using the nanodl API: language, vision and audio, starting with an LLM.
We provide various example usages of the nanodl API.

```py
import jax
Expand Down Expand Up @@ -139,7 +139,9 @@ outputs = model.apply({'params': params},
method=model.generate)
print(outputs)
```

Vision example

```py
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -186,6 +188,7 @@ print(generated_images.shape)
```

Audio example

```py
import jax
import jax.numpy as jnp
Expand All @@ -210,12 +213,6 @@ dataloader = DataLoader(dataset,
shuffle=True,
drop_last=False)

# How to loop through dataloader
for batch in dataloader:
x, y = batch
print(x.shape, y.shape)
break

# model parameters
hyperparams = {
'num_layers': 1,
Expand Down Expand Up @@ -256,7 +253,67 @@ transcripts = model.apply({'params': params},
print(transcripts)
```

Reward Model example for RLHF

```py
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)

# Create dataset and dataloader
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)

# model parameters
hyperparams = {
'num_layers': 1,
'hidden_dim': 256,
'num_heads': 2,
'feedforward_dim': 256,
'dropout': 0.1,
'vocab_size': 1000,
'embed_dim': 256,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
'num_groups': 2,
'window_size': 5,
'shift_size': 2
}

# Initialize reward model from Mistral
model = Mistral(**hyperparams)
reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)

# Train the reward model
trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')
trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')

# Call as you would a regular Flax model
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': dropout_rng})

print(rewards.shape)
```

PCA example

```py
import jax
from nanodl import PCA
Expand All @@ -272,6 +329,7 @@ print(X_sampled.shape, original_data.shape, transformed_data.shape)

NanoDL provides random module which abstracts away Jax's intricacies.
It generates truly random variables by using the current timestamp as seed.

```py
# Jax example
key = random.PRNGKey(0)
Expand All @@ -289,7 +347,7 @@ This is the first iteration of this project, roughness is expected, contribution
- Raise the issue/discussion to get second opinions
- Fork the repository
- Create a branch
- Make your changes without ruining the design patterns
- Make your changes without changing the design patterns
- Write tests for your changes if necessary
- Install locally with `pip install -e .`
- Run tests with `python -m unittest discover -s tests`
Expand Down
10 changes: 7 additions & 3 deletions nanodl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@
GemmaDecoderBlock
)

from nanodl.__src.models.reward import (
RewardModel,
RewardDataParallelTrainer
)

from nanodl.__src.utils.data import (
Dataset,
ArrayDataset,
Expand Down Expand Up @@ -232,13 +237,12 @@
"Mixtral",
"MixtralDecoder",
"MixtralDecoderBlock",
"Mamba",
"MambaDataParallelTrainer",
"MambaBlock"
"Whisper",
"WhisperDataParallelTrainer",
"WhisperSpeechEncoder",
"WhisperSpeechEncoderBlock",
"RewardModel",
"RewardDataParallelTrainer",
"DiffusionModel",
"DiffusionDataParallelTrainer",
"UNet",
Expand Down
Loading

0 comments on commit 0802890

Please sign in to comment.