Skip to content

[Bugfix][Model] Fix gpt-oss batch invariance#35404

Merged
yewentao256 merged 2 commits intovllm-project:mainfrom
jzakrzew:fix-gpt-oss-batch-invariance
Feb 27, 2026
Merged

[Bugfix][Model] Fix gpt-oss batch invariance#35404
yewentao256 merged 2 commits intovllm-project:mainfrom
jzakrzew:fix-gpt-oss-batch-invariance

Conversation

@jzakrzew
Copy link
Copy Markdown
Contributor

@jzakrzew jzakrzew commented Feb 26, 2026

Purpose

GPT-OSS is listed as verified in the batch invariance doc, but rerunning the provided tests on an H100 suggests it does not in fact work in all the claimed supported configurations:

# VLLM_TEST_MODEL="openai/gpt-oss-20b" uv run pytest tests/v1/determinism/test_batch_invariance.py -k "bs1_vs_bsN and not MLA"
...

============================================================================================================== short test summary info ==============================================================================================================
FAILED tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[FLASH_ATTN] - Failed: Batch invariance violated in 32/32 prompts. See output above for details.
FAILED tests/v1/determinism/test_batch_invariance.py::test_logprobs_bitwise_batch_invariance_bs1_vs_bsN[TRITON_ATTN] - Failed: Batch invariance violated in 32/32 prompts. See output above for details.
============================================================================================= 2 failed, 15 deselected, 17 warnings in 67.67s (0:01:07) ==============================================================================================

Details of the setup at the end of the description.

The root cause of this issue seems to be cublas and the fact that vLLM does not replace cublas calls aggressively enough in batch invariant mode. (It also makes some bad assumptions about cublas behavior on Hopper in batch_invariance.py, but that is out of scope of this PR)

This PR fixes it by doing two things:

  1. Replacing direct use of nn.Linear for the router layer with ReplicatedLinear
  2. Making UnquantizedLinearMethod unconditionally call the batch invariant linear layer when in batch invariance mode on CUDA, instead of doing this in the router layer only. GPT-OSS on H100 needs this for qkv and output projection too.

Test Plan

# VLLM_TEST_MODEL="openai/gpt-oss-20b" uv run pytest tests/v1/determinism/test_batch_invariance.py -k "bs1_vs_bsN and not MLA"

Test Result

...

tests/v1/determinism/test_batch_invariance.py ..                                                                                                                                                                                              [100%]

================================================================================================= 2 passed, 15 deselected, 17 warnings in 100.57s (0:01:40) =============================================================================================
The output of python collect_env.py
Collecting environment information...
uv is set

Python version               : 3.12.3 (main, Jan 22 2026, 20:57:42) [GCC 13.3.0] (64-bit runtime)
Python platform              : Linux-5.15.0-52-generic-x86_64-with-glibc2.39

==============================
       CUDA / GPU Info
==============================
Is CUDA available            : True
CUDA runtime version         : 13.1.115
CUDA_MODULE_LOADING set to   :
GPU models and configuration : GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version        : 580.105.08
cuDNN version                : Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.17.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.17.1
HIP runtime version          : N/A
MIOpen runtime version       : N/A
Is XNNPACK available         : True

==============================
          CPU Info
==============================
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          48
On-line CPU(s) list:             0-47
Vendor ID:                       AuthenticAMD
Model name:                      AMD EPYC 7413 24-Core Processor
CPU family:                      25
Model:                           1
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       1
Stepping:                        1
Frequency boost:                 enabled
CPU(s) scaling MHz:              58%
CPU max MHz:                     3630.8101
CPU min MHz:                     1500.0000
BogoMIPS:                        5300.17
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                  AMD-V
L1d cache:                       768 KiB (24 instances)
L1i cache:                       768 KiB (24 instances)
L2 cache:                        12 MiB (24 instances)
L3 cache:                        128 MiB (4 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-47
Vulnerability Itlb multihit:     Not affected
Vulnerabil
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-cufile-cu12==1.13.1.3
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cutlass-dsl==4.4.0
[pip3] nvidia-cutlass-dsl-libs-base==4.4.0
[pip3] nvidia-ml-py==13.590.48
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvshmem-cu12==3.4.5
[pip3] nvidia-nvtx-cu12==12.8.90                                                                                                                                                                                                             [0/1879]
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0
[pip3] torchaudio==2.10.0
[pip3] torchvision==0.25.0
[pip3] transformers==4.57.6
[pip3] triton==3.6.0
[conda] Could not collect

==============================
         vLLM Info
==============================
ROCM Version                 : Could not collect
vLLM Version                 : 0.1.dev14246+g845ee348e (git sha: 845ee348e)
vLLM Build Flags:
  CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
        GPU0    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    NODE    0-47    0               N/A
NIC0    NODE     X      PIX
NIC1    NODE    PIX      X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1

==============================
     Environment Variables
==============================
NVIDIA_VISIBLE_DEVICES=void
NVIDIA_REQUIRE_CUDA=cuda>=13.1 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=570,driver<571 brand=grid,driver>=570,driver<571 brand=tesla,driver>=570,driver<571 brand=nvidia,driver>=570,driver<571 brand=quadro,driver>=570,driver<571 brand=quadrortx,driver>=570,driver<571 brand=nvidiartx,driver>=570,driver<571 brand=vapps,driver>=570,driver<571 brand=vpc,driver>=570,driver<571 brand=vcs,driver>=570,driver<571 brand=vws,driver>=570,driver<571 brand=cloudgaming,driver>=570,driver<571 brand=unknown,driver>=575,driver<576 brand=grid,driver>=575,driver<576 brand=tesla,driver>=575,driver<576 brand=nvidia,driver>=575,driver<576 brand=quadro,driver>=575,driver<576 brand=quadrortx,driver>=575,driver<576 brand=nvidiartx,driver>=575,driver<576 brand=vapps,driver>=575,driver<576 brand=vpc,driver>=575,driver<576 brand=vcs,driver>=575,driver<576 brand=vws,driver>=575,driver<576 brand=cloudgaming,driver>=575,driver<576 brand=unknown,driver>=580,driver<581 brand=grid,driver>=580,driver<581 brand=tesla,driver>=580,driver<581 brand=nvidia,driver>=580,driver<581 brand=quadro,driver>=580,driver<581 brand=quadrortx,driver>=580,driver<581 brand=nvidiartx,driver>=580,driver<581 brand=vapps,driver>=580,driver<581 brand=vpc,driver>=580,driver<581 brand=vcs,driver>=580,driver<581 brand=vws,driver>=580,driver<581 brand=cloudgaming,driver>=580,driver<581
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=13.1.1
LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64
NVIDIA_CTK_LIBCUDA_DIR=/usr/lib/x86_64-linux-gnu
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_root

@mergify mergify bot added gpt-oss Related to GPT-OSS models bug Something isn't working labels Feb 26, 2026
@jzakrzew
Copy link
Copy Markdown
Contributor Author

cc @yewentao256 @tijyojwad

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a batch invariance bug in the gpt-oss model. The changes are twofold: first, the router layer in gpt_oss.py is updated from torch.nn.Linear to vllm.ReplicatedLinear, which is necessary to integrate it with vLLM's linear layer handling. Second, the batch-invariant linear method in linear.py is now applied to all unquantized linear layers under CUDA when batch invariance is enabled, instead of being restricted to MoE router gates. This generalization is crucial for ensuring deterministic behavior across all relevant layers, including QKV and output projections, as required by the model on H100 GPUs. The changes are well-justified, targeted, and appear to be a solid fix for the reported issue.

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Feb 26, 2026
@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 26, 2026
…nt GEMM

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
@jzakrzew jzakrzew force-pushed the fix-gpt-oss-batch-invariance branch from 41e539f to 3bd318b Compare February 27, 2026 09:29
@yewentao256 yewentao256 enabled auto-merge (squash) February 27, 2026 15:51
@yewentao256 yewentao256 merged commit 1f3dbd9 into vllm-project:main Feb 27, 2026
64 checks passed
sergey-zinchenko pushed a commit to sergey-zinchenko/vllm that referenced this pull request Mar 1, 2026
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: Sergey Zinchenko <sergey.zinchenko.rnd@gmail.com>
EanWang211123 pushed a commit to EanWang211123/vllm that referenced this pull request Mar 2, 2026
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
bhoomit pushed a commit to bhoomit/vllm that referenced this pull request Mar 2, 2026
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
machov added a commit to machov/vllm that referenced this pull request Mar 10, 2026
* [CI][AMD][BugFix] Add  torch.cuda.set_device to test_punica_ops so punica kernels execute on same device as tensor (#34985)

Signed-off-by: Randall Smith <Randall.Smith@amd.com>

* fix(mxfp4): Disable monolithic path for TRITON backend with EP (#34270)

Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>

* [ROCm][CI] Extending attention backend coverage for Eagle spec decode tests (#35265)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [ROCm][CI] Amending deletion of AMD mirror (#35322)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [offloader] v2: Hide weight onloading latency via prefetching (#29941)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>

* [UX] Add `--performance-mode {balanced,interactivity,throughput}` (#34936)

Signed-off-by: mgoin <mgoin64@gmail.com>

* [MoE Refactor] MXFP4 Cutlass Experts to MK (#34542)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>

* [UX] Add `--moe-backend` arg for explicit kernel selection (#33807)

Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>

* [Bugfix] Fix AttributeError in SMControlContextManager (#35338)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

* [Bugfix] Fix CUDA compatibility path setting for both datacenter and consumer NVIDIA GPUs (#33992)

Signed-off-by: Seungmin Kim <8457324+ehfd@users.noreply.github.com>
Signed-off-by: Andrew Mello <19512127+88plug@users.noreply.github.com>
Co-authored-by: 88plug <19512127+88plug@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>

* [BugFix] Fix fp4 quant kernel on CUDA 12.8 (#35210)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>

* openpangu-vl support video input (#34134)

Signed-off-by: hujiaxin <524446785@qq.com>
Signed-off-by: Emilie1001 <79921183+Emilie1001@users.noreply.github.com>
Co-authored-by: Emilie1001 <79921183+Emilie1001@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)

Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>

* [Model Runner V2] Add coding style guide (#35325)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [CPU][Feat]  Enable KleidiAI INT8_W4A8 for all input dtypes (#34890)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>

* [torch.compile] Sequence Parallelism threshold compile ranges (#28672)

Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>

* [BugFix] anthropic/serving_messages: fix tool call arguments streaming (#34887)

Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>

* [Misc][Harmony] Move Responses API only harmony utils to responses/harmony.py (#35339)

Signed-off-by: sfeng33 <4florafeng@gmail.com>

* [Benchmark] Simplify SLA scan (#35306)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [BugFix][XPU] Fix speculative decoding on Intel XPU due to bug with `IGC_ForceOCLSIMDWidth=16` (#35298)

Signed-off-by: Ofir Zafrir <ofir.zafrir@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>

* [XPU][8/N] Fix kernel bugs in XPU LoRA and MOE LORA (#34115)

Signed-off-by: chzhang <chaojun.zhang@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>

* [ROCm] Add extra step in config initialization to populate custom ops before compilation config init (#34848)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [Bugfix][Hardware][AMD] Gate FP4 ops on gfx950 to prevent MI300X crash (#35250)

Signed-off-by: c0de128 <kevin.mckay@outlook.com>

* [Test] Add tests for n parameter in chat completions API (#35283)

Signed-off-by: KrxGu <krishom70@gmail.com>

* [Model] Ring 2.5 (#35102)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>

* [Benchmarks] Plot benchmark timeline and requests statistics (#35220)

Signed-off-by: Sophie du Couédic <sop@zurich.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Hardware][Powerpc]Enable prefix caching and chunked prefill for ppc64le (#35081)

Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: Akash kaothalkar <akash.kaothalkar@ibm.com>

* [Bugfix] [Qwen3.5]Fix Qwen3.5 FP8 quantization: tuple shard_id weight loading (#35289)

Signed-off-by: daowu.hzy <daowu.hzy@alibaba-inc.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [XPU] use fixed UMD version in dockerfile.xpu (#35392)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

* Remove `bc-lint` (#35274)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Bugfix] Fix Qwen2.5-Omni and Qwen3-Omni mixed-modality embed regression (#35368)

Signed-off-by: linyueqian <linyueqian@outlook.com>

* [Bugfix] fix device_name for routing replay (#34336)

Signed-off-by: liyongwen <1310439159@qq.com>

* [Bugfix] Fix uint32 overflow in Mamba selective scan state pointer arithmetic (#35275)

Signed-off-by: Josephasafg <ajgard7@gmail.com>

* [Misc] Standardize handling of `mm_processor_kwargs.size` (#35284)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bug] Fix missing <think> tag after tool call in MiniMax 2.1 (#35352)

Signed-off-by: 冬马 <chenxinke@cai-inc.com>
Co-authored-by: 冬马 <chenxinke@cai-inc.com>

* [Model] Add nvidia/llama-nemotron-embed-vl-1b-v2 multimodal embedding model (#35297)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>

* [Refactor] Remove dead or duplicate func utils or variables (#35318)

Signed-off-by: yewentao256 <zhyanwentao@126.com>

* [ROCm] Add dynamic mxfp4 quantization for DeepSeek V2 projection layers (#34157)

Signed-off-by: Doug Lehr <douglehr@amd.com>
Signed-off-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>

* [ROCm] Update the torch version in rocm_build.txt to use the official 2.10 release (#34387)

Signed-off-by: Sage Moore <sage@neuralmagic.com>

* [BugFix][kv_offload]: Fix kernel block size detection (#35125)

Signed-off-by: Or Ozeri <oro@il.ibm.com>

* Add GlmOcrConfig for GLM-OCR model type recognition (#34982)

* [Perf] Optimize maxsim scores computation for pooling models, 13.9% E2E throughput improvement (#35330)

Signed-off-by: yewentao256 <zhyanwentao@126.com>

* [Core] Support `min_tokens` with speculative decoding (#32642)

Signed-off-by: qianlihuang <yiliu.dong@qq.com>
Co-authored-by: qianlihuang <yiliu.dong@qq.com>

* [Refactor] Remove dead code for attention benchmark script (#35418)

Signed-off-by: yewentao256 <zhyanwentao@126.com>

* [BugFix] Align fused MoE-LoRA kernel config with actual weight shapes  (#34396)

Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>

* [Bugfix] Remove erroneous lower bound on LoRA vocab size constraint (#35354)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

* [Model Runner V2] Fix error-handling (#35063)

Signed-off-by: Nick Hill <nickhill123@gmail.com>

* [Model Runner V2] Add model states [1/N]  (#35350)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [Model Runner V2] Prepare attn metadata in ModelState [2/N] (#35383)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* fix(reasoning): Qwen3ReasoningParser returns truncated output as reasoning (#35230)

Signed-off-by: stakeswky <stakeswky@users.noreply.github.com>
Co-authored-by: stakeswky <stakeswky@users.noreply.github.com>

* [Performance] Extract KV cache update op from flashinfer forward (#35422)

Signed-off-by: ElizaWszola <ewszola@redhat.com>

* Nemotron: use per-layer config in NemotronHMLPDecoderLayer for heterogeneous models (#35396)

Signed-off-by: dafrimi <dafrimi@nvidia.com>

* add mixed precision support for modelopt (#35047)

Signed-off-by: Shiyang Chen <shiychen@nvidia.com>

* [Bugfix] Fix MessageQueue connect_ip for cross-node data parallelism (#35429)

Signed-off-by: Lu Fang <fanglu@fb.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>

* [WideEP] Remove pplx all2all backend (#33724)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [Kernel][perf] optimize NCCL symm_mem vs custom_AR selection thresholds (#33839)

Signed-off-by: <>
Signed-off-by: pkousha <43781676+pkousha@users.noreply.github.com>
Co-authored-by: Pouya Kousha <pkousha@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>

* [ROCm][Quantization] GPT OSS Upstream MoE wmxfp4_afp8 with static scales (#30357)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>

* [Bugfix] Fix KV Scale loading for MLA Models (#35430)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>

* [Update] Use FlashInfer fast_decode_plan directly instead of replication (#34687)

Signed-off-by: Andrii <askliar@nvidia.com>
Co-authored-by: Andrii <askliar@nvidia.com>

* [Performance] Cublas Bf16 Gate with Fp32 Output (#35121)

Signed-off-by: Roi Koren <roik@nvidia.com>

* [CI] Actually run tests/kernels/quantization/test_block_fp8.py in CI (#34274)

* [Bugfix] Emit reasoning_part events in simple streaming path for Resp… (#35184)

Signed-off-by: Daniel Salib <danielsalib@meta.com>

* [compile] Invalidate cache for cpu flags (#35119)

Signed-off-by: angelayi <yiangela7@gmail.com>

* [Core]Extract is_last_rank in Ray for tpu to override (#33012)

Signed-off-by: Chenyaaang <chenyangli@google.com>

* [Misc] Move `GPUModelRunner.prepare_kernel_block_sizes` to utils (#35400)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [Bugfix] Fix Qwen3NextForCausalLM packed_modules_mapping (#35413)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>

* use 'max_active_experts' for moe lora input size (#33197)

Signed-off-by: gnovack <gnovack@amazon.com>

* [Bug] Fix outdated links in source code (#35314)

Signed-off-by: yewentao256 <zhyanwentao@126.com>

* [BugFix] Repo utils debug print patch (#35434)

Signed-off-by: Daniel Huang <daniel1.huang@intel.com>

* [Bugfix] disable allreduce_rms_fusion by default when pp size > 1 (#35424)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>

* [Bug] correct out dtype of rms_norm_gated native path (#35369)

Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>

* [Model Performance] Add Qwen3MoE tuned MoE configs for H200 (#35457)

Signed-off-by: Chengyi Nie <cnie@roblox.com>
Co-authored-by: Chengyi Nie <cnie@roblox.com>

* [Bugfix] Use 'sum' reduction instead of 'avg' in Async TP reduce-scatter (#33088)

Signed-off-by: Xingran Wang <wangxingran123456@outlook.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Co-authored-by: Hongjian Zhang <hirokenovo@gmail.com>

* [Bugfix] Replace assert with ValueError for response_format validation in completions endpoint (#35456)

Signed-off-by: umut-polat <52835619+umut-polat@users.noreply.github.com>

* Flashinfer cuDNN backend for Qwen3 VL ViT attention (#34580)

Signed-off-by: Max Hu <maxhu@nvidia.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Shang Wang <shangw@nvidia.com>

* [Bugfix] Add missing activation attr to RMSNormGated (#35423)

Signed-off-by: tibG <naps@qubes.milou>
Co-authored-by: tibG <naps@qubes.milou>

* [compile] Cleanup: Remove unnecessary +rms_norm forcing for sequence parallelism (#35410)

Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>

* Revert "Add GlmOcrConfig for GLM-OCR model type recognition" (#35512)

* [Bugfix] Fix DCP + FA3 crash due to missing num_splits in _forward_with_dcp (#35082)

Signed-off-by: haosdent <haosdent@gmail.com>

* [Bugfix] Fix check_interleaved_audio_video false positive for batched non-interleaved requests (#35487)

Signed-off-by: linyueqian <linyueqian@outlook.com>
Co-authored-by: Roger Wang <hey@rogerw.io>

* [Bugfix] Handle case when kimi ends reasoning with a tool call (#33646)

Signed-off-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: mondaylord <20212010046@fudan.edu.cn>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>

* Add @BoyuanFeng to CODEOWNERS (#35317)

Signed-off-by: Boyuan Feng <boyuan@meta.com>

* [Core] Fix `gpu_worker.py` pre-commit errors (#35312)

Signed-off-by: Nick Hill <nickhill123@gmail.com>

* [Kernel] [Helion] [7/N] Use HOP to represent Helion Kernel call to enable fx tracing and pattern matching (#34390)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>

* [Model] Add huggingface skt/A.X-K1 model (#32407)

Signed-off-by: Sungwan(Alex) Kim <sw0726.kim@sktelecom.com>
Signed-off-by: fort726 <38447663+fort726@users.noreply.github.com>
Co-authored-by: Sungwan(Alex) Kim <sw0726.kim@sktelecom.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>

* [Misc] Fill in some v1 CODEOWNERS gaps (#35524)

Signed-off-by: Nick Hill <nickhill123@gmail.com>

* [perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)

Signed-off-by: Huamin Li <3ericli@gmail.com>

* [Doc] Fix link to Llama chat template for usability (#35525)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* Support parakeet as audio encoder for nemotron-nano-vl (#35100)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Roger Wang <hey@rogerw.io>

* [BugFix] Fix 3D rope in transformers backend (#35097)

Signed-off-by: raushan <raushan@huggingface.co>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Model Runner V2] Warmup kernels (#35172)

Signed-off-by: Nick Hill <nickhill123@gmail.com>

* [compile] Fix caching error over pytree slice node. (#35308)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>

* [Transformers backend] Ignore MTP weights when num_nextn_predict_layers=0 (#34888)

Signed-off-by: SteadfastAsArt <695488173@qq.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Bugfix] Add monkeypatch to prevent race condition from writing (#35420)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>

* [DP] Only use DP padding when cudagraphs are actually used  (#34102)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>

* [Bugfix][Model] Fix gpt-oss batch invariance (#35404)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>

* [Feat][RL][2/2] Native Weight Syncing API: IPC (#34171)

Signed-off-by: hao-aaron <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>

* [ROCm] Enabling encoder and encoder-decoder on ROCm and AITER unified backends (#35334)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>

* [ROCm]: fix aiter rope functionalization (#35533)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>

* [Refactor][Kernel] Add global helper to deduplicate vectorized memory ops (#35105)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>

* [misc] cleanup one level of error stack when nixl fails to initialize (#35517)

Signed-off-by: youkaichao <youkaichao@gmail.com>

* [Misc] Clean up ResponsesRequest model validators (#35531)

Signed-off-by: umut-polat <52835619+umut-polat@users.noreply.github.com>

* [Model Runner V2] Support pooling models (#35120)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [Model Runner V2] Move MM encoder to Model States [3/N] (#35564)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [ROCm][Quantization] Add Composable Kernel (CK) backend support for M… (#34301)

Signed-off-by: Doug Lehr <douglehr@amd.com>
Signed-off-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Signed-off-by: Douglas Lehr <Doug.Lehr@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>

* [ROCm] Add `stablelm` Head Size 80 To Supported Head Sizes For ROCM_ATTN (#35527)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>

* [Bugfix] Fixes for SLA finder (#35537)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [MTP] Validate that MTP weights are actually loaded (#35548)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>

* [CI/Build] CPU release supports both of AVX2 and AVX512 (#35466)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Co-authored-by: jiang1.li <jiang1.li@intel.com>

* [1/N] Elastic EP Milestone 2 (#34861)

Signed-off-by: Yongji Wu <wuyongji317@gmail.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>

* [Bugfix] Propagate compilation_time from workers to main process for TP>1 (#35503)

Signed-off-by: Huy Do <huydhn@gmail.com>

* [Bugfix] Move chat completion response_format validation to Pydantic model_validator (#35510)

Signed-off-by: umut-polat <52835619+umut-polat@users.noreply.github.com>

* [EPLB] Enforce sync eplb for NCCL-based all2all backend (#35212)

Signed-off-by: ilmarkov <markovilya197@gmail.com>

* [ROCm][CI] Adding infiniband mappings for moriio tests (#35170)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [ROCm] Derive device capability from GCN arch string without CUDA init (#35069)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [ROCm][CI] Expose tests to AMD production CI and fix amdsmi heap corruption (#35071)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [CI] add trainer_send_weights for MockWeightTransferEngine (#35589)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Misc] Change logging level from info to debug for tool parser import (#35575)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* [Benchmark] Rename SLA Finder to Workload Explorer (#35586)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Benchmark] Improve UX of sweep scripts (#35600)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [ROCm][CI] Parametrize vision score tests across attention backends with per-backend tolerances (#35571)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* Add padding support to wvSplitK solution for skinny GEMMs (#33762)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>

* [Feature]Supports Anthropic Thinking Block (#33671)

Signed-off-by: mariohong <mariohong128@gmail.com>
Co-authored-by: zetaohong <i-hongzetao@stepfun.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>

* add io_process_plugin for sparse embedding (#34214)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: Augusto Yao <augusto.yjh@antgroup.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [Feat] Add CUDA torch fallbacks for fp8_mqa_logits/fp8_paged_mqa_logits_torch function (#35271)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* custom dataset img support base64 (#35280)

Signed-off-by: xjx <493337577@qq.com>

* Fix Qwen3_5MTP packed_modules_mapping for gate_up_proj (#35581)

* [Fix] Avoid sending image input to other PP ranks (#35405)

Signed-off-by: emricksini-h <emrick.birivoutin@hcompany.ai>
Co-authored-by: Roger Wang <hey@rogerw.io>

* [Benchmark] Avoid unnecessary video download in MMVU (#35618)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Deprecation] Deprecate code in 0.17 as scheduled (#35441)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

* [Chore] Cleanup BNB utilization dead code (#35620)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Bugfix] Fix Anthropic API base64 image handling in Messages endpoint (#35557)

Signed-off-by: Martin Vit <martin@voipmonitor.org>

* [Model Runner V2] Add ModelStateInterface [4/N] (#35621)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* Add TMA support to fused_moe_lora kernel (#32195)

Signed-off-by: gnovack <gnovack@amazon.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>

* [Bugfix][Model] Fix Qwen3.5/Qwen3Next ignoring --dtype flag on older GPUs (#35617)

* Fix typo: implictly -> implicitly in isaac.py docstring (#35646)

* [AMD][CI] Support Triton attention with ExampleConnector (#34931)

Signed-off-by: Ryan Rock <ryan.rock@amd.com>

* [Model Runner V2] Minor refactoring for EncoderRunner (#35628)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [Mamba1] - Kernel Level Chunk Alignment for Prefix Caching (#34798)

Signed-off-by: Josephasafg <ajgard7@gmail.com>

* [MISC] Fixing a null reference by removing parallel_utils from mypy EXCLUDE (#35630)

Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>

* fix(mxfp4): return is_monolithic=False when LoRA is enabled for Triton backend (#35382)

Signed-off-by: Seungho Yoon <yoonsnowdev@gmail.com>

* [Model Runner V2] Use block table apis for capture inputs (#35671)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [Bugfix] Fix dtype mismatch in RMSNormGated.forward_native() during torch.compile (#35256)

Signed-off-by: haosdent <haosdent@gmail.com>

* [torch.compile] Undo the fast_moe_cold_start hack in torch>=2.11 (#35475)

Signed-off-by: Richard Zou <zou3519@gmail.com>

* Revert "[Bugfix] Disable TRTLLM attention with KV transfer enabled (#33192)" (#34832)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>

* [Attention] FA4 integration (#32974)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>

* Fix deprecated v1 config tests (#35327)

Signed-off-by: Jesse Cai <jessecai@fb.com>

* [XPU] fix mxfp4 activation type (#35691)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

* [ROCm] add amd-quark package in requirements for rocm to use quantized models (#35658)

Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>

* [ROCm][CI] Disable skinny GEMMs in language model standard tests to fix non-determinism (#35152)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [Kernel] Integrate SM100 MXFP8 blockscaled grouped MM and quant kernels (#34448)

Signed-off-by: EdalatiAli <aliedalati@cohere.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>

* [Rocm][CI] Fix LM Eval Large Models (H100) test group (#34750)

Signed-off-by: charlifu <charlifu@amd.com>

* [CI] Defining extended V1 e2e + engine tests (#35580)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [Misc] Bound NIXL upper bound version (#35495)

Signed-off-by: NickLucche <nlucches@redhat.com>

* [CPU][Distributed] Fix Enable _CPUSHMDistributed only when TP/PP ranks share the same SHM group name (#34169)

Signed-off-by: Charles Ashby <charlesa.l@hotmail.com>

* [Misc] Cleanup useless `current_platform` import (#35715)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>

* [Feat] Supports Anthropic Messages count_tokens API (#35588)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>

* Fix unresolved-import errors when using Astral's ty by removing src.root (#35681)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>

* [MyPy][BugFix] Check profiler is assigned before calling start() on it  (#35505)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Docs] Add breadcrumbs for better UX (#35749)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Fix Bug]`num_active_loras` always equals to zero  (#34119)

Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>

* [Performance] Extract kv update ops from MLA attention backends (#34627)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Di Wu <dw2761@nyu.edu>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>

* [CI] Fix mypy for vllm/device allocator (#35518)

Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Core] Move test utility to test file (#35672)

Signed-off-by: Turner Jabbour <doubleujabbour@gmail.com>

* [Doc] Improve UX of `--enable-log-requests` (#35723)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [CI][HPU] Pin vllm commit compatible with vllm-gaudi - HPU tests (#35307)

Signed-off-by: PatrykWo <patryk.wolsza@intel.com>

* [CI/Build] Enable Qwen3.5 tests on CI (#35763)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [BugFix][Model]Fix the garbled code in Ernie4.5-VL caused by fast_moe_cold_start (#35587)

Signed-off-by: wangyafeng <wangyafeng@baidu.com>

* [torch.compile] Improve cold and warm start compile tests (#35709)

Signed-off-by: Richard Zou <zou3519@gmail.com>

* [Spec Decode] Add hidden states extraction system (#33736)

Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>

* [KVConnector] Auto-downgrade to PIECEWISE cudagraph mode for layerwise async ops (#31057)

Signed-off-by: Yashwant Bezawada <yashwant_b@me.com>

* [ci] Add Ray compatibility check informational CI job (#34672)

Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>

* [BUG] Fix rlhf_async example (#35788)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>

* [Bugfix] Fix missing sequence_lengths in qwen3_omni_moe_thinker (#35741)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>

* clean unused cudagraph_batch_sizes (#35552)

Signed-off-by: Boyuan Feng <boyuan@meta.com>

* [MoE][Perf] Wrap DSV3 QKVAProj GEMM in custom op for torch.compile (#35751)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>

* [Bugfix] Fix MM processor test for Qwen3.5 (#35797)

Signed-off-by: Roger Wang <hey@rogerw.io>

* [All Reduce] Change default backend of Flashinfer All Reduce to trtllm (#35793)

Signed-off-by: hjjq <hanjieq@nvidia.com>

* [ROCm][CI] Fix backslash-continuation in pytest marker re-quoting and treat exit code 5 as success (#35798)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [Model] Add support for nvidia/llama-nemotron-rerank-vl-1b-v2 (#35735)

Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>

* [XPU][NIXL] Add GPUDirect RDMA support for XPU (#35270)

Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>

* [Model Runner V2][Perf] align dummy_run tokens to uniform decode for dp cudagraph (#35376)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>

* [Tool Parser] Fix Qwen3Coder streaming parameter loss with speculative decode (#35615)

Signed-off-by: Martin Vit <martin@voipmonitor.org>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* [CI] Temporarily Disable Nightly Failures (#35770)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>

* [ModelRunnerV2] Rename sampler functions and variables for clarity (#35459)

Signed-off-by: Andy Lo <andy@mistral.ai>

* [Docs][Model Runner V2] Add Design Docs (#35819)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [V0 deprecation] Remove Swin model (#35821)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Model Runner V2] Use ModelState.prepare_attn() for cuda graph capture [5/N] (#35774)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [Core] Add optional flags to check for repetitive token patterns in engine output (#35451)

Signed-off-by: aykoppol <aykoppol+git@gmail.com>

* [CI/Build] Automatically patch video metadata for multimodal processor test (#35822)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Refactor] Fix maxsim cuda platform and add cli to control it (#35427)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

* [ROCm][CI] Fix Assertion Logic For `test_gpt_oss` (#35806)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>

* [CI/Build] Trigger processor tests on registry update (#35824)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [BugFix] Fix cmake based incremental install (wrong vllm install dir) (#35773)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>

* [MISC] Removed unused function find_all_indices() from tool_parsers/utils.py (#35683)

Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>

* [Misc] Fix typos in comments: explict→explicit, paramaters→parameters (#35648)

* Fix TYPE_CHECKING stub defaults in envs.py to match actual runtime defaults (#35645)

* [ROCm] [Release] Change the package from `aiter` to `amd-aiter` (#35198)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>

* add regression test (#35834)

Signed-off-by: hallerite <git@hallerite.com>

* [CI/Build][Intel] Add new performance benchmarks for Intel Gaudi 3 (#31025)

Signed-off-by: Szymon Reginis <sreginis@habana.ai>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>

* [Perf] [Hybrid] Copy num_accepted_tokens in non-blocking way when not using prefix caching (#35442)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>

* [CI] And PPL test for Qwen3.5. (#35853)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [Bugfix] Avoid src/dst as None in irecv/isend_tensor_dict (#35754)

Signed-off-by: jiang1.li <jiang1.li@intel.com>

* [Frontend][1/n] Improve pooling entrypoints | classify. (#35604)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>

* [ROCm] [CI] Add new fusion test cases that are relevant to vLLM IR Ops (#34307)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>

* [BugFix] Add support for MTP num_speculative_tokens > 1 with sparse MLA (#34552)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>

* TRTLLM gen-full attn Test Coverage (#34986)

Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>

* fix: Ensure invalid audio files return 400 error (#34715)

Signed-off-by: Jason Ozuzu <jasonozuzu@cohere.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>

* [CI] Bump `num_speculative_tokens` to 3 in nightly DeepSeek tests (#35882)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>

* [CI] Temporarily Disable Llama4 MoE Refactor Test (#35870)

Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>

* [MoE Refactor] Create MK for TRTLLM Kernels (#32564)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>

* [ROCm][Bugfix]: Disable AITER Triton ROPE by default (#35601)

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>

* [ROCm][CI] Fix TP size issue for `test_gpt_oss` (#35887)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>

* [Bugfix] Fix misnamed parameter in compressed_tensors_moe.py (#35813)

Signed-off-by: Bill Nell <bnell@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>

* [Model Runner V2] Fix inputs_embeds=None bug for MM models (#35917)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>

* [CI/Build] Allow mounting AWS credentials for sccache S3 auth (#35912)

Signed-off-by: Amr Mahdi <amrmahdi@meta.com>

* [Model Runner V2] support dp & ep for spec decoding (#35294)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai>

* [Core] Move save_tensorized_model logic to Worker (#35825)

Signed-off-by: Nick Hill <nickhill123@gmail.com>

* [Bugfix] Fix coord_socket assertion in DPEngineCoreProc for offline DP mode (#35916)

Signed-off-by: Jaewon Lee <jaewon@meta.com>

* [ROCm][CI] Support async weight transfer example with platform-aware determinism (#35710)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* Enable bnb for multiple indices weight (#35838)

Signed-off-by: xjx <493337577@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Bugfix] Fix EVS implementation for Qwen3 VL (#33607)

Signed-off-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com>

* [PluggableLayer][MM] Add PluggableLayer for RelPosAttention (#33753)

Signed-off-by: shen-shanshan <467638484@qq.com>

* [model] support FireRedASR2 (#35727)

Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Chore] Remove debug code in model implementation (#35883)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [Refactor] Clean up processor kwargs extraction (#35872)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* [Bugfix] Guard mm_token_type_ids kwarg in get_mrope_input_positions (#35711)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>

* [Rocm][CI] Fix ROCm LM Eval Large Models (8 Card) (#35913)

Signed-off-by: charlifu <charlifu@amd.com>

* [BugFix] Support tool_choice=none in the Anthropic API (#35835)

Signed-off-by: ZhongsJie <zhongsjie@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>

* [Bugfix] Improve engine ready timeout error message (#35616)

Signed-off-by: damaozi <1811866786@qq.com>

* [cohere][fix][spec-decode]: fix crash when allowed_token_ids is set without penalties (#35654)

Signed-off-by: kkt-cohere <komal@cohere.com>

* Support Audio Extraction from MP4 Video for Nemotron Nano VL (#35539)

Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Andrii <askliar@nvidia.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Andrii Skliar <askliar@oci-nrt-cs-001-vscode-01.cm.cluster>
Co-authored-by: Andrii <askliar@nvidia.com>
Co-authored-by: root <root@pool0-03748.cm.cluster>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: root <root@pool0-02416.cm.cluster>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: root <root@pool0-04880.cm.cluster>

* [Core] Remove busy loop from idle buffer readers (#28053)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>

* [Bugfix] Add missing dynamic_arg_dims for Qwen3-ASR torch.compile (#35869)

Signed-off-by: Nathan Price <nathan@abridge.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>

* [ROCm][Bugfix] Fall back from CK MXFP4 MoE when GEMM dimensions are unsupported (#35893)

Signed-off-by: Li <chuali@amd.com>

* [Hardware] Replace `torch.cuda.empty_cache` with `torch.accelerator.empty_cache` (#30681)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [XPU] bump vllm-xpu-kernels to v0.1.3 (#35984)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

* [Bugfix] Cap FULL decode cudagraph sizes for Mamba/hybrid models (#34094) (#34571)

Signed-off-by: haosdent <haosdent@gmail.com>
Co-authored-by: zjy0516 <riverclouds.zhu@qq.com>

* [Feature] Add basic metrics for /realtime endpoint (#35500)

Signed-off-by: Thomas Pouget-Abadie <thomaspou@microsoft.com>
Signed-off-by: pougetat <thomas.pougetabadie@gmail.com>
Co-authored-by: Thomas Pouget-Abadie <thomaspou@microsoft.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [MISC] fixed tool_parser mypy errors (#35640)

Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* [Bugfix][CPUOffloadingManager] Prevent eviction of already-stored blocks in LRU/ARC `prepare_store()` (#35846)

Signed-off-by: Ronen Schaffer <ronen.schaffer@ibm.com>

* [Bugfix][Model] Fix FP8 k_scale/v_scale not loaded for Qwen3-MoE (#35656)

Signed-off-by: raghavan <oneraghavan@gmail.com>

* [BugFix] Fix implicit and incorrect assumption on ECConnector is_producer (#34783)

Signed-off-by: Qi Wang <qiwa@nvidia.com>

* [Bugfix] Make `kaldi_native_fbank` optional (#35996)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>

* fix(mooncake): resolve HBM leak from stuck WAITING_FOR_REMOTE_KVS requests

Fixes #35943

When KV transfer fails (due to P timeout, crash, or network interruption),
D-side requests remained stuck in WAITING_FOR_REMOTE_KVS state permanently,
causing progressive HBM leak that eventually stalls the decode engine.

Root cause:
1. PullReqMeta.expire_time was never set (unlike P-side SendBlockMeta.expire_time)
2. Transfer errors were silently dropped in process_pulling_result()
3. ZMQ timeout exceptions in receive_kv_from_single_worker() didn't notify scheduler

Solution:
- Set PullReqMeta.expire_time in receive_kv() method (mirrors P-side pattern)
- Add timeout checking in fetch_finished_recving_reqs() to handle expired requests
- Handle err_reqs in process_pulling_result() by adding to finished_recving_reqs
- Handle ZMQ timeout exceptions by marking requests as finished for cleanup

This ensures stuck requests are properly cleaned up and their KV cache blocks
are released, preventing the progressive memory leak.

Tested: Syntax validation passes, no breaking API changes
Signed-off-by: machov <mv1742@nyu.edu>

* fix: address CI feedback - line length and missing metadata attribute

Signed-off-by: machov <mv1742@nyu.edu>

* fix(mooncake): resolve race conditions in KV transfer handling

- Fix critical race condition in fetch_finished_recving_reqs that could cause
  RuntimeError due to concurrent dictionary modification during iteration
- Fix race condition in receive_kv where expire_time was being set on all
  requests to the same value, potentially causing premature timeouts

Addresses code review feedback from gemini-code-assist on PR #36014

* fix(mooncake): address review comments from PR #36014

- Move self.reqs_to_recv.update() into _start_load_kv() so all
  reqs_to_recv operations happen in the same receiver_loop coroutine,
  eliminating race conditions (dtcccc comment at line 1272)
- Remove unnecessary list() copies in fetch_finished_recving_reqs since
  all access to reqs_to_recv is now in the same event loop (dtcccc
  comment at line 984)
- Call process_pulling_result() before ERROR status check so that err_reqs
  are always added to finished_recving_reqs for scheduler cleanup (dtcccc
  comment at line 1111)
- Simplify expired request deletion (no defensive check needed in same loop)

Co-authored-by: machov <43248948+machov@users.noreply.github.com>

* fix(mooncake): address review comments from PR #36014

- Move self.reqs_to_recv.update() into _start_load_kv() so all
  reqs_to_recv operations happen in the same receiver_loop, no race
- Remove unnecessary list() copies in fetch_finished_recving_reqs
  (no need to copy after moving, per reviewer comment)
- Call process_pulling_result() before ERROR status check so err_reqs
  are always added to finished_recving_reqs for scheduler cleanup

Co-authored-by: machov <43248948+machov@users.noreply.github.com>

---------

Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: Elizabeth Thomas <email2eliza@gmail.com>
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Seungmin Kim <8457324+ehfd@users.noreply.github.com>
Signed-off-by: Andrew Mello <19512127+88plug@users.noreply.github.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: hujiaxin <524446785@qq.com>
Signed-off-by: Emilie1001 <79921183+Emilie1001@users.noreply.github.com>
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
Signed-off-by: sfeng33 <4florafeng@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Ofir Zafrir <ofir.zafrir@intel.com>
Signed-off-by: chzhang <chaojun.zhang@intel.com>
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Signed-off-by: KrxGu <krishom70@gmail.com>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Sophie du Couédic <sop@zurich.ibm.com>
Signed-off-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Signed-off-by: daowu.hzy <daowu.hzy@alibaba-inc.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: linyueqian <linyueqian@outlook.com>
Signed-off-by: liyongwen <1310439159@qq.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: 冬马 <chenxinke@cai-inc.com>
Signed-off-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Doug Lehr <douglehr@amd.com>
Signed-off-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: qianlihuang <yiliu.dong@qq.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: stakeswky <stakeswky@users.noreply.github.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: dafrimi <dafrimi@nvidia.com>
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: <>
Signed-off-by: pkousha <43781676+pkousha@users.noreply.github.com>
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
Signed-off-by: Andrii <askliar@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Daniel Salib <danielsalib@meta.com>
Signed-off-by: angelayi <yiangela7@gmail.com>
Signed-off-by: Chenyaaang <chenyangli@google.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: gnovack <gnovack@amazon.com>
Signed-off-by: Daniel Huang <daniel1.huang@intel.com>
Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
Signed-off-by: Chengyi Nie <cnie@roblox.com>
Signed-off-by: Xingran Wang <wangxingran123456@outlook.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: umut-polat <52835619+umut-polat@users.noreply.github.com>
Signed-off-by: Max Hu <maxhu@nvidia.com>
Signed-off-by: Max Hu <hyoung2991@gmail.com>
Signed-off-by: tibG <naps@qubes.milou>
Signed-off-by: haosdent <haosdent@gmail.com>
Signed-off-by: Koushik Dutta <koushd@gmail.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Signed-off-by: Sungwan(Alex) Kim <sw0726.kim@sktelecom.com>
Signed-off-by: fort726 <38447663+fort726@users.noreply.github.com>
Signed-off-by: Huamin Li <3ericli@gmail.com>
Signed-off-by: Martin Hickey <martin.hickey@ie.ibm.com>
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Signed-off-by: raushan <raushan@huggingface.co>
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
Signed-off-by: SteadfastAsArt <695488173@qq.com>
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
Signed-off-by: hao-aaron <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Douglas Lehr <Doug.Lehr@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Yongji Wu <wuyongji317@gmail.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com>
Signed-off-by: Huy Do <huydhn@gmail.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: mariohong <mariohong128@gmail.com>
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: Augusto Yao <augusto.yjh@antgroup.com>
Signed-off-by: xjx <493337577@qq.com>
Signed-off-by: emricksini-h <emrick.birivoutin@hcompany.ai>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
Signed-off-by: Ryan Rock <ryan.rock@amd.com>
Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>
Signed-off-by: Seungho Yoon <yoonsnowdev@gmail.com>
Signed-off-by: Richard Zou <zou3519@gmail.com>
Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Jesse Cai <jessecai@fb.com>
Signed-off-by: Hongxia Yang <hongxiay.yang@amd.com>
Signed-off-by: EdalatiAli <aliedalati@cohere.com>
Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Charles Ashby <charlesa.l@hotmail.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: Turner Jabbour <doubleujabbour@gmail.com>
Signed-off-by: PatrykWo <patryk.wolsza@intel.com>
Signed-off-by: wangyafeng <wangyafeng@baidu.com>
Signed-off-by: Fynn Schmitt-Ulms <fschmitt@redhat.com>
Signed-off-by: Yashwant Bezawada <yashwant_b@me.com>
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: hjjq <hanjieq@nvidia.com>
Signed-off-by: zhenwei-intel <zhenwei.liu@intel.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: aykoppol <aykoppol+git@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: hallerite <git@hallerite.com>
Signed-off-by: Szymon Reginis <sreginis@habana.ai>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Signed-off-by: Anshika Ojha <anshikao@nvidia.com>
Signed-off-by: Jason Ozuzu <jasonozuzu@cohere.com>
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Signed-off-by: Robert Shaw <robertgshaw2@gmail.com>
Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: Amr Mahdi <amrmahdi@meta.com>
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Jaewon Lee <jaewon@meta.com>
Signed-off-by: 2ez4bz <133824995+2ez4bz@users.noreply.github.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: zixiao <shunli.dsl@alibaba-inc.com>
Signed-off-by: ZhongsJie <zhongsjie@gmail.com>
Signed-off-by: damaozi <1811866786@qq.com>
Signed-off-by: kkt-cohere <komal@cohere.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Signed-off-by: Nathan Price <nathan@abridge.com>
Signed-off-by: Li <chuali@amd.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Thomas Pouget-Abadie <thomaspou@microsoft.com>
Signed-off-by: pougetat <thomas.pougetabadie@gmail.com>
Signed-off-by: Ronen Schaffer <ronen.schaffer@ibm.com>
Signed-off-by: raghavan <oneraghavan@gmail.com>
Signed-off-by: Qi Wang <qiwa@nvidia.com>
Signed-off-by: machov <mv1742@nyu.edu>
Co-authored-by: rasmith <Randall.Smith@amd.com>
Co-authored-by: Elizabeth Thomas <email2eliza@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Andreas Karatzas <akaratza@amd.com>
Co-authored-by: Ming Yang <minos.future@gmail.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Seungmin Kim <8457324+ehfd@users.noreply.github.com>
Co-authored-by: 88plug <19512127+88plug@users.noreply.github.com>
Co-authored-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: hujiaxin0 <524446785@qq.com>
Co-authored-by: Emilie1001 <79921183+Emilie1001@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Hanjie Qiu <50634613+hjjq@users.noreply.github.com>
Co-authored-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Daniele <36171005+dtrifiro@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Flora Feng <4florafeng@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Ofir Zafrir <ofir.zafrir@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: Chaojun Zhang <chaojun.zhang@intel.com>
Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: Kevin McKay <kevin.mckay@outlook.com>
Co-authored-by: Krish Gupta <krishom70@gmail.com>
Co-authored-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Sophie du Couédic <sop@zurich.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Akash kaothalkar <61960177+Akashcodes732@users.noreply.github.com>
Co-authored-by: Akash kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: HZY <19858181030@163.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Yueqian Lin <70319226+linyueqian@users.noreply.github.com>
Co-authored-by: Li-Yongwen <63399187+Li-Yongwen@users.noreply.github.com>
Co-authored-by: Asaf Gardin <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: stingoChen <40136864+stingoChen@users.noreply.github.com>
Co-authored-by: 冬马 <chenxinke@cai-inc.com>
Co-authored-by: Jakub Zakrzewski <jzakrzewski@nvidia.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Co-authored-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: hujia177 <hujia@meta.com>
Co-authored-by: Yiliu Dong <91178480+qianlihuang@users.noreply.github.com>
Co-authored-by: qianlihuang <yiliu.dong@qq.com>
Co-authored-by: Runkai Tao <129432511+RunkaiTao@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: 不做了睡大觉 <64798754+stakeswky@users.noreply.github.com>
Co-authored-by: stakeswky <stakeswky@users.noreply.github.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: danielafrimi <45691845+danielafrimi@users.noreply.github.com>
Co-authored-by: sychen52 <41452870+sychen52@users.noreply.github.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: pkousha <43781676+pkousha@users.noreply.github.com>
Co-authored-by: Pouya Kousha <pkousha@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
Co-authored-by: Andrii Skliar <andreyws96@gmail.com>
Co-authored-by: Andrii <askliar@nvidia.com>
Co-authored-by: roikoren755 <26850796+roikoren755@users.noreply.github.com>
Co-authored-by: daniel-salib <danielsalib@meta.com>
Co-authored-by: Angela Yi <yiangela7@gmail.com>
Co-authored-by: Chenyaaang <42742451+Chenyaaang@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: gnovack <gnovack@amazon.com>
Co-authored-by: Daniel Huang <pilotflyer824@gmail.com>
Co-authored-by: zofia <110436990+zufangzhu@users.noreply.github.com>
Co-authored-by: Chengyi Nie <54555896+chengyinie@users.noreply.github.com>
Co-authored-by: Chengyi Nie <cnie@roblox.com>
Co-authored-by: Wang Xingran <72983099+wangxingran222@users.noreply.github.com>
Co-authored-by: Hongjian Zhang <hirokenovo@gmail.com>
Co-authored-by: Umut Polat <52835619+umut-polat@users.noreply.github.com>
Co-authored-by: Max Hu <hyoung2991@gmail.com>
Co-authored-by: Max Hu <maxhu@nvidia.com>
Co-authored-by: Shang Wang <shangw@nvidia.com>
Co-authored-by: Tib <34336452+Tib-Gridello@users.noreply.github.com>
Co-authored-by: tibG <naps@qubes.milou>
Co-authored-by: haosdent <haosdent@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Koushik Dutta <koushd@gmail.com>
Co-authored-by: mondaylord <20212010046@fudan.edu.cn>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: Boyuan Feng <boyuan@meta.com>
Co-authored-by: Yanan Cao <gmagogsfm@users.noreply.github.com>
Co-authored-by: fort726 <38447663+fort726@users.noreply.github.com>
Co-authored-by: Sungwan(Alex) Kim <sw0726.kim@sktelecom.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Co-authored-by: Huamin Li <3ericli@gmail.com>
Co-authored-by: Martin Hickey <martin.hickey@ie.ibm.com>
Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
Co-authored-by: Zhengxu Chen <zhxchen17@fb.com>
Co-authored-by: SteadfastAsArt <35479342+SteadfastAsArt@users.noreply.github.com>
Co-authored-by: Lucas Kabela <lucaskabela@meta.com>
Co-authored-by: Aaron Hao <ahao@anyscale.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Ma Jian <jian1.ma@intel.com>
Co-authored-by: jiang1.li <jiang1.li@intel.com>
Co-authored-by: Itay Alroy <75032521+itayalroy@users.noreply.github.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: Huy Do <huydhn@gmail.com>
Co-authored-by: Ilya Markov <markovilya197@gmail.com>
Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com>
Co-authored-by: Mario Hong <86880754+mariohong128@users.noreply.github.com>
Co-authored-by: zetaohong <i-hongzetao@stepfun.com>
Co-authored-by: Augusto Yao <augusto.yjh@antgroup.com>
Co-authored-by: flutist <30485581+flutist@users.noreply.github.com>
Co-authored-by: cwazai <38356712+cwazai@users.noreply.github.com>
Co-authored-by: emricksini-h <emrick.birivoutin@hcompany.ai>
Co-authored-by: Martin Vit <martin@voipmonitor.org>
Co-authored-by: lailoo <ll1042668699@gmail.com>
Co-authored-by: lin-shh <82112156+lin-shh@users.noreply.github.com>
Co-authored-by: Ryan Rock <ryan.rock@amd.com>
Co-authored-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>
Co-authored-by: Seungho Yoon <yoonsnowdev@gmail.com>
Co-authored-by: Richard Zou <zou3519@users.noreply.github.com>
Co-authored-by: zhanqiuhu <49648934+ZhanqiuHu@users.noreply.github.com>
Co-authored-by: Jesse Cai <jessecai@fb.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Hongxia Yang <hongxiay.yang@amd.com>
Co-authored-by: EdalatiAli <aliedalati@cohere.com>
Co-authored-by: Charlie Fu <charlifu@amd.com>
Co-authored-by: Charles Ashby <charlesa.l@hotmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Di Wu <dw2761@nyu.edu>
Co-authored-by: Turner Jabbour <doubleujabbour@gmail.com>
Co-authored-by: Patryk Wolsza <patryk.wolsza@intel.com>
Co-authored-by: CSWYF3634076 <wangyafeng@baidu.com>
Co-authored-by: Fynn Schmitt-Ulms <fynnsu@outlook.com>
Co-authored-by: Yashwant Bezawada <yashwant_b@me.com>
Co-authored-by: Jeffrey Wang <jeffreywang@anyscale.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: liuzhenwei <zhenwei.liu@intel.com>
Co-authored-by: zhrrr <43847754+izhuhaoran@users.noreply.github.com>
Co-authored-by: Andy Lo <andy@mistral.ai>
Co-authored-by: aykoppol <aykoppol+git@gmail.com>
Co-authored-by: hallerite <git@hallerite.com>
Co-authored-by: Szymon Reginis <szymon.reginis@intel.com>
Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: ojhaanshika <anshikao@nvidia.com>
Co-authored-by: Anshika Ojha <anshikao@gb-nvl-059-compute09.nvidia.com>
Co-authored-by: JasonCohere <jasonozuzu@cohere.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: Amr Mahdi <amrmahdi@meta.com>
Co-authored-by: Giancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: Jaewon <52840625+jaewonlee-fb@users.noreply.github.com>
Co-authored-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: AllenDou <allen.dou@hotmail.com>
Co-authored-by: zixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: ShiJie Zhong <62382570+ZhongsJie@users.noreply.github.com>
Co-authored-by: lailoo <1811866786@qq.com>
Co-authored-by: Komal Kumar Teru <162363718+kkt-cohere@users.noreply.github.com>
Co-authored-by: Andrii Skliar <askliar@oci-nrt-cs-001-vscode-01.cm.cluster>
Co-authored-by: root <root@pool0-03748.cm.cluster>
Co-authored-by: root <root@pool0-02416.cm.cluster>
Co-authored-by: root <root@pool0-04880.cm.cluster>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Nathan Price <125999937+TheCodeWrangler@users.noreply.github.com>
Co-authored-by: Chuan (Richard) Li <chuali@amd.com>
Co-authored-by: pougetat <thomas.pougetabadie@gmail.com>
Co-authored-by: Thomas Pouget-Abadie <thomaspou@microsoft.com>
Co-authored-by: Ronen Schaffer <ronen.schaffer@ibm.com>
Co-authored-by: Raghavan <oneraghavan@gmail.com>
Co-authored-by: Qi Wang <qiwa@nvidia.com>
Co-authored-by: machov <mv1742@nyu.edu>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: machov <43248948+machov@users.noreply.github.com>
khairulkabir1661 added a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
This check was removed in upstream commit 1f3dbd9 (vllm-project#35404) to fix
gpt-oss batch invariance. The check was too restrictive and prevented
batch invariance from working for non-MoE layers.

It was accidentally re-introduced during our rebase conflict resolution.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
khairulkabir1661 added a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
This check was removed in upstream commit 1f3dbd9 (vllm-project#35404) to fix
gpt-oss batch invariance. The check was too restrictive and prevented
batch invariance from working for non-MoE layers.

It was accidentally re-introduced during our rebase conflict resolution.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants