Skip to content

Commit

Permalink
added features
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Feb 22, 2024
1 parent 18c7f8e commit d2f816f
Show file tree
Hide file tree
Showing 12 changed files with 1,241 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ __pycache__/
# Ignore configuration files with sensitive information
config.ini
secrets.yaml
params.pkl

# Ignore user-specific files
/userdata/
Expand Down
25 changes: 22 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@ Author: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/)
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
- An extensive selection of models like LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers includding RLHFPPO and RLHFDPO so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
- True random number generators in Jax which do not need the verbose code.
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU etc.
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used.

Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at [email protected].

## What's New in version 1.2.0.dev1

- Google's Gemma and MAMBA architectures.
- Data parallel distributed RLHFPPO and RLHFDPO.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

## Quick install

You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)
Expand Down Expand Up @@ -260,7 +267,19 @@ X_sampled = pca.sample(n_samples=1000, key=None)
print(X_sampled.shape, original_data.shape, transformed_data.shape)
```

# Contribution
NanoDL provides random module which abstracts away Jax's intricacies.
It generates truly random variables by using the current timestamp as seed.
```py
# Jax example
key = random.PRNGKey(0)
jax_array = random.uniform(key, shape=(3, 3))

# NanoDL example
jax_array = nanodl.uniform(shape=(3, 3))

# For reproducability, use seed
jax_array = nanodl.uniform(shape=(3, 3), seed=0)
```

This is the first iteration of this project, roughness is expected, contributions are therefore highly encouraged! Follow the recommended steps:

Expand Down
36 changes: 34 additions & 2 deletions nanodl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__version__ = "1.0.0.dev1"
__version__ = "1.2.0.dev1"

from nanodl.__src.sklearn_gpu.bayes import NaiveBayesClassifier
from nanodl.__src.sklearn_gpu.dimensionality_reduction import PCA
from nanodl.__src.sklearn_gpu.clustering import KMeans, GaussianMixtureModel
from nanodl.__src.utils.random import *

from nanodl.__src.sklearn_gpu.regression import (
LinearRegression,
Expand Down Expand Up @@ -118,6 +119,13 @@
AddNorm
)

from nanodl.__src.models.gemma import (
Gemma,
GemmaDataParallelTrainer,
GemmaDecoder,
GemmaDecoderBlock
)

from nanodl.__src.utils.data import (
Dataset,
ArrayDataset,
Expand Down Expand Up @@ -170,6 +178,10 @@
"GaussianProcess",

# Models
"Gemma",
"GemmaDataParallelTrainer",
"GemmaDecoder",
"GemmaDecoderBlock",
"GAT",
"GraphAttentionLayer",
"T5",
Expand Down Expand Up @@ -267,7 +279,27 @@
"normalize_images",
"random_crop",
"random_flip_image",
"sobel_edge_detection"
"sobel_edge_detection",

# Random
"time_rng_key",
"uniform",
"normal",
"bernoulli",
"categorical",
"randint",
"permutation",
"gumbel",
"choice",
"binomial",
"bits",
"exponential",
"triangular",
"truncated_normal",
"poisson",
"geometric",
"gamma",
"chisquare",
]

import importlib
Expand Down
Empty file added nanodl/__src/layers/__init__.py
Empty file.
Empty file.
34 changes: 34 additions & 0 deletions nanodl/__src/layers/general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import jax
import time
import jax.numpy as jnp
from jax import random

def dropout(x: jnp.ndarray,
rate: float,
training: bool = False) -> jnp.ndarray:
"""Apply dropout to input tensor.
Args:
x (jnp.ndarray): Input tensor.
rate (float): Dropout rate, must be between 0 and 1.
training (bool, optional): Whether to apply dropout.
If False, returns input tensor unchanged. Defaults to False.
Raises:
ValueError: If dropout rate is not in [0, 1).
Returns:
jnp.ndarray: Tensor after applying dropout.
"""
if not training:
return x

if not 0 <= rate < 1:
raise ValueError("Dropout rate must be in the range [0, 1).")

if rate == 0:
return x

keep_prob = 1 - rate
mask = jax.random.bernoulli(random.PRNGKey(int(time.time())), keep_prob, x.shape)
return jax.lax.select(mask, x / keep_prob, jnp.zeros_like(x))
Loading

0 comments on commit d2f816f

Please sign in to comment.