Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 113 additions & 1 deletion vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer

logger = init_logger(__name__)


def is_aiter_found() -> bool:
from importlib.util import find_spec
Expand Down Expand Up @@ -352,6 +355,72 @@ def _rocm_aiter_gemm_w8a8_blockscale_fake(
return Y


@functools.lru_cache(maxsize=1)
def _initialize_hipblaslt():
from aiter import hipb_create_extension

hipb_create_extension()


def _gemm_weight_bpreshuffle_impl(
input: torch.Tensor, # [M, K]
weight: torch.Tensor, # [K, N]
bias: torch.Tensor | None = None, # [N]
out_dtype: torch.dtype | None = None,
scale_a: torch.Tensor | None = None, # None, (1,) or (M,1)
scale_b: torch.Tensor | None = None, # None, (1,) or (1,N)
) -> torch.Tensor:
_initialize_hipblaslt()

if out_dtype is None:
out_dtype = torch.bfloat16

assert out_dtype == torch.bfloat16, (
f"hip_bpreshuffle_gemm only supports bfloat16 output dtype"
f", you have passed in {out_dtype}"
)
if input.dim() >= 3:
inp_view = input.view(-1, input.size(-1))
batched = True
else:
inp_view = input
batched = False

from aiter import hipb_mm

output = hipb_mm(
inp_view,
weight,
solution_index=-1,
bias=bias,
out_dtype=out_dtype,
scaleA=scale_a,
scaleB=scale_b,
scaleOut=None,
bpreshuffle=True,
)

if batched:
output = output.view(*input.shape[:-1], weight.shape[1])

return output


def _gemm_weight_bpreshuffle_fake(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
out_dtype: torch.dtype | None = None,
scale_a: torch.Tensor | None = None,
scale_b: torch.Tensor | None = None,
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.bfloat16
return torch.empty(
*input.shape[:-1], weight.shape[1], dtype=out_dtype, device=input.device
)


def _rocm_aiter_rms_norm_impl(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
Expand Down Expand Up @@ -409,6 +478,7 @@ def _rocm_aiter_rmsnorm2d_fwd_with_add_fake(
class rocm_aiter_ops:
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_LINEAR_SHUFFLE_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
Expand Down Expand Up @@ -438,6 +508,16 @@ def is_linear_fp8_enaled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()

@classmethod
@if_aiter_supported
def is_linear_shuffle_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return (
cls._AITER_ENABLED
and cls.is_linear_enabled()
and cls._LINEAR_SHUFFLE_ENABLED
)

@classmethod
@if_aiter_supported
def is_rmsnorm_enabled(cls) -> bool:
Expand Down Expand Up @@ -570,6 +650,14 @@ def register_ops_once() -> None:
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="gemm_weight_bpreshuffle",
op_func=_gemm_weight_bpreshuffle_impl,
mutates_args=[],
fake_impl=_gemm_weight_bpreshuffle_fake,
dispatch_key=current_platform.dispatch_key,
)

direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=_rocm_aiter_rms_norm_impl,
Expand Down Expand Up @@ -629,6 +717,30 @@ def gemm_w8a8_blockscale(
A, B, As, Bs, output_dtype
)

@staticmethod
def gemm_weight_bpreshuffle(
input: torch.Tensor, # [M, K]
weight: torch.Tensor, # [K, N]
bias: torch.Tensor | None = None, # [N]
out_dtype: torch.dtype | None = None,
scale_a: torch.Tensor | None = None, # None, (1,) or (M,1)
scale_b: torch.Tensor | None = None, # None, (1,) or (1,N)
) -> torch.Tensor:
return torch.ops.vllm.gemm_weight_bpreshuffle(
input=input,
weight=weight,
bias=bias,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)

@staticmethod
def gemm_weight_can_shuffle(n: int, k: int, layout: tuple[int, int]) -> bool:
IN, IK = layout
BK = IK * 2
return (n % IN == 0) and (k % BK == 0)

@staticmethod
def fused_moe(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -908,7 +1020,7 @@ def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:

@staticmethod
def shuffle_weight(
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> torch.Tensor:
from aiter.ops.shuffle import shuffle_weight

Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
Expand Down Expand Up @@ -901,6 +902,12 @@ def get_vllm_port() -> int | None:
"VLLM_ROCM_USE_AITER_LINEAR": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1")
),
# Whether to use aiter linear with weights shuffled.
# This flag is only used in development. It will be removed in the future
# By default is enabled.
"VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_LINEAR_SHUFFLE", "True").lower() in ("true", "1")
),
# Whether to use aiter moe ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MOE": lambda: (
Expand Down
22 changes: 21 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -225,19 +226,38 @@ def create_weights(
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

self.weight_shuffled = False

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if current_platform.is_cpu():
from vllm.model_executor.layers.utils import dispatch_cpu_unquantized_gemm

dispatch_cpu_unquantized_gemm(layer, remove_weight=True)

if rocm_aiter_ops.is_linear_shuffle_enabled():
weight = layer.weight
layout = (16, 16)

if rocm_aiter_ops.gemm_weight_can_shuffle(
weight.shape[0], weight.shape[1], layout
):
shuffled_weight = rocm_aiter_ops.shuffle_weight(weight, layout).t()
self.weight_shuffled = True
layer.register_parameter(
"weight", Parameter(shuffled_weight.data, requires_grad=False)
)

layer.weight_shuffled = self.weight_shuffled

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
return dispatch_unquantized_gemm(
rocm_aiter_weight_shuffled=self.weight_shuffled
)(layer, x, layer.weight, bias)


class LinearBase(CustomOp):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,30 @@ def process_weights_after_loading(self, layer) -> None:
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
)

if (
self.use_aiter_and_is_supported
and rocm_aiter_ops.is_linear_shuffle_enabled()
):
layout = (16, 16)
use_swizzle_gemm = rocm_aiter_ops.gemm_weight_can_shuffle(
weight.shape[0], weight.shape[1], layout=layout
)

self.use_aiter_and_is_supported = (
self.use_aiter_and_is_supported and use_swizzle_gemm
)

if self.use_aiter_and_is_supported:
weight = rocm_aiter_ops.shuffle_weight(weight, layout)
weight_scale = weight_scale.t()

self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape,
rocm_aiter_weight_shuffled=True,
)

weight = weight.t()

elif self.strategy == QuantizationStrategy.BLOCK:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def forward_native(
x: torch.Tensor,
scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None,
):
) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant:
assert scale is None, "Group quantization is always dynamic"
return self._quantize_group_native(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from torch.nn import Parameter

from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -96,6 +97,23 @@ def process_weights_after_loading(self, layer) -> None:
weight_scale = layer.weight_scale.data
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1)

if rocm_aiter_ops.is_linear_shuffle_enabled():
layout = (16, 16)
use_swizzle_gemm = rocm_aiter_ops.gemm_weight_can_shuffle(
weight.shape[0], weight.shape[1], layout=layout
)

if use_swizzle_gemm:
weight = rocm_aiter_ops.shuffle_weight(weight, layout)
weight_scale = weight_scale.t()

self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_quant_group_shape,
rocm_aiter_weight_shuffled=True,
)

layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
Expand Down
Loading