Skip to content

Commit

Permalink
mixtral export (#9603)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Jul 5, 2024
1 parent 10768ae commit d4a32d0
Showing 1 changed file with 119 additions and 0 deletions.
119 changes: 119 additions & 0 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,122 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v):
)
def _import_moe_w1_w3(gate_proj, up_proj):
return torch.cat((gate_proj, up_proj), axis=0)


@io.model_exporter(MixtralModel, "hf")
class HFMixtralExporter(io.ModelConnector[MixtralModel, "MixtralForCausalLM"]):
def init(self) -> "MixtralForCausalLM":
from transformers import AutoModelForCausalLM

return AutoModelForCausalLM.from_config(self.config)

def apply(self, output_path: Path) -> Path:
# TODO: Make it work with lazy init
# with torch.device("meta"):
# target = self.init()
target = self.init()
source, _ = self.nemo_load(str(self))
target = self.convert_state(source, target)

# TODO: Make sure we don't need to do this
target = target.cpu()
target.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)

return output_path

def convert_state(self, source, target):
mapping = {
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
"decoder.layers.*.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight",
# MoE
"decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight": "model.layers.*.block_sparse_moe.experts.*.w2.weight",
"decoder.layers.*.mlp.router.weight": "model.layers.*.block_sparse_moe.gate.weight",
# lm-head
"decoder.final_layernorm.weight": "model.norm.weight",
"output_layer.weight": "lm_head.weight",
}

return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_moe_w1_w3])

@property
def tokenizer(self):
return io.load_ckpt(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "MixtralConfig":
source: MixtralConfig7B = io.load_ckpt(str(self)).model.config

from transformers import MixtralConfig as HfMixtralConfig

return HfMixtralConfig(
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
max_position_embeddings=source.max_position_embeddings,
seq_length=source.max_position_embeddings,
# RoPe
rope_theta=source.rotary_base,
# transformer config
num_attention_heads=source.num_attention_heads,
num_key_value_heads=source.num_query_groups,
num_local_experts=config.num_moe_experts,
num_experts_per_tok=config.moe_router_topk,
# norm
rms_norm_eps=source.layernorm_epsilon,
# init
initializer_range=source.init_method_std,
# vocab
vocab_size=self.tokenizer.vocab_size,
)


@io.state_transform(
source_key="decoder.layers.*.self_attention.linear_qkv.weight",
target_key=(
"model.layers.*.self_attn.q_proj.weight",
"model.layers.*.self_attn.k_proj.weight",
"model.layers.*.self_attn.v_proj.weight",
),
)
def _export_qkv(ctx: io.TransformCTX, linear_qkv):
megatron_config = ctx.source.config

head_num = megatron_config.num_attention_heads
num_query_groups = megatron_config.num_query_groups
heads_per_group = head_num // num_query_groups
hidden_size = megatron_config.hidden_size
head_num = megatron_config.num_attention_heads
head_size = hidden_size // head_num
qkv_total_dim = head_num + 2 * num_query_groups

linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size])
q_slice = torch.cat(
[
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
for i in range(num_query_groups)
]
)
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))

q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()

return q_proj, k_proj, v_proj


@io.state_transform(
source_key="decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight",
target_key=(
"model.layers.*.block_sparse_moe.experts.*.w1.weight",
"model.layers.*.block_sparse_moe.experts.*.w3.weight",
),
)
def _export_moe_w1_w3(linear_fc1):
gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)

return gate_proj, up_proj

0 comments on commit d4a32d0

Please sign in to comment.