Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tilelang/intrinsics/mma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ class TensorCoreIntrinEmitter:
}

# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False
is_m_first: bool = False
warp_rows: int = 1
warp_cols: int = 1

def __init__(
self,
Expand Down
24 changes: 1 addition & 23 deletions tilelang/intrinsics/tcgen05_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,29 +114,7 @@ def _assign_b_shared_layout(self, layout: Layout):
return self

def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles
# For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32
assert warp_row_tiles >= 8, f"warp_row_tiles must be greater than 8, got {warp_row_tiles}"
assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8, got {warp_row_tiles}"

# TODO: We should not apply constraints here. The warp partition is fakely calculated as
# (1, warps) but only one warp will issue the tcgen5mma.
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"

# four warps per block
self.warp_rows = warp_row_tiles // 8
if warp_col_tiles % 16 == 0:
self.n_dim = 16
self.micro_size_y = 16
self.warp_cols = warp_col_tiles // 16
else:
# must be divisible by 8
self.n_dim = 8
self.micro_size_y = 8
self.warp_cols = warp_col_tiles // 8

# tcgen05 doesn't care about warp partitioning
self.micro_size_x = m_dim
self.micro_size_k = k_dim

Expand Down
Loading