diff --git a/examples/plot_layout/layout_swizzle.py b/examples/plot_layout/layout_swizzle.py new file mode 100644 index 000000000..c8bd52600 --- /dev/null +++ b/examples/plot_layout/layout_swizzle.py @@ -0,0 +1,32 @@ +from tilelang.layout import ( + make_full_bank_swizzled_layout, + make_half_bank_swizzled_layout, + make_quarter_bank_swizzled_layout, +) +from tilelang.tools import plot_layout + +element_size = 16 # float16 = 16 bits + + +# ---- Plot the swizzle patterns ---- + +# 1. Quarter-bank (32B) — 1-bit XOR — 8x16 +# Rows 0-3: identity; Rows 4-7: two 8-element halves swap +layout = make_quarter_bank_swizzled_layout(8, 16, element_size) +print(f"Quarter-bank swizzle (8x16, fp16): {layout}") +plot_layout(layout, name="swizzle_quarter_8x16") + +# 2. Half-bank (64B) — 2-bit XOR — 8x32 +layout = make_half_bank_swizzled_layout(8, 32, element_size) +print(f"Half-bank swizzle (8x32, fp16): {layout}") +plot_layout(layout, name="swizzle_half_8x32") + +# 3. Full-bank (128B) — 3-bit XOR — 8x64 +layout = make_full_bank_swizzled_layout(8, 64, element_size) +print(f"Full-bank swizzle (8x64, fp16): {layout}") +plot_layout(layout, name="swizzle_full_8x64") + +# 4. Full-bank (128B) — multi-tile: 32x64 +layout = make_full_bank_swizzled_layout(32, 64, element_size) +print(f"Full-bank swizzle (32x64, fp16): {layout}") +plot_layout(layout, name="swizzle_full_32x64") diff --git a/examples/plot_layout/layout_transform.py b/examples/plot_layout/layout_transform.py new file mode 100644 index 000000000..144b3744f --- /dev/null +++ b/examples/plot_layout/layout_transform.py @@ -0,0 +1,24 @@ +import tilelang.language as T +from tilelang.tools import plot_layout + +# --- Example 1: Simple 2D Transpose (4x4) --- +transpose_layout = T.Layout([4, 4], lambda i, j: (j, i)) +print("Transpose 4x4:", transpose_layout) +plot_layout(transpose_layout, name="transpose_4x4") + +# --- Example 2: Larger Transpose (8x8) --- +transpose_8x8 = T.Layout([8, 8], lambda i, j: (j, i)) +print("Transpose 8x8:", transpose_8x8) +plot_layout(transpose_8x8, name="transpose_8x8") + +# --- Example 3: 3D → 2D reshape + transpose --- +# (i, j, k) with shape [2, 4, 8] → (k, i*4+j) +reshape_layout = T.Layout([2, 4, 8], lambda i, j, k: (k, i * 4 + j)) +print("Reshape 3D [2,4,8] -> [8,8]:", reshape_layout) +plot_layout(reshape_layout, name="reshape_3d_to_2d") + +# --- Example 4: Interleave layout --- +# Even rows from first half, odd rows from second half +interleave = T.Layout([8, 4], lambda i, j: (i % 4 * 2 + i // 4, j)) +print("Interleave [8,4]:", interleave) +plot_layout(interleave, name="interleave_8x4") diff --git a/tilelang/tools/plot_layout.py b/tilelang/tools/plot_layout.py index 887882e92..963931c94 100644 --- a/tilelang/tools/plot_layout.py +++ b/tilelang/tools/plot_layout.py @@ -4,36 +4,100 @@ def plot_layout( - layout: T.Fragment, + layout, save_directory="./tmp", name: str = "layout", - colormap: str = "RdPu", + colormap: str = None, verbose: bool = False, - formats: str | list[str] = "png", + formats: str | list[str] = "pdf", ) -> None: """ - Plot the layout of a buffer. + Plot the layout mapping as a 2D grid visualization. + + Dispatches to Fragment-specific or Layout-specific plotting based on the + type of the layout object. Parameters ---------- - layout : T.Layout - The layout object that describes how indices are mapped. + layout : T.Layout or T.Fragment + The layout object to visualize. save_directory : str, optional - The directory where the output images will be saved (default is "./tmp"). + Output directory (default "./tmp"). name : str, optional - The base name of the output files (default is "layout"). + Base filename for saved images (default "layout"). colormap : str, optional - The colormap to use for visualization (default is "RdPu"). + Matplotlib colormap name. Defaults to "RdPu" for Fragment, "Spectral" for Layout. verbose : bool, optional - If True, prints additional information about the mapping (default is False). + If True, print mapping details. formats : str | list[str], optional - The formats to save the image in (default is "png"). - Returns - ------- - None + Output format(s): "pdf", "png", "svg", "all", or comma-separated (default "pdf"). """ + from tilelang.layout.fragment import Fragment + + if isinstance(layout, Fragment): + _plot_fragment_layout(layout, save_directory, name, colormap or "RdPu", verbose, formats) + elif isinstance(layout, T.Layout): + _plot_layout_map(layout, save_directory, name, colormap or "Spectral", verbose, formats) + else: + raise TypeError(f"Expected T.Layout or T.Fragment, but got {type(layout).__name__}.") + + +def _parse_formats(formats): + """Parse the formats parameter into a list of format strings.""" + if isinstance(formats, str): + formats_str = formats.strip().lower() + if formats_str == "all": + return ["pdf", "png", "svg"] + elif "," in formats_str: + return [f.strip() for f in formats_str.split(",")] + else: + return [formats_str] + else: + raise TypeError( + f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." + ) + + +def _save_plot(plt, save_directory, name, formats): + """Save the current matplotlib figure in the specified format(s).""" import os import pathlib + + formats_list = _parse_formats(formats) + + tmp_directory = pathlib.Path(save_directory) + if not os.path.exists(tmp_directory): + os.makedirs(tmp_directory) + + if "pdf" in formats_list: + pdf_path = tmp_directory / f"{name}.pdf" + plt.savefig(pdf_path, bbox_inches="tight") + print(f"Saved pdf format into {pdf_path}") + + if "png" in formats_list: + png_path = tmp_directory / f"{name}.png" + plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) + print(f"Saved png format into {png_path}") + + if "svg" in formats_list: + svg_path = tmp_directory / f"{name}.svg" + plt.savefig(svg_path, bbox_inches="tight", format="svg") + print(f"Saved svg format into {svg_path}") + + +# --------------------------------------------------------------------------- +# Fragment-specific layout visualization (thread ID + local ID per cell) +# --------------------------------------------------------------------------- + + +def _plot_fragment_layout( + layout: T.Fragment, + save_directory="./tmp", + name: str = "layout", + colormap: str = "RdPu", + verbose: bool = False, + formats: str | list[str] = "pdf", +) -> None: import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -46,8 +110,6 @@ def plot_layout( # Get the total number of threads num_threads = int(layout.get_thread_size()) - import itertools - # Initialize a 2D array to store thread mappings thread_map = np.empty(input_shape, dtype=object) for idx in np.ndindex(thread_map.shape): @@ -177,42 +239,211 @@ def plot_layout( ncols=2, ) - # Create the output directory if it does not exist - tmp_directory = pathlib.Path(save_directory) - if not os.path.exists(tmp_directory): - os.makedirs(tmp_directory) - - # Save the figure in multiple formats plt.tight_layout() + _save_plot(plt, save_directory, name, formats) + plt.close() - if isinstance(formats, str): - formats_str = formats.strip().lower() - if formats_str == "all": - formats_list = ["pdf", "png", "svg"] - elif "," in formats_str: - formats_list = [f.strip() for f in formats_str.split(",")] - else: - formats_list = [formats_str] - else: - raise TypeError( - f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." - ) - # Save the figure - if "pdf" in formats_list: - pdf_path = tmp_directory / f"{name}.pdf" - plt.savefig(pdf_path, bbox_inches="tight") - print(f"Saved pdf format into {pdf_path}") +# --------------------------------------------------------------------------- +# Layout-specific visualization (position mapping, no thread/local ID) +# --------------------------------------------------------------------------- - if "png" in formats_list: - png_path = tmp_directory / f"{name}.png" - plt.savefig(png_path, bbox_inches="tight", transparent=False, dpi=255) - print(f"Saved png format into {png_path}") - if "svg" in formats_list: - svg_path = tmp_directory / f"{name}.svg" - plt.savefig(svg_path, bbox_inches="tight", format="svg") - print(f"Saved svg format into {svg_path}") +def _plot_layout_map( + layout: T.Layout, + save_directory="./tmp", + name: str = "layout", + colormap: str = "Spectral", + verbose: bool = False, + formats: str | list[str] = "pdf", +) -> None: + """ + Visualize a Layout object as a 2D grid showing position mappings. + + The grid represents the output space (viewed as 2D by keeping the last + dimension and flattening all preceding dimensions). Each cell displays the + original input coordinate that maps to that output position. + + Parameters + ---------- + layout : T.Layout + The layout object to visualize. + save_directory : str + Output directory. + name : str + Base filename. + colormap : str + Matplotlib colormap for coloring cells by source position. + verbose : bool + Print mapping details. + formats : str | list[str] + Output format(s). + """ + import functools + import operator + import numpy as np + import matplotlib.pyplot as plt + import matplotlib.patches as patches + + input_shape = [int(v) for v in layout.get_input_shape()] + total_in = functools.reduce(operator.mul, input_shape, 1) + + # -- helpers for N-D → 2-D conversion -------------------------------- + + def _flatten_to_2d(shape): + """Keep last dim, merge all preceding dims into one row dim.""" + if len(shape) <= 1: + return (1, shape[0]) if shape else (1, 1) + return (functools.reduce(operator.mul, shape[:-1], 1), shape[-1]) + + def _nd_to_2d(idx, shape): + """Convert an N-D index to (row, col) in the flattened 2-D view.""" + if len(shape) <= 1: + return (0, idx[0]) if shape else (0, 0) + row = 0 + for k in range(len(shape) - 1): + row = row * shape[k] + idx[k] + return (row, idx[-1]) + + # -- collect all input→output mappings --------------------------------- + + mappings = [] + num_out_dims = None + for in_idx in itertools.product(*[range(d) for d in input_shape]): + out_vals = layout.map_forward_index(list(in_idx)) + out_idx = tuple(int(v) for v in out_vals) + if num_out_dims is None: + num_out_dims = len(out_idx) + mappings.append((tuple(in_idx), out_idx)) + + # determine output shape from actual output indices + output_shape = [0] * num_out_dims + for _, out_idx in mappings: + for k in range(num_out_dims): + output_shape[k] = max(output_shape[k], out_idx[k] + 1) + + out_rows, out_cols = _flatten_to_2d(output_shape) + + if verbose: + print(f"Input shape : {input_shape}") + print(f"Output shape: {output_shape}") + print(f"Grid size : {out_rows} x {out_cols}") + + # -- build the output grid --------------------------------------------- + + grid_labels = [[None] * out_cols for _ in range(out_rows)] + grid_src_flat = np.full((out_rows, out_cols), -1, dtype=int) + + for in_idx, out_idx in mappings: + out_r, out_c = _nd_to_2d(out_idx, output_shape) + # flat source index for colour mapping + src_flat = 0 + for k in range(len(input_shape)): + src_flat = src_flat * input_shape[k] + in_idx[k] + + grid_labels[out_r][out_c] = list(in_idx) + grid_src_flat[out_r, out_c] = src_flat + + if verbose: + print(f" {list(in_idx)} -> {list(out_idx)} -> grid[{out_r}, {out_c}]") + + # -- plotting ---------------------------------------------------------- + + cmap = plt.get_cmap(colormap, max(total_in, 2)) + + # dynamic sizing + max_dim = max(out_rows, out_cols, 1) + cell_size = max(0.5, min(1.2, 16.0 / max_dim)) + + fig_w = cell_size * out_cols + 1.5 + fig_h = cell_size * out_rows + 1.0 + fig, ax = plt.subplots(figsize=(fig_w, fig_h)) + + # font size: adapt to cell size and longest label + sample_label = "[" + ",".join(str(d - 1) for d in input_shape) + "]" + max_label_len = len(sample_label) + cell_pts = cell_size * 72 # approximate cell width in points + base_font = max(5, min(16, cell_pts * 0.9 / max(max_label_len * 0.55, 1))) + + for i in range(out_rows): + for j in range(out_cols): + sf = grid_src_flat[i, j] + if sf >= 0: + color = cmap(sf / max(total_in - 1, 1)) + else: + color = (0.95, 0.95, 0.95, 1.0) + + rect = patches.Rectangle( + (j, i), + 1, + 1, + linewidth=0.8, + edgecolor="#aaaaaa", + facecolor=color, + ) + ax.add_patch(rect) + + coords = grid_labels[i][j] + if coords is not None: + label = "[" + ",".join(str(x) for x in coords) + "]" + r, g, b = color[0], color[1], color[2] + brightness = r * 0.299 + g * 0.587 + b * 0.114 + text_color = "white" if brightness < 0.5 else "black" + ax.text( + j + 0.5, + i + 0.5, + label, + ha="center", + va="center", + color=text_color, + fontsize=base_font, + fontfamily="monospace", + fontweight="bold", + ) + + # axis labels + label_font = max(5, min(10, base_font * 0.85)) + # row labels on the left + for i in range(out_rows): + ax.text(-0.15, i + 0.5, str(i), ha="right", va="center", fontsize=label_font, color="#666666") + # column labels at the bottom + for j in range(out_cols): + ax.text(j + 0.5, out_rows + 0.15, str(j), ha="center", va="top", fontsize=label_font, color="#666666") + + ax.set_xlim(-0.3, out_cols) + ax.set_ylim(-0.1, out_rows + 0.5) + ax.invert_yaxis() + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_visible(False) + + # outer border + border = patches.Rectangle( + (0, 0), + out_cols, + out_rows, + linewidth=1.5, + edgecolor="#333333", + facecolor="none", + ) + ax.add_patch(border) + + # title: show shape transformation + in_str = "x".join(str(d) for d in input_shape) + out_str = "x".join(str(d) for d in output_shape) + title_font = max(8, min(14, base_font * 1.1)) + ax.set_title(f"[{in_str}] -> [{out_str}]", fontsize=title_font, color="#333333", pad=8) + + plt.tight_layout() + _save_plot(plt, save_directory, name, formats) + plt.close() + + +# --------------------------------------------------------------------------- +# Fragment thread-value (TV) view +# --------------------------------------------------------------------------- def plot_fragment_tv( @@ -222,7 +453,7 @@ def plot_fragment_tv( apply_idx_fn=lambda *args: args, colormap: str = "RdPu", item_scale: float = 0.75, - formats: str | list[str] = "png", + formats: str | list[str] = "pdf", dpi=80, ): """ @@ -242,7 +473,7 @@ def plot_fragment_tv( item_scale : float, optional The scale factor for each item in the plot (default is 0.75). formats : str | list[str], optional - The formats to save the image in (default is "png"). + The formats to save the image in (default is "pdf"). dpi : int, optional The resolution in dots per inch for the saved image (default is 80). """