Skip to content

Commit

Permalink
Merge pull request #13 from HMUNACHI/dev
Browse files Browse the repository at this point in the history
NanoDL 1.2.0.dev1
  • Loading branch information
HMUNACHI committed Mar 11, 2024
2 parents 18c7f8e + 64484e6 commit a95e62d
Show file tree
Hide file tree
Showing 16 changed files with 2,906 additions and 29 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ __pycache__/
# Ignore configuration files with sensitive information
config.ini
secrets.yaml
params.pkl
base_params.pkl
reward_params.pkl

# Ignore user-specific files
/userdata/
Expand Down
108 changes: 91 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,26 @@ Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/)
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
- An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
- True random number generators in Jax which do not need the verbose code.
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used.

Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at [email protected].

## What's New in version 1.2.0.dev1

- Google's Gemma architecture.
- Reward model wrapper and data-parallel distributed reward trainer.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

There are experimental features (like MAMBA architecture and RLHF) in the repo which is not available via the package, pending tests.

## Quick install

You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)
Expand All @@ -52,7 +61,7 @@ pip install nanodl

## What does nanodl look like?

We provide various examples using the nanodl API: language, vision and audio, starting with an LLM.
We provide various example usages of the nanodl API.

```py
import jax
Expand Down Expand Up @@ -129,7 +138,9 @@ outputs = model.apply({'params': params},
method=model.generate)
print(outputs)
```

Vision example

```py
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -176,6 +187,7 @@ print(generated_images.shape)
```

Audio example

```py
import jax
import jax.numpy as jnp
Expand All @@ -200,12 +212,6 @@ dataloader = DataLoader(dataset,
shuffle=True,
drop_last=False)

# How to loop through dataloader
for batch in dataloader:
x, y = batch
print(x.shape, y.shape)
break

# model parameters
hyperparams = {
'num_layers': 1,
Expand Down Expand Up @@ -246,7 +252,67 @@ transcripts = model.apply({'params': params},
print(transcripts)
```

Reward Model example for RLHF

```py
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)

# Create dataset and dataloader
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)

# model parameters
hyperparams = {
'num_layers': 1,
'hidden_dim': 256,
'num_heads': 2,
'feedforward_dim': 256,
'dropout': 0.1,
'vocab_size': 1000,
'embed_dim': 256,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
'num_groups': 2,
'window_size': 5,
'shift_size': 2
}

# Initialize reward model from Mistral
model = Mistral(**hyperparams)
reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)

# Train the reward model
trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')
trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')

# Call as you would a regular Flax model
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': dropout_rng})

print(rewards.shape)
```

PCA example

```py
import jax
from nanodl import PCA
Expand All @@ -260,14 +326,27 @@ X_sampled = pca.sample(n_samples=1000, key=None)
print(X_sampled.shape, original_data.shape, transformed_data.shape)
```

# Contribution
NanoDL provides random module which abstracts away Jax's intricacies.
It generates truly random variables by using the current timestamp as seed.

```py
# Jax example
key = random.PRNGKey(0)
jax_array = random.uniform(key, shape=(3, 3))

# NanoDL example
jax_array = nanodl.uniform(shape=(3, 3))

# For reproducability, use seed
jax_array = nanodl.uniform(shape=(3, 3), seed=0)
```

This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:

- Raise the issue/discussion to get second opinions
- Fork the repository
- Create a branch
- Make your changes without ruining the design patterns
- Make your changes without changing the design patterns
- Write tests for your changes if necessary
- Install locally with `pip install -e .`
- Run tests with `python -m unittest discover -s tests`
Expand All @@ -279,16 +358,11 @@ Contributions can be made in various forms:
- Fixing bugs.
- Implementing papers.
- Writing high-coverage tests.
- OPtimizing existing codes.
- Optimizing existing codes.
- Experimenting and submitting real-world examples to the examples section.
- Reporting bugs.
- Responding to reported issues.

Coming features include:
- Reinforcement Learning With Human Feedback (RLHF).
- Tokenizers.
- Code optimisations.

To follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A)

## Sponsorships
Expand Down
61 changes: 59 additions & 2 deletions nanodl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__version__ = "1.0.0.dev1"
__version__ = "1.2.0.dev1"

from nanodl.__src.sklearn_gpu.bayes import NaiveBayesClassifier
from nanodl.__src.sklearn_gpu.dimensionality_reduction import PCA
from nanodl.__src.sklearn_gpu.clustering import KMeans, GaussianMixtureModel
from nanodl.__src.utils.random import *

from nanodl.__src.sklearn_gpu.regression import (
LinearRegression,
Expand Down Expand Up @@ -106,6 +107,7 @@
UNetResidualBlock
)


from nanodl.__src.models.transformer import (
Transformer,
TransformerDataParallelTrainer,
Expand All @@ -118,6 +120,26 @@
AddNorm
)

from nanodl.__src.models.gemma import (
Gemma,
GemmaDataParallelTrainer,
GemmaDecoder,
GemmaDecoderBlock
)

from nanodl.__src.models.reward import (
RewardModel,
RewardDataParallelTrainer
)

from nanodl.__src.layers.attention import (
MultiQueryAttention,
LocalMultiHeadAttention,
HierarchicalMultiHeadAttention,
GatedMultiHeadAttention,
RotaryMultiHeadAttention
)

from nanodl.__src.utils.data import (
Dataset,
ArrayDataset,
Expand Down Expand Up @@ -170,6 +192,10 @@
"GaussianProcess",

# Models
"Gemma",
"GemmaDataParallelTrainer",
"GemmaDecoder",
"GemmaDecoderBlock",
"GAT",
"GraphAttentionLayer",
"T5",
Expand Down Expand Up @@ -223,6 +249,8 @@
"WhisperDataParallelTrainer",
"WhisperSpeechEncoder",
"WhisperSpeechEncoderBlock",
"RewardModel",
"RewardDataParallelTrainer",
"DiffusionModel",
"DiffusionDataParallelTrainer",
"UNet",
Expand Down Expand Up @@ -267,7 +295,32 @@
"normalize_images",
"random_crop",
"random_flip_image",
"sobel_edge_detection"
"sobel_edge_detection",
"MultiQueryAttention",
"LocalMultiHeadAttention",
"HierarchicalMultiHeadAttention",
"GatedMultiHeadAttention",
"RotaryMultiHeadAttention",

# Random
"time_rng_key",
"uniform",
"normal",
"bernoulli",
"categorical",
"randint",
"permutation",
"gumbel",
"choice",
"binomial",
"bits",
"exponential",
"triangular",
"truncated_normal",
"poisson",
"geometric",
"gamma",
"chisquare",
]

import importlib
Expand All @@ -289,11 +342,15 @@ def test_jax(jax):
def test_optax(optax):
optimizer = optax.sgd(learning_rate=0.1)

def test_einops(einops):
arr = einops.rearrange([1, 2, 3], 'a b c -> b a c')

def main():
try:
flax = check_library_installed('flax')
jax = check_library_installed('jax')
optax = check_library_installed('optax')
einops = check_library_installed('einops')

test_flax(flax)
test_jax(jax)
Expand Down
Empty file added nanodl/__src/layers/__init__.py
Empty file.
Loading

0 comments on commit a95e62d

Please sign in to comment.