From dc52155c54a68bfcd2b174c1ad72cab367165890 Mon Sep 17 00:00:00 2001 From: Dezhan Tu Date: Fri, 14 Nov 2025 01:02:12 -0800 Subject: [PATCH] Fixed gpt-oss _load_weights_other() parameter position bug (#28715) Summary: Signed-off-by: Dezhan Tu For `_load_weights_other()`, `ep_rank_start` and `ep_rank_end` positions are wrongly placed, leading to the failure of loading expert data in expert parallelism Test Plan: ``` from vllm import LLM llm = LLM("openai/gpt-oss-120b", tensor_parallel_size=2, enable_expert_parallel) output = llm.generate("Hi, vLLM is a") ``` Reviewed By: helunwencser, jackm321 Differential Revision: D87021773 --- vllm/model_executor/models/gpt_oss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 692ef605fe17..328c8c0ac4b7 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -641,8 +641,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ) else: return self._load_weights_other( - ep_rank_end, ep_rank_start, + ep_rank_end, heads_per_rank, head_start, weights,