Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/py_itfs_cu/asm_gemm_a16w16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ torch::Tensor gemm_a16w16_asm(torch::Tensor& A,
int gdy = (Mdim + SUBM - 1) / SUBM;
int gdz = selectedksplit;

TORCH_CHECK(gdx <= 16, __func__, " gdx (", gdx, ") must be <= 16"); // 16 = 512/32
TORCH_CHECK(gdy <= 16, __func__, " gdy (", gdy, ") must be <= 16"); // 16 = 512/32

// semaphore.fill_(selectedksplit);
args.ptr_semaphore = (void*)semaphore.data_ptr<uint32_t>();
Expand Down
6 changes: 4 additions & 2 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import torch.nn.functional as F

import aiter
from aiter import dtypes, logger
from aiter import dtypes, gemm_a16w16_asm, get_semaphore_workspace, logger
Comment thread
amd-ruitang3 marked this conversation as resolved.
Outdated
from aiter.jit.core import AITER_CONFIG_GEMM_BF16, get_asm_dir
from aiter.jit.utils.chip_info import get_cu_num, get_gfx
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16
from aiter.utility.base_tuner import GemmCommonTuner
from aiter.utility.mp_tuner import mp_tuner
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 as triton_gemm_a16w16

aiter.hipb_create_extension()

Expand Down Expand Up @@ -59,10 +59,12 @@ def call_hipb_mm(
def run_gemm_bf16_asm(
inp, w, out, bias=None, splitK=None, kernelName=None, bpreshuffle=False
):
sema = get_semaphore_workspace(inp.device)
return aiter.gemm_a16w16_asm(
inp,
w,
out,
sema,
bias=bias,
splitK=splitK,
kernelName=kernelName,
Expand Down
Loading