Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 147 additions & 9 deletions studio/backend/core/inference/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,11 +522,50 @@ def _build_ngram_mod_flags(
return []


# Canonical Speculative Decoding modes exposed by the Studio chat UI.
# The dropdown renders five options (auto, mtp, ngram, mtp+ngram, off);
# the load API also accepts legacy values that the original Switch and
# external callers emit (default, draft-mtp, ngram-mod, ngram-simple).
_CANONICAL_SPEC_MODES = {"auto", "mtp", "ngram", "mtp+ngram", "off", "ngram-simple"}
def _build_ngram_map_k_flags(
caps: Optional[dict],
*,
variant: str,
size_n: int = 16,
size_m: int = 24,
min_hits: int = 1,
) -> list[str]:
"""Emit knob flags for ``ngram-map-k`` / ``ngram-map-k4v`` if the
bundled llama-server advertises them (added upstream in #23269 and
sibling commits). Returns ``[]`` when the binary doesn't.
"""
cap_key = (
"supports_ngram_map_k4v"
if variant == "ngram-map-k4v"
else "supports_ngram_map_k"
)
if not (caps and caps.get(cap_key)):
return []
prefix = f"--spec-{variant}"
return [
f"{prefix}-size-n",
str(size_n),
f"{prefix}-size-m",
str(size_m),
f"{prefix}-min-hits",
str(min_hits),
]


# Canonical Speculative Decoding modes exposed by the Studio chat UI
# (5-mode dropdown) plus power-user n-gram variants accepted via the
# load API. ``ngram-map-k`` / ``ngram-map-k4v`` are not in the dropdown
# but the resolver emits them correctly when sent via API.
_CANONICAL_SPEC_MODES = {
"auto",
"mtp",
"ngram",
"mtp+ngram",
"off",
"ngram-simple",
"ngram-map-k",
"ngram-map-k4v",
}
_LEGACY_SPEC_MODE_MAP = {
"default": "auto",
"draft-mtp": "mtp",
Expand All @@ -538,9 +577,10 @@ def _canonicalize_spec_mode(value):
"""Map any accepted ``speculative_type`` input onto a canonical mode.

Returns one of ``auto``, ``mtp``, ``ngram``, ``mtp+ngram``, ``off``,
``ngram-simple``, or ``None`` (callers treat ``None`` as ``auto``).
Unknown strings collapse to ``auto`` so a stale UI value or typo
falls back to the safe platform-aware path.
``ngram-simple``, ``ngram-map-k``, ``ngram-map-k4v``, or ``None``
(callers treat ``None`` as ``auto``). Unknown strings collapse to
``auto`` so a stale UI value or typo falls back to the safe
platform-aware path.
"""
if value is None:
return None
Expand Down Expand Up @@ -641,6 +681,10 @@ def __init__(self):
self._requested_spec_mode: Optional[str] = None
# User-supplied --spec-draft-n-max override (None = platform default).
self._spec_draft_n_max: Optional[int] = None
# User-supplied --spec-draft-p-min override (None = platform default).
# Only emitted for MTP / MTP+Ngram modes; the flag became functional
# in llama.cpp #23269 (default upstream dropped to 0.0).
self._spec_draft_p_min: Optional[float] = None
# KV-cache estimation fields (populated by _read_gguf_metadata)
self._n_layers: Optional[int] = None
self._n_kv_heads: Optional[int] = None
Expand Down Expand Up @@ -932,6 +976,12 @@ def spec_draft_n_max(self) -> Optional[int]:
when the platform default (6 GPU / 3 CPU) is in effect."""
return self._spec_draft_n_max

@property
def spec_draft_p_min(self) -> Optional[float]:
"""User --spec-draft-p-min override on the load, or None when
the llama-server default (0.0 since #23269) is in effect."""
return self._spec_draft_p_min

# ── Binary discovery ──────────────────────────────────────────

@staticmethod
Expand Down Expand Up @@ -1059,7 +1109,8 @@ def probe_server_capabilities(
) -> dict[str, object]:
"""Parse `llama-server --help` for feature flags. Returns
{found, mtp_token, supports_mtp, ngram_mod_flavor,
supports_ngram_mod, spec_draft_n_max_flag}.
supports_ngram_mod, spec_draft_n_max_flag, spec_draft_p_min_flag,
supports_ngram_map_k, supports_ngram_map_k4v}.

``ngram_mod_flavor`` is ``"new"`` when the binary exposes the
post-rename ``--spec-ngram-mod-n-match / -n-min / -n-max`` as
Expand All @@ -1072,6 +1123,13 @@ def probe_server_capabilities(
``spec_draft_n_max_flag`` is the actual flag name the binary
accepts: ``--spec-draft-n-max`` on post-rename builds, or
``--draft-max`` on legacy. ``None`` means n_max cannot be set.

``spec_draft_p_min_flag`` is the flag name when the binary
advertises ``--spec-draft-p-min`` (functional from llama.cpp
#23269 onward; before that it was a knob with no effect on
MTP). ``supports_ngram_map_k`` / ``supports_ngram_map_k4v``
signal that the binary accepts these power-user spec types
with their own knob sets.
"""
bin_path = binary or cls._find_llama_server_binary()
if not bin_path or not Path(bin_path).is_file():
Expand All @@ -1082,6 +1140,9 @@ def probe_server_capabilities(
"ngram_mod_flavor": None,
"supports_ngram_mod": False,
"spec_draft_n_max_flag": None,
"spec_draft_p_min_flag": None,
"supports_ngram_map_k": False,
"supports_ngram_map_k4v": False,
}
try:
mtime = int(Path(bin_path).stat().st_mtime)
Expand All @@ -1095,6 +1156,9 @@ def probe_server_capabilities(
mtp_token: Optional[str] = None
ngram_mod_flavor: Optional[str] = None
spec_draft_n_max_flag: Optional[str] = None
spec_draft_p_min_flag: Optional[str] = None
supports_ngram_map_k = False
supports_ngram_map_k4v = False
try:
result = subprocess.run(
[bin_path, "--help"],
Expand Down Expand Up @@ -1184,6 +1248,24 @@ def _is_real(flag: str) -> bool:
spec_draft_n_max_flag = "--spec-draft-n-max"
elif _is_real("--draft-max"):
spec_draft_n_max_flag = "--draft-max"

# p_min flag: present on builds that include the MTP cleanup
# (llama.cpp #23269) or earlier draft-model spec impls.
if _is_real("--spec-draft-p-min"):
spec_draft_p_min_flag = "--spec-draft-p-min"

# ngram-map-k / ngram-map-k4v: power-user spec types added
# alongside ngram-mod. Each carries its own knob triplet.
supports_ngram_map_k = (
_is_real("--spec-ngram-map-k-size-n")
and _is_real("--spec-ngram-map-k-size-m")
and _is_real("--spec-ngram-map-k-min-hits")
)
supports_ngram_map_k4v = (
_is_real("--spec-ngram-map-k4v-size-n")
and _is_real("--spec-ngram-map-k4v-size-m")
and _is_real("--spec-ngram-map-k4v-min-hits")
)
except (OSError, subprocess.SubprocessError) as exc:
logger.debug(f"llama-server --help probe failed: {exc}")

Expand All @@ -1194,6 +1276,9 @@ def _is_real(flag: str) -> bool:
"ngram_mod_flavor": ngram_mod_flavor,
"supports_ngram_mod": ngram_mod_flavor is not None,
"spec_draft_n_max_flag": spec_draft_n_max_flag,
"spec_draft_p_min_flag": spec_draft_p_min_flag,
"supports_ngram_map_k": supports_ngram_map_k,
"supports_ngram_map_k4v": supports_ngram_map_k4v,
}
cls._capability_cache[cache_key] = info
return info
Expand Down Expand Up @@ -2502,6 +2587,7 @@ def load_model(
cache_type_kv: Optional[str] = None,
speculative_type: Optional[str] = None,
spec_draft_n_max: Optional[int] = None,
spec_draft_p_min: Optional[float] = None,
n_threads: Optional[int] = None,
n_gpu_layers: Optional[int] = None, # Accepted for caller compat, unused
n_parallel: int = 1,
Expand Down Expand Up @@ -2534,6 +2620,7 @@ def load_model(
cache_type_kv = cache_type_kv,
speculative_type = speculative_type,
spec_draft_n_max = spec_draft_n_max,
spec_draft_p_min = spec_draft_p_min,
chat_template_override = chat_template_override,
extra_args = extra_args,
is_vision = is_vision,
Expand Down Expand Up @@ -2915,6 +3002,7 @@ def load_model(
spec_flags = self._build_speculative_flags(
speculative_type = speculative_type,
spec_draft_n_max = spec_draft_n_max,
spec_draft_p_min = spec_draft_p_min,
extra_args = extra_args,
model_identifier = model_identifier,
model_path = model_path,
Expand Down Expand Up @@ -3258,6 +3346,7 @@ def _build_speculative_flags(
*,
speculative_type: Optional[str],
spec_draft_n_max: Optional[int],
spec_draft_p_min: Optional[float] = None,
extra_args: Optional[List[str]],
model_identifier: str,
model_path: Optional[str],
Expand Down Expand Up @@ -3299,6 +3388,7 @@ def _build_speculative_flags(
flags: List[str] = []
# Reset; emit branches re-set on the resolved emission.
self._spec_draft_n_max = None
self._spec_draft_p_min = None
self._speculative_type = None

# Canonical UI-facing requested mode: auto / mtp / ngram /
Expand Down Expand Up @@ -3334,6 +3424,22 @@ def _resolved_draft_n_max() -> int:
return n
return 2 if gpus else 3

def _maybe_emit_p_min(caps: Optional[dict]) -> None:
"""Append --spec-draft-p-min if the user supplied a value AND
the binary advertises the flag (functional from #23269)."""
if spec_draft_p_min is None:
return
p_flag = (caps or {}).get("spec_draft_p_min_flag")
if not p_flag:
logger.warning(
"llama-server lacks --spec-draft-p-min; ignoring "
"spec_draft_p_min override. Run `unsloth studio update`."
)
return
p_min = float(spec_draft_p_min)
self._spec_draft_p_min = p_min
flags.extend([p_flag, f"{p_min:.4f}"])

def _emit_mtp(*, chain_ngram: bool) -> bool:
"""Append --spec-type mtp[/draft-mtp][,ngram-mod] + n-max."""
caps = self.probe_server_capabilities(binary)
Expand Down Expand Up @@ -3376,6 +3482,7 @@ def _emit_mtp(*, chain_ngram: bool) -> bool:
str(draft_n_max),
]
)
_maybe_emit_p_min(caps)
self._speculative_type = "draft-mtp"
chain_label = "chained ngram-mod" if chain_ngram else "MTP-only"
logger.info(f"Spec decoding: {mtp_token} ({chain_label})")
Expand Down Expand Up @@ -3405,6 +3512,25 @@ def _emit_ngram_mod() -> bool:
if effective_mode == "ngram":
_emit_ngram_mod()
return flags
if effective_mode in ("ngram-map-k", "ngram-map-k4v"):
map_caps = self.probe_server_capabilities(binary)
supported = map_caps.get(
"supports_ngram_map_k4v"
if effective_mode == "ngram-map-k4v"
else "supports_ngram_map_k"
)
Comment on lines +3517 to +3521

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for determining the capability key can be simplified using an f-string, which would make the code more concise and easier to read.

            supported = map_caps.get(f"supports_{effective_mode.replace('-', '_')}")

if not supported:
logger.warning(
f"llama-server lacks --spec-type {effective_mode}; "
"run `unsloth studio update`. Loading without "
"speculative decoding."
)
return flags
flags.extend(["--spec-type", effective_mode])
flags.extend(_build_ngram_map_k_flags(map_caps, variant = effective_mode))
self._speculative_type = effective_mode
logger.info(f"Spec decoding: {effective_mode}")
return flags
if effective_mode == "mtp":
if _mtp_too_small:
logger.warning(
Expand Down Expand Up @@ -3480,6 +3606,7 @@ def _already_in_target_state(
is_vision: bool,
gguf_path: Optional[str] = None,
spec_draft_n_max: Optional[int] = None,
spec_draft_p_min: Optional[float] = None,
) -> bool:
"""True iff the live server already satisfies these load kwargs.

Expand Down Expand Up @@ -3540,6 +3667,16 @@ def _norm(value):
):
return False

# Same treatment for spec_draft_p_min: only relevant when MTP is
# the resolved emit. None on either side means "default" (0.0
# since llama.cpp #23269) and matches anything.
if (
self._speculative_type == "draft-mtp"
and spec_draft_p_min is not None
and abs(float(spec_draft_p_min) - (self._spec_draft_p_min or 0.0)) > 1e-6
):
return False

if (self._chat_template_override or None) != (chat_template_override or None):
return False

Expand Down Expand Up @@ -3608,6 +3745,7 @@ def unload_model(self) -> bool:
self._speculative_type = None
self._requested_spec_mode = None
self._spec_draft_n_max = None
self._spec_draft_p_min = None
self._n_layers = None
self._n_kv_heads = None
self._n_kv_heads_by_layer = None
Expand Down
55 changes: 43 additions & 12 deletions studio/backend/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,16 @@ def normalize_blank_chat_template_override(
speculative_type: Optional[str] = Field(
None,
description = (
"Speculative decoding mode for GGUF models. Canonical values: "
"'auto' (platform-aware: MTP on MTP GGUFs, ngram-mod fallback "
"for sub-3B), 'mtp' (force draft-mtp only on both GPU and CPU), "
"'ngram' (force ngram-mod only), 'mtp+ngram' (force "
"ngram-mod+draft-mtp chain on both platforms), 'off' (disabled). "
"Legacy values 'default' (-> auto), 'draft-mtp' (-> mtp), "
"'ngram-mod' (-> ngram), and 'ngram-simple' (kept as-is) are "
"still accepted. Ignored for non-GGUF and vision models."
"Speculative decoding mode for GGUF models. Canonical values "
"exposed by the Studio dropdown: 'auto' (platform-aware: MTP "
"on MTP GGUFs, ngram-mod fallback for sub-3B), 'mtp' (force "
"draft-mtp only on both GPU and CPU), 'ngram' (force "
"ngram-mod only), 'mtp+ngram' (force ngram-mod+draft-mtp "
"chain on both platforms), 'off' (disabled). Power-user "
"spec types accepted via API: 'ngram-simple', 'ngram-map-k', "
"'ngram-map-k4v'. Legacy values 'default' (-> auto), "
"'draft-mtp' (-> mtp), 'ngram-mod' (-> ngram) are still "
"accepted. Ignored for non-GGUF and vision models."
),
)
spec_draft_n_max: Optional[int] = Field(
Expand All @@ -93,6 +95,19 @@ def normalize_blank_chat_template_override(
"'mtp' or 'mtp+ngram'."
),
)
spec_draft_p_min: Optional[float] = Field(
None,
ge = 0.0,
le = 1.0,
description = (
"Min draft probability for MTP speculative decoding "
"(--spec-draft-p-min). Drafts with predicted probability "
"below this threshold are rejected. Defaults to 0.0 (no "
"filtering) since llama.cpp #23269; before that the flag "
"existed but was non-functional. Only applied when "
"speculative_type resolves to 'mtp' or 'mtp+ngram'."
),
)
llama_extra_args: Optional[List[str]] = Field(
None,
description = (
Expand Down Expand Up @@ -242,8 +257,9 @@ class LoadResponse(BaseModel):
description = (
"Canonical UI-facing requested speculative decoding mode "
"('auto' / 'mtp' / 'ngram' / 'mtp+ngram' / 'off' / "
"'ngram-simple'), round-tripped from the original LoadRequest "
"via _canonicalize_spec_mode. None when no model is loaded."
"'ngram-simple' / 'ngram-map-k' / 'ngram-map-k4v'), "
"round-tripped from the original LoadRequest via "
"_canonicalize_spec_mode. None when no model is loaded."
),
)
spec_draft_n_max: Optional[int] = Field(
Expand All @@ -253,6 +269,13 @@ class LoadResponse(BaseModel):
"None when the platform default is in effect."
),
)
spec_draft_p_min: Optional[float] = Field(
None,
description = (
"Active --spec-draft-p-min for MTP speculative decoding, or "
"None when the llama-server default (0.0) is in effect."
),
)


class UnloadResponse(BaseModel):
Expand Down Expand Up @@ -376,8 +399,9 @@ class InferenceStatusResponse(BaseModel):
description = (
"Canonical UI-facing requested speculative decoding mode "
"('auto' / 'mtp' / 'ngram' / 'mtp+ngram' / 'off' / "
"'ngram-simple'), round-tripped from the original LoadRequest. "
"None when no model is loaded."
"'ngram-simple' / 'ngram-map-k' / 'ngram-map-k4v'), "
"round-tripped from the original LoadRequest. None when no "
"model is loaded."
),
)
spec_draft_n_max: Optional[int] = Field(
Expand All @@ -387,6 +411,13 @@ class InferenceStatusResponse(BaseModel):
"None when the platform default is in effect."
),
)
spec_draft_p_min: Optional[float] = Field(
None,
description = (
"Active --spec-draft-p-min for MTP speculative decoding, or "
"None when the llama-server default (0.0) is in effect."
),
)
llama_cpp_supports_mtp: bool = Field(
True,
description = (
Expand Down
Loading
Loading