-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from HMUNACHI/dev
NanoDL 1.2.0.dev1
- Loading branch information
Showing
16 changed files
with
2,906 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,17 +20,26 @@ Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) | |
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: | ||
|
||
- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. | ||
- An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. | ||
- Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. | ||
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. | ||
- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. | ||
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective. | ||
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. | ||
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. | ||
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. | ||
- True random number generators in Jax which do not need the verbose code. | ||
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc. | ||
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used. | ||
|
||
Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at [email protected]. | ||
|
||
## What's New in version 1.2.0.dev1 | ||
|
||
- Google's Gemma architecture. | ||
- Reward model wrapper and data-parallel distributed reward trainer. | ||
- True random number generators in Jax which do not need the verbose code (examples shown in next sections). | ||
|
||
There are experimental features (like MAMBA architecture and RLHF) in the repo which is not available via the package, pending tests. | ||
|
||
## Quick install | ||
|
||
You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md) | ||
|
@@ -52,7 +61,7 @@ pip install nanodl | |
|
||
## What does nanodl look like? | ||
|
||
We provide various examples using the nanodl API: language, vision and audio, starting with an LLM. | ||
We provide various example usages of the nanodl API. | ||
|
||
```py | ||
import jax | ||
|
@@ -129,7 +138,9 @@ outputs = model.apply({'params': params}, | |
method=model.generate) | ||
print(outputs) | ||
``` | ||
|
||
Vision example | ||
|
||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
|
@@ -176,6 +187,7 @@ print(generated_images.shape) | |
``` | ||
|
||
Audio example | ||
|
||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
|
@@ -200,12 +212,6 @@ dataloader = DataLoader(dataset, | |
shuffle=True, | ||
drop_last=False) | ||
|
||
# How to loop through dataloader | ||
for batch in dataloader: | ||
x, y = batch | ||
print(x.shape, y.shape) | ||
break | ||
|
||
# model parameters | ||
hyperparams = { | ||
'num_layers': 1, | ||
|
@@ -246,7 +252,67 @@ transcripts = model.apply({'params': params}, | |
print(transcripts) | ||
``` | ||
|
||
Reward Model example for RLHF | ||
|
||
```py | ||
import jax | ||
import jax.numpy as jnp | ||
from nanodl import ArrayDataset, DataLoader | ||
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer | ||
|
||
# Generate dummy data | ||
batch_size = 8 | ||
max_length = 10 | ||
|
||
# Replace with actual tokenised data | ||
dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32) | ||
dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32) | ||
|
||
# Create dataset and dataloader | ||
dataset = ArrayDataset(dummy_chosen, dummy_rejected) | ||
dataloader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
drop_last=False) | ||
|
||
# model parameters | ||
hyperparams = { | ||
'num_layers': 1, | ||
'hidden_dim': 256, | ||
'num_heads': 2, | ||
'feedforward_dim': 256, | ||
'dropout': 0.1, | ||
'vocab_size': 1000, | ||
'embed_dim': 256, | ||
'max_length': max_length, | ||
'start_token': 0, | ||
'end_token': 50, | ||
'num_groups': 2, | ||
'window_size': 5, | ||
'shift_size': 2 | ||
} | ||
|
||
# Initialize reward model from Mistral | ||
model = Mistral(**hyperparams) | ||
reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1) | ||
|
||
# Train the reward model | ||
trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl') | ||
trainer.train(dataloader, 5, dataloader) | ||
params = trainer.load_params('reward_model_weights.pkl') | ||
|
||
# Call as you would a regular Flax model | ||
rngs = jax.random.PRNGKey(0) | ||
rngs, dropout_rng = jax.random.split(rngs) | ||
rewards = reward_model.apply({'params': params}, | ||
dummy_chosen, | ||
rngs={'dropout': dropout_rng}) | ||
|
||
print(rewards.shape) | ||
``` | ||
|
||
PCA example | ||
|
||
```py | ||
import jax | ||
from nanodl import PCA | ||
|
@@ -260,14 +326,27 @@ X_sampled = pca.sample(n_samples=1000, key=None) | |
print(X_sampled.shape, original_data.shape, transformed_data.shape) | ||
``` | ||
|
||
# Contribution | ||
NanoDL provides random module which abstracts away Jax's intricacies. | ||
It generates truly random variables by using the current timestamp as seed. | ||
|
||
```py | ||
# Jax example | ||
key = random.PRNGKey(0) | ||
jax_array = random.uniform(key, shape=(3, 3)) | ||
|
||
# NanoDL example | ||
jax_array = nanodl.uniform(shape=(3, 3)) | ||
|
||
# For reproducability, use seed | ||
jax_array = nanodl.uniform(shape=(3, 3), seed=0) | ||
``` | ||
|
||
This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: | ||
|
||
- Raise the issue/discussion to get second opinions | ||
- Fork the repository | ||
- Create a branch | ||
- Make your changes without ruining the design patterns | ||
- Make your changes without changing the design patterns | ||
- Write tests for your changes if necessary | ||
- Install locally with `pip install -e .` | ||
- Run tests with `python -m unittest discover -s tests` | ||
|
@@ -279,16 +358,11 @@ Contributions can be made in various forms: | |
- Fixing bugs. | ||
- Implementing papers. | ||
- Writing high-coverage tests. | ||
- OPtimizing existing codes. | ||
- Optimizing existing codes. | ||
- Experimenting and submitting real-world examples to the examples section. | ||
- Reporting bugs. | ||
- Responding to reported issues. | ||
|
||
Coming features include: | ||
- Reinforcement Learning With Human Feedback (RLHF). | ||
- Tokenizers. | ||
- Code optimisations. | ||
|
||
To follow up or share thoughts, follow [here](https://forms.gle/vwveb9SKdPYywHx9A) | ||
|
||
## Sponsorships | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.