Skip to content

Commit

Permalink
nemo ux mixtral 8x22b config (NVIDIA#9977)
Browse files Browse the repository at this point in the history
* nemo ux mixtral 8x22b config

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add mixtral 8x22b recipe

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* add note

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* fix type hint

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

* fix type hint

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
2 people authored and XuesongYang committed Jan 18, 2025
1 parent d814f04 commit b4dbe17
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 6 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MistralConfig7B,
MistralModel,
MixtralConfig8x7B,
MixtralConfig8x22B,
MixtralModel,
gpt_data_step,
gpt_forward_step,
Expand All @@ -54,6 +55,7 @@
"MistralConfig7B",
"MistralModel",
"MixtralConfig8x7B",
"MixtralConfig8x22B",
"MixtralModel",
"LlamaConfig",
"Llama2Config7B",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LlamaModel,
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel

__all__ = [
"GPTConfig",
Expand Down
48 changes: 43 additions & 5 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -52,10 +52,44 @@ class MixtralConfig8x7B(GPTConfig):
params_dtype: torch.dtype = torch.bfloat16


@dataclass
class MixtralConfig8x22B(GPTConfig):
"""
Config for Mixtral-8x7B model
Official announcement: https://mistral.ai/news/mixtral-8x22b/
"""

normalization: str = "RMSNorm"
activation_func: Callable = F.silu
position_embedding_type: str = "rope"
add_bias_linear: bool = False
gated_linear_unit: bool = True
apply_query_key_layer_scaling: bool = False # TODO: Should this be True?

num_layers: int = 56
hidden_size: int = 6144
num_attention_heads: int = 48
num_query_groups: int = 8
ffn_hidden_size: int = 16384
max_position_embeddings: int = 65536
seq_length: int = 4096 # 65536
# MoE
num_moe_experts: int = 8
moe_router_topk: int = 2

init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
# rotary
rotary_percent: float = 0 # TODO: @akoumparouli: is this correct?
rotary_base: float = 1000000
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16


class MixtralModel(GPTModel):
def __init__(
self,
config: Optional[MixtralConfig8x7B] = None,
config: Optional[Union[MixtralConfig8x7B, MixtralConfig8x22B]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
Expand Down Expand Up @@ -107,11 +141,14 @@ def tokenizer(self) -> "AutoTokenizer":
return AutoTokenizer(str(self))

@property
def config(self) -> MixtralConfig8x7B:
def config(self) -> MixtralConfig8x7B | MixtralConfig8x22B:
from transformers import MixtralConfig as HfMixtralConfig

config = HfMixtralConfig.from_pretrained(str(self))
return MixtralConfig8x7B(
config_cls = MixtralConfig8x7B
if '8x22b' in str(self).lower():
config_cls = MixtralConfig8x22B
return config_cls(
bf16=getattr(config, "torch_dtype", None) == torch.bfloat16,
activation_func=F.silu,
# network
Expand Down Expand Up @@ -241,7 +278,8 @@ def tokenizer(self):

@property
def config(self) -> "MixtralConfig":
source: MixtralConfig7B = io.load_ckpt(str(self)).model.config
# Either MixtralConfig8x7B or MixtralConfig8x22B
source: MixtralConfig8x7B = io.load_ckpt(str(self)).model.config

from transformers import MixtralConfig as HfMixtralConfig

Expand Down
64 changes: 64 additions & 0 deletions nemo/collections/llm/recipes/mixtral_8x22b_4k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytorch_lightning as pl

from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.api import squad
from nemo.collections.llm.gpt.model.llama import MixtralConfig8x22B, MixtralModel
from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.recipes.log.default import default_log
from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.utils import Partial, factory

NAME = "mixtral_8x22b_4k"


@factory(name=NAME)
def model() -> pl.LightningModule:
return MixtralModel(MixtralConfig8x22B(seq_length=4096))


@factory(name=NAME)
def trainer(devices=8) -> nl.Trainer:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=8,
sequence_parallel=True,
)

return nl.Trainer(
devices=devices,
max_steps=100,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)


@factory(name=NAME + "_hf")
def hf_resume() -> nl.AutoResume:
return nl.AutoResume(import_path="hf://mistralai/Mixtral-8x22B-v0.1")


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
data=squad,
log=default_log,
optim=adam_with_cosine_annealing,
)


@factory(name=NAME, for_task="llm.finetune")
def finetune_recipe() -> Partial:
return Partial(
finetune,
model=model,
trainer=trainer,
data=squad,
log=default_log,
optim=adam_with_cosine_annealing,
peft=gpt_lora,
resume=hf_resume,
)

0 comments on commit b4dbe17

Please sign in to comment.