Skip to content

Commit

Permalink
added MAMBA architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
HMUNACHI committed Feb 23, 2024
1 parent d2f816f commit a924069
Show file tree
Hide file tree
Showing 4 changed files with 543 additions and 38 deletions.
13 changes: 13 additions & 0 deletions nanodl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@
UNetResidualBlock
)

from nanodl.__src.models.mamba import (
Mamba,
MambaDataParallelTrainer,
MambaBlock
)

from nanodl.__src.models.transformer import (
Transformer,
TransformerDataParallelTrainer,
Expand Down Expand Up @@ -231,6 +237,9 @@
"Mixtral",
"MixtralDecoder",
"MixtralDecoderBlock",
"Mamba",
"MambaDataParallelTrainer",
"MambaBlock"
"Whisper",
"WhisperDataParallelTrainer",
"WhisperSpeechEncoder",
Expand Down Expand Up @@ -321,11 +330,15 @@ def test_jax(jax):
def test_optax(optax):
optimizer = optax.sgd(learning_rate=0.1)

def test_einops(einops):
arr = einops.rearrange([1, 2, 3], 'a b c -> b a c')

def main():
try:
flax = check_library_installed('flax')
jax = check_library_installed('jax')
optax = check_library_installed('optax')
einops = check_library_installed('einops')

test_flax(flax)
test_jax(jax)
Expand Down
4 changes: 1 addition & 3 deletions nanodl/__src/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,7 @@ def __call__(self,

class Gemma(nn.Module):
"""
Implements the LLaMA2 model for text generation, featuring grouped rotary positional embeddings.
LLaMA2 enhances the transformer architecture by incorporating grouped rotary positional embeddings within its decoder blocks, aiming to improve the model's understanding of positional context and its ability to generate coherent and contextually relevant text.
Implements the Gemma model for text generation, featuring GQA + RMSNorm + RoPE.
Attributes:
num_layers (int): Number of layers (blocks) in the LLaMA2 model.
Expand Down
Loading

0 comments on commit a924069

Please sign in to comment.