diff --git a/vllm/envs.py b/vllm/envs.py index d29e367bcae8..b3da89cd1760 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -51,9 +51,10 @@ VLLM_CPU_OMP_THREADS_BIND: str = "auto" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_SGL_KERNEL: bool = False - VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False + VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024 + VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -97,7 +98,6 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] - VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_OINK_OPS: bool = False VLLM_ROCM_USE_AITER: bool = False @@ -117,6 +117,8 @@ VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True + VLLM_USE_AITER_FUSED: bool = True + VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -169,7 +171,7 @@ VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) - VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto" + VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm" VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 @@ -246,8 +248,6 @@ VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False - VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 - VLLM_XPU_ENABLE_XPU_GRAPH: bool = False def get_default_cache_root(): @@ -295,16 +295,6 @@ def use_aot_compile() -> bool: ) -def use_mega_aot_artifact(): - from vllm.utils.torch_utils import is_torch_equal_or_newer - - default_value = ( - "1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0" - ) - - return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1" - - def env_with_choices( env_name: str, default: str | None, @@ -628,7 +618,10 @@ def _get_or_set_default() -> str: # Enable loading compiled models directly from cached standalone compile artifacts # without re-splitting graph modules. This reduces overhead during model # loading by using reconstruct_serializable_fn_from_mega_artifact. - "VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact, + "VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get( + "VLLM_USE_MEGA_AOT_ARTIFACT", "0" + ) + == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), @@ -719,11 +712,6 @@ def _get_or_set_default() -> str: else None, # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), - # (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout - # at model load time. Eliminates per-inference layout conversion overhead. - "VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool( - int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1")) - ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -841,6 +829,15 @@ def _get_or_set_default() -> str: ), # Enable SPMD mode for TPU backend. "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), + "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024)) + ), + # Control whether to use fused MoE activation chunking. Current chunking + # logic is incompatible with torch.compile and causes IMA. See issue + # https://github.com/vllm-project/vllm/issues/19631. + "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool( + int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1")) + ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( @@ -910,9 +907,6 @@ def _get_or_set_default() -> str: "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ["VLLM_DISABLED_KERNELS"].split(","), - "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool( - int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1")) - ), # Disable pynccl (using torch.distributed instead) "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") @@ -993,6 +987,14 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1") ), + # Enable AITER fused decode kernel for MLA (ROCm only, decode path only) + # Enable AITER fused kernels for MLA (ROCm only, prefill and decode) + # Fuses: RoPE + concat + KV cache write (prefill) or BMM + RoPE + + # concat + KV cache write (decode) in ONE kernel + # By default is enabled for AMD GPUs with FP8 support. + "VLLM_USE_AITER_FUSED": lambda: ( + os.getenv("VLLM_USE_AITER_FUSED", "True").lower() in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") @@ -1001,6 +1003,10 @@ def _get_or_set_default() -> str: "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Pad the weights for the moe kernel "VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), + # custom paged attention kernel for MI3* cards + "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( + os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") + ), # Whether to use the shuffled kv cache layout "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") @@ -1305,9 +1311,14 @@ def _get_or_set_default() -> str: ["throughput", "latency", "masked_gemm"], ), # Flashinfer fused allreduce backend. + # "auto" will default to "mnnvl", which performs mostly same/better than "trtllm". + # But "mnnvl" backend does not support fuse with quantization. + # TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph: + # https://github.com/vllm-project/vllm/issues/35772 + # Should switch back to "auto" if the issue is resolved. "VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices( "VLLM_FLASHINFER_ALLREDUCE_BACKEND", - "auto", + "trtllm", ["auto", "trtllm", "mnnvl"], ), # Control the workspace buffer size for the FlashInfer backend. @@ -1640,14 +1651,6 @@ def _get_or_set_default() -> str: "VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool( int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0")) ), - # NIXL EP environment variables - "VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int( - os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "32") - ), - # Whether enable XPU graph on Intel GPU - "VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool( - int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0")) - ), } @@ -1784,7 +1787,6 @@ def compile_factors() -> dict[str, object]: "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "VLLM_CPU_KVCACHE_SPACE", "VLLM_CPU_MOE_PREPACK", - "VLLM_ZENTORCH_WEIGHT_PREPACK", "VLLM_TEST_FORCE_LOAD_FORMAT", "VLLM_ENABLE_CUDA_COMPATIBILITY", "VLLM_CUDA_COMPATIBILITY_PATH", diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0215ec1a0735..5b6cb72ca2fc 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -302,6 +302,11 @@ def __init__( prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + # AITER fused kernel parameters + cos_cache: torch.Tensor | None = None, + sin_cache: torch.Tensor | None = None, + is_neox_style: bool = False, + rotary_emb: torch.nn.Module | None = None, **extra_impl_args, ): super().__init__() @@ -314,6 +319,7 @@ def __init__( self.kv_lora_rank = kv_lora_rank self.kv_b_proj = kv_b_proj self.head_size = kv_lora_rank + qk_rope_head_dim + self.rotary_emb = rotary_emb self.layer_name = prefix self.indexer = indexer @@ -440,6 +446,78 @@ def __init__( and self.kv_b_proj.weight.dtype == torch.bfloat16 ) + # Store RoPE caches for AITER fused kernels + self.cos_cache = cos_cache + self.sin_cache = sin_cache + self.is_neox_style = is_neox_style + + # Check if AITER fused kernels can be used + self.use_aiter_fused = ( + current_platform.is_rocm() + and ( + self.is_aiter_triton_fp4_bmm_enabled + or self.is_aiter_triton_fp8_bmm_enabled + ) + and envs.VLLM_USE_AITER_FUSED + and cos_cache is not None + and sin_cache is not None + ) + + if self.use_aiter_fused: + # Import prefill kernel (shared between FP4 and FP8) + try: + from aiter.ops.triton.fusions.fused_kv_cache import ( + fused_qk_rope_cat_and_cache_mla, + ) + + self._fused_prefill_kernel = fused_qk_rope_cat_and_cache_mla + except ImportError as e: + logger.warning_once( + f"AITER fused prefill kernel not available: {e}, " + "falling back to separate ops", + scope="local", + ) + self.use_aiter_fused = False + + # Import FP4 or FP8 decode kernel + if self.use_aiter_fused and self.is_aiter_triton_fp4_bmm_enabled: + try: + from aiter.ops.triton.fusions.fused_bmm_rope_kv_cache import ( + fused_fp4_bmm_rope_cat_and_cache_mla, + ) + + self._fused_decode_kernel = fused_fp4_bmm_rope_cat_and_cache_mla + self._fused_kernel_type = "fp4" + except ImportError as e: + logger.warning_once( + f"AITER fused FP4 decode kernel not available: {e}, " + "falling back to separate ops", + scope="local", + ) + self.use_aiter_fused = False + elif self.use_aiter_fused: + try: + from aiter.ops.triton.fusions.fused_bmm_rope_kv_cache import ( + fused_fp8_bmm_rope_cat_and_cache_mla, + ) + + self._fused_decode_kernel = fused_fp8_bmm_rope_cat_and_cache_mla + self._fused_kernel_type = "fp8" + except ImportError as e: + logger.warning_once( + f"AITER fused FP8 decode kernel not available: {e}, " + "falling back to separate ops", + scope="local", + ) + self.use_aiter_fused = False + + if self.use_aiter_fused: + logger.info( + "AITER fused MLA kernels ENABLED: %s variant " + "(decode: BMM+RoPE+KV, prefill: RoPE+KV)", + self._fused_kernel_type.upper(), + ) + # Attributes for forward_impl method self._vllm_config = get_current_vllm_config() self._chunked_prefill_workspace_size: int | None = None @@ -465,12 +543,21 @@ def forward( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, + use_fused_path: bool = False, + rotary_emb: torch.nn.Module | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + # Store context for custom ops + forward_context: ForwardContext = get_forward_context() + if positions is not None: + forward_context._positions = positions + forward_context._use_fused_path = use_fused_path + if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] @@ -504,6 +591,11 @@ def forward( q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: + # Custom ops path (ROCm) + if slot_mapping is not None: + forward_context.slot_mapping[self.layer_name] = slot_mapping + + # KV cache update (no-op for fused path, writes for unfused path) kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_c_normed, k_pe, @@ -519,6 +611,8 @@ def forward( k_pe, output, self.layer_name, + positions, + slot_mapping, kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output @@ -528,6 +622,8 @@ def forward( kv_c_normed, k_pe, self.layer_name, + positions, + slot_mapping, kv_cache_dummy_dep=kv_cache_dummy_dep, ) @@ -541,9 +637,18 @@ def forward_impl( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + # AITER fused kernel parameters + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, + use_fused_path: bool | None = None, + rotary_emb: torch.nn.Module | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." + # Default to instance variable if not provided + if use_fused_path is None: + use_fused_path = self.use_aiter_fused + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for MLA" @@ -600,11 +705,71 @@ def forward_impl( num_mqa_tokens = attn_metadata.num_decode_tokens num_mha_tokens = q.size(0) - num_mqa_tokens + # Trim positions tensor to match actual batch size if needed + if positions is not None: + num_actual_tokens = q.size(0) + if positions.size(0) > num_actual_tokens: + positions = positions[:num_actual_tokens] + + # Retrieve slot_mapping from attn_metadata if not provided + if slot_mapping is None and attn_metadata is not None: + slot_mapping = attn_metadata.slot_mapping + if num_mha_tokens > 0: + # Prefill path: extract prefill slices + prefill_q = q[num_mqa_tokens:] + prefill_k_c_normed = k_c_normed[num_mqa_tokens:] + prefill_k_pe = k_pe[num_mqa_tokens:] + + if ( + self.use_aiter_fused + and rotary_emb is not None + and positions is not None + and slot_mapping is not None + and prefill_q.shape[0] > 0 + ): + # AITER fused prefill: RoPE + KV cache write in single kernel + prefill_positions = positions[num_mqa_tokens:] + prefill_slot_mapping = slot_mapping[num_mqa_tokens:] + + # Split Q into nope and pe components + prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim] + prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] + + # Reshape K to [batch, num_kv_heads, head_dim] + prefill_k_nope_3d = prefill_k_c_normed.view( + -1, self.num_kv_heads, self.kv_lora_rank + ) + prefill_k_pe_3d = prefill_k_pe.squeeze(1).view( + -1, self.num_kv_heads, self.qk_rope_head_dim + ) + + # AITER fused kernel applies RoPE and writes to KV cache + q_fused, _, k_pe_out, _ = self._fused_prefill_kernel( + q_nope=prefill_q_nope, + q_pe=prefill_q_pe, + k_nope=prefill_k_nope_3d, + k_pe=prefill_k_pe_3d, + kv_cache=kv_cache, + slot_mapping=prefill_slot_mapping, + pos=prefill_positions, + cos=self.cos_cache, + sin=self.sin_cache, + k_scale=self._k_scale, + is_neox=self.is_neox_style, + num_decode_toks_for_zeros=0, + apply_scale=True, + q_out_dtype=prefill_q.dtype, + ) + + prefill_q[:] = q_fused + prefill_k_pe[:] = k_pe_out + + # Run prefill attention self.impl.forward_mha( - q[num_mqa_tokens:], - k_c_normed[num_mqa_tokens:], - k_pe[num_mqa_tokens:], + prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, attn_metadata, self._k_scale, @@ -612,9 +777,18 @@ def forward_impl( ) if num_mqa_tokens > 0: + # Decode path: extract decode slices mqa_q = q[:num_mqa_tokens] mqa_output_slice = output[:num_mqa_tokens] + # Extract additional slices for AITER fused decode + if self.use_aiter_fused and slot_mapping is not None: + mqa_k_c_normed = k_c_normed[:num_mqa_tokens] + mqa_k_pe = k_pe[:num_mqa_tokens] + if positions is not None: + mqa_positions = positions[:num_mqa_tokens] + mqa_slot_mapping = slot_mapping[:num_mqa_tokens] + mqa_q_nope, mqa_q_pe = mqa_q.split( [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) @@ -629,7 +803,42 @@ def forward_impl( mqa_pe_padded.copy_(mqa_q_pe) mqa_q_pe = mqa_pe_padded - if self.is_aiter_triton_fp4_bmm_enabled: + # Compute positions from seq_lens if not provided + # For decode: position = seq_lens - 1 (0-indexed current position) + if positions is None and attn_metadata.decode is not None: + decode_seq_lens = attn_metadata.decode.seq_lens + positions = decode_seq_lens - 1 + logger.info_once( + "[MLA] Computed positions from decode seq_lens: shape=%s", + positions.shape, + scope="local", + ) + + if self.use_aiter_fused and slot_mapping is not None: + # AITER fused path: RoPE + KV cache write in kernel + assert positions is not None + assert slot_mapping is not None + + logger.info_once( + "Using AITER fused %s decode kernel for MLA", + self._fused_kernel_type.upper(), + scope="local", + ) + + # Fused kernel applies RoPE and writes to KV cache + mqa_ql_nope, mqa_q_pe_rotated = self._run_aiter_fused_decode( + mqa_q_nope, # [num_heads, batch, qk_nope_head_dim] + mqa_q_pe, # [batch, num_heads, qk_rope_head_dim] + mqa_k_c_normed, # [batch, kv_lora_rank] + mqa_k_pe, # [batch, 1, qk_rope_head_dim] + kv_cache, + mqa_slot_mapping, + mqa_positions, + ) + mqa_q_pe = mqa_q_pe_rotated + + elif self.is_aiter_triton_fp4_bmm_enabled: + # Unfused FP4 path: RoPE already applied in mla.py from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 mqa_ql_nope = batched_gemm_a16wfp4( @@ -705,12 +914,96 @@ def forward_impl( # v_up projection self._v_up_proj(attn_out, out=mqa_output_slice) + return output_padded + def _run_aiter_fused_decode( + self, + mqa_q_nope: torch.Tensor, + mqa_q_pe: torch.Tensor, + k_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + positions: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run AITER fused decode operation. + + Fuses: FP4/FP8 BMM + RoPE + concat + KV cache write in ONE kernel. + Chooses FP4 or FP8 variant based on GPU capabilities. + + Args: + mqa_q_nope: [num_heads, batch, qk_nope_head_dim] + mqa_q_pe: [batch, num_heads, qk_rope_head_dim] NO RoPE! + k_c_normed: [batch, kv_lora_rank] + k_pe: [batch, 1, qk_rope_head_dim] NO RoPE! + kv_cache: KV cache tensor + slot_mapping: Slot mapping for cache write + positions: Position IDs for RoPE + + Returns: + mqa_ql_nope: [batch, num_heads, kv_lora_rank] + mqa_q_pe_rotated: [batch, num_heads, qk_rope_head_dim] + """ + # Reshape K to [batch, num_kv_heads, head_dim] + k_nope_3d = k_c_normed.view(-1, self.num_kv_heads, self.kv_lora_rank) + k_rope_3d = k_pe.squeeze(1).view(-1, self.num_kv_heads, self.qk_rope_head_dim) + + # Call FP4 or FP8 fused kernel + if self._fused_kernel_type == "fp4": + q_fused, _, _, _ = self._fused_decode_kernel( + mqa_q_nope, + self.W_K, + self.W_K_scale, + mqa_q_pe, + k_nope_3d, + k_rope_3d, + kv_cache, + slot_mapping, + positions, + self.cos_cache, + self.sin_cache, + y=None, + transpose_bm=True, + prequant=True, + y_scale=None, + k_scale=self._k_scale, + is_neox=self.is_neox_style, + q_out_dtype=mqa_q_nope.dtype, + num_decode_toks_for_zeros=0, + ) + else: # fp8 + q_fused, _, _, _ = self._fused_decode_kernel( + mqa_q_nope, + self.W_K, + self.W_K_scale, + mqa_q_pe, + k_nope_3d, + k_rope_3d, + kv_cache, + slot_mapping, + positions, + self.cos_cache, + self.sin_cache, + group_size=128, + transpose_bm=True, + k_scale=self._k_scale, + is_neox=self.is_neox_style, + q_out_dtype=mqa_q_nope.dtype, + num_decode_toks_for_zeros=0, + ) + + # Split fused output into nope and rope components + mqa_ql_nope = q_fused[..., : self.kv_lora_rank] + mqa_q_pe_rotated = q_fused[..., self.kv_lora_rank :] + + return mqa_ql_nope, mqa_q_pe_rotated + def process_weights_after_loading(self, act_dtype: torch.dtype): # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and + # perform the bmm's in 16-bit, the extra memory overhead of + # this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights( self.kv_b_proj, out_dtype=act_dtype ).T @@ -886,14 +1179,44 @@ def unified_mla_attention( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: # kv_cache_dummy_dep is not used but accepting it creates a data dependency # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep - attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) - output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) + attn_metadata, layer, kv_cache, forward_context = get_attention_context(layer_name) + + # Get rotary_emb from layer (stored during __init__) + rotary_emb = layer.rotary_emb + + # Retrieve slot_mapping from forward_context or attn_metadata + slot_mapping = None + if hasattr(forward_context, "slot_mapping") and isinstance( + forward_context.slot_mapping, dict + ): + slot_mapping = forward_context.slot_mapping.get(layer_name) + + # Fallback to attn_metadata if not in forward_context + if slot_mapping is None and attn_metadata is not None: + slot_mapping = attn_metadata.slot_mapping + + # Use AITER fused path if available (static decision based on layer config) + use_fused_path = layer.use_aiter_fused + + output = layer.forward_impl( + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + positions=positions, + slot_mapping=slot_mapping, + use_fused_path=use_fused_path, + rotary_emb=rotary_emb, + ) return output @@ -903,6 +1226,8 @@ def unified_mla_attention_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() @@ -925,8 +1250,9 @@ def unified_mla_kv_cache_update( k_scale: torch.Tensor, ) -> torch.Tensor: """ - Returns a dummy that is passed to unified_attention to signal a side effect and - the data dependency between them to ensure torch.compile preserves ordering. + Writes KV cache for unfused path. For AITER fused path, returns early + (KV writes handled in forward_impl). + Returns a dummy tensor to signal side effect for torch.compile ordering. """ forward_context = get_forward_context() if forward_context.attn_metadata is None: @@ -934,6 +1260,15 @@ def unified_mla_kv_cache_update( return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) attn_layer = forward_context.no_compile_layers[layer_name] + + # Use static flag (CUDA graph compatible) + if attn_layer.use_aiter_fused: + # AITER fused path: KV cache writes handled in forward_impl + # - Prefill: written after RoPE application + # - Decode: written in fused decode kernel + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + # Unfused path: write KV cache here after mla.py applies RoPE kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping @@ -941,7 +1276,8 @@ def unified_mla_kv_cache_update( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " ) layer_slot_mapping = slot_mapping.get(layer_name) - if layer_slot_mapping is not None: + + if layer_slot_mapping is not None and kv_c_normed.shape[0] > 0: attn_layer.impl.do_kv_cache_update( kv_c_normed, k_pe, @@ -978,6 +1314,8 @@ def unified_mla_attention_with_output( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, @@ -986,7 +1324,22 @@ def unified_mla_attention_with_output( # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep - attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) + attn_metadata, layer, kv_cache, forward_context = get_attention_context(layer_name) + + # Retrieve slot_mapping from attn_metadata if not provided as parameter + if ( + slot_mapping is None + and attn_metadata is not None + and hasattr(attn_metadata, "slot_mapping") + ): + slot_mapping = attn_metadata.slot_mapping + + # Get rotary_emb from layer (stored during __init__) + rotary_emb = layer.rotary_emb + + # Use AITER fused path if available (static decision based on layer config) + use_fused_path = layer.use_aiter_fused + layer.forward_impl( q, kv_c_normed, @@ -996,6 +1349,10 @@ def unified_mla_attention_with_output( output=output, output_scale=output_scale, output_block_scale=output_block_scale, + positions=positions, + slot_mapping=slot_mapping, + use_fused_path=use_fused_path, + rotary_emb=rotary_emb, ) @@ -1005,6 +1362,8 @@ def unified_mla_attention_with_output_fake( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index e8ed8a5249d1..b3acc89712cb 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -106,7 +106,7 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 def forward( - self, x: torch.Tensor + self, x: torch.Tensor, x_scale: torch.Tensor | None = None ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 44fd516f5e5c..d57cf266af5b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -27,6 +27,7 @@ ) from vllm.model_executor.layers.utils import ( dispatch_unquantized_gemm, + is_layer_moe_router_gate, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -222,8 +223,16 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: - if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike(): + assert input_scale is None, ( + "UnquantizedLinearMethod does not support input_scale" + ) + if ( + envs.VLLM_BATCH_INVARIANT + and current_platform.is_cuda_alike() + and is_layer_moe_router_gate(getattr(layer, "prefix", "")) + ): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -384,11 +393,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def forward( self, x: torch.Tensor, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) + output = self.quant_method.apply(self, x, bias, input_scale=x_scale) if not self.return_bias: return output @@ -574,12 +584,15 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply( + self, input_, bias, input_scale=x_scale + ) if self.gather_output and self.tp_size > 1: # All-gather across the partitions. @@ -1512,6 +1525,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: if self.input_is_parallel: input_parallel = input_ @@ -1523,10 +1537,12 @@ def forward( # Matrix multiply. assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) + # Only fuse bias add into GEMM for rank 0 (ensures bias not + # added multiple times in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias_) + output_parallel = self.quant_method.apply( + self, input_parallel, bias_, input_scale=x_scale + ) if self.reduce_results and self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 1d3e987b7e17..e6541e65eb83 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -5,10 +5,93 @@ import torch from vllm.config import CacheConfig +from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.quantization import QuantizationConfig +logger = init_logger(__name__) + +# Import AITER ops for fused RMSNorm + FP8 quantization +try: + from aiter import dtypes + from aiter.jit.utils.torch_guard import torch_compile_guard + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + _AITER_AVAILABLE = True +except ImportError: + _AITER_AVAILABLE = False + dtypes = None + torch_compile_guard = None + fused_rms_fp8_group_quant = None + + +def _fused_rms_fp8_group_quant_fake( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile/CUDA graphs.""" + if dtype_quant is None: + dtype_quant = dtypes.fp8 + m, n1 = q_c.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=q_c.device) + out1_bs = torch.empty( + (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=q_c.device + ) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out2 = torch.empty_like(kv_c) + return out1_quantized, out1_bs, out2 + + +def _fuse_rmsnorm_quant_impl( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm + FP8 quantization using AITER. + + Fuses RMSNorm on q_c with FP8 group quantization, and RMSNorm on kv_c + without quantization. + + Returns: + (q_c_quantized, q_c_scale, kv_c_normed) + """ + (q_c_quantized, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant( + q_c, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_c, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + group_size, + dtype_quant, + None, + output_unquantized_inp1, + transpose_scale, + ) + return q_c_quantized, q_c_scale, kv_c_normed + + +# Make fusion transparent to compiler (no @torch_compile_guard) +# This allows the compiler to trace through and batch operations efficiently +_fuse_rmsnorm_quant = _fuse_rmsnorm_quant_impl + @dataclass class MLAModules: @@ -87,6 +170,21 @@ def __init__( self.indexer_rope_emb = mla_modules.indexer_rotary_emb self.is_sparse = mla_modules.is_sparse + # Extract RoPE caches for AITER fused kernels + if self.rotary_emb is not None: + # RoPE stores combined cos_sin_cache, need to split it + # Format: [seq_len, rotary_dim] where first half is cos, second half is sin + cos_sin_cache = self.rotary_emb.cos_sin_cache + rotary_dim = self.rotary_emb.rotary_dim + half_dim = rotary_dim // 2 + self.cos_cache = cos_sin_cache[:, :half_dim] + self.sin_cache = cos_sin_cache[:, half_dim:] + self.is_neox_style = self.rotary_emb.is_neox_style + else: + self.cos_cache = None + self.sin_cache = None + self.is_neox_style = False + if self.indexer is not None: assert hasattr(self.indexer, "topk_tokens") self.topk_tokens = self.indexer.topk_tokens @@ -106,10 +204,36 @@ def __init__( kv_b_proj=self.kv_b_proj, use_sparse=self.is_sparse, indexer=self.indexer, + # Pass RoPE caches for AITER fused kernels + cos_cache=self.cos_cache, + sin_cache=self.sin_cache, + is_neox_style=self.is_neox_style, + # Pass RoPE module (static, doesn't change) + rotary_emb=self.rotary_emb, ) self.prefix = prefix + # Enable RMSNorm+Quant fusion when AITER is available with FP8 + self.quant_config = quant_config + self.quant_dtype = None + self.fuse_qknorm_quant = False + + if _AITER_AVAILABLE and quant_config is not None: + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + if isinstance(quant_config, Fp8Config): + self.quant_dtype = dtypes.fp8 + self.fuse_qknorm_quant = True + logger.info( + "[MLA_FUSION_INIT] Fusion enabled for %s: " + "AITER available and FP8 quantization detected", + prefix, + ) + + # VERIFICATION: Confirm all_mla_fused_mixed_batch branch is active + logger.warning("MLA.PY ALL_MLA_FUSED_MIXED_BATCH BRANCH ACTIVE - 2026-03-20") + def forward( self, positions: torch.Tensor, @@ -118,6 +242,7 @@ def forward( ) -> torch.Tensor: q_c = None kv_lora = None + q_c_scale = None # Set when fuse_qknorm_quant is enabled if self.q_lora_rank is not None: assert self.fused_qkv_a_proj is not None, ( @@ -130,13 +255,37 @@ def forward( "q_b_proj is required when q_lora_rank is not None" ) + # Step 1: QKV projection (use existing layer) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_lora = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # Step 2: Apply RMSNorm and optional FP8 quantization + if self.fuse_qknorm_quant: + # Fused RMSNorm + FP8 quantization + q_c_quantized, q_c_scale, kv_c_normed = _fuse_rmsnorm_quant( + q_c, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + kv_c, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + dtype_quant=self.quant_dtype, + group_size=128, + output_unquantized_inp1=False, + transpose_scale=True, + ) + q = self.q_b_proj(q_c_quantized, x_scale=q_c_scale)[0] + else: + # Unfused path: RMSNorm only + q_c = self.q_a_layernorm(q_c) + kv_c_normed = self.kv_a_layernorm(kv_c) + q = self.q_b_proj(q_c)[0] else: assert self.kv_a_proj_with_mqa is not None, ( "kv_a_proj_with_mqa is required when q_lora_rank is None" @@ -146,18 +295,67 @@ def forward( ) kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) + # VERIFY: Log mla.py outputs before RoPE (EAGER MODE ONLY) + # COMMENTED OUT: Breaks torch compile / CUDA graph capture + # from vllm.logger import init_logger + # logger = init_logger(__name__) + # logger.warning( + # f"[VERIFY MLA] BEFORE RoPE: " + # f"q: abs_max={q.float().abs().max().item():.6e}, " + # f"first_3={q[0,0,:3].tolist()}, " + # f"k_pe: abs_max={k_pe.float().abs().max().item():.6e}, " + # f"first_3={k_pe[0,0,:3].tolist()}, " + # f"kv_c_normed: abs_max=" + # f"{kv_c_normed.float().abs().max().item():.6e}, " + # f"first_3={kv_c_normed[0,:3].tolist()}" + # ) + + # STEP 3: Determine if fused path can be used (SINGLE CHECK) + # Check all requirements once and use everywhere + can_use_fused_path = ( + hasattr(self.mla_attn, "use_aiter_fused") + and self.mla_attn.use_aiter_fused # Platform supports fused kernel + and positions is not None # Required for RoPE + and self.rotary_emb is not None # RoPE module available + ) + + # Apply RoPE based on fused vs unfused path if self.rotary_emb is not None: - q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim :], k_pe - ) + if can_use_fused_path: + # FUSED PATH: Skip RoPE here, custom op will apply it + # Problem: num_decode_tokens retrieved from forward_context gets + # frozen as a constant when CUDA graph is captured, causing RoPE + # to be applied to wrong tokens (e.g., q[512:] instead of q[1:]) + # Solution: Move RoPE to custom op (splitting op, not compiled) + # where attn_metadata.num_decode_tokens is available dynamically + pass + else: + # UNFUSED PATH: Apply RoPE to ALL tokens + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, + q[..., self.qk_nope_head_dim :], # Q PE part gets RoPE + k_pe, # K PE gets RoPE + ) + + # Log AFTER RoPE + # logger.warning( + # f"[UNFUSED AFTER ROPE] " + # f"q_pe_abs_max=" + # f"{q[..., self.qk_nope_head_dim:].float().abs().max()" + # f".item():.6e}, " + # f"k_pe_abs_max={k_pe.float().abs().max().item():.6e}, " + # f"k_pe_after_first3={k_pe[0, 0, :3].tolist()}" + # ) if self.indexer and self.is_sparse: _topk_indices = self.indexer( @@ -167,11 +365,27 @@ def forward( if llama_4_scaling is not None: q *= llama_4_scaling + # STEP 4: Store rotary_emb in forward_context for custom ops + # positions is now passed as a parameter to custom ops (no longer + # stored in context). rotary_emb is still stored in context (not + # needed in compiled graph) + from vllm.forward_context import get_forward_context + + forward_context = get_forward_context() + if self.rotary_emb is not None: + forward_context._rotary_emb = self.rotary_emb + + # STEP 5: Pass to mla_attention attn_out = self.mla_attn( - q, + q, # Has RoPE if unfused, NO RoPE if fused kv_c_normed, - k_pe, + k_pe, # Has RoPE if unfused, NO RoPE if fused output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), + positions=positions, + slot_mapping=None, # Retrieved from attn_metadata in mla_attention.py + use_fused_path=can_use_fused_path, # Single flag for entire forward pass + rotary_emb=self.rotary_emb, ) - return self.o_proj(attn_out)[0] + final_out = self.o_proj(attn_out)[0] + return final_out diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 69255a2793cb..72019f5b47f8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -441,6 +441,7 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. @@ -488,7 +489,7 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=input_scale, bias=bias, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9568d1320bc6..5df5fecf4205 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -400,7 +400,6 @@ def apply( input_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, ) -> torch.Tensor: - assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] @@ -411,20 +410,29 @@ def apply( ) and should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # FlashInfer: does not support pre-quantized input + assert input_scale is None, ( + "FlashInfer FP8 blockscale GEMM does not support pre-quantized input" + ) output = self._run_flashinfer(input_2d, weight, weight_scale) elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # DeepGEMM: does not support pre-quantized input + assert input_scale is None, ( + "DeepGEMM FP8 linear does not support pre-quantized input" + ) output = self._run_deepgemm(input_2d, weight, weight_scale) else: + # AITER/Triton/Cutlass: supports pre-quantized input output = self.w8a8_blockscale_op( input_2d, weight, weight_scale, input_scale ) if bias is not None: output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + return output.view(*output_shape) def _run_deepgemm( self, @@ -451,9 +459,15 @@ def _run_cutlass( weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + output_dtype = torch.bfloat16 if self.is_hopper: return torch.ops.vllm.padded_cutlass( q_input, @@ -461,7 +475,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) else: return cutlass_scaled_mm( @@ -470,7 +484,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_aiter( @@ -495,9 +509,13 @@ def _run_aiter( gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale if input_scale is not None: + # Use pre-quantized FP8 input directly q_input = input_2d + output_dtype = torch.bfloat16 else: + # Quantize input if not already quantized q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton) + output_dtype = input_2d.dtype return gemm_a8w8_blockscale_op( q_input, @@ -505,7 +523,7 @@ def _run_aiter( input_scale, weight_scale, list(self.weight_group_shape), - output_dtype=input_2d.dtype, + output_dtype=output_dtype, ) def _run_triton( @@ -515,16 +533,22 @@ def _run_triton( weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + output_dtype = torch.bfloat16 return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_flashinfer(