From b59e6e8dfd07ae2f6b91ed0dc94820d84ce10a34 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Thu, 27 Nov 2025 19:54:56 +0800 Subject: [PATCH 01/21] Provide layout visualization tool Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation. This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance. The visualization can be enabled via a pass config option. --- examples/gemm/example_gemm.py | 2 +- src/op/builtin.cc | 1 + src/op/builtin.h | 1 + tilelang/__init__.py | 1 + tilelang/engine/lower.py | 4 +++ tilelang/engine/phase.py | 9 ++++++ tilelang/tools/__init__.py | 1 + tilelang/tools/layout_visual.py | 49 +++++++++++++++++++++++++++++++ tilelang/tools/plot_layout.py | 9 +++++- tilelang/transform/pass_config.py | 3 ++ 10 files changed, 78 insertions(+), 2 deletions(-) create mode 100644 tilelang/tools/layout_visual.py diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd388a..32d24b587 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -2,7 +2,7 @@ import tilelang.language as T -@tilelang.jit(out_idx=[-1]) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL: True}) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/src/op/builtin.cc b/src/op/builtin.cc index e7e86f2f5..f7d558301 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -34,6 +34,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisual, Bool); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/op/builtin.h b/src/op/builtin.h index f5c7d9edc..86370b472 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -51,6 +51,7 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kStorageRewriteDetectInplace = "tl.storage_rewrite_detect_inplace"; +static constexpr const char *kEnableLayoutVisual = "tl.enable_layout_visual"; /*! * \brief Whether to disable dynamic tail split * diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 2eae5cdb7..4dd62cfa0 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -137,6 +137,7 @@ def _load_tile_lang_lib(): transform, # noqa: F401 language, # noqa: F401 engine, # noqa: F401 + tools, ) from .autotuner import autotune # noqa: F401 from .transform import PassConfigKey # noqa: F401 diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 88d89dcc2..5f6a2f099 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -16,6 +16,7 @@ from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.utils.target import determine_target from tilelang.engine.phase import ( + LayoutVisual, PreLowerSemanticCheck, LowerAndLegalize, OptimizeForTarget, @@ -249,6 +250,9 @@ def lower( # Phase 1: Lower and legalize the IR mod = LowerAndLegalize(mod, target) + # Visualize the layout + LayoutVisual(mod) + # Phase 2: Optimize the IR for the target mod = OptimizeForTarget(mod, target) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 35c16a438..02a3827a9 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -66,6 +66,15 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: pass_ctx = tilelang.transform.get_pass_context() return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) +def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL, False)) + +def LayoutVisual(mod: IRModule) -> None: + """Apply layout visualization pass if enabled.""" + if should_enable_layout_visual(): + tilelang.tools.LayoutVisual()(mod) def PreLowerSemanticCheck(mod: IRModule) -> None: """ diff --git a/tilelang/tools/__init__.py b/tilelang/tools/__init__.py index 7a8bde514..06a009d85 100644 --- a/tilelang/tools/__init__.py +++ b/tilelang/tools/__init__.py @@ -1,2 +1,3 @@ from .plot_layout import plot_layout # noqa: F401 from .Analyzer import * +from .layout_visual import LayoutVisual # noqa: F401 diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py new file mode 100644 index 000000000..1fcb39133 --- /dev/null +++ b/tilelang/tools/layout_visual.py @@ -0,0 +1,49 @@ +import tvm +import tilelang +import tilelang.language as T +from tvm import tir +from tvm.tir import PyStmtExprVisitor + +from tvm.tir.transform import prim_func_pass +from tilelang.tools.plot_layout import plot_layout + +def print_layout_format(layout: T.Fragment) -> str: + input_shape = layout.get_input_shape() + output_shape = layout.get_output_shape() + lines = [ + f" Shape: {input_shape} -> {output_shape}", + f" Thread: {layout.forward_thread}", + f" Index: {layout.forward_index}" + ] + + return "\n".join(lines) + +@tir.functor.visitor +class _LayoutVisualVisitor(PyStmtExprVisitor): + def __init__(self): + super().__init__() + self.layout_found = [] + self.processed_layouts = set() + + def visit_block_(self, op: tir.Block) -> None: + if "layout_map" in op.annotations: + layout_map = op.annotations["layout_map"] + + for key, layout in layout_map.items(): + if isinstance(layout, T.Fragment): + layout_id = str(layout) + if layout_id not in self.processed_layouts: + print(f"{key} layout inference:") + print(print_layout_format(layout)) + plot_layout(layout, name=f"{key}_layout") + self.processed_layouts.add(layout_id) + + self.visit_stmt(op.body) + + +def LayoutVisual(): + def pass_fn(func: tir.PrimFunc, mod, ctx): + _LayoutVisualVisitor().visit_stmt(func.body) + return func + + return prim_func_pass(pass_fn, opt_level=0) \ No newline at end of file diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 291da2571..69e5636f4 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -1,7 +1,7 @@ import tilelang.language as T -def plot_layout(layout: T.Layout, +def plot_layout(layout: T.Fragment, save_directory="./tmp", name: str = "layout", colormap: str = "RdPu", @@ -82,6 +82,13 @@ def plot_layout(layout: T.Layout, raw_colors = [cmap(i) for i in range(num_threads)] colors = raw_colors.copy() + # Show the distribution of registers in each thread of a warp. + warp_size = 32 + spectral_camp = plt.get_cmap("hsv", warp_size * 6) + for i in range(warp_size): + colors[i] = spectral_camp(i * 6) + + # Determine the number of rows and columns in the input shape nrows, ncols = input_shape # Adjust figure size to maintain square cells diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index a1edb881d..6fb02bd2f 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -69,6 +69,9 @@ class PassConfigKey(str, Enum): TL_FORCE_LET_INLINE = "tl.force_let_inline" """Force TileLang to inline let bindings during simplification. Default: False""" + TL_ENABLE_LAYOUT_VISUAL = "tl.enable_layout_visual" + """Enable layout inference visualization. Default: False""" + TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace" """Control StorageRewrite inplace detection. From 70519cd93bd2bc16ba834000377d6f08038bd923 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Thu, 27 Nov 2025 19:57:44 +0800 Subject: [PATCH 02/21] format --- examples/gemm/example_gemm.py | 2 +- tilelang/__init__.py | 2 +- tilelang/engine/phase.py | 6 +++++- tilelang/tools/layout_visual.py | 19 ++++++++++--------- tilelang/tools/plot_layout.py | 1 - 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index 32d24b587..f18cd388a 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -2,7 +2,7 @@ import tilelang.language as T -@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL: True}) +@tilelang.jit(out_idx=[-1]) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/tilelang/__init__.py b/tilelang/__init__.py index 4dd62cfa0..75a92eab6 100644 --- a/tilelang/__init__.py +++ b/tilelang/__init__.py @@ -137,7 +137,7 @@ def _load_tile_lang_lib(): transform, # noqa: F401 language, # noqa: F401 engine, # noqa: F401 - tools, + tools, # noqa: F401 ) from .autotuner import autotune # noqa: F401 from .transform import PassConfigKey # noqa: F401 diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 5a157c488..4f546e07e 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -66,16 +66,20 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: pass_ctx = tilelang.transform.get_pass_context() return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) + def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() - return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL, False)) + return bool( + pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL, False)) + def LayoutVisual(mod: IRModule) -> None: """Apply layout visualization pass if enabled.""" if should_enable_layout_visual(): tilelang.tools.LayoutVisual()(mod) + def PreLowerSemanticCheck(mod: IRModule) -> None: """ Check whether the module is valid before lowering. If not, raise a user-friendly error diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 1fcb39133..c78614439 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -1,5 +1,3 @@ -import tvm -import tilelang import tilelang.language as T from tvm import tir from tvm.tir import PyStmtExprVisitor @@ -7,24 +5,26 @@ from tvm.tir.transform import prim_func_pass from tilelang.tools.plot_layout import plot_layout + def print_layout_format(layout: T.Fragment) -> str: input_shape = layout.get_input_shape() output_shape = layout.get_output_shape() lines = [ - f" Shape: {input_shape} -> {output_shape}", - f" Thread: {layout.forward_thread}", + f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}" ] return "\n".join(lines) + @tir.functor.visitor class _LayoutVisualVisitor(PyStmtExprVisitor): + def __init__(self): super().__init__() self.layout_found = [] self.processed_layouts = set() - + def visit_block_(self, op: tir.Block) -> None: if "layout_map" in op.annotations: layout_map = op.annotations["layout_map"] @@ -37,13 +37,14 @@ def visit_block_(self, op: tir.Block) -> None: print(print_layout_format(layout)) plot_layout(layout, name=f"{key}_layout") self.processed_layouts.add(layout_id) - + self.visit_stmt(op.body) - + def LayoutVisual(): + def pass_fn(func: tir.PrimFunc, mod, ctx): _LayoutVisualVisitor().visit_stmt(func.body) return func - - return prim_func_pass(pass_fn, opt_level=0) \ No newline at end of file + + return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 69e5636f4..5f29d6c2f 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -87,7 +87,6 @@ def plot_layout(layout: T.Fragment, spectral_camp = plt.get_cmap("hsv", warp_size * 6) for i in range(warp_size): colors[i] = spectral_camp(i * 6) - # Determine the number of rows and columns in the input shape nrows, ncols = input_shape From 67f27dd34c49e448f6bb3f2fc84fb3b07b59dbcc Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Fri, 28 Nov 2025 00:33:49 +0800 Subject: [PATCH 03/21] add layout visual example --- .../visual_layout_inference.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/visual_layout_inference/visual_layout_inference.py diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py new file mode 100644 index 000000000..90e7c91a4 --- /dev/null +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -0,0 +1,56 @@ +import tilelang +import tilelang.language as T + + +# use pass_configs to enable layout visualization +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL: True}) +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def gemm( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm + + +def main(): + kernel = matmul(128, 128, 128, 32, 32, 32) + + import torch + + a = torch.randn(128, 128).cuda().half() + b = torch.randn(128, 128).cuda().half() + + c = kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + # print the layout visualization result and save figures to ./tmp. + ''' + C_local layout inference: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] + ''' + + +if __name__ == "__main__": + main() From 695009bd60a9ef480f161cd121c7d174aa6d1476 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sat, 29 Nov 2025 17:28:47 +0800 Subject: [PATCH 04/21] Adds vis extra with matplotlib dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 706cd5290..1e4c59f92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ # mldtypes should be greater than 0.5.1 # if you want to enable fp4 fp4 = ["ml-dtypes>=0.5.1"] +vis = ["matplotlib"] [build-system] requires = ["cython>=3.0.0", "scikit-build-core"] From 79b0914736f9725e37be669e89bd1e1c3f34d626 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sat, 29 Nov 2025 18:19:32 +0800 Subject: [PATCH 05/21] rafactor pass config name --- src/op/builtin.cc | 2 +- src/op/builtin.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/op/builtin.cc b/src/op/builtin.cc index ab5d7f2bc..c73d6415c 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -34,7 +34,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); -TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisual, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisualization, Bool); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/op/builtin.h b/src/op/builtin.h index 3d4f73937..1e9b1d1f0 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -51,7 +51,7 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kStorageRewriteDetectInplace = "tl.storage_rewrite_detect_inplace"; -static constexpr const char *kEnableLayoutVisual = "tl.enable_layout_visual"; +static constexpr const char *kEnableLayoutVisualization = "tl.enable_layout_visualization"; /*! * \brief Whether to disable dynamic tail split * From 00feb64527ae4cab7cf0ceff4f0ab42c3771fa6a Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sat, 29 Nov 2025 18:22:19 +0800 Subject: [PATCH 06/21] fix lint --- examples/visual_layout_inference/visual_layout_inference.py | 3 ++- src/op/builtin.h | 3 ++- tilelang/engine/phase.py | 3 ++- tilelang/transform/pass_config.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 90e7c91a4..1949ed390 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -3,7 +3,8 @@ # use pass_configs to enable layout visualization -@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL: True}) +@tilelang.jit( + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: True}) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/src/op/builtin.h b/src/op/builtin.h index 1e9b1d1f0..a94801649 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -51,7 +51,8 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kStorageRewriteDetectInplace = "tl.storage_rewrite_detect_inplace"; -static constexpr const char *kEnableLayoutVisualization = "tl.enable_layout_visualization"; +static constexpr const char *kEnableLayoutVisualization = + "tl.enable_layout_visualization"; /*! * \brief Whether to disable dynamic tail split * diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 4f546e07e..b0f6ca3b6 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -71,7 +71,8 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() return bool( - pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL, False)) + pass_ctx and + pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION, False)) def LayoutVisual(mod: IRModule) -> None: diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 6fb02bd2f..64d1b55dc 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -69,7 +69,7 @@ class PassConfigKey(str, Enum): TL_FORCE_LET_INLINE = "tl.force_let_inline" """Force TileLang to inline let bindings during simplification. Default: False""" - TL_ENABLE_LAYOUT_VISUAL = "tl.enable_layout_visual" + TL_ENABLE_LAYOUT_VISUALIZATION = "tl.enable_layout_visualization" """Enable layout inference visualization. Default: False""" TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace" From 3495e81cba22403cf4d751b37bad118a147b1408 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 11:46:00 +0800 Subject: [PATCH 07/21] Enables configurable layout visualization formats Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option. This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs. It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false". --- .../visual_layout_inference.py | 4 +- src/op/builtin.cc | 2 +- tilelang/engine/phase.py | 8 +-- tilelang/tools/layout_visual.py | 23 ++++++-- tilelang/tools/plot_layout.py | 52 +++++++++++++------ tilelang/transform/pass_config.py | 7 ++- 6 files changed, 67 insertions(+), 29 deletions(-) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 1949ed390..27c712b64 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -1,10 +1,10 @@ import tilelang import tilelang.language as T - +tilelang.disable_cache() # use pass_configs to enable layout visualization @tilelang.jit( - out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: True}) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "False"}) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/src/op/builtin.cc b/src/op/builtin.cc index c73d6415c..a9f9f541f 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -34,7 +34,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); -TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisualization, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisualization, ffi::String); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index b0f6ca3b6..1fd7dc12e 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -70,9 +70,11 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() - return bool( - pass_ctx and - pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION, False)) + + config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + + config_str = str(config_value).strip().lower() + return bool(config_str and config_str != "false") def LayoutVisual(mod: IRModule) -> None: diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index c78614439..9bc399f13 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -1,3 +1,4 @@ +import tilelang import tilelang.language as T from tvm import tir from tvm.tir import PyStmtExprVisitor @@ -10,7 +11,8 @@ def print_layout_format(layout: T.Fragment) -> str: input_shape = layout.get_input_shape() output_shape = layout.get_output_shape() lines = [ - f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", + f" Shape: {input_shape} -> {output_shape}", + f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}" ] @@ -20,10 +22,11 @@ def print_layout_format(layout: T.Fragment) -> str: @tir.functor.visitor class _LayoutVisualVisitor(PyStmtExprVisitor): - def __init__(self): + def __init__(self, formats: str = "png"): super().__init__() self.layout_found = [] self.processed_layouts = set() + self.formats = formats def visit_block_(self, op: tir.Block) -> None: if "layout_map" in op.annotations: @@ -35,16 +38,26 @@ def visit_block_(self, op: tir.Block) -> None: if layout_id not in self.processed_layouts: print(f"{key} layout inference:") print(print_layout_format(layout)) - plot_layout(layout, name=f"{key}_layout") + plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id) self.visit_stmt(op.body) - def LayoutVisual(): def pass_fn(func: tir.PrimFunc, mod, ctx): - _LayoutVisualVisitor().visit_stmt(func.body) + pass_ctx = tilelang.transform.get_pass_context() + config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + + config_str = str(config_value).strip().lower() + if not config_str or config_str == "false": + return func + elif config_str == "true": + formats = "all" + else: + formats = config_str + + _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 5f29d6c2f..00b208e8f 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -1,3 +1,4 @@ +from torch import Value import tilelang.language as T @@ -5,7 +6,8 @@ def plot_layout(layout: T.Fragment, save_directory="./tmp", name: str = "layout", colormap: str = "RdPu", - verbose: bool = False) -> None: + verbose: bool = False, + formats: str | list[str] = "png") -> None: """ Plot the layout of a buffer. @@ -21,9 +23,10 @@ def plot_layout(layout: T.Fragment, The colormap to use for visualization (default is "RdPu"). verbose : bool, optional If True, prints additional information about the mapping (default is False). - + formats : str | list[str], optional + The formats to save the image in (default is "png"). Returns - ------- + -------s None """ import os @@ -197,17 +200,32 @@ def plot_layout(layout: T.Fragment, # Save the figure in multiple formats plt.tight_layout() - # Save as PDF - pdf_path = tmp_directory / f"{name}.pdf" - plt.savefig(pdf_path, bbox_inches="tight") - print(f"Saved pdf format into {pdf_path}") - - # Save as PNG - png_path = tmp_directory / f"{name}.png" - plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) - print(f"Saved png format into {png_path}") - - # Save as SVG - svg_path = tmp_directory / f"{name}.svg" - plt.savefig(svg_path, bbox_inches="tight", format="svg") - print(f"Saved svg format into {svg_path}") + if isinstance(formats, str): + formats_str = formats.strip().lower() + if formats_str == 'all': + formats_list = ['pdf', 'png', 'svg'] + elif "," in formats_str: + formats_list = [f.strip() for f in formats_str.split(',')] + else: + formats_list = [formats_str] + else: + raise TypeError( + f"Expected str, but got {type(formats).__name__}. " + f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." + ) + + # Save the figure + if 'pdf' in formats_list: + pdf_path = tmp_directory / f"{name}.pdf" + plt.savefig(pdf_path, bbox_inches="tight") + print(f"Saved pdf format into {pdf_path}") + + if 'png' in formats_list: + png_path = tmp_directory / f"{name}.png" + plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) + print(f"Saved png format into {png_path}") + + if 'svg' in formats_list: + svg_path = tmp_directory / f"{name}.svg" + plt.savefig(svg_path, bbox_inches="tight", format="svg") + print(f"Saved svg format into {svg_path}") \ No newline at end of file diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 64d1b55dc..3e4931862 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -70,7 +70,12 @@ class PassConfigKey(str, Enum): """Force TileLang to inline let bindings during simplification. Default: False""" TL_ENABLE_LAYOUT_VISUALIZATION = "tl.enable_layout_visualization" - """Enable layout inference visualization. Default: False""" + """Enable layout inference visualization. Accepts string values: + - "" or "false": disabled (default) + - "true" or "all": enabled, generate all formats (pdf, png, svg) + - "png", "pdf", "svg": enabled, generate specified format + - "png,svg": enabled, generate multiple formats (comma-separated) + """"" TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace" """Control StorageRewrite inplace detection. From 18c93c459a70abe9e46875f2d6a0fc9ce71e05eb Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 12:01:42 +0800 Subject: [PATCH 08/21] Adds visual layout inference tool docs --- docs/tutorials/debug_tools_for_tilelang.md | 24 +++++++++++++++++++ .../visual_layout_inference.py | 2 ++ pyproject.toml | 1 + tilelang/engine/phase.py | 2 +- tilelang/tools/layout_visual.py | 9 +++---- tilelang/tools/plot_layout.py | 10 ++++---- tilelang/transform/pass_config.py | 2 +- 7 files changed, 38 insertions(+), 12 deletions(-) diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md index e18b13279..1fd513777 100644 --- a/docs/tutorials/debug_tools_for_tilelang.md +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -171,6 +171,30 @@ The output messages will include something like: msg='hello world' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): 0 ``` +### Visual Layout Inference For TileLang + The **Visual Layout Inference** tool automatically generates visual diagrams that illustrate the mapping between logical indices, thread IDs, and register file locations. + +When TileLang performs layout inference, it determines how fragment buffers are distributed across threads. The visual layout tool captures this information and generates: +1. **Textual output**: A human-readable description of the layout mapping +2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping + +The visual layout inference tool is controlled through the `TL_ENABLE_LAYOUT_VISUALIZATION` pass configuration. By default, visualization is **disabled** to avoid performance overhead during compilation. + +`TL_ENABLE_LAYOUT_VISUALIZATION` accepts string values to control output formats: +- "True" or "all": Enabled, generates all formats (PDF, PNG, SVG) +- "png": Generate PNG format only +- "pdf": Generate PDF format only +- "svg": Generate SVG format only + +The output messages will include something like: +``` +C_local layout inference: + Shape: [32, 32] -> [8] + Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] +``` + + ## Conclusion By carefully examining intermediate representations (IR) before final code generation—and by leveraging runtime printing through `T.print`—one can quickly diagnose where index calculations, copy logic, or other kernel operations deviate from the intended behavior. This two-pronged approach (inspecting IR transformations and using runtime prints) is often sufficient for resolving generation and correctness issues in TileLang programs. diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 27c712b64..7b275a833 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -2,6 +2,8 @@ import tilelang.language as T tilelang.disable_cache() + + # use pass_configs to enable layout visualization @tilelang.jit( out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "False"}) diff --git a/pyproject.toml b/pyproject.toml index 1e4c59f92..d793fb1b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ # mldtypes should be greater than 0.5.1 # if you want to enable fp4 fp4 = ["ml-dtypes>=0.5.1"] +# if you want to enable layout inference visualization vis = ["matplotlib"] [build-system] diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 1fd7dc12e..6f40c6245 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -70,7 +70,7 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() - + config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) config_str = str(config_value).strip().lower() diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 9bc399f13..66e08294d 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -11,8 +11,7 @@ def print_layout_format(layout: T.Fragment) -> str: input_shape = layout.get_input_shape() output_shape = layout.get_output_shape() lines = [ - f" Shape: {input_shape} -> {output_shape}", - f" Thread: {layout.forward_thread}", + f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}" ] @@ -43,11 +42,13 @@ def visit_block_(self, op: tir.Block) -> None: self.visit_stmt(op.body) + def LayoutVisual(): def pass_fn(func: tir.PrimFunc, mod, ctx): pass_ctx = tilelang.transform.get_pass_context() - config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + config_value = pass_ctx.config.get( + tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) config_str = str(config_value).strip().lower() if not config_str or config_str == "false": @@ -56,7 +57,7 @@ def pass_fn(func: tir.PrimFunc, mod, ctx): formats = "all" else: formats = config_str - + _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 00b208e8f..0d96b8129 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -1,4 +1,4 @@ -from torch import Value +from __future__ import annotations import tilelang.language as T @@ -209,10 +209,8 @@ def plot_layout(layout: T.Fragment, else: formats_list = [formats_str] else: - raise TypeError( - f"Expected str, but got {type(formats).__name__}. " - f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." - ) + raise TypeError(f"Expected str, but got {type(formats).__name__}. " + f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") # Save the figure if 'pdf' in formats_list: @@ -228,4 +226,4 @@ def plot_layout(layout: T.Fragment, if 'svg' in formats_list: svg_path = tmp_directory / f"{name}.svg" plt.savefig(svg_path, bbox_inches="tight", format="svg") - print(f"Saved svg format into {svg_path}") \ No newline at end of file + print(f"Saved svg format into {svg_path}") diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 3e4931862..db090aafe 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -75,7 +75,7 @@ class PassConfigKey(str, Enum): - "true" or "all": enabled, generate all formats (pdf, png, svg) - "png", "pdf", "svg": enabled, generate specified format - "png,svg": enabled, generate multiple formats (comma-separated) - """"" + """ "" TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace" """Control StorageRewrite inplace detection. From 360854b50ffded5039c6ebca50fd3367caac7679 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 12:12:09 +0800 Subject: [PATCH 09/21] fix lint --- tilelang/tools/layout_visual.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 66e08294d..c2c6f516f 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -4,7 +4,6 @@ from tvm.tir import PyStmtExprVisitor from tvm.tir.transform import prim_func_pass -from tilelang.tools.plot_layout import plot_layout def print_layout_format(layout: T.Fragment) -> str: @@ -37,6 +36,7 @@ def visit_block_(self, op: tir.Block) -> None: if layout_id not in self.processed_layouts: print(f"{key} layout inference:") print(print_layout_format(layout)) + from tilelang.tools.plot_layout import plot_layout plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id) From bb9d49fda7c5cf101edea3ab2f347d78783c3019 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 12:20:09 +0800 Subject: [PATCH 10/21] fix lint --- tilelang/engine/phase.py | 6 +++++- tilelang/tools/layout_visual.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 6f40c6245..192f51992 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -71,7 +71,11 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() - config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, + "") + + if config_value is None: + return False config_str = str(config_value).strip().lower() return bool(config_str and config_str != "false") diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index c2c6f516f..66e08294d 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -4,6 +4,7 @@ from tvm.tir import PyStmtExprVisitor from tvm.tir.transform import prim_func_pass +from tilelang.tools.plot_layout import plot_layout def print_layout_format(layout: T.Fragment) -> str: @@ -36,7 +37,6 @@ def visit_block_(self, op: tir.Block) -> None: if layout_id not in self.processed_layouts: print(f"{key} layout inference:") print(print_layout_format(layout)) - from tilelang.tools.plot_layout import plot_layout plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id) From eac27c93c59185fff70ecffba5a7bb1e90a0376d Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 15:44:33 +0800 Subject: [PATCH 11/21] Rafactor configurable layout visualization formats --- .../visual_layout_inference.py | 2 +- tilelang/engine/phase.py | 11 ++++++++++- tilelang/tools/layout_visual.py | 12 +++++------- tilelang/transform/pass_config.py | 4 ++-- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 7b275a833..fea511c58 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -6,7 +6,7 @@ # use pass_configs to enable layout visualization @tilelang.jit( - out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "False"}) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "png, pdf"}) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 192f51992..595f9484e 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -78,7 +78,16 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: return False config_str = str(config_value).strip().lower() - return bool(config_str and config_str != "false") + valid_formats = ["png", "pdf", "svg", "all"] + formats_list = [f.strip() for f in config_str.split(",")] + + invalid_formats = [fmt for fmt in formats_list if fmt not in valid_formats] + if invalid_formats: + raise ValueError( + f"Invalid formats for TL_ENABLE_LAYOUT_VISUALIZATION: {invalid_formats}. " + f"Valid formats are: {valid_formats}. " + f"You can choose one of the valid formats or a comma-separated list of formats.") + return True def LayoutVisual(mod: IRModule) -> None: diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 66e08294d..ed69d4acc 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -48,15 +48,13 @@ def LayoutVisual(): def pass_fn(func: tir.PrimFunc, mod, ctx): pass_ctx = tilelang.transform.get_pass_context() config_value = pass_ctx.config.get( - tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") - config_str = str(config_value).strip().lower() - if not config_str or config_str == "false": + if config_value is None: return func - elif config_str == "true": - formats = "all" - else: - formats = config_str + + config_str = str(config_value).strip().lower() + formats = config_str _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index db090aafe..6823abd9d 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -71,8 +71,8 @@ class PassConfigKey(str, Enum): TL_ENABLE_LAYOUT_VISUALIZATION = "tl.enable_layout_visualization" """Enable layout inference visualization. Accepts string values: - - "" or "false": disabled (default) - - "true" or "all": enabled, generate all formats (pdf, png, svg) + - "" or: disabled (default) + - "all": enabled, generate all formats (pdf, png, svg) - "png", "pdf", "svg": enabled, generate specified format - "png,svg": enabled, generate multiple formats (comma-separated) """ "" From fd25d67f8d950e0311ca5f35481d35147a06ad16 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 15:59:58 +0800 Subject: [PATCH 12/21] fix lint --- examples/visual_layout_inference/visual_layout_inference.py | 4 +--- tilelang/engine/phase.py | 4 ++-- tilelang/tools/layout_visual.py | 6 +----- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index fea511c58..7eabe741a 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -1,12 +1,10 @@ import tilelang import tilelang.language as T -tilelang.disable_cache() - # use pass_configs to enable layout visualization @tilelang.jit( - out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "png, pdf"}) + out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "png"}) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 595f9484e..759387a34 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -74,10 +74,10 @@ def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") - if config_value is None: + config_str = str(config_value).strip().lower() + if not config_str: return False - config_str = str(config_value).strip().lower() valid_formats = ["png", "pdf", "svg", "all"] formats_list = [f.strip() for f in config_str.split(",")] diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index ed69d4acc..963c1ff27 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -50,13 +50,9 @@ def pass_fn(func: tir.PrimFunc, mod, ctx): config_value = pass_ctx.config.get( tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") - if config_value is None: - return func - config_str = str(config_value).strip().lower() - formats = config_str - _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) + _LayoutVisualVisitor(formats=config_str).visit_stmt(func.body) return func return prim_func_pass(pass_fn, opt_level=0) From 86eb08a0632b7a0e49d1cf39bfd7c51c67470382 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sun, 30 Nov 2025 20:01:16 +0800 Subject: [PATCH 13/21] fix typo --- docs/tutorials/debug_tools_for_tilelang.md | 2 +- tilelang/tools/plot_layout.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md index 1fd513777..6203c8062 100644 --- a/docs/tutorials/debug_tools_for_tilelang.md +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -181,7 +181,7 @@ When TileLang performs layout inference, it determines how fragment buffers are The visual layout inference tool is controlled through the `TL_ENABLE_LAYOUT_VISUALIZATION` pass configuration. By default, visualization is **disabled** to avoid performance overhead during compilation. `TL_ENABLE_LAYOUT_VISUALIZATION` accepts string values to control output formats: -- "True" or "all": Enabled, generates all formats (PDF, PNG, SVG) +- "all": Enabled, generates all formats (PDF, PNG, SVG) - "png": Generate PNG format only - "pdf": Generate PDF format only - "svg": Generate SVG format only diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 0d96b8129..a75fe3495 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -26,7 +26,7 @@ def plot_layout(layout: T.Fragment, formats : str | list[str], optional The formats to save the image in (default is "png"). Returns - -------s + ------- None """ import os From 2da20b0c3f5ff7834d63d1ec000c5ba08a106e90 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Mon, 1 Dec 2025 20:32:49 +0800 Subject: [PATCH 14/21] add some comments --- tilelang/tools/layout_visual.py | 55 ++++++++++++++++++++++++++++----- 1 file changed, 47 insertions(+), 8 deletions(-) diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 963c1ff27..4fdea3dba 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -7,19 +7,58 @@ from tilelang.tools.plot_layout import plot_layout -def print_layout_format(layout: T.Fragment) -> str: - input_shape = layout.get_input_shape() - output_shape = layout.get_output_shape() - lines = [ - f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", - f" Index: {layout.forward_index}" - ] +def print_fragment_format(layout: T.Fragment) -> str: + """ + Format fragment layout information into a human-readable string. + + Parameters + ---------- + layout : T.Fragment + The fragment layout to format + + Returns + ------- + str + Formatted string showing shape, thread mapping, and index mapping + """ + if isinstance(layout, T.Fragment): + input_shape = layout.get_input_shape() + output_shape = layout.get_output_shape() + lines = [ + f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", + f" Index: {layout.forward_index}" + ] + else: + raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") return "\n".join(lines) @tir.functor.visitor class _LayoutVisualVisitor(PyStmtExprVisitor): + """ + User-friendly pass which visualizes fragment layouts inferred during compilation. + + In TileLang, Fragment layouts describe: + - How logical indices (e.g., [i, j]) map to thread IDs + - How logical indices map to register file locations within each thread + - The shape transformation from input dimensions to output dimensions + + This pass generates two types of output: + 1. Textual output: A human-readable description printed to console + 2. Visual diagrams: Color-coded plots saved to files (PDF, PNG, SVG formats) + + Configuration: + The pass is controlled by the TL_ENABLE_LAYOUT_VISUALIZATION configuration option. + The configuration accepts string values: + + - Empty string or not set: Pass does nothing (default, disabled) + - "png": Generate PNG format only (recommended for quick inspection) + - "pdf": Generate PDF format only (recommended for documentation) + - "svg": Generate SVG format only (recommended for web/vector graphics) + - "all": Generate all formats (PDF, PNG, SVG) + - "png,svg": Generate multiple formats (comma-separated) + """ def __init__(self, formats: str = "png"): super().__init__() @@ -36,7 +75,7 @@ def visit_block_(self, op: tir.Block) -> None: layout_id = str(layout) if layout_id not in self.processed_layouts: print(f"{key} layout inference:") - print(print_layout_format(layout)) + print(print_fragment_format(layout)) plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id) From 03dc6b861cbd74908b8d4b6059627f24c19dbd37 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Mon, 1 Dec 2025 20:48:22 +0800 Subject: [PATCH 15/21] fix lints --- tilelang/tools/layout_visual.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 4fdea3dba..8d207e040 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -28,11 +28,10 @@ def print_fragment_format(layout: T.Fragment) -> str: f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", f" Index: {layout.forward_index}" ] + print("\n".join(lines)) else: raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") - return "\n".join(lines) - @tir.functor.visitor class _LayoutVisualVisitor(PyStmtExprVisitor): @@ -75,7 +74,7 @@ def visit_block_(self, op: tir.Block) -> None: layout_id = str(layout) if layout_id not in self.processed_layouts: print(f"{key} layout inference:") - print(print_fragment_format(layout)) + print_fragment_format(layout) plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id) From 5d075f2c20c50e894786d2987973dd6252aec6b3 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Mon, 1 Dec 2025 20:55:58 +0800 Subject: [PATCH 16/21] add some warnings for user --- tilelang/tools/plot_layout.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index a75fe3495..06e01f489 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -87,8 +87,17 @@ def plot_layout(layout: T.Fragment, # Show the distribution of registers in each thread of a warp. warp_size = 32 + # Warn if the number of threads is less than the warp size + if num_threads < warp_size: + import warnings + warnings.warn( + f"Layout visualization has {num_threads} threads, which is less than the warp size ({warp_size}). " + f"For the best viewing experience, it is recommended to have at least {warp_size} threads.", + UserWarning, + stacklevel=2) spectral_camp = plt.get_cmap("hsv", warp_size * 6) - for i in range(warp_size): + + for i in range(min(warp_size, num_threads)): colors[i] = spectral_camp(i * 6) # Determine the number of rows and columns in the input shape From 7b1532d0ece1246f20942b4191805b07733b85d4 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Mon, 1 Dec 2025 21:07:11 +0800 Subject: [PATCH 17/21] Moves layout visualization --- tilelang/engine/lower.py | 4 ---- tilelang/engine/phase.py | 2 ++ 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 5f6a2f099..88d89dcc2 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -16,7 +16,6 @@ from tilelang.engine.param import KernelParam, CompiledArtifact from tilelang.utils.target import determine_target from tilelang.engine.phase import ( - LayoutVisual, PreLowerSemanticCheck, LowerAndLegalize, OptimizeForTarget, @@ -250,9 +249,6 @@ def lower( # Phase 1: Lower and legalize the IR mod = LowerAndLegalize(mod, target) - # Visualize the layout - LayoutVisual(mod) - # Phase 2: Optimize the IR for the target mod = OptimizeForTarget(mod, target) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 759387a34..10ff57c30 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -152,6 +152,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tilelang.transform.LayoutReducer()(mod) # Infer memory layouts for fragments and shared memory mod = tilelang.transform.LayoutInference()(mod) + # Visualize the layout + LayoutVisual(mod) # Lower high-level tile operations to low-level operations mod = tilelang.transform.LowerTileOp()(mod) # Lower l2 persistent map From cf9b3c7c155e0786573919b8833e111f1ac36b18 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Fri, 5 Dec 2025 13:38:39 +0800 Subject: [PATCH 18/21] Refactors layout visualization pass configuration Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats. --- .../visual_layout_inference.py | 6 ++++- src/op/builtin.cc | 3 ++- src/op/builtin.h | 6 +++-- tilelang/engine/phase.py | 26 ++++++++++--------- tilelang/tools/layout_visual.py | 20 +++++--------- tilelang/transform/pass_config.py | 15 ++++++----- 6 files changed, 40 insertions(+), 36 deletions(-) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index 7eabe741a..e92b7345e 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -4,7 +4,11 @@ # use pass_configs to enable layout visualization @tilelang.jit( - out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION: "png"}) + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "pdf" + }) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @T.prim_func diff --git a/src/op/builtin.cc b/src/op/builtin.cc index a9f9f541f..260ba6fa9 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -34,7 +34,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kStorageRewriteDetectInplace, Bool); -TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisualization, ffi::String); +TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationEnable, Bool); +TVM_REGISTER_PASS_CONFIG_OPTION(kLayoutVisualizationFormats, String); DataType cuTensorMapType() { return DataType::UInt(8, 128); } diff --git a/src/op/builtin.h b/src/op/builtin.h index a94801649..ea861d067 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -51,8 +51,10 @@ static constexpr const char *kDisableWGMMA = "tl.disable_wgmma"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kStorageRewriteDetectInplace = "tl.storage_rewrite_detect_inplace"; -static constexpr const char *kEnableLayoutVisualization = - "tl.enable_layout_visualization"; +static constexpr const char *kLayoutVisualizationEnable = + "tl.layout_visualization_enable"; +static constexpr const char *kLayoutVisualizationFormats = + "tl.layout_visualization_formats"; /*! * \brief Whether to disable dynamic tail split * diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 7cde12f49..e4646fbe9 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -70,30 +70,32 @@ def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() + enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False) + return enabled - config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, - "") - config_str = str(config_value).strip().lower() - if not config_str: - return False +def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: + if pass_ctx is None: + pass_ctx = tilelang.transform.get_pass_context() + formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") + if not formats_value: + return "" + formats_str = formats_value.strip().lower() valid_formats = ["png", "pdf", "svg", "all"] - formats_list = [f.strip() for f in config_str.split(",")] - - invalid_formats = [fmt for fmt in formats_list if fmt not in valid_formats] - if invalid_formats: + if formats_str not in valid_formats: raise ValueError( - f"Invalid formats for TL_ENABLE_LAYOUT_VISUALIZATION: {invalid_formats}. " + f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {formats_str}. " f"Valid formats are: {valid_formats}. " f"You can choose one of the valid formats or a comma-separated list of formats.") - return True + return formats_str def LayoutVisual(mod: IRModule) -> None: """Apply layout visualization pass if enabled.""" if should_enable_layout_visual(): - tilelang.tools.LayoutVisual()(mod) + formats = get_layout_visual_formats() + tilelang.tools.LayoutVisual(formats=formats)(mod) def PreLowerSemanticCheck(mod: IRModule) -> None: diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 8d207e040..3711e8cb7 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -1,4 +1,3 @@ -import tilelang import tilelang.language as T from tvm import tir from tvm.tir import PyStmtExprVisitor @@ -59,11 +58,11 @@ class _LayoutVisualVisitor(PyStmtExprVisitor): - "png,svg": Generate multiple formats (comma-separated) """ - def __init__(self, formats: str = "png"): + def __init__(self, formats: str = ""): super().__init__() self.layout_found = [] self.processed_layouts = set() - self.formats = formats + self.formats = formats.strip().lower() if formats else "" def visit_block_(self, op: tir.Block) -> None: if "layout_map" in op.annotations: @@ -75,22 +74,17 @@ def visit_block_(self, op: tir.Block) -> None: if layout_id not in self.processed_layouts: print(f"{key} layout inference:") print_fragment_format(layout) - plot_layout(layout, name=f"{key}_layout", formats=self.formats) + if self.formats: + plot_layout(layout, name=f"{key}_layout", formats=self.formats) self.processed_layouts.add(layout_id) - self.visit_stmt(op.body) + # super().visit_block_(op) -def LayoutVisual(): +def LayoutVisual(formats: str = ""): def pass_fn(func: tir.PrimFunc, mod, ctx): - pass_ctx = tilelang.transform.get_pass_context() - config_value = pass_ctx.config.get( - tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") - - config_str = str(config_value).strip().lower() - - _LayoutVisualVisitor(formats=config_str).visit_stmt(func.body) + _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func return prim_func_pass(pass_fn, opt_level=0) diff --git a/tilelang/transform/pass_config.py b/tilelang/transform/pass_config.py index 6823abd9d..92adcb42c 100644 --- a/tilelang/transform/pass_config.py +++ b/tilelang/transform/pass_config.py @@ -69,13 +69,14 @@ class PassConfigKey(str, Enum): TL_FORCE_LET_INLINE = "tl.force_let_inline" """Force TileLang to inline let bindings during simplification. Default: False""" - TL_ENABLE_LAYOUT_VISUALIZATION = "tl.enable_layout_visualization" - """Enable layout inference visualization. Accepts string values: - - "" or: disabled (default) - - "all": enabled, generate all formats (pdf, png, svg) - - "png", "pdf", "svg": enabled, generate specified format - - "png,svg": enabled, generate multiple formats (comma-separated) - """ "" + TL_LAYOUT_VISUALIZATION_ENABLE = "tl.layout_visualization_enable" + """Enable layout inference visualization. Default: False""" + + TL_LAYOUT_VISUALIZATION_FORMATS = "tl.layout_visualization_formats" + """Layout visualization formats. + Acceptable values: "pdf", "png", "svg", "all" + + """ TL_STORAGE_REWRITE_DETECT_INPLACE = "tl.storage_rewrite_detect_inplace" """Control StorageRewrite inplace detection. From c1230a47f1cd8640c30e0dae5efb20d5e2158f44 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Fri, 5 Dec 2025 23:07:59 +0800 Subject: [PATCH 19/21] Enables multiple layout visualization formats --- .../visual_layout_inference.py | 4 ++-- tilelang/engine/phase.py | 23 ++++++++++++++----- tilelang/tools/layout_visual.py | 8 +++---- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/examples/visual_layout_inference/visual_layout_inference.py b/examples/visual_layout_inference/visual_layout_inference.py index e92b7345e..3677d4754 100644 --- a/examples/visual_layout_inference/visual_layout_inference.py +++ b/examples/visual_layout_inference/visual_layout_inference.py @@ -7,7 +7,7 @@ out_idx=[-1], pass_configs={ tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE: True, - tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "pdf" + tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg" }) def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): @@ -50,7 +50,7 @@ def main(): # print the layout visualization result and save figures to ./tmp. ''' - C_local layout inference: + C_local inferenced layout: Shape: [32, 32] -> [8] Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index e4646fbe9..307c2fbd1 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -79,16 +79,27 @@ def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: pass_ctx = tilelang.transform.get_pass_context() formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") if not formats_value: - return "" + return ["txt"] formats_str = formats_value.strip().lower() - valid_formats = ["png", "pdf", "svg", "all"] - if formats_str not in valid_formats: + valid_formats = ["txt", "png", "pdf", "svg", "all"] + + if formats_str == "all": + return ["txt", "png", "pdf", "svg"] + + if "," in formats_str: + formats_list = [f.strip() for f in formats_str.split(',')] + else: + formats_list = [formats_str] + + invalid_formats = [f for f in formats_list if f not in valid_formats] + if invalid_formats: raise ValueError( - f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {formats_str}. " + f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. " f"Valid formats are: {valid_formats}. " - f"You can choose one of the valid formats or a comma-separated list of formats.") - return formats_str + f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')" + ) + return formats_list def LayoutVisual(mod: IRModule) -> None: diff --git a/tilelang/tools/layout_visual.py b/tilelang/tools/layout_visual.py index 3711e8cb7..f27eed7d2 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/tools/layout_visual.py @@ -58,11 +58,11 @@ class _LayoutVisualVisitor(PyStmtExprVisitor): - "png,svg": Generate multiple formats (comma-separated) """ - def __init__(self, formats: str = ""): + def __init__(self, formats: list[str] = ""): super().__init__() self.layout_found = [] self.processed_layouts = set() - self.formats = formats.strip().lower() if formats else "" + self.formats_list = [f for f in formats if f != "txt"] def visit_block_(self, op: tir.Block) -> None: if "layout_map" in op.annotations: @@ -74,8 +74,8 @@ def visit_block_(self, op: tir.Block) -> None: if layout_id not in self.processed_layouts: print(f"{key} layout inference:") print_fragment_format(layout) - if self.formats: - plot_layout(layout, name=f"{key}_layout", formats=self.formats) + for fmt in self.formats_list: + plot_layout(layout, name=f"{key}_layout", formats=fmt) self.processed_layouts.add(layout_id) # super().visit_block_(op) From 9282abd9879a4a7508e70ab4f5dbb94338c6ffd1 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Fri, 5 Dec 2025 23:12:17 +0800 Subject: [PATCH 20/21] Updates layout visualization docs --- docs/tutorials/debug_tools_for_tilelang.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/tutorials/debug_tools_for_tilelang.md b/docs/tutorials/debug_tools_for_tilelang.md index 6203c8062..f8dfaab82 100644 --- a/docs/tutorials/debug_tools_for_tilelang.md +++ b/docs/tutorials/debug_tools_for_tilelang.md @@ -178,17 +178,19 @@ When TileLang performs layout inference, it determines how fragment buffers are 1. **Textual output**: A human-readable description of the layout mapping 2. **Visual diagrams**: Color-coded plots showing the thread-to-data mapping -The visual layout inference tool is controlled through the `TL_ENABLE_LAYOUT_VISUALIZATION` pass configuration. By default, visualization is **disabled** to avoid performance overhead during compilation. +The visual layout inference tool is controlled through the `TL_LAYOUT_VISUALIZATION_ENABLE` and `TL_LAYOUT_VISUALIZATION_FORMATS` pass configuration. By default, `TL_LAYOUT_VISUALIZATION_ENABLE` is **disabled** to avoid performance overhead during compilation. -`TL_ENABLE_LAYOUT_VISUALIZATION` accepts string values to control output formats: -- "all": Enabled, generates all formats (PDF, PNG, SVG) +When enabled, `TL_LAYOUT_VISUALIZATION_FORMATS` accepts string values to control output formats: +- "txt": Text output only (same as default) +- "all": Generates all formats (TXT, PDF, PNG, SVG) - "png": Generate PNG format only - "pdf": Generate PDF format only - "svg": Generate SVG format only +- "txt,svg": Generate multiple formats (comma-separated) in addition to text output -The output messages will include something like: +The output messages of "txt" will include something like: ``` -C_local layout inference: +C_local inferenced layout: Shape: [32, 32] -> [8] Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] From d897a95105e631878be3201148d04aa7f68c25b7 Mon Sep 17 00:00:00 2001 From: Cunxiao2002 Date: Sat, 6 Dec 2025 13:08:40 +0800 Subject: [PATCH 21/21] Moves layout visualization to analysis --- tilelang/analysis/__init__.py | 1 + tilelang/{tools => analysis}/layout_visual.py | 2 +- tilelang/engine/phase.py | 2 +- tilelang/tools/__init__.py | 1 - 4 files changed, 3 insertions(+), 3 deletions(-) rename tilelang/{tools => analysis}/layout_visual.py (98%) diff --git a/tilelang/analysis/__init__.py b/tilelang/analysis/__init__.py index 33ccded64..4e4090d80 100644 --- a/tilelang/analysis/__init__.py +++ b/tilelang/analysis/__init__.py @@ -3,3 +3,4 @@ from .ast_printer import ASTPrinter # noqa: F401 from .nested_loop_checker import NestedLoopChecker # noqa: F401 from .fragment_loop_checker import FragmentLoopChecker # noqa: F401 +from .layout_visual import LayoutVisual # noqa: F401 diff --git a/tilelang/tools/layout_visual.py b/tilelang/analysis/layout_visual.py similarity index 98% rename from tilelang/tools/layout_visual.py rename to tilelang/analysis/layout_visual.py index f27eed7d2..782b9126d 100644 --- a/tilelang/tools/layout_visual.py +++ b/tilelang/analysis/layout_visual.py @@ -72,7 +72,7 @@ def visit_block_(self, op: tir.Block) -> None: if isinstance(layout, T.Fragment): layout_id = str(layout) if layout_id not in self.processed_layouts: - print(f"{key} layout inference:") + print(f"{key} inferenced layout:") print_fragment_format(layout) for fmt in self.formats_list: plot_layout(layout, name=f"{key}_layout", formats=fmt) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 307c2fbd1..e9eac8ac2 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -106,7 +106,7 @@ def LayoutVisual(mod: IRModule) -> None: """Apply layout visualization pass if enabled.""" if should_enable_layout_visual(): formats = get_layout_visual_formats() - tilelang.tools.LayoutVisual(formats=formats)(mod) + tilelang.analysis.LayoutVisual(formats=formats)(mod) def PreLowerSemanticCheck(mod: IRModule) -> None: diff --git a/tilelang/tools/__init__.py b/tilelang/tools/__init__.py index 06a009d85..7a8bde514 100644 --- a/tilelang/tools/__init__.py +++ b/tilelang/tools/__init__.py @@ -1,3 +1,2 @@ from .plot_layout import plot_layout # noqa: F401 from .Analyzer import * -from .layout_visual import LayoutVisual # noqa: F401