diff --git a/studio/backend/core/inference/llama_cpp.py b/studio/backend/core/inference/llama_cpp.py index 260e675a73..bd55346f4b 100644 --- a/studio/backend/core/inference/llama_cpp.py +++ b/studio/backend/core/inference/llama_cpp.py @@ -522,11 +522,50 @@ def _build_ngram_mod_flags( return [] -# Canonical Speculative Decoding modes exposed by the Studio chat UI. -# The dropdown renders five options (auto, mtp, ngram, mtp+ngram, off); -# the load API also accepts legacy values that the original Switch and -# external callers emit (default, draft-mtp, ngram-mod, ngram-simple). -_CANONICAL_SPEC_MODES = {"auto", "mtp", "ngram", "mtp+ngram", "off", "ngram-simple"} +def _build_ngram_map_k_flags( + caps: Optional[dict], + *, + variant: str, + size_n: int = 16, + size_m: int = 24, + min_hits: int = 1, +) -> list[str]: + """Emit knob flags for ``ngram-map-k`` / ``ngram-map-k4v`` if the + bundled llama-server advertises them (added upstream in #23269 and + sibling commits). Returns ``[]`` when the binary doesn't. + """ + cap_key = ( + "supports_ngram_map_k4v" + if variant == "ngram-map-k4v" + else "supports_ngram_map_k" + ) + if not (caps and caps.get(cap_key)): + return [] + prefix = f"--spec-{variant}" + return [ + f"{prefix}-size-n", + str(size_n), + f"{prefix}-size-m", + str(size_m), + f"{prefix}-min-hits", + str(min_hits), + ] + + +# Canonical Speculative Decoding modes exposed by the Studio chat UI +# (5-mode dropdown) plus power-user n-gram variants accepted via the +# load API. ``ngram-map-k`` / ``ngram-map-k4v`` are not in the dropdown +# but the resolver emits them correctly when sent via API. +_CANONICAL_SPEC_MODES = { + "auto", + "mtp", + "ngram", + "mtp+ngram", + "off", + "ngram-simple", + "ngram-map-k", + "ngram-map-k4v", +} _LEGACY_SPEC_MODE_MAP = { "default": "auto", "draft-mtp": "mtp", @@ -538,9 +577,10 @@ def _canonicalize_spec_mode(value): """Map any accepted ``speculative_type`` input onto a canonical mode. Returns one of ``auto``, ``mtp``, ``ngram``, ``mtp+ngram``, ``off``, - ``ngram-simple``, or ``None`` (callers treat ``None`` as ``auto``). - Unknown strings collapse to ``auto`` so a stale UI value or typo - falls back to the safe platform-aware path. + ``ngram-simple``, ``ngram-map-k``, ``ngram-map-k4v``, or ``None`` + (callers treat ``None`` as ``auto``). Unknown strings collapse to + ``auto`` so a stale UI value or typo falls back to the safe + platform-aware path. """ if value is None: return None @@ -641,6 +681,10 @@ def __init__(self): self._requested_spec_mode: Optional[str] = None # User-supplied --spec-draft-n-max override (None = platform default). self._spec_draft_n_max: Optional[int] = None + # User-supplied --spec-draft-p-min override (None = platform default). + # Only emitted for MTP / MTP+Ngram modes; the flag became functional + # in llama.cpp #23269 (default upstream dropped to 0.0). + self._spec_draft_p_min: Optional[float] = None # KV-cache estimation fields (populated by _read_gguf_metadata) self._n_layers: Optional[int] = None self._n_kv_heads: Optional[int] = None @@ -932,6 +976,12 @@ def spec_draft_n_max(self) -> Optional[int]: when the platform default (6 GPU / 3 CPU) is in effect.""" return self._spec_draft_n_max + @property + def spec_draft_p_min(self) -> Optional[float]: + """User --spec-draft-p-min override on the load, or None when + the llama-server default (0.0 since #23269) is in effect.""" + return self._spec_draft_p_min + # ── Binary discovery ────────────────────────────────────────── @staticmethod @@ -1059,7 +1109,8 @@ def probe_server_capabilities( ) -> dict[str, object]: """Parse `llama-server --help` for feature flags. Returns {found, mtp_token, supports_mtp, ngram_mod_flavor, - supports_ngram_mod, spec_draft_n_max_flag}. + supports_ngram_mod, spec_draft_n_max_flag, spec_draft_p_min_flag, + supports_ngram_map_k, supports_ngram_map_k4v}. ``ngram_mod_flavor`` is ``"new"`` when the binary exposes the post-rename ``--spec-ngram-mod-n-match / -n-min / -n-max`` as @@ -1072,6 +1123,13 @@ def probe_server_capabilities( ``spec_draft_n_max_flag`` is the actual flag name the binary accepts: ``--spec-draft-n-max`` on post-rename builds, or ``--draft-max`` on legacy. ``None`` means n_max cannot be set. + + ``spec_draft_p_min_flag`` is the flag name when the binary + advertises ``--spec-draft-p-min`` (functional from llama.cpp + #23269 onward; before that it was a knob with no effect on + MTP). ``supports_ngram_map_k`` / ``supports_ngram_map_k4v`` + signal that the binary accepts these power-user spec types + with their own knob sets. """ bin_path = binary or cls._find_llama_server_binary() if not bin_path or not Path(bin_path).is_file(): @@ -1082,6 +1140,9 @@ def probe_server_capabilities( "ngram_mod_flavor": None, "supports_ngram_mod": False, "spec_draft_n_max_flag": None, + "spec_draft_p_min_flag": None, + "supports_ngram_map_k": False, + "supports_ngram_map_k4v": False, } try: mtime = int(Path(bin_path).stat().st_mtime) @@ -1095,6 +1156,9 @@ def probe_server_capabilities( mtp_token: Optional[str] = None ngram_mod_flavor: Optional[str] = None spec_draft_n_max_flag: Optional[str] = None + spec_draft_p_min_flag: Optional[str] = None + supports_ngram_map_k = False + supports_ngram_map_k4v = False try: result = subprocess.run( [bin_path, "--help"], @@ -1184,6 +1248,24 @@ def _is_real(flag: str) -> bool: spec_draft_n_max_flag = "--spec-draft-n-max" elif _is_real("--draft-max"): spec_draft_n_max_flag = "--draft-max" + + # p_min flag: present on builds that include the MTP cleanup + # (llama.cpp #23269) or earlier draft-model spec impls. + if _is_real("--spec-draft-p-min"): + spec_draft_p_min_flag = "--spec-draft-p-min" + + # ngram-map-k / ngram-map-k4v: power-user spec types added + # alongside ngram-mod. Each carries its own knob triplet. + supports_ngram_map_k = ( + _is_real("--spec-ngram-map-k-size-n") + and _is_real("--spec-ngram-map-k-size-m") + and _is_real("--spec-ngram-map-k-min-hits") + ) + supports_ngram_map_k4v = ( + _is_real("--spec-ngram-map-k4v-size-n") + and _is_real("--spec-ngram-map-k4v-size-m") + and _is_real("--spec-ngram-map-k4v-min-hits") + ) except (OSError, subprocess.SubprocessError) as exc: logger.debug(f"llama-server --help probe failed: {exc}") @@ -1194,6 +1276,9 @@ def _is_real(flag: str) -> bool: "ngram_mod_flavor": ngram_mod_flavor, "supports_ngram_mod": ngram_mod_flavor is not None, "spec_draft_n_max_flag": spec_draft_n_max_flag, + "spec_draft_p_min_flag": spec_draft_p_min_flag, + "supports_ngram_map_k": supports_ngram_map_k, + "supports_ngram_map_k4v": supports_ngram_map_k4v, } cls._capability_cache[cache_key] = info return info @@ -2502,6 +2587,7 @@ def load_model( cache_type_kv: Optional[str] = None, speculative_type: Optional[str] = None, spec_draft_n_max: Optional[int] = None, + spec_draft_p_min: Optional[float] = None, n_threads: Optional[int] = None, n_gpu_layers: Optional[int] = None, # Accepted for caller compat, unused n_parallel: int = 1, @@ -2534,6 +2620,7 @@ def load_model( cache_type_kv = cache_type_kv, speculative_type = speculative_type, spec_draft_n_max = spec_draft_n_max, + spec_draft_p_min = spec_draft_p_min, chat_template_override = chat_template_override, extra_args = extra_args, is_vision = is_vision, @@ -2915,6 +3002,7 @@ def load_model( spec_flags = self._build_speculative_flags( speculative_type = speculative_type, spec_draft_n_max = spec_draft_n_max, + spec_draft_p_min = spec_draft_p_min, extra_args = extra_args, model_identifier = model_identifier, model_path = model_path, @@ -3258,6 +3346,7 @@ def _build_speculative_flags( *, speculative_type: Optional[str], spec_draft_n_max: Optional[int], + spec_draft_p_min: Optional[float] = None, extra_args: Optional[List[str]], model_identifier: str, model_path: Optional[str], @@ -3299,6 +3388,7 @@ def _build_speculative_flags( flags: List[str] = [] # Reset; emit branches re-set on the resolved emission. self._spec_draft_n_max = None + self._spec_draft_p_min = None self._speculative_type = None # Canonical UI-facing requested mode: auto / mtp / ngram / @@ -3334,6 +3424,22 @@ def _resolved_draft_n_max() -> int: return n return 2 if gpus else 3 + def _maybe_emit_p_min(caps: Optional[dict]) -> None: + """Append --spec-draft-p-min if the user supplied a value AND + the binary advertises the flag (functional from #23269).""" + if spec_draft_p_min is None: + return + p_flag = (caps or {}).get("spec_draft_p_min_flag") + if not p_flag: + logger.warning( + "llama-server lacks --spec-draft-p-min; ignoring " + "spec_draft_p_min override. Run `unsloth studio update`." + ) + return + p_min = float(spec_draft_p_min) + self._spec_draft_p_min = p_min + flags.extend([p_flag, f"{p_min:.4f}"]) + def _emit_mtp(*, chain_ngram: bool) -> bool: """Append --spec-type mtp[/draft-mtp][,ngram-mod] + n-max.""" caps = self.probe_server_capabilities(binary) @@ -3376,6 +3482,7 @@ def _emit_mtp(*, chain_ngram: bool) -> bool: str(draft_n_max), ] ) + _maybe_emit_p_min(caps) self._speculative_type = "draft-mtp" chain_label = "chained ngram-mod" if chain_ngram else "MTP-only" logger.info(f"Spec decoding: {mtp_token} ({chain_label})") @@ -3405,6 +3512,25 @@ def _emit_ngram_mod() -> bool: if effective_mode == "ngram": _emit_ngram_mod() return flags + if effective_mode in ("ngram-map-k", "ngram-map-k4v"): + map_caps = self.probe_server_capabilities(binary) + supported = map_caps.get( + "supports_ngram_map_k4v" + if effective_mode == "ngram-map-k4v" + else "supports_ngram_map_k" + ) + if not supported: + logger.warning( + f"llama-server lacks --spec-type {effective_mode}; " + "run `unsloth studio update`. Loading without " + "speculative decoding." + ) + return flags + flags.extend(["--spec-type", effective_mode]) + flags.extend(_build_ngram_map_k_flags(map_caps, variant = effective_mode)) + self._speculative_type = effective_mode + logger.info(f"Spec decoding: {effective_mode}") + return flags if effective_mode == "mtp": if _mtp_too_small: logger.warning( @@ -3480,6 +3606,7 @@ def _already_in_target_state( is_vision: bool, gguf_path: Optional[str] = None, spec_draft_n_max: Optional[int] = None, + spec_draft_p_min: Optional[float] = None, ) -> bool: """True iff the live server already satisfies these load kwargs. @@ -3540,6 +3667,16 @@ def _norm(value): ): return False + # Same treatment for spec_draft_p_min: only relevant when MTP is + # the resolved emit. None on either side means "default" (0.0 + # since llama.cpp #23269) and matches anything. + if ( + self._speculative_type == "draft-mtp" + and spec_draft_p_min is not None + and abs(float(spec_draft_p_min) - (self._spec_draft_p_min or 0.0)) > 1e-6 + ): + return False + if (self._chat_template_override or None) != (chat_template_override or None): return False @@ -3608,6 +3745,7 @@ def unload_model(self) -> bool: self._speculative_type = None self._requested_spec_mode = None self._spec_draft_n_max = None + self._spec_draft_p_min = None self._n_layers = None self._n_kv_heads = None self._n_kv_heads_by_layer = None diff --git a/studio/backend/models/inference.py b/studio/backend/models/inference.py index f8bf219846..5ccb2ab447 100644 --- a/studio/backend/models/inference.py +++ b/studio/backend/models/inference.py @@ -71,14 +71,16 @@ def normalize_blank_chat_template_override( speculative_type: Optional[str] = Field( None, description = ( - "Speculative decoding mode for GGUF models. Canonical values: " - "'auto' (platform-aware: MTP on MTP GGUFs, ngram-mod fallback " - "for sub-3B), 'mtp' (force draft-mtp only on both GPU and CPU), " - "'ngram' (force ngram-mod only), 'mtp+ngram' (force " - "ngram-mod+draft-mtp chain on both platforms), 'off' (disabled). " - "Legacy values 'default' (-> auto), 'draft-mtp' (-> mtp), " - "'ngram-mod' (-> ngram), and 'ngram-simple' (kept as-is) are " - "still accepted. Ignored for non-GGUF and vision models." + "Speculative decoding mode for GGUF models. Canonical values " + "exposed by the Studio dropdown: 'auto' (platform-aware: MTP " + "on MTP GGUFs, ngram-mod fallback for sub-3B), 'mtp' (force " + "draft-mtp only on both GPU and CPU), 'ngram' (force " + "ngram-mod only), 'mtp+ngram' (force ngram-mod+draft-mtp " + "chain on both platforms), 'off' (disabled). Power-user " + "spec types accepted via API: 'ngram-simple', 'ngram-map-k', " + "'ngram-map-k4v'. Legacy values 'default' (-> auto), " + "'draft-mtp' (-> mtp), 'ngram-mod' (-> ngram) are still " + "accepted. Ignored for non-GGUF and vision models." ), ) spec_draft_n_max: Optional[int] = Field( @@ -93,6 +95,19 @@ def normalize_blank_chat_template_override( "'mtp' or 'mtp+ngram'." ), ) + spec_draft_p_min: Optional[float] = Field( + None, + ge = 0.0, + le = 1.0, + description = ( + "Min draft probability for MTP speculative decoding " + "(--spec-draft-p-min). Drafts with predicted probability " + "below this threshold are rejected. Defaults to 0.0 (no " + "filtering) since llama.cpp #23269; before that the flag " + "existed but was non-functional. Only applied when " + "speculative_type resolves to 'mtp' or 'mtp+ngram'." + ), + ) llama_extra_args: Optional[List[str]] = Field( None, description = ( @@ -242,8 +257,9 @@ class LoadResponse(BaseModel): description = ( "Canonical UI-facing requested speculative decoding mode " "('auto' / 'mtp' / 'ngram' / 'mtp+ngram' / 'off' / " - "'ngram-simple'), round-tripped from the original LoadRequest " - "via _canonicalize_spec_mode. None when no model is loaded." + "'ngram-simple' / 'ngram-map-k' / 'ngram-map-k4v'), " + "round-tripped from the original LoadRequest via " + "_canonicalize_spec_mode. None when no model is loaded." ), ) spec_draft_n_max: Optional[int] = Field( @@ -253,6 +269,13 @@ class LoadResponse(BaseModel): "None when the platform default is in effect." ), ) + spec_draft_p_min: Optional[float] = Field( + None, + description = ( + "Active --spec-draft-p-min for MTP speculative decoding, or " + "None when the llama-server default (0.0) is in effect." + ), + ) class UnloadResponse(BaseModel): @@ -376,8 +399,9 @@ class InferenceStatusResponse(BaseModel): description = ( "Canonical UI-facing requested speculative decoding mode " "('auto' / 'mtp' / 'ngram' / 'mtp+ngram' / 'off' / " - "'ngram-simple'), round-tripped from the original LoadRequest. " - "None when no model is loaded." + "'ngram-simple' / 'ngram-map-k' / 'ngram-map-k4v'), " + "round-tripped from the original LoadRequest. None when no " + "model is loaded." ), ) spec_draft_n_max: Optional[int] = Field( @@ -387,6 +411,13 @@ class InferenceStatusResponse(BaseModel): "None when the platform default is in effect." ), ) + spec_draft_p_min: Optional[float] = Field( + None, + description = ( + "Active --spec-draft-p-min for MTP speculative decoding, or " + "None when the llama-server default (0.0) is in effect." + ), + ) llama_cpp_supports_mtp: bool = Field( True, description = ( diff --git a/studio/backend/routes/inference.py b/studio/backend/routes/inference.py index cde86998d2..aec753f0e2 100644 --- a/studio/backend/routes/inference.py +++ b/studio/backend/routes/inference.py @@ -504,11 +504,21 @@ def _request_matches_loaded_settings( backend_mode = llama_backend.requested_spec_mode or "auto" if req_mode != backend_mode: return False - # spec_draft_n_max only matters when an MTP variant is engaged; None - # means "platform default" and matches whatever the backend chose. + # spec_draft_n_max / spec_draft_p_min only matter when an MTP variant + # is engaged; None on the request means "platform default" and + # matches whatever the backend chose. if backend_mode in ("mtp", "mtp+ngram") and request.spec_draft_n_max is not None: if int(request.spec_draft_n_max) != (llama_backend.spec_draft_n_max or 0): return False + if ( + backend_mode in ("mtp", "mtp+ngram") + and request.spec_draft_p_min is not None + and abs( + float(request.spec_draft_p_min) - (llama_backend.spec_draft_p_min or 0.0) + ) + > 1e-6 + ): + return False if (request.chat_template_override or None) != ( llama_backend.chat_template_override or None ): @@ -649,6 +659,7 @@ async def load_model( chat_template = llama_backend.chat_template, speculative_type = llama_backend.requested_spec_mode, spec_draft_n_max = llama_backend.spec_draft_n_max, + spec_draft_p_min = llama_backend.spec_draft_p_min, ) else: if ( @@ -780,6 +791,7 @@ async def load_model( strip_spec = ( "speculative_type" in fields_set or "spec_draft_n_max" in fields_set + or "spec_draft_p_min" in fields_set ), strip_template = "chat_template_override" in fields_set, ) @@ -822,6 +834,7 @@ async def load_model( cache_type_kv = request.cache_type_kv, speculative_type = request.speculative_type, spec_draft_n_max = request.spec_draft_n_max, + spec_draft_p_min = request.spec_draft_p_min, n_parallel = _n_parallel, extra_args = extra_llama_args, ) @@ -846,6 +859,7 @@ async def load_model( cache_type_kv = request.cache_type_kv, speculative_type = request.speculative_type, spec_draft_n_max = request.spec_draft_n_max, + spec_draft_p_min = request.spec_draft_p_min, n_parallel = _n_parallel, extra_args = extra_llama_args, ) @@ -906,6 +920,7 @@ async def load_model( chat_template = llama_backend.chat_template, speculative_type = llama_backend.requested_spec_mode, spec_draft_n_max = llama_backend.spec_draft_n_max, + spec_draft_p_min = llama_backend.spec_draft_p_min, ) # ── Standard path: load via Unsloth/transformers ────────── @@ -1395,6 +1410,7 @@ async def get_status( chat_template_override = llama_backend.chat_template_override, speculative_type = llama_backend.requested_spec_mode, spec_draft_n_max = llama_backend.spec_draft_n_max, + spec_draft_p_min = llama_backend.spec_draft_p_min, llama_cpp_supports_mtp = _supports_mtp, llama_cpp_prebuilt_stale = _stale, llama_cpp_installed_tag = _installed_tag, diff --git a/studio/backend/tests/test_llama_cpp_mtp_detection.py b/studio/backend/tests/test_llama_cpp_mtp_detection.py index 4a8276adc0..a6b0f9a9f9 100644 --- a/studio/backend/tests/test_llama_cpp_mtp_detection.py +++ b/studio/backend/tests/test_llama_cpp_mtp_detection.py @@ -1011,7 +1011,15 @@ def test_canonicalize_spec_mode(value, expected): # ── _build_speculative_flags resolver matrix ────────────────────── -def _resolver_backend(monkeypatch, *, ngram_supported = True, mtp_token = "draft-mtp"): +def _resolver_backend( + monkeypatch, + *, + ngram_supported = True, + mtp_token = "draft-mtp", + p_min_flag = "--spec-draft-p-min", + supports_ngram_map_k = True, + supports_ngram_map_k4v = True, +): """Backend with a deterministic probe so the resolver is hermetic.""" fake = { "found": True, @@ -1020,6 +1028,9 @@ def _resolver_backend(monkeypatch, *, ngram_supported = True, mtp_token = "draft "ngram_mod_flavor": "new" if ngram_supported else None, "supports_ngram_mod": bool(ngram_supported), "spec_draft_n_max_flag": "--spec-draft-n-max", + "spec_draft_p_min_flag": p_min_flag, + "supports_ngram_map_k": supports_ngram_map_k, + "supports_ngram_map_k4v": supports_ngram_map_k4v, } monkeypatch.setattr( LlamaCppBackend, @@ -1198,3 +1209,132 @@ def test_build_speculative_flags_mtp_token_missing_logs_and_skips(monkeypatch): # _requested_spec_mode still reflects the user's choice. assert backend.requested_spec_mode == "mtp" assert backend.speculative_type is None + + +# ── spec_draft_p_min ──────────────────────────────────────────── + + +@pytest.mark.parametrize("mode", ["mtp", "mtp+ngram"]) +def test_build_speculative_flags_emits_p_min_for_mtp_modes(monkeypatch, mode): + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = mode, + spec_draft_n_max = None, + spec_draft_p_min = 0.25, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + parsed = _flags_dict(flags) + assert parsed.get("--spec-draft-p-min") == "0.2500" + assert backend.spec_draft_p_min == 0.25 + + +@pytest.mark.parametrize("mode", ["auto", "ngram", "off", "ngram-simple"]) +def test_build_speculative_flags_does_not_emit_p_min_for_non_mtp_modes( + monkeypatch, mode +): + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = mode, + spec_draft_n_max = None, + spec_draft_p_min = 0.25, + extra_args = None, + # Use a non-MTP model so "auto" doesn't promote to draft-mtp + # and accidentally pick up the p-min emission. + model_identifier = _NON_MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + assert "--spec-draft-p-min" not in flags + assert backend.spec_draft_p_min is None + + +def test_build_speculative_flags_p_min_skipped_when_binary_lacks_flag(monkeypatch): + # Older llama-server (no #23269): p_min flag is None in the probe. + # Resolver should log a warning and skip emission rather than crash. + backend = _resolver_backend(monkeypatch, p_min_flag = None) + flags = backend._build_speculative_flags( + speculative_type = "mtp", + spec_draft_n_max = None, + spec_draft_p_min = 0.5, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + assert "--spec-draft-p-min" not in flags + assert backend.spec_draft_p_min is None + + +def test_build_speculative_flags_p_min_emitted_via_auto_when_promoted(monkeypatch): + # Auto on an MTP GGUF auto-promotes to draft-mtp; p_min should still + # flow through because the resolved emission is MTP. + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = "auto", + spec_draft_n_max = None, + spec_draft_p_min = 0.1, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + parsed = _flags_dict(flags) + assert parsed.get("--spec-type") == "draft-mtp" + assert parsed.get("--spec-draft-p-min") == "0.1000" + + +# ── ngram-map-k / ngram-map-k4v ────────────────────────────────── + + +@pytest.mark.parametrize("variant", ["ngram-map-k", "ngram-map-k4v"]) +def test_build_speculative_flags_ngram_map_variants(monkeypatch, variant): + backend = _resolver_backend(monkeypatch) + flags = backend._build_speculative_flags( + speculative_type = variant, + spec_draft_n_max = None, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + parsed = _flags_dict(flags) + assert parsed.get("--spec-type") == variant + assert parsed.get(f"--spec-{variant}-size-n") == "16" + assert parsed.get(f"--spec-{variant}-size-m") == "24" + assert parsed.get(f"--spec-{variant}-min-hits") == "1" + assert backend.requested_spec_mode == variant + assert backend.speculative_type == variant + # n-max/p-min are MTP-only knobs; ngram-map-* variants don't carry them. + assert "--spec-draft-n-max" not in flags + assert "--spec-draft-p-min" not in flags + + +def test_build_speculative_flags_ngram_map_k_skipped_when_unsupported(monkeypatch): + backend = _resolver_backend(monkeypatch, supports_ngram_map_k = False) + flags = backend._build_speculative_flags( + speculative_type = "ngram-map-k", + spec_draft_n_max = None, + extra_args = None, + model_identifier = _MTP_MODEL, + model_path = None, + gpus = True, + binary = "/fake/llama-server", + ) + assert "--spec-type" not in flags + assert backend.speculative_type is None + # Requested mode still round-trips so the UI / status can surface it. + assert backend.requested_spec_mode == "ngram-map-k" + + +def test_canonicalize_spec_mode_recognises_ngram_map_variants(): + assert _canonicalize_spec_mode("ngram-map-k") == "ngram-map-k" + assert _canonicalize_spec_mode("ngram-map-k4v") == "ngram-map-k4v" + assert _canonicalize_spec_mode("NGRAM-MAP-K4V") == "ngram-map-k4v" diff --git a/studio/frontend/src/features/chat/chat-settings-sheet.tsx b/studio/frontend/src/features/chat/chat-settings-sheet.tsx index 1ed0a88b63..7fe51d4d98 100644 --- a/studio/frontend/src/features/chat/chat-settings-sheet.tsx +++ b/studio/frontend/src/features/chat/chat-settings-sheet.tsx @@ -575,6 +575,11 @@ export function ChatSettingsPanel({ const loadedSpecDraftNMax = useChatRuntimeStore( (s) => s.loadedSpecDraftNMax, ); + const specDraftPMin = useChatRuntimeStore((s) => s.specDraftPMin); + const setSpecDraftPMin = useChatRuntimeStore((s) => s.setSpecDraftPMin); + const loadedSpecDraftPMin = useChatRuntimeStore( + (s) => s.loadedSpecDraftPMin, + ); const modelRequiresTrustRemoteCode = useChatRuntimeStore( (s) => s.modelRequiresTrustRemoteCode, ); @@ -614,7 +619,9 @@ export function ChatSettingsPanel({ const ctxDirty = customContextLength !== null; const specDirty = speculativeType !== loadedSpeculativeType; const specDraftDirty = specDraftNMax !== loadedSpecDraftNMax; - const modelSettingsDirty = kvDirty || ctxDirty || specDirty || specDraftDirty; + const specPMinDirty = specDraftPMin !== loadedSpecDraftPMin; + const modelSettingsDirty = + kvDirty || ctxDirty || specDirty || specDraftDirty || specPMinDirty; const chatTemplateOverride = useChatRuntimeStore( (s) => s.chatTemplateOverride, ); @@ -1004,6 +1011,7 @@ export function ChatSettingsPanel({ setSpeculativeType(v); if (v !== "mtp" && v !== "mtp+ngram") { setSpecDraftNMax(null); + setSpecDraftPMin(null); } }} > @@ -1066,6 +1074,46 @@ export function ChatSettingsPanel({ /> )} + {(speculativeType === "mtp" || + speculativeType === "mtp+ngram") && ( +