-
Notifications
You must be signed in to change notification settings - Fork 452
[Tool] Provide layout visualization tool #1353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b59e6e8
3f8ab91
70519cd
67f27dd
695009b
79b0914
00feb64
3495e81
18c93c4
360854b
bb9d49f
eac27c9
fd25d67
86eb08a
2da20b0
03dc6b8
5d075f2
7b1532d
581beab
376b63c
cf9b3c7
c1230a4
9282abd
d897a95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
|
|
||
| # use pass_configs to enable layout visualization | ||
| @tilelang.jit( | ||
| out_idx=[-1], | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, | ||
| tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg" | ||
| }) | ||
| def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): | ||
|
|
||
| @T.prim_func | ||
| def gemm( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((K, N), dtype), | ||
| C: T.Tensor((M, N), dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||
| B_shared = T.alloc_shared((block_K, block_N), dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
|
|
||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local) | ||
|
|
||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| return gemm | ||
|
|
||
|
|
||
| def main(): | ||
| kernel = matmul(128, 128, 128, 32, 32, 32) | ||
|
|
||
| import torch | ||
|
|
||
| a = torch.randn(128, 128).cuda().half() | ||
| b = torch.randn(128, 128).cuda().half() | ||
|
|
||
| c = kernel(a, b) | ||
|
|
||
| ref_c = a @ b | ||
|
|
||
| torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) | ||
| print("All check passed.") | ||
|
|
||
| # print the layout visualization result and save figures to ./tmp. | ||
| ''' | ||
| C_local inferenced layout: | ||
| Shape: [32, 32] -> [8] | ||
| Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 | ||
| Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] | ||
| ''' | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| import tilelang.language as T | ||
| from tvm import tir | ||
| from tvm.tir import PyStmtExprVisitor | ||
|
|
||
| from tvm.tir.transform import prim_func_pass | ||
| from tilelang.tools.plot_layout import plot_layout | ||
|
|
||
|
|
||
| def print_fragment_format(layout: T.Fragment) -> str: | ||
| """ | ||
| Format fragment layout information into a human-readable string. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| layout : T.Fragment | ||
| The fragment layout to format | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Formatted string showing shape, thread mapping, and index mapping | ||
| """ | ||
| if isinstance(layout, T.Fragment): | ||
| input_shape = layout.get_input_shape() | ||
| output_shape = layout.get_output_shape() | ||
| lines = [ | ||
| f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", | ||
| f" Index: {layout.forward_index}" | ||
| ] | ||
| print("\n".join(lines)) | ||
| else: | ||
| raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") | ||
|
|
||
|
|
||
| @tir.functor.visitor | ||
| class _LayoutVisualVisitor(PyStmtExprVisitor): | ||
| """ | ||
| User-friendly pass which visualizes fragment layouts inferred during compilation. | ||
|
|
||
| In TileLang, Fragment layouts describe: | ||
| - How logical indices (e.g., [i, j]) map to thread IDs | ||
| - How logical indices map to register file locations within each thread | ||
| - The shape transformation from input dimensions to output dimensions | ||
|
|
||
| This pass generates two types of output: | ||
| 1. Textual output: A human-readable description printed to console | ||
| 2. Visual diagrams: Color-coded plots saved to files (PDF, PNG, SVG formats) | ||
|
|
||
| Configuration: | ||
| The pass is controlled by the TL_ENABLE_LAYOUT_VISUALIZATION configuration option. | ||
| The configuration accepts string values: | ||
|
|
||
| - Empty string or not set: Pass does nothing (default, disabled) | ||
| - "png": Generate PNG format only (recommended for quick inspection) | ||
| - "pdf": Generate PDF format only (recommended for documentation) | ||
| - "svg": Generate SVG format only (recommended for web/vector graphics) | ||
| - "all": Generate all formats (PDF, PNG, SVG) | ||
| - "png,svg": Generate multiple formats (comma-separated) | ||
| """ | ||
|
|
||
| def __init__(self, formats: list[str] = ""): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Fix type annotation inconsistency. Line 61 declares Additionally, there's a critical type mismatch with the caller in To resolve this, decide whether
If - def __init__(self, formats: list[str] = ""):
+ def __init__(self, formats: list[str] | None = None):
super().__init__()
self.layout_found = []
self.processed_layouts = set()
- self.formats_list = [f for f in formats if f != "txt"]
+ self.formats_list = [f for f in (formats or []) if f != "txt"]And update the factory at line 84: -def LayoutVisual(formats: str = ""):
+def LayoutVisual(formats: list[str] | None = None):
🤖 Prompt for AI Agents |
||
| super().__init__() | ||
| self.layout_found = [] | ||
| self.processed_layouts = set() | ||
| self.formats_list = [f for f in formats if f != "txt"] | ||
|
|
||
| def visit_block_(self, op: tir.Block) -> None: | ||
| if "layout_map" in op.annotations: | ||
| layout_map = op.annotations["layout_map"] | ||
|
|
||
| for key, layout in layout_map.items(): | ||
| if isinstance(layout, T.Fragment): | ||
| layout_id = str(layout) | ||
| if layout_id not in self.processed_layouts: | ||
| print(f"{key} inferenced layout:") | ||
| print_fragment_format(layout) | ||
| for fmt in self.formats_list: | ||
| plot_layout(layout, name=f"{key}_layout", formats=fmt) | ||
| self.processed_layouts.add(layout_id) | ||
|
|
||
| # super().visit_block_(op) | ||
|
|
||
|
|
||
| def LayoutVisual(formats: str = ""): | ||
|
|
||
| def pass_fn(func: tir.PrimFunc, mod, ctx): | ||
| _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) | ||
| return func | ||
|
|
||
| return prim_func_pass(pass_fn, opt_level=0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be better to add some comments for this file |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,48 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: | |
| return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) | ||
|
|
||
|
|
||
| def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: | ||
| if pass_ctx is None: | ||
| pass_ctx = tilelang.transform.get_pass_context() | ||
| enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False) | ||
| return enabled | ||
|
|
||
|
Comment on lines
+70
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Harden Right now A small defensive tweak keeps bool configs working while handling strings safely: def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
- enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
- return enabled
+ value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
+ if isinstance(value, str):
+ v = value.strip().lower()
+ return bool(v and v != "false")
+ return bool(value)This keeps the default “disabled unless explicitly enabled” behavior while avoiding surprises from accidental string usage. 🤖 Prompt for AI Agents |
||
|
|
||
| def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: | ||
| if pass_ctx is None: | ||
| pass_ctx = tilelang.transform.get_pass_context() | ||
| formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") | ||
| if not formats_value: | ||
| return ["txt"] | ||
|
|
||
| formats_str = formats_value.strip().lower() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be better to be designed to support multiple format dumps, and text will be the default one. |
||
| valid_formats = ["txt", "png", "pdf", "svg", "all"] | ||
|
|
||
| if formats_str == "all": | ||
| return ["txt", "png", "pdf", "svg"] | ||
|
|
||
| if "," in formats_str: | ||
| formats_list = [f.strip() for f in formats_str.split(',')] | ||
| else: | ||
| formats_list = [formats_str] | ||
|
|
||
| invalid_formats = [f for f in formats_list if f not in valid_formats] | ||
| if invalid_formats: | ||
| raise ValueError( | ||
| f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. " | ||
| f"Valid formats are: {valid_formats}. " | ||
| f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')" | ||
| ) | ||
| return formats_list | ||
|
|
||
|
|
||
| def LayoutVisual(mod: IRModule) -> None: | ||
| """Apply layout visualization pass if enabled.""" | ||
| if should_enable_layout_visual(): | ||
| formats = get_layout_visual_formats() | ||
| tilelang.analysis.LayoutVisual(formats=formats)(mod) | ||
|
|
||
|
|
||
| def PreLowerSemanticCheck(mod: IRModule) -> None: | ||
| """ | ||
| Check whether the module is valid before lowering. If not, raise a user-friendly error | ||
|
|
@@ -121,6 +163,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: | |
| mod = tilelang.transform.LayoutReducer()(mod) | ||
| # Infer memory layouts for fragments and shared memory | ||
| mod = tilelang.transform.LayoutInference()(mod) | ||
| # Visualize the layout | ||
| LayoutVisual(mod) | ||
| # Lower high-level tile operations to low-level operations | ||
| mod = tilelang.transform.LowerTileOp()(mod) | ||
| # Lower l2 persistent map | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,13 @@ | ||
| from __future__ import annotations | ||
LeiWang1999 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import tilelang.language as T | ||
|
|
||
|
|
||
| def plot_layout(layout: T.Layout, | ||
| def plot_layout(layout: T.Fragment, | ||
| save_directory="./tmp", | ||
| name: str = "layout", | ||
| colormap: str = "RdPu", | ||
| verbose: bool = False) -> None: | ||
| verbose: bool = False, | ||
| formats: str | list[str] = "png") -> None: | ||
LeiWang1999 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Plot the layout of a buffer. | ||
|
|
||
|
|
@@ -21,7 +23,8 @@ def plot_layout(layout: T.Layout, | |
| The colormap to use for visualization (default is "RdPu"). | ||
| verbose : bool, optional | ||
| If True, prints additional information about the mapping (default is False). | ||
|
|
||
| formats : str | list[str], optional | ||
| The formats to save the image in (default is "png"). | ||
| Returns | ||
| ------- | ||
| None | ||
|
|
@@ -82,6 +85,21 @@ def plot_layout(layout: T.Layout, | |
| raw_colors = [cmap(i) for i in range(num_threads)] | ||
| colors = raw_colors.copy() | ||
|
|
||
| # Show the distribution of registers in each thread of a warp. | ||
| warp_size = 32 | ||
| # Warn if the number of threads is less than the warp size | ||
| if num_threads < warp_size: | ||
| import warnings | ||
| warnings.warn( | ||
| f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). " | ||
| f"For the best viewing experience, it is recommended to have at least {warp_size} threads.", | ||
| UserWarning, | ||
| stacklevel=2) | ||
| spectral_camp = plt.get_cmap("hsv", warp_size * 6) | ||
|
|
||
| for i in range(min(warp_size, num_threads)): | ||
| colors[i] = spectral_camp(i * 6) | ||
|
|
||
LeiWang1999 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Determine the number of rows and columns in the input shape | ||
| nrows, ncols = input_shape | ||
| # Adjust figure size to maintain square cells | ||
|
|
@@ -191,17 +209,30 @@ def plot_layout(layout: T.Layout, | |
| # Save the figure in multiple formats | ||
| plt.tight_layout() | ||
|
|
||
| # Save as PDF | ||
| pdf_path = tmp_directory / f"{name}.pdf" | ||
| plt.savefig(pdf_path, bbox_inches="tight") | ||
| print(f"Saved pdf format into {pdf_path}") | ||
|
|
||
| # Save as PNG | ||
| png_path = tmp_directory / f"{name}.png" | ||
| plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) | ||
| print(f"Saved png format into {png_path}") | ||
|
|
||
| # Save as SVG | ||
| svg_path = tmp_directory / f"{name}.svg" | ||
| plt.savefig(svg_path, bbox_inches="tight", format="svg") | ||
| print(f"Saved svg format into {svg_path}") | ||
| if isinstance(formats, str): | ||
| formats_str = formats.strip().lower() | ||
| if formats_str == 'all': | ||
| formats_list = ['pdf', 'png', 'svg'] | ||
| elif "," in formats_str: | ||
| formats_list = [f.strip() for f in formats_str.split(',')] | ||
| else: | ||
| formats_list = [formats_str] | ||
| else: | ||
| raise TypeError(f"Expected str, but got {type(formats).__name__}. " | ||
| f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") | ||
|
Comment on lines
+218
to
+222
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The Useful? React with 👍 / 👎. |
||
|
|
||
| # Save the figure | ||
| if 'pdf' in formats_list: | ||
| pdf_path = tmp_directory / f"{name}.pdf" | ||
| plt.savefig(pdf_path, bbox_inches="tight") | ||
| print(f"Saved pdf format into {pdf_path}") | ||
|
|
||
| if 'png' in formats_list: | ||
| png_path = tmp_directory / f"{name}.png" | ||
| plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) | ||
| print(f"Saved png format into {png_path}") | ||
|
|
||
| if 'svg' in formats_list: | ||
| svg_path = tmp_directory / f"{name}.svg" | ||
| plt.savefig(svg_path, bbox_inches="tight", format="svg") | ||
| print(f"Saved svg format into {svg_path}") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix docstring inconsistency: function prints but claims to return
str.The function docstring (lines 19-22) states it returns a formatted string, but the implementation calls
print()on line 31 and has no return statement.Apply this diff to align the implementation with the documented behavior:
def print_fragment_format(layout: T.Fragment) -> str: """ Format fragment layout information into a human-readable string. Parameters ---------- layout : T.Fragment The fragment layout to format Returns ------- str Formatted string showing shape, thread mapping, and index mapping """ if isinstance(layout, T.Fragment): input_shape = layout.get_input_shape() output_shape = layout.get_output_shape() lines = [ f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}" ] - print("\n".join(lines)) + return "\n".join(lines) else: raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")Then update the caller on line 77 to print the returned string:
if layout_id not in self.processed_layouts: print(f"{key} layout inference:") - print_fragment_format(layout) + print(print_fragment_format(layout)) plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id)🧰 Tools
🪛 Ruff (0.14.6)
33-33: Prefer
TypeErrorexception for invalid type(TRY004)
33-33: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents