Skip to content

[ROCm][CI] Fix logprob divergence for TitanML/tiny-mixtral under AITER rms_norm#36101

Merged
gshtras merged 4 commits intovllm-project:mainfrom
ROCm:akaratza_fix_lang_mod_s
Mar 9, 2026
Merged

[ROCm][CI] Fix logprob divergence for TitanML/tiny-mixtral under AITER rms_norm#36101
gshtras merged 4 commits intovllm-project:mainfrom
ROCm:akaratza_fix_lang_mod_s

Conversation

@AndreasKaratzas
Copy link
Copy Markdown
Collaborator

@AndreasKaratzas AndreasKaratzas commented Mar 5, 2026

Two tests were flaking (and consistently failing) for TitanML/tiny-mixtral:

FAILED models/language/generation/test_common.py::test_models[True-True-5-32-TitanML/tiny-mixtral]
FAILED models/language/generation/test_common.py::test_models[False-True-5-32-TitanML/tiny-mixtral]

The error message looks like a token ranking disagreement between HF and vLLM, but it's not a logic bug. It's a numeric precision issue that only surfaces with this particular model.

Root cause

TitanML/tiny-mixtral is randomly initialised and never trained. That means its hidden states going into the unembedding layer have no preferred direction, the resulting logit distribution is nearly uniform across the 32k vocab (98% of max entropy). As a consequence, the gap between the top-1 and top-2 token probabilities is tiny, on the order of O(1/sqrt(V)) \approx 0.006 nats.

AITER's plain rms_norm accumulates the variance term in bfloat16 rather than float32. At typical post-norm magnitudes (~2), that's a rounding error of 2^-6 = 0.015625 per call, which is already larger than the rank gap above. Over 24 decode steps, the error accumulates in the KV cache and eventually flips the argmax. The vLLM and HF outputs diverge not because the logic is wrong, but because the two paths are hitting different sides of a near-tie.

This doesn't affect trained models because their rank-1/rank-2 gap is ~1-3 nats, roughly 100x the per-step AITER error. The same kernel runs fine on everything else in the test suite.

Fix

disable_aiter_plain_rmsnorm() patches dispatch_rocm_rmsnorm_func for the duration of the test so the plain (non-fused) rms_norm path falls back to the native float32 kernel. The fused path (rms_norm2d_with_add, used when with_fused_add=True) is left on AITER, only the plain path is redirected.

This is applied in test_common.py whenever use_rocm_aiter=True and the model is in AITER_MODEL_LIST.

Proof chain

A separate file test_custom.py was written to prove each step of the causal chain. All 5 tests pass.

cc @kenroche @mawong-amd

Proof that AITER `rms_norm` is actually off just by 1 ULP and that it's really the randomly initialized model that is causing this havoc

Test 1: random weights -> near-uniform logits (CPU only)

Constructs synthetic weights matching tiny-mixtral's TruncNormal(0, 0.02) initialiser and measures entropy and the rank-1/rank-2 gap distribution. Proves ~5-9% of token positions sit below 1 ULP of bfloat16.

Test 2: AITER plain rms_norm has non-zero error

Calls rocm_aiter_ops.rms_norm directly against a float32 reference. Asserts max_diff > 0.

Test 3: native rms_norm is bit-exact

Same comparison with the native kernel. Asserts max_diff == 0.0. Confirms the fix fully eliminates the error.

Test 4: real checkpoint is also near-uniform

Downloads the actual lm_head.weight from TitanML/tiny-mixtral and repeats Test 1's assertions with real tensors. Entropy = 98% of max, near-tie fraction = 9.2%.

Test 5: error accumulates layer by layer

Runs AITER and native rms_norm in parallel through the model's norm stack, accumulating separate residual streams. Measures drift at each step.

Expected output:

Accumulated divergence through TitanML/tiny-mixtral residual stream (hidden_size=1024):
  norm             out_max  out_n_diff  out_mean_diff  h_drift_max
  ----------------------------------------------------------------
  0/attn          0.015625         473       0.004333     0.031250  <<< first divergence
  0/ffn           0.015625         432       0.004517     0.062500
  1/attn          0.031250         361       0.005798     0.125000
  1/ffn           0.031250         335       0.005341     0.187500
  final           0.031250         323       0.006226     0.250000

  => First divergence at norm: 0/attn

The h_drift_max column roughly doubles every layer because each norm receives the already-drifted hidden state from the step before. By the final norm, the accumulated drift in the residual stream is 0.25, well above the rank gap of ~0.006 nats.

The `test_custom.py` script used for all of the above
#!/usr/bin/env python3
"""
Proof chain: why TitanML/tiny-mixtral (randomly-initialised) makes logprob
comparison tests diverge when AITER's plain rms_norm is used, and why the
fault is in the model's weight distribution, not in vLLM or AITER.

The five tests below are self-contained and build on each other:

  [1] test_random_weights_produce_near_uniform_logits  (CPU-only, no model)
        Asserts:
          mean entropy    > 0.90 * log(32000) = 9.34 nats    => near-uniform
          mean entropy observed in practice: ~9.55 nats (92% of max)
          trained model entropy for comparison: ~2-4 nats
          fraction of near-ties (gap < 1 ULP of bf16) > 1%   => near-ties
                                                                exist and are
                                                                hit in 24 steps

  [2] test_aiter_dispatch_has_ulp_error  (requires VLLM_ROCM_USE_AITER=1)
        Asserts:
          max |AITER(x) - f32_ref(x)| > 0.0                 => AITER error
                                                                is non-zero

  [3] test_native_dispatch_matches_f32_reference  (requires ROCm)
        Asserts:
          max |native(x) - f32_ref(x)| == 0.0               => native is
                                                                bit-exact

  [4] test_tiny_mixtral_actual_weights_near_uniform  (requires model download)
        Same assertions as [1] but with the real lm_head.weight from the
        TitanML/tiny-mixtral checkpoint instead of synthetic weights.
        Confirms the real checkpoint has the same near-uniform property.

  [5] test_tiny_mixtral_layer_divergence  (requires VLLM_ROCM_USE_AITER=1 + model)
        Simulates the residual stream through the model's layernorm stack.
        Runs AITER and native rms_norm in parallel at each layer and
        accumulates both residual streams independently.
        Prints max diff at every norm and flags the first diverging layer.

The "peculiarity" that exposes the bug:
  [1][4] prove the random model's rank-1/rank-2 gap is < 0.016 nats.
  [2]    proves the AITER per-call error is > 0.0 nats (empirically ~0.016).
  => The error scale and the gap scale are the same.
  => After even ONE decode step the AITER error can already exceed the gap
     and flip the argmax.  After 24 steps (as in test_common.py) divergence
     is nearly certain.

  With a trained model this interaction does not occur: a trained model's
  rank-1/rank-2 gap is ~1-3 nats, which is ~100x the per-step AITER error.
  The same AITER path runs without any observable degradation.

  [3] confirms that switching to the native kernel eliminates the error.
  [5] pinpoints exactly which layer in tiny-mixtral first diverges.
  disable_aiter_plain_rmsnorm() in tests/utils.py applies that switch for
  tests that use vllm_runner.

Run:
    pytest tests/models/language/generation/test_custom.py -v          # [1][3][4]
    VLLM_ROCM_USE_AITER=1 pytest tests/models/language/generation/test_custom.py -v  # all
"""
import math

import pytest
import torch

import vllm.model_executor.layers.layernorm as _layernorm
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.layernorm import rms_norm as _native_rms_norm

EPS = 1e-5
MODEL = "TitanML/tiny-mixtral"


def _load_hf_model():
    """Load MODEL with transformers in bfloat16; skip the test if unavailable."""
    try:
        from transformers import AutoModelForCausalLM
        return AutoModelForCausalLM.from_pretrained(
            MODEL, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
        )
    except Exception as e:
        pytest.skip(f"Could not load {MODEL}: {e}")


def test_random_weights_produce_near_uniform_logits():
    """[1] Prove that randomly-initialised weights produce near-uniform logits.

    Models like TitanML/tiny-mixtral are initialised from TruncNormal(0, 0.02)
    and never trained.  With random weights the residual stream after each
    layer is statistically isotropic: the hidden state going into the
    unembedding has no preferred direction.  The resulting logit vector
    (hidden @ unembedding) is effectively i.i.d. normal, and after softmax
    over V = 32000 tokens the distribution is near-uniform.

    This test constructs that exact situation and asserts two bounds that
    connect directly to the divergence observed in test_common.py:

      (a) mean entropy > 0.90 * log(32000) = 9.34 nats
          The logit scale from (hidden @ unembedding * 0.02) is ~1.28 std,
          which reduces entropy to ~9.55 nats (92% of maximum).  A trained
          model in contrast has entropy of ~2-4 nats.  The threshold of 0.90
          (= 9.34 nats) sits comfortably above trained models and is met by
          any randomly-initialised model with this initialiser.

      (b) fraction of near-ties (rank-1/rank-2 gap < 1 ULP of bfloat16) > 1%
          0.016 nats is 1 ULP of bfloat16 at magnitude ~2 (= 2^-6 = 0.015625).
          The MEAN gap for a random model is ~0.32 nats (the Gaussian order-
          statistic formula gives sigma/sqrt(2*ln(V)) ~= 1.28/4.55 ~= 0.28).
          But the tail of the gap distribution is exponential: about 5% of
          token positions have a gap below 1 ULP.  Over 24 decode steps each
          contributing to the KV cache, the accumulated AITER error is almost
          certain to flip at least one of these near-ties.

    Combined with test_aiter_dispatch_has_ulp_error [2]:
      ~5% of gaps < 0.016  AND  AITER error > 0 per step
      => over 24 steps the accumulated KV-cache error will hit a near-tie
      => argmax is not stable under AITER's bfloat16 accumulation.

    This test is CPU-only and requires no external model download.
    """
    VOCAB = 32_000
    HIDDEN = 4_096
    N = 500  # independent forward-pass samples

    torch.manual_seed(0)
    # Isotropic hidden state: what a randomly initialised residual stream gives.
    hidden = torch.randn(N, HIDDEN)
    # Unembedding drawn from the same initialiser as tiny-mixtral (scale=0.02).
    unembedding = torch.randn(HIDDEN, VOCAB) * 0.02
    logits = hidden @ unembedding

    log_probs = torch.log_softmax(logits.float(), dim=-1)
    entropy = -(log_probs * log_probs.exp()).sum(-1)
    max_entropy = math.log(VOCAB)

    sorted_lp, _ = log_probs.sort(dim=-1, descending=True)
    gaps = sorted_lp[:, 0] - sorted_lp[:, 1]

    # (a) Output distribution is near-uniform (>> trained model entropy of 2-4 nats).
    assert entropy.mean().item() > 0.90 * max_entropy, (
        f"mean entropy {entropy.mean():.3f} should be "
        f">= 0.90 * log({VOCAB}) = {0.90 * max_entropy:.3f} nats"
    )

    # (b) Near-ties (gap < 1 ULP of bfloat16) occur with meaningful frequency.
    # The mean gap is ~0.32 nats, but the tail of the exponential-shaped gap
    # distribution puts ~5% of positions below 1 ULP = 0.016 nats.  Over 24
    # decode steps this makes divergence nearly certain when AITER error is > 0.
    one_ulp_bf16 = 0.016  # 2^-6 at bfloat16 magnitude ~2
    frac_near_ties = (gaps < one_ulp_bf16).float().mean().item()
    assert frac_near_ties > 0.01, (
        f"{100 * frac_near_ties:.1f}% of positions have rank-1/rank-2 gap "
        f"< 1 ULP of bfloat16 ({one_ulp_bf16} nats); expected > 1%. "
        f"Mean gap = {gaps.mean():.3f} nats (trained model: ~1-3 nats). "
        f"See test_aiter_dispatch_has_ulp_error: AITER error > 0 per step, "
        f"so over 24 steps a near-tie is almost certain to be flipped."
    )


@pytest.fixture
def bf16_input():
    torch.manual_seed(0)
    x = torch.randn(1, 512, dtype=torch.bfloat16, device="cuda")
    w = torch.ones(512, dtype=torch.bfloat16, device="cuda")
    x_f32 = x.float()
    ref = (
        x_f32 / (x_f32.pow(2).mean(-1, keepdim=True) + EPS).sqrt() * w.float()
    ).bfloat16()
    return x, w, ref


def test_aiter_dispatch_has_ulp_error(bf16_input):
    """[2] Prove that AITER's plain rms_norm introduces a non-zero error.

    Calls rocm_aiter_ops.rms_norm directly (bypassing the dispatch function
    so the test is independent of the production default).  Skips when AITER
    RMSNorm is not enabled (set VLLM_ROCM_USE_AITER=1 to run).

    AITER accumulates the variance term in bfloat16 instead of float32.
    At the bfloat16 magnitudes typical of post-norm hidden states (~2),
    1 ULP = 2^-6 = 0.015625 nats.  The assertion below proves the error
    is non-zero; the actual value observed is ~0.016 nats, matching that
    1 ULP bound.

    Cross-reference with test_random_weights_produce_near_uniform_logits [1]:
      [1] asserts rank-1/rank-2 gap < 0.016 nats for a random model.
      [2] asserts AITER error > 0.0 nats (empirically ~0.016).
      => The error and the gap are at the same scale.  The argmax is not
         stable: after enough decode steps the accumulated KV-cache drift
         will exceed the gap and flip the winner.
    """
    if not rocm_aiter_ops.is_rmsnorm_enabled():
        pytest.skip("AITER RMSNorm not enabled -- set VLLM_ROCM_USE_AITER=1")
    x, w, ref = bf16_input
    out = rocm_aiter_ops.rms_norm(x, w, EPS)
    max_diff = (out - ref).abs().max().item()
    assert max_diff > 0.0, (
        f"AITER plain rms_norm max error = {max_diff:.6f}; expected > 0. "
        f"The error (~0.016 nats = 1 ULP of bfloat16 at magnitude ~2) is at "
        f"the same scale as the rank-1/rank-2 gap of a random model "
        f"(< 0.016 nats, proved by test_random_weights_produce_near_uniform_logits)."
    )


def test_native_dispatch_matches_f32_reference(bf16_input):
    """[3] Prove that the native rms_norm is bit-exact vs the float32 reference.

    The native kernel promotes to float32 internally for variance accumulation,
    so there is no bfloat16 rounding error.  This test confirms that switching
    from AITER to native (which disable_aiter_plain_rmsnorm() in tests/utils.py
    does at test time) fully eliminates the per-call error proved in [2].
    """
    out = _native_rms_norm(bf16_input[0], bf16_input[1], EPS)
    assert (out - bf16_input[2]).abs().max().item() == 0.0, (
        "native rms_norm should be bit-exact vs the float32 reference"
    )


def test_tiny_mixtral_actual_weights_near_uniform():
    """[4] Same near-uniform check as [1] but with TitanML/tiny-mixtral's real weights.

    [1] proves the property holds for the TruncNormal(0, 0.02) weight family.
    [4] proves it holds for the specific checkpoint -- using the actual lm_head.weight
    tensor from disk, not synthetic weights.

    tiny-mixtral config: hidden_size=1024, vocab_size=32000.
    """
    m = _load_hf_model()
    lm_w = m.lm_head.weight.detach().float()  # (vocab, hidden)
    vocab, hidden = lm_w.shape

    torch.manual_seed(0)
    h = torch.randn(500, hidden)
    log_probs = torch.log_softmax(h @ lm_w.T, dim=-1)
    entropy = -(log_probs * log_probs.exp()).sum(-1)
    max_entropy = math.log(vocab)

    sorted_lp, _ = log_probs.sort(dim=-1, descending=True)
    gaps = sorted_lp[:, 0] - sorted_lp[:, 1]
    frac = (gaps < 0.016).float().mean().item()

    print(
        f"\n[{MODEL}] actual lm_head: "
        f"entropy={entropy.mean():.3f} nats "
        f"({100 * entropy.mean() / max_entropy:.1f}% of max), "
        f"near-tie fraction={100 * frac:.1f}%"
    )

    assert entropy.mean().item() > 0.90 * max_entropy, (
        f"entropy {entropy.mean():.3f} < 0.90 * log({vocab}) = "
        f"{0.90 * max_entropy:.3f} nats"
    )
    assert frac > 0.01, (
        f"{100 * frac:.1f}% near-ties with actual {MODEL} lm_head; expected > 1%"
    )


def test_tiny_mixtral_layer_divergence():
    """[5] Trace accumulated AITER error through tiny-mixtral's layernorm stack.

    Simulates the residual stream by applying each layer's layernorm weights in
    sequence.  At every norm step both the AITER path and the native path receive
    the hidden state that has accumulated from all previous layers in their own
    path, so errors compound realistically.

    Per norm step, prints:
      out_max       max |AITER_out - native_out| across all hidden dimensions
      out_n_diff    number of elements where the two outputs differ
      out_mean_diff mean |diff| over the differing elements only
      h_drift_max   max |h_aiter - h_native| after the residual add
                    (shows how divergence accumulates in the hidden state)
    Marks the first step where out_max > 0.

    tiny-mixtral has 2 decoder layers, each with input_layernorm and
    post_attention_layernorm, plus one final model.norm -- 5 norm calls total.

    Requires VLLM_ROCM_USE_AITER=1.
    """
    if not rocm_aiter_ops.is_rmsnorm_enabled():
        pytest.skip("AITER RMSNorm not enabled -- set VLLM_ROCM_USE_AITER=1")
    if not torch.cuda.is_available():
        pytest.skip("requires CUDA/ROCm")

    m = _load_hf_model()
    eps = getattr(m.config, "rms_norm_eps", EPS)
    hidden_size = m.config.hidden_size  # 1024 for tiny-mixtral

    torch.manual_seed(0)
    h_aiter = torch.randn(1, hidden_size, dtype=torch.bfloat16, device="cuda")
    h_native = h_aiter.clone()

    hdr = (
        f"  {'norm':<12}  {'out_max':>10}  {'out_n_diff':>10}  "
        f"{'out_mean_diff':>13}  {'h_drift_max':>11}"
    )
    print(f"\nAccumulated divergence through {MODEL} residual stream "
          f"(hidden_size={hidden_size}):")
    print(hdr)
    print("  " + "-" * (len(hdr) - 2))

    first = None

    def _step(norm_mod, label):
        nonlocal first, h_aiter, h_native
        w = norm_mod.weight.to(dtype=torch.bfloat16, device="cuda")
        out_a = rocm_aiter_ops.rms_norm(h_aiter, w, eps)
        out_n = _native_rms_norm(h_native, w, eps)

        abs_diff = (out_a - out_n).abs()
        out_max = abs_diff.max().item()
        out_n_diff = (abs_diff > 0).sum().item()
        out_mean = abs_diff[abs_diff > 0].mean().item() if out_n_diff > 0 else 0.0

        h_aiter = out_a + h_aiter
        h_native = out_n + h_native
        h_drift = (h_aiter - h_native).abs().max().item()

        tag = "  <<< first divergence" if (out_max > 0 and first is None) else ""
        print(
            f"  {label:<12}  {out_max:>10.6f}  {out_n_diff:>10d}  "
            f"{out_mean:>13.6f}  {h_drift:>11.6f}{tag}"
        )
        if out_max > 0 and first is None:
            first = label

    for i, layer in enumerate(m.model.layers):
        _step(layer.input_layernorm, f"{i}/attn")
        _step(layer.post_attention_layernorm, f"{i}/ffn")

    _step(m.model.norm, "final")

    assert first is not None, (
        "expected AITER error in at least one norm layer but saw none"
    )
    print(f"\n  => First divergence at norm: {first}")

…R rms_norm

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a numerical divergence issue with AITER's rms_norm on ROCm for the TitanML/tiny-mixtral model, which was causing test flakiness. The root cause, related to bfloat16 precision in variance accumulation, is well-understood and clearly explained.

The fix is a targeted monkeypatch that swaps the problematic AITER kernel with a more precise native implementation, but only for the specific test case where the issue occurs. This is a clean and appropriate solution for resolving CI instability without affecting production code. The implementation is well-documented and correctly scoped. I have no further comments.

Note: Security Review is unavailable for this PR.

@AndreasKaratzas AndreasKaratzas added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 6, 2026
Copy link
Copy Markdown
Collaborator

@gshtras gshtras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest opening an issue for ROCm/aiter.git about this even though currently it only affects a certain range of values

@gshtras gshtras merged commit 1e0f917 into vllm-project:main Mar 9, 2026
19 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Mar 9, 2026
@gshtras gshtras deleted the akaratza_fix_lang_mod_s branch March 9, 2026 17:07
cong-or pushed a commit to cong-or/vllm that referenced this pull request Mar 10, 2026
…R rms_norm (vllm-project#36101)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
ZehaoLu98 added a commit to ZehaoLu98/vllm that referenced this pull request Mar 10, 2026
commit 8d983d7cd661aae1ac8781f67fbbff017db4d0af
Author: Nick Hill <nickhill123@gmail.com>
Date:   Tue Mar 10 14:55:21 2026 -0700

    [Model Runner V2] Add initial CI tests (#36041)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 65b2f405dca824adad17a42a71c908c6ebbcfd9a
Author: Nick Hill <nhill@redhat.com>
Date:   Tue Mar 10 13:20:02 2026 -0700

    [Core] Simplify core kv-cache blocks initialization logic (#36521)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 2a68464c5bf1a26821afe76cf49dc53f75b87e98
Author: Nick Hill <nhill@redhat.com>
Date:   Tue Mar 10 11:17:26 2026 -0700

    [Test] `test_async_scheduling.py` improvements (#36340)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit bdd8981dab8d8c6ae88a3f605d04ec5243088e5a
Author: Zhengxu Chen <zhxchen17@fb.com>
Date:   Tue Mar 10 12:34:35 2026 -0400

    [compile] Apply stored functorch config while finalizing loaded artifacts. (#36582)

    Signed-off-by: zhxchen17 <zhxchen17@fb.com>

commit f088a831dd6c35d995c4232cc2462c024c61925b
Author: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date:   Tue Mar 10 09:30:56 2026 -0700

    [Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)

    Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

commit f83b933b84b85ee54121575fc347881b35090616
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Tue Mar 10 16:18:28 2026 +0000

    [CI] Bump `mypy` version to 1.19.1 (#36104)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 82f3f30e266e24b26c46916a8c9daaea7d5e32bd
Author: Pleaplusone <ygan@amd.com>
Date:   Wed Mar 11 00:14:35 2026 +0800

    [ROCm][Perf] Enable `sparse_mla`'s cudagraph on ROCm platform (#35719)

    Signed-off-by: ganyi <ygan@amd.com>

commit 9095cbbfb6f68f3f7abc7f55c74768e9f7b1d0a7
Author: Matthew Bonanni <mbonanni@redhat.com>
Date:   Tue Mar 10 12:14:31 2026 -0400

    [Bugfix][Sparse MLA] report indexer CG support properly (#36519)

    Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>

commit 721ae79f50c5f85b301d05f1db71372b1ca85dd6
Author: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Date:   Tue Mar 10 09:14:27 2026 -0700

    Improvements to wvSplitKrc skinny GEMM solution (#34304)

    Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>

commit aefc59f088665b23c0285c7f77c32b365efaa5dc
Author: AllenDou <allen.dou@hotmail.com>
Date:   Tue Mar 10 23:14:21 2026 +0800

    FunASR model bugfix (#36633)

    Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com>
    Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>

commit d88f28da05b12bc7d63ebe3dcedf445ecb274343
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Tue Mar 10 15:03:18 2026 +0000

    Fix `hf_override_fn` when it modifies `model_type` (#35200)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 106ff69c4eb4921d33341a96b9c3d6db9d12ba76
Author: Srinivasoo7 <194645829+Srinivasoo7@users.noreply.github.com>
Date:   Tue Mar 10 09:43:40 2026 -0500

    feat(kv-offload): Strategy A — StoreReusedOffloadingManager gates CPU stores on reuse frequency (#35342)

    Signed-off-by: srinivas_oo7 <Sriusa4414@gmail.com>
    Signed-off-by: Sriusa4414@gmail.com
    Signed-off-by: Srinivasoo7 <158864704+Srinivasoo7@users.noreply.github.com>
    Co-authored-by: srinivas_oo7 <sklinkedin0120@gmail.com>
    Co-authored-by: Srinivasoo7 <158864704+Srinivasoo7@users.noreply.github.com>
    Co-authored-by: Or Ozeri <oro@il.ibm.com>

commit ca5fb4bbd85244fafba72fb91523c657025998a3
Author: Jiangyun Zhu <riverclouds.zhu@qq.com>
Date:   Tue Mar 10 22:39:01 2026 +0800

    [Bugfix] Avoid merging empty-only partitions into splitting-op subgraphs (#36595)

    Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>

commit cf88b23749187b9a31406925d3f9e966fc4c566b
Author: Alvin Tang <104285249+alvinttang@users.noreply.github.com>
Date:   Tue Mar 10 22:22:40 2026 +0800

    fix: check HTTP status in batch read_file to prevent silent failures (#36397)

    Signed-off-by: gambletan <ethanchang32@gmail.com>
    Co-authored-by: gambletan <ethanchang32@gmail.com>
    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit a3189a08b0d3de44dd6d49c5d883abf29ac1e6fa
Author: wang.yuqi <yuqi.wang@daocloud.io>
Date:   Tue Mar 10 21:32:25 2026 +0800

    [Model] Consolidate score logic by introduce score_type (#36479)

    Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>

commit 409c4e632d58acc7f2a2f66e7554776c78bb65ad
Author: SoluMilken <ypiheyn.imm02g@g2.nctu.edu.tw>
Date:   Tue Mar 10 21:25:37 2026 +0800

    [Misc] fix typo: homogenous-> homogeneous (2 lines change) (#36508)

    Signed-off-by: SoluMilken <ypiheyn.imm02g@g2.nctu.edu.tw>

commit 8850738b700cca34448fbafbc8ac41bcad5a2e17
Author: Raushan Turganbay <raushan@huggingface.co>
Date:   Tue Mar 10 14:20:47 2026 +0100

    [Bugfix] Fix processor signature (#36630)

    Signed-off-by: raushan <raushan@huggingface.co>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 234860399b9d390bf59bfe1f19c2e2304ac5c806
Author: Mark McLoughlin <markmc@redhat.com>
Date:   Tue Mar 10 13:20:41 2026 +0000

    [Frontend][Core] Revert "Add shutdown timeout" (#34730 and #36270) (#36628)

    Signed-off-by: Mark McLoughlin <markmc@redhat.com>

commit c88510083b8d6b4fa7a42ae29bc27ff6adc181ee
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Tue Mar 10 12:05:34 2026 +0000

    Fix Qwen2.5-VL test for Transformers v5 (#36532)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 4ff8c3c8f9ece010a1d0e376f5cc1b468b95f366
Author: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Date:   Tue Mar 10 14:32:20 2026 +0400

    [BUGFIX][Mamba][Qwen3.5] Zero freed SSM cache blocks on GPU (#35219)

    Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>

commit 507ddbe9927f421a1d574b283d1611044859a30d
Author: Chang Su <chang.s.su@oracle.com>
Date:   Tue Mar 10 03:29:59 2026 -0700

    feat(grpc): extract gRPC servicer into smg-grpc-servicer package, add --grpc flag to vllm serve (#36169)

    Signed-off-by: Chang Su <chang.s.su@oracle.com>
    Co-authored-by: Nick Hill <nhill@redhat.com>

commit ddbb0d230a3592106ac9f5f7f4e9a861863fcbee
Author: Nick Hill <nhill@redhat.com>
Date:   Tue Mar 10 00:24:58 2026 -0700

    [Model Runner V2] Fix mm input embeddings lookup (#36588)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 9efc3bdcd6749f6d0ba26b12aee27cc8829c6f93
Author: Nick Hill <nhill@redhat.com>
Date:   Tue Mar 10 00:23:42 2026 -0700

    [Model Runner V2] Fix `_compute_slot_mappings_kernel` for chunked prefill (#36580)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 156e33553ccdba940fec83a720290b30d2686ee8
Author: amirkl94 <203507526+amirkl94@users.noreply.github.com>
Date:   Tue Mar 10 08:11:27 2026 +0200

    Fix: Re-Enable EP for trtllm MoE FP8 backend (#36494)

    Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>

commit d0cd736caadafea1ec1721737af432d8b0a7e919
Author: hallerite <git@hallerite.com>
Date:   Mon Mar 9 22:30:51 2026 -0700

    [Bugfix] Fix `RuntimeError: Already borrowed` that degrades VLM serving throughput under concurrent load. (#36557)

    Signed-off-by: hallerite <hallerite@users.noreply.github.com>
    Signed-off-by: hallerite <git@hallerite.com>
    Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

commit 195c9972037034355c5e85207f611aa09023cb66
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Tue Mar 10 05:29:17 2026 +0000

    Fix LFM2 MoE test for Transformers v5 (#36534)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 04b67d8f62cab3a1832df5c6ed840f8a6afccaf9
Author: Zhuohan Li <zhuohan123@gmail.com>
Date:   Mon Mar 9 20:56:54 2026 -0700

    Remove unused disable_fallback field (#36546)

commit 7279374f9108652296a8f38b6f9c7f0585a0cda4
Author: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Date:   Mon Mar 9 23:55:58 2026 -0400

    [Perf] Compute maxsim in worker side, reducing redundant copies, 2.7% E2E throughput improvement (#36159)

    Signed-off-by: yewentao256 <zhyanwentao@126.com>

commit 006aea17d7de338ab9f9e13bfe566715782d19a4
Author: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date:   Mon Mar 9 20:02:02 2026 -0700

    [BugFix] Remove incorrect assert in split_decodes_and_prefills (#36553)

    Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

commit 0836be3b03c9f4a4da7d2eba0d3e8cbe5511f6bf
Author: Hojin Yang <57383540+effortprogrammer@users.noreply.github.com>
Date:   Tue Mar 10 11:59:19 2026 +0900

    [Model] Add HyperCLOVAX-SEED-Think-32B vision-language model support (#31471)

    Signed-off-by: effortprogrammer <yhjhoward7@gmail.com>
    Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

commit 4e95ec111cd179f2ab0f6931bf57663f828a51ec
Author: Ajay Anubolu <124525760+AjAnubolu@users.noreply.github.com>
Date:   Mon Mar 9 19:16:26 2026 -0700

    [Bugfix] Fix Qwen3-Next in_proj_ba weight sharding with TP > 1 (#36242)

    Signed-off-by: AjAnubolu <anuboluajay@gmail.com>

commit 179547d62c73e7174bf42b8ca0a34177ac3a5c9e
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Mon Mar 9 19:55:20 2026 -0500

    [ROCm][CI] Fix ROCm GPT-OSS Eval test group (#36179)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit f85b4eda3a22fedd885ef31650c825d56867587e
Author: youkaichao <youkaichao@gmail.com>
Date:   Tue Mar 10 07:49:47 2026 +0800

    [bugfix] fix nvlink for nixl/ucx (#36475)

    Signed-off-by: youkaichao <youkaichao@gmail.com>

commit 2a194ddd72a0cc5b6c404a694a64197d0c572f5b
Author: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date:   Mon Mar 9 15:14:51 2026 -0700

    [Model Runner V2] Add model_state inputs to CUDA graph capture (#36544)

    Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

commit 203a7f27dac2197ddcf5bb1cfd105596a19ea990
Author: Shaun Kotek <93727115+shaunkotek@users.noreply.github.com>
Date:   Tue Mar 10 00:11:41 2026 +0200

    add nemotron v3 reasoning parser (#36393)

    Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
    Co-authored-by: root <root@gpu-259.slurm-workers-slurm.slurm.svc.cluster.local>

commit 483463f735c41c36a41431044fa537dc4c81fc3c
Author: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Date:   Mon Mar 9 16:58:45 2026 -0400

    [MRV2] Extensible CG dispatch rework  (#35959)

    Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

commit 4e571ce6433b6768950becda40d55cb4f24741ce
Author: Matthew Bonanni <mbonanni@redhat.com>
Date:   Mon Mar 9 14:43:06 2026 -0400

    [MTP][Misc] Clean up dead code (#36507)

    Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>

commit 4ff9b045fe7a9da9b5a7737407ed4e7ef203ffad
Author: Micah Williamson <micah.williamson@amd.com>
Date:   Mon Mar 9 13:27:55 2026 -0500

    [ROCm][CI] Prep Tests For Change To ROCM_ATTN As New Default Backend On ROCm (#36025)

    Signed-off-by: Micah Williamson <micah.williamson@amd.com>

commit 3fd03f1ec29cf9ac20584ad68156fc7279387979
Author: Lucas Kabela <lucaskabela@meta.com>
Date:   Mon Mar 9 11:22:05 2026 -0700

    [BE] Rename `should_torch_compile_mm_vit` to `should_torch_compile_mm_encoder` (#36281)

    Signed-off-by: Lucas Kabela <lucaskabela@meta.com>

commit 10a5f4d53d0dc7390802ad99bf5d27b2423094e9
Author: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date:   Mon Mar 9 11:17:34 2026 -0700

    [Model Runner V2] Use NamedTuple for `execute_model_state` (#35930)

    Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

commit fe0c085c28dc5703da33ac3c329fb4370a798798
Author: Simon Mo <simon.mo@hey.com>
Date:   Mon Mar 9 11:16:50 2026 -0700

    [Docs] Remove the reo beacon (#36528)

    Co-authored-by: Cursor Agent <cursoragent@cursor.com>

commit 8d6b3d5dda293231c7c2fc9301002113f270a534
Author: Taneem Ibrahim <taneem.ibrahim@gmail.com>
Date:   Mon Mar 9 14:14:11 2026 -0400

    [Misc] Refactored 5 duplicate helper functions that were copied-pasted across multiple parsers (#36436)

    Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>

commit 4b87ffbefb3881a0a33f9c1cb7121429bddad666
Author: Copilot <198982749+Copilot@users.noreply.github.com>
Date:   Mon Mar 9 18:04:40 2026 +0000

    [torch.compile] Rename `compile_ranges_split_points` to `compile_ranges_endpoints` (#36027)

    Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
    Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
    Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
    Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>

commit fa028207aa9d4baa6cfc4863f6f54c4277884e6e
Author: Shaun Kotek <93727115+shaunkotek@users.noreply.github.com>
Date:   Mon Mar 9 20:01:18 2026 +0200

    Fix/resupport nongated fused moe triton (#36412)

    Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>
    Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>
    Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
    Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
    Signed-off-by: yewentao256 <zhyanwentao@126.com>
    Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
    Signed-off-by: liweiguang <codingpunk@gmail.com>
    Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
    Signed-off-by: wang.yuqi <noooop@126.com>
    Signed-off-by: Alex Brooks <albrooks@redhat.com>
    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
    Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
    Signed-off-by: Tushar Shetty <tushar.shetty@abbyy.com>
    Signed-off-by: Tushar Shetty <54362365+tusharshetty61@users.noreply.github.com>
    Signed-off-by: jiang1.li <jiang1.li@intel.com>
    Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
    Signed-off-by: Xin Yang <xyangx@amazon.com>
    Signed-off-by: Kevin H. Luu <khluu000@gmail.com>
    Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
    Co-authored-by: nvnbagrov <nbagrov@nvidia.com>
    Co-authored-by: Sage <80211083+sagearc@users.noreply.github.com>
    Co-authored-by: danisereb <daserebrenik@nvidia.com>
    Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
    Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
    Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
    Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
    Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
    Co-authored-by: Weiguang Li <codingpunk@gmail.com>
    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
    Co-authored-by: Li, Jiang <jiang1.li@intel.com>
    Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
    Co-authored-by: Alex Brooks <albrooks@redhat.com>
    Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
    Co-authored-by: cong-or <conchubhar.gannon@gmail.com>
    Co-authored-by: Tushar Shetty <54362365+tusharshetty61@users.noreply.github.com>
    Co-authored-by: liuzhenwei <zhenwei.liu@intel.com>
    Co-authored-by: Xin Yang <105740670+xyang16@users.noreply.github.com>
    Co-authored-by: Kevin H. Luu <khluu000@gmail.com>
    Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

commit d460a18fc656f7fb217b977d4c2ee1003af2a5b6
Author: Russell Bryant <rbryant@redhat.com>
Date:   Mon Mar 9 13:43:42 2026 -0400

    [Docs] Expand --allowed-media-domains security guidance with threat details (#36506)

    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 6e956d9eca398005929d29f123607d1029800cc7
Author: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Date:   Mon Mar 9 10:20:13 2026 -0700

    [Model Runner V2] Add dummy profile_cudagraph_memory API (#36520)

    Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

commit 1e0f917b349338ac09377dd277ded5e1e62df77e
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Mon Mar 9 12:07:44 2026 -0500

    [ROCm][CI] Fix logprob divergence for TitanML/tiny-mixtral under AITER rms_norm (#36101)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit c174d54f86aa10e63ae236dc09f05f821134d469
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Mon Mar 9 12:02:41 2026 -0500

    [ROCm][CI] Fix ROCm attention backend validation for head sizes, block sizes, and compute capability checks (#36292)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit 55d27cca55310a04fb82c90d26a5afed90f01de7
Author: SoluMilken <s916526000@gmail.com>
Date:   Tue Mar 10 01:00:12 2026 +0800

    [Misc] fix typo: dependant -> dependent (2 lines change) (#36511)

    Signed-off-by: SoluMilken <ypiheyn.imm02g@g2.nctu.edu.tw>

commit 580864d81eb03d9fb1383e1782636ff6a9425fa2
Author: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Date:   Mon Mar 9 17:50:36 2026 +0100

    [Attention][Perf][Kernel] Replace torch.cat with vectorized CUDA kernel MLA query concat - DeepSeek-V3.2 (#34917)

    Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
    Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>

commit 2b28b9b269e18cfe42c7e945d1da8d1c40989efa
Author: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Date:   Mon Mar 9 17:46:57 2026 +0100

    [Attention][Perf] Optimize cp_gather_and_upconvert_fp8_kv_cache - DeepSeek-v3.2 (#35290)

    Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
    Co-authored-by: Claude <noreply@anthropic.com>

commit 70485a11bd83afa50e6ecc8e9619d9bdd0ff2039
Author: Taoyu Zhu <z609495@gmail.com>
Date:   Tue Mar 10 00:30:35 2026 +0800

    [ROCM] Optimize the fused_topk_bias to use aiter instead of fallback torch ops. (#36253)

    Signed-off-by: zhutaoyu <zhutaoyu97@gmail.com>

commit 74a9f54cdb07eca31036d96390db968b780e44f5
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Mon Mar 9 16:06:19 2026 +0000

    [CI] Fix edge case that could lead to broken docs builds on main (#36515)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 00c4cb5606ae4f7ba80485f4a2756df33a2d4065
Author: Matthew Bonanni <mbonanni@redhat.com>
Date:   Mon Mar 9 11:56:00 2026 -0400

    [Bugfix] Clear stale CG keys after memory profiling (#36416)

    Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>

commit 941e52c29813ed75b3382f2a0d74ad5f168fc046
Author: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Date:   Mon Mar 9 11:33:46 2026 -0400

    [Refactor] Simplify `chat_completion_full_generator` for tool parsers (#35634)

    Signed-off-by: yewentao256 <zhyanwentao@126.com>

commit be292b7c14e08e6e6883d5ebee79240d04814159
Author: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Date:   Mon Mar 9 11:17:45 2026 -0400

    [Bug] Fix pooling model benchmark script (#36300)

    Signed-off-by: yewentao256 <zhyanwentao@126.com>

commit 77a73458e3ae8b5b7a2a13f78d3a6b4d39b1414d
Author: Matthew Bonanni <mbonanni@redhat.com>
Date:   Mon Mar 9 10:17:14 2026 -0400

    Reapply [Attention] Refactor `check_and_update_config` (#35122)

    Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>

commit 5578f2a4d33b3451203fa5d43e4e6847c00b55c6
Author: Tianyu Guo <guoty9@mail2.sysu.edu.cn>
Date:   Mon Mar 9 22:16:44 2026 +0800

    Support online use_audio_in_video (#36319)

    Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn>
    Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
    Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

commit 3ec2115015334e26b00bb2b4cadc2587138c5948
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Mon Mar 9 21:03:21 2026 +0800

    [Frontend] Move warmup into Renderer (#36482)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit b0906d8b02681d8d8f0709f0cc730f5fe845b5b1
Author: Isotr0py <mozf@mail2.sysu.edu.cn>
Date:   Mon Mar 9 18:43:44 2026 +0800

    [MM Encoder] Default to use TORCH_SDPA backend for ViT on Volta/Turing GPU (#36472)

    Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

commit aaf5fa9abfb7c265ccfe00480c349870a72b7209
Author: Kevin H. Luu <khluu000@gmail.com>
Date:   Mon Mar 9 03:43:26 2026 -0700

    [ci] Bound openai dependency to 2.24.0 (#36471)

    Signed-off-by: Kevin H. Luu <khluu000@gmail.com>

commit f96c3ab08cc75f18d40892ef59b6f295e71ffe83
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Mon Mar 9 18:43:23 2026 +0800

    [Deprecation][1/2] Remove items deprecated in v0.18 (#36470)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit dc6b57846686206d6d77fe788f71ab7fe8e568ab
Author: Xin Yang <105740670+xyang16@users.noreply.github.com>
Date:   Sun Mar 8 23:41:01 2026 -0700

    [Kernel] Add fused_sigmoid_gating_delta_rule_update kernel for Qwen3 Next (#35777)

    Signed-off-by: Xin Yang <xyangx@amazon.com>

commit 1bc9c77f6d324bf7b9253b0c78626fbc50286bfb
Author: liuzhenwei <zhenwei.liu@intel.com>
Date:   Mon Mar 9 13:50:27 2026 +0800

    [XPU] Add test script of PD disaggregation (#36434)

    Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>

commit 65a4da15043f11e86ffcc036f9eb9ad549f0ad17
Author: Alex Brooks <albrooks@redhat.com>
Date:   Sun Mar 8 23:46:23 2026 -0600

    [Frontend] Add Support for MM Encoder/Decoder Beam Search (Online Transcriptions) (#36160)

    Signed-off-by: Alex Brooks <albrooks@redhat.com>

commit 217f27598dbf3cc8ec0765cc3a41b667939ce6bb
Author: Li, Jiang <jiang1.li@intel.com>
Date:   Mon Mar 9 13:06:28 2026 +0800

    [Bugfix] Avoid to replace non-tensor members in cpu model runner (#36430)

    Signed-off-by: jiang1.li <jiang1.li@intel.com>

commit fff3711a244dd9e2915323e31c20768d922e90b5
Author: wang.yuqi <yuqi.wang@daocloud.io>
Date:   Mon Mar 9 11:42:19 2026 +0800

    [Frontend][2/n] Improve pooling entrypoints | embed. (#36110)

    Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
    Signed-off-by: wang.yuqi <noooop@126.com>

commit c4d859c274960d62f0b2ff6e7ac96be452994b55
Author: Tushar Shetty <54362365+tusharshetty61@users.noreply.github.com>
Date:   Mon Mar 9 09:10:16 2026 +0530

    [Bugfix] Skip out-of-stage layers in get_layers_from_vllm_config for pipeline parallel (#36243)

    Signed-off-by: Tushar Shetty <tushar.shetty@abbyy.com>
    Signed-off-by: Tushar Shetty <54362365+tusharshetty61@users.noreply.github.com>

commit 747431044df6b15c7b359b5720cc7368c662c232
Author: cong-or <conchubhar.gannon@gmail.com>
Date:   Mon Mar 9 03:40:12 2026 +0000

    feat(attention): extract KV-cache update from FlexAttention backend (#36263)

    Signed-off-by: cong-or <conchubhar.gannon@gmail.com>

commit d62856b9283b5f5a90e6f135b787e63b5ca3f157
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Mon Mar 9 11:31:39 2026 +0800

    [Misc] Move processors to `transformers_utils` (#35953)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit bd2659a5660a7c5ccfeb1f1579e4000ed6536250
Author: Alex Brooks <albrooks@redhat.com>
Date:   Sun Mar 8 21:30:49 2026 -0600

    Increase Flexibility for OOV Multimodal Token Handling (#34858)

    Signed-off-by: Alex Brooks <albrooks@redhat.com>

commit 90512b2e8bff5bddca5fca30dc4f0136d682f7d4
Author: Shaun Kotek <93727115+shaunkotek@users.noreply.github.com>
Date:   Mon Mar 9 05:25:21 2026 +0200

    fix: Use iterator as not to store all the file loads in memory at once (#36149)

    Signed-off-by: Shaun Kotek - Nvidia <skotek@nvidia.com>

commit dcf8862fd47624ec48a6e3a06ff2bcc53dc4d4a0
Author: wang.yuqi <yuqi.wang@daocloud.io>
Date:   Mon Mar 9 11:22:53 2026 +0800

    [Examples][1/n] Resettle basic examples. (#35579)

    Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
    Signed-off-by: wang.yuqi <noooop@126.com>
    Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 43aa3892314f8336f83a9fbe614899ddcf0e1df8
Author: Weiguang Li <codingpunk@gmail.com>
Date:   Mon Mar 9 11:07:29 2026 +0800

    [Bugfix] Fix CPU OMP autobind assertion to use local_world_size (#35815)

    Signed-off-by: liweiguang <codingpunk@gmail.com>
    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
    Co-authored-by: Li, Jiang <jiang1.li@intel.com>

commit 384425f84e314b11076289365277b1c2650ee902
Author: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Date:   Sun Mar 8 23:06:22 2026 -0400

    [Dependency] Remove default ray dependency (#36170)

    Signed-off-by: yewentao256 <zhyanwentao@126.com>
    Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
    Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

commit a0f44bb6169dcd6225d2efc0a59dd343a8d4a38e
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Mon Mar 9 03:05:24 2026 +0000

    Allow `markdownlint` to run locally (#36398)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit fde4771bbda69f86a58eace1447f3ab5e369b63d
Author: Kunshang Ji <kunshang.ji@intel.com>
Date:   Mon Mar 9 10:09:22 2026 +0800

    [XPU][Doc] update xpu document about triton dependency/conflict issue. (#36301)

    Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>

commit e5ff140216272c529261b02b6fd13fc480713735
Author: Jiangyun Zhu <riverclouds.zhu@qq.com>
Date:   Mon Mar 9 08:27:41 2026 +0800

    [cudagraph] fix cudagraph warning in deepseekv32 (#28044)

    Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>

commit 0a6a3a12906bd581fb2983c81b4d51dc60e0bb4a
Author: danisereb <daserebrenik@nvidia.com>
Date:   Sun Mar 8 22:00:05 2026 +0200

    Add support for ModelOpt MXFP8 MoE models (#35986)

    Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>

commit 4497431df654e46fb1fb5e64bf8611e762ae5d87
Author: Sage <80211083+sagearc@users.noreply.github.com>
Date:   Sun Mar 8 17:35:09 2026 +0200

    [Frontend] Add GPU-less render serving path (`vllm launch render`) (#36166)

commit b7332b058c3b0d8533395b49dea9273aa0973b4e
Author: nvnbagrov <nbagrov@nvidia.com>
Date:   Sun Mar 8 12:04:05 2026 +0200

    [Model] Nano Nemotron VL - fast media preprocessing (#35657)

    Signed-off-by: Natan Bagrov <nbagrov@nvidia.com>

commit 40077ea3defdf2b0997245ca8999097eede2308f
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Sun Mar 8 00:42:24 2026 -0600

    [CI] fix flaky empty responses and add diagnostic assertions in vision chat tests (#36341)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit 5d6aae4577590cd6b6a604f9e74c17c5f234271d
Author: Samuel Shen <slshen@uchicago.edu>
Date:   Sat Mar 7 13:52:48 2026 -0800

    [LMCache MP Patch]: Race Condition + Duplicated Block Ids (#35831)

commit 63298ee17350e4eda3f574eab16286bc405b23a6
Author: Roy Huang <roy.y.huang@gmail.com>
Date:   Sat Mar 7 13:52:35 2026 -0800

    [Bugfix][LMCache][KVConnector] fix potential memory leak in LMCache multiprocess mode (#35931)

commit 2dde535df1b736315e56eace0fa1923fe0beffc5
Author: Richard Zou <zou3519@users.noreply.github.com>
Date:   Sat Mar 7 16:52:11 2026 -0500

    [compile] Split compile/warmup monitoring (#36098)

commit 379689d533642cfc1d3ab2cf4dc02f09a8318a5f
Author: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
Date:   Sat Mar 7 16:51:54 2026 -0500

    [Perf] Support FP8 KV cache for Flashinfer MLA Sparse (#35891)

commit a6be75dbd2a8dd1886da725727ee178f42e3f84f
Author: PatchyTIS <58251192+PatchouliTIS@users.noreply.github.com>
Date:   Sun Mar 8 05:51:37 2026 +0800

    [Core] NGram GPU Implementation compatible with Async Scheduler (#29184)

commit ee54f9cdb91f04350bba0cf11890b02b12c62baa
Author: Micah Williamson <micah.williamson@amd.com>
Date:   Sat Mar 7 15:50:52 2026 -0600

    [ROCm][CI] Accept Different But Valid Output for `test_olmoe_tp` (#35224)

commit fc4657756ff01fec770433530a5dd2a238e7e034
Author: Micah Williamson <micah.williamson@amd.com>
Date:   Sat Mar 7 15:50:17 2026 -0600

    [ROCm][CI] Enable AITER for failing `test_gpt_oss` test case on MI355 (#36174)

commit eebd14651f7618eddda5e79eab2d4ea0cdcc1770
Author: qli88 <qiang.li2@amd.com>
Date:   Sat Mar 7 15:49:56 2026 -0600

    [CI] Enable Crosslayer KV layout tests for ROCm platforms (#35416)

commit ebb9cc5f2b26d73222c08e42b32fcf59e831386c
Author: Matthew Bonanni <mbonanni@redhat.com>
Date:   Sat Mar 7 16:49:23 2026 -0500

    [UX][Startup] Account for CUDA graphs during memory profiling (#30515)

commit 85f50eb41fa43783b64e07d768ba3ac6d4ed7a5a
Author: rahul-sarvam <140298821+rahul-sarvam@users.noreply.github.com>
Date:   Sun Mar 8 01:16:24 2026 +0800

    Adding support to Sarvam's MoE models (#33942)

    Signed-off-by: rahul-sarvam <140298821+rahul-sarvam@users.noreply.github.com>

commit 5261223c2d1082fa3facc99c52fc96c0ebcc041b
Author: Taneem Ibrahim <taneem.ibrahim@gmail.com>
Date:   Sat Mar 7 08:37:01 2026 -0600

    [Misc] Remove duplicate parser registration (#36303)

    Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>

commit 00b814ba5a4139910c0824619a8dc6af547e178a
Author: lif <1835304752@qq.com>
Date:   Sat Mar 7 22:09:55 2026 +0800

    [V0 Deprecation] Remove unused swap_space parameter (#36216)

    Signed-off-by: majiayu000 <1835304752@qq.com>
    Co-authored-by: mcelrath

commit ee8a29511fc69e3f0f6291fa6ff1cf6e47f7750d
Author: vllmellm <vllm.ellm@embeddedllm.com>
Date:   Sat Mar 7 17:26:59 2026 +0800

    [Bugfix] Fix compressed-tensors quantization failure for DeepSeek-R1 on MI300x (#36247)

    Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>

commit 755356b3d18d8079b1b115dfd2111dc1accdb764
Author: milesial <milesial@users.noreply.github.com>
Date:   Fri Mar 6 20:27:04 2026 -0800

    feat: expose media_io_kwargs at runtime (#34778)

    Signed-off-by: Alexandre Milesi <milesial@users.noreply.github.com>

commit 58928475e4c1910df28548849734ba30d3ef4580
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Fri Mar 6 21:04:40 2026 -0600

    [ROCm][CI] Making entrypoints more deterministic on ROCm (#36293)

commit 1a9718085c7980443558db1ff4160c58096a3f0e
Author: Mengtao (Martin) Yuan <mengtaoyuan1@gmail.com>
Date:   Fri Mar 6 18:12:07 2026 -0800

    Fix CUDA graph decode capture crash in AITER FlashAttention (#36042)

    Signed-off-by: Martin Yuan <myuan@meta.com>
    Co-authored-by: Martin Yuan <myuan@meta.com>

commit 7eb524e64c4533a5e24909873bb926109f3a4ac7
Author: Kunshang Ji <kunshang.ji@intel.com>
Date:   Sat Mar 7 10:10:33 2026 +0800

    refine `vllm bench throughput --backend hf`  (#35971)

    Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

commit c7f32e08c2e49665621be72f8e83d6433b2564d1
Author: Nick Hill <nickhill123@gmail.com>
Date:   Fri Mar 6 17:24:18 2026 -0800

    [BugFix] Avoid ignored trust_remote_code warnings (#36290)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit b3546865247d5f61025b6fa256fe08c2843f6ea0
Author: Nick Hill <nhill@redhat.com>
Date:   Fri Mar 6 16:58:51 2026 -0800

    [Model Runner V2] Fix warmup for pipeline parallel (#36280)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 6a18d8789be899a3ca4a07a55bf3383050493d35
Author: Nick Hill <nhill@redhat.com>
Date:   Fri Mar 6 16:39:21 2026 -0800

    [Core] Fix benign error log during normal shutdown (#36270)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>
    Co-authored-by: Mark McLoughlin <markmc@redhat.com>

commit 24a03915f525b88ebc4c36127c3e9ccf56dc21ee
Author: Itay Alroy <75032521+itayalroy@users.noreply.github.com>
Date:   Sat Mar 7 02:36:00 2026 +0200

    mla: don't update kv cache on dummy forwards (#36282)

    Signed-off-by: Itay Alroy <ialroy@nvidia.com>

commit b5e34e1fcaefaf1d28249b6db17c99084ea25b5e
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Fri Mar 6 18:30:39 2026 -0600

    [ROCm][CI] Fixing yaml file for external amd-ci signal (#36284)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit ce8546a12b613085e5d1d0e110f2c970774a1a84
Author: Copilot <198982749+Copilot@users.noreply.github.com>
Date:   Fri Mar 6 23:55:06 2026 +0000

    [docs][torch.compile] Add fusions.md — kernel/operator fusion reference page (#35538)

    Signed-off-by: ProExpertProg <luka.govedic@gmail.com>
    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
    Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
    Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
    Co-authored-by: ProExpertProg <luka.govedic@gmail.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
    Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>

commit c188749bcdaa2c72cc3c8a4a28e722af2abc4bb8
Author: Chuan (Richard) Li <chuali@amd.com>
Date:   Fri Mar 6 12:24:03 2026 -0800

    [ROCm] Support MLA with nhead<16 and FP8 KV cache for TP=8 (Kimi K2.5/Linear) (#35850)

    Signed-off-by: Li <chuali@amd.com>

commit 225d1090a0996710a23d58cfcd1d4d2b089cc553
Author: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Date:   Fri Mar 6 13:27:20 2026 -0600

    Enabling some B200-specific tests on MI355 (#35253)

    Signed-off-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
    Signed-off-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>

commit f3c6c9c9d794fac5e74b59bc75da6e9d1921eeac
Author: eellison <elias.ellison@gmail.com>
Date:   Fri Mar 6 13:53:37 2026 -0500

    [CustomOp] CustomOp FusedRMSNormGated (#35877)

    Signed-off-by: Elias Ellison <elias.ellison@gmail.com>
    Signed-off-by: eellison <elias.ellison@gmail.com>

commit 26bd43b52df305c5610efed9e72261d263b9fe75
Author: Nick Hill <nhill@redhat.com>
Date:   Fri Mar 6 08:28:09 2026 -0800

    Revert "[BugFix] Fix engine hanging after KV cache initialization fai… (#36262)

commit 6b625a8807f4c82137c46d58dfb38f8eeef4865c
Author: Travis Johnson <tsjohnso@us.ibm.com>
Date:   Fri Mar 6 09:13:05 2026 -0700

    [Bugfix] Quickfix followups to busy loop removal in #28053 (#36068)

    Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
    Signed-off-by: Nick Hill <nickhill123@gmail.com>
    Co-authored-by: Nick Hill <nickhill123@gmail.com>

commit 54756b61091e3c913436ddd00b9d99e11e7c9a8c
Author: Richard Zou <zou3519@users.noreply.github.com>
Date:   Fri Mar 6 10:17:27 2026 -0500

    [compile] Stop unconditionally patching constrain_to_fx_strides (#36152)

    Signed-off-by: Richard Zou <zou3519@gmail.com>

commit 39f9ea0da4a45e9638937b062f86f03db313a0d8
Author: Raphaël Rialland <36076211+TQCB@users.noreply.github.com>
Date:   Fri Mar 6 15:15:31 2026 +0100

    [Bugfix] Fix `cudagraph_mode:FULL` dispatch (This does not impact `FULL_AND_PIECEWISE` (default)) (#36165)

commit e4ae148a787df846beb194078c35655c44784bd5
Author: Isotr0py <mozf@mail2.sysu.edu.cn>
Date:   Fri Mar 6 22:06:59 2026 +0800

    [Refactor] Modular video loader backend refactoring (#35202)

    Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

commit 1d0c0d209c3de3be2d54cd70c2618472a2fe4929
Author: Isotr0py <mozf@mail2.sysu.edu.cn>
Date:   Fri Mar 6 22:06:45 2026 +0800

    [Misc] Lazy import registered processors (#36024)

    Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
    Co-authored-by: Roger Wang <hey@rogerw.io>

commit fcb73f306ccedb07ff33e3e3696018f66ccd40ea
Author: Chenguang Zheng <645327136@qq.com>
Date:   Fri Mar 6 20:00:09 2026 +0800

    [bugfix] add api process rank in default multimodal request (#36150)

    Signed-off-by: fake0fan <645327136@qq.com>
    Signed-off-by: Chenguang ZHENG <645327136@qq.com>

commit e2090bf3af96843c899d6f5c85d9c12b03b5cabb
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Fri Mar 6 11:50:28 2026 +0000

    [CI] Fix startup error test (#36230)

    A change in engine startup error messages in #35478 caused this test failure.

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 2a00d3241f2c5810f4ba6a3c5fe79f7c76a94900
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Fri Mar 6 03:17:08 2026 -0600

    [CI][MM] Gate vision encoder attention mask to MiniCPM only, fixing Aria regression (#36206)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit 10f4db4dbecaafc8c0af8b36e9e0bc2f186deb2d
Author: Alex Brooks <albrooks@redhat.com>
Date:   Fri Mar 6 02:16:56 2026 -0700

    [Frontend] Add Support for MM Encoder/Decoder Beam Search (Offline) (#36153)

    Signed-off-by: Alex Brooks <albrooks@redhat.com>
    Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

commit 5b3ba94ab4bd9da739bcc27cdd05505467fa499e
Author: Nicolò Lucchesi <nlucches@redhat.com>
Date:   Fri Mar 6 08:51:21 2026 +0100

    [Core][KVConnector] Support HMA+NixlConnector (#35758)

    Signed-off-by: NickLucche <nlucches@redhat.com>

commit 90f3c01fa4dfc00d13beb8ae758d43365f7ba91f
Author: zhanqiuhu <49648934+ZhanqiuHu@users.noreply.github.com>
Date:   Fri Mar 6 02:50:44 2026 -0500

    [Spec Decode][KV Connector] Fix KV transfer in PD + speculative decoding (#35158)

    Signed-off-by: Claude <noreply@anthropic.com>
    Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
    Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>

commit 807d6803376ff8610efbf9da23f772a5dbd7b5ea
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Fri Mar 6 01:15:12 2026 -0600

    [ROCm][CI] Fix tool use test stability - disable skinny GEMM, prefix caching, eliminate batch variance (#35553)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit 5afb387bd43cef01d68119d017587e689b0729fa
Author: Tyler Michael Smith <tyler@neuralmagic.com>
Date:   Fri Mar 6 01:15:46 2026 -0500

    Change "following fields were present in the request but ignored" log from warn to debug (#36173)

    Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>

commit 43e77e59abcaf0764aa6851fcc2bc9b86d4afdba
Author: Walter Beller-Morales <walterbm@users.noreply.github.com>
Date:   Fri Mar 6 01:15:29 2026 -0500

    [BugFix] avoid infinite loop with VLLM_PORT and get_open_ports_list (#36191)

    Signed-off-by: walterbm <walter.beller.morales@gmail.com>

commit 00bd08edeee5dd4d4c13277c0114a464011acf72
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Mar 6 01:15:19 2026 -0500

    [Security] Respect user trust_remote_code setting in NemotronVL and KimiK25 (#36192)

    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 43f10573c9701df093f6523da43cc1a2fac1b3b3
Author: Ajay Anubolu <124525760+AjAnubolu@users.noreply.github.com>
Date:   Thu Mar 5 22:15:12 2026 -0800

    [Bugfix] Fix misleading context length error messages (#36197)

    Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit 86e1060b17d9042ab8f7b7baba26b1d6cbc36c2b
Author: Yongye Zhu <zyy1102000@gmail.com>
Date:   Fri Mar 6 01:04:44 2026 -0500

    [Bugfix] Fix inner_dp_world initialization order for multi-node TP (#35892)

    Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
    Signed-off-by: Nick Hill <nickhill123@gmail.com>
    Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
    Co-authored-by: Nick Hill <nickhill123@gmail.com>
    Co-authored-by: Nick Hill <nhill@redhat.com>

commit 27066d1b2bd0dea89d617afa24da611d9a32e36a
Author: Mark McLoughlin <markmc@redhat.com>
Date:   Fri Mar 6 06:04:31 2026 +0000

    [Frontend][Core] Add shutdown timeout - allowing in-flight requests to finish (#34730)

    Signed-off-by: Mark McLoughlin <markmc@redhat.com>
    Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>

commit 57c84ff129de4ab8072bbc9756942650803001ef
Author: cong-or <conchubhar.gannon@gmail.com>
Date:   Fri Mar 6 06:04:09 2026 +0000

    perf: add __slots__ to KVCacheBlock  (#36164)

    Signed-off-by: cong-or <conchubhar.gannon@gmail.com>

commit e68de8adc0301babb3bb3fcd2ddccaf98e7695c8
Author: Xiang Shi <realkevin@tutanota.com>
Date:   Fri Mar 6 14:01:02 2026 +0800

    docs: fix wrong cc in int8.md (#36209)

    Signed-off-by: Xiang Shi <realkevin@tutanota.com>

commit a1ffa56a1e6b644a176c0546053dae01f1823a61
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Thu Mar 5 23:07:29 2026 -0600

    [CI] Fix bge-m3 similarity reference values after *Defination* typo fix (#36208)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit 0a208d1f549a5e35605af5b01685d64cd727b73b
Author: Shiyan Deng <dsy842974287@meta.com>
Date:   Thu Mar 5 20:58:09 2026 -0800

    [BugFix] Fix engine hanging after KV cache initialization failure (#35478)

    Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
    Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>

commit 03a49bb8f0c8ad3472a61ec163167898fda02917
Author: Shiyan Deng <dsy842974287@meta.com>
Date:   Thu Mar 5 20:57:51 2026 -0800

    [Feature] Add --distributed-timeout-seconds CLI option (#36047)

    Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
    Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>

commit 8e87cc57f1b071d69a93b5d5aa27a5841f817739
Author: Shiyan Deng <dsy842974287@meta.com>
Date:   Thu Mar 5 20:57:32 2026 -0800

    [Bug] Fix a corner case in _process_simple_streaming_events (#34754)

    Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
    Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>

commit 6dd302653f82148ad44d9766fdc3daede0ede040
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Fri Mar 6 12:32:48 2026 +0800

    [Misc] Rename `group_mm_kwargs_by_modality -> group_and_batch_mm_kwargs` (#36158)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit de00ebeac4abddafff9f23bb598a6619b5892261
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Fri Mar 6 12:25:11 2026 +0800

    [Bugfix] Fix simple Mistral-Small example (#36156)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit 639680d220c9103cf47d63c5ff0ad3885426f487
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Thu Mar 5 22:23:10 2026 -0600

    [ROCm][CI] Adding missing dependencies for Multi-modal models tests (#36177)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit c5362c739fb31c171fd345ed4a83fb0127804aa3
Author: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Date:   Thu Mar 5 22:21:06 2026 -0600

    Reenable features for ROCm attention backends (#36185)

    Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>

commit 0a49676fb0e54c9229a39f6304bc88b7d24e0355
Author: Nikhil Gupta <nikhil.gupta2@arm.com>
Date:   Fri Mar 6 03:48:59 2026 +0000

    cpu: aarch64: Upgrade OneDNN for aarch64 to add support for int8 matmul (#36147)

    Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>

commit c012a8c477dd78b4444f22568b2bf1b08f2ad813
Author: Jeffrey Wang <jeffreywang@anyscale.com>
Date:   Thu Mar 5 16:42:21 2026 -0800

    Don't fire ray compatibility webhook when PR or branch is not provided (#36088)

    Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>

commit ebed80a7c8c652ff43b5bd910c8fe35d73bfa786
Author: Dor Huri <92430368+dorhuri123@users.noreply.github.com>
Date:   Fri Mar 6 02:22:43 2026 +0200

    [Performance] Extract KV-cache update from TreeAttention backend (#35384)

    Signed-off-by: dorhuri123 <dor.huri1@live.biu.ac.il>

commit a73af584fe6d4c1c2781d537c35e3cc85f58480b
Author: Nick Hill <nhill@redhat.com>
Date:   Thu Mar 5 14:48:10 2026 -0800

    [Model Runner V2] Fix warmup for very small kvcache and/or blocksizes (#36176)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit a97954b6a8fa41a162ebf58f80a1460a98e0baf0
Author: Zhengxu Chen <zhxchen17@fb.com>
Date:   Thu Mar 5 15:08:12 2026 -0500

    [compile] Consistent compiler config for saved/loaded vllm backends. (#35810)

    Signed-off-by: zhxchen17 <zhxchen17@fb.com>

commit a911f4dd20d0a0fcfee362f096e9c6fd23d59590
Author: Yanhong Li <90665285+yanhong-lbh@users.noreply.github.com>
Date:   Thu Mar 5 11:51:06 2026 -0800

    [Model] Add support for OLMo Hybrid (#32550)

commit 5395471d29f703f19213da629102edc6e9b944be
Author: Russell Bryant <rbryant@redhat.com>
Date:   Thu Mar 5 14:08:48 2026 -0500

    [CI] Add explicit permissions to macOS smoke test workflow (#35775)

    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit a57c877f18188cb7bafc0fc5309b6c88fe2a8f66
Author: Frank Wang <41319051+frankwang28@users.noreply.github.com>
Date:   Thu Mar 5 11:05:56 2026 -0800

    [BugFix] Fallback from FA4->FA2 for Batch Invariance (#36059)

    Signed-off-by: frankwang28 <frank.wbb@hotmail.com>

commit f9170209834af0e8e53a6d16ccd17eacc0db2c67
Author: Xin Yang <105740670+xyang16@users.noreply.github.com>
Date:   Thu Mar 5 10:47:53 2026 -0800

    [Perf] Optimize FusedMoEModularKernel output tensor using torch.empty (#35794)

    Signed-off-by: Xin Yang <xyangx@amazon.com>

commit 86483ca7749b3d7a2ae16283a7896c203983f1ef
Author: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Date:   Thu Mar 5 19:49:05 2026 +0200

    [Bugfix] Disable FlashInfer TRTLLM BF16 path for non-gated MoE (#36146)

    Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>

commit b93a9e6f6d91baf59e39089ce8dbf2f2a3f0f6c9
Author: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date:   Thu Mar 5 19:29:30 2026 +0200

    ParakeetProjection.norm = RMSNorm instead of nn.LayerNorm (#36133)

    Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>

commit d8839ef7d964dd98b82e671e743b42754be3350c
Author: Xinyu Chen <xinyu1.chen@intel.com>
Date:   Fri Mar 6 01:19:18 2026 +0800

    [XPU] Enable ModelRunnerV2 on XPU (#36078)

    Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>

commit e998fa76b99a73ba923adeb7457376228269cc9c
Author: Avery Miao <108777392+jjmiao1@users.noreply.github.com>
Date:   Fri Mar 6 01:16:29 2026 +0800

    [BUGFIX]Fix Qwen-Omni models audio max_token_per_item estimation error leading to encoder_cache_size is 0 (#35994)

    Signed-off-by: Miao, Avery <avery.miao@intel.com>

commit 6a895197fafa7069be75ff615709b77546bcec30
Author: Jiayi Yan <66017932+1195343015@users.noreply.github.com>
Date:   Fri Mar 6 01:05:46 2026 +0800

    [Bugfix][CI] fix typos (#34934)

    Signed-off-by: 1195343015 <1195343015@qq.com>
    Signed-off-by: Jiayi Yan <66017932+1195343015@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 8c760b6ab6993c6a0d5f639747baefedb4612525
Author: Sage Moore <sage@neuralmagic.com>
Date:   Thu Mar 5 08:51:26 2026 -0800

    [ROCm] Refactor ROCm attention backend selection logic (#35246)

    Signed-off-by: Sage Moore <sage@neuralmagic.com>

commit 3ee68590c7fafe05f1db1f1bee019c7b3a83ec96
Author: AllenDou <allen.dou@hotmail.com>
Date:   Fri Mar 6 00:07:37 2026 +0800

    refactor funasr model. (#36108)

    Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com>
    Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
    Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

commit 719634815791ad97cf1e35ad52d4e39e630aeafd
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Fri Mar 6 00:07:19 2026 +0800

    [Bugfix] Fix Qwen-VL tokenizer implementation (#36140)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit 176c799f4c512daf0904556940fc9a2c938af5ce
Author: Ning Xie <andy.xning@gmail.com>
Date:   Fri Mar 6 00:00:12 2026 +0800

    [openai api] log exception in exception handler (1/N) (#31164)

    Signed-off-by: Andy Xie <andy.xning@gmail.com>

commit 612e7729c2a548a7b6c9baa1821f419909777ffa
Author: Or Ozeri <oro@il.ibm.com>
Date:   Thu Mar 5 16:25:15 2026 +0200

    [KVConnector] Scheduler: Fix num_computed_tokens after async KV load (#34616)

    Signed-off-by: Or Ozeri <oro@il.ibm.com>

commit ecde7af9c492077bbf1bd8df16d941b1b441b60b
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Thu Mar 5 13:59:44 2026 +0000

    Fix import that was moved in Transformers 5.2.0 (#36120)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 8df523351f6e665ea5b07f1b731aa2449d197624
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Thu Mar 5 13:58:16 2026 +0000

    [Docs] Only build docs if `documentation` or `ready` labels are present (#36135)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit b03ff6a96bb090676cab07c432b4b0937abb7011
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Thu Mar 5 07:52:49 2026 -0600

    [CI] Stabilize test_no_args_tool_call and add ROCm-specific server args (#36107)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit ed81d5edd16b0d933d0e1115003c258dcecd991c
Author: Ajay Anubolu <124525760+AjAnubolu@users.noreply.github.com>
Date:   Thu Mar 5 04:14:20 2026 -0800

    [Bugfix] Fix RunAI streamer crash with S3-hosted model paths (#35976)

    Signed-off-by: AjAnubolu <anuboluajay@gmail.com>
    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit 3c23ac840e758e7b4ff34752e25d9eac12e4a3da
Author: Shiyan Deng <dsy842974287@meta.com>
Date:   Thu Mar 5 03:37:47 2026 -0800

    [Bugfix] Fix mypy errors in hermes_tool_parser.py (#36114)

    Signed-off-by: Shiyan Deng <dsy842974287@meta.com>

commit a708ef59443377aeda2d8ece804fa1e916881577
Author: cjackal <44624812+cjackal@users.noreply.github.com>
Date:   Thu Mar 5 19:55:31 2026 +0900

    [Misc] Fix SyntaxWarning - invalid escape sequence '\e' (#36020)

    Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>

commit 66a2209645438e9ad20b1bfb8fa4eca219944d46
Author: Kunshang Ji <kunshang.ji@intel.com>
Date:   Thu Mar 5 18:36:39 2026 +0800

    [Hardware] Replace `torch.cuda.synchronize()` api with `torch.accelerator.synchronize` (#36085)

    Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

commit 0bfa229bf1f6b12f215d045f4acb4b9607937f32
Author: Doug Smith <dosmith@redhat.com>
Date:   Thu Mar 5 04:43:50 2026 -0500

    [Release] Include source distribution (sdist) in PyPI uploads (#35136)

    Signed-off-by: dougbtv <dosmith@redhat.com>
    Co-authored-by: Daniele Trifirò <dtrifiro@redhat.com>

commit 7493c51c5532c25e2f2573eb274461e39f7e2a0b
Author: Paco Xu <paco.xu@daocloud.io>
Date:   Thu Mar 5 17:39:50 2026 +0800

    [Docs] add Dynamo/aibrix integration and kubeai/aks link (#32767)

    Signed-off-by: Paco Xu <paco.xu@daocloud.io>

commit ac773bbe8095b4493c258abbf35c2a2d10d2faab
Author: Reagan Lee <96998476+reaganjlee@users.noreply.github.com>
Date:   Thu Mar 5 01:38:25 2026 -0800

    [Docs] Update docs to include mm processor + encoder benchmarks  (#34083)

    Signed-off-by: Reagan <reaganjlee@gmail.com>

commit 48e376a007173910330a8c83f53474b21e4279c0
Author: Christian Munley <cmunley@nvidia.com>
Date:   Thu Mar 5 01:06:57 2026 -0800

    qwen3coder tool parser fix anyOf double encoded parameters (#36032)

    Signed-off-by: Christian Munley <cmunley@nvidia.com>

commit 21eb2c3372fb6447ef36bee44ff7af79a330ffec
Author: Isotr0py <mozf@mail2.sysu.edu.cn>
Date:   Thu Mar 5 16:55:04 2026 +0800

    [Chore] Correct MTP models test registry ordering (#36115)

    Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

commit e2b31243c092e9f4ade5ffe4bf9a5d5ddae06ca7
Author: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com>
Date:   Wed Mar 4 22:24:08 2026 -0800

    [Docs] Update `CacheConfig` block_size docstring to remove inaccurate limit when using CUDA (#35632)

    Signed-off-by: Seiji Eicher <seiji@anyscale.com>

commit c3598d02fa638119ae4ac933850dbcd3d629fa1c
Author: Martin Hickey <martin.hickey@ie.ibm.com>
Date:   Thu Mar 5 06:14:50 2026 +0000

    [Misc] Remove deprecated items that are due for removal (#36006)

    Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>

commit 57c629e9c1ce10ae649c5cb7411770ac31240bb0
Author: Benjamin Chislett <bchislett@nvidia.com>
Date:   Thu Mar 5 01:10:54 2026 -0500

    [Bugfix] Fix block_size for hybrid model MTP (#36036)

    Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>

commit d106bf39f56cdc59d08a84094c0de41a0be9ad0f
Author: zihaoanllm <zihaoan2@amd.com>
Date:   Thu Mar 5 13:44:07 2026 +0800

    [Doc] Add Parallel Draft Models (#35973)

    Signed-off-by: <zihaoan2@amd.com>
    Signed-off-by: zihaoanllm <zihaoan2@amd.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit b0651021e5c042e0893929e1b80cf367c6611708
Author: Yanan Cao <gmagogsfm@users.noreply.github.com>
Date:   Wed Mar 4 21:25:59 2026 -0800

    [Kernel] [Helion] [11/N] Retune configs for silu_mul_fp8 (#36062)

commit f600d5192e287f122b358044f52e17b1d23c06ab
Author: Hanjun Cho <gkswns0531@gmail.com>
Date:   Thu Mar 5 13:57:20 2026 +0900

    [Bugfix] Fix score layer quantization for sequence classification models  - Qwen3 (VL) Reranker (#35849)

    Signed-off-by: Hanjun Cho <gkswns0531@gmail.com>
    Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>

commit 8e7820131ee8d0295e6a533d745f6ca8085baec9
Author: Tianmu Li <tianmu.li@intel.com>
Date:   Wed Mar 4 20:56:49 2026 -0800

    [Perf] Use dummy M for weight prepacking on x86 (#35890)

    Signed-off-by: Li, Tianmu <tianmu.li@intel.com>

commit 0a12cea25f4a0c2a2ce1c145677a7f54545d8d7d
Author: Andrii Skliar <andreyws96@gmail.com>
Date:   Thu Mar 5 05:56:47 2026 +0100

    Order `config.py` in Lexicographical order (#35866)

    Signed-off-by: Andrii Skliar <askliar@nvidia.com>
    Co-authored-by: Andrii Skliar <askliar@nvidia.com>

commit dd6dbd93f8d299ee1e0fdbdd7cd0d41f47a4093f
Author: Zhengxu Chen <zhxchen17@fb.com>
Date:   Wed Mar 4 23:56:30 2026 -0500

    [compile] Fix extra cache save on warm start. (#35921)

    Signed-off-by: zhxchen17 <zhxchen17@fb.com>

commit 26366009c57251998fecf5909b06b5fcd297d072
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Thu Mar 5 04:51:46 2026 +0000

    [CI] Don't leave docs preview comment on closed PRs (#36087)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 16c472abe7e0e77e7924080bd4ed55bdceb86c53
Author: Nick Hill <nhill@redhat.com>
Date:   Wed Mar 4 20:11:59 2026 -0800

    [Core] Move ray-specific WorkerWrapperBase methods to RayWorkerWrapper (#35328)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 3b23d57c960c77edbc31f9bcae9dcb69a491fd19
Author: daje0601 <73736988+daje0601@users.noreply.github.com>
Date:   Thu Mar 5 11:38:25 2026 +0900

    [Model] Add LoRA support for Whisper models (#29856)

    Signed-off-by: daje0601 <englishmt4118@gmail.com>
    Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>

commit 2f4226fe5280b60c47b4f6f01d9b18ac9cda2038
Author: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Date:   Wed Mar 4 21:13:12 2026 -0500

    [CI] Fix pre-commit mypy issue in main (#36049)

commit 792cbd64ca1ad7b2b3bc927f1a11cf2532f624da
Author: nkm-meta <166880490+nkm-meta@users.noreply.github.com>
Date:   Wed Mar 4 16:50:32 2026 -0800

    Add platform method to enable custom collective ops registration (#34760)

    Signed-off-by: Naina Kuruballi Mahesh <nainakm@meta.com>

commit 2ed4722e26864a212fbd7a48ae663d97318a8887
Author: Zhengxu Chen <zhxchen17@fb.com>
Date:   Wed Mar 4 19:48:36 2026 -0500

    [compile] Reduce log spam from compile. (#36044)

    Signed-off-by: zhxchen17 <zhxchen17@fb.com>

commit a3299c3d1d6c260c35a866599bdf4d3e7b7d84dd
Author: Nick Hill <nhill@redhat.com>
Date:   Wed Mar 4 15:26:35 2026 -0800

    [Model Runner V2] Misc code simplification (#35941)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 6c21a0c2d75a716fa0b8bcf90b68dd46d2bc7265
Author: Andreas Karatzas <akaratza@amd.com>
Date:   Wed Mar 4 16:48:46 2026 -0600

    [ROCm][CI] Added MI325 mirrors (stage C) (#35239)

    Signed-off-by: Andreas Karatzas <akaratza@amd.com>

commit 562339abc321ac5e86cc7b000ef0734839eea49f
Author: Shanshan Shen <467638484@qq.com>
Date:   Thu Mar 5 06:25:56 2026 +0800

    [Misc] Support OOT linear method registering (#35981)

    Signed-off-by: shen-shanshan <467638484@qq.com>

commit d7adcadb9bf4c7ea240fcc6cc668192bc2260ec0
Author: amitz-nv <203509407+amitz-nv@users.noreply.github.com>
Date:   Thu Mar 5 00:23:51 2026 +0200

    [Bugfix] Fix passing of activation_type to trtllm fused MoE NVFP4 and FP8 (#36017)

    Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com>

commit f678c3f61a2f3f224f29d3574225a6660e818e7e
Author: Simon Mo <simon.mo@hey.com>
Date:   Wed Mar 4 14:05:32 2026 -0800

    [RL] [Weight Sync] Guard IPC update-info pickle deserialization behind insecure serialization flag (#35928)

    Co-authored-by: Cursor Agent <cursoragent@cursor.com>

commit be0a3f7570726ca49cc9b53f9b48175418bddda0
Author: Thomas Parnell <tpa@zurich.ibm.com>
Date:   Wed Mar 4 22:52:44 2026 +0100

    [Bugfix] Fix race in non-blocking num_accepted_tokens GPU->CPU copy (#36013)

    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit 17dc9c7fc94534e542b6849192ed382c122d2d08
Author: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date:   Wed Mar 4 20:55:11 2026 +0000

    [CI] Bump `mypy` version (#34950)

    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 7eca85911072b9732293c3d4181e20a4c9394b21
Author: fenypatel99 <133059111+fenypatel99@users.noreply.github.com>
Date:   Wed Mar 4 12:53:38 2026 -0800

    Add PyTorch profiler schedule support with warmup/active iterations (#35240)

commit 636ee223ac976dfc3d4e93b31d33521230810f00
Author: Russell Bryant <rbryant@redhat.com>
Date:   Wed Mar 4 15:27:31 2026 -0500

    [Docs] Document security risks of GPT-OSS Python tool (#35139)

    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit b7d59ffce2f951e0ec8d1dc3a2f1e3d27f779906
Author: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Date:   Wed Mar 4 15:13:40 2026 -0500

    [UX] Remove NoOpOffloader log (#35678)

    Signed-off-by: Robert Shaw <robshaw@redhat.com>
    Co-authored-by: Robert Shaw <robshaw@redhat.com>

commit 5569f5218d3b8a08cfbb9fd51c9f01852f16ddbc
Author: Richard Zou <zou3519@users.noreply.github.com>
Date:   Wed Mar 4 15:13:17 2026 -0500

    [torch.compile] Stop lazily compiling (#35472)

    Signed-off-by: Richard Zou <zou3519@gmail.com>

commit 138d891d7f42004c417561050a6813792316b13b
Author: Davina Zaman <davzaman@users.noreply.github.com>
Date:   Wed Mar 4 11:44:39 2026 -0800

    [Docs] Clarify structured outputs configuration for Qwen3 reasoning mode (#32441)

    Signed-off-by: Davina Zaman <davzaman@users.noreply.github.com>
    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit d7166e74c191741065d280441965adc3a9ea89c3
Author: Stefano Castagnetta <stefanocastagnetta@gmail.com>
Date:   Wed Mar 4 20:41:21 2026 +0100

    [CI] Add Blackwell AsyncTP correctness test (#35871)

    Signed-off-by: Stefano Castagnetta <scastagnetta@nvidia.com>

commit 417fd28fb125cbb166ef3ada187d06d0c8dd0d30
Author: Nick Hill <nhill@redhat.com>
Date:   Wed Mar 4 10:53:17 2026 -0800

    [Model Runner V2] Fix pooling (#36019)

    Signed-off-by: Nick Hill <nickhill123@gmail.com>

commit 7faba503c403bc8c562888df3a841b6df104d042
Author: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Date:   Wed Mar 4 20:47:17 2026 +0200

    [Kernel][Mamba] Optimize Mamba2 SSD prefill Triton kernels (#35397)

    Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>

commit bc6be89d16c6a0b3763a3fdc2623b90a9f7da8f1
Author: Hyunkyun Moon <mhg5303@gmail.com>
Date:   Thu Mar 5 03:41:52 2026 +0900

    [Frontend] Add vllm launch command for GPU-less preprocessing serving (#34551)

    Signed-off-by: HyunKyun Moon <mhg5303@gmail.com>

commit 32224f568a6965267ad6d430973bc42c27ded0b1
Author: Maxime Grenu <69890511+cluster2600@users.noreply.github.com>
Date:   Wed Mar 4 19:31:35 2026 +0100

    docs: update CPU Docker images to reference Docker Hub instead of AWS ECR (#34882)

    Signed-off-by: Maxime Grenu <69890511+cluster2600@users.noreply.github.com>
    Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit f3dc292e9f2cad55f914b7a7ed73e1969174ad77
Author: Abhishek Mathukiya <144843228+abhishkh@users.noreply.github.com>
Date:   Wed Mar 4 13:13:54 2026 -0500

    docs: add version requirement note for --profiler-config flag (#32454)

    Signed-off-by: abhishkh <mathukiya.a@northeastern.edu>

commit 138c5fa1869188ddeffd060ee586ed915d996d70
Author: Chen <zhuchen200245@163.com>
Date:   Wed Mar 4 12:11:34 2026 -0600

    [Docs] Add RunPod GPU deployment guide for vLLM (#34531)

    Signed-off-by: lisperz <zhuchen200245@163.com>
    Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit 2f2c1d73a745d8a38d1a21a5865a7d53d8d616b7
Author: Russell Bryant <rbryant@redhat.com>
Date:   Wed Mar 4 13:01:42 2026 -0500

    [Docs] Upgrade dynamic LoRA warning to admonition block (#35218)

    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit fb3e78ab095f48f7f1856176783d29b6652340cf
Author: Bhuminjay Soni <Soni5Happy@gmail.com>
Date:   Wed Mar 4 23:31:16 2026 +0530

    [Feature][CI]: compare `func` & `no_func` outputs in test_functionalization.py  (#35481)

    Signed-off-by: Bhuminjay <bhuminjaysoni@gmail.com>
    Signed-off-by: Bhuminjay Soni <Soni5Happy@gmail.com>
    Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>

commit fd3bfe74c972bccc3c7c45cb3be44cb4c3a26090
Author: Michael Yao <haifeng.yao@daocloud.io>
Date:   Thu Mar 5 01:58:59 2026 +0800

    [Docs] Update design/multiprocessing.md (#30677)

    Signed-off-by: windsonsea <haifeng.yao@daocloud.io>

commit bfdb512f111156a8f455dd9f396c1d15ba5bf655
Author: tc-mb <157115220+tc-mb@users.noreply.github.com>
Date:   Thu Mar 5 01:46:17 2026 +0800

    fix minicpmo4.5: fix attn_mask in vit attn && fix resampler pos_emb i… (#34127)

    Signed-off-by: tc-mb <caitianchi@modelbest.cn>
    Co-authored-by: hezhihui <hezhihui@modelbest.cn>

commit d25c1ec3c9706746e7606821101172194c005f0d
Author: Sage <80211083+sagearc@users.noreply.github.com>
Date:   Wed Mar 4 19:45:35 2026 +0200

    docs(cpu): Clarify pre-built wheels requirement for CPU Python-only build (#35090)

    Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>

commit 7cc6058ac69009b7d595c891f0b439d1d6b0351d
Author: Xing Liu <46082449+XingLiu1@users.noreply.github.com>
Date:   Thu Mar 5 01:23:34 2026 +0800

    [Doc] Add MTP docs and update speculative decoding guidance (#35197)

    Signed-off-by: liuxing <945764858@qq.com>

commit 28028dff2fed19e0face08a303b86273d954979a
Author: Manrique Vargas <mv1742@nyu.edu>
Date:   Wed Mar 4 12:15:35 2026 -0500

    fix(docs): use static rdzv backend in multi-node troubleshooting script (#34784)

    Signed-off-by: machov <mv1742@nyu.edu>

commit 3417ba5648b73b8125bdd20a2b9bb11ac35b9ab7
Author: Dr Alex Mitre <bedr10_capacitacion@hotmail.com>
Date:   Wed Mar 4 11:09:19 2026 -0600

    docs: add README for logits_processor examples (#35933)

commit 58cfe0dc44b29ced86cf8a6db069e55faf5d4f7d
Author: Yan Ma <yan.ma@intel.com>
Date:   Thu Mar 5 01:08:05 2026 +0800

    Fix phi4-mm and remove cuda binding (#35964)

    Signed-off-by: Yan Ma <yan.ma@intel.com>

commit e86221deb6859c28325097f4568e6d553ae92e8d
Author: simone-dotolo <84937474+simone-dotolo@users.noreply.github.com>
Date:   Wed Mar 4 18:03:14 2026 +0100

    [Doc] Fix GPU Worker count in Process Count Summary (#36000)

    Signed-off-by: simone-dotolo <simonedotolo@libero.it>
    Signed-off-by: simone-dotolo <84937474+simone-dotolo@users.noreply.github.com>
    Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

commit 289fc48ab73fb1eb610a72b4ddde9694e529bfba
Author: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Date:   Wed Mar 4 18:43:13 2026 +0200

    Use MMEncoderAttention (=use FlashAttention) instead of torch.sdpa in radio.py (#35653)

commit 2f2212e6ccfc01d123879d635d19448f5cc3653c
Author: Christian Pinto <christian.pinto@ibm.com>
Date:   Wed Mar 4 16:01:03 2026 +0000

    Split generic IO Processor plugins tests from Terratorch specific ones (#35756)

    Signed-off-by: Christian Pinto <christian.pinto@ibm.com>

commit 18e01a0a10e37ed7a705b46373b9b004f03b9e6b
Author: Nicolò Lucchesi <nlucches@redhat.com>
Date:   Wed Mar 4 16:12:27 2026 +0100

    [Misc] Add `--attention-backend auto` option (#35738)

    Signed-off-by: NickLucche <nlucches@redhat.com>

commit 6cb901093f3df8e26cbc0a8a0e1a884f4dbaa5ea
Author: sungsoo ha <hasungsoo@gmail.com>
Date:   Wed Mar 4 07:01:57 2026 -0800

    [Core] Add All-to-All communication backend for DCP  (#34883)

    Signed-off-by: Sungsoo Ha <sungsooh@nvidia.com>
    Signed-off-by: sungsoo ha <hasungsoo@gmail.com>
    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
    Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
    Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

commit ead7bde1ab2ba939f0c3a73b3c829860d82888c8
Author: Cyrus Leung <tlleungac@connect.ust.hk>
Date:   Wed Mar 4 22:47:32 2026 +0800

    [Bugfix] Make `kaldi_native_fbank` optional (#35996)

    Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

commit 6aa6ad8992a928777f840a843f897ed4cb04c763
Author: Qi Wang <qiwa@nvidia.com>
Date:   Wed Mar 4 06:01:30 2026 -0800

  …
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Mar 12, 2026
…R rms_norm (vllm-project#36101)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…R rms_norm (vllm-project#36101)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
mklasby pushed a commit to mklasby/vllm that referenced this pull request Apr 1, 2026
…R rms_norm (vllm-project#36101)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants