Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 134 additions & 10 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV

from ..base import analysis
from ..base import analysis, BlockInfo, IterInfo
from .base import GPUScheduleRule


Expand Down Expand Up @@ -273,6 +273,32 @@ def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]:
)


def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) -> BlockInfo:
def _iter_kind(loop: tir.IterVar) -> str:
return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O")

def _is_reduction_block(block: tir.schedule.BlockRV):
for iter_var in sch.get(block).iter_vars:
if _iter_kind(iter_var) == "R":
return True
return False

return BlockInfo(
name=sch.get(block).name_hint,
iters=[
IterInfo(
kind=_iter_kind(iter_var),
var=iter_var.var,
dom=iter_var.dom.extent,
loop_rv=loop_rv,
)
for loop_rv, iter_var in zip(sch.get_loops(block), sch.get(block).iter_vars)
],
block_rv=block,
reduction_block=_is_reduction_block(block),
)


def get_reduction_blocks(sch, blocks) -> bool:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
Expand Down Expand Up @@ -914,17 +940,19 @@ def get_configs(self, target: Target) -> Config:
storage_align=True,
inner_x=False,
)
elif target.kind.name == "opencl" and "android" in str(target.host):
elif target.kind.name == "opencl" and (
("android" in str(target.host)) or ("windows" in str(target.host))
):
return Matmul.Config(
block_size_x=8,
block_size_y=16,
block_size_x=32,
block_size_y=8,
vthread_x=1,
vthread_y=1,
micro_size_x=8,
micro_size_y=2,
micro_size_k=16,
vector_size=8,
unroll=64,
unroll=4,
use_shared=False,
storage_align=False,
inner_x=True,
Expand All @@ -941,6 +969,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target):
return None
sch = tir.Schedule(func)
config = self.get_configs(target)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

Expand All @@ -953,9 +982,22 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

main_block_info = get_block_info(sch, main_block)
iter_infos = main_block_info.iters

# Checks if it's a inner reduction by getting the last matrix's inner Index
def is_inner_reduction(block_stmt, iter_infos):
end_it = block_stmt.reads[-1].region[-1].min
return {it.var: it.kind for it in iter_infos}.get(end_it, "O") == "R"

if target.kind.name == "opencl" and not is_inner_reduction(block_stmt, iter_infos):
ret = self.sch_outer_reduction(sch, config, main_block, blocks)
if ret is not None:
return ret

# Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
Expand Down Expand Up @@ -994,10 +1036,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
except: # pylint: disable=bare-except
pass

# Step 2. Get schedule config.
config = self.get_configs(target)

# Step 3. Schedule matmul
# Step 2. Schedule matmul
y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y
x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x
if config.inner_x:
Expand Down Expand Up @@ -1075,3 +1114,88 @@ def _cooperative_fetch(index, vec_len):

sch.decompose_reduction(main_block, ko)
return sch

def sch_outer_reduction(
self,
sch: tir.Schedule,
config: Config,
reduction_block: tir.schedule.BlockRV,
blocks: List[tir.schedule.BlockRV],
) -> Optional[tir.Schedule]:
reduction_loops = sch.get_loops(reduction_block)
if not len(reduction_loops) == 4:
return None

mb, ms, n, k = reduction_loops
if not (
isinstance(sch.get(n).extent, tir.IntImm)
and isinstance(sch.get(mb).extent, tir.IntImm)
and isinstance(sch.get(ms).extent, tir.Var)
):
return None

Threads_X, Threads_Y, VecSize, Unroll_M = (
config.block_size_x,
config.block_size_y,
config.vector_size,
config.unroll,
)

is_dequant_block = len(blocks) > 1
if is_dequant_block:
compute_block, dequant_block, matmul_block = blocks
sch.compute_inline(compute_block)
else:
(matmul_block,) = blocks

m = sch.fuse(mb, ms)

sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1])

rmat_block, wmat_block = (
sch.get_producers(matmul_block)[0],
sch.get_consumers(matmul_block)[0],
)
mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M])
no, ni, nv = sch.split(n, [None, Threads_X, VecSize])
k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8])
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)

sch.compute_at(rmat_block, k0)
if is_dequant_block:
sch.compute_at(dequant_block, k3)
sch.reverse_compute_at(wmat_block, mi)
sch.set_scope(rmat_block, 0, "shared")
sch.set_scope(matmul_block, 0, "local")
if is_dequant_block:
sch.set_scope(dequant_block, 0, "local")

sch.bind(mo, "blockIdx.y")
sch.bind(no, "blockIdx.x")
sch.bind(mi, "threadIdx.y")
sch.bind(ni, "threadIdx.x")
sch.vectorize(sch.get_loops(matmul_block)[-1])
if is_dequant_block:
sch.vectorize(sch.get_loops(dequant_block)[-1])

# Co-operative Memory Fetch
ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize])
sch.bind(ro, "threadIdx.x")
sch.vectorize(rv)

wv = sch.get_loops(wmat_block)[-1]
sch.vectorize(wv)

# Scale and Quant Cache
if is_dequant_block:
qb = sch.cache_read(dequant_block, 0, "local")
sb = sch.cache_read(dequant_block, 1, "local")
sch.compute_at(sb, k1)
sch.compute_at(qb, k2)
sch.set_scope(sb, 0, "local")
sch.set_scope(qb, 0, "local")
sch.vectorize(sch.get_loops(qb)[-1])
sch.vectorize(sch.get_loops(sb)[-1])

sch.decompose_reduction(matmul_block, k0)
return sch
Loading