Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Akoumparouli/nemo ux mixtral export #9603

Merged
merged 5 commits into from
Jul 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,122 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
)
def _import_moe_w1_w3(gate_proj, up_proj):
return torch.cat((gate_proj, up_proj), axis=0)


@io.model_exporter(MixtralModel, "hf")
class HFMixtralExporter(io.ModelConnector[MixtralModel, "MixtralForCausalLM"]):
def init(self) -> "MixtralForCausalLM":
from transformers import AutoModelForCausalLM

return AutoModelForCausalLM.from_config(self.config)

def apply(self, output_path: Path) -> Path:
# TODO: Make it work with lazy init
# with torch.device("meta"):
# target = self.init()
Dismissed Show dismissed Hide dismissed
target = self.init()
source, _ = self.nemo_load(str(self))
target = self.convert_state(source, target)

# TODO: Make sure we don't need to do this
target = target.cpu()
target.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)

return output_path

def convert_state(self, source, target):
mapping = {
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
"decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight",
# MoE
"decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight": "model.layers.*.block_sparse_moe.experts.*.w2.weight",
"decoder.layers.*.mlp.router.weight": "model.layers.*.block_sparse_moe.gate.weight",
# lm-head
"decoder.final_layernorm.weight": "model.norm.weight",
"output_layer.weight": "lm_head.weight",
}

return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_moe_w1_w3])

@property
def tokenizer(self):
return io.load_ckpt(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "MixtralConfig":
source: MixtralConfig7B = io.load_ckpt(str(self)).model.config

from transformers import MixtralConfig as HfMixtralConfig

return HfMixtralConfig(
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
max_position_embeddings=source.max_position_embeddings,
seq_length=source.max_position_embeddings,
# RoPe
rope_theta=source.rotary_base,
# transformer config
num_attention_heads=source.num_attention_heads,
num_key_value_heads=source.num_query_groups,
num_local_experts=config.num_moe_experts,
num_experts_per_tok=config.moe_router_topk,
# norm
rms_norm_eps=source.layernorm_epsilon,
# init
initializer_range=source.init_method_std,
# vocab
vocab_size=self.tokenizer.vocab_size,
)


@io.state_transform(
source_key="decoder.layers.*.self_attention.linear_qkv.weight",
target_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
qkv_total_dim = head_num + 2 * num_query_groups

linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))

q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()

return q_proj, k_proj, v_proj


@io.state_transform(
source_key="decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight",
target_key=(
"model.layers.*.block_sparse_moe.experts.*.w1.weight",
"model.layers.*.block_sparse_moe.experts.*.w3.weight",
),
)
def _export_moe_w1_w3(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj
Loading