Skip to content

[Feature] TurboQuant: support hybrid models and uniform quantization#39931

Merged
mgoin merged 19 commits intovllm-project:mainfrom
JartX:feature/hybrid_turboquant
May 5, 2026
Merged

[Feature] TurboQuant: support hybrid models and uniform quantization#39931
mgoin merged 19 commits intovllm-project:mainfrom
JartX:feature/hybrid_turboquant

Conversation

@JartX
Copy link
Copy Markdown
Contributor

@JartX JartX commented Apr 15, 2026

TurboQuant support for hybrid models

This PR fixes TurboQuant startup for hybrid models such as Qwen3.5, Qwen3-Next, and similar architectures.

Previously, TurboQuant would fail with a NotImplementedError as soon as it encountered Mamba layers. With this change, hybrid models are now handled correctly: TurboQuant is applied only to full_attention layers.

Additional fixes

While enabling proper hybrid support, this PR also fixes three additional issues:

  • Page-size planner mismatch: The hybrid page-size planner was sizing attention pages using the standard formula, which does not match TurboQuant's packed K|V layout. As a result, every TurboQuant attention layer could trigger an assertion in the page merger. The planner now uses the TurboQuant-specific layout.

  • Incorrect backend selection for excluded layers: If a layer was excluded from TurboQuant — for example because it was a skipped layer, a sliding-window layer, or a Mamba layer — the ROCm/CUDA backend selector could still incorrectly force the TURBOQUANT backend. These layers now correctly fall back to the default backend.

  • ROCm flash_attn_varlen_func incompatibility: On ROCm, upstream flash_attn_varlen_func does not accept out=. A lightweight wrapper now detects that case and copies the result only when needed.

Summary

Overall, this makes TurboQuant work reliably on hybrid architectures while preserving the current behavior and baselines for dense models.

Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX force-pushed the feature/hybrid_turboquant branch from 9d6d814 to 8c054e9 Compare April 15, 2026 17:52
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables TurboQuant support for hybrid models (e.g., attention combined with Mamba or linear-attention) by introducing a mechanism to identify full-attention layers and ensuring proper backend fallback for non-quantized layers. It also includes adjustments for TurboQuant's packed KV layout in page size calculations and a compatibility wrapper for Flash Attention on ROCm. I have no feedback to provide as there are no review comments.

@JartX JartX changed the title [Feature] TurboQuant: support hybrid models and uniform quantization and [Feature] TurboQuant: support hybrid models and uniform quantization Apr 15, 2026
Signed-off-by: JartX <sagformas@epdcenter.es>
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 15, 2026
@mergify mergify Bot added nvidia rocm Related to AMD ROCm v1 labels Apr 15, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 15, 2026
@JartX JartX marked this pull request as draft April 15, 2026 19:38
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX marked this pull request as ready for review April 15, 2026 22:27
@gaby
Copy link
Copy Markdown

gaby commented Apr 16, 2026

@vibhavagarwal5 @mgoin Thoughts on this?

Comment thread vllm/engine/arg_utils.py Outdated
Comment thread vllm/platforms/cuda.py Outdated
Signed-off-by: JartX <sagformas@epdcenter.es>
@JartX JartX requested a review from tlrmchlsmth as a code owner April 16, 2026 22:31
Signed-off-by: JartX <sagformas@epdcenter.es>
@MidasMining
Copy link
Copy Markdown

Cross-validation on Nemotron-H hybrid (Mamba+MoE+Attention) at TP=8

Tested this PR's approach on Nemotron-3-Super-120B-AWQ-4bit (88 layers: ~80 Mamba/MoE + 8 full-attention) and Nemotron-Cascade-2-30B-A3B (52 layers: ~44 Mamba/MoE + 8 full-attention) on 8× RTX A4000, vLLM 0.20.0, driver 580.76.05/CUDA 13.0.

This PR overlaps with #41123 but takes a more principled approach to the page-size unification problem. Sharing data both for validation and to help arbitrate between the two PRs.

Validation: this PR's approach is correct

The lcm(tq_page, skip_page) page-size logic in platforms/interface.py is the right fix. Posted a related diagnosis on #41123 noting that the v0.20 change to slot_size_aligned (power-of-2 → even) was the actual root cause of the page-size unification failure on hybrids — but our local "revert to power-of-2" was a workaround, not a fix. Your LCM-based approach is the proper solution because it handles mixed skip+TQ layers without forcing all slot sizes to be powers of 2.

The is_hybrid boundary-protection skip is also better motivated than #41123's preserve-as-is — Nemotron-H's 8 attention layers in 88-layer Super-120B would lose 50% of TQ-eligible layers under default n=2 boundary protection. Disabling it for hybrids is correct.

Test results

Standard 22-check practical bench (zmq, pplns, expert, nightmare, hiveos):

Model Config Quality Notes
Cascade-2 BF16 (30B/3B) turboquant_3bit_nc 100% KV cache 1.7M tokens (PR #40941 buffer-share win)
Super-120B AWQ-4bit (120B/12B) turboquant_3bit_nc 100% KV cache 587K tokens at gpu_mem_util=0.92

Decode throughput on Super-120B:

Note on Qwen3.5 baseline question

Re: gaby's Apr 18 question about Qwen3.5-35B-A3B-GPTQ-W4A16-G32 not being an official model — it's a community quant of the official Qwen3.5-35B-A3B. For Nemotron-H we tested against the upstream NVIDIA AWQ-4bit and BF16 checkpoints, both showing 100% on our 22-check bench when run with this PR's slot_size_aligned-aware page planner.

Hybrid model attention-layer count concern

For Nemotron-H Super-120B specifically: 8 attention layers out of 88 total. With this PR's hybrid path (no boundary skip), all 8 get TQ. Quality is preserved at 100% on our bench. But for models with even fewer attention layers (e.g., extreme-hybrid future architectures), the absence of any boundary protection might surface quality regressions on aggressive presets (turboquant_3bit_nc). Worth flagging in case the maintainer wants a safety knob — current behavior is correct for the models that exist today.

Suggestion vs #41123

Recommend merging this PR over #41123:

  • page_size_bytes via LCM correctly handles mixed-precision (skip layers + TQ layers) without slot_size_aligned assumptions
  • _get_full_attention_layer_indices is a real model-aware fix; fix(kv-cache): allow TurboQuant on hybrid models #41123's "remove the rejection" alone leaves the page-size error downstream
  • ROCm flash_attn_varlen_func wrapper is a useful adjacent fix that fix(kv-cache): allow TurboQuant on hybrid models #41123 doesn't address
  • Boundary-skip override for hybrids is correct (we'd lose 25% of TQ-eligible layers on Super-120B under default n=2)

If the page-planner change here lands, our local slot_size_aligned revert becomes unnecessary.

— MidasMining, 8× RTX A4000 / Nemotron-3-Super-120B + Cascade-2 / vLLM v0.20.0 + this PR

@webcodes-cz
Copy link
Copy Markdown

Thanks for this work — we were able to use the hybrid TurboQuant path to run Qwen3.6-27B FP8 on a single RTX 5090 32GB.

Our setup:

  • GPU: RTX 5090 32GB / Blackwell
  • Model: Qwen3.6-27B FP8 with lm_head also quantized to FP8
  • KV cache: --kv-cache-dtype turboquant_k8v4
  • --gpu-memory-utilization 0.96
  • --max-model-len 5120
  • --max-num-seqs 3
  • --max-num-batched-tokens 15360

Observed result:

  • startup: PASS
  • Czech sanity generation: PASS
  • real concurrent test: 3 parallel requests, each around 5065 prompt tokens + 32 completion tokens
  • no OOM

This is a meaningful improvement for our use case. Before combining FP8 lm_head with hybrid TurboQuant, the same model was too tight on 32GB cards: either startup/profiling OOM or too little KV cache. With this PR path, the bottleneck moved into a usable serving envelope.

For reference:

We also tested turboquant_4bit_nc. It worked for a long-context profile (max_model_len=6144, max_num_seqs=2), but turboquant_k8v4 was the better production-quality profile for us.

@jhsmith409
Copy link
Copy Markdown
Contributor

Adding a hybrid-MoE long-context data point on the same hardware class (@webcodes-cz's RTX 5090 32 GiB, but pushing context length).

Running the three production-file changes from this PR (engine/arg_utils.py, model_executor/.../turboquant/config.py, platforms/interface.py) as bind-mount overlays on vllm/vllm-openai:v0.20.0-cu130. The arg_utils.py was applied surgically (just the hybrid-guard removal + new get_boundary_skip_layers(model_config) call-site) since the PR head is based on post-v0.20.0 main.

Model: cyankiwi/Qwen3.6-35B-A3B-AWQ-4bit — Qwen3-Next-style hybrid (30 GDN + 10 full_attention), compressed-tensors W4A16 int4 MoE. Boot picked the layout up correctly:

TQ hybrid: full-attention layers [3, 7, 11, 15, 19, 23, 27, 31, 35, 39]
Using TURBOQUANT attention backend out of potential backends: ['TURBOQUANT'].
Overriding flash_attn_version to 2.

What works: short-/mid-context decode throughput

--kv-cache-dtype turboquant_4bit_nc vs fp8 baseline, both at --max-model-len 240000 --max-num-batched-tokens 32768 --max-num-seqs 16 --gpu-memory-utilization 0.9346. 1000-word generation sweep, 16 unique prompts, ignore_eos=True, max_tokens=1300, one warmup per N:

N fp8 agg / per-req dec TQ4_nc agg / per-req dec per-req decode Δ
1 204.9 / 206.3 195.4 / 196.8 −4.6%
2 296.6 / 151.3 285.6 / 144.3 −4.6%
4 577.0 / 145.8 547.5 / 140.0 −4.0%
8 930.8 / 129.4 892.7 / 123.8 −4.3%
16 1628.4 / 112.2 1541.2 / 106.7 −4.9%

KV pool grew 62,880 → 135,168 tokens (+115%). Effective single-seq context (× 4 hybrid factor) ~251K → ~540K. Decode tok/s loses ~4–5% across all batch sizes — clean trade for the pool win.

What doesn't yet work: long-context prefill

Pushing to --max-model-len 256000 and trying a NIAH-style 250K-token prefill hits a deterministic OOM, always at the same line, regardless of total request length:

File "vllm/model_executor/layers/fla/ops/chunk_o.py", line 161, in chunk_fwd_o
    o = torch.empty_like(v)
torch.OutOfMemoryError: Tried to allocate 256.00 MiB. ...

The 256 MiB ask is constant whether the input is 119K, 196K, or 242K tokens — chunk_fwd_o's output allocation is bounded by max-num-batched-tokens (so 256 MiB at 32768, 128 MiB at 16384), not by total request length. Free memory at OOM time is also constant for a given (util, batched-tokens) — about 75 MiB without any margin patch.

Arithmetic from boot vs runtime suggests the startup profile is under-counting:

  • Profile-time activation peak ≈ 1.87 GiB (back-derived from Available KV cache memory boot log).
  • At runtime OOM, PyTorch has ~30 GiB allocated; subtracting weights (22.4) + reserved KV pool (4.12) leaves ~3.5 GiB of activation. ~1.6 GiB more than profile saw.
  • The KV pool got sized to consume the gap, leaving no chunk_fwd_o headroom.

Two structural reasons the profile likely misses it:

  1. Profile is forced eager (gpu_model_runner.py:5366: force_eager=is_profile). Production long-prefill runs through the inductor-compiled graph at compile_ranges_endpoints=[max_num_batched_tokens] and materialises intermediates the eager profile doesn't.
  2. _dummy_run distributes tokens across max_num_seqs (e.g. 16 × 2048 in our config). Production long-prefill is single-seq × full max-num-batched-tokens — a shape the multi-seq profile may not exercise identically.

A gpu_worker.py overlay that adds a hybrid-GDN safety margin to determine_available_memory (1.0 / 1.25 / 1.35 GiB tested) gets close — 1.25 GiB was the sweet spot, leaves 255.69 MiB free at chunk_fwd_o time vs the 256 MiB ask. Short by 0.3 MiB, with 510 MiB classified as "reserved by PyTorch but unallocated" (caching-allocator fragmentation). It's a band-aid, not a fix.

Suggestion

Long-context generation appears well-validated by the existing tests in this PR. The gap is on the prefill side specifically. Would be worth a closer look at the profile-side coverage for hybrid GDN models:

  • Add a profile pass with a single sequence at max_num_batched_tokens length alongside the existing multi-seq pass; take the max peak. Catches shape-sensitive allocation patterns even if the eager-vs-compiled gap remains.
  • If that doesn't close it, a profile-after-compile pass so the inductor allocation pattern is actually measured (more invasive — the ordering with KV cache allocation needs care).
  • chunk_fwd_o's o = torch.empty_like(v) is the hot allocation site to instrument. The v_new it consumes isn't read after chunk_fwd_o returns, so an in-place / aliased output is theoretically possible but needs kernel-level review.

Happy to share the bench scripts and full overlay set if useful for reproducing. The TurboQuant decode-side is a clean win — thanks for this PR.

@jhsmith409
Copy link
Copy Markdown
Contributor

Follow-up — NIAH bench script, as offered.

Builds a filler text, inserts N synthetic needles (The magic number for FakeCity_NNN is XXXXX.), uses the /tokenize endpoint to size the prompt accurately to a target percentage of max-model-len, sends the request, parses the response, scores recall. Edit HOST and MODEL (the served-model-name) at the top to match your deployment.

TARGET_TOTAL_TOKENS is the knob — set to 0.90 * 256000 here as the practical ceiling we hit with the gpu_worker.py safety-margin overlay; bump back to 0.98 * 256000 to reproduce the deterministic OOM in chunk_fwd_o.

Run:

python3 bench_needle.py
#!/usr/bin/env python3
"""Needle-in-haystack at 98% of 256K context, sweep N=8,16,32,64,128,256.

Each test: build a ~245K-token filler text, insert N needles ("The magic number
for FakeCity_<i> is <5-digit>"), append a question asking for the enumeration,
send the request, parse the response, score how many needles were correctly
retrieved. Uses vLLM's /tokenize endpoint to size the prompt accurately.
"""

import asyncio
import json
import re
import sys
import time

import aiohttp

HOST = "http://localhost:8005"
MODEL = "Qwen3.6-35B"
TARGET_TOTAL_TOKENS = int(0.90 * 256000)  # 230,400 — practical ceiling on RTX 5090 + 1.25 GiB safety margin
NEEDLE_COUNTS = [8, 16, 32, 64, 128, 256]
OUTPUT_BUDGET = 9000  # tokens reserved for the model's enumeration response

FILLER_BLOCK = (
    "In the early afternoon the lighthouse keeper noted that the sea was unusually "
    "calm, the gulls had retreated inland, and a faint smell of iron drifted on the "
    "breeze. He recorded the observation, the tide level, the barometric pressure, "
    "and the direction of the wind in his journal, as he had done every day for the "
    "past nineteen years. The keeper was a methodical man, descended from a long line "
    "of methodical lighthouse keepers, and he believed that careful observation was "
    "the foundation of all useful knowledge. The cliffs to the west were limestone, "
    "veined with quartz, and they had stood against the Atlantic for longer than any "
    "human story could measure. Pelicans nested in the rookeries below, their cries "
    "rising and falling in counterpoint to the surf. The keeper made a sandwich, ate "
    "it slowly, and watched the horizon for a sail or a thunderhead or any small "
    "aberration that might give the day a shape distinct from yesterday. None came. "
    "He drank his tea, brewed strong with a slice of lemon, and reflected that nothing "
    "much had happened in nineteen years and that this was, perhaps, the highest form "
    "of accomplishment a lighthouse keeper could aspire to. "
)


def needle_text(i: int) -> tuple[str, str, str]:
    city = f"FakeCity_{i:03d}"
    magic = f"{(i * 314159 + 7919) % 100000:05d}"
    sentence = f" The magic number for {city} is {magic}. "
    return city, magic, sentence


def build_prompt(num_needles: int, filler_chars: int) -> tuple[str, list[tuple[str, str]]]:
    needles = [needle_text(i) for i in range(num_needles)]
    question = (
        "\n\n=== END OF DOCUMENT ===\n\n"
        f"The document above contains exactly {num_needles} hidden facts, each in "
        "the form: 'The magic number for FakeCity_NNN is XXXXX.' List ALL of them. "
        "Output one per line in this exact format (no other text, no preamble):\n"
        "FakeCity_NNN: XXXXX\n\n"
        "Begin enumeration:\n"
    )
    needle_chars = sum(len(n[2]) for n in needles)
    filler_total = max(0, filler_chars - needle_chars - len(question))
    filler = (FILLER_BLOCK * (filler_total // len(FILLER_BLOCK) + 1))[:filler_total]

    # interleave needles evenly through the filler
    chunk = len(filler) // (num_needles + 1) if num_needles else len(filler)
    parts: list[str] = []
    for i in range(num_needles):
        parts.append(filler[i * chunk : (i + 1) * chunk])
        parts.append(needles[i][2])
    parts.append(filler[num_needles * chunk :])
    prompt = "".join(parts) + question
    return prompt, [(c, m) for c, m, _ in needles]


async def tokenize(session: aiohttp.ClientSession, text: str) -> int:
    async with session.post(
        f"{HOST}/tokenize",
        json={"model": MODEL, "prompt": text, "add_special_tokens": False},
    ) as r:
        r.raise_for_status()
        d = await r.json()
        return d.get("count") or len(d.get("tokens", []))


async def size_prompt_to_target(
    session: aiohttp.ClientSession, num_needles: int, target_tokens: int
) -> tuple[str, list[tuple[str, str]], int]:
    """Build a prompt and binary-search filler size to hit target_tokens (within +-200)."""
    lo, hi = 50_000 * 4, 1_500_000  # filler_chars search range
    best = None
    for _ in range(8):
        mid = (lo + hi) // 2
        prompt, needles = build_prompt(num_needles, mid)
        n_tokens = await tokenize(session, prompt)
        if abs(n_tokens - target_tokens) < 200:
            return prompt, needles, n_tokens
        if n_tokens > target_tokens:
            hi = mid
        else:
            lo = mid
        best = (prompt, needles, n_tokens)
    return best  # type: ignore[return-value]


async def run_one(session: aiohttp.ClientSession, num_needles: int):
    target_input = TARGET_TOTAL_TOKENS - OUTPUT_BUDGET
    print(f"  [N={num_needles}] sizing prompt to ~{target_input} tokens ...", flush=True)
    t0 = time.perf_counter()
    prompt, needles, actual_tokens = await size_prompt_to_target(
        session, num_needles, target_input
    )
    t_size = time.perf_counter() - t0
    print(
        f"  [N={num_needles}] sized in {t_size:.1f}s — input {actual_tokens} tokens, sending ...",
        flush=True,
    )

    payload = {
        "model": MODEL,
        "prompt": prompt,
        "max_tokens": OUTPUT_BUDGET,
        "temperature": 0.0,
        "top_p": 1.0,
        "stream": True,
        "stream_options": {"include_usage": True},
    }
    t_send = time.perf_counter()
    t_first = None
    completion_tokens = 0
    out_chunks: list[str] = []
    async with session.post(f"{HOST}/v1/completions", json=payload) as r:
        r.raise_for_status()
        async for raw in r.content:
            if not raw:
                continue
            line = raw.decode("utf-8", "ignore").strip()
            if not line.startswith("data:"):
                continue
            data = line[5:].strip()
            if data == "[DONE]":
                break
            try:
                ev = json.loads(data)
            except json.JSONDecodeError:
                continue
            if ev.get("choices") and ev["choices"][0].get("text"):
                if t_first is None:
                    t_first = time.perf_counter()
                out_chunks.append(ev["choices"][0]["text"])
            usage = ev.get("usage")
            if usage:
                completion_tokens = usage.get("completion_tokens", completion_tokens)
    t_end = time.perf_counter()
    if t_first is None:
        t_first = t_end

    response = "".join(out_chunks)
    # parse "FakeCity_NNN: XXXXX" lines
    found: dict[str, str] = {}
    for m in re.finditer(r"FakeCity_(\d{3})\s*[:\-]\s*(\d{5})", response):
        city = f"FakeCity_{m.group(1)}"
        found[city] = m.group(2)
    expected = dict(needles)
    correct = sum(1 for c, v in expected.items() if found.get(c) == v)
    recalled = sum(1 for c in expected if c in found)
    extra = [c for c in found if c not in expected]
    wrong = [(c, found[c], expected[c]) for c in expected if c in found and found[c] != expected[c]]

    return {
        "n": num_needles,
        "input_tokens": actual_tokens,
        "output_tokens": completion_tokens,
        "ttft_s": t_first - t_send,
        "decode_s": max(t_end - t_first, 1e-6),
        "wall_s": t_end - t_send,
        "recall_correct": correct,
        "recall_pct": 100.0 * correct / num_needles if num_needles else 0.0,
        "recalled_total": recalled,
        "wrong_value_count": len(wrong),
        "extra_count": len(extra),
        "first_wrong": wrong[:3],
    }


async def main():
    print(
        f"=== needle-in-haystack  HOST={HOST}  MODEL={MODEL}  "
        f"target={TARGET_TOTAL_TOKENS} tokens (98% of 256K), output_budget={OUTPUT_BUDGET} ==="
    )
    timeout = aiohttp.ClientTimeout(total=900)
    async with aiohttp.ClientSession(timeout=timeout) as session:
        for n in NEEDLE_COUNTS:
            try:
                r = await run_one(session, n)
            except Exception as e:
                print(f"  N={n} FAILED: {type(e).__name__}: {e}")
                continue
            print(
                f"  N={r['n']:>3}  in={r['input_tokens']:>6}t  "
                f"out={r['output_tokens']:>5}t  "
                f"TTFT={r['ttft_s']:5.1f}s  decode={r['decode_s']:5.1f}s  "
                f"wall={r['wall_s']:5.1f}s  "
                f"recall={r['recall_correct']:>3}/{r['n']} ({r['recall_pct']:5.1f}%)  "
                f"wrong_val={r['wrong_value_count']}  extra={r['extra_count']}",
                flush=True,
            )


if __name__ == "__main__":
    asyncio.run(main())

Caveat: the docstring still says "98% of 256K" but the constant got dialled back to 0.90 after the OOM investigation — change TARGET_TOTAL_TOKENS to int(0.98 * 256000) to reproduce the deterministic chunk_fwd_o OOM at the safety-margin ceiling.

mlzy added a commit to focalcrest/vllm-sm70 that referenced this pull request May 3, 2026
TurboQuant KV cache (turboquant_k8v4) failed on hybrid attention/mamba
models due to page_size_bytes mismatch across MambaSpec, TQFullAttentionSpec,
and FullAttentionSpec layers.

Fix ported from upstream PR vllm-project#39931 (JartX fork):
- platforms/interface.py: add TQ branch in _align_hybrid_block_size that
  computes attn_page_size_1_token using TQFullAttentionSpec with lcm()
  for skip layers; handle turboquant in kv_cache_dtype resolution
- turboquant/config.py: rewrite get_boundary_skip_layers(ModelConfig) for
  hybrid support, add _get_full_attention_layer_indices(); hybrid models
  get no boundary protection (too few attention layers)
- engine/arg_utils.py: simplify to single get_boundary_skip_layers() call
- platforms/cuda.py: add TURBOQUANT to SM70 backend list

Cudagraph benchmark (V100 TP8, block_size=1040): 60-62 tok/s,
near-zero overhead vs baseline fp16 ~60.4 tok/s.

Based-on: vllm-project#39931
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for cleaning this up!

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 5, 2026
@mgoin mgoin merged commit 4f2af1a into vllm-project:main May 5, 2026
72 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 5, 2026
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA May 5, 2026
@xyehya
Copy link
Copy Markdown

xyehya commented May 5, 2026

Great job guys

jsboige added a commit to jsboige/vllm that referenced this pull request May 6, 2026
Migrate prod (GPUs 0,1, port 5002) from Qwen3.6-35B-A3B MoE to
Qwen3.6-27B Dense with TurboQuant K8V4 KV cache, after upstream PR
vllm-project#39931 (TurboQuant hybrid model support, commit 4f2af1a) merged
on 2026-05-05.

New artifacts:
- Dockerfile.qwen36-27b-tq: base nightly e47c98e (post-merge) +
  transformers>=5.0 (qwen3_5 dense model_type) + shm_broadcast.py
  patch carried forward (PR vllm-project#40303 OPEN).
- profiles/medium-qwen36-27b.yml: TP=2 (no EP, Dense), TurboQuant
  K8V4, max_model_len 262144, qwen3_coder + qwen3 parsers,
  preserve_thinking default, watchdog sidecar.

Bench (post-warmup, 2026-05-06):
- KV cache: 516K tokens (vs MoE 322K, +60%)
- Decode single-user: 52-54 tok/s (vs MoE 107, -50%)
- Decode thinking: 50.5 tok/s (vs MoE 116.5, -57%)
- Concurrent 5 (aggregate): 189 tok/s (vs MoE 369, -49%)
- Tool call latency: 0.66s (vs MoE 0.47s, +40%)

Speed regressions trip all 3 of the migration plan's "consider
rollback" thresholds (decode <80, concurrent <200, tool >0.6s).
Upstream quality gains (SWE +3.8, Terminal-Bench +7.8, SkillsBench
+19.5) NOT yet locally validated. MoE profile + image retained
for fast rollback (~10-15 min).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
chaojun-zhang pushed a commit to chaojun-zhang/vllm that referenced this pull request May 6, 2026
…llm-project#39931)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Sandermage <sandermage@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
jsboige added a commit to jsboige/vllm that referenced this pull request May 6, 2026
Same-day cutover sequence 2026-05-06, both candidates rejected:
- 27B Dense+TQ K8V4: tripped all 3 perf rollback thresholds (decode -50%,
  5-concurrent -49%, tool +40%); GSM8K/IFEval gains within sampling noise.
- MoE 35B-A3B+TQ K8V4: booted with 1.49M-token KV (+4.6x vs FP8 322K) but
  EngineCore crashes on first chunked-prefill continuation (workspace
  16.31->29.73 MB, turboquant_attn.py:720). Upstream issue vllm#41726
  already filed by jhsmith409, candidate fix PR vllm-project#40798 open. Our repro
  posted as issue comment confirms persistence post-vllm-project#39931 on hybrid MoE.

Restored production: vllm-qwen36-shmpatched:nightly-f6983f01d-patched1 +
--kv-cache-dtype fp8 (Apr 06 baseline, stable since 2026-04-19). All
smoke tests pass (chat, thinking, tool calling).

Files:
- CLAUDE.md: TQ migration section rewritten as REJECTED, current state
  reverted to MoE+FP8, deployment table updated, 2 entries added to
  rejected models list.
- profiles/medium-qwen36-moe.yml: image + kv-cache-dtype reverted with
  inline rationale.
- Dockerfile.qwen36-27b-tq -> Dockerfile.qwen36-tq (renamed generic, used
  for both 27B and MoE TQ attempts; image vllm-qwen36-tq:nightly-e47c98ef-
  patched1 retained for re-test once vllm-project#40798 merges).
- profiles/medium-qwen36-27b.yml -> archives/2026/medium-qwen36-27b.yml.
  rejected-2026-05-06.
- qwen3_benchmark/lmms_results/qwen3.6-27b/: GSM8K + IFEval results
  preserved as evidence for the rejection rationale.

Upstream tracking:
- Issue: vllm-project#41726
- PR (fix): vllm-project#40798
- Our comment: vllm-project#41726 (comment)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…llm-project#39931)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Sandermage <sandermage@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request May 7, 2026
…llm-project#39931)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Sandermage <sandermage@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
libinta pushed a commit to libinta/vllm that referenced this pull request May 8, 2026
…llm-project#39931)

Signed-off-by: JartX <sagformas@epdcenter.es>
Signed-off-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Jim Smith <jhsmith0@me.com>
Co-authored-by: Sandermage <sandermage@users.noreply.github.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: Libin Tang <libin.tang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.