Skip to content

Commit

Permalink
Pushed v1.2.4Dev1 codes
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed May 12, 2024
1 parent bce1cdb commit f72df49
Show file tree
Hide file tree
Showing 42 changed files with 5,803 additions and 4,510 deletions.
74 changes: 15 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ Each model is purposefully contained in a file without inter-file dependencies.
Developing and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:

- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.
- An extensive selection of models like Gemma, LlaMa2, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications.
- Data-parallel distributed trainers includding RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops.
- An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc.
- Data-parallel distributed trainers models on multiple GPUs or TPUs, without the need for manual training loops.
- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.
- Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU.
- Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models.
- Layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.
- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc.
- True random number generators in Jax which do not need the verbose code.
- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc.
- Each model is contained in a single file with no external dependencies, so the source code can also be easily used.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

There are experimental features (like MAMBA architecture and RLHF) in the repo which are not available via the package, pending tests.
There are experimental and/or unfinished features (like MAMBA, KAN, BitNet, GAT and RLHF)
in the repo which are not yet available via the package, but can be copied from this repo.
Feedback on any of our discussion, issue and pull request threads are welcomed!
Please report any feature requests, issues, questions or concerns in the [Discord](https://discord.gg/3u9vumJEmz),
or just let us know what you're working on!
Expand Down Expand Up @@ -58,7 +58,6 @@ We provide various example usages of the nanodl API.
```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer, Tokenizer

Expand All @@ -67,29 +66,8 @@ batch_size = 8
max_length = 50
vocab_size = 1000

text_paths = ['/path/sample1.txt',
'/path/sample2.txt',
'/path/sample3.txt']

tokenizer = Tokenizer(training_data=text_paths,
vocab_size=vocab_size,
model_type='bpe',
max_sentence_length=max_length)

data = []
for path in text_paths:
with open(path, 'r') as file:
text = file.read()
# To-Do: preprocess however you wish
encoded = list(map(tokenizer.encode, text))
data.extend(encoded)

# Pad sequences with 0
max_length = max(len(seq) for seq in data)
padded = [seq + [0] * (max_length - len(seq)) for seq in data]

# Jax does not support strings yet, encode before converting to array
data = jnp.array(padded)
# Create random data
data = nanodl.uniform(shape=(batch, max_length))

# Shift to create next-token prediction dataset
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]
Expand Down Expand Up @@ -124,7 +102,7 @@ trainer = GPTDataParallelTrainer(model,

trainer.train(train_loader=dataloader,
num_epochs=100,
val_loader=dataloader) #To Do: replace with actual val data
val_loader=dataloader) # use actual val data

# Generating from a start token
start_tokens = jnp.array([[123, 456]])
Expand All @@ -133,19 +111,15 @@ start_tokens = jnp.array([[123, 456]])
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': time_rng_key()},
rngs={'dropout': nanodl.time_rng_key()},
method=model.generate)

# Jax does not support strings yet, convert to list before decoding
outputs = tokenizer.decode(outputs.tolist())
```

Vision example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

Expand All @@ -154,7 +128,7 @@ block_depth = 2
batch_size = 8
widths = [32, 64, 128]
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(time_rng_key(), input_shape)
images = nanodl.normal(shape=input_shape)

# Use your own images
dataset = ArrayDataset(images)
Expand All @@ -169,7 +143,7 @@ trainer = DiffusionDataParallelTrainer(diffusion_model,
weights_filename='params.pkl',
learning_rate=1e-4)

trainer.train(dataloader, 10, dataloader)
trainer.train(dataloader, 10, dataloader) # use actual val data

# Generate some samples: Each model is a Flax.linen module
# Use as you normally would
Expand All @@ -185,7 +159,6 @@ Audio example
```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Whisper, WhisperDataParallelTrainer

Expand Down Expand Up @@ -287,7 +260,7 @@ params = trainer.load_params('reward_model_weights.pkl')
# Call as you would a regular Flax model
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': time_rng_key()})
rngs={'dropout': nanodl.time_rng_key()})
```

PCA example
Expand All @@ -297,7 +270,7 @@ import jax
from nanodl import PCA

# Use actual data
data = jax.random.normal(jax.random.key(0), (1000, 10))
data = nanodl.normal(shape=(1000, 10))

# Initialise and train PCA model
pca = PCA(n_components=2)
Expand All @@ -313,24 +286,7 @@ original_data = pca.inverse_transform(transformed_data)
X_sampled = pca.sample(n_samples=1000, key=None)
```

NanoDL provides random module which abstracts away Jax's intricacies.
It generates truly random variables by using the current timestamp as seed.

```py
import jax

# Jax example
key = jax.random.PRNGKey(0)
jax_array = jax.random.uniform(key, shape=(3, 3))

# NanoDL example
jax_array = nanodl.uniform(shape=(3, 3))

# For reproducability, use seed
jax_array = nanodl.uniform(shape=(3, 3), seed=0)
```

This is the first iteration of this project, roughness is expected, and contributions are therefore highly encouraged!
This is still in dev, works great but roughness is expected, and contributions are therefore highly encouraged!

- Make your changes without changing the design patterns.
- Write tests for your changes if necessary.
Expand Down
Loading

0 comments on commit f72df49

Please sign in to comment.