diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 4b41eef2a..f94d99100 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -56,7 +56,9 @@ class TensorCoreIntrinEmitter: } # Represent the thread binding in the form of (tx, warp_n, warp_m) - is_m_first = False + is_m_first: bool = False + warp_rows: int = 1 + warp_cols: int = 1 def __init__( self, diff --git a/tilelang/intrinsics/tcgen05_macro_generator.py b/tilelang/intrinsics/tcgen05_macro_generator.py index 00a492cab..95a095264 100644 --- a/tilelang/intrinsics/tcgen05_macro_generator.py +++ b/tilelang/intrinsics/tcgen05_macro_generator.py @@ -114,29 +114,7 @@ def _assign_b_shared_layout(self, layout: Layout): return self def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): - warp_row_tiles = self.warp_row_tiles - warp_col_tiles = self.warp_col_tiles - # For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32 - assert warp_row_tiles >= 8, f"warp_row_tiles must be greater than 8, got {warp_row_tiles}" - assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8, got {warp_row_tiles}" - - # TODO: We should not apply constraints here. The warp partition is fakely calculated as - # (1, warps) but only one warp will issue the tcgen5mma. - assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" - assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" - - # four warps per block - self.warp_rows = warp_row_tiles // 8 - if warp_col_tiles % 16 == 0: - self.n_dim = 16 - self.micro_size_y = 16 - self.warp_cols = warp_col_tiles // 16 - else: - # must be divisible by 8 - self.n_dim = 8 - self.micro_size_y = 8 - self.warp_cols = warp_col_tiles // 8 - + # tcgen05 doesn't care about warp partitioning self.micro_size_x = m_dim self.micro_size_k = k_dim