diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 95a8c26a3a..578e35ad4e 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -486,6 +486,122 @@ def _extra_args_set_spec_type(extra_args: Optional[Iterable[str]]) -> bool: return False +def _build_ngram_mod_flags( + caps: Optional[dict], + n_match: int = 24, + n_min: int = 48, + n_max: int = 64, +) -> list[str]: + """Emit the right ngram-mod knob flags for the running llama-server. + + Post-rename builds expose ``--spec-ngram-mod-n-{match,min,max}``; + pre-rename builds expose the legacy ``--spec-ngram-size-n`` / + ``--draft-min`` / ``--draft-max``. ``caps`` comes from + ``probe_server_capabilities``; ``ngram_mod_flavor`` tells us which + set is real (vs a removal-stub entry). Returns ``[]`` when neither + set is available so the caller can drop ngram-mod entirely. + """ + flavor = caps.get("ngram_mod_flavor") if caps else None + if flavor == "new": + return [ + "--spec-ngram-mod-n-match", + str(n_match), + "--spec-ngram-mod-n-min", + str(n_min), + "--spec-ngram-mod-n-max", + str(n_max), + ] + if flavor == "legacy": + # Legacy llama.cpp before the spec arg rename: same knobs lived + # under --spec-ngram-size-n (lookup length) and the generic + # --draft-min / --draft-max (ngram size N range). + return [ + "--spec-ngram-size-n", + str(n_match), + "--draft-min", + str(n_min), + "--draft-max", + str(n_max), + ] + return [] + + +# Canonical Speculative Decoding modes exposed by the Studio chat UI. +# The dropdown renders five options (auto, mtp, ngram, mtp+ngram, off); +# the load API also accepts legacy values that the original Switch and +# external callers emit (default, draft-mtp, ngram-mod, ngram-simple). +_CANONICAL_SPEC_MODES = {"auto", "mtp", "ngram", "mtp+ngram", "off", "ngram-simple"} +_LEGACY_SPEC_MODE_MAP = { + "default": "auto", + "draft-mtp": "mtp", + "ngram-mod": "ngram", +} + + +def _canonicalize_spec_mode(value): + """Map any accepted ``speculative_type`` input onto a canonical mode. + + Returns one of ``auto``, ``mtp``, ``ngram``, ``mtp+ngram``, ``off``, + ``ngram-simple``, or ``None`` (callers treat ``None`` as ``auto``). + Unknown strings collapse to ``auto`` so a stale UI value or typo + falls back to the safe platform-aware path. + """ + if value is None: + return None + if not isinstance(value, str): + return None + stripped = value.strip().lower() + if not stripped: + return None + if stripped in _CANONICAL_SPEC_MODES: + return stripped + if stripped in _LEGACY_SPEC_MODE_MAP: + return _LEGACY_SPEC_MODE_MAP[stripped] + # llama.cpp comma-chains are emitted by old persisted state e.g. + # "ngram-mod,draft-mtp"; collapse the most common one explicitly. + pieces = [p.strip() for p in stripped.split(",") if p.strip()] + has_mtp = any(p in ("mtp", "draft-mtp") for p in pieces) + has_ngram = any(p in ("ngram", "ngram-mod") for p in pieces) + if has_mtp and has_ngram: + return "mtp+ngram" + if has_mtp: + return "mtp" + if has_ngram: + return "ngram" + return "auto" + + +def _backfill_usage_from_timings(usage, timings): + """Synthesize ``usage`` from llama-server's ``timings`` when the + OpenAI-style usage block is missing or reports zero tokens. + + The Studio chat UI computes generation t/s from + ``meta.usage.completion_tokens / totalStreamTime``. llama-server + always populates ``timings.predicted_n`` (true decoded count) and + ``timings.prompt_n``, but the ``usage`` field on the final SSE chunk + can be absent or zero on some server builds / streaming + configurations, which makes the UI fall back to wall-clock t/s and + dilute speculative-decoding speedups. + """ + if not timings: + return usage + if usage and usage.get("completion_tokens"): + return usage + predicted_n = timings.get("predicted_n") + prompt_n = timings.get("prompt_n") + if predicted_n is None and prompt_n is None: + return usage + out = dict(usage or {}) + if not out.get("completion_tokens") and predicted_n is not None: + out["completion_tokens"] = predicted_n + if not out.get("prompt_tokens") and prompt_n is not None: + out["prompt_tokens"] = prompt_n + out["total_tokens"] = int(out.get("prompt_tokens") or 0) + int( + out.get("completion_tokens") or 0 + ) + return out + + class LlamaCppBackend: """ Manages a llama-server subprocess for GGUF model inference. @@ -520,6 +636,15 @@ def __init__(self): self._cache_type_kv: Optional[str] = None self._reasoning_default: bool = True self._speculative_type: Optional[str] = None + # Canonical UI-facing mode the user requested: one of + # ``auto``/``mtp``/``ngram``/``mtp+ngram``/``off``/``ngram-simple``. + # Round-tripped through the status API so the dropdown reflects + # the picked mode rather than the resolved internal flag set + # (auto on a 27B MTP GGUF resolves to draft-mtp but the dropdown + # should still read "Auto"). + self._requested_spec_mode: Optional[str] = None + # User-supplied --spec-draft-n-max override (None = platform default). + self._spec_draft_n_max: Optional[int] = None # KV-cache estimation fields (populated by _read_gguf_metadata) self._n_layers: Optional[int] = None self._n_kv_heads: Optional[int] = None @@ -800,6 +925,17 @@ def cache_type_kv(self) -> Optional[str]: def speculative_type(self) -> Optional[str]: return self._speculative_type + @property + def requested_spec_mode(self) -> Optional[str]: + """Canonical UI-facing mode the user requested (see field doc).""" + return self._requested_spec_mode + + @property + def spec_draft_n_max(self) -> Optional[int]: + """User --spec-draft-n-max override active on the load, or None + when the platform default (6 GPU / 3 CPU) is in effect.""" + return self._spec_draft_n_max + # ── Binary discovery ────────────────────────────────────────── @staticmethod @@ -926,11 +1062,31 @@ def probe_server_capabilities( cls, binary: Optional[str] = None ) -> dict[str, object]: """Parse `llama-server --help` for feature flags. Returns - {found, mtp_token, supports_mtp}. mtp_token is "draft-mtp" - (older) or "mtp" (renamed upstream), or None.""" + {found, mtp_token, supports_mtp, ngram_mod_flavor, + supports_ngram_mod, spec_draft_n_max_flag}. + + ``ngram_mod_flavor`` is ``"new"`` when the binary exposes the + post-rename ``--spec-ngram-mod-n-match / -n-min / -n-max`` as + real args, ``"legacy"`` when only the pre-rename + ``--spec-ngram-size-n / --draft-min / --draft-max`` are real + (the rename ships with stub removal entries for the legacy + names; we tell stubs apart by the "argument has been removed" + description), or ``None`` if neither set is usable. + + ``spec_draft_n_max_flag`` is the actual flag name the binary + accepts: ``--spec-draft-n-max`` on post-rename builds, or + ``--draft-max`` on legacy. ``None`` means n_max cannot be set. + """ bin_path = binary or cls._find_llama_server_binary() if not bin_path or not Path(bin_path).is_file(): - return {"found": False, "mtp_token": None, "supports_mtp": False} + return { + "found": False, + "mtp_token": None, + "supports_mtp": False, + "ngram_mod_flavor": None, + "supports_ngram_mod": False, + "spec_draft_n_max_flag": None, + } try: mtime = int(Path(bin_path).stat().st_mtime) except OSError: @@ -941,6 +1097,8 @@ def probe_server_capabilities( return cached mtp_token: Optional[str] = None + ngram_mod_flavor: Optional[str] = None + spec_draft_n_max_flag: Optional[str] = None try: result = subprocess.run( [bin_path, "--help"], @@ -950,6 +1108,52 @@ def probe_server_capabilities( check = False, ) help_text = (result.stdout or "") + "\n" + (result.stderr or "") + # Split into per-flag blocks: each --flag line plus its + # indented continuation lines, so the "argument has been + # removed" description sits with its flag. + blocks: dict[str, str] = {} + current_flags: list[str] = [] + current_desc: list[str] = [] + for line in help_text.splitlines(): + stripped = line.strip() + if stripped.startswith("-") and not line.startswith(" "): + # New flag line; flush previous. + if current_flags: + desc = " ".join(current_desc) + for f in current_flags: + blocks[f] = desc + current_flags = [] + current_desc = [stripped] + # Extract long-form flag tokens from the DECLARATION + # prefix only (comma-separated aliases). Stop at the + # first token that isn't itself a flag, so flag + # references inside descriptions are ignored. + for tok in re.split(r"[,\s]+", stripped): + if tok.startswith("--") and re.match( + r"--[A-Za-z][A-Za-z0-9_-]*$", tok + ): + current_flags.append(tok) + elif tok.startswith("-") and len(tok) > 1: + # short alias like -fa; keep scanning aliases. + continue + else: + # First non-flag token marks end of decl. + break + else: + current_desc.append(stripped) + if current_flags: + desc = " ".join(current_desc) + for f in current_flags: + blocks[f] = desc + + def _is_real(flag: str) -> bool: + """True if the flag exists AND is not a removal stub.""" + desc = blocks.get(flag) + if desc is None: + return False + return "argument has been removed" not in desc + + # MTP token detection from --spec-type line. spec_line = "" for line in help_text.splitlines(): if "--spec-type" in line: @@ -960,6 +1164,30 @@ def probe_server_capabilities( mtp_token = "draft-mtp" elif re.search(r"[|,\[]mtp[|,\]]", spec_line): mtp_token = "mtp" + + # ngram-mod flag flavor. Post-rename builds advertise both + # the new args (real) and the legacy ones (stubs); pre-rename + # builds only have the legacy ones as real. + new_ngram_real = ( + _is_real("--spec-ngram-mod-n-match") + and _is_real("--spec-ngram-mod-n-min") + and _is_real("--spec-ngram-mod-n-max") + ) + legacy_ngram_real = ( + _is_real("--spec-ngram-size-n") + and _is_real("--draft-max") + and _is_real("--draft-min") + ) + if new_ngram_real: + ngram_mod_flavor = "new" + elif legacy_ngram_real: + ngram_mod_flavor = "legacy" + + # n_max flag: prefer post-rename, fall back to legacy. + if _is_real("--spec-draft-n-max"): + spec_draft_n_max_flag = "--spec-draft-n-max" + elif _is_real("--draft-max"): + spec_draft_n_max_flag = "--draft-max" except (OSError, subprocess.SubprocessError) as exc: logger.debug(f"llama-server --help probe failed: {exc}") @@ -967,6 +1195,9 @@ def probe_server_capabilities( "found": True, "mtp_token": mtp_token, "supports_mtp": mtp_token is not None, + "ngram_mod_flavor": ngram_mod_flavor, + "supports_ngram_mod": ngram_mod_flavor is not None, + "spec_draft_n_max_flag": spec_draft_n_max_flag, } cls._capability_cache[cache_key] = info return info @@ -2265,6 +2496,7 @@ def load_model( chat_template_override: Optional[str] = None, cache_type_kv: Optional[str] = None, speculative_type: Optional[str] = None, + spec_draft_n_max: Optional[int] = None, n_threads: Optional[int] = None, n_gpu_layers: Optional[int] = None, # Accepted for caller compat, unused n_parallel: int = 1, @@ -2296,6 +2528,7 @@ def load_model( n_ctx = n_ctx, cache_type_kv = cache_type_kv, speculative_type = speculative_type, + spec_draft_n_max = spec_draft_n_max, chat_template_override = chat_template_override, extra_args = extra_args, is_vision = is_vision, @@ -2643,98 +2876,16 @@ def load_model( # (llama.cpp #22673). Auto-enabled via nextn_predict_layers, # fallback to -MTP in name. GPU: MTP-only. CPU/Mac: chain # with ngram-mod. See unsloth.ai/docs/models/qwen3.6#mtp-guide. - _valid_spec_types = {"ngram-simple", "ngram-mod", "draft-mtp"} - normalized_spec = ( - speculative_type.lower().strip() if speculative_type else None - ) - is_mtp_model = bool(self._nextn_predict_layers) or ( - _is_mtp_model_name(model_identifier, model_path) + spec_flags = self._build_speculative_flags( + speculative_type = speculative_type, + spec_draft_n_max = spec_draft_n_max, + extra_args = extra_args, + model_identifier = model_identifier, + model_path = model_path, + gpus = bool(gpus), + binary = binary, ) - user_owns_spec_type = _extra_args_set_spec_type(extra_args) - # Auto-promote unset/"default" to draft-mtp on MTP GGUFs. - # llama.cpp #22673: MTP is compatible with mmproj, so the - # vision gate previously here was wrong. - if ( - is_mtp_model - and not user_owns_spec_type - and normalized_spec in (None, "", "default") - ): - normalized_spec = "draft-mtp" - if user_owns_spec_type: - # User --spec-type wins (it accumulates if repeated). - normalized_spec = None - self._speculative_type = None - if normalized_spec and normalized_spec != "off": - if normalized_spec == "default": - cmd.append("--spec-default") - self._speculative_type = "default" - elif normalized_spec == "draft-mtp": - # Probe binary; fail gracefully on outdated prebuilts. - # Use whichever token the binary advertises - # (older: draft-mtp; renamed upstream: mtp). - caps = self.probe_server_capabilities(binary) - mtp_token = caps.get("mtp_token") if caps else None - if not mtp_token: - logger.warning( - "MTP GGUF detected but llama-server lacks " - "--spec-type mtp/draft-mtp; run " - "`unsloth studio update`. Loading without " - "speculative decoding." - ) - self._speculative_type = None - else: - if gpus: - cmd.extend( - [ - "--spec-type", - mtp_token, - "--spec-draft-n-max", - "6", - ] - ) - else: - # CPU/Mac: chain ngram-mod + MTP in one - # comma-separated --spec-type (not repeated). - # ngram-mod knobs match llama.cpp defaults - # (n-match 24, n-min 48, n-max 64). - cmd.extend( - [ - "--spec-type", - f"ngram-mod,{mtp_token}", - "--spec-draft-n-max", - "3", - "--spec-ngram-mod-n-match", - "24", - "--spec-ngram-mod-n-min", - "48", - "--spec-ngram-mod-n-max", - "64", - ] - ) - self._speculative_type = "draft-mtp" - logger.info( - f"Spec decoding: {mtp_token} ({'GPU' if gpus else 'CPU/Mac'})" - ) - elif normalized_spec in _valid_spec_types: - cmd.extend(["--spec-type", normalized_spec]) - if normalized_spec == "ngram-mod": - # llama.cpp defaults; legacy --spec-ngram-size-n - # / --draft-{min,max} were removed for ngram-mod. - cmd.extend( - [ - "--spec-ngram-mod-n-match", - "24", - "--spec-ngram-mod-n-min", - "48", - "--spec-ngram-mod-n-max", - "64", - ] - ) - self._speculative_type = normalized_spec - else: - self._speculative_type = None - else: - self._speculative_type = None + cmd.extend(spec_flags) # Apply custom chat template override if provided self._chat_template_override = chat_template_override @@ -3066,6 +3217,220 @@ def load_model( ) return True + def _build_speculative_flags( + self, + *, + speculative_type: Optional[str], + spec_draft_n_max: Optional[int], + extra_args: Optional[List[str]], + model_identifier: str, + model_path: Optional[str], + gpus: bool, + binary: Optional[str], + ) -> List[str]: + """Return the llama-server flag list for the requested spec mode. + + Side effects: sets ``self._speculative_type`` (resolved internal + emit), ``self._requested_spec_mode`` (canonical UI mode for the + status round-trip), and ``self._spec_draft_n_max`` (user override + only; None when the platform default applies). + + Speculative decoding (n-gram self-speculation, zero VRAM cost): + ngram-mod uses a ~16 MB shared hash pool, constant memory / + complexity, variable draft lengths. Helps most when the model + repeats existing text (code refactor, summarisation, reasoning). + For general chat with low repetition, overhead is ~5 ms. + + Benchmarks from upstream llama.cpp speculative-decoding PRs: + Scenario | Without | With | Speedup + gpt-oss-120b code refactor | 181 t/s | 446 t/s | 2.5x + Qwen3-235B offloaded | 12 t/s | 21 t/s | 1.8x + gpt-oss-120b repeat (92% accept)| 181 t/s | 814 t/s | 4.5x + + Sub-3B dense MTP regresses vs spec-off because the draft head's + per-token cost exceeds the acceptance savings at this scale. + Q4_K_XL clean bench (each prompt once after an unrelated warmup) + on B200 + x86 CPU: + 0.8B GPU: draft-mtp n=2 = 0.58x vs OFF; ngram-only = 1.10x + 2B GPU: draft-mtp n=2 = 0.82x vs OFF; OFF or ngram = 1.00x + 0.8B CPU: chained n=2 = 0.86x vs OFF; ngram-only = 1.19x + 2B CPU: chained n=2 = 0.83x vs OFF; ngram-only = 1.01x + 4B+ GPU/CPU: spec on is a net win (1.08x-1.46x). + Auto falls back to ngram-mod (zero-VRAM, near-zero idle cost on + diverse content); forced MTP variants engage anyway and just log + a warning per the user's choice. + """ + flags: List[str] = [] + # Reset; emit branches re-set on the resolved emission. + self._spec_draft_n_max = None + self._speculative_type = None + + # Canonical UI-facing requested mode: auto / mtp / ngram / + # mtp+ngram / off / ngram-simple. Legacy values are mapped via + # _canonicalize_spec_mode (default->auto, draft-mtp->mtp, + # ngram-mod->ngram, "ngram-mod,draft-mtp"->mtp+ngram). + canonical_mode = _canonicalize_spec_mode(speculative_type) + is_mtp_model = bool(self._nextn_predict_layers) or ( + _is_mtp_model_name(model_identifier, model_path) + ) + user_owns_spec_type = _extra_args_set_spec_type(extra_args) + _mtp_size_b = _extract_model_size_b(model_identifier) + _mtp_too_small = _mtp_size_b is not None and _mtp_size_b < 3.0 + + if user_owns_spec_type: + # User --spec-type in extra_args wins outright; suppress + # auto-emit so we don't emit a duplicate / conflicting + # spec block. Record requested mode as None. + self._requested_spec_mode = None + return flags + + effective_mode = canonical_mode or "auto" + self._requested_spec_mode = effective_mode + + def _resolved_draft_n_max() -> int: + # User override wins; else platform default (the B200 / x86 + # clean-sweep sweet spot from PR #5582 is n=2 GPU, n=3 CPU; + # raising past 3 starts to regress on essay-style + # low-acceptance prompts). + if spec_draft_n_max is not None: + n = int(spec_draft_n_max) + self._spec_draft_n_max = n + return n + return 2 if gpus else 3 + + def _emit_mtp(*, chain_ngram: bool) -> bool: + """Append --spec-type mtp[/draft-mtp][,ngram-mod] + n-max.""" + caps = self.probe_server_capabilities(binary) + mtp_token = caps.get("mtp_token") if caps else None + if not mtp_token: + logger.warning( + "Requested MTP speculative decoding but " + "llama-server lacks --spec-type mtp/draft-mtp; " + "run `unsloth studio update`. Loading without " + "speculative decoding." + ) + return False + draft_n_max = _resolved_draft_n_max() + n_max_flag = caps.get("spec_draft_n_max_flag") or "--spec-draft-n-max" + if chain_ngram: + ngram_knobs = _build_ngram_mod_flags(caps) + if ngram_knobs: + spec_value = f"ngram-mod,{mtp_token}" + else: + logger.warning( + "llama-server lacks ngram-mod tuning " + "flags; loading MTP only (no ngram chain)" + ) + spec_value = mtp_token + flags.extend( + [ + "--spec-type", + spec_value, + n_max_flag, + str(draft_n_max), + ] + ) + flags.extend(ngram_knobs) + else: + flags.extend( + [ + "--spec-type", + mtp_token, + n_max_flag, + str(draft_n_max), + ] + ) + self._speculative_type = "draft-mtp" + chain_label = "chained ngram-mod" if chain_ngram else "MTP-only" + logger.info(f"Spec decoding: {mtp_token} ({chain_label})") + return True + + def _emit_ngram_mod() -> bool: + """Append --spec-type ngram-mod + flag-set knobs.""" + ngram_caps = self.probe_server_capabilities(binary) + ngram_knobs = _build_ngram_mod_flags(ngram_caps) + flags.extend(["--spec-type", "ngram-mod"]) + if not ngram_knobs: + logger.warning( + "llama-server lacks ngram-mod tuning " + "flags; loading without --spec-ngram-mod-* knobs" + ) + flags.extend(ngram_knobs) + self._speculative_type = "ngram-mod" + logger.info("Spec decoding: ngram-mod") + return True + + if effective_mode == "off": + return flags # nothing to emit + if effective_mode == "ngram-simple": + flags.extend(["--spec-type", "ngram-simple"]) + self._speculative_type = "ngram-simple" + return flags + if effective_mode == "ngram": + _emit_ngram_mod() + return flags + if effective_mode == "mtp": + if _mtp_too_small: + logger.warning( + f"Forcing MTP on a {_mtp_size_b:.1f}B model; " + "the bench shows draft-mtp regresses below 3B. " + "Engaging anyway (user override)." + ) + elif not is_mtp_model: + logger.warning( + "Forcing MTP on a non-MTP GGUF; llama-server may " + "fall back to spec-off if no nextn head is present. " + "Engaging anyway (user override)." + ) + _emit_mtp(chain_ngram = False) + return flags + if effective_mode == "mtp+ngram": + if _mtp_too_small: + logger.warning( + f"Forcing MTP+Ngram on a {_mtp_size_b:.1f}B model; " + "the bench shows the chain regresses below 3B. " + "Engaging anyway (user override)." + ) + elif not is_mtp_model: + logger.warning( + "Forcing MTP+Ngram on a non-MTP GGUF; llama-server " + "may fall back to ngram-only if no nextn head is " + "present. Engaging anyway (user override)." + ) + _emit_mtp(chain_ngram = True) + return flags + + # effective_mode == "auto": today's promotion path. llama.cpp + # #22673: MTP is compatible with mmproj, so there's no vision gate. + if is_mtp_model and not _mtp_too_small: + # GPU: MTP-only. CPU/Mac: chain ngram-mod + MTP. + _emit_mtp(chain_ngram = not gpus) + elif is_mtp_model and _mtp_too_small: + # Sub-3B fallback: drop the MTP draft head, keep ngram-mod + # when the binary supports it. + _small_caps = self.probe_server_capabilities(binary) + if _small_caps.get("supports_ngram_mod"): + logger.info( + f"MTP GGUF detected but model size {_mtp_size_b:.1f}B " + "is below the 3B speedup threshold; using ngram-mod " + "only (zero-VRAM, no draft head). Override via " + "--spec-type or the Studio Speculative Decoding " + "dropdown." + ) + _emit_ngram_mod() + else: + logger.info( + f"MTP GGUF detected but model size {_mtp_size_b:.1f}B " + "is below the 3B speedup threshold and the bundled " + "llama-server does not advertise ngram-mod; " + "auto-disabling speculative decoding." + ) + else: + # Non-MTP model: let llama-server choose its default strategy. + flags.append("--spec-default") + self._speculative_type = "default" + return flags + def _already_in_target_state( self, *, @@ -3078,6 +3443,7 @@ def _already_in_target_state( extra_args: Optional[List[str]], is_vision: bool, gguf_path: Optional[str] = None, + spec_draft_n_max: Optional[int] = None, ) -> bool: """True iff the live server already satisfies these load kwargs. @@ -3114,18 +3480,28 @@ def _norm(value): if _norm(self._cache_type_kv) != _norm(cache_type_kv): return False - # Mirror load_model's auto-promotion. Vision is no longer a - # spec blocker (llama.cpp #22673: MTP is compatible with mmproj). - raw_spec = _norm(speculative_type) - req_spec = raw_spec or "off" + # Compare on the canonical UI-facing mode the user requested. + # When extra_args carries --spec-type, the route-layer code paths + # bypass the dropdown anyway and the backend stores + # _requested_spec_mode = None; the request mirrors that by + # canonicalising to None. + if _extra_args_set_spec_type(extra_args): + req_mode = None + else: + req_mode = _canonicalize_spec_mode(speculative_type) or "auto" + backend_mode = self._requested_spec_mode + if req_mode != backend_mode: + return False + + # spec_draft_n_max only matters when an MTP variant is actually + # engaged. Compare on the resolved spec rather than the requested + # mode so an Auto request that auto-promoted to draft-mtp under + # the hood still bounces a reload when the user changes n_max. if ( - raw_spec in (None, "default") - and _is_mtp_model_name(model_identifier, gguf_path) - and not _extra_args_set_spec_type(extra_args) + self._speculative_type == "draft-mtp" + and spec_draft_n_max is not None + and int(spec_draft_n_max) != (self._spec_draft_n_max or 0) ): - req_spec = "draft-mtp" - backend_spec = _norm(self._speculative_type) or "off" - if req_spec != backend_spec: return False if (self._chat_template_override or None) != (chat_template_override or None): @@ -3194,6 +3570,8 @@ def unload_model(self) -> bool: self._supports_tools = False self._cache_type_kv = None self._speculative_type = None + self._requested_spec_mode = None + self._spec_draft_n_max = None self._n_layers = None self._n_kv_heads = None self._n_kv_heads_by_layer = None @@ -3935,6 +4313,9 @@ def generate_chat_completion( if _stream_done: break # exit outer for if _metadata_usage or _metadata_timings: + _metadata_usage = _backfill_usage_from_timings( + _metadata_usage, _metadata_timings + ) yield { "type": "metadata", "usage": _metadata_usage, @@ -4387,7 +4768,10 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: } ) # Accumulate tokens and timing from this iteration - _fu_r = _iter_usage or {} + _fu_r = ( + _backfill_usage_from_timings(_iter_usage, _iter_timings) + or {} + ) _accumulated_completion_tokens += _fu_r.get( "completion_tokens", 0 ) @@ -4399,7 +4783,10 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: # Content was already streamed. Yield metadata. yield {"type": "status", "text": ""} - _fu = _iter_usage or {} + _fu = ( + _backfill_usage_from_timings(_iter_usage, _iter_timings) + or {} + ) _fc = _fu.get("completion_tokens", 0) _fp = _fu.get("prompt_tokens", 0) _tc = _fc + _accumulated_completion_tokens @@ -4489,7 +4876,10 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: ) if content_accum: yield {"type": "content", "text": content_accum} - _fu = _iter_usage or {} + _fu = ( + _backfill_usage_from_timings(_iter_usage, _iter_timings) + or {} + ) _fc = _fu.get("completion_tokens", 0) _fp = _fu.get("prompt_tokens", 0) _tc = _fc + _accumulated_completion_tokens @@ -4523,9 +4913,9 @@ def _strip_tool_markup(text: str, *, final: bool = False) -> str: return # ── Execute tool calls ── - _accumulated_completion_tokens += (_iter_usage or {}).get( - "completion_tokens", 0 - ) + _accumulated_completion_tokens += ( + _backfill_usage_from_timings(_iter_usage, _iter_timings) or {} + ).get("completion_tokens", 0) _it = _iter_timings or {} _accumulated_predicted_ms += _it.get("predicted_ms", 0) _accumulated_predicted_n += _it.get("predicted_n", 0) diff --git a/studio/backend/core/inference/llama_server_args.py b/studio/backend/core/inference/llama_server_args.py index 572ac2ceda..d8b7eb383e 100644 --- a/studio/backend/core/inference/llama_server_args.py +++ b/studio/backend/core/inference/llama_server_args.py @@ -148,6 +148,8 @@ def is_managed_flag(flag: str) -> bool: # MTP path (llama.cpp #22673). "--spec-draft-n-max", "--spec-draft-n-min", + "--spec-draft-p-min", + "--spec-draft-p-split", "--spec-ngram-mod-n-match", "--spec-ngram-mod-n-min", "--spec-ngram-mod-n-max", diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index 99d1df37b6..e32d134628 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -70,7 +70,28 @@ def normalize_blank_chat_template_override( ) speculative_type: Optional[str] = Field( None, - description = "Speculative decoding mode for GGUF models (e.g. 'ngram-simple', 'ngram-mod'). Ignored for non-GGUF and vision models.", + description = ( + "Speculative decoding mode for GGUF models. Canonical values: " + "'auto' (platform-aware: MTP on MTP GGUFs, ngram-mod fallback " + "for sub-3B), 'mtp' (force draft-mtp only on both GPU and CPU), " + "'ngram' (force ngram-mod only), 'mtp+ngram' (force " + "ngram-mod+draft-mtp chain on both platforms), 'off' (disabled). " + "Legacy values 'default' (-> auto), 'draft-mtp' (-> mtp), " + "'ngram-mod' (-> ngram), and 'ngram-simple' (kept as-is) are " + "still accepted. Ignored for non-GGUF and vision models." + ), + ) + spec_draft_n_max: Optional[int] = Field( + None, + ge = 1, + le = 16, + description = ( + "Max draft tokens per step for MTP speculative decoding " + "(--spec-draft-n-max). Defaults to 2 on GPU and 3 on CPU/Mac " + "when unset (upstream-bench sweet spot for dense Qwen3.6 MTP " + "quants). Only applied when speculative_type resolves to " + "'mtp' or 'mtp+ngram'." + ), ) llama_extra_args: Optional[List[str]] = Field( None, @@ -218,7 +239,19 @@ class LoadResponse(BaseModel): ) speculative_type: Optional[str] = Field( None, - description = "Active speculative decoding mode (e.g. 'ngram-simple', 'ngram-mod'), or None if disabled", + description = ( + "Canonical UI-facing requested speculative decoding mode " + "('auto' / 'mtp' / 'ngram' / 'mtp+ngram' / 'off' / " + "'ngram-simple'), round-tripped from the original LoadRequest " + "via _canonicalize_spec_mode. None when no model is loaded." + ), + ) + spec_draft_n_max: Optional[int] = Field( + None, + description = ( + "Active --spec-draft-n-max for MTP speculative decoding, or " + "None when the platform default is in effect." + ), ) @@ -340,7 +373,19 @@ class InferenceStatusResponse(BaseModel): ) speculative_type: Optional[str] = Field( None, - description = "Active speculative decoding mode (e.g. 'ngram-simple', 'ngram-mod'), or None if disabled", + description = ( + "Canonical UI-facing requested speculative decoding mode " + "('auto' / 'mtp' / 'ngram' / 'mtp+ngram' / 'off' / " + "'ngram-simple'), round-tripped from the original LoadRequest. " + "None when no model is loaded." + ), + ) + spec_draft_n_max: Optional[int] = Field( + None, + description = ( + "Active --spec-draft-n-max for MTP speculative decoding, or " + "None when the platform default is in effect." + ), ) llama_cpp_supports_mtp: bool = Field( True, diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index 607245467c..8bb2dc6d80 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -117,6 +117,7 @@ def _friendly_error(exc: Exception) -> str: LlamaCppBackend, _DEFAULT_MAX_TOKENS_FLOOR, _DEFAULT_T_MAX_PREDICT_MS, + _canonicalize_spec_mode, _hf_offline_if_dns_dead, detect_reasoning_flags, ) @@ -143,6 +144,7 @@ def _friendly_error(exc: Exception) -> str: LlamaCppBackend, _DEFAULT_MAX_TOKENS_FLOOR, _DEFAULT_T_MAX_PREDICT_MS, + _canonicalize_spec_mode, _hf_offline_if_dns_dead, detect_reasoning_flags, ) @@ -441,12 +443,17 @@ def _request_matches_loaded_settings( # spec on ``not is_vision``), so treat the request as ``off`` against # the backend's ``None`` to avoid forcing a redundant reload. if llama_backend.is_vision: - req_spec = "off" + req_mode = "off" else: - req_spec = _normalise_settings_str(request.speculative_type) or "off" - backend_spec = _normalise_settings_str(llama_backend.speculative_type) or "off" - if req_spec != backend_spec: + req_mode = _canonicalize_spec_mode(request.speculative_type) or "auto" + backend_mode = llama_backend.requested_spec_mode or "auto" + if req_mode != backend_mode: return False + # spec_draft_n_max only matters when an MTP variant is engaged; None + # means "platform default" and matches whatever the backend chose. + if backend_mode in ("mtp", "mtp+ngram") and request.spec_draft_n_max is not None: + if int(request.spec_draft_n_max) != (llama_backend.spec_draft_n_max or 0): + return False if (request.chat_template_override or None) != ( llama_backend.chat_template_override or None ): @@ -584,7 +591,8 @@ async def load_model( reasoning_always_on = llama_backend.reasoning_always_on, supports_preserve_thinking = llama_backend.supports_preserve_thinking, chat_template = llama_backend.chat_template, - speculative_type = llama_backend.speculative_type, + speculative_type = llama_backend.requested_spec_mode, + spec_draft_n_max = llama_backend.spec_draft_n_max, ) else: if ( @@ -724,7 +732,10 @@ async def load_model( llama_backend.extra_args, strip_context = "max_seq_length" in fields_set, strip_cache = "cache_type_kv" in fields_set, - strip_spec = "speculative_type" in fields_set, + strip_spec = ( + "speculative_type" in fields_set + or "spec_draft_n_max" in fields_set + ), strip_template = "chat_template_override" in fields_set, ) try: @@ -765,6 +776,7 @@ async def load_model( chat_template_override = request.chat_template_override, cache_type_kv = request.cache_type_kv, speculative_type = request.speculative_type, + spec_draft_n_max = request.spec_draft_n_max, n_parallel = _n_parallel, extra_args = extra_llama_args, ) @@ -788,6 +800,7 @@ async def load_model( chat_template_override = request.chat_template_override, cache_type_kv = request.cache_type_kv, speculative_type = request.speculative_type, + spec_draft_n_max = request.spec_draft_n_max, n_parallel = _n_parallel, extra_args = extra_llama_args, ) @@ -846,7 +859,8 @@ async def load_model( supports_tools = llama_backend.supports_tools, cache_type_kv = llama_backend.cache_type_kv, chat_template = llama_backend.chat_template, - speculative_type = llama_backend.speculative_type, + speculative_type = llama_backend.requested_spec_mode, + spec_draft_n_max = llama_backend.spec_draft_n_max, ) # ── Standard path: load via Unsloth/transformers ────────── @@ -1345,7 +1359,8 @@ async def get_status( native_context_length = llama_backend.native_context_length, cache_type_kv = llama_backend.cache_type_kv, chat_template_override = llama_backend.chat_template_override, - speculative_type = llama_backend.speculative_type, + speculative_type = llama_backend.requested_spec_mode, + spec_draft_n_max = llama_backend.spec_draft_n_max, llama_cpp_supports_mtp = _supports_mtp, llama_cpp_prebuilt_stale = _stale, llama_cpp_installed_tag = _installed_tag, diff --git a/studio/backend/tests/test_gguf_reload_inheritance.py b/studio/backend/tests/test_gguf_reload_inheritance.py index 4b0b450cb0..1a663725d9 100644 --- a/studio/backend/tests/test_gguf_reload_inheritance.py +++ b/studio/backend/tests/test_gguf_reload_inheritance.py @@ -78,6 +78,7 @@ def _loaded_backend(**overrides): backend._requested_n_ctx = 8192 backend._cache_type_kv = None backend._speculative_type = None + backend._requested_spec_mode = "auto" backend._chat_template_override = None backend._is_vision = False backend._extra_args = None diff --git a/studio/backend/tests/test_llama_cpp_mtp_detection.py b/studio/backend/tests/test_llama_cpp_mtp_detection.py index 7da633201f..4a8276adc0 100644 --- a/studio/backend/tests/test_llama_cpp_mtp_detection.py +++ b/studio/backend/tests/test_llama_cpp_mtp_detection.py @@ -52,6 +52,9 @@ from core.inference.llama_cpp import ( LlamaCppBackend, + _backfill_usage_from_timings, + _build_ngram_mod_flags, + _canonicalize_spec_mode, _extra_args_set_spec_type, _is_mtp_model_name, ) @@ -186,6 +189,10 @@ def _mtp_backend(**overrides): backend._requested_n_ctx = 8192 backend._cache_type_kv = None backend._speculative_type = "draft-mtp" + # Default fixture simulates Auto having auto-promoted to draft-mtp. + # Individual tests override _requested_spec_mode when they want a + # forced mode or the user---spec-type-extra-args path. + backend._requested_spec_mode = "auto" backend._chat_template_override = None backend._is_vision = False backend._extra_args = None @@ -233,9 +240,16 @@ def test_already_in_target_state_matches_when_request_uses_default_for_mtp_model ) -def test_already_in_target_state_non_mtp_model_unaffected(): - # Promotion is gated on the name; non-MTP must still mismatch req=None. - backend = _mtp_backend(_model_identifier = "unsloth/Qwen3.6-27B-GGUF") +def test_already_in_target_state_auto_request_matches_auto_backend_for_non_mtp_model(): + # Under the requested-mode round-trip model, Auto requested against an + # Auto-recorded backend matches regardless of model name. The underlying + # resolved emission (--spec-default vs draft-mtp) is handled by the + # backend's own load path and reflected in _speculative_type; the + # short-circuit comparison only cares whether the *intent* changed. + backend = _mtp_backend( + _model_identifier = "unsloth/Qwen3.6-27B-GGUF", + _speculative_type = "default", + ) assert ( backend._already_in_target_state( gguf_path = None, @@ -248,7 +262,7 @@ def test_already_in_target_state_non_mtp_model_unaffected(): extra_args = None, is_vision = False, ) - is False + is True ) @@ -308,6 +322,7 @@ def test_already_in_target_state_user_spec_type_override_matches_clean_backend() # User --spec-type none suppressed auto-MTP; repeat /load must not re-promote. backend = _mtp_backend( _speculative_type = None, + _requested_spec_mode = None, _extra_args = ["--spec-type", "none"], ) assert ( @@ -389,12 +404,17 @@ def test_already_in_target_state_vision_mtp_default_matches(): ) -def test_already_in_target_state_vision_non_mtp_unaffected(): - # Vision non-MTP repo (no -MTP marker) must still mismatch req=None - # against a backend running draft-mtp. +def test_already_in_target_state_vision_off_matches_vision_backend(): + # Vision loads silently drop speculative decoding at the route level + # (_request_matches_loaded_settings overrides req to "off"). At the + # llama_cpp.py level, _already_in_target_state compares canonical + # requested modes; a vision backend recorded with _requested_spec_mode + # = "off" matches a req of "off" or None+vision. backend = _mtp_backend( _model_identifier = "unsloth/Qwen3-VL-4B-Instruct-GGUF", _is_vision = True, + _speculative_type = None, + _requested_spec_mode = "off", ) assert ( backend._already_in_target_state( @@ -403,12 +423,12 @@ def test_already_in_target_state_vision_non_mtp_unaffected(): hf_variant = "Q4_K_M", n_ctx = 8192, cache_type_kv = None, - speculative_type = None, + speculative_type = "off", chat_template_override = None, extra_args = None, is_vision = True, ) - is False + is True ) @@ -482,10 +502,17 @@ def _make_fake_llama_server(path: Path, help_text: str) -> Path: return path +_NEEDS_BASH = pytest.mark.skipif( + sys.platform == "win32", + reason = "fake llama-server is a bash stub; Windows has no direct executor", +) + + def _clear_caps_cache(): LlamaCppBackend._capability_cache.clear() +@_NEEDS_BASH def test_probe_server_capabilities_detects_draft_mtp(tmp_path): # Original naming from llama.cpp #22673. fake = _make_fake_llama_server( @@ -500,6 +527,7 @@ def test_probe_server_capabilities_detects_draft_mtp(tmp_path): assert caps["supports_mtp"] is True +@_NEEDS_BASH def test_probe_server_capabilities_detects_renamed_mtp(tmp_path): # Renamed upstream: draft-mtp -> mtp. fake = _make_fake_llama_server( @@ -513,6 +541,7 @@ def test_probe_server_capabilities_detects_renamed_mtp(tmp_path): assert caps["supports_mtp"] is True +@_NEEDS_BASH def test_probe_server_capabilities_reports_outdated_binary(tmp_path): # Pre-MTP llama.cpp: only ngram variants. fake = _make_fake_llama_server( @@ -533,6 +562,130 @@ def test_probe_server_capabilities_handles_missing_binary(): assert caps["supports_mtp"] is False +# ngram-mod flag flavor detection (new vs legacy llama-server). + +# Help-text fixtures mirror the actual `llama-server --help` block +# layout (flag on its own line; description indented underneath). +_POST_RENAME_HELP = """\ +--spec-draft-n-max N number of tokens to draft for speculative decoding (default: 16) + (env: LLAMA_ARG_SPEC_DRAFT_N_MAX) +--spec-draft-n-min N minimum number of draft tokens to use for speculative decoding (default: 0) + (env: LLAMA_ARG_SPEC_DRAFT_N_MIN) +--spec-draft-p-min, --draft-p-min P minimum speculative decoding probability (greedy) (default: 0.75) + (env: LLAMA_ARG_SPEC_DRAFT_P_MIN) +--spec-ngram-mod-n-min N minimum number of ngram tokens (default: 48) +--spec-ngram-mod-n-max N maximum number of ngram tokens (default: 64) +--spec-ngram-mod-n-match N ngram-mod lookup length (default: 24) +--spec-type none,draft-simple,draft-mtp,ngram-mod comma-separated list of types of speculative decoding to use + (env: LLAMA_ARG_SPEC_TYPE) +--draft, --draft-n, --draft-max N the argument has been removed. use --spec-draft-n-max or --spec-ngram-mod-n-max + (env: LLAMA_ARG_DRAFT_MAX) +--draft-min, --draft-n-min N the argument has been removed. use --spec-draft-n-min or --spec-ngram-mod-n-min + (env: LLAMA_ARG_DRAFT_MIN) +--spec-ngram-size-n N the argument has been removed. use the respective --spec-ngram-*-size-n or --spec-ngram-mod-n-match +""" + +_LEGACY_HELP = """\ +--draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 8) + (env: LLAMA_ARG_DRAFT_MAX) +--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding (default: 0) + (env: LLAMA_ARG_DRAFT_MIN) +--spec-ngram-size-n N ngram lookup length (default: 24) +--spec-type none,ngram-mod,ngram-simple comma-separated list of types of speculative decoding to use +""" + + +@_NEEDS_BASH +def test_probe_detects_post_rename_ngram_mod_flavor(tmp_path): + fake = _make_fake_llama_server(tmp_path / "llama-server", _POST_RENAME_HELP) + _clear_caps_cache() + caps = LlamaCppBackend.probe_server_capabilities(str(fake)) + assert caps["found"] is True + assert caps["ngram_mod_flavor"] == "new" + assert caps["supports_ngram_mod"] is True + assert caps["spec_draft_n_max_flag"] == "--spec-draft-n-max" + + +@_NEEDS_BASH +def test_probe_detects_legacy_ngram_mod_flavor(tmp_path): + fake = _make_fake_llama_server(tmp_path / "llama-server", _LEGACY_HELP) + _clear_caps_cache() + caps = LlamaCppBackend.probe_server_capabilities(str(fake)) + assert caps["found"] is True + assert caps["ngram_mod_flavor"] == "legacy" + assert caps["supports_ngram_mod"] is True + assert caps["spec_draft_n_max_flag"] == "--draft-max" + + +@_NEEDS_BASH +def test_probe_ignores_removal_stub_descriptions(tmp_path): + # Post-rename binary: legacy flags are present but with + # "argument has been removed" descriptions; must not be detected + # as legacy. + fake = _make_fake_llama_server(tmp_path / "llama-server", _POST_RENAME_HELP) + _clear_caps_cache() + caps = LlamaCppBackend.probe_server_capabilities(str(fake)) + assert caps["ngram_mod_flavor"] == "new" + + +@_NEEDS_BASH +def test_probe_no_ngram_mod_on_minimal_binary(tmp_path): + # Pre-anything: neither set present. + fake = _make_fake_llama_server( + tmp_path / "llama-server", + "--spec-type none\n--threads N\n", + ) + _clear_caps_cache() + caps = LlamaCppBackend.probe_server_capabilities(str(fake)) + assert caps["ngram_mod_flavor"] is None + assert caps["supports_ngram_mod"] is False + + +def test_build_ngram_mod_flags_new(): + flags = _build_ngram_mod_flags({"ngram_mod_flavor": "new"}) + assert flags == [ + "--spec-ngram-mod-n-match", + "24", + "--spec-ngram-mod-n-min", + "48", + "--spec-ngram-mod-n-max", + "64", + ] + + +def test_build_ngram_mod_flags_legacy(): + flags = _build_ngram_mod_flags({"ngram_mod_flavor": "legacy"}) + assert flags == [ + "--spec-ngram-size-n", + "24", + "--draft-min", + "48", + "--draft-max", + "64", + ] + + +def test_build_ngram_mod_flags_empty_when_unsupported(): + assert _build_ngram_mod_flags({"ngram_mod_flavor": None}) == [] + assert _build_ngram_mod_flags(None) == [] + assert _build_ngram_mod_flags({}) == [] + + +def test_build_ngram_mod_flags_respects_custom_values(): + flags = _build_ngram_mod_flags( + {"ngram_mod_flavor": "new"}, n_match = 16, n_min = 24, n_max = 32 + ) + assert flags == [ + "--spec-ngram-mod-n-match", + "16", + "--spec-ngram-mod-n-min", + "24", + "--spec-ngram-mod-n-max", + "32", + ] + + +@_NEEDS_BASH def test_probe_server_capabilities_caches_by_mtime(tmp_path): # Same (path, mtime) -> cache hit. Bumped mtime -> re-probe. fake = _make_fake_llama_server( @@ -555,3 +708,493 @@ def test_probe_server_capabilities_caches_by_mtime(tmp_path): caps2 = LlamaCppBackend.probe_server_capabilities(str(fake)) assert caps2["mtp_token"] == "draft-mtp" assert caps2["supports_mtp"] is True + + +# spec_draft_n_max plumbing (first-class --spec-draft-n-max override). + + +def test_already_in_target_state_matches_when_draft_n_max_unset(): + # None on the request means "platform default"; matches any backend. + backend = _mtp_backend(_spec_draft_n_max = None) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.6-27B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + spec_draft_n_max = None, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +def test_already_in_target_state_matches_when_draft_n_max_equals_backend(): + backend = _mtp_backend(_spec_draft_n_max = 4) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.6-27B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + spec_draft_n_max = 4, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +def test_already_in_target_state_mismatches_when_draft_n_max_differs(): + backend = _mtp_backend(_spec_draft_n_max = 4) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.6-27B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + spec_draft_n_max = 8, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is False + ) + + +def test_already_in_target_state_draft_n_max_ignored_when_not_mtp(): + # ngram-mod backend; spec_draft_n_max is MTP-only and must not force + # a reload against a non-MTP active spec. + backend = _mtp_backend( + _speculative_type = "ngram-mod", + _requested_spec_mode = "ngram", + _spec_draft_n_max = None, + ) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.6-27B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = "ngram-mod", + spec_draft_n_max = 8, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +# Sub-3B MTP gate -- tiny dense models regress with the MTP draft +# head, so load_model falls back to ngram-mod (when the binary supports +# it) instead of draft-mtp. The reload-skip mirror must follow the +# same fallback so a sub-3B reload-with-default does not bounce a +# correctly-configured ngram-mod / off backend. + + +def _patch_probe(monkeypatch, ngram_supported): + """Force probe_server_capabilities to a deterministic result so + tests don't depend on whatever llama-server happens to be on PATH.""" + fake = { + "found": True, + "mtp_token": "draft-mtp", + "supports_mtp": True, + "ngram_mod_flavor": "new" if ngram_supported else None, + "supports_ngram_mod": bool(ngram_supported), + "spec_draft_n_max_flag": "--spec-draft-n-max", + } + monkeypatch.setattr( + LlamaCppBackend, + "probe_server_capabilities", + classmethod(lambda cls, binary = None: fake), + ) + monkeypatch.setattr( + LlamaCppBackend, + "_find_llama_server_binary", + classmethod(lambda cls: "/fake/llama-server"), + ) + + +def test_already_in_target_state_sub_3b_falls_back_to_ngram_mod_when_supported( + monkeypatch, +): + # 0.8B MTP request -- load_model would have promoted to ngram-mod + # (no MTP head); reload check must match a ngram-mod backend. + _patch_probe(monkeypatch, ngram_supported = True) + backend = _mtp_backend( + _model_identifier = "unsloth/Qwen3.5-0.8B-MTP-GGUF", + _speculative_type = "ngram-mod", + _spec_draft_n_max = None, + ) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.5-0.8B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +def test_already_in_target_state_sub_3b_falls_back_to_off_when_no_ngram(monkeypatch): + # 0.8B + binary lacks ngram-mod -> fall back to off. + _patch_probe(monkeypatch, ngram_supported = False) + backend = _mtp_backend( + _model_identifier = "unsloth/Qwen3.5-0.8B-MTP-GGUF", + _speculative_type = None, + _spec_draft_n_max = None, + ) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.5-0.8B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +def test_already_in_target_state_4b_mtp_request_promotes_as_before(monkeypatch): + # 4B is above the 3B threshold -> auto-promote still applies. + _patch_probe(monkeypatch, ngram_supported = True) + backend = _mtp_backend( + _model_identifier = "unsloth/Qwen3.5-4B-MTP-GGUF", + _speculative_type = "draft-mtp", + _spec_draft_n_max = None, + ) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.5-4B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +def test_already_in_target_state_2b_falls_back_to_ngram_below_threshold(monkeypatch): + # 2.0B is below the 3B threshold -> ngram-mod fallback, not + # draft-mtp. Clean-bench shows 2B regresses with draft-mtp. + _patch_probe(monkeypatch, ngram_supported = True) + backend = _mtp_backend( + _model_identifier = "unsloth/Qwen3.5-2B-MTP-GGUF", + _speculative_type = "ngram-mod", + _spec_draft_n_max = None, + ) + assert ( + backend._already_in_target_state( + gguf_path = None, + model_identifier = "unsloth/Qwen3.5-2B-MTP-GGUF", + hf_variant = "Q4_K_M", + n_ctx = 8192, + cache_type_kv = None, + speculative_type = None, + chat_template_override = None, + extra_args = None, + is_vision = False, + ) + is True + ) + + +# usage backfill from timings (Studio UI t/s widget fix). + + +def test_backfill_usage_from_timings_fills_when_completion_tokens_zero(): + out = _backfill_usage_from_timings( + {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + {"prompt_n": 42, "predicted_n": 128, "predicted_per_second": 100.0}, + ) + assert out["completion_tokens"] == 128 + assert out["prompt_tokens"] == 42 + assert out["total_tokens"] == 170 + + +def test_backfill_usage_from_timings_fills_when_usage_missing(): + out = _backfill_usage_from_timings( + None, + {"prompt_n": 42, "predicted_n": 128, "predicted_per_second": 100.0}, + ) + assert out["completion_tokens"] == 128 + assert out["prompt_tokens"] == 42 + assert out["total_tokens"] == 170 + + +def test_backfill_usage_from_timings_preserves_real_usage(): + # Non-zero completion_tokens means llama-server reported correctly; + # do not overwrite. + real = {"prompt_tokens": 50, "completion_tokens": 200, "total_tokens": 250} + out = _backfill_usage_from_timings(real, {"predicted_n": 999, "prompt_n": 999}) + assert out is real + assert out["completion_tokens"] == 200 + + +def test_backfill_usage_from_timings_passthrough_when_timings_empty(): + assert _backfill_usage_from_timings(None, None) is None + assert _backfill_usage_from_timings(None, {}) is None + usage = {"completion_tokens": 0} + # No timings.predicted_n -> nothing to fill, return as-is. + assert _backfill_usage_from_timings(usage, {"prompt_ms": 5.0}) is usage + + +# ── _canonicalize_spec_mode (pure) ───────────────────────────────── + + +@pytest.mark.parametrize( + "value, expected", + [ + # New canonical values pass through unchanged. + ("auto", "auto"), + ("mtp", "mtp"), + ("ngram", "ngram"), + ("mtp+ngram", "mtp+ngram"), + ("off", "off"), + ("ngram-simple", "ngram-simple"), + # Legacy wire values map onto the new vocabulary. + ("default", "auto"), + ("draft-mtp", "mtp"), + ("ngram-mod", "ngram"), + # Comma-chained legacy values (e.g. from persisted state) collapse + # to the right canonical mode. + ("ngram-mod,draft-mtp", "mtp+ngram"), + ("draft-mtp,ngram-mod", "mtp+ngram"), + ("draft-mtp,mtp", "mtp"), + ("ngram-mod,ngram", "ngram"), + # Case and whitespace are ignored. + (" AUTO ", "auto"), + ("MTP", "mtp"), + ("MTP+Ngram", "mtp+ngram"), + # None / empty / whitespace pass through as None. + (None, None), + ("", None), + (" ", None), + # Non-string inputs collapse to None. + (42, None), + (True, None), + # Unknown strings fall back to "auto" (safe default). + ("bogus", "auto"), + ], +) +def test_canonicalize_spec_mode(value, expected): + assert _canonicalize_spec_mode(value) == expected + + +# ── _build_speculative_flags resolver matrix ────────────────────── + + +def _resolver_backend(monkeypatch, *, ngram_supported = True, mtp_token = "draft-mtp"): + """Backend with a deterministic probe so the resolver is hermetic.""" + fake = { + "found": True, + "mtp_token": mtp_token, + "supports_mtp": bool(mtp_token), + "ngram_mod_flavor": "new" if ngram_supported else None, + "supports_ngram_mod": bool(ngram_supported), + "spec_draft_n_max_flag": "--spec-draft-n-max", + } + monkeypatch.setattr( + LlamaCppBackend, + "probe_server_capabilities", + classmethod(lambda cls, binary = None: fake), + ) + backend = LlamaCppBackend() + backend._nextn_predict_layers = None + return backend + + +def _flags_dict(flags): + """Parse the spec-flag list into a small {flag: value} dict; collapses + repeated flags by keeping the last (only --spec-type can repeat and + never does in our resolver).""" + out = {} + i = 0 + while i < len(flags): + token = flags[i] + if i + 1 < len(flags) and not flags[i + 1].startswith("--"): + out[token] = flags[i + 1] + i += 2 + else: + out[token] = True + i += 1 + return out + + +_MTP_MODEL = "unsloth/Qwen3.6-27B-MTP-GGUF" +_NON_MTP_MODEL = "unsloth/Qwen3-7B-Instruct-GGUF" +_SUB_3B_MTP_MODEL = "unsloth/Qwen3.5-0.8B-MTP-GGUF" + + +@pytest.mark.parametrize( + "requested, gpus, model, expect_spec_type, expect_n_max, expect_ngram_knobs", + [ + # ── auto + MTP model + 3B+: GPU = mtp only, CPU = chain ── + ("auto", True, _MTP_MODEL, "draft-mtp", "2", False), + ("auto", False, _MTP_MODEL, "ngram-mod,draft-mtp", "3", True), + # ── auto + non-MTP: emit --spec-default ── + ("auto", True, _NON_MTP_MODEL, None, None, False), + ("auto", False, _NON_MTP_MODEL, None, None, False), + # ── auto + sub-3B MTP: fallback to ngram-mod ── + ("auto", True, _SUB_3B_MTP_MODEL, "ngram-mod", None, True), + ("auto", False, _SUB_3B_MTP_MODEL, "ngram-mod", None, True), + # ── mtp forced: MTP-only on BOTH platforms ── + ("mtp", True, _MTP_MODEL, "draft-mtp", "2", False), + ("mtp", False, _MTP_MODEL, "draft-mtp", "3", False), + # ── mtp forced on sub-3B: engage anyway ── + ("mtp", True, _SUB_3B_MTP_MODEL, "draft-mtp", "2", False), + # ── mtp forced on non-MTP: engage anyway ── + ("mtp", True, _NON_MTP_MODEL, "draft-mtp", "2", False), + # ── ngram forced: ngram-mod alone on BOTH platforms ── + ("ngram", True, _MTP_MODEL, "ngram-mod", None, True), + ("ngram", False, _MTP_MODEL, "ngram-mod", None, True), + ("ngram", True, _NON_MTP_MODEL, "ngram-mod", None, True), + # ── mtp+ngram forced: chain on BOTH platforms ── + ("mtp+ngram", True, _MTP_MODEL, "ngram-mod,draft-mtp", "2", True), + ("mtp+ngram", False, _MTP_MODEL, "ngram-mod,draft-mtp", "3", True), + ("mtp+ngram", True, _SUB_3B_MTP_MODEL, "ngram-mod,draft-mtp", "2", True), + # ── off: nothing emitted ── + ("off", True, _MTP_MODEL, None, None, False), + ("off", False, _MTP_MODEL, None, None, False), + # ── legacy values round-trip to the canonical emission ── + ("default", True, _MTP_MODEL, "draft-mtp", "2", False), + ("draft-mtp", True, _MTP_MODEL, "draft-mtp", "2", False), + ("ngram-mod", True, _MTP_MODEL, "ngram-mod", None, True), + ("ngram-mod,draft-mtp", False, _MTP_MODEL, "ngram-mod,draft-mtp", "3", True), + # ── ngram-simple: pass through ── + ("ngram-simple", True, _MTP_MODEL, "ngram-simple", None, False), + ], +) +def test_build_speculative_flags_matrix( + monkeypatch, + requested, + gpus, + model, + expect_spec_type, + expect_n_max, + expect_ngram_knobs, +): + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = requested, + spec_draft_n_max = None, + extra_args = None, + model_identifier = model, + model_path = None, + gpus = gpus, + binary = "/fake/llama-server", + ) + parsed = _flags_dict(flags) + if expect_spec_type is None: + assert "--spec-type" not in parsed + else: + assert parsed.get("--spec-type") == expect_spec_type + if expect_n_max is None: + assert "--spec-draft-n-max" not in parsed + else: + assert parsed.get("--spec-draft-n-max") == expect_n_max + if expect_ngram_knobs: + assert "--spec-ngram-mod-n-match" in parsed + assert "--spec-ngram-mod-n-min" in parsed + assert "--spec-ngram-mod-n-max" in parsed + else: + assert "--spec-ngram-mod-n-match" not in parsed + + +def test_build_speculative_flags_user_extra_args_owns_spec_type(monkeypatch): + # User --spec-type in extra_args bypasses the dropdown entirely. + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = "mtp", # would normally force MTP + spec_draft_n_max = None, + extra_args = ["--spec-type", "ngram-mod"], + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + # No flags emitted by the resolver -- the user's extra_args carries + # the --spec-type, and the resolver records requested_spec_mode = None. + assert flags == [] + assert backend.requested_spec_mode is None + assert backend.speculative_type is None + + +@pytest.mark.parametrize("mode", ["auto", "mtp", "ngram", "mtp+ngram", "off"]) +def test_build_speculative_flags_round_trips_requested_mode(monkeypatch, mode): + # The status round-trip is the contract that lets the UI dropdown + # restore its picked value after reload / refresh. + backend = _resolver_backend(monkeypatch) + backend._build_speculative_flags( + speculative_type = mode, + spec_draft_n_max = None, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + assert backend.requested_spec_mode == mode + + +def test_build_speculative_flags_user_draft_n_max_override(monkeypatch): + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = "mtp", + spec_draft_n_max = 5, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + parsed = _flags_dict(flags) + assert parsed.get("--spec-draft-n-max") == "5" + assert backend.spec_draft_n_max == 5 + + +def test_build_speculative_flags_mtp_token_missing_logs_and_skips(monkeypatch): + # Outdated llama-server with no MTP support: forced MTP must degrade + # to spec-off (warned) rather than emit a bad --spec-type. + backend = _resolver_backend(monkeypatch, mtp_token = None) + flags = backend._build_speculative_flags( + speculative_type = "mtp", + spec_draft_n_max = None, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + assert "--spec-type" not in flags + # _speculative_type stays None (resolved emission was none), but + # _requested_spec_mode still reflects the user's choice. + assert backend.requested_spec_mode == "mtp" + assert backend.speculative_type is None diff --git a/studio/frontend/src/features/chat/chat-settings-sheet.tsx b/studio/frontend/src/features/chat/chat-settings-sheet.tsx index 3703beff0a..1ed0a88b63 100644 --- a/studio/frontend/src/features/chat/chat-settings-sheet.tsx +++ b/studio/frontend/src/features/chat/chat-settings-sheet.tsx @@ -570,6 +570,11 @@ export function ChatSettingsPanel({ const loadedSpeculativeType = useChatRuntimeStore( (s) => s.loadedSpeculativeType, ); + const specDraftNMax = useChatRuntimeStore((s) => s.specDraftNMax); + const setSpecDraftNMax = useChatRuntimeStore((s) => s.setSpecDraftNMax); + const loadedSpecDraftNMax = useChatRuntimeStore( + (s) => s.loadedSpecDraftNMax, + ); const modelRequiresTrustRemoteCode = useChatRuntimeStore( (s) => s.modelRequiresTrustRemoteCode, ); @@ -608,7 +613,8 @@ export function ChatSettingsPanel({ const kvDirty = kvCacheDtype !== loadedKvCacheDtype; const ctxDirty = customContextLength !== null; const specDirty = speculativeType !== loadedSpeculativeType; - const modelSettingsDirty = kvDirty || ctxDirty || specDirty; + const specDraftDirty = specDraftNMax !== loadedSpecDraftNMax; + const modelSettingsDirty = kvDirty || ctxDirty || specDirty || specDraftDirty; const chatTemplateOverride = useChatRuntimeStore( (s) => s.chatTemplateOverride, ); @@ -985,19 +991,81 @@ export function ChatSettingsPanel({ Speculative Decoding - Faster generation with 0% accuracy hit. + Faster generation with 0% accuracy hit. Auto picks + MTP / ngram-mod based on the model and platform. + Pick MTP, Ngram, or MTP+Ngram to force a specific + strategy on both GPU and CPU. - { - setSpeculativeType(checked ? "default" : "off"); - }} - /> +
+ +
+ {(speculativeType === "mtp" || + speculativeType === "mtp+ngram") && ( +
+
+ + Draft Tokens + + + Max MTP draft tokens per step + (--spec-draft-n-max). Lower = less wasted + draft decode; higher = bigger speedup when + acceptance stays high. Default: 2 on GPU, + 3 on CPU/Mac. + +
+ { + const raw = e.target.value; + if (raw === "") { + setSpecDraftNMax(null); + return; + } + const parsed = Number.parseInt(raw, 10); + if (Number.isFinite(parsed)) { + const clamped = Math.max(1, Math.min(16, parsed)); + setSpecDraftNMax(clamped); + } + }} + data-test-id="spec-draft-n-max-input" + aria-label="Speculative decoding draft tokens" + className="h-7 w-[72px] rounded-[10px] border-transparent bg-black/[0.04] dark:bg-white/[0.05] hover:bg-black/[0.06] dark:hover:bg-white/[0.07] px-2 py-0 text-[13px] font-medium text-nav-fg outline-none focus-visible:ring-0" + /> +
+ )} )} {!isGguf && params.checkpoint && ( @@ -1051,6 +1119,7 @@ export function ChatSettingsPanel({ setCustomContextLength(null); setKvCacheDtype(loadedKvCacheDtype); setSpeculativeType(loadedSpeculativeType); + setSpecDraftNMax(loadedSpecDraftNMax); setChatTemplateOverride(loadedChatTemplateOverride); }} className="h-7 px-3 text-[12px] font-medium tracking-nav text-muted-foreground" diff --git a/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts b/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts index 03f80c19f2..3f1060edf7 100644 --- a/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts +++ b/studio/frontend/src/features/chat/hooks/use-chat-model-runtime.ts @@ -141,10 +141,30 @@ function getTrustRemoteCodeRequiredMessage(modelName: string): string { return `${modelName} needs custom code enabled to load. Turn on "Enable custom code" in Chat Settings, then try again.`; } +// Canonicalises any value the backend reports (or persisted state holds) +// onto the five UI-facing modes the Speculative Decoding dropdown +// understands: "auto" / "mtp" / "ngram" / "mtp+ngram" / "off" / null. +// Mirrors backend _canonicalize_spec_mode so old persisted "default" / +// "draft-mtp" / "ngram-mod" / chain values round-trip cleanly. function normalizeSpeculativeType(v: string | null | undefined): string | null { if (v == null) return null; - if (v === "default" || v === "off") return v; - return "default"; + const s = String(v).trim().toLowerCase(); + if (!s) return null; + if (s === "auto" || s === "default") return "auto"; + if (s === "off") return "off"; + if (s === "ngram-simple") return "ngram-simple"; + if (s === "mtp" || s === "draft-mtp") return "mtp"; + if (s === "ngram" || s === "ngram-mod") return "ngram"; + if (s === "mtp+ngram") return "mtp+ngram"; + // Comma-chained legacy values (e.g. from older persisted state). + const parts = s.split(",").map((p) => p.trim()).filter(Boolean); + const hasMtp = parts.some((p) => p === "mtp" || p === "draft-mtp"); + const hasNgram = parts.some((p) => p === "ngram" || p === "ngram-mod"); + if (hasMtp && hasNgram) return "mtp+ngram"; + if (hasMtp) return "mtp"; + if (hasNgram) return "ngram"; + // Unknown -> safe fallback to Auto so the dropdown stays controlled. + return "auto"; } type LocalReasoningEffort = Extract; @@ -323,6 +343,12 @@ export function useChatModelRuntime() { speculativeType: currentSpecType, loadedSpeculativeType: currentSpecType, }), + ...(statusRes.spec_draft_n_max !== undefined && + prevState.loadedSpecDraftNMax === null && + prevState.specDraftNMax === null && { + specDraftNMax: statusRes.spec_draft_n_max ?? null, + loadedSpecDraftNMax: statusRes.spec_draft_n_max ?? null, + }), ...(statusRes.cache_type_kv !== undefined && prevState.loadedKvCacheDtype === null && { kvCacheDtype: statusRes.cache_type_kv, @@ -528,12 +554,31 @@ export function useChatModelRuntime() { } if (abortCtrl.signal.aborted) throw new Error("Cancelled"); + // Reset Speculative Decoding to Auto whenever the user + // switches to a different model. Spec strategy is a + // per-model decision: a sub-3B non-MTP GGUF that ran with + // "Off" should not carry that choice into a 27B MTP GGUF + // where Auto would auto-promote to draft-mtp. The user can + // still pick a forced mode on the new model; this just + // clears the stale prior-model choice so the backend's + // platform-aware path runs by default. Same applies to + // spec_draft_n_max which is MTP-only. + if (currentCheckpoint && currentCheckpoint !== modelId) { + useChatRuntimeStore.setState({ + speculativeType: null, + loadedSpeculativeType: null, + specDraftNMax: null, + loadedSpecDraftNMax: null, + }); + } + const { chatTemplateOverride, kvCacheDtype, customContextLength, ggufContextLength, speculativeType, + specDraftNMax, activePresetSource, activeGgufVariant, } = useChatRuntimeStore.getState(); @@ -561,6 +606,7 @@ export function useChatModelRuntime() { chat_template_override: effectiveChatTemplateOverride, cache_type_kv: kvCacheDtype, speculative_type: speculativeType, + spec_draft_n_max: specDraftNMax, }); // If cancelled while loading, don't update UI to show @@ -635,6 +681,8 @@ export function useChatModelRuntime() { loadedKvCacheDtype: loadedKv, speculativeType: loadedSpec, loadedSpeculativeType: loadedSpec, + specDraftNMax: loadResponse.spec_draft_n_max ?? null, + loadedSpecDraftNMax: loadResponse.spec_draft_n_max ?? null, customContextLength: keepCustomCtx, defaultChatTemplate: loadResponse.chat_template ?? null, chatTemplateOverride: effectiveChatTemplateOverride, diff --git a/studio/frontend/src/features/chat/stores/chat-runtime-store.ts b/studio/frontend/src/features/chat/stores/chat-runtime-store.ts index 13d3cb533f..e55a25d08d 100644 --- a/studio/frontend/src/features/chat/stores/chat-runtime-store.ts +++ b/studio/frontend/src/features/chat/stores/chat-runtime-store.ts @@ -258,6 +258,9 @@ type ChatRuntimeStore = { loadedKvCacheDtype: string | null; speculativeType: string | null; loadedSpeculativeType: string | null; + /** User --spec-draft-n-max override (null = platform default). */ + specDraftNMax: number | null; + loadedSpecDraftNMax: number | null; loadedIsMultimodal: boolean; customContextLength: number | null; defaultChatTemplate: string | null; @@ -305,6 +308,7 @@ type ChatRuntimeStore = { setToolCallTimeout: (value: number) => void; setKvCacheDtype: (dtype: string | null) => void; setSpeculativeType: (type: string | null) => void; + setSpecDraftNMax: (value: number | null) => void; setCustomContextLength: (v: number | null) => void; setChatTemplateOverride: (template: string | null) => void; setPendingAudio: (base64: string, name: string) => void; @@ -349,8 +353,10 @@ export const useChatRuntimeStore = create((set) => ({ toolCallTimeout: loadInt(TOOL_CALL_TIMEOUT_KEY, 5), kvCacheDtype: null, loadedKvCacheDtype: null, - speculativeType: "default", + speculativeType: "auto", loadedSpeculativeType: null, + specDraftNMax: null, + loadedSpecDraftNMax: null, loadedIsMultimodal: false, customContextLength: null, defaultChatTemplate: null, @@ -457,8 +463,10 @@ export const useChatRuntimeStore = create((set) => ({ toolStatus: null, kvCacheDtype: null, loadedKvCacheDtype: null, - speculativeType: "default", + speculativeType: "auto", loadedSpeculativeType: null, + specDraftNMax: null, + loadedSpecDraftNMax: null, loadedIsMultimodal: false, customContextLength: null, defaultChatTemplate: null, @@ -506,6 +514,7 @@ export const useChatRuntimeStore = create((set) => ({ }), setKvCacheDtype: (kvCacheDtype) => set({ kvCacheDtype }), setSpeculativeType: (speculativeType) => set({ speculativeType }), + setSpecDraftNMax: (specDraftNMax) => set({ specDraftNMax }), setCustomContextLength: (customContextLength) => set({ customContextLength }), setChatTemplateOverride: (chatTemplateOverride) => set({ chatTemplateOverride }), setPendingAudio: (base64, name) => diff --git a/studio/frontend/src/features/chat/types/api.ts b/studio/frontend/src/features/chat/types/api.ts index 3fc5d320df..1e6bcf8b87 100644 --- a/studio/frontend/src/features/chat/types/api.ts +++ b/studio/frontend/src/features/chat/types/api.ts @@ -42,7 +42,21 @@ export interface LoadModelRequest { trust_remote_code?: boolean; chat_template_override?: string | null; cache_type_kv?: string | null; + /** + * Speculative decoding mode for GGUF models. Canonical values: + * "auto" (platform-aware: MTP on MTP GGUFs, ngram-mod fallback for + * sub-3B), "mtp" (force draft-mtp only on both GPU and CPU), "ngram" + * (force ngram-mod only), "mtp+ngram" (force ngram-mod + draft-mtp + * chain on both platforms), or "off". Legacy values "default" / + * "draft-mtp" / "ngram-mod" / "ngram-simple" are still accepted by + * the backend. + */ speculative_type?: string | null; + /** + * Override --spec-draft-n-max for MTP speculative decoding. Only + * applied when speculative_type resolves to "mtp" or "mtp+ngram". + */ + spec_draft_n_max?: number | null; } export interface ValidateModelResponse { @@ -118,7 +132,9 @@ export interface LoadModelResponse { supports_tools?: boolean; cache_type_kv?: string | null; chat_template?: string | null; + /** Canonical UI-facing mode the load request resolved to. See LoadModelRequest. */ speculative_type?: string | null; + spec_draft_n_max?: number | null; } export interface UnloadModelRequest { @@ -155,7 +171,9 @@ export interface InferenceStatusResponse { native_context_length?: number | null; cache_type_kv?: string | null; chat_template_override?: string | null; + /** Canonical UI-facing mode currently active. See LoadModelRequest. */ speculative_type?: string | null; + spec_draft_n_max?: number | null; } export interface AudioGenerationResponse {