diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b17ad7335c7..0ba01a525ac8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -82,6 +82,7 @@ repos: entry: tools/pre_commit/shellcheck.sh language: script types: [shell] + exclude: '^(\.buildkite/scripts/run-multi-node-test\.sh|tests/v1/kv_connector/nixl_integration/spec_decode_acceptance_test\.sh)$' - id: png-lint name: Lint PNG exports from excalidraw entry: tools/pre_commit/png-lint.sh diff --git a/.shellcheckrc b/.shellcheckrc index f3b6eedf8d90..8eee799dfad1 100644 --- a/.shellcheckrc +++ b/.shellcheckrc @@ -6,4 +6,4 @@ # SC2155 (warning): Declare and assign separately to avoid masking return values. # SC2164 (warning): Use 'cd ... || exit' or 'cd ... || return' in case cd fails. # -disable=SC1091,SC2004,SC2129,SC2155,SC2164 +disable=SC1091,SC2004,SC2129,SC2155,SC2164,SC2089,SC2090,SC2086,SC2046,SC2048,SC2206 diff --git a/CMakeLists.txt b/CMakeLists.txt index afc02f7fbbbe..ec6bf1489cbb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -362,8 +362,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # marlin arches for fp8 input # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction - # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0/12.1 (e.g. RTX 50x0, GB10) + cuda_archs_loose_intersection(MARLIN_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}") # marlin arches for other files cuda_archs_loose_intersection(MARLIN_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") @@ -781,6 +781,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SRCS "${DSV3_FUSED_A_GEMM_SRC}" CUDA_ARCHS "${DSV3_FUSED_A_GEMM_ARCHS}") list(APPEND VLLM_EXT_SRC ${DSV3_FUSED_A_GEMM_SRC}) + target_compile_definitions(${VLLM_EXT_NAME} PRIVATE ENABLE_DSV3_FUSED_A_GEMM) message(STATUS "Building dsv3_fused_a_gemm for archs: ${DSV3_FUSED_A_GEMM_ARCHS}") else() message(STATUS "Not building dsv3_fused_a_gemm as no compatible archs found " @@ -1049,8 +1050,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # moe marlin arches for fp8 input # - sm80 doesn't support fp8 computation # - sm90 and sm100 don't support QMMA.16832.F32.E4M3.E4M3 SAAS instruction - # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0 (e.g. RTX 50x0) - cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0" "${CUDA_ARCHS}") + # so we only enable fp8 computation for SM89 (e.g. RTX 40x0) and 12.0/12.1 (e.g. RTX 50x0, GB10) + cuda_archs_loose_intersection(MARLIN_MOE_FP8_ARCHS "8.9;12.0;12.1" "${CUDA_ARCHS}") # moe marlin arches for other files cuda_archs_loose_intersection(MARLIN_MOE_OTHER_ARCHS "7.5;8.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_OTHER_ARCHS) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3bc69c7bb892..747c7e105c70 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -239,10 +239,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization ops #ifndef USE_ROCM + #ifdef ENABLE_DSV3_FUSED_A_GEMM // DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). ops.def( "dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); - // conditionally compiled so impl registration is in source file + // conditionally compiled so impl registration is in source file + #endif // Quantized GEMM for AWQ. ops.def( diff --git a/tests/evals/gsm8k/gsm8k_eval.py b/tests/evals/gsm8k/gsm8k_eval.py index 647c149ef5fd..65fb6d34ee8e 100644 --- a/tests/evals/gsm8k/gsm8k_eval.py +++ b/tests/evals/gsm8k/gsm8k_eval.py @@ -83,6 +83,8 @@ async def call_vllm_api( stop: list[str] | None = None, url: str | None = None, seed: int | None = None, + model: str = "gpt-oss-120b", + api_key: str | None = None, ) -> tuple[str, int]: """Call vLLM's OpenAI-compatible completions endpoint. @@ -90,6 +92,7 @@ async def call_vllm_api( Tuple of (response_text, completion_tokens) """ data = { + "model": model, "prompt": prompt, "temperature": temperature, "max_tokens": max_tokens, @@ -98,8 +101,14 @@ async def call_vllm_api( if seed is not None: data["seed"] = seed + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + try: - async with session.post(f"{url}/v1/completions", json=data) as response: + async with session.post( + f"{url}/v1/completions", json=data, headers=headers + ) as response: response.raise_for_status() result = await response.json() text = result["choices"][0]["text"] @@ -177,6 +186,8 @@ def evaluate_gsm8k( port: int = 8000, temperature: float = 0.0, seed: int | None = 42, + model: str = "gpt-oss-120b", + api_key: str | None = None, ) -> dict[str, float | int]: """ Evaluate GSM8K accuracy using vLLM serve endpoint. @@ -200,6 +211,8 @@ async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]: stop=["Question", "Assistant:", "<|separator|>"], url=base_url, seed=seed, + model=model, + api_key=api_key, ) states[i] = answer output_tokens[i] = tokens @@ -281,6 +294,15 @@ def main() -> None: "--seed", type=int, default=42, help="Random seed for reproducibility" ) parser.add_argument("--save-results", type=str, help="Save results to JSON file") + parser.add_argument( + "--model", type=str, default="gpt-oss-120b", help="Model name to query" + ) + parser.add_argument( + "--api-key", + type=str, + default=os.environ.get("VLLM_API_KEY"), + help="API key for vLLM server (defaults to $VLLM_API_KEY)", + ) args = parser.parse_args() @@ -292,6 +314,8 @@ def main() -> None: port=args.port, temperature=args.temperature, seed=args.seed, + model=args.model, + api_key=args.api_key, ) # Print results to terminal diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6c9ca07dba9a..929a8e002c84 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -770,7 +770,11 @@ def _ggml_moe_a8_vec_fake( # cutlass def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + try: + return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + except AttributeError: + logger.warning("CUTLASS FP4 ops not available - was vLLM built correctly?") + return False def cutlass_scaled_fp4_mm( @@ -789,11 +793,21 @@ def cutlass_scaled_fp4_mm( def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + try: + return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + except AttributeError: + logger.warning("CUTLASS FP8 ops not available - was vLLM built correctly?") + return False def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + try: + return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + except AttributeError: + logger.warning( + "CUTLASS block FP8 ops not available - was vLLM built correctly?" + ) + return False def cutlass_scaled_mm( @@ -876,6 +890,14 @@ def cutlass_scaled_mm_azp( return out.view(*target_shape) +def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: + try: + return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) + except AttributeError: + logger.warning("CUTLASS sparse ops not available - was vLLM built correctly?") + return False + + def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: if cuda_device_capability < 90 or cuda_device_capability >= 110: return False diff --git a/vllm/compilation/passes/fusion/matcher_utils.py b/vllm/compilation/passes/fusion/matcher_utils.py index ec36c12d1776..4a4eba78362e 100644 --- a/vllm/compilation/passes/fusion/matcher_utils.py +++ b/vllm/compilation/passes/fusion/matcher_utils.py @@ -40,11 +40,13 @@ if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Dynamic] = torch.ops._C.scaled_fp4_quant.out # noqa: E501 -if current_platform.is_cuda(): +if current_platform.is_cuda() and hasattr(torch.ops._C, "per_token_group_fp8_quant"): QUANT_OPS[kFp8Dynamic128Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 QUANT_OPS[kFp8Dynamic64Sym] = torch.ops._C.per_token_group_fp8_quant.default # noqa: E501 -SILU_MUL_OP = torch.ops._C.silu_and_mul.default +SILU_MUL_OP = ( + torch.ops._C.silu_and_mul.default if hasattr(torch.ops._C, "silu_and_mul") else None +) class MatcherCustomOp(ABC): @@ -448,7 +450,7 @@ def inputs(self) -> list[torch.Tensor]: class MatcherSiluAndMul(MatcherCustomOp): def __init__(self, enabled: bool | None = None) -> None: if enabled is None: - enabled = SiluAndMul.enabled() + enabled = SiluAndMul.enabled() and SILU_MUL_OP is not None super().__init__(enabled) def inputs(self) -> list[torch.Tensor]: diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 493c26d3aed9..1ee2d4e5ecd7 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -12,6 +12,7 @@ import partial_json_parser import regex as re from fastapi import Request +from openai_harmony import HarmonyError from partial_json_parser.core.options import Allow from vllm.engine.protocol import EngineClient @@ -708,7 +709,15 @@ async def chat_completion_stream_generator( # Track accumulated content per token with their state token_states: list[TokenState] = [] for token_id in output.token_ids: - harmony_parser.process(token_id) + try: + harmony_parser.process(token_id) + except HarmonyError as e: + logger.warning( + "HarmonyError in stream generator, " + "returning partial result: %s", + e, + ) + break token_delta = harmony_parser.last_content_delta or "" token_states.append( TokenState( diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 52f03447dcaa..9cda424d8075 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -386,23 +386,49 @@ def _make_request_with_harmony( assert not self.supports_code_interpreter if (reasoning_effort := request.reasoning_effort) == "none": raise ValueError(f"Harmony does not support {reasoning_effort=}") + + # Extract client-provided system message content so it can be + # passed as structured instructions rather than appended as a raw + # system message (which the Harmony parser cannot handle). + non_system_messages = [] + system_instructions_parts: list[str] = [] + for msg in request.messages: + msg_dict = ( + msg if isinstance(msg, dict) else msg.model_dump(exclude_none=True) + ) + if msg_dict.get("role") == "system": + content = msg_dict.get("content") or "" + if isinstance(content, list): + content = "".join( + c.get("text", "") + for c in content + if isinstance(c, dict) and c.get("type") == "text" + ) + if content: + system_instructions_parts.append(content) + else: + non_system_messages.append(msg) + instructions = "\n".join(system_instructions_parts) or None + sys_msg = get_system_message( reasoning_effort=reasoning_effort, browser_description=None, python_description=None, with_custom_tools=should_include_tools, + instructions=instructions, ) messages.append(sys_msg) # Add developer message. - if request.tools: + if request.tools or instructions: dev_msg = get_developer_message( - tools=request.tools if should_include_tools else None # type: ignore[arg-type] + instructions=instructions, + tools=request.tools if should_include_tools else None, # type: ignore[arg-type] ) messages.append(dev_msg) - # Add user message. - messages.extend(parse_chat_inputs_to_harmony_messages(request.messages)) + # Add user message (system messages already extracted above). + messages.extend(parse_chat_inputs_to_harmony_messages(non_system_messages)) # Render prompt token ids. prompt_token_ids = render_for_completion(messages) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8ce5432fed83..20b3ade86569 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -36,6 +36,131 @@ logger = init_logger(__name__) +# enum for mxfp4 backend +class Mxfp4Backend(Enum): + NONE = 0 + + # FlashInfer Backend + SM100_FI_MXFP4_MXFP8_TRTLLM = 1 + SM100_FI_MXFP4_MXFP8_CUTLASS = 2 + SM100_FI_MXFP4_BF16 = 3 + SM90_FI_MXFP4_BF16 = 4 + + # Marlin Backend + MARLIN = 5 + + # Triton Backend + TRITON = 6 + + CK = 7 + + +def get_mxfp4_backend_with_lora() -> Mxfp4Backend: + """ + Not all MXFP4 backends support LoRA. Select backends that are known to + have LoRA support. + """ + if not current_platform.is_cuda(): + return Mxfp4Backend.NONE + + # If FlashInfer is not available, try either Marlin or Triton + triton_kernels_supported = ( + has_triton_kernels() + # NOTE: triton_kernels are only confirmed to work on SM90, SM100, + # and SM120. + # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 + # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 + and ( + (9, 0) <= current_platform.get_device_capability() < (11, 0) + or (12, 0) <= current_platform.get_device_capability() <= (12, 1) + ) + ) + if envs.VLLM_MXFP4_USE_MARLIN is False and triton_kernels_supported: + logger.info_once("[get_mxfp4_backend_with_lora] Using Triton backend") + return Mxfp4Backend.TRITON + + logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") + return Mxfp4Backend.MARLIN + + +def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: + # Backend Selection + + if with_lora_support: + return get_mxfp4_backend_with_lora() + + if current_platform.is_cuda(): + if ( + current_platform.is_device_capability(90) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 + ): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 + elif ( + current_platform.is_device_capability_family(100) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS + ): + logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS + elif ( + current_platform.is_device_capability_family(100) + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 + ): + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local" + ) + return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM + elif current_platform.is_device_capability_family(100) and has_flashinfer(): + logger.info_once( + "Using FlashInfer MXFP4 BF16 backend for SM100, " + "For faster performance on SM100, consider setting " + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " + "accuracy." + ) + return Mxfp4Backend.SM100_FI_MXFP4_BF16 + elif ( + current_platform.is_device_capability_family(100) + or current_platform.is_device_capability(90) + ) and not has_flashinfer(): + logger.warning_once( + "MXFP4 MoE is enabled on Hopper/Blackwell but FlashInfer " + "is not available. This may result in degraded performance. " + "Please `pip install vllm[flashinfer]` for best results." + ) + + # If FlashInfer is not available, try either Marlin or Triton + triton_kernels_supported = ( + has_triton_kernels() + # NOTE: triton_kernels are only confirmed to work on SM90 and SM100 + # SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317 + # SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498 + and (9, 0) <= current_platform.get_device_capability() < (11, 0) + ) + if envs.VLLM_MXFP4_USE_MARLIN or not triton_kernels_supported: + logger.info_once("Using Marlin backend") + return Mxfp4Backend.MARLIN + else: + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + elif current_platform.is_xpu(): + logger.info_once("Using xpu backend on XPU") + return Mxfp4Backend.MARLIN + elif current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx950 + + if rocm_aiter_ops.is_enabled() and on_gfx950(): + logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)") + return Mxfp4Backend.CK + elif has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON + + return Mxfp4Backend.NONE + + class Mxfp4Config(QuantizationConfig): def __init__(self, ignored_layers: list[str] | None = None): super().__init__() diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b309bf14d991..674d75447cc9 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -15,7 +15,10 @@ def cutlass_fp8_supported() -> bool: capability_tuple = current_platform.get_device_capability() capability = -1 if capability_tuple is None else capability_tuple.to_int() - return ops.cutlass_scaled_mm_supports_fp8(capability) + try: + return ops.cutlass_scaled_mm_supports_fp8(capability) + except (AttributeError, RuntimeError): + return False def cutlass_block_fp8_supported() -> bool: @@ -25,7 +28,10 @@ def cutlass_block_fp8_supported() -> bool: capability_tuple = current_platform.get_device_capability() capability = -1 if capability_tuple is None else capability_tuple.to_int() - return ops.cutlass_scaled_mm_supports_block_fp8(capability) + try: + return ops.cutlass_scaled_mm_supports_block_fp8(capability) + except (AttributeError, RuntimeError): + return False def cutlass_group_gemm_supported() -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0fa59579ee76..5b43bff72463 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -503,6 +503,12 @@ def step_with_batch_queue( exec_model_fut.result() raise RuntimeError("unexpected error") + # Handle None return from execute_model (async scheduling). + # This mirrors the same check in step() method. + if model_output is None: + grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output) + model_output = self.model_executor.sample_tokens(grammar_output) + # Before processing the model output, process any aborts that happened # during the model execution. self._process_aborts_queue()