Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion aiter/ops/gemm_op_a16w16.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def gen_gemm_a16w16_asm_fake_tensors(
A: Tensor,
B: Tensor,
out: Tensor,
semaphore: Tensor,
bias: Optional[Tensor] = None,
splitK: Optional[int] = None,
kernelName: Optional[str] = None,
Expand All @@ -37,13 +38,19 @@ def gemm_a16w16_asm(
A: Tensor,
B: Tensor,
out: Tensor,
semaphore: Tensor,
bias: Optional[Tensor] = None,
splitK: Optional[int] = None,
kernelName: Optional[str] = None,
bpreshuffle: bool = False,
) -> Tensor: ...


@functools.lru_cache(maxsize=1)
def get_semaphore_workspace(device: torch.device) -> Tensor:
return torch.zeros((16, 64), dtype=torch.uint32, device=device)


def gemm_a16w16(
A: Tensor,
B: Tensor,
Expand All @@ -52,4 +59,5 @@ def gemm_a16w16(
splitK: Optional[int] = None,
kernelName: Optional[str] = None,
):
return gemm_a16w16_asm(A, B, out, bias, splitK, kernelName)
sema = get_semaphore_workspace(out.device)
return gemm_a16w16_asm(A, B, out, bias, sema, splitK, kernelName)
14 changes: 12 additions & 2 deletions aiter/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
import torch.nn.functional as F
from torch import Tensor

from aiter import dtypes, gemm_a16w16_asm, hipb_create_extension, hipb_mm, logger
from aiter import (
dtypes,
gemm_a16w16_asm,
get_semaphore_workspace,
hipb_create_extension,
hipb_mm,
logger,
)
from aiter.jit.core import AITER_CONFIGS, AITER_LOG_TUNED_CONFIG
from aiter.jit.utils.chip_info import get_cu_num, get_gfx
from aiter.jit.utils.torch_guard import torch_compile_guard
Expand Down Expand Up @@ -392,7 +399,10 @@ def asm_gemm(
out_asm = torch.empty(
inp.shape[0], weights.shape[0], dtype=otype, device=inp.device
)
return gemm_a16w16_asm(inp, weights, out_asm, bias, splitK, KernelName, bpreshuffle)
sema = get_semaphore_workspace(out_asm.device)
return gemm_a16w16_asm(
inp, weights, out_asm, sema, bias, splitK, KernelName, bpreshuffle
)


def triton_gemm(
Expand Down
1 change: 1 addition & 0 deletions csrc/include/asm_gemm_a16w16.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
torch::Tensor gemm_a16w16_asm(torch::Tensor& A, // A:[M, K] bf16
torch::Tensor& B, // B:[N, K] bf16
torch::Tensor& out, // Out:[M, N] f32
torch::Tensor& semaphore,
std::optional<torch::Tensor> bias,
std::optional<int> splitK,
std::optional<std::string> kernelName,
Expand Down
57 changes: 29 additions & 28 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ namespace py = pybind11;
py::arg("A"), \
py::arg("B"), \
py::arg("out"), \
py::arg("semaphore"), \
Comment thread
valarLip marked this conversation as resolved.
py::arg("bias") = std::nullopt, \
py::arg("splitK") = std::nullopt, \
py::arg("kernelName") = std::nullopt, \
Expand Down Expand Up @@ -1537,34 +1538,34 @@ namespace py = pybind11;
#define GEMM_COMMON_PYBIND \
m.def("get_padded_m", &getPaddedM, py::arg("M"), py::arg("N"), py::arg("K"), py::arg("gl"));

#define TOP_K_PER_ROW_PYBIND \
m.def("top_k_per_row_prefill", \
&top_k_per_row_prefill, \
py::arg("logits"), \
py::arg("rowStarts"), \
py::arg("rowEnds"), \
py::arg("indices"), \
py::arg("values"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1")); \
m.def("top_k_per_row_decode", \
&top_k_per_row_decode, \
py::arg("logits"), \
py::arg("next_n"), \
py::arg("seqLens"), \
py::arg("indices"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1")); \
m.def("top_k_per_row_decode_fast", \
&top_k_per_row_decode_fast, \
py::arg("logits"), \
py::arg("next_n"), \
py::arg("seqLens"), \
py::arg("indices"), \
py::arg("numRows"), \
py::arg("stride0"), \
#define TOP_K_PER_ROW_PYBIND \
m.def("top_k_per_row_prefill", \
&top_k_per_row_prefill, \
py::arg("logits"), \
py::arg("rowStarts"), \
py::arg("rowEnds"), \
py::arg("indices"), \
py::arg("values"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1")); \
m.def("top_k_per_row_decode", \
&top_k_per_row_decode, \
py::arg("logits"), \
py::arg("next_n"), \
py::arg("seqLens"), \
py::arg("indices"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1")); \
m.def("top_k_per_row_decode_fast", \
&top_k_per_row_decode_fast, \
py::arg("logits"), \
py::arg("next_n"), \
py::arg("seqLens"), \
py::arg("indices"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1"));

#define MLA_METADATA_PYBIND \
Expand Down
Loading