Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/flag_gems/runtime/backend/_nvidia/hopper/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# from .mm import mm
import triton

# __all__ = ["mm"]
if triton.__version__ >= "3.4":
from .mm import mm # noqa: F401

__all__ = ["*"]
292 changes: 292 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/hopper/ops/mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import logging
from functools import lru_cache
from typing import Optional

import torch
import triton
import triton.language as tl

from flag_gems import runtime
from flag_gems.ops.mm_streamk import streamk_mm
from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry, libtuner
from flag_gems.utils import triton_lang_extension as tle


@lru_cache(maxsize=1)
def get_device_info():
try:
device_id = torch_device_fn.current_device()
except Exception:
device_id = 0

try:
props = torch_device_fn.get_device_properties(device_id)
return device_id, props.L2_cache_size, props.multi_processor_count
except Exception:
# fallback for A100 default attributes
# L2 cache size is 40MB and SM count is 108 for A100
return device_id, 40 * 1024 * 1024, 108


def get_device_id():
return get_device_info()[0]


def get_l2_cache_size():
return get_device_info()[1]


def get_sm_count():
return get_device_info()[2]


CACHE_USAGE_THRESHOLD = 0.8

logger = logging.getLogger(__name__)


@triton.jit
def prev_multiple_of(a, b):
# the largest x<a that x%b ==0
return tl.cdiv(a, b) * b - b


@libentry()
@libtuner(
configs=runtime.get_tuned_config("mm"),
# Add 'stride_am' and 'stride_bk' to trigger autotune for tensors with the same shape but different strides.
key=["M", "N", "K", "stride_am", "stride_bk"],
strategy=["default", "default", "default", "default", "default"],
warmup=5,
rep=10,
)
@triton.jit
def mm_kernel_general(
A,
B,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
# matrix multiplication
pid = tle.program_id(0)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)

if M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0:
# offset
offset_am = pid_m * BLOCK_M
offset_bn = pid_n * BLOCK_N
offset_k = 0

a_desc = tl.make_tensor_descriptor(
base=A,
shape=[M, K],
strides=[K, 1],
block_shape=[BLOCK_M, BLOCK_K],
)

# row-major
b_desc = tl.make_tensor_descriptor(
base=B,
shape=[K, N],
strides=[N, 1],
block_shape=[BLOCK_K, BLOCK_N],
)

# column-major
# b_desc = tl.make_tensor_descriptor(
# B,
# shape = [N, K],
# strides = [K, 1],
# block_shape = [BLOCK_N, BLOCK_K],
# )

c_desc = tl.make_tensor_descriptor(
base=C,
shape=[M, N],
strides=[N, 1],
block_shape=[BLOCK_M, BLOCK_N],
)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = a_desc.load([offset_am.to(tl.int32), offset_k.to(tl.int32)])
b = b_desc.load([offset_k.to(tl.int32), offset_bn.to(tl.int32)])
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)
offset_k += BLOCK_K

acc = acc.to(a_desc.dtype)
c_desc.store([offset_am.to(tl.int32), offset_bn.to(tl.int32)], acc)

else:
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
prev_multiple = prev_multiple_of(K, BLOCK_K)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for start_k in range(0, prev_multiple, BLOCK_K):
rk = start_k + tl.arange(0, BLOCK_K)
a = tl.load(A + (ram[:, None] * stride_am + rk[None, :] * stride_ak))
b = tl.load(B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn))
if a.dtype != b.dtype:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)

# loop peeling
rk = prev_multiple + tl.arange(0, BLOCK_K)
mask_k = rk < K
a = tl.load(
A + (ram[:, None] * stride_am + rk[None, :] * stride_ak),
mask=mask_k[None, :],
other=0.0,
)
b = tl.load(
B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn),
mask=mask_k[:, None],
other=0.0,
)
if a.dtype != b.dtype:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
acc += tl.dot(a, b, out_dtype=tl.float32, allow_tf32=False)

acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offsets = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
tl.store(offsets, acc, mask=mask)


_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32]


def get_higher_dtype(a, b):
if a is b:
return a

assert a in _ordered_datatypes
assert b in _ordered_datatypes

for d in _ordered_datatypes:
if a is d:
return b
if b is d:
return a


def general_mm(a, b, c, M, N, K):
logger.debug(
"GEMS MM, [mm scenario]: general, [shape info]: [-, %s, %s, %s](batch, M, N, K), "
"[A column-major]: %s, [B column-major]: %s",
M,
N,
K,
a.stride(0) == 1,
b.stride(0) == 1,
)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)

def alloc_fn(size: int, align: int, stream: Optional[int]):
return torch.empty(size, dtype=torch.int8, device=a.device)

triton.set_allocator(alloc_fn)

with torch_device_fn.device(a.device):
mm_kernel_general[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
GROUP_M=8,
)
return c


def streamk_scenario(a, b, M, N, K):
# TODO: this my change sometime according to the realbenchmark result
# Currently, the best configuration for streamk has only been tested on A100(capability[0] == 8).
# The optimal settings for other devices need to be determined through real testing.
capability = torch_device_fn.get_device_capability(get_device_info())
return (
capability[0] == 8
and a.dtype in [torch.float16, torch.bfloat16]
and b.dtype in [torch.float16, torch.bfloat16]
and K > M * 5
and K > N * 5
)


def mm(a, b):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c_dtype = get_higher_dtype(a.dtype, b.dtype)
c = torch.empty((M, N), device=device, dtype=c_dtype)
# l2_cache_size = get_l2_cache_size()
sm_count = get_sm_count()
if streamk_scenario(a, b, M, N, K):
return streamk_mm(a, b, c, M, N, K, sm_count=sm_count)
else:
return general_mm(a, b, c, M, N, K)


def mm_out(a, b, *, out):
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# l2_cache_size = get_l2_cache_size()
sm_count = get_sm_count()
if streamk_scenario(a, b, M, N, K):
return streamk_mm(a, b, out, M, N, K, sm_count=sm_count)
else:
return general_mm(a, b, out, M, N, K)
48 changes: 48 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,54 @@ mm:
BLOCK_K: 64
num_stages: 3
num_warps: 8
- META:
BLOCK_M: 128
BLOCK_N: 256
BLOCK_K: 64
num_stages: 3
num_warps: 8
- META:
BLOCK_M: 256
BLOCK_N: 128
BLOCK_K: 64
num_stages: 3
num_warps: 8
- META:
BLOCK_M: 128
BLOCK_N: 256
BLOCK_K: 32
num_stages: 3
num_warps: 8
- META:
BLOCK_M: 256
BLOCK_N: 128
BLOCK_K: 32
num_stages: 3
num_warps: 8
- META:
BLOCK_M: 128
BLOCK_N: 256
BLOCK_K: 64
num_stages: 4
num_warps: 8
- META:
BLOCK_M: 256
BLOCK_N: 128
BLOCK_K: 64
num_stages: 4
num_warps: 8
- META:
BLOCK_M: 128
BLOCK_N: 256
BLOCK_K: 32
num_stages: 4
num_warps: 8
- META:
BLOCK_M: 256
BLOCK_N: 128
BLOCK_K: 32
num_stages: 4
num_warps: 8
baddbmm:
- META:
TILE_M: 32
Expand Down
Loading