Skip to content

Commit

Permalink
llama : support OLMoE (#9462)
Browse files Browse the repository at this point in the history
  • Loading branch information
2015aroras authored Sep 16, 2024
1 parent 95ca851 commit 0aadac1
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
- [x] [OLMo](https://allenai.org/olmo)
- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924)
- [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330)
- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia)
- [x] [Snowflake-Arctic MoE](https://huggingface.co/collections/Snowflake/arctic-66290090abe542894a5ac520)
Expand Down
60 changes: 60 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2998,6 +2998,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("OlmoeForCausalLM")
class OlmoeModel(Model):
model_arch = gguf.MODEL_ARCH.OLMOE

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)

_experts: list[dict[str, Tensor]] | None = None

# Copied from: Qwen2MoeModel
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

# Copied from: Qwen2MoeModel
def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@Model.register("JinaBertModel", "JinaBertForMaskedLM")
class JinaBertV2Model(BertModel):
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
Expand Down Expand Up @@ -375,6 +376,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
Expand Down Expand Up @@ -1027,6 +1029,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.OLMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
],
MODEL_ARCH.OPENELM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
30 changes: 15 additions & 15 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron
"model.embed_tokens", # llama-hf nemotron olmoe
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
Expand Down Expand Up @@ -54,7 +54,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -66,7 +66,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2
"model.norm", # llama-hf baichuan internlm2 olmoe
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
Expand Down Expand Up @@ -98,7 +98,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
Expand Down Expand Up @@ -142,7 +142,7 @@ class TensorNameMap:

# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
Expand All @@ -154,7 +154,7 @@ class TensorNameMap:

# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
Expand All @@ -167,7 +167,7 @@ class TensorNameMap:

# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
Expand All @@ -185,7 +185,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
Expand Down Expand Up @@ -229,7 +229,7 @@ class TensorNameMap:
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
"layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
Expand All @@ -253,7 +253,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"model.layers.{bid}.mlp.gate", # qwen2moe
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
),
Expand Down Expand Up @@ -295,7 +295,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
),

MODEL_TENSOR.FFN_UP_SHEXP: (
Expand Down Expand Up @@ -327,7 +327,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
),

MODEL_TENSOR.FFN_GATE_SHEXP: (
Expand Down Expand Up @@ -367,7 +367,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged)
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
),

MODEL_TENSOR.FFN_DOWN_SHEXP: (
Expand All @@ -378,7 +378,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.q_norm", # cohere
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
Expand All @@ -387,7 +387,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.k_norm", # cohere
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm
Expand Down
Loading

0 comments on commit 0aadac1

Please sign in to comment.