Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nemo ux mixtral 8x22b config #9977

Merged
merged 6 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MistralConfig7B,
MistralModel,
MixtralConfig8x7B,
MixtralConfig8x22B,
MixtralModel,
gpt_data_step,
gpt_forward_step,
Expand All @@ -54,6 +55,7 @@
"MistralConfig7B",
"MistralModel",
"MixtralConfig8x7B",
"MixtralConfig8x22B",
"MixtralModel",
"LlamaConfig",
"Llama2Config7B",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
LlamaModel,
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel
from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel

__all__ = [
"GPTConfig",
Expand Down
48 changes: 43 additions & 5 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -51,10 +51,44 @@ class MixtralConfig8x7B(GPTConfig):
params_dtype: torch.dtype = torch.bfloat16


@dataclass
class MixtralConfig8x22B(GPTConfig):
"""
Config for Mixtral-8x7B model
Official announcement: https://mistral.ai/news/mixtral-8x22b/
"""

normalization: str = "RMSNorm"
activation_func: Callable = F.silu
position_embedding_type: str = "rope"
add_bias_linear: bool = False
gated_linear_unit: bool = True
apply_query_key_layer_scaling: bool = False # TODO: Should this be True?

num_layers: int = 56
hidden_size: int = 6144
num_attention_heads: int = 48
num_query_groups: int = 8
ffn_hidden_size: int = 16384
max_position_embeddings: int = 65536
seq_length: int = 4096 # 65536
# MoE
num_moe_experts: int = 8
moe_router_topk: int = 2

init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
# rotary
rotary_percent: float = 0 # TODO: @akoumparouli: is this correct?
rotary_base: float = 1000000
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16


class MixtralModel(GPTModel):
def __init__(
self,
config: Optional[MixtralConfig8x7B] = None,
config: Optional[Union[MixtralConfig8x7B, MixtralConfig8x22B]] = None,
optim: Optional[OptimizerModule] = None,
tokenizer: Optional["TokenizerSpec"] = None,
model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
Expand Down Expand Up @@ -106,11 +140,14 @@ def tokenizer(self) -> "AutoTokenizer":
return AutoTokenizer(str(self))

@property
def config(self) -> MixtralConfig8x7B:
def config(self) -> MixtralConfig8x7B | MixtralConfig8x22B:
from transformers import MixtralConfig as HfMixtralConfig

config = HfMixtralConfig.from_pretrained(str(self))
return MixtralConfig8x7B(
config_cls = MixtralConfig8x7B
if '8x22b' in str(self).lower():
config_cls = MixtralConfig8x22B
return config_cls(
bf16=getattr(config, "torch_dtype", None) == torch.bfloat16,
activation_func=F.silu,
# network
Expand Down Expand Up @@ -239,7 +276,8 @@ def tokenizer(self):

@property
def config(self) -> "MixtralConfig":
source: MixtralConfig7B = io.load_ckpt(str(self)).model.config
# Either MixtralConfig8x7B or MixtralConfig8x22B
source: MixtralConfig8x7B = io.load_ckpt(str(self)).model.config

from transformers import MixtralConfig as HfMixtralConfig

Expand Down
64 changes: 64 additions & 0 deletions nemo/collections/llm/recipes/mixtral_8x22b_4k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytorch_lightning as pl

from nemo import lightning as nl
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.api import squad
from nemo.collections.llm.gpt.model.llama import MixtralConfig8x22B, MixtralModel
from nemo.collections.llm.peft.api import gpt_lora
from nemo.collections.llm.recipes.log.default import default_log
from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing
from nemo.collections.llm.utils import Partial, factory

NAME = "mixtral_8x22b_4k"


@factory(name=NAME)
def model() -> pl.LightningModule:
return MixtralModel(MixtralConfig8x22B(seq_length=4096))


@factory(name=NAME)
def trainer(devices=8) -> nl.Trainer:
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=8,
sequence_parallel=True,
)

return nl.Trainer(
devices=devices,
max_steps=100,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)


@factory(name=NAME + "_hf")
def hf_resume() -> nl.AutoResume:
return nl.AutoResume(import_path="hf://mistralai/Mixtral-8x22B-v0.1")


@factory(name=NAME, for_task="llm.pretrain")
def pretrain_recipe() -> Partial:
return Partial(
pretrain,
model=model,
trainer=trainer,
data=squad,
log=default_log,
optim=adam_with_cosine_annealing,
)


@factory(name=NAME, for_task="llm.finetune")
def finetune_recipe() -> Partial:
return Partial(
finetune,
model=model,
trainer=trainer,
data=squad,
log=default_log,
optim=adam_with_cosine_annealing,
peft=gpt_lora,
resume=hf_resume,
)
Loading