From cd9058652e8ee96d94f8522e37c54751c4861af9 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 18 May 2026 15:16:21 +0200 Subject: [PATCH 1/8] StepFun 3.5 MTP --- common/speculative.cpp | 13 ++ conversion/step3.py | 109 ++++++++++- convert_hf_to_gguf.py | 5 +- gguf-py/gguf/constants.py | 7 + scripts/fix_step35_mtp_metadata.py | 237 +++++++++++++++++++++++ src/llama-context.cpp | 9 + src/llama-context.h | 1 + src/llama-cparams.h | 5 + src/llama-ext.h | 10 + src/models/models.h | 4 + src/models/step35.cpp | 300 +++++++++++++++++++++++++++-- 11 files changed, 677 insertions(+), 23 deletions(-) create mode 100755 scripts/fix_step35_mtp_metadata.py diff --git a/common/speculative.cpp b/common/speculative.cpp index 253a5ececbb..3900fe08a33 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -665,6 +665,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } + // First draft step uses the first MTP block (step 0). Archs with a + // single MTP block ignore this; multi-block archs (Step-3.5-Flash) use + // it to round-robin across their N MTP layers. + llama_set_mtp_step(ctx_dft, 0); + int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); @@ -729,6 +734,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { break; } + // Step i+1: feed the i-th sampled draft token into the (i+1)-th + // MTP block. Multi-block archs round-robin via mtp_step % N. + llama_set_mtp_step(ctx_dft, (uint32_t)(i + 1)); + // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { @@ -739,6 +748,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { ++i; } + // Reset MTP step so a subsequent non-MTP decode on this context doesn't + // inherit a stale offset. + llama_set_mtp_step(ctx_dft, 0); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { diff --git a/conversion/step3.py b/conversion/step3.py index ba867fb831b..bd0adb3f6ad 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -99,6 +99,24 @@ class Step3VLTextModel(Qwen3Model): class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 + # --mtp / --no-mtp toggles (see convert_hf_to_gguf.py main()). + # Unlike Qwen3.5 which stores MTP under a `mtp.*` namespace, Step3.5 just + # appends MTP layers at `model.layers.{num_hidden_layers + i}`; these flags + # filter by layer index instead of by name prefix. + no_mtp: bool = False + mtp_only: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # NextN/MTP layers are appended past num_hidden_layers; extend the + # tensor map to cover them so the MTP block's tensors get correctly + # indexed names. When --no-mtp drops the MTP blocks, fall back to the + # base num_hidden_layers so we don't reserve unused slots. + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + if n_nextn > 0 and not self.no_mtp: + self.block_count = int(self.hparams["num_hidden_layers"]) + n_nextn + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + def set_gguf_parameters(self): rope_theta = self.hparams.get("rope_theta") if isinstance(rope_theta, list): @@ -119,8 +137,25 @@ def set_gguf_parameters(self): n_head_swa = attn_other.get("num_attention_heads", n_head_base) n_kv_swa = attn_other.get("num_attention_groups", n_kv_base) - layer_types = layer_types[: self.block_count] - partial_rotary_factors = partial_rotary_factors[: self.block_count] + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) + + # The Step3p5 HF checkpoint stores layer_types/partial_rotary_factors + # entries for the MTP blocks past num_hidden_layers; preserve them so + # the MTP layer's attention shape, SWA flag, and partial RoPE dim are + # set correctly. Pad with full-attention defaults if the checkpoint + # truncated them. + def _pad(arr, n, default): + arr = list(arr) + if len(arr) < n: + arr = arr + [default] * (n - len(arr)) + return arr[:n] + + layer_types = _pad(layer_types, self.block_count, "full_attention") + partial_rotary_factors = _pad( + partial_rotary_factors, + self.block_count, + 0.5, # full_attention default for Step3p5 + ) assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types] kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] @@ -157,14 +192,25 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) - # Optional per-layer SwiGLU clamps. + # Optional per-layer SwiGLU clamps. MTP layers default to no clamping (0.0). if (limits := self.hparams.get("swiglu_limits")) is not None: - limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]] + limits_f = _pad( + [0.0 if v is None else float(v) for v in limits], + self.block_count, + 0.0, + ) self.gguf_writer.add_swiglu_clamp_exp(limits_f) if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None: - limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]] + limits_shared_f = _pad( + [0.0 if v is None else float(v) for v in limits_shared], + self.block_count, + 0.0, + ) self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f) + if n_nextn > 0 and not self.no_mtp: + self.gguf_writer.add_nextn_predict_layers(n_nextn) + @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: name, gen = item @@ -175,13 +221,41 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca return super().filter_tensors((name, gen)) + def _is_mtp_layer(self, bid: int | None) -> bool: + if bid is None: + return False + n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) + return bid >= n_main + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - # remove mtp layers - if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None: - il = int(m.group(1)) - n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) - if il >= n_main: + is_mtp = self._is_mtp_layer(bid) + + # --no-mtp: drop the appended MTP block(s) entirely. + if is_mtp and self.no_mtp: + return + # --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/lm_head + # (so the resulting GGUF carries just the draft head). + if self.mtp_only and not is_mtp and bid is not None: + return + if self.mtp_only and bid is None: + # Top-level tensors: keep only shared embeddings/norm/lm_head. + keep = name in ( + "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", + ) + if not keep: return + + # The checkpoint nests the per-MTP-layer shared head under + # `model.layers.{N+i}.transformer.shared_head.{norm,output}.weight`; + # strip the `transformer.` infix and rename `output` → `head` so the + # existing NEXTN_SHARED_HEAD_{NORM,HEAD} tensor mapping picks them up. + # Mirrors vllm's `_rewrite_spec_layer_name` (step3p5_mtp.py). + if is_mtp: + if ".transformer." in name: + name = name.replace(".transformer.", ".") + if "shared_head.output" in name: + name = name.replace("shared_head.output", "shared_head.head") + if name.endswith("norm.weight"): data_torch += 1.0 @@ -190,6 +264,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): yield from super().modify_tensors(data_torch, name, bid) + def prepare_metadata(self, vocab_only: bool): + from_dir = self.fname_out.is_dir() + super().prepare_metadata(vocab_only=vocab_only) + + # Mirror Qwen3.5's behavior: when emitting a draft-only file into a + # directory, prefix with "mtp-" so it doesn't collide with the trunk. + if not self.mtp_only or not from_dir: + return + + output_type: str = self.ftype.name.partition("_")[2] + fname_default: str = gguf.naming_convention( + self.metadata.name, self.metadata.basename, self.metadata.finetune, + self.metadata.version, size_label=None, output_type=output_type, model_type=None) + self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf" + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3"). # llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS). diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 85527553563..cd19eebdfa3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -251,8 +251,9 @@ def main() -> None: if args.mtp or args.no_mtp: from conversion.qwen import _Qwen35MtpMixin - if not issubclass(model_class, _Qwen35MtpMixin): - logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today") + from conversion.step3 import Step35Model + if not (issubclass(model_class, _Qwen35MtpMixin) or issubclass(model_class, Step35Model)): + logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 and Step3.5 text variants today") sys.exit(1) if args.no_mtp: model_class.no_mtp = True diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b4dfd58382d..cca5f79d52b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3987,6 +3987,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, + # NextN/MTP tensors (Step3p5 draft head) + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.LLAMA_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/scripts/fix_step35_mtp_metadata.py b/scripts/fix_step35_mtp_metadata.py new file mode 100755 index 00000000000..855347f81f7 --- /dev/null +++ b/scripts/fix_step35_mtp_metadata.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +Fix Step-3.5 GGUF metadata for MTP support. + +Old (pre-MTP) Step-3.5 GGUFs were written with `step35.block_count = num_hidden_layers` +and per-layer arrays sized to the same length, so the appended MTP blocks have +no metadata slot. This script: + + * sets `step35.block_count = num_hidden_layers + num_nextn_predict_layers` + * appends `num_nextn_predict_layers` entries to every known per-layer array + (head_count, head_count_kv, sliding_window_pattern, swiglu.clamp_exp, + swiglu.clamp_shexp) so the C++ loader's length check passes + * writes `step35.nextn_predict_layers` + * copies all tensors over unchanged + +Defaults assume the MTP blocks are `sliding_attention` (Step-3.5-Flash): +head_count=96, head_count_kv=8, swa=True, swiglu_clamp_exp=0, swiglu_clamp_shexp=0. +Override with the per-array flags below if your model differs. + +Run conversion with the up-to-date `conversion/step3.py` to produce a correct +GGUF in the first place; this script exists only to retrofit older outputs. +""" +from __future__ import annotations + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +# Allow running from a llama.cpp checkout without installing gguf-py. +if "NO_LOCAL_GGUF" not in os.environ: + repo_root = Path(__file__).resolve().parent.parent + sys.path.insert(0, str(repo_root / "gguf-py")) + +import gguf # noqa: E402 + +logger = logging.getLogger("fix-step35-mtp-metadata") + + +# Per-layer metadata keys (step35-specific) we know how to extend. +PER_LAYER_KEYS: dict[str, gguf.GGUFValueType] = { + "step35.attention.head_count": gguf.GGUFValueType.UINT32, + "step35.attention.head_count_kv": gguf.GGUFValueType.UINT32, + "step35.attention.sliding_window_pattern": gguf.GGUFValueType.BOOL, + "step35.swiglu_clamp_exp": gguf.GGUFValueType.FLOAT32, + "step35.swiglu_clamp_shexp": gguf.GGUFValueType.FLOAT32, +} + +BLOCK_COUNT_KEY = "step35.block_count" +NEXTN_LAYERS_KEY = "step35.nextn_predict_layers" + + +def get_field_contents(reader: gguf.GGUFReader, key: str) -> Any: + field = reader.get_field(key) + return field.contents() if field else None + + +def field_main_type(reader: gguf.GGUFReader, key: str) -> gguf.GGUFValueType | None: + field = reader.get_field(key) + if field is None or not field.types: + return None + return field.types[0] + + +def make_extended_value( + reader: gguf.GGUFReader, + key: str, + sub_type: gguf.GGUFValueType, + n_total: int, + mtp_value: Any, + n_mtp: int, +) -> list[Any] | None: + """Return the new array for ``key`` (length ``n_total``) or None if the key + is absent and no broadcast is needed.""" + field = reader.get_field(key) + if field is None: + return None + + main_type = field.types[0] + if main_type == gguf.GGUFValueType.ARRAY: + current: list[Any] = list(field.contents()) + if len(current) >= n_total: + logger.info(" %s already length %d (>= %d), trimming", key, len(current), n_total) + return current[:n_total] + pad = n_total - len(current) + logger.info(" %s: extending %d -> %d (appending %d × %r)", key, len(current), n_total, pad, mtp_value) + return current + [mtp_value] * pad + + # Scalar — broadcast to length n_total, replacing the trailing n_mtp entries. + scalar = field.contents() + logger.info(" %s: scalar %r -> array of %d (last %d = %r)", key, scalar, n_total, n_mtp, mtp_value) + return [scalar] * (n_total - n_mtp) + [mtp_value] * n_mtp + + +def _bool(s: str) -> bool: + t = s.strip().lower() + if t in ("1", "true", "yes", "y"): return True + if t in ("0", "false", "no", "n"): return False + raise argparse.ArgumentTypeError(f"expected boolean, got {s!r}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("input", type=Path, help="input GGUF (Step-3.5, trunk-only)") + parser.add_argument("output", type=Path, help="output GGUF (overwritten if exists)") + parser.add_argument("--n-mtp", type=int, default=3, + help="number of MTP blocks to append (default: 3, matching Step-3.5-Flash)") + + # Per-MTP-layer values. Defaults are correct for Step-3.5-Flash whose MTP + # blocks are `sliding_attention` type. + parser.add_argument("--head-count", type=int, default=96, + help="MTP layer head_count (default 96, sliding_attention)") + parser.add_argument("--head-count-kv", type=int, default=8, + help="MTP layer head_count_kv (default 8)") + parser.add_argument("--swa", type=_bool, default=True, + help="MTP layer sliding-window flag (default true)") + parser.add_argument("--swiglu-clamp-exp", type=float, default=0.0, + help="MTP layer swiglu clamp exp (default 0.0 = no clamp)") + parser.add_argument("--swiglu-clamp-shexp", type=float, default=0.0, + help="MTP layer swiglu clamp shexp (default 0.0)") + + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, + format="%(levelname)s: %(message)s") + + if args.n_mtp <= 0: + logger.error("--n-mtp must be > 0; got %d", args.n_mtp) + sys.exit(1) + + logger.info("Reading %s", args.input) + reader = gguf.GGUFReader(args.input, "r") + + arch = get_field_contents(reader, gguf.Keys.General.ARCHITECTURE) + if arch != "step35": + logger.error("Expected arch 'step35', got %r — this script is Step-3.5-specific.", arch) + sys.exit(1) + + block_count_field = reader.get_field(BLOCK_COUNT_KEY) + if block_count_field is None: + logger.error("Missing %s in input GGUF.", BLOCK_COUNT_KEY) + sys.exit(1) + block_count = int(block_count_field.contents()) + + existing_nextn = get_field_contents(reader, NEXTN_LAYERS_KEY) + if existing_nextn is not None: + # block_count already includes the MTP blocks; back them out so the + # math below is independent of whether the input is old-style + # (block_count = trunk) or new-style (block_count = trunk + nextn). + n_main = block_count - int(existing_nextn) + logger.info("Input declares nextn_predict_layers=%d; trunk has %d main blocks.", + existing_nextn, n_main) + else: + n_main = block_count + logger.info("Input has no nextn_predict_layers key; treating block_count=%d as the trunk count.", + block_count) + + n_total = n_main + args.n_mtp + logger.info("Block count: trunk=%d, MTP=%d -> total=%d", n_main, args.n_mtp, n_total) + + per_key_value: dict[str, Any] = { + "step35.attention.head_count": args.head_count, + "step35.attention.head_count_kv": args.head_count_kv, + "step35.attention.sliding_window_pattern": args.swa, + "step35.swiglu_clamp_exp": args.swiglu_clamp_exp, + "step35.swiglu_clamp_shexp": args.swiglu_clamp_shexp, + } + + # Pre-compute the new array values so we can write them via the writer's + # standard add_key_value() path. + new_arrays: dict[str, list[Any]] = {} + for key, sub_type in PER_LAYER_KEYS.items(): + new_val = make_extended_value(reader, key, sub_type, n_total, per_key_value[key], args.n_mtp) + if new_val is not None: + new_arrays[key] = new_val + + logger.info("Writing %s", args.output) + writer = gguf.GGUFWriter( + path = args.output, + arch = arch, + endianess = reader.endianess, + ) + + # Pass 1: copy every existing KV except those we're rewriting. + rewritten = set(new_arrays.keys()) | {BLOCK_COUNT_KEY, NEXTN_LAYERS_KEY} + for field in reader.fields.values(): + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith("GGUF."): + continue + if field.name in rewritten: + continue + val_type = field.types[0] + sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None + writer.add_key_value(field.name, field.contents(), val_type, sub_type=sub_type) + + # Pass 2: rewritten metadata. + writer.add_uint32(BLOCK_COUNT_KEY, n_total) + writer.add_uint32(NEXTN_LAYERS_KEY, args.n_mtp) + + for key, values in new_arrays.items(): + sub_type = PER_LAYER_KEYS[key] + writer.add_key_value(key, values, gguf.GGUFValueType.ARRAY, sub_type=sub_type) + + # Tensors: copy unchanged. + total_bytes = 0 + for tensor in reader.tensors: + total_bytes += tensor.n_bytes + writer.add_tensor_info( + tensor.name, + tensor.data.shape, + tensor.data.dtype, + tensor.data.nbytes, + tensor.tensor_type, + ) + + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_ti_data_to_file() + + for tensor in reader.tensors: + writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) + bar.update(tensor.n_bytes) + + writer.close() + bar.close() + logger.info("Done. Wrote %s with block_count=%d, nextn_predict_layers=%d.", + args.output, n_total, args.n_mtp) + + +if __name__ == "__main__": + main() diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ad36c06667d..e8ba8497565 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -67,6 +67,7 @@ llama_context::llama_context( cparams.embeddings = params.embeddings; cparams.embeddings_pre_norm = false; cparams.embeddings_pre_norm_masked = false; + cparams.mtp_step = 0; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -1105,6 +1106,10 @@ void llama_context::set_embeddings_pre_norm(bool value, bool masked) { cparams.embeddings_pre_norm_masked = masked; } +void llama_context::set_mtp_step(uint32_t step) { + cparams.mtp_step = step; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -3600,6 +3605,10 @@ float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { return ctx->get_embeddings_pre_norm_ith(i); } +void llama_set_mtp_step(llama_context * ctx, uint32_t step) { + ctx->set_mtp_step(step); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index d03f681d4a1..6de8f6df6cb 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -111,6 +111,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_pre_norm(bool value, bool masked); + void set_mtp_step(uint32_t step); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 20ec59fe335..f4faac6d080 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -30,6 +30,11 @@ struct llama_cparams { bool embeddings; bool embeddings_pre_norm; // also extract the hidden state before the final output norm bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 + + // MTP draft-step index, used by archs with num_nextn_predict_layers > 1 to + // round-robin across MTP blocks (matches vllm's spec_step_idx). The graph + // builder selects `il = n_main + (mtp_step % nextn_predict_layers)`. + uint32_t mtp_step; bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index edfa71c207c..3c6fa3eb5e7 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -104,3 +104,13 @@ LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); + +// +// MTP draft-step index (round-robin selector across MTP blocks) +// + +// Set the MTP draft-step index for the next llama_decode call. Used by archs +// with num_nextn_predict_layers > 1 to round-robin across their MTP blocks +// (matches vllm's spec_step_idx). Pass step = 0 for the first draft token, +// step = 1 for the second, etc. The graph builder reads cparams.mtp_step. +LLAMA_API void llama_set_mtp_step(struct llama_context * ctx, uint32_t step); diff --git a/src/models/models.h b/src/models/models.h index 5251e2d8280..cbef040870b 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1913,5 +1913,9 @@ struct llama_model_step35 : public llama_model_base { graph(const llama_model & model, const llm_graph_params & params); }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/src/models/step35.cpp b/src/models/step35.cpp index 3b68e68707a..c62f6f0eb28 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -26,20 +26,36 @@ void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); - switch (hparams.n_layer) { + // NextN/MTP (Step3p5): extra decoder block appended beyond the main stack. + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 45: type = LLM_TYPE_196B_A11B; break; default: type = LLM_TYPE_UNKNOWN; } } -void llama_model_step35::load_arch_tensors(llama_model_loader &) { +void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + // Trunk-only: the GGUF declares MTP layers in metadata but the actual MTP + // tensors live in a separate file (e.g. user split target/draft). Mark + // MTP tensors NOT_REQUIRED so the trunk loads cleanly. + const std::string mtp_probe = "blk." + std::to_string(n_main) + ".nextn.eh_proj.weight"; + const bool trunk_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight(mtp_probe.c_str()) == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags); // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. @@ -51,14 +67,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { n_rot_max = n_rot; } - for (int i = 0; i < n_layer; ++i) { + auto load_block_trunk = [&](int i, int flags) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -70,13 +86,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); // head-wise attention gate (Step35 self_attn.g_proj) layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense MLP (leading dense blocks) layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -95,10 +111,74 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + }; + + auto load_block_mtp = [&](int i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the + // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). + // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags); + + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags); + + // dense MLP (leading dense blocks) — present if the MTP block isn't MoE + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +191,9 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; const uint32_t n_head_l = hparams.n_head(il); @@ -198,8 +280,8 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "attn_proj", il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -257,6 +339,13 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -267,3 +356,192 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Step3p5 (MoE) +llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "STEP35 MTP requires nextn_predict_layers > 0"); + + // Round-robin across MTP blocks at draft step boundaries. Matches vllm's + // `current_step_idx = spec_step_idx % num_mtp_layers` (step3p5_mtp.py). + // The first MTP block lives at layer index `n_main`; the speculative + // driver bumps `cparams.mtp_step` between AR iterations. + const int n_main = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int step_offset = (int) (cparams.mtp_step % hparams.nextn_predict_layers); + const int il = n_main + step_offset; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + // mtp_block: full Step3p5 decoder layer (attention with optional head-wise gate, then MoE/dense FFN) + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Qcur, "mtp_Qcur", il); + cb(Kcur, "mtp_Kcur", il); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + } + + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "mtp_Qcur_pos", il); + cb(Kcur, "mtp_Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "mtp_attn_out", il); + + // head-wise attention gate: sigmoid(g_proj(x)) + if (layer.wqkv_gate) { + ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "mtp_attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "mtp_attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "mtp_attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "mtp_attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "mtp_attn_gated", il); + } + + cur = build_lora_mm(layer.wo, attn_out, layer.wo_s); + cb(cur, "mtp_attn_proj", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_ffn_norm", il); + + // FFN: dense MLP or MoE (mirrors trunk path) + if (layer.ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + layer.ffn_up, layer.ffn_up_b, nullptr, + layer.ffn_gate, layer.ffn_gate_b, nullptr, + layer.ffn_down, layer.ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "mtp_ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "mtp_ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "mtp_ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} From c06411bc18cd957a9c6d7de041d9019bdacbea21 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Wed, 20 May 2026 15:00:48 +0200 Subject: [PATCH 2/8] Simplify to single layer --- common/speculative.cpp | 8 +++++--- src/models/step35.cpp | 44 +++++++++++++++++++++++++++--------------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 3900fe08a33..853356f85d0 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -734,9 +734,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { break; } - // Step i+1: feed the i-th sampled draft token into the (i+1)-th - // MTP block. Multi-block archs round-robin via mtp_step % N. - llama_set_mtp_step(ctx_dft, (uint32_t)(i + 1)); + // Single-block-MTP-only: every AR step reuses the first MTP block + // (Qwen MTP / vLLM single-MTP-layer style). mtp_step stays at 0; + // trailing MTP blocks loaded from the GGUF are ignored at + // runtime, and pruned GGUFs (block 0 only) work the same way. + llama_set_mtp_step(ctx_dft, 0); // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); diff --git a/src/models/step35.cpp b/src/models/step35.cpp index c62f6f0eb28..caf18c743ff 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -113,7 +113,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); }; - auto load_block_mtp = [&](int i) { + auto load_block_mtp = [&](int i, bool is_first_mtp) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); @@ -123,7 +123,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags); + // + // Only the FIRST MTP block (i == n_main) is required for the + // single-block MTP runtime; trailing MTP blocks are always tolerated + // as missing so pruned GGUFs (block 0 only) load cleanly. Override + // mtp_flags to NOT_REQUIRED for those. + const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -134,12 +141,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); // dense MLP (leading dense blocks) — present if the MTP block isn't MoE layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -159,9 +166,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); @@ -170,8 +177,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { for (int i = 0; i < (int) n_main; ++i) { load_block_trunk(i, trunk_flags); } + // Only the first MTP block (i == n_main) is required at runtime — the + // single-block-MTP graph in build_arch_graph always uses that one. + // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with + // all MTP layers still works) but tolerated when absent via the pruning + // path. See scripts/prune_step35_extra_mtp.py for the pruner. for (int i = (int) n_main; i < n_layer; ++i) { - load_block_mtp(i); + load_block_mtp(i, /*is_first_mtp=*/ i == (int) n_main); } } @@ -362,13 +374,13 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr : llm_graph_context(params) { GGML_ASSERT(hparams.nextn_predict_layers > 0 && "STEP35 MTP requires nextn_predict_layers > 0"); - // Round-robin across MTP blocks at draft step boundaries. Matches vllm's - // `current_step_idx = spec_step_idx % num_mtp_layers` (step3p5_mtp.py). - // The first MTP block lives at layer index `n_main`; the speculative - // driver bumps `cparams.mtp_step` between AR iterations. - const int n_main = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; - const int step_offset = (int) (cparams.mtp_step % hparams.nextn_predict_layers); - const int il = n_main + step_offset; + // Single-block MTP only: always run the first trained MTP block (Qwen + // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to + // be a much deeper refactor than this PR justifies; the trailing MTP + // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just + // block 0) also work — see load_arch_tensors below and + // scripts/prune_step35_extra_mtp.py. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); From 89c1be6013ececdb76cc7cca6cb68ab2fabf4836 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Wed, 20 May 2026 16:04:00 +0200 Subject: [PATCH 3/8] Rollback core changes --- common/speculative.cpp | 15 --- scripts/prune_step35_extra_mtp.py | 204 ++++++++++++++++++++++++++++++ src/llama-context.cpp | 9 -- src/llama-context.h | 1 - src/llama-cparams.h | 5 - src/llama-ext.h | 10 -- 6 files changed, 204 insertions(+), 40 deletions(-) create mode 100755 scripts/prune_step35_extra_mtp.py diff --git a/common/speculative.cpp b/common/speculative.cpp index 853356f85d0..253a5ececbb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -665,11 +665,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); } - // First draft step uses the first MTP block (step 0). Archs with a - // single MTP block ignore this; multi-block archs (Step-3.5-Flash) use - // it to round-robin across their N MTP layers. - llama_set_mtp_step(ctx_dft, 0); - int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); @@ -734,12 +729,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { break; } - // Single-block-MTP-only: every AR step reuses the first MTP block - // (Qwen MTP / vLLM single-MTP-layer style). mtp_step stays at 0; - // trailing MTP blocks loaded from the GGUF are ignored at - // runtime, and pruned GGUFs (block 0 only) work the same way. - llama_set_mtp_step(ctx_dft, 0); - // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { @@ -750,10 +739,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { ++i; } - // Reset MTP step so a subsequent non-MTP decode on this context doesn't - // inherit a stale offset. - llama_set_mtp_step(ctx_dft, 0); - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { diff --git a/scripts/prune_step35_extra_mtp.py b/scripts/prune_step35_extra_mtp.py new file mode 100755 index 00000000000..001ed4a6123 --- /dev/null +++ b/scripts/prune_step35_extra_mtp.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +""" +Prune a Step-3.5 GGUF down to just the first MTP block. + +The runtime only uses the first MTP block (single-block-MTP, Qwen/vLLM +style); trailing MTP blocks are loaded with TENSOR_NOT_REQUIRED so a pruned +GGUF works without further surgery. This script does the surgery: it drops +all tensors for blocks `blk.{n_main+1}..blk.{n_total-1}`, rewrites the +per-layer metadata arrays + block_count + nextn_predict_layers so the loader +sees the slimmer model, and writes a new GGUF. Saves ~one MTP block of +weights per pruned block (single-digit GB on Step-3.5-Flash). + + n_main = block_count - nextn_predict_layers (transformer trunk) + keep = blk.0 .. blk.n_main (trunk + first MTP block) + drop = blk.{n_main+1} .. blk.{block_count-1} + +After pruning the output has block_count = n_main + 1 and +nextn_predict_layers = 1. + +Example: + ./scripts/prune_step35_extra_mtp.py step3p5-flash-full.gguf step3p5-flash-mtp1.gguf +""" +from __future__ import annotations + +import argparse +import logging +import os +import re +import sys +from pathlib import Path +from typing import Any + +from tqdm import tqdm + +# Allow running from a llama.cpp checkout without installing gguf-py. +if "NO_LOCAL_GGUF" not in os.environ: + repo_root = Path(__file__).resolve().parent.parent + sys.path.insert(0, str(repo_root / "gguf-py")) + +import gguf # noqa: E402 + +logger = logging.getLogger("prune-step35-extra-mtp") + +# Per-layer metadata keys (step35-specific) that we need to trim along with +# block_count. Matches the schema fix_step35_mtp_metadata.py knows about. +PER_LAYER_KEYS: dict[str, gguf.GGUFValueType] = { + "step35.attention.head_count": gguf.GGUFValueType.UINT32, + "step35.attention.head_count_kv": gguf.GGUFValueType.UINT32, + "step35.attention.sliding_window_pattern": gguf.GGUFValueType.BOOL, + "step35.swiglu_clamp_exp": gguf.GGUFValueType.FLOAT32, + "step35.swiglu_clamp_shexp": gguf.GGUFValueType.FLOAT32, +} + +BLOCK_COUNT_KEY = "step35.block_count" +NEXTN_LAYERS_KEY = "step35.nextn_predict_layers" + +# Tensors with a "blk.." prefix belong to block N. Anything that doesn't +# match (token_embd, output_norm, output, ...) is kept unconditionally. +BLOCK_TENSOR_RE = re.compile(r"^blk\.(\d+)\.") + + +def get_field_contents(reader: gguf.GGUFReader, key: str) -> Any: + field = reader.get_field(key) + return field.contents() if field else None + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("input", type=Path, help="input GGUF (Step-3.5 with full MTP)") + parser.add_argument("output", type=Path, help="output GGUF (overwritten if exists)") + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, + format="%(levelname)s: %(message)s") + + logger.info("Reading %s", args.input) + reader = gguf.GGUFReader(args.input, "r") + + arch = get_field_contents(reader, gguf.Keys.General.ARCHITECTURE) + if arch != "step35": + logger.error("Expected arch 'step35', got %r — this script is Step-3.5-specific.", arch) + sys.exit(1) + + block_count_field = reader.get_field(BLOCK_COUNT_KEY) + if block_count_field is None: + logger.error("Missing %s in input GGUF.", BLOCK_COUNT_KEY) + sys.exit(1) + block_count = int(block_count_field.contents()) + + nextn_field = reader.get_field(NEXTN_LAYERS_KEY) + if nextn_field is None: + logger.error("Input has no %s — nothing to prune (run fix_step35_mtp_metadata.py first if needed).", + NEXTN_LAYERS_KEY) + sys.exit(1) + n_mtp = int(nextn_field.contents()) + + if n_mtp <= 1: + logger.info("nextn_predict_layers=%d already <= 1; nothing to prune. Copying input unchanged.", n_mtp) + + n_main = block_count - n_mtp + keep_first = n_main # first MTP block (index) + keep_last = n_main # inclusive — only one MTP block kept + drop_first = n_main + 1 + drop_last = block_count - 1 # inclusive + + n_total_new = n_main + 1 # new block_count + n_mtp_new = 1 # new nextn_predict_layers + + logger.info( + "Pruning plan: trunk %d blocks (0..%d), keep MTP block %d, drop MTP blocks %d..%d -> new block_count=%d", + n_main, n_main - 1 if n_main > 0 else -1, keep_first, drop_first, drop_last, n_total_new, + ) + + # Per-layer arrays: trim to length n_total_new (drop the trailing n_mtp-1 entries). + new_arrays: dict[str, list[Any]] = {} + for key in PER_LAYER_KEYS: + field = reader.get_field(key) + if field is None: + continue + if field.types[0] != gguf.GGUFValueType.ARRAY: + # Scalar — irrelevant for per-layer trimming, will be copied as-is. + continue + current: list[Any] = list(field.contents()) + if len(current) <= n_total_new: + logger.info(" %s: already length %d (<= %d), leaving as-is", key, len(current), n_total_new) + continue + trimmed = current[:n_total_new] + logger.info(" %s: trimming length %d -> %d", key, len(current), n_total_new) + new_arrays[key] = trimmed + + logger.info("Writing %s", args.output) + writer = gguf.GGUFWriter( + path = args.output, + arch = arch, + endianess = reader.endianess, + ) + + # Pass 1: copy every existing KV except those we're rewriting. + rewritten = set(new_arrays.keys()) | {BLOCK_COUNT_KEY, NEXTN_LAYERS_KEY} + for field in reader.fields.values(): + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith("GGUF."): + continue + if field.name in rewritten: + continue + val_type = field.types[0] + sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None + writer.add_key_value(field.name, field.contents(), val_type, sub_type=sub_type) + + # Pass 2: rewritten metadata. + writer.add_uint32(BLOCK_COUNT_KEY, n_total_new) + writer.add_uint32(NEXTN_LAYERS_KEY, n_mtp_new) + + for key, values in new_arrays.items(): + sub_type = PER_LAYER_KEYS[key] + writer.add_key_value(key, values, gguf.GGUFValueType.ARRAY, sub_type=sub_type) + + # Tensors: copy those that belong to a kept block (or to no block at all). + kept_tensors = [] + dropped_count = 0 + dropped_bytes = 0 + for tensor in reader.tensors: + m = BLOCK_TENSOR_RE.match(tensor.name) + if m is not None: + blk_idx = int(m.group(1)) + if blk_idx > keep_last: + dropped_count += 1 + dropped_bytes += tensor.n_bytes + logger.debug(" drop %s (blk.%d)", tensor.name, blk_idx) + continue + kept_tensors.append(tensor) + + total_bytes = sum(t.n_bytes for t in kept_tensors) + logger.info("Tensors: keeping %d (%.2f GB), dropping %d (%.2f GB)", + len(kept_tensors), total_bytes / (1 << 30), + dropped_count, dropped_bytes / (1 << 30)) + + for tensor in kept_tensors: + writer.add_tensor_info( + tensor.name, + tensor.data.shape, + tensor.data.dtype, + tensor.data.nbytes, + tensor.tensor_type, + ) + + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_ti_data_to_file() + + for tensor in kept_tensors: + writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) + bar.update(tensor.n_bytes) + + writer.close() + bar.close() + logger.info("Done. Wrote %s (block_count=%d, nextn_predict_layers=%d).", + args.output, n_total_new, n_mtp_new) + + +if __name__ == "__main__": + main() diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e8ba8497565..ad36c06667d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -67,7 +67,6 @@ llama_context::llama_context( cparams.embeddings = params.embeddings; cparams.embeddings_pre_norm = false; cparams.embeddings_pre_norm_masked = false; - cparams.mtp_step = 0; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -1106,10 +1105,6 @@ void llama_context::set_embeddings_pre_norm(bool value, bool masked) { cparams.embeddings_pre_norm_masked = masked; } -void llama_context::set_mtp_step(uint32_t step) { - cparams.mtp_step = step; -} - void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -3605,10 +3600,6 @@ float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { return ctx->get_embeddings_pre_norm_ith(i); } -void llama_set_mtp_step(llama_context * ctx, uint32_t step) { - ctx->set_mtp_step(step); -} - bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index 6de8f6df6cb..d03f681d4a1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -111,7 +111,6 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_pre_norm(bool value, bool masked); - void set_mtp_step(uint32_t step); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index f4faac6d080..20ec59fe335 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -30,11 +30,6 @@ struct llama_cparams { bool embeddings; bool embeddings_pre_norm; // also extract the hidden state before the final output norm bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 - - // MTP draft-step index, used by archs with num_nextn_predict_layers > 1 to - // round-robin across MTP blocks (matches vllm's spec_step_idx). The graph - // builder selects `il = n_main + (mtp_step % nextn_predict_layers)`. - uint32_t mtp_step; bool causal_attn; bool offload_kqv; bool flash_attn; diff --git a/src/llama-ext.h b/src/llama-ext.h index 3c6fa3eb5e7..edfa71c207c 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -104,13 +104,3 @@ LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); - -// -// MTP draft-step index (round-robin selector across MTP blocks) -// - -// Set the MTP draft-step index for the next llama_decode call. Used by archs -// with num_nextn_predict_layers > 1 to round-robin across their MTP blocks -// (matches vllm's spec_step_idx). Pass step = 0 for the first draft token, -// step = 1 for the second, etc. The graph builder reads cparams.mtp_step. -LLAMA_API void llama_set_mtp_step(struct llama_context * ctx, uint32_t step); From e444ae70f1152a8329786a33cd00a65b5770968a Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Fri, 29 May 2026 12:27:27 +0200 Subject: [PATCH 4/8] fix flake8 errors --- scripts/fix_step35_mtp_metadata.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/fix_step35_mtp_metadata.py b/scripts/fix_step35_mtp_metadata.py index 855347f81f7..4ddeae4b820 100755 --- a/scripts/fix_step35_mtp_metadata.py +++ b/scripts/fix_step35_mtp_metadata.py @@ -98,8 +98,10 @@ def make_extended_value( def _bool(s: str) -> bool: t = s.strip().lower() - if t in ("1", "true", "yes", "y"): return True - if t in ("0", "false", "no", "n"): return False + if t in ("1", "true", "yes", "y"): + return True + if t in ("0", "false", "no", "n"): + return False raise argparse.ArgumentTypeError(f"expected boolean, got {s!r}") From 82dec8d95927a2b43cf14cab369b092ae675ef58 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Fri, 29 May 2026 13:40:59 +0200 Subject: [PATCH 5/8] Remove scripts --- scripts/fix_step35_mtp_metadata.py | 239 ----------------------------- scripts/prune_step35_extra_mtp.py | 204 ------------------------ 2 files changed, 443 deletions(-) delete mode 100755 scripts/fix_step35_mtp_metadata.py delete mode 100755 scripts/prune_step35_extra_mtp.py diff --git a/scripts/fix_step35_mtp_metadata.py b/scripts/fix_step35_mtp_metadata.py deleted file mode 100755 index 4ddeae4b820..00000000000 --- a/scripts/fix_step35_mtp_metadata.py +++ /dev/null @@ -1,239 +0,0 @@ -#!/usr/bin/env python3 -""" -Fix Step-3.5 GGUF metadata for MTP support. - -Old (pre-MTP) Step-3.5 GGUFs were written with `step35.block_count = num_hidden_layers` -and per-layer arrays sized to the same length, so the appended MTP blocks have -no metadata slot. This script: - - * sets `step35.block_count = num_hidden_layers + num_nextn_predict_layers` - * appends `num_nextn_predict_layers` entries to every known per-layer array - (head_count, head_count_kv, sliding_window_pattern, swiglu.clamp_exp, - swiglu.clamp_shexp) so the C++ loader's length check passes - * writes `step35.nextn_predict_layers` - * copies all tensors over unchanged - -Defaults assume the MTP blocks are `sliding_attention` (Step-3.5-Flash): -head_count=96, head_count_kv=8, swa=True, swiglu_clamp_exp=0, swiglu_clamp_shexp=0. -Override with the per-array flags below if your model differs. - -Run conversion with the up-to-date `conversion/step3.py` to produce a correct -GGUF in the first place; this script exists only to retrofit older outputs. -""" -from __future__ import annotations - -import argparse -import logging -import os -import sys -from pathlib import Path -from typing import Any - -from tqdm import tqdm - -# Allow running from a llama.cpp checkout without installing gguf-py. -if "NO_LOCAL_GGUF" not in os.environ: - repo_root = Path(__file__).resolve().parent.parent - sys.path.insert(0, str(repo_root / "gguf-py")) - -import gguf # noqa: E402 - -logger = logging.getLogger("fix-step35-mtp-metadata") - - -# Per-layer metadata keys (step35-specific) we know how to extend. -PER_LAYER_KEYS: dict[str, gguf.GGUFValueType] = { - "step35.attention.head_count": gguf.GGUFValueType.UINT32, - "step35.attention.head_count_kv": gguf.GGUFValueType.UINT32, - "step35.attention.sliding_window_pattern": gguf.GGUFValueType.BOOL, - "step35.swiglu_clamp_exp": gguf.GGUFValueType.FLOAT32, - "step35.swiglu_clamp_shexp": gguf.GGUFValueType.FLOAT32, -} - -BLOCK_COUNT_KEY = "step35.block_count" -NEXTN_LAYERS_KEY = "step35.nextn_predict_layers" - - -def get_field_contents(reader: gguf.GGUFReader, key: str) -> Any: - field = reader.get_field(key) - return field.contents() if field else None - - -def field_main_type(reader: gguf.GGUFReader, key: str) -> gguf.GGUFValueType | None: - field = reader.get_field(key) - if field is None or not field.types: - return None - return field.types[0] - - -def make_extended_value( - reader: gguf.GGUFReader, - key: str, - sub_type: gguf.GGUFValueType, - n_total: int, - mtp_value: Any, - n_mtp: int, -) -> list[Any] | None: - """Return the new array for ``key`` (length ``n_total``) or None if the key - is absent and no broadcast is needed.""" - field = reader.get_field(key) - if field is None: - return None - - main_type = field.types[0] - if main_type == gguf.GGUFValueType.ARRAY: - current: list[Any] = list(field.contents()) - if len(current) >= n_total: - logger.info(" %s already length %d (>= %d), trimming", key, len(current), n_total) - return current[:n_total] - pad = n_total - len(current) - logger.info(" %s: extending %d -> %d (appending %d × %r)", key, len(current), n_total, pad, mtp_value) - return current + [mtp_value] * pad - - # Scalar — broadcast to length n_total, replacing the trailing n_mtp entries. - scalar = field.contents() - logger.info(" %s: scalar %r -> array of %d (last %d = %r)", key, scalar, n_total, n_mtp, mtp_value) - return [scalar] * (n_total - n_mtp) + [mtp_value] * n_mtp - - -def _bool(s: str) -> bool: - t = s.strip().lower() - if t in ("1", "true", "yes", "y"): - return True - if t in ("0", "false", "no", "n"): - return False - raise argparse.ArgumentTypeError(f"expected boolean, got {s!r}") - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("input", type=Path, help="input GGUF (Step-3.5, trunk-only)") - parser.add_argument("output", type=Path, help="output GGUF (overwritten if exists)") - parser.add_argument("--n-mtp", type=int, default=3, - help="number of MTP blocks to append (default: 3, matching Step-3.5-Flash)") - - # Per-MTP-layer values. Defaults are correct for Step-3.5-Flash whose MTP - # blocks are `sliding_attention` type. - parser.add_argument("--head-count", type=int, default=96, - help="MTP layer head_count (default 96, sliding_attention)") - parser.add_argument("--head-count-kv", type=int, default=8, - help="MTP layer head_count_kv (default 8)") - parser.add_argument("--swa", type=_bool, default=True, - help="MTP layer sliding-window flag (default true)") - parser.add_argument("--swiglu-clamp-exp", type=float, default=0.0, - help="MTP layer swiglu clamp exp (default 0.0 = no clamp)") - parser.add_argument("--swiglu-clamp-shexp", type=float, default=0.0, - help="MTP layer swiglu clamp shexp (default 0.0)") - - parser.add_argument("--verbose", action="store_true") - args = parser.parse_args() - - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, - format="%(levelname)s: %(message)s") - - if args.n_mtp <= 0: - logger.error("--n-mtp must be > 0; got %d", args.n_mtp) - sys.exit(1) - - logger.info("Reading %s", args.input) - reader = gguf.GGUFReader(args.input, "r") - - arch = get_field_contents(reader, gguf.Keys.General.ARCHITECTURE) - if arch != "step35": - logger.error("Expected arch 'step35', got %r — this script is Step-3.5-specific.", arch) - sys.exit(1) - - block_count_field = reader.get_field(BLOCK_COUNT_KEY) - if block_count_field is None: - logger.error("Missing %s in input GGUF.", BLOCK_COUNT_KEY) - sys.exit(1) - block_count = int(block_count_field.contents()) - - existing_nextn = get_field_contents(reader, NEXTN_LAYERS_KEY) - if existing_nextn is not None: - # block_count already includes the MTP blocks; back them out so the - # math below is independent of whether the input is old-style - # (block_count = trunk) or new-style (block_count = trunk + nextn). - n_main = block_count - int(existing_nextn) - logger.info("Input declares nextn_predict_layers=%d; trunk has %d main blocks.", - existing_nextn, n_main) - else: - n_main = block_count - logger.info("Input has no nextn_predict_layers key; treating block_count=%d as the trunk count.", - block_count) - - n_total = n_main + args.n_mtp - logger.info("Block count: trunk=%d, MTP=%d -> total=%d", n_main, args.n_mtp, n_total) - - per_key_value: dict[str, Any] = { - "step35.attention.head_count": args.head_count, - "step35.attention.head_count_kv": args.head_count_kv, - "step35.attention.sliding_window_pattern": args.swa, - "step35.swiglu_clamp_exp": args.swiglu_clamp_exp, - "step35.swiglu_clamp_shexp": args.swiglu_clamp_shexp, - } - - # Pre-compute the new array values so we can write them via the writer's - # standard add_key_value() path. - new_arrays: dict[str, list[Any]] = {} - for key, sub_type in PER_LAYER_KEYS.items(): - new_val = make_extended_value(reader, key, sub_type, n_total, per_key_value[key], args.n_mtp) - if new_val is not None: - new_arrays[key] = new_val - - logger.info("Writing %s", args.output) - writer = gguf.GGUFWriter( - path = args.output, - arch = arch, - endianess = reader.endianess, - ) - - # Pass 1: copy every existing KV except those we're rewriting. - rewritten = set(new_arrays.keys()) | {BLOCK_COUNT_KEY, NEXTN_LAYERS_KEY} - for field in reader.fields.values(): - if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith("GGUF."): - continue - if field.name in rewritten: - continue - val_type = field.types[0] - sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None - writer.add_key_value(field.name, field.contents(), val_type, sub_type=sub_type) - - # Pass 2: rewritten metadata. - writer.add_uint32(BLOCK_COUNT_KEY, n_total) - writer.add_uint32(NEXTN_LAYERS_KEY, args.n_mtp) - - for key, values in new_arrays.items(): - sub_type = PER_LAYER_KEYS[key] - writer.add_key_value(key, values, gguf.GGUFValueType.ARRAY, sub_type=sub_type) - - # Tensors: copy unchanged. - total_bytes = 0 - for tensor in reader.tensors: - total_bytes += tensor.n_bytes - writer.add_tensor_info( - tensor.name, - tensor.data.shape, - tensor.data.dtype, - tensor.data.nbytes, - tensor.tensor_type, - ) - - bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - - writer.write_header_to_file() - writer.write_kv_data_to_file() - writer.write_ti_data_to_file() - - for tensor in reader.tensors: - writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) - bar.update(tensor.n_bytes) - - writer.close() - bar.close() - logger.info("Done. Wrote %s with block_count=%d, nextn_predict_layers=%d.", - args.output, n_total, args.n_mtp) - - -if __name__ == "__main__": - main() diff --git a/scripts/prune_step35_extra_mtp.py b/scripts/prune_step35_extra_mtp.py deleted file mode 100755 index 001ed4a6123..00000000000 --- a/scripts/prune_step35_extra_mtp.py +++ /dev/null @@ -1,204 +0,0 @@ -#!/usr/bin/env python3 -""" -Prune a Step-3.5 GGUF down to just the first MTP block. - -The runtime only uses the first MTP block (single-block-MTP, Qwen/vLLM -style); trailing MTP blocks are loaded with TENSOR_NOT_REQUIRED so a pruned -GGUF works without further surgery. This script does the surgery: it drops -all tensors for blocks `blk.{n_main+1}..blk.{n_total-1}`, rewrites the -per-layer metadata arrays + block_count + nextn_predict_layers so the loader -sees the slimmer model, and writes a new GGUF. Saves ~one MTP block of -weights per pruned block (single-digit GB on Step-3.5-Flash). - - n_main = block_count - nextn_predict_layers (transformer trunk) - keep = blk.0 .. blk.n_main (trunk + first MTP block) - drop = blk.{n_main+1} .. blk.{block_count-1} - -After pruning the output has block_count = n_main + 1 and -nextn_predict_layers = 1. - -Example: - ./scripts/prune_step35_extra_mtp.py step3p5-flash-full.gguf step3p5-flash-mtp1.gguf -""" -from __future__ import annotations - -import argparse -import logging -import os -import re -import sys -from pathlib import Path -from typing import Any - -from tqdm import tqdm - -# Allow running from a llama.cpp checkout without installing gguf-py. -if "NO_LOCAL_GGUF" not in os.environ: - repo_root = Path(__file__).resolve().parent.parent - sys.path.insert(0, str(repo_root / "gguf-py")) - -import gguf # noqa: E402 - -logger = logging.getLogger("prune-step35-extra-mtp") - -# Per-layer metadata keys (step35-specific) that we need to trim along with -# block_count. Matches the schema fix_step35_mtp_metadata.py knows about. -PER_LAYER_KEYS: dict[str, gguf.GGUFValueType] = { - "step35.attention.head_count": gguf.GGUFValueType.UINT32, - "step35.attention.head_count_kv": gguf.GGUFValueType.UINT32, - "step35.attention.sliding_window_pattern": gguf.GGUFValueType.BOOL, - "step35.swiglu_clamp_exp": gguf.GGUFValueType.FLOAT32, - "step35.swiglu_clamp_shexp": gguf.GGUFValueType.FLOAT32, -} - -BLOCK_COUNT_KEY = "step35.block_count" -NEXTN_LAYERS_KEY = "step35.nextn_predict_layers" - -# Tensors with a "blk.." prefix belong to block N. Anything that doesn't -# match (token_embd, output_norm, output, ...) is kept unconditionally. -BLOCK_TENSOR_RE = re.compile(r"^blk\.(\d+)\.") - - -def get_field_contents(reader: gguf.GGUFReader, key: str) -> Any: - field = reader.get_field(key) - return field.contents() if field else None - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument("input", type=Path, help="input GGUF (Step-3.5 with full MTP)") - parser.add_argument("output", type=Path, help="output GGUF (overwritten if exists)") - parser.add_argument("--verbose", action="store_true") - args = parser.parse_args() - - logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, - format="%(levelname)s: %(message)s") - - logger.info("Reading %s", args.input) - reader = gguf.GGUFReader(args.input, "r") - - arch = get_field_contents(reader, gguf.Keys.General.ARCHITECTURE) - if arch != "step35": - logger.error("Expected arch 'step35', got %r — this script is Step-3.5-specific.", arch) - sys.exit(1) - - block_count_field = reader.get_field(BLOCK_COUNT_KEY) - if block_count_field is None: - logger.error("Missing %s in input GGUF.", BLOCK_COUNT_KEY) - sys.exit(1) - block_count = int(block_count_field.contents()) - - nextn_field = reader.get_field(NEXTN_LAYERS_KEY) - if nextn_field is None: - logger.error("Input has no %s — nothing to prune (run fix_step35_mtp_metadata.py first if needed).", - NEXTN_LAYERS_KEY) - sys.exit(1) - n_mtp = int(nextn_field.contents()) - - if n_mtp <= 1: - logger.info("nextn_predict_layers=%d already <= 1; nothing to prune. Copying input unchanged.", n_mtp) - - n_main = block_count - n_mtp - keep_first = n_main # first MTP block (index) - keep_last = n_main # inclusive — only one MTP block kept - drop_first = n_main + 1 - drop_last = block_count - 1 # inclusive - - n_total_new = n_main + 1 # new block_count - n_mtp_new = 1 # new nextn_predict_layers - - logger.info( - "Pruning plan: trunk %d blocks (0..%d), keep MTP block %d, drop MTP blocks %d..%d -> new block_count=%d", - n_main, n_main - 1 if n_main > 0 else -1, keep_first, drop_first, drop_last, n_total_new, - ) - - # Per-layer arrays: trim to length n_total_new (drop the trailing n_mtp-1 entries). - new_arrays: dict[str, list[Any]] = {} - for key in PER_LAYER_KEYS: - field = reader.get_field(key) - if field is None: - continue - if field.types[0] != gguf.GGUFValueType.ARRAY: - # Scalar — irrelevant for per-layer trimming, will be copied as-is. - continue - current: list[Any] = list(field.contents()) - if len(current) <= n_total_new: - logger.info(" %s: already length %d (<= %d), leaving as-is", key, len(current), n_total_new) - continue - trimmed = current[:n_total_new] - logger.info(" %s: trimming length %d -> %d", key, len(current), n_total_new) - new_arrays[key] = trimmed - - logger.info("Writing %s", args.output) - writer = gguf.GGUFWriter( - path = args.output, - arch = arch, - endianess = reader.endianess, - ) - - # Pass 1: copy every existing KV except those we're rewriting. - rewritten = set(new_arrays.keys()) | {BLOCK_COUNT_KEY, NEXTN_LAYERS_KEY} - for field in reader.fields.values(): - if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith("GGUF."): - continue - if field.name in rewritten: - continue - val_type = field.types[0] - sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None - writer.add_key_value(field.name, field.contents(), val_type, sub_type=sub_type) - - # Pass 2: rewritten metadata. - writer.add_uint32(BLOCK_COUNT_KEY, n_total_new) - writer.add_uint32(NEXTN_LAYERS_KEY, n_mtp_new) - - for key, values in new_arrays.items(): - sub_type = PER_LAYER_KEYS[key] - writer.add_key_value(key, values, gguf.GGUFValueType.ARRAY, sub_type=sub_type) - - # Tensors: copy those that belong to a kept block (or to no block at all). - kept_tensors = [] - dropped_count = 0 - dropped_bytes = 0 - for tensor in reader.tensors: - m = BLOCK_TENSOR_RE.match(tensor.name) - if m is not None: - blk_idx = int(m.group(1)) - if blk_idx > keep_last: - dropped_count += 1 - dropped_bytes += tensor.n_bytes - logger.debug(" drop %s (blk.%d)", tensor.name, blk_idx) - continue - kept_tensors.append(tensor) - - total_bytes = sum(t.n_bytes for t in kept_tensors) - logger.info("Tensors: keeping %d (%.2f GB), dropping %d (%.2f GB)", - len(kept_tensors), total_bytes / (1 << 30), - dropped_count, dropped_bytes / (1 << 30)) - - for tensor in kept_tensors: - writer.add_tensor_info( - tensor.name, - tensor.data.shape, - tensor.data.dtype, - tensor.data.nbytes, - tensor.tensor_type, - ) - - bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) - - writer.write_header_to_file() - writer.write_kv_data_to_file() - writer.write_ti_data_to_file() - - for tensor in kept_tensors: - writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) - bar.update(tensor.n_bytes) - - writer.close() - bar.close() - logger.info("Done. Wrote %s (block_count=%d, nextn_predict_layers=%d).", - args.output, n_total_new, n_mtp_new) - - -if __name__ == "__main__": - main() From ded7052a13d3ebe5b56feb51810507684c9030df Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Jun 2026 12:49:44 +0200 Subject: [PATCH 6/8] modify to convention --- conversion/step3.py | 63 +++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/conversion/step3.py b/conversion/step3.py index bd0adb3f6ad..f6240bea1fe 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -99,12 +99,13 @@ class Step3VLTextModel(Qwen3Model): class Step35Model(TextModel): model_arch = gguf.MODEL_ARCH.STEP35 - # --mtp / --no-mtp toggles (see convert_hf_to_gguf.py main()). - # Unlike Qwen3.5 which stores MTP under a `mtp.*` namespace, Step3.5 just - # appends MTP layers at `model.layers.{num_hidden_layers + i}`; these flags - # filter by layer index instead of by name prefix. - no_mtp: bool = False - mtp_only: bool = False + # The --mtp / --no-mtp toggles are ModelBase.mtp_only / no_mtp (set in + # convert_hf_to_gguf.py main()). Unlike Qwen3.5, which stores MTP under a + # `mtp.*` namespace, Step3.5 appends MTP layers at + # `model.layers.{num_hidden_layers + i}`, so we filter them by layer index. + # The trunk layer count is captured before indexing so the classmethod + # filter_tensors can tell the appended MTP block(s) apart from the trunk. + _n_main_layers: int | None = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -117,6 +118,13 @@ def __init__(self, *args, **kwargs): self.block_count = int(self.hparams["num_hidden_layers"]) + n_nextn self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + def index_tensors(self, remote_hf_model_id: str | None = None): + # filter_tensors is a classmethod and can't reach self.hparams; stash + # the trunk layer count here (before indexing runs) so it can detect + # the appended MTP layers by index. + type(self)._n_main_layers = int(self.hparams["num_hidden_layers"]) + return super().index_tensors(remote_hf_model_id=remote_hf_model_id) + def set_gguf_parameters(self): rope_theta = self.hparams.get("rope_theta") if isinstance(rope_theta, list): @@ -219,31 +227,19 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca if name.endswith(".moe.router_bias"): name += ".bias" - return super().filter_tensors((name, gen)) - - def _is_mtp_layer(self, bid: int | None) -> bool: - if bid is None: - return False - n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) - return bid >= n_main - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): - is_mtp = self._is_mtp_layer(bid) + # Step3.5 appends the MTP block(s) past num_hidden_layers. + assert cls._n_main_layers is not None + is_mtp = (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None and int(m.group(1)) >= cls._n_main_layers # --no-mtp: drop the appended MTP block(s) entirely. - if is_mtp and self.no_mtp: - return - # --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/lm_head - # (so the resulting GGUF carries just the draft head). - if self.mtp_only and not is_mtp and bid is not None: - return - if self.mtp_only and bid is None: - # Top-level tensors: keep only shared embeddings/norm/lm_head. - keep = name in ( - "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", - ) - if not keep: - return + if is_mtp and cls.no_mtp: + return None + # --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/ + # lm_head (so the resulting GGUF carries just the draft head). + if cls.mtp_only and not is_mtp and name not in ( + "model.embed_tokens.weight", "model.norm.weight", "lm_head.weight", + ): + return None # The checkpoint nests the per-MTP-layer shared head under # `model.layers.{N+i}.transformer.shared_head.{norm,output}.weight`; @@ -251,11 +247,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): # existing NEXTN_SHARED_HEAD_{NORM,HEAD} tensor mapping picks them up. # Mirrors vllm's `_rewrite_spec_layer_name` (step3p5_mtp.py). if is_mtp: - if ".transformer." in name: - name = name.replace(".transformer.", ".") - if "shared_head.output" in name: - name = name.replace("shared_head.output", "shared_head.head") + name = name.replace(".transformer.", ".") + name = name.replace("shared_head.output", "shared_head.head") + return super().filter_tensors((name, gen)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.endswith("norm.weight"): data_torch += 1.0 From 66bc6a32679b662c6f670c3f2ee99ba14ea8789c Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Mon, 1 Jun 2026 13:44:52 +0200 Subject: [PATCH 7/8] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sigbjørn Skjæret --- conversion/step3.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/conversion/step3.py b/conversion/step3.py index f6240bea1fe..ca61f73eafe 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -115,14 +115,16 @@ def __init__(self, *args, **kwargs): # base num_hidden_layers so we don't reserve unused slots. n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) if n_nextn > 0 and not self.no_mtp: - self.block_count = int(self.hparams["num_hidden_layers"]) + n_nextn + self.block_count += n_nextn self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def index_tensors(self, remote_hf_model_id: str | None = None): # filter_tensors is a classmethod and can't reach self.hparams; stash # the trunk layer count here (before indexing runs) so it can detect # the appended MTP layers by index. - type(self)._n_main_layers = int(self.hparams["num_hidden_layers"]) + hparams = {**self.hparams, **self.hparams.get("text_config", {})} + key = next((k for k in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] if k in hparams), None) + type(self)._n_main_layers = hparams.get(key) return super().index_tensors(remote_hf_model_id=remote_hf_model_id) def set_gguf_parameters(self): @@ -221,7 +223,9 @@ def _pad(arr, n, default): @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: - name, gen = item + if (titem := super().filter_tensors(item)) is None: + return None + name, gen = titem # Map router bias (expert selection bias) to a GGUF bias tensor if name.endswith(".moe.router_bias"): @@ -250,7 +254,7 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca name = name.replace(".transformer.", ".") name = name.replace("shared_head.output", "shared_head.head") - return super().filter_tensors((name, gen)) + return name, gen def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.endswith("norm.weight"): From 0bbad317555c7a7af859117cb6ae88cf6c4ca545 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Jun 2026 14:13:41 +0200 Subject: [PATCH 8/8] dos2unix --- conversion/step3.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/conversion/step3.py b/conversion/step3.py index ca61f73eafe..59758ee0ad9 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -115,16 +115,16 @@ def __init__(self, *args, **kwargs): # base num_hidden_layers so we don't reserve unused slots. n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0)) if n_nextn > 0 and not self.no_mtp: - self.block_count += n_nextn + self.block_count += n_nextn self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def index_tensors(self, remote_hf_model_id: str | None = None): # filter_tensors is a classmethod and can't reach self.hparams; stash # the trunk layer count here (before indexing runs) so it can detect # the appended MTP layers by index. - hparams = {**self.hparams, **self.hparams.get("text_config", {})} - key = next((k for k in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] if k in hparams), None) - type(self)._n_main_layers = hparams.get(key) + hparams = {**self.hparams, **self.hparams.get("text_config", {})} + key = next((k for k in ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] if k in hparams), None) + type(self)._n_main_layers = hparams.get(key) return super().index_tensors(remote_hf_model_id=remote_hf_model_id) def set_gguf_parameters(self): @@ -223,9 +223,9 @@ def _pad(arr, n, default): @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: - if (titem := super().filter_tensors(item)) is None: - return None - name, gen = titem + if (titem := super().filter_tensors(item)) is None: + return None + name, gen = titem # Map router bias (expert selection bias) to a GGUF bias tensor if name.endswith(".moe.router_bias"): @@ -254,7 +254,7 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca name = name.replace(".transformer.", ".") name = name.replace("shared_head.output", "shared_head.head") - return name, gen + return name, gen def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.endswith("norm.weight"):