Skip to content

Commit

Permalink
added mamba experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Feb 23, 2024
1 parent a924069 commit ea7607b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 27 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ Feedback on any of our discussion, issue and pull request threads are welcomed!

## What's New in version 1.2.0.dev1

- Google's Gemma and MAMBA architectures.
- Google's Gemma architecture.
- Data parallel distributed RLHFPPO and RLHFDPO.
- True random number generators in Jax which do not need the verbose code (examples shown in next sections).

There are experimental features (like MAMBA architecture) in the repo which is not available via the package,
pending tests.

## Quick install

You will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)
Expand Down
5 changes: 0 additions & 5 deletions nanodl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@
UNetResidualBlock
)

from nanodl.__src.models.mamba import (
Mamba,
MambaDataParallelTrainer,
MambaBlock
)

from nanodl.__src.models.transformer import (
Transformer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flax.training import train_state
from typing import Tuple, Any, Optional, Iterable

########## EXPERIMENMTAL ############

class MambaBlock(nn.Module):
"""
Expand Down Expand Up @@ -252,8 +253,8 @@ def setup(self):
# You might need to implement a custom method for weight tying or handle it outside the model definition.

def __call__(self,
input_ids: jnp.Array,
training: bool = False) -> jnp.Array:
input_ids: jnp.ndarray,
training: bool = False) -> jnp.ndarray:

x = self.embedding(input_ids)
for layer in self.layers:
Expand All @@ -264,6 +265,19 @@ def __call__(self,
return logits


def zero_pad(self, arr, max_length):
current_length = arr.shape[1]
num_zeros = max_length - current_length

if num_zeros > 0:
zeros = jnp.zeros((arr.shape[0], num_zeros), dtype=arr.dtype)
padded_array = jnp.concatenate([arr, zeros], axis=1)
else:
padded_array = arr

return padded_array


def generate(self,
x: Optional[jnp.ndarray] = None,
temperature: float = 1.0,
Expand All @@ -276,8 +290,10 @@ def generate(self,
output_sequence = []

# Autoregressive decoding loop
for _ in range(self.max_length):
decoder_output = self.__call__(decoder_input, training=False)[0]
print(self.zero_pad(decoder_input, self.max_length).shape)
for _ in range(self.max_length-1):
decoder_output = self.__call__(self.zero_pad(decoder_input, self.max_length), training=False)[0]
print(decoder_output.shape)
last_token_logits = decoder_output[:, -1, :]
scaled_logits = last_token_logits / temperature
next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1)
Expand All @@ -291,7 +307,7 @@ def generate(self,
output_sequence.append(next_token.item())
decoder_input = jnp.concatenate([decoder_input, jnp.array([[next_token]])], axis=1)

if next_token.item() == self.end_token:
if next_token.item() == self.end_token or len(output_sequence) == self.max_length:
break

return jnp.array(output_sequence)
Expand All @@ -306,8 +322,8 @@ def generate_batch(self,
decoder_input = x if x is not None else jnp.full((batch_size, 1), self.start_token)
output_sequences = jnp.zeros((batch_size, self.max_length), dtype=jnp.int32)

for i in range(self.max_length):
decoder_output = self.__call__(decoder_input, training=False)[0]
for i in range(self.max_length-1):
decoder_output = self.__call__(self.zero_pad(decoder_input, self.max_length), training=False)[0]
last_token_logits = decoder_output[:, -1, :]
scaled_logits = last_token_logits / temperature
next_token_probabilities = jax.nn.softmax(scaled_logits, axis=-1)
Expand All @@ -321,13 +337,12 @@ def generate_batch(self,
output_sequences = output_sequences.at[:, i].set(next_token)
decoder_input = jnp.concatenate([decoder_input, next_token[:, None]], axis=1)

if jnp.all(next_token == self.end_token):
if jnp.all(next_token == self.end_token) or len(output_sequences) == self.max_length:
break

return output_sequences



class MambaDataParallelTrainer:
"""
Trainer class using data parallelism with JAX.
Expand Down Expand Up @@ -531,19 +546,7 @@ def load_params(self, filename: str):

print(outputs.shape)

# Training on data
trainer = MambaDataParallelTrainer(model, dummy_inputs.shape, 'params.pkl')
trainer.train(train_loader=dataloader,
num_epochs=2,
val_loader=dataloader)

print(trainer.evaluate(dataloader))

# Generating from a start token
start_tokens = jnp.array([[123, 456]])

# Remember to load the trained parameters
params = trainer.load_params('params.pkl')
outputs = model.apply({'params': params},
start_tokens,
rngs={'dropout': jax.random.PRNGKey(2)},
Expand Down
Empty file added nanodl/__src/models/rlhf.py
Empty file.

0 comments on commit ea7607b

Please sign in to comment.