Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from dataclasses import dataclass
from typing import Optional, Tuple
from functools import partial as functools_partial

import torch
from aiter import (
Expand All @@ -28,15 +29,16 @@
)
from aiter import (
QuantType,
dtypes,
get_hip_quant,
)

if use_triton_gemm():
try:
from aiter.ops.triton.fused_gemm_afp4wfp4_split_cat import fused_gemm_afp4wfp4_preshuffle_split_cat
from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_preshuffle_split_cat
except:
fused_gemm_afp4wfp4_preshuffle_split_cat = None
fused_gemm_a8w8_blockscale_preshuffle_split_cat = None

# from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla
# # from aiter.ops.triton.fused_kv_cache import fused_qk_rope_cat_and_cache_mla
Expand Down Expand Up @@ -279,38 +281,41 @@ def _forward_prefill_mha(
self.v_head_dim,
output_dtype
)
else: # FP8 GEMM + split + cat
elif fused_gemm_a8w8_blockscale_preshuffle_split_cat is not None and weight.dtype == dtypes.fp8: # FP8 GEMM + split + cat
weight_shuffled = weight.reshape(
weight.shape[0] // 16,
weight.shape[1] * 16
)

output_dtype = kv_c_normed.dtype

quant_func = functools_partial(
get_hip_quant(QuantType.per_1x128),
transpose_scale=True
)
q_input, x_scale = quant_func(
kv_c_normed,
quant_dtype=dtypes.fp8,
scale=getattr(self.kv_b_proj, "input_scale", None)
)

k, v = fused_gemm_a8w8_blockscale_preshuffle_split_cat(
q_input,
weight_shuffled,
k_rope.expand((-1, self.num_heads, -1)),
x_scale,
weight_scale,
self.qk_nope_head_dim,
self.v_head_dim,
output_dtype
)
else:
kv_nope = self.kv_b_proj(kv_c_normed).view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)

k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)

# from aiter.ops.triton.fused_gemm_a8w8_blockscale_split_cat import fused_gemm_a8w8_blockscale_split_cat
# import aiter as rocm_aiter
# from aiter import get_hip_quant
# aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
# from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8

# input = kv_c_normed
# weight = self.kv_b_proj.weight
# block_size = self.kv_b_proj.quant_method.quant_config.weight_block_size
# weight_scale = self.kv_b_proj.weight_scale

# input_2d = input.view(-1, input.shape[-1])
# output_dtype = input.dtype

# if current_platform.is_fp8_fnuz():
# q_input, x_scale = aiter_per1x128_quant(
# input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
# else:
# q_input, x_scale = per_token_group_quant_fp8(
# input_2d, block_size[1], column_major_scales=False)

# k, v = fused_gemm_a8w8_blockscale_split_cat(
# q_input, weight, k_rope.expand((-1, self.num_heads, -1)), x_scale, weight_scale, self.qk_nope_head_dim, self.v_head_dim, output_dtype
# )
else:
kv_nope = self.kv_b_proj(kv_c_normed).view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
Expand Down
49 changes: 38 additions & 11 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,29 @@
from aiter.ops.shuffle import shuffle_weight
from aiter.tuned_gemm import tgemm
from aiter.utility import fp4_utils
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
from atom.utils import envs


def divide(numerator, denominator):
assert (
numerator % denominator == 0
), f"numerator {numerator} denominator {denominator}"
return numerator // denominator

def use_triton_gemm() -> bool:
return envs.ATOM_USE_TRITON_GEMM

if use_triton_gemm():
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle
try:
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4_preshuffle
# For Triton FP8 Blockscale GEMM is mostly slower then AITER GEMM, we turn off Triton FP8 GEMM
# from aiter.ops.triton.gemm_a8w8_blockscale import gemm_a8w8_blockscale_preshuffle
except:
gemm_afp4wfp4_preshuffle = None
gemm_a8w8_blockscale_preshuffle = None
else:
gemm_afp4wfp4_preshuffle = None
gemm_a8w8_blockscale_preshuffle = None

from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
def divide(numerator, denominator):
assert (
numerator % denominator == 0
), f"numerator {numerator} denominator {denominator}"
return numerator // denominator

def gemm_a4w4_quant_fake(
x: torch.Tensor,
Expand Down Expand Up @@ -132,6 +136,29 @@ def gemm_a4w4_quant(
return y[:m, ...]


def gemm_a8w8_blockscale_preshuffle_fake(x: torch.Tensor, weight: torch.Tensor,
x_scale: torch.Tensor, w_scale: torch.Tensor,
dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
return torch.empty(
(*x.shape[:-1], weight.shape[0]), dtype=dtype, device=x.device
)


@torch_compile_guard(gen_fake=gemm_a8w8_blockscale_preshuffle_fake, mutates_args=[])
def gemm_a8w8_blockscale_preshuffle_impl(x: torch.Tensor, weight: torch.Tensor,
x_scale: torch.Tensor, w_scale: torch.Tensor,
dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
if gemm_a8w8_blockscale_preshuffle is not None:
weight_shuffled = weight.reshape(
weight.shape[0] // 16,
weight.shape[1] * 16
)
y = gemm_a8w8_blockscale_preshuffle(x, weight_shuffled, x_scale, w_scale, dtype)
else:
y = gemm_a8w8_blockscale_bpreshuffle(x, weight, x_scale, w_scale, dtype)
return y


class LinearBase(nn.Module):

def __init__(
Expand Down Expand Up @@ -360,7 +387,7 @@ def forward(
if self.bias is not None:
y += self.bias
elif self.quant_type.value == QuantType.per_1x128.value:
y = gemm_a8w8_blockscale_bpreshuffle(
y = gemm_a8w8_blockscale_preshuffle_impl(
x, self.weight, x_scale, self.weight_scale, dtype=otype
)
if self.bias is not None:
Expand Down
Loading