Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Mar 19, 2024
1 parent 833aebe commit 1a4d354
Showing 1 changed file with 33 additions and 63 deletions.
96 changes: 33 additions & 63 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,23 @@ We provide various example usages of the nanodl API.
```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer

# Generate dummy data
batch_size = 8
max_length = 10

# Replace with actual tokenised data
# Replace with actual list of tokenised texts
data = jnp.ones((101, max_length), dtype=jnp.int32)

# Shift to create next-token prediction dataset
dummy_inputs = data[:, :-1]
dummy_targets = data[:, 1:]
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]

# Create dataset and dataloader
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)

# How to loop through dataloader
for batch in dataloader:
x, y = batch
print(x.shape, y.shape)
break
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -103,25 +94,17 @@ hyperparams = {
'end_token': 50,
}

# Initialize model
# Initialize inferred GPT4 model
model = GPT4(**hyperparams)
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
params = model.init({'params': rngs, 'dropout': dropout_rng}, dummy_inputs)['params']

# Call as you would a Jax/Flax model
outputs = model.apply({'params': params},
dummy_inputs,
rngs={'dropout': dropout_rng})
print(outputs.shape)
params = model.init(
{'params': time_rng_key(),
'dropout': time_rng_key()
},
dummy_inputs)['params']

# Training on data
trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
trainer.train(train_loader=dataloader,
num_epochs=2,
val_loader=dataloader)

print(trainer.evaluate(dataloader))
trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)

# Generating from a start token
start_tokens = jnp.array([[123, 456]])
Expand All @@ -130,33 +113,29 @@ start_tokens = jnp.array([[123, 456]])
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': jax.random.PRNGKey(2)},
rngs={'dropout': time_rng_key()},
method=model.generate)
print(outputs)
```

Vision example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer

image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
key = jax.random.PRNGKey(0)
input_shape = (101, image_size, image_size, 3)
images = jax.random.normal(key, input_shape)
images = jax.random.normal(time_rng_key(), input_shape)

# Use your own images
dataset = ArrayDataset(images)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
Expand All @@ -165,28 +144,26 @@ pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

# Training on your data
# Note: saved params are often different from training weights, use the saved params for generation
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)
trainer.train(dataloader, 10, dataloader)
print(trainer.evaluate(dataloader))

# Generate some samples
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
diffusion_steps=5,
method=diffusion_model.generate)
print(generated_images.shape)
```

Audio example

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Whisper, WhisperDataParallelTrainer

Expand All @@ -200,13 +177,8 @@ vocab_size = 1000
dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_inputs = jnp.ones((101, max_length, embed_dim))

dataset = ArrayDataset(dummy_inputs,
dummy_targets)

dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -224,10 +196,8 @@ hyperparams = {

# Initialize model
model = Whisper(**hyperparams)
rngs = {'params': jax.random.key(0), 'dropout': jax.random.key(1)}
rngs = {'params': time_rng_key(), 'dropout': time_rng_key()}
params = model.init(rngs, dummy_inputs, dummy_targets)['params']
outputs = model.apply({'params': params}, dummy_inputs, dummy_targets, rngs=rngs)
print(outputs.shape)

# Training on your data
trainer = WhisperDataParallelTrainer(model,
Expand All @@ -239,20 +209,19 @@ trainer.train(dataloader, 2, dataloader)
# Sample inference
params = trainer.load_params('params.pkl')

# for more than one sample, use model.generate_batch
# for more than one sample, often use model.generate_batch
transcripts = model.apply({'params': params},
dummy_inputs[:1],
rngs=rngs,
method=model.generate)

print(transcripts)
```

Reward Model example for RLHF

```py
import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer

Expand All @@ -266,10 +235,7 @@ dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)

# Create dataset and dataloader
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

# model parameters
hyperparams = {
Expand Down Expand Up @@ -298,13 +264,9 @@ trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')

# Call as you would a regular Flax model
rngs = jax.random.PRNGKey(0)
rngs, dropout_rng = jax.random.split(rngs)
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': dropout_rng})

print(rewards.shape)
rngs={'dropout': time_rng_key()})
```

PCA example
Expand All @@ -313,13 +275,21 @@ PCA example
import jax
from nanodl import PCA

# Use actual data
data = jax.random.normal(jax.random.key(0), (1000, 10))

# Initialise and train PCA model
pca = PCA(n_components=2)
pca.fit(data)

# Get PCA transforms
transformed_data = pca.transform(data)

# Get reverse transforms
original_data = pca.inverse_transform(transformed_data)

# Sample from the distribution
X_sampled = pca.sample(n_samples=1000, key=None)
print(X_sampled.shape, original_data.shape, transformed_data.shape)
```

NanoDL provides random module which abstracts away Jax's intricacies.
Expand Down Expand Up @@ -371,7 +341,7 @@ Following the success of Phi models, the long-term goal is to build and train na
while ensuring they compete with the original models in performance, with total
number of parameters not exceeding 1B. Trained weights will be made available via this library.
Any form of sponsorship, funding, grants or contribution will help with training resources.
You can sponsor via the tag on the user profile, or reach out via [email protected].
You can sponsor via the user profile tag or reach out via [email protected].

## Citing nanodl

Expand Down

0 comments on commit 1a4d354

Please sign in to comment.