diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 8d35aa65738b..dde71d9df9ba 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -6,9 +6,12 @@ import torch import vllm.envs as envs +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer +logger = init_logger(__name__) + def is_aiter_found() -> bool: from importlib.util import find_spec @@ -352,6 +355,72 @@ def _rocm_aiter_gemm_w8a8_blockscale_fake( return Y +@functools.lru_cache(maxsize=1) +def _initialize_hipblaslt(): + from aiter import hipb_create_extension + + hipb_create_extension() + + +def _gemm_weight_bpreshuffle_impl( + input: torch.Tensor, # [M, K] + weight: torch.Tensor, # [K, N] + bias: torch.Tensor | None = None, # [N] + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, # None, (1,) or (M,1) + scale_b: torch.Tensor | None = None, # None, (1,) or (1,N) +) -> torch.Tensor: + _initialize_hipblaslt() + + if out_dtype is None: + out_dtype = torch.bfloat16 + + assert out_dtype == torch.bfloat16, ( + f"hip_bpreshuffle_gemm only supports bfloat16 output dtype" + f", you have passed in {out_dtype}" + ) + if input.dim() >= 3: + inp_view = input.view(-1, input.size(-1)) + batched = True + else: + inp_view = input + batched = False + + from aiter import hipb_mm + + output = hipb_mm( + inp_view, + weight, + solution_index=-1, + bias=bias, + out_dtype=out_dtype, + scaleA=scale_a, + scaleB=scale_b, + scaleOut=None, + bpreshuffle=True, + ) + + if batched: + output = output.view(*input.shape[:-1], weight.shape[1]) + + return output + + +def _gemm_weight_bpreshuffle_fake( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, + scale_b: torch.Tensor | None = None, +) -> torch.Tensor: + if out_dtype is None: + out_dtype = torch.bfloat16 + return torch.empty( + *input.shape[:-1], weight.shape[1], dtype=out_dtype, device=input.device + ) + + def _rocm_aiter_rms_norm_impl( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: @@ -409,6 +478,7 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake( class rocm_aiter_ops: _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR + _LINEAR_SHUFFLE_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA @@ -438,6 +508,16 @@ def is_linear_fp8_enaled(cls) -> bool: """ "Verifies device specs and availability of env variable.""" return cls.is_linear_enabled() and current_platform.is_fp8_fnuz() + @classmethod + @if_aiter_supported + def is_linear_shuffle_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return ( + cls._AITER_ENABLED + and cls.is_linear_enabled() + and cls._LINEAR_SHUFFLE_ENABLED + ) + @classmethod @if_aiter_supported def is_rmsnorm_enabled(cls) -> bool: @@ -570,6 +650,14 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="gemm_weight_bpreshuffle", + op_func=_gemm_weight_bpreshuffle_impl, + mutates_args=[], + fake_impl=_gemm_weight_bpreshuffle_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_rms_norm", op_func=_rocm_aiter_rms_norm_impl, @@ -629,6 +717,30 @@ def gemm_w8a8_blockscale( A, B, As, Bs, output_dtype ) + @staticmethod + def gemm_weight_bpreshuffle( + input: torch.Tensor, # [M, K] + weight: torch.Tensor, # [K, N] + bias: torch.Tensor | None = None, # [N] + out_dtype: torch.dtype | None = None, + scale_a: torch.Tensor | None = None, # None, (1,) or (M,1) + scale_b: torch.Tensor | None = None, # None, (1,) or (1,N) + ) -> torch.Tensor: + return torch.ops.vllm.gemm_weight_bpreshuffle( + input=input, + weight=weight, + bias=bias, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + ) + + @staticmethod + def gemm_weight_can_shuffle(n: int, k: int, layout: tuple[int, int]) -> bool: + IN, IK = layout + BK = IK * 2 + return (n % IN == 0) and (k % BK == 0) + @staticmethod def fused_moe( hidden_states: torch.Tensor, @@ -908,7 +1020,7 @@ def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool: @staticmethod def shuffle_weight( - self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) + tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) ) -> torch.Tensor: from aiter.ops.shuffle import shuffle_weight diff --git a/vllm/envs.py b/vllm/envs.py index 52178e5f5250..7385b95200ff 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -104,6 +104,7 @@ VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True + VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True @@ -901,6 +902,12 @@ def get_vllm_port() -> int | None: "VLLM_ROCM_USE_AITER_LINEAR": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") ), + # Whether to use aiter linear with weights shuffled. + # This flag is only used in development. It will be removed in the future + # By default is enabled. + "VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE", "True").lower() in ("true", "1") + ), # Whether to use aiter moe ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MOE": lambda: ( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index dfcc601a1c53..87e2d79cd2c5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -8,6 +8,7 @@ import torch from torch.nn.parameter import Parameter, UninitializedParameter +from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -225,19 +226,38 @@ def create_weights( layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) + self.weight_shuffled = False + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if current_platform.is_cpu(): from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm dispatch_cpu_unquantized_gemm(layer, remove_weight=True) + if rocm_aiter_ops.is_linear_shuffle_enabled(): + weight = layer.weight + layout = (16, 16) + + if rocm_aiter_ops.gemm_weight_can_shuffle( + weight.shape[0], weight.shape[1], layout + ): + shuffled_weight = rocm_aiter_ops.shuffle_weight(weight, layout).t() + self.weight_shuffled = True + layer.register_parameter( + "weight", Parameter(shuffled_weight.data, requires_grad=False) + ) + + layer.weight_shuffled = self.weight_shuffled + def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) + return dispatch_unquantized_gemm( + rocm_aiter_weight_shuffled=self.weight_shuffled + )(layer, x, layer.weight, bias) class LinearBase(CustomOp): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 6da136cbc8f6..456205745494 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -148,6 +148,30 @@ def process_weights_after_loading(self, layer) -> None: weight, weight_scale, input_scale = process_fp8_weight_channel_strategy( layer.weight, layer.weight_scale, getattr(layer, "input_scale", None) ) + + if ( + self.use_aiter_and_is_supported + and rocm_aiter_ops.is_linear_shuffle_enabled() + ): + layout = (16, 16) + use_swizzle_gemm = rocm_aiter_ops.gemm_weight_can_shuffle( + weight.shape[0], weight.shape[1], layout=layout + ) + + self.use_aiter_and_is_supported = ( + self.use_aiter_and_is_supported and use_swizzle_gemm + ) + + if self.use_aiter_and_is_supported: + weight = rocm_aiter_ops.shuffle_weight(weight, layout) + weight_scale = weight_scale.t() + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape, + rocm_aiter_weight_shuffled=True, + ) + weight = weight.t() elif self.strategy == QuantizationStrategy.BLOCK: diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py index 7ded8eea7906..eef0a38a4e90 100644 --- a/vllm/model_executor/layers/quantization/input_quant_fp8.py +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -97,7 +97,7 @@ def forward_native( x: torch.Tensor, scale: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None, - ): + ) -> tuple[torch.Tensor, torch.Tensor]: if self.is_group_quant: assert scale is None, "Group quantization is always dynamic" return self._quantize_group_native(x) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 1e5ee93b61f2..6c94a5e636b5 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,6 +7,7 @@ import torch from torch.nn import Parameter +from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -96,6 +97,23 @@ def process_weights_after_loading(self, layer) -> None: weight_scale = layer.weight_scale.data if self.act_quant_group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) + + if rocm_aiter_ops.is_linear_shuffle_enabled(): + layout = (16, 16) + use_swizzle_gemm = rocm_aiter_ops.gemm_weight_can_shuffle( + weight.shape[0], weight.shape[1], layout=layout + ) + + if use_swizzle_gemm: + weight = rocm_aiter_ops.shuffle_weight(weight, layout) + weight_scale = weight_scale.t() + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_quant_group_shape, + rocm_aiter_weight_shuffled=True, + ) + layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter layer.weight_scale = Parameter(weight_scale, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 7fe902807a74..b1cc29846c6b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import CompilationMode, get_current_vllm_config from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape @@ -357,8 +358,43 @@ def torch_channelwise_w8a8_scaled_mm( return output.to(out_dtype).view(*output_shape) +def aiter_ptpc_w8a8_scaled_mm_bpreshuffled( + *, + qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + bias: torch.Tensor, + output_shape: list, +) -> torch.Tensor: + # K_size = qinput.shape[-1] + # if K_size == 8192: + # output = rocm_aiter_ops.gemm_weight_bpreshuffle_fp8_ck( + # input=qinput, + # weight=weight.t(), + # bias=bias, + # out_dtype=out_dtype, + # scale_a=scale_a, + # scale_b=scale_b.t(), + # ) + # return output.view(*output_shape) + + return rocm_aiter_ops.gemm_weight_bpreshuffle( + input=qinput, + weight=weight, + bias=bias, + out_dtype=out_dtype, + scale_a=scale_a, + scale_b=scale_b, + ).view(*output_shape) + + def dispatch_w8a8_scaled_mm( - preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool + preferred_backend: str, + per_tensor_weights: bool, + per_tensor_activations: bool, + rocm_aiter_weight_shuffled: bool = False, ) -> Callable[..., torch.Tensor]: if per_tensor_weights and per_tensor_activations: if preferred_backend == "rocm": @@ -374,12 +410,11 @@ def dispatch_w8a8_scaled_mm( return cutlass_w8a8_scaled_mm # If torch.scaled_mm supports per-channel (weights) per-token (inputs) - if ( - not per_tensor_weights - and not per_tensor_activations - and USE_ROWWISE_TORCH_SCALED_MM - ): - return torch_per_token_w8a8_scaled_mm + if not per_tensor_weights and not per_tensor_activations: + if rocm_aiter_weight_shuffled: + return aiter_ptpc_w8a8_scaled_mm_bpreshuffled + if USE_ROWWISE_TORCH_SCALED_MM: + return torch_per_token_w8a8_scaled_mm # Normally, torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token return torch_channelwise_w8a8_scaled_mm @@ -400,6 +435,7 @@ def __init__( act_quant_static: bool, act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, pad_output: bool | None = None, + rocm_aiter_weight_shuffled: bool | None = None, ): if current_platform.is_rocm(): self.preferred_backend = "rocm" @@ -432,6 +468,10 @@ def __init__( num_token_padding=self.output_padding, ) + self.rocm_aiter_weight_shuffled = ( + False if rocm_aiter_weight_shuffled is None else rocm_aiter_weight_shuffled + ) + def apply( self, input: torch.Tensor, @@ -477,7 +517,10 @@ def apply( # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( - self.preferred_backend, per_tensor_weights, per_tensor_activations + self.preferred_backend, + per_tensor_weights, + per_tensor_activations, + self.rocm_aiter_weight_shuffled, ) return w8a8_scaled_mm_func( diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index b17bdd0b7207..f0836d17904e 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,6 +8,7 @@ from vllm import _custom_ops as ops from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -182,6 +183,20 @@ def rocm_unquantized_gemm( ) +def aiter_unquantized_gemm_bpreshuffled( + layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + return rocm_aiter_ops.gemm_weight_bpreshuffle( + input=x, + weight=weight, + bias=bias, + out_dtype=x.dtype, + ) + + def check_cpu_sgl_kernel(n: int, k: int, dtype: torch.dtype) -> bool: return ( torch._C._cpu._is_amx_tile_supported() @@ -241,8 +256,12 @@ def cpu_unquantized_gemm( return layer.cpu_linear(x, weight, bias) -def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: +def dispatch_unquantized_gemm( + rocm_aiter_weight_shuffled: bool = False, +) -> Callable[..., torch.Tensor]: if current_platform.is_rocm(): + if rocm_aiter_weight_shuffled: + return aiter_unquantized_gemm_bpreshuffled return rocm_unquantized_gemm elif current_platform.is_cpu(): return cpu_unquantized_gemm