Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/tutorials/debug_tools_for_tilelang.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,32 @@ The output messages will include something like:
msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0
```

### Visual Layout Inference For TileLang
The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations.

When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates:
1. **Textual output**: A human-readable description of the layout mapping
2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping

The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation.

When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats:
- "txt": Text output only (same as default)
- "all": Generates all formats (TXT, PDF, PNG, SVG)
- "png": Generate PNG format only
- "pdf": Generate PDF format only
- "svg": Generate SVG format only
- "txt,svg": Generate multiple formats (comma-separated) in addition to text output

The output messages of "txt" will include something like:
```
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
```


## Conclusion

By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs.
Expand Down
61 changes: 61 additions & 0 deletions examples/visual_layout_inference/visual_layout_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import tilelang
import tilelang.language as T


# use pass_configs to enable layout visualization
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True,
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg"
})
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)

T.copy(C_local, C[by * block_M, bx * block_N])

return gemm


def main():
kernel = matmul(128, 128, 128, 32, 32, 32)

import torch

a = torch.randn(128, 128).cuda().half()
b = torch.randn(128, 128).cuda().half()

c = kernel(a, b)

ref_c = a @ b

torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")

# print the layout visualization result and save figures to ./tmp.
'''
C_local inferenced layout:
Shape: [32, 32] -> [8]
Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2
Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
'''


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ dependencies = [
# mldtypes should be greater than 0.5.1
# if you want to enable fp4
fp4 = ["ml-dtypes>=0.5.1"]
# if you want to enable layout inference visualization
vis = ["matplotlib"]

[build-system]
requires = ["cython>=3.0.0", "scikit-build-core"]
Expand Down
2 changes: 2 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String);

DataType cuTensorMapType() { return DataType::UInt(8, 128); }

Expand Down
4 changes: 4 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
static constexpr const char *kStorageRewriteDetectInplace =
"tl.storage_rewrite_detect_inplace";
static constexpr const char *kLayoutVisualizationEnable =
"tl.layout_visualization_enable";
static constexpr const char *kLayoutVisualizationFormats =
"tl.layout_visualization_formats";
/*!
* \brief Whether to disable dynamic tail split
*
Expand Down
1 change: 1 addition & 0 deletions tilelang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _load_tile_lang_lib():
transform, # noqa: F401
language, # noqa: F401
engine, # noqa: F401
tools, # noqa: F401
)
from .autotuner import autotune # noqa: F401
from .transform import PassConfigKey # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions tilelang/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .ast_printer import ASTPrinter # noqa: F401
from .nested_loop_checker import NestedLoopChecker # noqa: F401
from .fragment_loop_checker import FragmentLoopChecker # noqa: F401
from .layout_visual import LayoutVisual # noqa: F401
90 changes: 90 additions & 0 deletions tilelang/analysis/layout_visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import tilelang.language as T
from tvm import tir
from tvm.tir import PyStmtExprVisitor

from tvm.tir.transform import prim_func_pass
from tilelang.tools.plot_layout import plot_layout


def print_fragment_format(layout: T.Fragment) -> str:
"""
Format fragment layout information into a human-readable string.

Parameters
----------
layout : T.Fragment
The fragment layout to format

Returns
-------
str
Formatted string showing shape, thread mapping, and index mapping
"""
if isinstance(layout, T.Fragment):
input_shape = layout.get_input_shape()
output_shape = layout.get_output_shape()
lines = [
f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}",
f" Index: {layout.forward_index}"
]
print("\n".join(lines))
else:
raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")

Comment on lines +9 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix docstring inconsistency: function prints but claims to return str.

The function docstring (lines 19-22) states it returns a formatted string, but the implementation calls print() on line 31 and has no return statement.

Apply this diff to align the implementation with the documented behavior:

 def print_fragment_format(layout: T.Fragment) -> str:
     """
     Format fragment layout information into a human-readable string.
 
     Parameters
     ----------
     layout : T.Fragment
         The fragment layout to format
 
     Returns
     -------
     str
         Formatted string showing shape, thread mapping, and index mapping
     """
     if isinstance(layout, T.Fragment):
         input_shape = layout.get_input_shape()
         output_shape = layout.get_output_shape()
         lines = [
             f"  Shape: {input_shape} -> {output_shape}", f"  Thread: {layout.forward_thread}",
             f"  Index:  {layout.forward_index}"
         ]
-        print("\n".join(lines))
+        return "\n".join(lines)
     else:
         raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")

Then update the caller on line 77 to print the returned string:

                     if layout_id not in self.processed_layouts:
                         print(f"{key} layout inference:")
-                        print_fragment_format(layout)
+                        print(print_fragment_format(layout))
                         plot_layout(layout, name=f"{key}_layout", formats=self.formats)
                         self.processed_layouts.add(layout_id)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.6)

33-33: Prefer TypeError exception for invalid type

(TRY004)


33-33: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/tools/layout_visual.py around lines 10 to 34, the function
print_fragment_format's docstring states it returns a formatted string but the
implementation prints the string and returns None; change the implementation to
build and return the formatted string instead of calling print (construct the
same 'lines' and return "\n".join(lines)), and remove the print(); then update
the caller at line 77 to print(print_fragment_format(...)) so the caller
displays the returned string.


@tir.functor.visitor
class _LayoutVisualVisitor(PyStmtExprVisitor):
"""
User-friendly pass which visualizes fragment layouts inferred during compilation.

In TileLang, Fragment layouts describe:
- How logical indices (e.g., [i, j]) map to thread IDs
- How logical indices map to register file locations within each thread
- The shape transformation from input dimensions to output dimensions

This pass generates two types of output:
1. Textual output: A human-readable description printed to console
2. Visual diagrams: Color-coded plots saved to files (PDF, PNG, SVG formats)

Configuration:
The pass is controlled by the TL_ENABLE_LAYOUT_VISUALIZATION configuration option.
The configuration accepts string values:

- Empty string or not set: Pass does nothing (default, disabled)
- "png": Generate PNG format only (recommended for quick inspection)
- "pdf": Generate PDF format only (recommended for documentation)
- "svg": Generate SVG format only (recommended for web/vector graphics)
- "all": Generate all formats (PDF, PNG, SVG)
- "png,svg": Generate multiple formats (comma-separated)
"""

def __init__(self, formats: list[str] = ""):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Fix type annotation inconsistency.

Line 61 declares formats: list[str] = "" which is invalid—a list[str] type cannot have a string default. Looking at line 65, the code iterates [f for f in formats if f != "txt"], treating formats as an iterable of strings. However, this creates ambiguity: if formats="png,svg" (a string), the list comprehension yields ['p', 'n', 'g', ',', 's', 'v', 'g'] (individual characters), not ['png', 'svg'].

Additionally, there's a critical type mismatch with the caller in tilelang/engine/phase.py line 109, which passes a list[str] (from get_layout_visual_formats).

To resolve this, decide whether formats should be:

  1. A list[str] (as the caller provides), or
  2. A str (as the outer LayoutVisual function signature at line 84 suggests)

If formats should be a list, apply this diff:

-    def __init__(self, formats: list[str] = ""):
+    def __init__(self, formats: list[str] | None = None):
         super().__init__()
         self.layout_found = []
         self.processed_layouts = set()
-        self.formats_list = [f for f in formats if f != "txt"]
+        self.formats_list = [f for f in (formats or []) if f != "txt"]

And update the factory at line 84:

-def LayoutVisual(formats: str = ""):
+def LayoutVisual(formats: list[str] | None = None):

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tilelang/tools/layout_visual.py around line 61, the constructor currently
declares formats: list[str] = "" which is invalid and mismatches callers; change
the parameter to accept an optional list (e.g., formats: list[str] | None =
None), guard inside the ctor to set formats = [] when None (avoiding a mutable
default), and keep the existing list comprehension that filters out "txt"; also
update the factory function signature at line 84 to accept and forward a
list[str] (not a str) so callers like tilelang/engine/phase.py that pass a
list[str] match the type.

super().__init__()
self.layout_found = []
self.processed_layouts = set()
self.formats_list = [f for f in formats if f != "txt"]

def visit_block_(self, op: tir.Block) -> None:
if "layout_map" in op.annotations:
layout_map = op.annotations["layout_map"]

for key, layout in layout_map.items():
if isinstance(layout, T.Fragment):
layout_id = str(layout)
if layout_id not in self.processed_layouts:
print(f"{key} inferenced layout:")
print_fragment_format(layout)
for fmt in self.formats_list:
plot_layout(layout, name=f"{key}_layout", formats=fmt)
self.processed_layouts.add(layout_id)

# super().visit_block_(op)


def LayoutVisual(formats: str = ""):

def pass_fn(func: tir.PrimFunc, mod, ctx):
_LayoutVisualVisitor(formats=formats).visit_stmt(func.body)
return func

return prim_func_pass(pass_fn, opt_level=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to add some comments for this file

44 changes: 44 additions & 0 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,48 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))


def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
return enabled

Comment on lines +70 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Harden should_enable_layout_visual against string-valued configs.

Right now enabled is returned directly from PassContext.config. If a user mistakenly sets "false" (string) instead of False (bool), if should_enable_layout_visual(): will still evaluate truthy and enable visualization.

A small defensive tweak keeps bool configs working while handling strings safely:

 def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
     if pass_ctx is None:
         pass_ctx = tilelang.transform.get_pass_context()
-    enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
-    return enabled
+    value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
+    if isinstance(value, str):
+        v = value.strip().lower()
+        return bool(v and v != "false")
+    return bool(value)

This keeps the default “disabled unless explicitly enabled” behavior while avoiding surprises from accidental string usage.

🤖 Prompt for AI Agents
In tilelang/engine/phase.py around lines 70 to 75, the function returns the raw
value from pass_ctx.config which can be a string like "false" and still evaluate
truthy; change the logic to coerce and validate the config into a strict
boolean: fetch the raw value, if it is a bool return it, if it is a str
interpret only common true values (e.g. "true", "1", "yes") as True
(case-insensitive) and everything else as False, and default to False when the
key is missing or value is None.


def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "")
if not formats_value:
return ["txt"]

formats_str = formats_value.strip().lower()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better to be designed to support multiple format dumps, and text will be the default one.

valid_formats = ["txt", "png", "pdf", "svg", "all"]

if formats_str == "all":
return ["txt", "png", "pdf", "svg"]

if "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')]
else:
formats_list = [formats_str]

invalid_formats = [f for f in formats_list if f not in valid_formats]
if invalid_formats:
raise ValueError(
f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. "
f"Valid formats are: {valid_formats}. "
f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')"
)
return formats_list


def LayoutVisual(mod: IRModule) -> None:
"""Apply layout visualization pass if enabled."""
if should_enable_layout_visual():
formats = get_layout_visual_formats()
tilelang.analysis.LayoutVisual(formats=formats)(mod)


def PreLowerSemanticCheck(mod: IRModule) -> None:
"""
Check whether the module is valid before lowering. If not, raise a user-friendly error
Expand Down Expand Up @@ -121,6 +163,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LayoutReducer()(mod)
# Infer memory layouts for fragments and shared memory
mod = tilelang.transform.LayoutInference()(mod)
# Visualize the layout
LayoutVisual(mod)
# Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map
Expand Down
65 changes: 48 additions & 17 deletions tilelang/tools/plot_layout.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations
import tilelang.language as T


def plot_layout(layout: T.Layout,
def plot_layout(layout: T.Fragment,
save_directory="./tmp",
name: str = "layout",
colormap: str = "RdPu",
verbose: bool = False) -> None:
verbose: bool = False,
formats: str | list[str] = "png") -> None:
"""
Plot the layout of a buffer.

Expand All @@ -21,7 +23,8 @@ def plot_layout(layout: T.Layout,
The colormap to use for visualization (default is "RdPu").
verbose : bool, optional
If True, prints additional information about the mapping (default is False).

formats : str | list[str], optional
The formats to save the image in (default is "png").
Returns
-------
None
Expand Down Expand Up @@ -82,6 +85,21 @@ def plot_layout(layout: T.Layout,
raw_colors = [cmap(i) for i in range(num_threads)]
colors = raw_colors.copy()

# Show the distribution of registers in each thread of a warp.
warp_size = 32
# Warn if the number of threads is less than the warp size
if num_threads < warp_size:
import warnings
warnings.warn(
f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). "
f"For the best viewing experience, it is recommended to have at least {warp_size} threads.",
UserWarning,
stacklevel=2)
spectral_camp = plt.get_cmap("hsv", warp_size * 6)

for i in range(min(warp_size, num_threads)):
colors[i] = spectral_camp(i * 6)

# Determine the number of rows and columns in the input shape
nrows, ncols = input_shape
# Adjust figure size to maintain square cells
Expand Down Expand Up @@ -191,17 +209,30 @@ def plot_layout(layout: T.Layout,
# Save the figure in multiple formats
plt.tight_layout()

# Save as PDF
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
print(f"Saved pdf format into {pdf_path}")

# Save as PNG
png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
print(f"Saved png format into {png_path}")

# Save as SVG
svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg")
print(f"Saved svg format into {svg_path}")
if isinstance(formats, str):
formats_str = formats.strip().lower()
if formats_str == 'all':
formats_list = ['pdf', 'png', 'svg']
elif "," in formats_str:
formats_list = [f.strip() for f in formats_str.split(',')]
else:
formats_list = [formats_str]
else:
raise TypeError(f"Expected str, but got {type(formats).__name__}. "
f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.")
Comment on lines +218 to +222

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Accept documented list inputs for output formats

The formats parameter is annotated and documented as str | list[str], but the new parsing block immediately raises a TypeError for any non‑string input, so callers following the API and passing a list like ["png", "pdf"] cannot use the feature and the tool fails before plotting.

Useful? React with 👍 / 👎.


# Save the figure
if 'pdf' in formats_list:
pdf_path = tmp_directory / f"{name}.pdf"
plt.savefig(pdf_path, bbox_inches="tight")
print(f"Saved pdf format into {pdf_path}")

if 'png' in formats_list:
png_path = tmp_directory / f"{name}.png"
plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255)
print(f"Saved png format into {png_path}")

if 'svg' in formats_list:
svg_path = tmp_directory / f"{name}.svg"
plt.savefig(svg_path, bbox_inches="tight", format="svg")
print(f"Saved svg format into {svg_path}")
9 changes: 9 additions & 0 deletions tilelang/transform/pass_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class PassConfigKey(str, Enum):
TL_FORCE_LET_INLINE = "tl.force_let_inline"
"""Force TileLang to inline let bindings during simplification. Default: False"""

TL_LAYOUT_VISUALIZATION_ENABLE = "tl.layout_visualization_enable"
"""Enable layout inference visualization. Default: False"""

TL_LAYOUT_VISUALIZATION_FORMATS = "tl.layout_visualization_formats"
"""Layout visualization formats.
Acceptable values: "pdf", "png", "svg", "all"

"""

TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace"
"""Control StorageRewrite inplace detection.

Expand Down
Loading