Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Mar 27, 2024
1 parent 0d4b28d commit 6a951af
Showing 1 changed file with 48 additions and 23 deletions.
71 changes: 48 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,46 @@ import jax
import jax.numpy as jnp
from nanodl import time_rng_key
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer
from nanodl import GPT4, GPTDataParallelTrainer, Tokenizer

# Generate dummy data
# Preparing your dataset
batch_size = 8
max_length = 10
max_length = 50
vocab_size = 1000

text_paths = ['/path/sample1.txt',
'/path/sample2.txt',
'/path/sample3.txt']

tokenizer = Tokenizer(training_data=text_paths,
vocab_size=vocab_size,
model_type='bpe',
max_sentence_length=max_length)

data = []
for path in text_paths:
with open(path, 'r') as file:
text = file.read()
# To-Do: preprocess however you wish
encoded = list(map(tokenizer.encode, text))
data.extend(encoded)

# Pad sequences with 0
max_length = max(len(seq) for seq in data)
padded = [seq + [0] * (max_length - len(seq)) for seq in data]

# Replace with actual list of tokenised texts
data = jnp.ones((101, max_length), dtype=jnp.int32)
# Jax does not support strings yet, encode before converting to array
data = jnp.array(padded)

# Shift to create next-token prediction dataset
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]

# Create dataset and dataloader
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
drop_last=False)

# model parameters
hyperparams = {
Expand All @@ -83,24 +108,23 @@ hyperparams = {
'num_heads': 2,
'feedforward_dim': 256,
'dropout': 0.1,
'vocab_size': 1000,
'vocab_size': vocab_size,
'embed_dim': 256,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
}

# Initialize inferred GPT4 model
# Inferred GPT4 model
model = GPT4(**hyperparams)
params = model.init(
{'params': time_rng_key(),
'dropout': time_rng_key()
},
dummy_inputs)['params']

# Training on data
trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)
trainer = GPTDataParallelTrainer(model,
dummy_inputs.shape,
'params.pkl')

trainer.train(train_loader=dataloader,
num_epochs=100,
val_loader=dataloader) #To Do: replace with actual val data

# Generating from a start token
start_tokens = jnp.array([[123, 456]])
Expand All @@ -111,6 +135,9 @@ outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': time_rng_key()},
method=model.generate)

# Jax does not support strings yet, convert to list before decoding
outputs = tokenizer.decode(outputs.tolist())
```

Vision example
Expand All @@ -135,18 +162,17 @@ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=

# Create diffusion model
diffusion_model = DiffusionModel(image_size, widths, block_depth)
params = diffusion_model.init(key, images)
pred_noises, pred_images = diffusion_model.apply(params, images)
print(pred_noises.shape, pred_images.shape)

# Training on your data
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)

trainer.train(dataloader, 10, dataloader)

# Generate some samples
# Generate some samples: Each model is a Flax.linen module
# Use as you normally would
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
Expand Down Expand Up @@ -192,14 +218,13 @@ hyperparams = {

# Initialize model
model = Whisper(**hyperparams)
rngs = {'params': time_rng_key(), 'dropout': time_rng_key()}
params = model.init(rngs, dummy_inputs, dummy_targets)['params']

# Training on your data
trainer = WhisperDataParallelTrainer(model,
dummy_inputs.shape,
dummy_targets.shape,
'params.pkl')

trainer.train(dataloader, 2, dataloader)

# Sample inference
Expand Down Expand Up @@ -347,4 +372,4 @@ To cite this repository:
url = {https://github.com/hmunachi/nanodl},
year = {2024},
}
```
```

0 comments on commit 6a951af

Please sign in to comment.