-
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,241 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,17 +20,24 @@ Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) | |
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features: | ||
|
||
- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. | ||
- An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. | ||
- Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. | ||
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. | ||
- Data-parallel distributed trainers includding RLHFPPO and RLHFDPO so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. | ||
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective. | ||
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. | ||
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. | ||
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. | ||
- True random number generators in Jax which do not need the verbose code. | ||
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc. | ||
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used. | ||
|
||
Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at [email protected]. | ||
|
||
## What's New in version 1.2.0.dev1 | ||
|
||
- Google's Gemma and MAMBA architectures. | ||
- Data parallel distributed RLHFPPO and RLHFDPO. | ||
- True random number generators in Jax which do not need the verbose code (examples shown in next sections). | ||
|
||
## Quick install | ||
|
||
You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md) | ||
|
@@ -260,7 +267,19 @@ X_sampled = pca.sample(n_samples=1000, key=None) | |
print(X_sampled.shape, original_data.shape, transformed_data.shape) | ||
``` | ||
|
||
# Contribution | ||
NanoDL provides random module which abstracts away Jax's intricacies. | ||
It generates truly random variables by using the current timestamp as seed. | ||
```py | ||
# Jax example | ||
key = random.PRNGKey(0) | ||
jax_array = random.uniform(key, shape=(3, 3)) | ||
|
||
# NanoDL example | ||
jax_array = nanodl.uniform(shape=(3, 3)) | ||
|
||
# For reproducability, use seed | ||
jax_array = nanodl.uniform(shape=(3, 3), seed=0) | ||
``` | ||
|
||
This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps: | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import jax | ||
import time | ||
import jax.numpy as jnp | ||
from jax import random | ||
|
||
def dropout(x: jnp.ndarray, | ||
rate: float, | ||
training: bool = False) -> jnp.ndarray: | ||
"""Apply dropout to input tensor. | ||
Args: | ||
x (jnp.ndarray): Input tensor. | ||
rate (float): Dropout rate, must be between 0 and 1. | ||
training (bool, optional): Whether to apply dropout. | ||
If False, returns input tensor unchanged. Defaults to False. | ||
Raises: | ||
ValueError: If dropout rate is not in [0, 1). | ||
Returns: | ||
jnp.ndarray: Tensor after applying dropout. | ||
""" | ||
if not training: | ||
return x | ||
|
||
if not 0 <= rate < 1: | ||
raise ValueError("Dropout rate must be in the range [0, 1).") | ||
|
||
if rate == 0: | ||
return x | ||
|
||
keep_prob = 1 - rate | ||
mask = jax.random.bernoulli(random.PRNGKey(int(time.time())), keep_prob, x.shape) | ||
return jax.lax.select(mask, x / keep_prob, jnp.zeros_like(x)) |
Oops, something went wrong.