Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions aiter/ops/triton/batched_gemm_a16wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
_get_config,
)
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton.utils.common_utils import deserialize_str
from aiter.ops.triton.gemm_a16wfp4 import (
get_splitk,
)
from aiter.jit.utils.torch_guard import torch_compile_guard

_LOGGER = AiterTritonLogger()

Expand All @@ -26,17 +28,36 @@ def set_use_gemm_splitk_bf16(value: bool):
_USE_GEMM_SPLITK_BF16 = value


def batched_gemm_a16wfp4_fake_tensor(
x: torch.Tensor,
w: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[str] = None,
transpose_bm: Optional[bool] = False,
prequant: Optional[bool] = True,
y_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if y is None:
Bx, M, _ = x.shape
_, N, _ = w.shape
return torch.empty((Bx, M, N), dtype=dtype, device=x.device)
return y


@torch_compile_guard(gen_fake=batched_gemm_a16wfp4_fake_tensor)
def batched_gemm_a16wfp4(
x,
w,
w_scales,
dtype: Optional[float] = torch.bfloat16,
x: torch.Tensor,
w: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
config: Optional[str] = None,
transpose_bm: Optional[bool] = False,
prequant: Optional[bool] = True,
y_scale: Optional[torch.Tensor] = None,
):
) -> torch.Tensor:
"""
Computes batched FP4 matrix multiplication Y[i] = X[i] @ W[i]^T with active activation quantization.
X is quantized to MXFP4 during computation, W is pre-quantized FP4.
Expand Down Expand Up @@ -72,6 +93,8 @@ def batched_gemm_a16wfp4(

if config is None:
config = _get_config(M, N, K)
else:
config = deserialize_str(config)

if y is None:
if transpose_bm:
Expand Down
7 changes: 4 additions & 3 deletions aiter/ops/triton/batched_gemm_afp4wfp4_pre_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from typing import Optional
import torch
import triton
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton.utils.common_utils import serialize_dict
from aiter.ops.triton.batched_gemm_a16wfp4 import (
batched_gemm_a16wfp4,
)
Expand All @@ -32,6 +31,8 @@ def batched_gemm_afp4wfp4_pre_quant(
_LOGGER.info(
"batched_gemm_afp4wfp4_pre_quant will be deprecated in future AITER release, please switch to batched_gemm_a16wfp4"
)

config_hashable = serialize_dict(config) if config else None
return batched_gemm_a16wfp4(
x, w, w_scales, dtype, y, config, transpose_bm=False, prequant=True
x, w, w_scales, dtype, y, config_hashable, transpose_bm=False, prequant=True
)
72 changes: 51 additions & 21 deletions aiter/ops/triton/gemm_a16w16_atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,34 @@
_get_config,
)
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str
from aiter.jit.utils.torch_guard import torch_compile_guard

_LOGGER = AiterTritonLogger()


def gemm_a16w16_atomic(
x,
w,
dtype: Optional[float] = torch.bfloat16,
def gemm_a16w16_atomic_fake_tensor(
x: torch.Tensor,
w: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
):
"""
Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction.

Args:
x (torch.Tensor): Input matrix with shape (M, K).
w (torch.Tensor): Weight matrix with shape (N, K), internally transposed.
dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16).
Note: BF16 atomic aggregation may have slight precision loss.
y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N).
Must be zero-initialized for split-K (NUM_KSPLIT > 1).
config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N,
BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier).
config: Optional[str] = None,
) -> torch.Tensor:
if y is None:
M, _ = x.shape
_, N = w.shape
return torch.zeros((M, N), dtype=dtype, device=x.device)
return y
Comment thread
mqhc2020 marked this conversation as resolved.

Returns:
torch.Tensor: Output with shape (M, N).
"""

@torch_compile_guard(gen_fake=gemm_a16w16_atomic_fake_tensor)
def gemm_a16w16_atomic_(
x: torch.Tensor,
w: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[str] = None,
) -> torch.Tensor:
_LOGGER.info(
f"GEMM_A16W16_ATOMIC: x.shape={tuple(x.shape)}, w.shape={tuple(w.shape)} "
)
Expand All @@ -50,6 +50,9 @@ def gemm_a16w16_atomic(

if config is None:
config = _get_config(M, N, K)
else:
config = deserialize_str(config)

# For compatability reasons, these keys may not exist in the config
# TODO: This needs to be embedded in the configs later
if "NUM_KSPLIT" not in config:
Expand Down Expand Up @@ -89,3 +92,30 @@ def gemm_a16w16_atomic(
)

return y


def gemm_a16w16_atomic(
x: torch.Tensor,
w: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
) -> torch.Tensor:
"""
Computes 16 bit matrix multiplication Y = X @ W^T using atomic operations for split-K reduction.

Args:
x (torch.Tensor): Input matrix with shape (M, K).
w (torch.Tensor): Weight matrix with shape (N, K), internally transposed.
dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16).
Note: BF16 atomic aggregation may have slight precision loss.
y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N).
Must be zero-initialized for split-K (NUM_KSPLIT > 1).
config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N,
BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, cache_modifier).

Returns:
torch.Tensor: Output with shape (M, N).
"""
config_hashable = serialize_dict(config) if config else None
return gemm_a16w16_atomic_(x, w, dtype, y, config_hashable)
35 changes: 27 additions & 8 deletions aiter/ops/triton/gemm_a16wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
from typing import Optional
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton.quant import _mxfp4_quant_op
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton.utils.common_utils import deserialize_str
from aiter.ops.triton._triton_kernels.gemm_a16wfp4 import (
_gemm_a16wfp4_kernel,
_get_config,
Expand All @@ -18,20 +17,38 @@
from aiter.ops.triton.gemm_afp4wfp4 import (
get_splitk,
)
from aiter.jit.utils.torch_guard import torch_compile_guard


_LOGGER = AiterTritonLogger()


def gemm_a16wfp4_fake_tensor(
x: torch.Tensor,
w: torch.Tensor,
w_scales: torch.Tensor,
atomic_add: bool = False,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[str] = None,
) -> torch.Tensor:
if y is None:
M, _ = x.shape
N, _ = w.shape
return torch.zeros((M, N), dtype=dtype, device=x.device)
return y


@torch_compile_guard(gen_fake=gemm_a16wfp4_fake_tensor)
def gemm_a16wfp4(
x,
w,
w_scales,
x: torch.Tensor,
w: torch.Tensor,
w_scales: torch.Tensor,
atomic_add: bool = False,
dtype: Optional[float] = torch.bfloat16,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
):
config: Optional[str] = None,
) -> torch.Tensor:
"""
Computes the matmul Y = X x W
W is an e2m1 fp4 tensor and w_scales is an e8m0 tensor.
Expand Down Expand Up @@ -62,6 +79,8 @@ def gemm_a16wfp4(

if config is None:
config = _get_config(M, N, K)
else:
config = deserialize_str(config)

if y is None:
if atomic_add:
Expand Down
54 changes: 43 additions & 11 deletions aiter/ops/triton/gemm_afp4wfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from typing import Optional
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils._triton.arch_info as arch_info
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton.utils.common_utils import serialize_dict, deserialize_str
from aiter.ops.triton._triton_kernels.gemm_afp4wfp4 import (
_gemm_afp4wfp4_kernel,
_gemm_afp4wfp4_kernel_preshuffle_scales,
_gemm_afp4wfp4_preshuffle_kernel,
_gemm_afp4wfp4_reduce_kernel,
_get_config,
)
from .utils.core import AITER_TRITON_CONFIGS_PATH
from aiter.jit.utils.torch_guard import torch_compile_guard

import os
from aiter.utility.triton.triton_metadata_redirect import AOTMetadataContext
Expand Down Expand Up @@ -63,16 +63,34 @@ def get_splitk(K: int, BLOCK_SIZE_K: int, NUM_KSPLIT: int):
return SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT


def gemm_afp4wfp4(
x,
w,
x_scales,
w_scales,
dtype: Optional[float] = torch.bfloat16,
def gemm_afp4wfp4_fake_tensor(
x: torch.Tensor,
w: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
config: Optional[str] = None,
skip_reduce: Optional[bool] = False,
):
) -> torch.Tensor:
if y is None:
M, _ = x.shape
N, _ = w.shape
return torch.empty((M, N), dtype=dtype, device=x.device)
return y


@torch_compile_guard(gen_fake=gemm_afp4wfp4_fake_tensor)
def gemm_afp4wfp4_(
x: torch.Tensor,
w: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[str] = None,
skip_reduce: Optional[bool] = False,
) -> torch.Tensor:
"""
Computes matrix multiplication Y = X @ W^T with FP4 activations and FP4 weights.

Expand All @@ -91,7 +109,6 @@ def gemm_afp4wfp4(
Returns:
torch.Tensor: Output with shape (M, N).
"""

_LOGGER.info(
f"GEMM_AFPWFP4: x.shape={tuple(x.shape)} w.shape={tuple(w.shape)} x_scale={tuple(x_scales.shape)} w_scale={tuple(w_scales.shape)} "
)
Expand All @@ -106,6 +123,8 @@ def gemm_afp4wfp4(

if config is None:
config = _get_config(M, N, K)
else:
config = deserialize_str(config)

if config["NUM_KSPLIT"] > 1:
SPLITK_BLOCK_SIZE, BLOCK_SIZE_K, NUM_KSPLIT = get_splitk(
Expand Down Expand Up @@ -527,3 +546,16 @@ def gemm_afp4wfp4_preshuffled_weight_scales(
"gemm_afp4wfp4_preshuffled_weight_scales will be deprecated in future AITER release, please switch to gemm_afp4wfp4_preshuffle"
)
return gemm_afp4wfp4_preshuffle(x, w, x_scales, w_scales, dtype, y, config, use_aot)


def gemm_afp4wfp4(
x: torch.Tensor,
w: torch.Tensor,
x_scales: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[torch.dtype] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
) -> torch.Tensor:
config_hashable = serialize_dict(config) if config else None
return gemm_afp4wfp4_(x, w, x_scales, w_scales, dtype, y, config_hashable)
13 changes: 7 additions & 6 deletions aiter/ops/triton/gemm_afp4wfp4_pre_quant_atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from typing import Optional
import torch
import triton
import triton.language as tl
from aiter.ops.triton.utils.logger import AiterTritonLogger
from aiter.ops.triton.utils.common_utils import serialize_dict
from aiter.ops.triton.gemm_a16wfp4 import (
gemm_a16wfp4,
)
Expand All @@ -14,14 +13,16 @@


def gemm_afp4wfp4_pre_quant(
x,
w,
w_scales,
x: torch.Tensor,
w: torch.Tensor,
w_scales: torch.Tensor,
dtype: Optional[float] = torch.bfloat16,
y: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
):
_LOGGER.info(
"gemm_afp4wfp4_pre_quant will be deprecated in future AITER release, please switch to gemm_a16wfp4"
)
return gemm_a16wfp4(x, w, w_scales, True, dtype, y, config)

config_hashable = serialize_dict(config) if config else None
return gemm_a16wfp4(x, w, w_scales, True, dtype, y, config_hashable)
9 changes: 9 additions & 0 deletions aiter/ops/triton/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
import triton
import json


def prev_power_of_2(x: int) -> int:
Expand Down Expand Up @@ -34,3 +35,11 @@ def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor:
if x.stride(-1) == 1:
return x
return x.contiguous()


def serialize_dict(d: dict) -> str:
return json.dumps(d)


def deserialize_str(s: str) -> dict:
return json.loads(s)