Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions examples/plot_layout/fragment_mma_load_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,53 +17,66 @@ def make_mma_load_base_layout(dtype: str = "float16",
----------
dtype : str
The data type of the matrix.
local_buf : tir.Buffer
The local buffer representing a fragment of a matrix.
matrix : Literal["A", "B"]
The mma operand to be loaded.
transposed : bool
Whether the matrix is transposed, by default False.

Returns
-------
T.Fragment
A fragment object that describes how threads and indices
in `local_buf` are laid out.
Describes how threads and indices in fragment are laid out.

Raises
------
AssertionError
If `local_buf` is not detected to be a fragment buffer.
"""
from tilelang.intrinsics.mma_layout import (
shared_16x16_to_mma_32x8_layout_sr,
shared_16x16_to_mma_32x8_layout_rs,
shared_16x32_to_mma_32x16_layout,
shared_32x16_to_mma_32x16_layout,
shared_16x8_to_mma_32x4_layout_sr_a,
shared_16x16_to_mma_32x8_layout_sr_a,
shared_16x32_to_mma_32x16_layout_sr_a,
shared_16x8_to_mma_32x4_layout_sr_b,
shared_16x16_to_mma_32x8_layout_sr_b,
shared_16x32_to_mma_32x16_layout_sr_b,
)
assert matrix in ["A", "B"], "matrix should be either A or B"
dtype_bits = DataType(dtype).bits
assert transposed is False, "transposed is not supported yet"
# s represents spatial axis
# r represents reduction axis
# sr represents the two dims are spatial + reduction
# rs represents the two dims are reduction + spatial
transform_func_sr: Callable = None
transform_func_rs: Callable = None
if dtype_bits == 16:
transform_func_sr = shared_16x16_to_mma_32x8_layout_sr
transform_func_rs = shared_16x16_to_mma_32x8_layout_rs
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if dtype_bits == 32:
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16:
transform_func_sr_a = shared_16x16_to_mma_32x8_layout_sr_a
transform_func_sr_b = shared_16x16_to_mma_32x8_layout_sr_b
elif dtype_bits == 8:
transform_func_sr = shared_16x32_to_mma_32x16_layout
transform_func_rs = shared_32x16_to_mma_32x16_layout
transform_func_sr_a = shared_16x32_to_mma_32x16_layout_sr_a
transform_func_sr_b = shared_16x32_to_mma_32x16_layout_sr_b
else:
raise ValueError(f"Unsupported dtype {dtype}")

is_sr_conditions = [False]
is_sr_conditions.append(matrix == "A" and not transposed)
is_sr_conditions.append(matrix == "B" and transposed)
is_sr_axis_order = any(is_sr_conditions)

transform_func: Callable = transform_func_sr if is_sr_axis_order else transform_func_rs

micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)

# the layout of mma.sync is row.col.
# so the b matrix expected a transposed basic layout
transform_func: Callable = None
if matrix == "A":
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
micro_size_s, micro_size_r = micro_size_x, micro_size_k
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
Comment on lines +73 to +76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix B micro-size mapping (s vs r) — currently reversed

For matrix "B", spatial should be y and reduction should be k. Using s=k and r=y breaks 8‑bit shapes (expects 16x32 sr), and can misalign with shared_16x32_to_mma_32x16_layout_sr_b. Apply:

 elif matrix == "B":
   transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
       j, i)
-  micro_size_s, micro_size_r = micro_size_k, micro_size_y
+  micro_size_s, micro_size_r = micro_size_y, micro_size_k

Based on relevant code snippet (get_mma_micro_size defines (x, y, k) and B’s s=y, r=k).

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_k, micro_size_y
elif matrix == "B":
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
j, i)
micro_size_s, micro_size_r = micro_size_y, micro_size_k
🤖 Prompt for AI Agents
In examples/plot_layout/fragment_mma_load_a.py around lines 73 to 76, the
micro-size mapping for matrix "B" is reversed: it currently assigns micro_size_s
= micro_size_k and micro_size_r = micro_size_y, but B’s spatial should be y and
reduction should be k. Change the assignment to micro_size_s, micro_size_r =
micro_size_y, micro_size_k so s maps to y and r maps to k (keep the
transform_func logic as-is).

else:
raise ValueError(f"Unsupported matrix {matrix}")

transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")

def forward_thread(i: int, j: int) -> int:
Expand All @@ -81,7 +94,7 @@ def forward_index(i: int, j: int) -> int:
return local_id

base_fragment = T.Fragment(
[micro_size_r, micro_size_s],
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
forward_thread_fn=forward_thread,
forward_index_fn=forward_index,
)
Expand Down Expand Up @@ -109,4 +122,4 @@ def forward_index(i: int, j: int) -> int:
# block layout 128x32
block_layout = warp_layout.repeat([warp_rows, chunk], repeat_on_thread=False, lower_dim_first=False)
print(block_layout)
# plot_layout(block_layout, name="block_layout")
plot_layout(block_layout, name="block_layout")
1 change: 0 additions & 1 deletion tilelang/intrinsics/mma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def make_mma_load_layout(self,
transform_func_sr_a: Callable = None
transform_func_sr_b: Callable = None
if dtype_bits == 32:
...
transform_func_sr_a = shared_16x8_to_mma_32x4_layout_sr_a
transform_func_sr_b = shared_16x8_to_mma_32x4_layout_sr_b
elif dtype_bits == 16:
Expand Down
Loading