Skip to content

Commit

Permalink
Merge pull request #18 from HMUNACHI/dev
Browse files Browse the repository at this point in the history
Some patches
  • Loading branch information
HMUNACHI committed Mar 12, 2024
2 parents 8508e7a + cf3f8d7 commit 8aecd88
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 66 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Developing and training transformer-based models is typically resource-intensive

Feedback on any of our discussion, issue and pull request threads are welcomed! Please report any feature requests, issues, questions or concerns in the [discussion forum](https://github.com/hmunachi/nanodl/discussions), or just let us know what you're working on! In case you want to reach out directly, we're at [email protected].

## What's New in version 1.2.0.dev1
## What's New in version 1.2.1.dev1

- Google's Gemma architecture.
- Reward model wrapper and data-parallel distributed reward trainer.
Expand Down Expand Up @@ -330,9 +330,11 @@ NanoDL provides random module which abstracts away Jax's intricacies.
It generates truly random variables by using the current timestamp as seed.

```py
import jax

# Jax example
key = random.PRNGKey(0)
jax_array = random.uniform(key, shape=(3, 3))
key = jax.random.PRNGKey(0)
jax_array = jax.random.uniform(key, shape=(3, 3))

# NanoDL example
jax_array = nanodl.uniform(shape=(3, 3))
Expand Down
2 changes: 1 addition & 1 deletion nanodl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.2.0.dev1"
__version__ = "1.2.1.dev1"

from nanodl.__src.sklearn_gpu.bayes import NaiveBayesClassifier
from nanodl.__src.sklearn_gpu.dimensionality_reduction import PCA
Expand Down
124 changes: 62 additions & 62 deletions nanodl/__src/models/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flax.training import train_state
from typing import Tuple, Any, Optional, Iterable


# Still in active development
class RLHF(nn.Module):
policy_network: Any
reference: bool = False
Expand Down Expand Up @@ -244,72 +244,72 @@ def load_params(self, filename: str, params=None):



from nanodl import ArrayDataset, DataLoader
from nanodl import Gemma, GemmaDataParallelTrainer
from nanodl import RewardModel, RewardDataParallelTrainer
# from nanodl import RLHF, PPODataParallelTrainer

batch_size = 8
max_length = 10
model_params_path = 'base_params.pkl'
rlhf_params_path = 'rlhf_params.pkl'
reward_params_path = 'reward_params.pkl'
# from nanodl import ArrayDataset, DataLoader
# from nanodl import Gemma, GemmaDataParallelTrainer
# from nanodl import RewardModel, RewardDataParallelTrainer
# # from nanodl import RLHF, PPODataParallelTrainer

# model parameters
hyperparams = {
'num_layers': 1,
'hidden_dim': 128,
'num_heads': 2,
'feedforward_dim': 128,
'dropout': 0.1,
'vocab_size': 200,
'embed_dim': 128,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
'num_groups': 2,
}
# batch_size = 8
# max_length = 10
# model_params_path = 'base_params.pkl'
# rlhf_params_path = 'rlhf_params.pkl'
# reward_params_path = 'reward_params.pkl'

print('Step 1: Pretraining')
# Replace with actual tokenised data
data = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_inputs = data[:, :-1]
dummy_targets = data[:, 1:]
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
model = Gemma(**hyperparams)
# trainer = GemmaDataParallelTrainer(model, dummy_inputs.shape, model_params_path)
# trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)
# # model parameters
# hyperparams = {
# 'num_layers': 1,
# 'hidden_dim': 128,
# 'num_heads': 2,
# 'feedforward_dim': 128,
# 'dropout': 0.1,
# 'vocab_size': 200,
# 'embed_dim': 128,
# 'max_length': max_length,
# 'start_token': 0,
# 'end_token': 50,
# 'num_groups': 2,
# }

print('\nStep 2: Superfised Fine-Tuning')
# Replace with actual tokenised data
dummy_prompt = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)
# dataset = ArrayDataset(dummy_prompt, dummy_chosen)
# print('Step 1: Pretraining')
# # Replace with actual tokenised data
# data = jnp.ones((101, max_length), dtype=jnp.int32)
# dummy_inputs = data[:, :-1]
# dummy_targets = data[:, 1:]
# dataset = ArrayDataset(dummy_inputs, dummy_targets)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# model = Gemma(**hyperparams)
# trainer = GemmaDataParallelTrainer(model, dummy_prompt.shape, model_params_path)
# trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)
# # trainer = GemmaDataParallelTrainer(model, dummy_inputs.shape, model_params_path)
# # trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)

print('\nStep 3: Train a reward model')
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
reward_model = RewardModel(Gemma(**hyperparams), dim=hyperparams['hidden_dim'], dropout=0.1)
# trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, reward_params_path)
# trainer.train(dataloader, 2, dataloader)
# print('\nStep 2: Superfised Fine-Tuning')
# # Replace with actual tokenised data
# dummy_prompt = jnp.ones((101, max_length), dtype=jnp.int32)
# dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
# dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)
# # dataset = ArrayDataset(dummy_prompt, dummy_chosen)
# # dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# # model = Gemma(**hyperparams)
# # trainer = GemmaDataParallelTrainer(model, dummy_prompt.shape, model_params_path)
# # trainer.train(train_loader=dataloader, num_epochs=2, val_loader=dataloader)

print('\nStep 4: Train the RLHF model via PPO, using a reference model and the reward model.')
rlhf_model = RLHF(model)
rlhf_ref = RLHF(model, reference=True)
dataset = ArrayDataset(dummy_chosen)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
trainer = PPODataParallelTrainer(rlhf_model,
rlhf_ref,
reward_model,
dummy_inputs.shape,
rlhf_params_path,
sft_params_path=model_params_path,
reward_params_path=reward_params_path)
# print('\nStep 3: Train a reward model')
# dataset = ArrayDataset(dummy_chosen, dummy_rejected)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# reward_model = RewardModel(Gemma(**hyperparams), dim=hyperparams['hidden_dim'], dropout=0.1)
# # trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, reward_params_path)
# # trainer.train(dataloader, 2, dataloader)

# print('\nStep 4: Train the RLHF model via PPO, using a reference model and the reward model.')
# rlhf_model = RLHF(model)
# rlhf_ref = RLHF(model, reference=True)
# dataset = ArrayDataset(dummy_chosen)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# trainer = PPODataParallelTrainer(rlhf_model,
# rlhf_ref,
# reward_model,
# dummy_inputs.shape,
# rlhf_params_path,
# sft_params_path=model_params_path,
# reward_params_path=reward_params_path)

trainer.train(dataloader, 2)
# trainer.train(dataloader, 2)

0 comments on commit 8aecd88

Please sign in to comment.