diff --git a/vllm/envs.py b/vllm/envs.py index d29e367bcae8..830e8589c328 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -51,9 +51,10 @@ VLLM_CPU_OMP_THREADS_BIND: str = "auto" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_SGL_KERNEL: bool = False - VLLM_ZENTORCH_WEIGHT_PREPACK: bool = True VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") VLLM_XLA_CHECK_RECOMPILATION: bool = False + VLLM_FUSED_MOE_CHUNK_SIZE: int = 16 * 1024 + VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING: bool = True VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: Literal["auto", "nccl", "shm"] = "auto" VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = False VLLM_USE_RAY_WRAPPED_PP_COMM: bool = True @@ -97,7 +98,6 @@ VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] - VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_OINK_OPS: bool = False VLLM_ROCM_USE_AITER: bool = False @@ -117,6 +117,9 @@ VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True + VLLM_USE_AITER_FUSED: bool = True + VLLM_USE_AITER_PREFILL_FUSED: bool = True + VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -169,7 +172,7 @@ VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) - VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto" + VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm" VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 @@ -246,8 +249,6 @@ VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False - VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 - VLLM_XPU_ENABLE_XPU_GRAPH: bool = False def get_default_cache_root(): @@ -295,16 +296,6 @@ def use_aot_compile() -> bool: ) -def use_mega_aot_artifact(): - from vllm.utils.torch_utils import is_torch_equal_or_newer - - default_value = ( - "1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0" - ) - - return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1" - - def env_with_choices( env_name: str, default: str | None, @@ -628,7 +619,10 @@ def _get_or_set_default() -> str: # Enable loading compiled models directly from cached standalone compile artifacts # without re-splitting graph modules. This reduces overhead during model # loading by using reconstruct_serializable_fn_from_mega_artifact. - "VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact, + "VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get( + "VLLM_USE_MEGA_AOT_ARTIFACT", "0" + ) + == "1", # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), @@ -719,11 +713,6 @@ def _get_or_set_default() -> str: else None, # (CPU backend only) whether to use SGL kernels, optimized for small batch. "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), - # (Zen CPU backend) eagerly prepack weights into ZenDNN blocked layout - # at model load time. Eliminates per-inference layout conversion overhead. - "VLLM_ZENTORCH_WEIGHT_PREPACK": lambda: bool( - int(os.getenv("VLLM_ZENTORCH_WEIGHT_PREPACK", "1")) - ), # If the env var is set, Ray Compiled Graph uses the specified # channel type to communicate between workers belonging to # different pipeline-parallel stages. @@ -841,6 +830,15 @@ def _get_or_set_default() -> str: ), # Enable SPMD mode for TPU backend. "VLLM_XLA_USE_SPMD": lambda: bool(int(os.getenv("VLLM_XLA_USE_SPMD", "0"))), + "VLLM_FUSED_MOE_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(16 * 1024)) + ), + # Control whether to use fused MoE activation chunking. Current chunking + # logic is incompatible with torch.compile and causes IMA. See issue + # https://github.com/vllm-project/vllm/issues/19631. + "VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING": lambda: bool( + int(os.getenv("VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING", "1")) + ), # If set, the OpenAI API server will stay alive even after the underlying # AsyncLLMEngine errors and stops serving requests "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": lambda: bool( @@ -910,9 +908,6 @@ def _get_or_set_default() -> str: "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ["VLLM_DISABLED_KERNELS"].split(","), - "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool( - int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1")) - ), # Disable pynccl (using torch.distributed instead) "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") @@ -993,6 +988,19 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_TRITON_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_TRITON_GEMM", "True").lower() in ("true", "1") ), + # Enable AITER fused decode kernel for MLA (ROCm only, decode path only) + # Enable AITER fused kernels for MLA (ROCm only, prefill and decode) + # Fuses: RoPE + concat + KV cache write (prefill) or BMM + RoPE + + # concat + KV cache write (decode) in ONE kernel + # By default is enabled for AMD GPUs with FP8 support. + "VLLM_USE_AITER_FUSED": lambda: ( + os.getenv("VLLM_USE_AITER_FUSED", "True").lower() in ("true", "1") + ), + # AITER fused RoPE + KV cache write for prefill tokens + # By default is enabled when VLLM_USE_AITER_FUSED is enabled. + "VLLM_USE_AITER_PREFILL_FUSED": lambda: ( + os.getenv("VLLM_USE_AITER_PREFILL_FUSED", "True").lower() in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") @@ -1001,6 +1009,10 @@ def _get_or_set_default() -> str: "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Pad the weights for the moe kernel "VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), + # custom paged attention kernel for MI3* cards + "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( + os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") + ), # Whether to use the shuffled kv cache layout "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") @@ -1305,9 +1317,14 @@ def _get_or_set_default() -> str: ["throughput", "latency", "masked_gemm"], ), # Flashinfer fused allreduce backend. + # "auto" will default to "mnnvl", which performs mostly same/better than "trtllm". + # But "mnnvl" backend does not support fuse with quantization. + # TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph: + # https://github.com/vllm-project/vllm/issues/35772 + # Should switch back to "auto" if the issue is resolved. "VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices( "VLLM_FLASHINFER_ALLREDUCE_BACKEND", - "auto", + "trtllm", ["auto", "trtllm", "mnnvl"], ), # Control the workspace buffer size for the FlashInfer backend. @@ -1640,14 +1657,6 @@ def _get_or_set_default() -> str: "VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool( int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0")) ), - # NIXL EP environment variables - "VLLM_NIXL_EP_MAX_NUM_RANKS": lambda: int( - os.getenv("VLLM_NIXL_EP_MAX_NUM_RANKS", "32") - ), - # Whether enable XPU graph on Intel GPU - "VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool( - int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0")) - ), } @@ -1784,7 +1793,6 @@ def compile_factors() -> dict[str, object]: "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "VLLM_CPU_KVCACHE_SPACE", "VLLM_CPU_MOE_PREPACK", - "VLLM_ZENTORCH_WEIGHT_PREPACK", "VLLM_TEST_FORCE_LOAD_FORMAT", "VLLM_ENABLE_CUDA_COMPATIBILITY", "VLLM_CUDA_COMPATIBILITY_PATH", diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 0215ec1a0735..9abf121c9b8f 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -302,6 +302,12 @@ def __init__( prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + # RoPE caches for AITER fused kernels + cos_cache: torch.Tensor | None = None, + sin_cache: torch.Tensor | None = None, + is_neox_style: bool = False, + # RoPE module (static, doesn't change) + rotary_emb: torch.nn.Module | None = None, **extra_impl_args, ): super().__init__() @@ -314,6 +320,9 @@ def __init__( self.kv_lora_rank = kv_lora_rank self.kv_b_proj = kv_b_proj self.head_size = kv_lora_rank + qk_rope_head_dim + # Store rotary_emb module as class attribute + # (static, shared across all forwards) + self.rotary_emb = rotary_emb self.layer_name = prefix self.indexer = indexer @@ -440,6 +449,66 @@ def __init__( and self.kv_b_proj.weight.dtype == torch.bfloat16 ) + # Store RoPE caches for AITER fused kernels + self.cos_cache = cos_cache + self.sin_cache = sin_cache + self.is_neox_style = is_neox_style + + # Detect if AITER fused decode kernel can be used (AMD GPU only) + # Support both FP4 and FP8 variants based on GPU capabilities + self.use_aiter_fused = ( + current_platform.is_rocm() # AMD GPU only + and ( + self.is_aiter_triton_fp4_bmm_enabled + or self.is_aiter_triton_fp8_bmm_enabled + ) # FP4 or FP8 BMM available + and envs.VLLM_USE_AITER_FUSED # Feature flag enabled + and cos_cache is not None # RoPE caches available + and sin_cache is not None + ) + + if self.use_aiter_fused: + # Use unified RoPE + KV cache kernel for both prefill and decode + # Separate BMM will be used for decode (no BMM fusion) + try: + from aiter.ops.triton.fusions.fused_kv_cache import ( + fused_qk_rope_cat_and_cache_mla, + ) + + self._fused_rope_kv_kernel = fused_qk_rope_cat_and_cache_mla + + # Set kernel type for BMM selection + if self.is_aiter_triton_fp4_bmm_enabled: + self._fused_kernel_type = "fp4" + else: + self._fused_kernel_type = "fp8" + except ImportError as e: + logger.warning_once( + f"AITER fused RoPE+KV cache kernel not available: {e}, " + "falling back to separate ops", + scope="local", + ) + self.use_aiter_fused = False + + # Log when AITER fused kernels are enabled + if self.use_aiter_fused: + logger.info( + "AITER unified RoPE+KV fusion ENABLED for prefill+decode, " + "using %s BMM for decode", + self._fused_kernel_type.upper(), + ) + + # Enable prefill fusion (RoPE + KV cache write for prefill tokens) + # Same kernel now used for decode too + self.use_aiter_rope_kv_fused = ( + self.use_aiter_fused and envs.VLLM_USE_AITER_PREFILL_FUSED + ) + + if self.use_aiter_rope_kv_fused: + logger.info( + "AITER unified RoPE+KV fusion ENABLED (same kernel for prefill+decode)" + ) + # Attributes for forward_impl method self._vllm_config = get_current_vllm_config() self._chunked_prefill_workspace_size: int | None = None @@ -465,12 +534,21 @@ def forward( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, + use_fused_path: bool = False, + rotary_emb: torch.nn.Module | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + # Store AITER fusion parameters in forward_context for custom ops + forward_context: ForwardContext = get_forward_context() + if positions is not None: + forward_context._positions = positions + forward_context._use_fused_path = use_fused_path + if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] @@ -504,6 +582,10 @@ def forward( q, kv_c_normed, k_pe, self_kv_cache, attn_metadata ) else: + # Custom ops path (ROCm AITER) + if slot_mapping is not None: + forward_context.slot_mapping[self.layer_name] = slot_mapping + kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( kv_c_normed, k_pe, @@ -519,6 +601,8 @@ def forward( k_pe, output, self.layer_name, + positions, + slot_mapping, kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output @@ -528,6 +612,8 @@ def forward( kv_c_normed, k_pe, self.layer_name, + positions, + slot_mapping, kv_cache_dummy_dep=kv_cache_dummy_dep, ) @@ -541,9 +627,20 @@ def forward_impl( output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, + rope_applied: bool | None = None, + use_fused_path: bool | None = None, + rotary_emb: torch.nn.Module | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." + # Derive fusion flags from instance variables if not provided + if rope_applied is None: + rope_applied = not self.use_aiter_fused + if use_fused_path is None: + use_fused_path = self.use_aiter_fused + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported for MLA" @@ -600,11 +697,104 @@ def forward_impl( num_mqa_tokens = attn_metadata.num_decode_tokens num_mha_tokens = q.size(0) - num_mqa_tokens + # Fix positions tensor size to match actual batch + # positions may have extra padding that doesn't match q.size(0) + if positions is not None: + num_actual_tokens = q.size(0) + if positions.size(0) > num_actual_tokens: + positions = positions[:num_actual_tokens] + + # Retrieve slot_mapping from attn_metadata if not provided + # This is needed for both prefill and decode KV cache writes + if slot_mapping is None and attn_metadata is not None: + slot_mapping = attn_metadata.slot_mapping + + # Apply unified RoPE+KV fusion to entire batch (decode + prefill) + if ( + self.use_aiter_fused + and self.use_aiter_rope_kv_fused + and rotary_emb is not None + and positions is not None + and slot_mapping is not None + ): + # Single unified kernel call applies RoPE and writes KV cache + # for the entire batch. This must be done outside CUDA graph + # where num_mqa_tokens from attn_metadata is dynamically available. + logger.info_once( + "Using AITER unified RoPE+KV fusion (single call) for entire batch", + scope="local", + ) + + # Split Q into nope and pe components + q_nope = q[..., : self.qk_nope_head_dim] + q_pe = q[..., self.qk_nope_head_dim :] + + # Reshape K: [batch, dim] -> [batch, num_kv_heads, dim] + k_nope_3d = k_c_normed.view(-1, self.num_kv_heads, self.kv_lora_rank) + k_pe_3d = k_pe.squeeze(1).view(-1, self.num_kv_heads, self.qk_rope_head_dim) + + # Call unified kernel for entire batch + # num_decode_toks_for_zeros tells kernel to handle first + # num_mqa_tokens specially + q_fused, _, k_pe_out, _ = self._fused_rope_kv_kernel( + q_nope=q_nope, + q_pe=q_pe, + k_nope=k_nope_3d, + k_pe=k_pe_3d, + kv_cache=kv_cache, + slot_mapping=slot_mapping, + pos=positions, + cos=self.cos_cache, + sin=self.sin_cache, + k_scale=self._k_scale, + is_neox=self.is_neox_style, + num_decode_toks_for_zeros=num_mqa_tokens, + apply_scale=True, + q_out_dtype=q.dtype, + ) + + # Update tensors with fused results (RoPE applied, KV cache written) + q[:] = q_fused + k_pe[:] = k_pe_out + if num_mha_tokens > 0: + # Prefill path: process prefill tokens + prefill_q = q[num_mqa_tokens:] + prefill_k_c_normed = k_c_normed[num_mqa_tokens:] + prefill_k_pe = k_pe[num_mqa_tokens:] + + # Apply RoPE and write KV cache if not using unified fusion + if ( + self.use_aiter_fused + and rotary_emb is not None + and positions is not None + and not (self.use_aiter_rope_kv_fused and slot_mapping is not None) + ): + # Unfused path: apply RoPE separately + prefill_positions = positions[num_mqa_tokens:] + prefill_q[..., self.qk_nope_head_dim :], prefill_k_pe = rotary_emb( + prefill_positions, + prefill_q[..., self.qk_nope_head_dim :], + prefill_k_pe, + ) + + # Write prefill KV to cache + if slot_mapping is not None: + prefill_slot_mapping = slot_mapping[num_mqa_tokens:] + self.impl.do_kv_cache_update( + prefill_k_c_normed, + prefill_k_pe, + kv_cache, + prefill_slot_mapping, + self.kv_cache_dtype, + self._k_scale, + ) + + # Run prefill attention self.impl.forward_mha( - q[num_mqa_tokens:], - k_c_normed[num_mqa_tokens:], - k_pe[num_mqa_tokens:], + prefill_q, + prefill_k_c_normed, + prefill_k_pe, kv_cache, attn_metadata, self._k_scale, @@ -612,15 +802,18 @@ def forward_impl( ) if num_mqa_tokens > 0: + # Extract decode slices mqa_q = q[:num_mqa_tokens] mqa_output_slice = output[:num_mqa_tokens] - mqa_q_nope, mqa_q_pe = mqa_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) + # Split Q for unfused path (fused path does this later) + if not (self.use_aiter_rope_kv_fused and slot_mapping is not None): + mqa_q_nope, mqa_q_pe = mqa_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) - # Convert from (B, N, P) to (N, B, P) - mqa_q_nope = mqa_q_nope.transpose(0, 1) + # Convert from (B, N, P) to (N, B, P) + mqa_q_nope = mqa_q_nope.transpose(0, 1) if self.q_pad_num_heads is not None: B, N, L = mqa_q_pe.shape @@ -629,6 +822,54 @@ def forward_impl( mqa_pe_padded.copy_(mqa_q_pe) mqa_q_pe = mqa_pe_padded + # Compute positions from seq_lens if not provided + # For decode tokens, position = seq_lens - 1 + # (current position in sequence) + # This matches the logic in prepare_pos_seq_lens_kernel + # where pos = num_computed_tokens + # and seq_len = num_computed_tokens + query_len, + # so pos = seq_len - query_len + # For decode (query_len=1): pos = seq_len - 1 + if positions is None and attn_metadata.decode is not None: + # Get decode sequence lengths for decode tokens only + decode_seq_lens = attn_metadata.decode.seq_lens + # Position is current sequence length - 1 (0-indexed) + positions = decode_seq_lens - 1 + logger.info_once( + "[MLA] Computed positions from decode seq_lens: shape=%s", + positions.shape, + scope="local", + ) + + # CUDA graph compatible: Use STATIC flag + # self.use_aiter_fused is class attribute, same for all + # batches. num_mqa_tokens > 0 is dynamic but OK - PyTorch + # handles control flow in graphs + + # Extract decode Q from already-processed batch (if fused path was used) + # RoPE+KV was already applied in the unified call above + if self.use_aiter_rope_kv_fused and slot_mapping is not None: + # Extract decode portion from already RoPE'd batch + # mqa_q was extracted earlier as q[:num_mqa_tokens] + # It now has RoPE applied from the unified kernel call + + # Log when decode uses unified fusion results + logger.info_once( + "Decode using unified RoPE+KV results, running %s BMM", + self._fused_kernel_type.upper(), + scope="local", + ) + + # mqa_q already extracted earlier, just split it + # mqa_q: [batch, num_heads, qk_nope_head_dim + qk_rope_head_dim] + # Note: RoPE already applied by unified kernel + mqa_q_nope, mqa_q_pe = mqa_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Transpose Q nope: [batch, num_heads, dim] -> [num_heads, batch, dim] + mqa_q_nope = mqa_q_nope.transpose(0, 1) + if self.is_aiter_triton_fp4_bmm_enabled: from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 @@ -705,12 +946,14 @@ def forward_impl( # v_up projection self._v_up_proj(attn_out, out=mqa_output_slice) + return output_padded def process_weights_after_loading(self, act_dtype: torch.dtype): # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and + # perform the bmm's in 16-bit, the extra memory overhead of + # this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights( self.kv_b_proj, out_dtype=act_dtype ).T @@ -886,14 +1129,59 @@ def unified_mla_attention( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: # kv_cache_dummy_dep is not used but accepting it creates a data dependency # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep - attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) - output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) + attn_metadata, layer, kv_cache, forward_context = get_attention_context(layer_name) + + # positions and slot_mapping come from parameters (passed through compiled graph) + # rotary_emb retrieved from layer (stored as class attribute during __init__) + rotary_emb = layer.rotary_emb + + # Retrieve slot_mapping from forward_context or attn_metadata + slot_mapping = None + if hasattr(forward_context, "slot_mapping") and isinstance( + forward_context.slot_mapping, dict + ): + slot_mapping = forward_context.slot_mapping.get(layer_name) + + # Fallback: get slot_mapping from attn_metadata if not in forward_context + # This happens with torch.compile when forward_context doesn't persist + if slot_mapping is None and attn_metadata is not None: + slot_mapping = attn_metadata.slot_mapping + + logger.info_once( + f"[unified_mla_attention] RETRIEVED: " + f"positions={'exists' if positions is not None else 'None'}, " + f"slot_mapping={'exists' if slot_mapping is not None else 'None'}, " + f"layer={layer_name}", + scope="local", + ) + + # Determine rope_applied and use_fused_path from layer config + # STATIC decision based on whether AITER kernels available + # Assumptions: rotary_emb always exists, positions always provided + # Therefore: use_aiter_fused is the sole deciding factor + use_fused_path = layer.use_aiter_fused + rope_applied = not use_fused_path + + output = layer.forward_impl( + q, + kv_c_normed, + k_pe, + kv_cache, + attn_metadata, + positions=positions, + slot_mapping=slot_mapping, + rope_applied=rope_applied, + use_fused_path=use_fused_path, + rotary_emb=rotary_emb, + ) return output @@ -903,6 +1191,8 @@ def unified_mla_attention_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() @@ -925,8 +1215,11 @@ def unified_mla_kv_cache_update( k_scale: torch.Tensor, ) -> torch.Tensor: """ - Returns a dummy that is passed to unified_attention to signal a side effect and - the data dependency between them to ensure torch.compile preserves ordering. + Write KV cache for UNFUSED path only. + For fused path, KV cache writes happen in forward_impl: + - Prefill tokens: after RoPE applied (line 806) + - Decode tokens: in fused kernel (line 890+) + Returns a dummy tensor to signal side effect for torch.compile ordering. """ forward_context = get_forward_context() if forward_context.attn_metadata is None: @@ -934,6 +1227,14 @@ def unified_mla_kv_cache_update( return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) attn_layer = forward_context.no_compile_layers[layer_name] + + # Check if AITER fused kernels are available (static decision) + if attn_layer.use_aiter_fused: + # FUSED path: Skip KV write here, forward_impl handles it + # (Prefill: after RoPE, Decode: in fused kernel) + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + # UNFUSED path: Write all tokens to KV cache here kv_cache = attn_layer.kv_cache slot_mapping = forward_context.slot_mapping @@ -941,7 +1242,8 @@ def unified_mla_kv_cache_update( f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " ) layer_slot_mapping = slot_mapping.get(layer_name) - if layer_slot_mapping is not None: + + if layer_slot_mapping is not None and kv_c_normed.shape[0] > 0: attn_layer.impl.do_kv_cache_update( kv_c_normed, k_pe, @@ -978,6 +1280,8 @@ def unified_mla_attention_with_output( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, @@ -986,7 +1290,24 @@ def unified_mla_attention_with_output( # that ensures torch.compile preserves ordering between KV cache update and # attention forward. del kv_cache_dummy_dep - attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) + attn_metadata, layer, kv_cache, forward_context = get_attention_context(layer_name) + + # If slot_mapping is None, retrieve it from attn_metadata as fallback + # (happens when called from mla.py which doesn't have slot_mapping) + if ( + slot_mapping is None + and attn_metadata is not None + and hasattr(attn_metadata, "slot_mapping") + ): + slot_mapping = attn_metadata.slot_mapping + + # Retrieve rotary_emb from layer (stored as class attribute during __init__) + rotary_emb = layer.rotary_emb + + # Determine whether to use AITER fused path based on layer config + use_fused_path = layer.use_aiter_fused + rope_applied = not use_fused_path + layer.forward_impl( q, kv_c_normed, @@ -996,6 +1317,11 @@ def unified_mla_attention_with_output( output=output, output_scale=output_scale, output_block_scale=output_block_scale, + positions=positions, + slot_mapping=slot_mapping, + rope_applied=rope_applied, + use_fused_path=use_fused_path, + rotary_emb=rotary_emb, ) @@ -1005,6 +1331,8 @@ def unified_mla_attention_with_output_fake( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, + slot_mapping: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, @@ -2503,7 +2831,12 @@ def _compute_prefill_context( if use_fp8_prefill or _kv_b_proj_w_dtype != current_platform.fp8_dtype(): kv_c_normed = kv_c_normed.to(_kv_b_proj_w_dtype) - k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) + # Extract k_pe from workspace + # workspace shape: 2D [toks, dim] (from gather ops) + # Ensure k_pe: [toks, num_kv_heads=1, qk_rope_head_dim] + k_pe = workspace[:toks][..., self.kv_lora_rank :] + if k_pe.ndim == 2: + k_pe = k_pe.unsqueeze(1) # [toks, pe_dim] -> [toks, 1, pe_dim] kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index e8ed8a5249d1..b3acc89712cb 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -106,7 +106,7 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 def forward( - self, x: torch.Tensor + self, x: torch.Tensor, x_scale: torch.Tensor | None = None ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: # Tier 1: DSV3 specialized kernel if self.allow_dsv3_router_gemm and x.shape[0] <= 16: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 44fd516f5e5c..c36e513463d5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -222,7 +222,11 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: + assert input_scale is None, ( + "UnquantizedLinearMethod does not support input_scale" + ) if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike(): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -384,11 +388,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def forward( self, x: torch.Tensor, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) + output = self.quant_method.apply(self, x, bias, input_scale=x_scale) if not self.return_bias: return output @@ -574,12 +579,15 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) + output_parallel = self.quant_method.apply( + self, input_, bias, input_scale=x_scale + ) if self.gather_output and self.tp_size > 1: # All-gather across the partitions. @@ -1512,6 +1520,7 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor def forward( self, input_, + x_scale: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: if self.input_is_parallel: input_parallel = input_ @@ -1523,10 +1532,12 @@ def forward( # Matrix multiply. assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) + # Only fuse bias add into GEMM for rank 0 (ensures bias not + # added multiple times in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, input_parallel, bias_) + output_parallel = self.quant_method.apply( + self, input_parallel, bias_, input_scale=x_scale + ) if self.reduce_results and self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index 1d3e987b7e17..21a5b737f23b 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -5,10 +5,103 @@ import torch from vllm.config import CacheConfig +from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.quantization import QuantizationConfig +logger = init_logger(__name__) + +# Try to import AITER ops for fused kernels +try: + from aiter import dtypes + from aiter.jit.utils.torch_guard import torch_compile_guard + from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant + + _AITER_AVAILABLE = True +except ImportError: + _AITER_AVAILABLE = False + dtypes = None + torch_compile_guard = None + fused_rms_fp8_group_quant = None + + +def _fused_rms_fp8_group_quant_fake( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile/CUDA graphs. + + Returns tuple: (out1_quantized, out1_bs, out2) + """ + if dtype_quant is None: + dtype_quant = dtypes.fp8 + m, n1 = q_c.shape + out1_quantized = torch.empty((m, n1), dtype=dtype_quant, device=q_c.device) + out1_bs = torch.empty( + (m, (n1 + group_size - 1) // group_size), dtype=torch.float32, device=q_c.device + ) + if transpose_scale: + out1_bs = out1_bs.transpose(0, 1).contiguous().view(*out1_bs.shape) + out2 = torch.empty_like(kv_c) + # Return tuple for ATOM-style pattern + return out1_quantized, out1_bs, out2 + + +def _fuse_rmsnorm_quant_impl( + q_c: torch.Tensor, + q_a_layernorm_weight: torch.Tensor, + q_a_layernorm_variance_epsilon: float, + kv_c: torch.Tensor, + kv_a_layernorm_weight: torch.Tensor, + kv_a_layernorm_variance_epsilon: float, + dtype_quant: torch.dtype | None = None, + group_size: int = 128, + output_unquantized_inp1: bool = False, + transpose_scale: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fused dual RMSNorm + FP8 quantization using AITER (ATOM pattern). + + Fuses: + 1. RMSNorm on q_c + 2. FP8 group quantization on q_c + 3. RMSNorm on kv_c (without quantization) + + Based on ATOM's implementation in deepseek_v2.py:245-280 + + Returns: + (q_c_quantized, q_c_scale, kv_c_normed) + + Uses @torch_compile_guard decorator for CUDA graph compatibility. + """ + (q_c_quantized, q_c_scale), _, kv_c_normed, _ = fused_rms_fp8_group_quant( + q_c, + q_a_layernorm_weight, + q_a_layernorm_variance_epsilon, + kv_c, + kv_a_layernorm_weight, + kv_a_layernorm_variance_epsilon, + group_size, + dtype_quant, + None, + output_unquantized_inp1, + transpose_scale, + ) + return q_c_quantized, q_c_scale, kv_c_normed + + +# Make fusion transparent to compiler (no @torch_compile_guard) +# This allows the compiler to trace through and batch operations efficiently +_fuse_rmsnorm_quant = _fuse_rmsnorm_quant_impl + @dataclass class MLAModules: @@ -87,6 +180,21 @@ def __init__( self.indexer_rope_emb = mla_modules.indexer_rotary_emb self.is_sparse = mla_modules.is_sparse + # Extract RoPE caches for AITER fused kernels + if self.rotary_emb is not None: + # RoPE stores combined cos_sin_cache, need to split it + # Format: [seq_len, rotary_dim] where first half is cos, second half is sin + cos_sin_cache = self.rotary_emb.cos_sin_cache + rotary_dim = self.rotary_emb.rotary_dim + half_dim = rotary_dim // 2 + self.cos_cache = cos_sin_cache[:, :half_dim] + self.sin_cache = cos_sin_cache[:, half_dim:] + self.is_neox_style = self.rotary_emb.is_neox_style + else: + self.cos_cache = None + self.sin_cache = None + self.is_neox_style = False + if self.indexer is not None: assert hasattr(self.indexer, "topk_tokens") self.topk_tokens = self.indexer.topk_tokens @@ -106,10 +214,35 @@ def __init__( kv_b_proj=self.kv_b_proj, use_sparse=self.is_sparse, indexer=self.indexer, + # Pass RoPE caches for AITER fused kernels + cos_cache=self.cos_cache, + sin_cache=self.sin_cache, + is_neox_style=self.is_neox_style, + # Pass RoPE module (static, doesn't change) + rotary_emb=self.rotary_emb, ) self.prefix = prefix + # Determine if RMSNorm+Quant fusion should be enabled + # Fusion requires AITER and FP8 quantization + self.quant_config = quant_config + self.quant_dtype = None + self.fuse_qknorm_quant = False + + if _AITER_AVAILABLE and quant_config is not None: + # Check if quant_config is FP8 + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + if isinstance(quant_config, Fp8Config): + self.quant_dtype = dtypes.fp8 + self.fuse_qknorm_quant = True + logger.info( + "[MLA_FUSION_INIT] Fusion enabled for %s: " + "AITER available and FP8 quantization detected", + prefix, + ) + def forward( self, positions: torch.Tensor, @@ -118,6 +251,7 @@ def forward( ) -> torch.Tensor: q_c = None kv_lora = None + q_c_scale = None # For FP8 quantized path if self.q_lora_rank is not None: assert self.fused_qkv_a_proj is not None, ( @@ -130,13 +264,37 @@ def forward( "q_b_proj is required when q_lora_rank is not None" ) + # QKV projection qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_lora = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + + # Apply RMSNorm and optional FP8 quantization fusion + if self.fuse_qknorm_quant: + # Fused RMSNorm + FP8 quantization + q_c_quantized, q_c_scale, kv_c_normed = _fuse_rmsnorm_quant( + q_c, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + kv_c, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + dtype_quant=self.quant_dtype, + group_size=128, + output_unquantized_inp1=False, + transpose_scale=True, + ) + q = self.q_b_proj(q_c_quantized, x_scale=q_c_scale)[0] + else: + # Unfused path: standard RMSNorm + q_c = self.q_a_layernorm(q_c) + kv_c_normed = self.kv_a_layernorm(kv_c) + q = self.q_b_proj(q_c)[0] else: assert self.kv_a_proj_with_mqa is not None, ( "kv_a_proj_with_mqa is required when q_lora_rank is None" @@ -146,18 +304,36 @@ def forward( ) kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) + kv_c, k_pe = kv_lora.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe k_pe = k_pe.unsqueeze(1) + # Determine if AITER fused RoPE+KV path can be used + can_use_fused_path = ( + hasattr(self.mla_attn, "use_aiter_fused") + and self.mla_attn.use_aiter_fused + and positions is not None + and self.rotary_emb is not None + ) + + # Apply RoPE if not using fused path if self.rotary_emb is not None: - q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( - positions, q[..., self.qk_nope_head_dim :], k_pe - ) + if can_use_fused_path: + # Fused path: RoPE applied in unified kernel + pass + else: + # Unfused path: apply RoPE here + q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( + positions, + q[..., self.qk_nope_head_dim :], + k_pe, + ) if self.indexer and self.is_sparse: _topk_indices = self.indexer( @@ -167,11 +343,23 @@ def forward( if llama_4_scaling is not None: q *= llama_4_scaling + # Store rotary_emb in forward_context for custom ops + from vllm.forward_context import get_forward_context + + forward_context = get_forward_context() + if self.rotary_emb is not None: + forward_context._rotary_emb = self.rotary_emb + attn_out = self.mla_attn( q, kv_c_normed, k_pe, output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), + positions=positions, + slot_mapping=None, + use_fused_path=can_use_fused_path, + rotary_emb=self.rotary_emb, ) - return self.o_proj(attn_out)[0] + final_out = self.o_proj(attn_out)[0] + return final_out diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 69255a2793cb..9ae347d08feb 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -441,6 +441,7 @@ def apply( layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + input_scale: torch.Tensor | None = None, ) -> torch.Tensor: # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. @@ -451,7 +452,9 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=input_scale + if input_scale is not None + else layer.input_scale, bias=bias, ) else: @@ -488,7 +491,7 @@ def apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale_inv, - input_scale=layer.input_scale, + input_scale=input_scale, bias=bias, ) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9568d1320bc6..3aa1b77fb1f2 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -399,32 +399,44 @@ def apply( weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, bias: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: - assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - output_dtype = input.dtype + # Use provided output_dtype, or default based on whether input is + # pre-quantized (bfloat16) or not (input.dtype) + if output_dtype is None: + output_dtype = input.dtype if input_scale is None else torch.bfloat16 if should_use_flashinfer_for_blockscale_fp8_gemm( self.is_flashinfer_supported, output_dtype, input_2d, weight ) and should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # FlashInfer: does not support pre-quantized input + assert input_scale is None, ( + "FlashInfer FP8 blockscale GEMM does not support pre-quantized input" + ) output = self._run_flashinfer(input_2d, weight, weight_scale) elif should_use_deepgemm_for_fp8_linear( output_dtype, weight, self.is_deep_gemm_supported ): + # DeepGEMM: does not support pre-quantized input + assert input_scale is None, ( + "DeepGEMM FP8 linear does not support pre-quantized input" + ) output = self._run_deepgemm(input_2d, weight, weight_scale) else: - output = self.w8a8_blockscale_op( - input_2d, weight, weight_scale, input_scale + # AITER/Triton/Cutlass: supports pre-quantized input + output = self.w8a8_blockscale_op( # type: ignore[call-arg] + input_2d, weight, weight_scale, input_scale, output_dtype ) if bias is not None: output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + return output.view(*output_shape) def _run_deepgemm( self, @@ -450,10 +462,19 @@ def _run_cutlass( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + if output_dtype is None: + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + if output_dtype is None: + output_dtype = torch.bfloat16 if self.is_hopper: return torch.ops.vllm.padded_cutlass( q_input, @@ -461,7 +482,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) else: return cutlass_scaled_mm( @@ -470,7 +491,7 @@ def _run_cutlass( input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_aiter( @@ -479,6 +500,7 @@ def _run_aiter( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: assert self.act_quant_group_shape == GroupShape(1, 128) @@ -495,9 +517,15 @@ def _run_aiter( gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale if input_scale is not None: + # Use pre-quantized FP8 input directly q_input = input_2d + if output_dtype is None: + output_dtype = torch.bfloat16 else: + # Quantize input if not already quantized q_input, input_scale = self.input_quant_op(input_2d, use_triton=use_triton) + if output_dtype is None: + output_dtype = input_2d.dtype return gemm_a8w8_blockscale_op( q_input, @@ -505,7 +533,7 @@ def _run_aiter( input_scale, weight_scale, list(self.weight_group_shape), - output_dtype=input_2d.dtype, + output_dtype=output_dtype, ) def _run_triton( @@ -514,17 +542,26 @@ def _run_triton( weight: torch.Tensor, weight_scale: torch.Tensor, input_scale: torch.Tensor | None = None, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: - assert input_scale is None - assert self.input_quant_op is not None - q_input, input_scale = self.input_quant_op(input_2d) + if input_scale is None: + # Quantize input if not already quantized + assert self.input_quant_op is not None + q_input, input_scale = self.input_quant_op(input_2d) + if output_dtype is None: + output_dtype = input_2d.dtype + else: + # Use pre-quantized FP8 input directly + q_input = input_2d + if output_dtype is None: + output_dtype = torch.bfloat16 return torch.ops.vllm.w8a8_triton_block_scaled_mm_func( q_input, weight, input_scale, weight_scale, list(self.weight_group_shape), - input_2d.dtype, + output_dtype, ) def _run_flashinfer(