Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,27 +219,31 @@ def get_extent(loop_rv: tir.schedule.LoopRV):
else:
len_ty = min(len_s, 4)

# Use `split_k` to prevent too large shared memory usage
split_k: int = 4

_, tx = sch.split(r, [None, len_tx], preserve_unit_iters=True)
# Schedule the RF block
rf = sch.rfactor(tx, 0)
batch, bx, r, tx, _ = sch.get_loops(rf)
sch.reorder(bx, tx, r)
ro, ri = sch.split(r, [split_k, None], preserve_unit_iters=True)
bx, ty = sch.split(bx, [None, len_ty], preserve_unit_iters=True)

sch.bind(batch, "blockIdx.y")
sch.bind(bx, "blockIdx.x")
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
unit = sch.add_unit_loop(r)
sch.annotate(unit, "pragma_auto_unroll_max_step", unroll_number)
sch.annotate(unit, "pragma_unroll_explicit", 1)
sch.annotate(ro, "pragma_auto_unroll_max_step", unroll_number)
sch.annotate(ro, "pragma_unroll_explicit", 1)

if target.kind.name == "cuda":
# Cache read the vector
def cache_shared(index: int):
block: tir.Block = sch.get(rf)
type_bytes: int = get_bytes(block.reads[index].buffer.dtype)
cache = sch.cache_read(rf, index, "shared")
sch.compute_at(cache, unit, preserve_unit_loops=True)
sch.compute_at(cache, ro, preserve_unit_loops=True)
fused = sch.fuse(*sch.get_loops(cache)[5:])
loop: tir.For = sch.get(fused)
vec_length = vec_bytes // type_bytes
Expand All @@ -256,7 +260,7 @@ def cache_local(index: int):
type_bytes: int = get_bytes(block.reads[index].buffer.dtype)
vec_length = vec_bytes // type_bytes
cache = sch.cache_read(rf, index, "local")
sch.compute_at(cache, r, preserve_unit_loops=True)
sch.compute_at(cache, ri, preserve_unit_loops=True)
fused = sch.fuse(*sch.get_loops(cache)[6:])
loop: tir.For = sch.get(fused)
if isinstance(loop.extent, tir.IntImm) and loop.extent.value % vec_length == 0:
Expand All @@ -273,7 +277,7 @@ def cache_local(index: int):
# TODO: cache scale buffer in Decode-GEMV to shared memory

sch.set_scope(rf, 0, "local")
sch.decompose_reduction(rf, r)
sch.decompose_reduction(rf, ro)
# Schedule the write back block
sch.reverse_compute_at(block, ty, preserve_unit_loops=True)
_, _, _, tx, *s = sch.get_loops(block)
Expand Down
80 changes: 41 additions & 39 deletions tests/python/dlight/test_gpu_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,40 +97,40 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p
for ax1_fused_0 in T.thread_binding(n, thread="blockIdx.x"):
for ax1_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
for ax2_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for u in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.block("NT_matmul_rf_init"):
vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused])
v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0)
for ax2_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_ax1_ax2_ax3_fused_0 in range(1):
for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(4):
for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(1):
with T.block("lv1637_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(32, ax0_fused)
v2 = T.axis.spatial(1, 0)
v3 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused_0 * 128 + ax0_ax1_ax2_ax3_fused_1 * 128 + ax0_ax1_ax2_ax3_fused_2 * 4 + ax0_ax1_ax2_ax3_fused_3)
v3 = T.axis.spatial(128, ax2_fused_0_0 * 32 + ax0_ax1_ax2_ax3_fused_0 * 32 + ax0_ax1_ax2_ax3_fused_1 * 32 + ax0_ax1_ax2_ax3_fused_2 + ax0_ax1_ax2_ax3_fused_3)
T.reads(lv1637[v0, v1, v2, v3])
T.writes(lv1637_shared[v0, v1, v2, v3])
lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3]
with T.block("NT_matmul_rf_init"):
vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused])
v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0)
for ax2_fused_0 in range(4):
for ax2_fused_0_1 in range(1):
for ax0_ax1_ax2_ax3_fused in T.vectorized(1):
with T.block("lv1637_shared_local"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(32, ax0_fused)
v2 = T.axis.spatial(1, 0)
v3 = T.axis.spatial(128, ax2_fused_0 * 32 + ax2_fused_1)
v3 = T.axis.spatial(128, ax2_fused_0_0 * 32 + ax2_fused_1)
T.reads(lv1637_shared[v0, v1, v2, v3])
T.writes(lv1637_shared_local[v0, v1, v2, v3])
lv1637_shared_local[v0, v1, v2, v3] = lv1637_shared[v0, v1, v2, v3]
for u_1 in range(1):
for u in range(1):
with T.block("NT_matmul_rf_update"):
vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused])
v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1)
vax2_fused_0 = T.axis.reduce(4, ax2_fused_0)
vax2_fused_0 = T.axis.reduce(4, ax2_fused_0_0 + ax2_fused_0_1)
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1], lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1], lv1638[0, v0, v1, vax2_fused_0 * 32 + vax2_fused_1])
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] + lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1] * lv1638[0, v0, v1, vax2_fused_0 * 32 + vax2_fused_1]
Expand Down Expand Up @@ -186,39 +186,40 @@ def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 12
for ax0_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
for ax1_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for u in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_ax1_ax2_fused_0 in range(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
for ax1_0_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_ax1_ax2_fused_0 in range(1):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_3 in T.vectorized(8):
for ax0_ax1_ax2_fused_3 in T.vectorized(4):
with T.block("lv1654_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(4096, ax0_ax1_ax2_fused_0 * 2048 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 8 + ax0_ax1_ax2_fused_3)
v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 + ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3)
T.reads(lv1654[v0, v1, v2])
T.writes(lv1654_shared[v0, v1, v2])
lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2]
with T.block("NT_matmul_rf_init"):
vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
for ax1_0_fused_0 in range(16):
for ax1_0_fused_0_1 in range(4):
for ax0_ax1_ax2_fused_0 in range(1):
for ax0_ax1_ax2_fused_1 in T.vectorized(8):
with T.block("lv1654_shared_local"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(4096, ax1_0_fused_0 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
T.reads(lv1654_shared[v0, v1, v2])
T.writes(lv1654_shared_local[v0, v1, v2])
lv1654_shared_local[v0, v1, v2] = lv1654_shared[v0, v1, v2]
for ax1_1 in range(8):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1)
vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1])
vax1_0_fused_0 = T.axis.reduce(16, ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1)
vax1_1 = T.axis.reduce(8, ax1_1)
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv572[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
Expand Down Expand Up @@ -278,39 +279,40 @@ def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 12
for ax0_fused_0 in T.thread_binding(4000, thread="blockIdx.x"):
for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
for ax1_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for u in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_ax1_ax2_fused_0 in range(2):
with T.block("NT_matmul_rf_init"):
vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
for ax1_0_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax0_ax1_ax2_fused_0 in range(1):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_2 in T.thread_binding(32, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_3 in T.vectorized(8):
for ax0_ax1_ax2_fused_3 in T.vectorized(4):
with T.block("lv3216_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(4096, ax0_ax1_ax2_fused_0 * 2048 + ax0_ax1_ax2_fused_1 * 256 + ax0_ax1_ax2_fused_2 * 8 + ax0_ax1_ax2_fused_3)
v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 + ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3)
T.reads(lv3216[v0, v1, v2])
T.writes(lv3216_shared[v0, v1, v2])
lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2]
with T.block("NT_matmul_rf_init"):
vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1)
T.reads()
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
for ax1_0_fused_0 in range(16):
for ax1_0_fused_0_1 in range(4):
for ax0_ax1_ax2_fused_0 in range(1):
for ax0_ax1_ax2_fused_1 in T.vectorized(8):
with T.block("lv3216_shared_local"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1, 0)
v2 = T.axis.spatial(4096, ax1_0_fused_0 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
T.reads(lv3216_shared[v0, v1, v2])
T.writes(lv3216_shared_local[v0, v1, v2])
lv3216_shared_local[v0, v1, v2] = lv3216_shared[v0, v1, v2]
for ax1_1 in range(8):
with T.block("NT_matmul_rf_update"):
vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1)
vax1_0_fused_0, vax1_1 = T.axis.remap("RR", [ax1_0_fused_0, ax1_1])
vax1_0_fused_0 = T.axis.reduce(16, ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1)
vax1_1 = T.axis.reduce(8, ax1_1)
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv772[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
Expand Down