diff --git a/README.md b/README.md index d4d806f..660bc34 100644 --- a/README.md +++ b/README.md @@ -60,21 +60,46 @@ import jax import jax.numpy as jnp from nanodl import time_rng_key from nanodl import ArrayDataset, DataLoader -from nanodl import GPT4, GPTDataParallelTrainer +from nanodl import GPT4, GPTDataParallelTrainer, Tokenizer -# Generate dummy data +# Preparing your dataset batch_size = 8 -max_length = 10 +max_length = 50 +vocab_size = 1000 + +text_paths = ['/path/sample1.txt', + '/path/sample2.txt', + '/path/sample3.txt'] + +tokenizer = Tokenizer(training_data=text_paths, + vocab_size=vocab_size, + model_type='bpe', + max_sentence_length=max_length) + +data = [] +for path in text_paths: + with open(path, 'r') as file: + text = file.read() + # To-Do: preprocess however you wish + encoded = list(map(tokenizer.encode, text)) + data.extend(encoded) + +# Pad sequences with 0 +max_length = max(len(seq) for seq in data) +padded = [seq + [0] * (max_length - len(seq)) for seq in data] -# Replace with actual list of tokenised texts -data = jnp.ones((101, max_length), dtype=jnp.int32) +# Jax does not support strings yet, encode before converting to array +data = jnp.array(padded) # Shift to create next-token prediction dataset dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:] # Create dataset and dataloader dataset = ArrayDataset(dummy_inputs, dummy_targets) -dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) +dataloader = DataLoader(dataset, + batch_size=batch_size, + shuffle=True, + drop_last=False) # model parameters hyperparams = { @@ -83,24 +108,23 @@ hyperparams = { 'num_heads': 2, 'feedforward_dim': 256, 'dropout': 0.1, - 'vocab_size': 1000, + 'vocab_size': vocab_size, 'embed_dim': 256, 'max_length': max_length, 'start_token': 0, 'end_token': 50, } -# Initialize inferred GPT4 model +# Inferred GPT4 model model = GPT4(**hyperparams) -params = model.init( - {'params': time_rng_key(), - 'dropout': time_rng_key() - }, - dummy_inputs)['params'] -# Training on data -trainer = GPTDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl') -trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader) +trainer = GPTDataParallelTrainer(model, + dummy_inputs.shape, + 'params.pkl') + +trainer.train(train_loader=dataloader, + num_epochs=100, + val_loader=dataloader) #To Do: replace with actual val data # Generating from a start token start_tokens = jnp.array([[123, 456]]) @@ -111,6 +135,9 @@ outputs = model.apply({'params': params}, start_tokens, rngs={'dropout': time_rng_key()}, method=model.generate) + +# Jax does not support strings yet, convert to list before decoding +outputs = tokenizer.decode(outputs.tolist()) ``` Vision example @@ -135,18 +162,17 @@ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last= # Create diffusion model diffusion_model = DiffusionModel(image_size, widths, block_depth) -params = diffusion_model.init(key, images) -pred_noises, pred_images = diffusion_model.apply(params, images) -print(pred_noises.shape, pred_images.shape) # Training on your data trainer = DiffusionDataParallelTrainer(diffusion_model, input_shape=images.shape, weights_filename='params.pkl', learning_rate=1e-4) + trainer.train(dataloader, 10, dataloader) -# Generate some samples +# Generate some samples: Each model is a Flax.linen module +# Use as you normally would params = trainer.load_params('params.pkl') generated_images = diffusion_model.apply({'params': params}, num_images=5, @@ -192,14 +218,13 @@ hyperparams = { # Initialize model model = Whisper(**hyperparams) -rngs = {'params': time_rng_key(), 'dropout': time_rng_key()} -params = model.init(rngs, dummy_inputs, dummy_targets)['params'] # Training on your data trainer = WhisperDataParallelTrainer(model, dummy_inputs.shape, dummy_targets.shape, 'params.pkl') + trainer.train(dataloader, 2, dataloader) # Sample inference @@ -347,4 +372,4 @@ To cite this repository: url = {http://github.com/hmunachi/nanodl}, year = {2024}, } -``` \ No newline at end of file +```