diff --git a/examples/plot_layout/fragment_mma_load_a.py b/examples/plot_layout/fragment_mma_load_a.py index b203bc30e..988899448 100644 --- a/examples/plot_layout/fragment_mma_load_a.py +++ b/examples/plot_layout/fragment_mma_load_a.py @@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16", ---------- dtype : str The data type of the matrix. - local_buf : tir.Buffer - The local buffer representing a fragment of a matrix. + matrix : Literal["A", "B"] + The mma operand to be loaded. + transposed : bool + Whether the matrix is transposed, by default False. Returns ------- T.Fragment - A fragment object that describes how threads and indices - in `local_buf` are laid out. + Describes how threads and indices in fragment are laid out. - Raises - ------ - AssertionError - If `local_buf` is not detected to be a fragment buffer. """ from tilelang.intrinsics.mma_layout import ( - shared_16x16_to_mma_32x8_layout_sr, - shared_16x16_to_mma_32x8_layout_rs, - shared_16x32_to_mma_32x16_layout, - shared_32x16_to_mma_32x16_layout, + shared_16x8_to_mma_32x4_layout_sr_a, + shared_16x16_to_mma_32x8_layout_sr_a, + shared_16x32_to_mma_32x16_layout_sr_a, + shared_16x8_to_mma_32x4_layout_sr_b, + shared_16x16_to_mma_32x8_layout_sr_b, + shared_16x32_to_mma_32x16_layout_sr_b, ) assert matrix in ["A", "B"], "matrix should be either A or B" dtype_bits = DataType(dtype).bits - assert transposed is False, "transposed is not supported yet" # s represents spatial axis # r represents reduction axis # sr represents the two dims are spatial + reduction # rs represents the two dims are reduction + spatial - transform_func_sr: Callable = None - transform_func_rs: Callable = None - if dtype_bits == 16: - transform_func_sr = shared_16x16_to_mma_32x8_layout_sr - transform_func_rs = shared_16x16_to_mma_32x8_layout_rs + transform_func_sr_a: Callable = None + transform_func_sr_b: Callable = None + if dtype_bits == 32: + transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a + transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b + elif dtype_bits == 16: + transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a + transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b elif dtype_bits == 8: - transform_func_sr = shared_16x32_to_mma_32x16_layout - transform_func_rs = shared_32x16_to_mma_32x16_layout + transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a + transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b else: raise ValueError(f"Unsupported dtype {dtype}") + is_sr_conditions = [False] is_sr_conditions.append(matrix == "A" and not transposed) is_sr_conditions.append(matrix == "B" and transposed) is_sr_axis_order = any(is_sr_conditions) - transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs - - micro_size_s, _, micro_size_r = get_mma_micro_size(dtype) + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype) + + # the layout of mma.sync is row.col. + # so the b matrix expected a transposed basic layout + transform_func: Callable = None + if matrix == "A": + transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( + j, i) + micro_size_s, micro_size_r = micro_size_x, micro_size_k + elif matrix == "B": + transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b( + j, i) + micro_size_s, micro_size_r = micro_size_k, micro_size_y + else: + raise ValueError(f"Unsupported matrix {matrix}") - transform_func = transform_func inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32") def forward_thread(i: int, j: int) -> int: @@ -81,7 +94,7 @@ def forward_index(i: int, j: int) -> int: return local_id base_fragment = T.Fragment( - [micro_size_r, micro_size_s], + [micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s], forward_thread_fn=forward_thread, forward_index_fn=forward_index, ) @@ -109,4 +122,4 @@ def forward_index(i: int, j: int) -> int: # block layout 128x32 block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False) print(block_layout) -# plot_layout(block_layout, name="block_layout") +plot_layout(block_layout, name="block_layout") diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index cb999ac41..65d2ab0ca 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -490,7 +490,6 @@ def make_mma_load_layout(self, transform_func_sr_a: Callable = None transform_func_sr_b: Callable = None if dtype_bits == 32: - ... transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b elif dtype_bits == 16: