diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a88544c1c0f9..a09666b65a99 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -357,8 +357,11 @@ def forward( if self.use_output: if output_shape is None: + # Handle both 2D [num_tokens, hidden] and + # 3D [num_tokens, heads, head_dim] query + num_tokens = query.shape[0] output_shape = torch.Size( - (*query.shape[:-1], self.num_heads * self.head_size_v) + (num_tokens, self.num_heads * self.head_size_v) ) output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 08e1f4d444ee..3dcd9a84a139 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -180,7 +180,19 @@ def get_fp8_moe_backend( scope="local", ) - if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant: + # Determine if we should use DeepGEMM (top-level enable switch) + # - If explicitly set by user, respect their choice + # - If not platform supports DeepGEMM, disable it + # This helps avoid warning messages on unsupported platforms. + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM + if not is_deep_gemm_supported(): + use_deep_gemm = False + logger.info_once( + "DeepGEMM is disabled because the platform does not support it.", + scope="local", + ) + + if use_deep_gemm and moe_use_deep_gemm and block_quant: if not has_deep_gemm(): logger.warning_once( "DeepGEMM backend requested but not available.", scope="local"