Skip to content
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
551a64f
add grok-2 support
CISC Aug 24, 2025
301ba77
type fix
CISC Aug 24, 2025
cf87c76
type fix
CISC Aug 24, 2025
f582b84
type fix
CISC Aug 24, 2025
711ab17
"fix" vocab for invalid sequences
CISC Aug 24, 2025
8edece8
fix expert tensor mapping and spaces in vocab
CISC Aug 24, 2025
3ef6cf5
add chat template
CISC Aug 25, 2025
25e4e5f
fix norm tensor mapping
CISC Aug 25, 2025
4a53f13
rename layer_out_norm to ffn_post_norm
CISC Aug 25, 2025
e0a0024
ensure ffn_post_norm is mapped
CISC Aug 25, 2025
d7efed8
fix experts merging
CISC Aug 25, 2025
92266e9
remove erroneous FFN_GATE entry
CISC Aug 26, 2025
6b3f775
concatenate split tensors and add more metadata
CISC Aug 26, 2025
c556663
process all expert layers and try cat instead of hstack
CISC Aug 27, 2025
9f86876
add support for community BPE vocab
CISC Aug 27, 2025
5d4e407
fix expert feed forward length and ffn_down concat
CISC Aug 27, 2025
3e83c64
commit this too
CISC Aug 27, 2025
b1627ce
add ffn_up/gate/down, unsure if sequence is right
CISC Aug 28, 2025
00481af
add ffn_gate/down/up to tensor names
CISC Aug 28, 2025
2e8b67b
correct residual moe (still not working)
CISC Aug 30, 2025
94bcbbf
mess--
CISC Aug 30, 2025
b7675ea
fix embedding scale being applied twice
CISC Sep 1, 2025
6cf16aa
add built in chat template
CISC Sep 1, 2025
4abde12
change beta fast for grok if default value
CISC Sep 3, 2025
705f84a
remove spm vocab in favor of community bpe vocab
CISC Sep 3, 2025
a8fa83f
change attention temp length metadata type to integer
CISC Sep 3, 2025
05b52fa
update attention temp length metadata
CISC Sep 3, 2025
b7bfc9a
remove comment
CISC Sep 3, 2025
c0d755c
Merge branch 'master' into cisc/grok-2
CISC Sep 3, 2025
0408a4f
replace M_SQRT2 with std::sqrt(2)
CISC Sep 3, 2025
ed4d8f2
Merge branch 'master' into cisc/grok-2
CISC Sep 8, 2025
d032a1b
add yarn metadata, move defaults to hparams
CISC Sep 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 74 additions & 23 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
res = "qwen2"
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe"
Expand Down Expand Up @@ -2682,57 +2685,105 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield (new_name, data_torch)


@ModelBase.register("GrokForCausalLM")
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
class GrokModel(TextModel):
model_arch = gguf.MODEL_ARCH.GROK

def set_vocab(self):
self._set_vocab_sentencepiece()
if (self.dir_model / 'tokenizer.model').is_file():
self._set_vocab_sentencepiece()
return

if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
sys.exit(1)

self._set_vocab_gpt2()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def set_gguf_parameters(self):
super().set_gguf_parameters()

_experts: list[dict[str, Tensor]] | None = None
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)

if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]

if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)

# Treat "original" as "yarn", seems to have been a mistake
if self.hparams.get("rope_type") in ("yarn", "original"):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])

if temp_len := self.hparams.get("attn_temperature_len"):
self.gguf_writer.add_attn_temperature_length(temp_len)

self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])

_experts: list[dict[str, list[Tensor]]] | None = None
_cur_expert = ""

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors: list[tuple[str, Tensor]] = []
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name

if not is_expert:
tensors.append((self.map_tensor_name(name), data_torch))

# process the experts separately
if name.find(".moe.") != -1:
if is_expert or self._cur_expert:
n_experts = self.hparams["num_local_experts"]

assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch
# concatenate split tensors
if name in self._experts[bid]:
self._cur_expert = name
self._experts[bid][name].append(data_torch)
return []
elif is_expert:
self._cur_expert = name
self._experts[bid][name] = [data_torch]
return []
else:
self._cur_expert = ""

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
for bid in range(self.block_count):
if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
datas: list[Tensor] = []

# merge the experts into a single 3d tensor
for wid in ["linear", "linear_1", "linear_v"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
if ename not in self._experts[bid]:
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
tensor_list = self._experts[bid][ename]
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
del self._experts[bid][ename]

for xid in range(n_experts):
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)

data_torch = torch.stack(datas, dim=0)
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"

merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"

new_name = self.map_tensor_name(merged_name)
new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []
yield (new_name, data_torch)

return [(self.map_tensor_name(name), data_torch)]
yield from tensors


@ModelBase.register("DbrxForCausalLM")
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
]


Expand Down
4 changes: 4 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class LLM:
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
Expand Down Expand Up @@ -145,6 +146,8 @@ class Attention:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_SCALE = "{arch}.attention.output_scale"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
Expand Down Expand Up @@ -1110,6 +1113,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.GPTNEOX: [
Expand Down
9 changes: 9 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,9 @@ def add_logit_scale(self, value: float) -> None:
def add_attn_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_router_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

Expand Down Expand Up @@ -826,6 +829,12 @@ def add_sliding_window(self, value: int) -> None:
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)

def add_attn_output_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)

def add_attn_temperature_length(self, value: int) -> None:
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand Down
8 changes: 6 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"model.layers.{bid}.pre_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
Expand Down Expand Up @@ -272,6 +273,7 @@ class TensorNameMap:
"transformer.layer.{bid}.sa_layer_norm", # distillbert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"model.layers.{bid}.post_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),

Expand Down Expand Up @@ -306,6 +308,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"model.layers.{bid}.pre_moe_norm", # grok-2
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
Expand All @@ -325,10 +328,11 @@ class TensorNameMap:

# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
"model.layers.{bid}.feed_forward.up_proj",
"model.layers.{bid}.post_moe_norm", # grok-2
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand Down
7 changes: 7 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
Expand Down Expand Up @@ -167,6 +168,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },

Expand Down Expand Up @@ -395,12 +398,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
},
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ enum llm_kv {
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_SWIN_NORM,
LLM_KV_RESCALE_EVERY_N_LAYERS,
Expand Down Expand Up @@ -171,6 +172,8 @@ enum llm_kv {
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,

Expand Down
17 changes: 17 additions & 0 deletions src/llama-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
};

llm_chat_template llm_chat_template_from_str(const std::string & name) {
Expand Down Expand Up @@ -204,6 +205,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("<seed:bos>")) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
Expand Down Expand Up @@ -763,6 +766,20 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<seed:bos>assistant\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
for (auto message : chat) {
std::string role(message->role);
if (role == "system") {
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "user") {
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
} else if (role == "assistant") {
ss << "Assistant: " << message->content << "<|separator|>\n\n";
}
}
if (add_ass) {
ss << "Assistant:";
}
} else {
// template not supported
return -1;
Expand Down
1 change: 1 addition & 0 deletions src/llama-chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

Expand Down
4 changes: 4 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ llama_context::llama_context(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}

if (model.arch == LLM_ARCH_GROK && params.yarn_beta_fast == 32.0f) {
cparams.yarn_beta_fast = 8.0f;
}

cparams.yarn_attn_factor *= hparams.rope_attn_factor;

if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
Expand Down
6 changes: 3 additions & 3 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,14 +1290,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(

if (arch == LLM_ARCH_GROK) {
// need to do the following:
// multiply by attn_output_multiplyer of 0.08838834764831845
// multiply by attn_output_multiplier
// and then :
// kq = 30 * tanh(kq / 30)
// before the softmax below

kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
cb(kq, "kq_tanh", il);
kq = ggml_scale(ctx0, kq, 30);
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
cb(kq, "kq_scaled", il);
}

Expand Down
9 changes: 7 additions & 2 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ struct llama_hparams {
float f_norm_rms_eps;
float f_norm_group_eps;

float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
float f_attn_logit_softcapping = 50.0f;
float f_router_logit_softcapping = 30.0f;
float f_final_logit_softcapping = 30.0f;

// for RWKV
uint32_t rescale_every_n_layers = 0;
Expand Down Expand Up @@ -135,6 +136,10 @@ struct llama_hparams {
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;

// grok-2
float f_attn_out_scale = 0.0f;
uint32_t attn_temp_length = 0;

bool causal_attn = true;
bool use_alibi = false;
bool attn_soft_cap = false;
Expand Down
Loading
Loading