-
Notifications
You must be signed in to change notification settings - Fork 446
[Enhancement] Improve plot_layout visualization for Layouts #1811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement] Improve plot_layout visualization for Layouts #1811
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Caution Review failedThe pull request is closed. 📝 WalkthroughWalkthroughRefactors plot_layout to dispatch between Fragment and Layout rendering, adds format parsing and multi-format saving helpers, and introduces two new example scripts for swizzle and transform layout visualizations. Changes
Sequence DiagramsequenceDiagram
participant User
participant plot_layout as plot_layout()
participant TypeCheck as "isinstance(layout)"
participant FragPath as _plot_fragment_layout()
participant LayoutPath as _plot_layout_map()
participant Parser as _parse_formats()
participant Saver as _save_plot()
participant MPL as matplotlib
User->>plot_layout: call(layout, name, save_directory, colormap, formats, verbose)
plot_layout->>TypeCheck: determine layout type
alt Fragment
TypeCheck->>FragPath: render fragment cells (thread/local ids)
FragPath->>MPL: draw plot (colormap default "RdPu")
else Layout
TypeCheck->>LayoutPath: render layout map (positions)
LayoutPath->>MPL: draw plot (colormap default "Spectral")
else Invalid
TypeCheck-->>User: raise TypeError
end
MPL->>Parser: hand off plt object
Parser->>Saver: provide parsed formats list
Saver->>Saver: ensure save_directory exists
Saver->>MPL: savefig in each requested format
Saver-->>User: files saved
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In `@examples/plot_layout/README.md`:
- Around line 107-122: The two section titles "Layout Example (1D/2D)" and
"Swizzle Layout Examples" are plain text and should be Markdown headings; update
the README so those lines are prefixed with heading markers (e.g., add "## " or
"### " before "Layout Example (1D/2D)" and before "Swizzle Layout Examples") to
ensure they render as proper section headings and improve document structure.
In `@tilelang/tools/plot_layout.py`:
- Around line 263-265: The assertion currently compares each element to itself
due to variable shadowing (local_id used both as the reference and as the
generator loop variable); update the assertion in plot_layout.py so it compares
the chosen reference local_id to each element in local_ids (e.g., assert
all(local_id == lid for lid in local_ids)) or rename the generator variable to
avoid shadowing; ensure the assertion uses a different loop variable name (like
lid) so mismatched local IDs are detected.
- Around line 8-16: The _parse_formats function's signature allows str |
list[str] but only handles str; update it to accept list inputs by detecting
isinstance(formats, list) and normalizing each entry (strip and lower),
expanding "all" to ["pdf","png","svg"] if any element equals "all", and
validating that every element is a string (raise TypeError for invalid element
types), or alternatively change the type hint to just str if you intend to only
support strings; make this change in the _parse_formats function so callers
passing list[str] are properly handled.
- Around line 236-238: The T.Fragment visualization branch currently unpacks
input_shape with nrows, ncols = input_shape which fails for 1D shapes; update
that branch to mirror the T.Layout handling by checking the length of
input_shape and when it's 1 set nrows = input_shape[0] and ncols = 1 (otherwise
unpack as before), so the T.Fragment code (the block that defines nrows, ncols
for fragment visualization) handles 1D inputs the same way as the T.Layout
branch.
🧹 Nitpick comments (3)
tilelang/tools/plot_layout.py (3)
36-39:_linear_to_input_viewsilently produces wrong results for >2D input shapes.This function only handles 1D and 2D. For ≥3D, Line 39 computes a row/col using only
input_shape[1], ignoring higher dimensions. While the caller inplot_layoutguards against >2D, a future caller won't get any warning. Consider adding a guard or an assertion.Proposed defensive check
def _linear_to_input_view(linear_idx: int, input_shape: list[int]) -> list[int]: + assert len(input_shape) <= 2, f"_linear_to_input_view only supports 1D/2D, got {input_shape}" if len(input_shape) == 1: return [linear_idx] return [linear_idx // input_shape[1], linear_idx % input_shape[1]]
102-104: Redundantint()calls flagged by Ruff (RUF046).
math.ceilalready returnsintin Python 3. The outerint()wrappers on Lines 104 and 118 are unnecessary.Cleanup
- row_step = 1 if nrows <= 16 else int(math.ceil(nrows / 12)) + row_step = 1 if nrows <= 16 else math.ceil(nrows / 12) ... - col_step = 1 if ncols <= 16 else int(math.ceil(ncols / 12)) + col_step = 1 if ncols <= 16 else math.ceil(ncols / 12)Also applies to: 117-118
305-314:T.Layoutbranch: 1D input withncols=1creates a tall single-column grid.For a 1D layout with large input (e.g., shape
[128]), this produces a 128×1 figure. This may not be the best visual representation. Consider allowing an optional 1D horizontal mode or wrapping. Low priority since it works correctly.
examples/plot_layout/README.md
Outdated
| Layout Example (1D/2D) | ||
|
|
||
| ```python | ||
| import tilelang.language as T | ||
| from tilelang.tools import plot_layout | ||
|
|
||
| # 1D layout: reverse the index order | ||
| layout_1d_reverse = T.Layout([8], lambda i: 7 - i) | ||
| plot_layout(layout_1d_reverse, name="layout_1d_reverse") | ||
|
|
||
| # 2D layout: transpose | ||
| layout_2d_transpose = T.Layout([4, 6], lambda i, j: (j, i)) | ||
| plot_layout(layout_2d_transpose, name="layout_2d_transpose") | ||
| ``` | ||
|
|
||
| Swizzle Layout Examples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New section titles are missing Markdown heading markers.
"Layout Example (1D/2D)" (Line 107) and "Swizzle Layout Examples" (Line 122) will render as plain body text rather than section headings. Prefix them with ## or ### for proper document structure.
Proposed fix
-Layout Example (1D/2D)
+## Layout Example (1D/2D)
...
-Swizzle Layout Examples
+## Swizzle Layout Examples📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Layout Example (1D/2D) | |
| ```python | |
| import tilelang.language as T | |
| from tilelang.tools import plot_layout | |
| # 1D layout: reverse the index order | |
| layout_1d_reverse = T.Layout([8], lambda i: 7 - i) | |
| plot_layout(layout_1d_reverse, name="layout_1d_reverse") | |
| # 2D layout: transpose | |
| layout_2d_transpose = T.Layout([4, 6], lambda i, j: (j, i)) | |
| plot_layout(layout_2d_transpose, name="layout_2d_transpose") | |
| ``` | |
| Swizzle Layout Examples | |
| ## Layout Example (1D/2D) | |
🤖 Prompt for AI Agents
In `@examples/plot_layout/README.md` around lines 107 - 122, The two section
titles "Layout Example (1D/2D)" and "Swizzle Layout Examples" are plain text and
should be Markdown headings; update the README so those lines are prefixed with
heading markers (e.g., add "## " or "### " before "Layout Example (1D/2D)" and
before "Swizzle Layout Examples") to ensure they render as proper section
headings and improve document structure.
tilelang/tools/plot_layout.py
Outdated
| def _parse_formats(formats: str | list[str]) -> list[str]: | ||
| if isinstance(formats, str): | ||
| formats_str = formats.strip().lower() | ||
| if formats_str == "all": | ||
| return ["pdf", "png", "svg"] | ||
| if "," in formats_str: | ||
| return [f.strip() for f in formats_str.split(",")] | ||
| return [formats_str] | ||
| raise TypeError(f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_parse_formats rejects list[str] despite the type hint accepting it.
The type annotation is str | list[str], but the function body only handles str. Passing a list[str] falls through to the raise TypeError on Line 16. Either handle list inputs or narrow the type hint.
🐛 Proposed fix
def _parse_formats(formats: str | list[str]) -> list[str]:
if isinstance(formats, str):
formats_str = formats.strip().lower()
if formats_str == "all":
return ["pdf", "png", "svg"]
if "," in formats_str:
return [f.strip() for f in formats_str.split(",")]
return [formats_str]
+ if isinstance(formats, list):
+ return [f.strip().lower() for f in formats]
raise TypeError(f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _parse_formats(formats: str | list[str]) -> list[str]: | |
| if isinstance(formats, str): | |
| formats_str = formats.strip().lower() | |
| if formats_str == "all": | |
| return ["pdf", "png", "svg"] | |
| if "," in formats_str: | |
| return [f.strip() for f in formats_str.split(",")] | |
| return [formats_str] | |
| raise TypeError(f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") | |
| def _parse_formats(formats: str | list[str]) -> list[str]: | |
| if isinstance(formats, str): | |
| formats_str = formats.strip().lower() | |
| if formats_str == "all": | |
| return ["pdf", "png", "svg"] | |
| if "," in formats_str: | |
| return [f.strip() for f in formats_str.split(",")] | |
| return [formats_str] | |
| if isinstance(formats, list): | |
| return [f.strip().lower() for f in formats] | |
| raise TypeError(f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") |
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 16-16: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@tilelang/tools/plot_layout.py` around lines 8 - 16, The _parse_formats
function's signature allows str | list[str] but only handles str; update it to
accept list inputs by detecting isinstance(formats, list) and normalizing each
entry (strip and lower), expanding "all" to ["pdf","png","svg"] if any element
equals "all", and validating that every element is a string (raise TypeError for
invalid element types), or alternatively change the type hint to just str if you
intend to only support strings; make this change in the _parse_formats function
so callers passing list[str] are properly handled.
tilelang/tools/plot_layout.py
Outdated
|
|
||
| # Determine the number of rows and columns in the input shape | ||
| nrows, ncols = input_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cat -n tilelang/tools/plot_layout.py | sed -n '220,250p'Repository: tile-ai/tilelang
Length of output: 1678
🏁 Script executed:
cat -n tilelang/tools/plot_layout.py | sed -n '300,320p'Repository: tile-ai/tilelang
Length of output: 1085
🏁 Script executed:
rg -n "Fragment\s*\(" tilelang/tools/plot_layout.py -A2 -B2Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for Fragment instantiation across the codebase
rg -n "Fragment\s*\(" --type=py -A2 | head -80Repository: tile-ai/tilelang
Length of output: 7009
Add 1D input shape handling to the T.Fragment visualization branch.
Line 238 unpacks input_shape as nrows, ncols = input_shape, which crashes with ValueError if input_shape has length ≠ 2. The T.Layout branch (lines 309–314) handles 1D gracefully by converting it to 2D (nrows, ncols = input_shape[0], 1). Apply similar logic to the T.Fragment branch to ensure consistency and prevent crashes if 1D fragments are visualized.
🤖 Prompt for AI Agents
In `@tilelang/tools/plot_layout.py` around lines 236 - 238, The T.Fragment
visualization branch currently unpacks input_shape with nrows, ncols =
input_shape which fails for 1D shapes; update that branch to mirror the T.Layout
handling by checking the length of input_shape and when it's 1 set nrows =
input_shape[0] and ncols = 1 (otherwise unpack as before), so the T.Fragment
code (the block that defines nrows, ncols for fragment visualization) handles 1D
inputs the same way as the T.Layout branch.
tilelang/tools/plot_layout.py
Outdated
| local_id = local_ids[0] | ||
| # assert local id in local_ids is equal | ||
| assert all(local_id == local_id for local_id in local_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: assertion is a tautology — always passes.
local_id on Line 263 is shadowed by the loop variable in the generator on Line 265. The expression local_id == local_id compares each element to itself, which is always True. This means the assertion never catches mismatched local IDs.
🐛 Proposed fix
local_id = local_ids[0]
- # assert local id in local_ids is equal
- assert all(local_id == local_id for local_id in local_ids)
+ # assert all local_ids are equal
+ assert all(lid == local_id for lid in local_ids)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| local_id = local_ids[0] | |
| # assert local id in local_ids is equal | |
| assert all(local_id == local_id for local_id in local_ids) | |
| local_id = local_ids[0] | |
| # assert all local_ids are equal | |
| assert all(lid == local_id for lid in local_ids) |
🤖 Prompt for AI Agents
In `@tilelang/tools/plot_layout.py` around lines 263 - 265, The assertion
currently compares each element to itself due to variable shadowing (local_id
used both as the reference and as the generator loop variable); update the
assertion in plot_layout.py so it compares the chosen reference local_id to each
element in local_ids (e.g., assert all(local_id == lid for lid in local_ids)) or
rename the generator variable to avoid shadowing; ensure the assertion uses a
different loop variable name (like lid) so mismatched local IDs are detected.
…s for visualization. Update parameters for colormap and formats, and introduce helper functions for format parsing and saving plots. Improve documentation for clarity on usage and expected input types.
e3403d4 to
ce08dfe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tilelang/tools/plot_layout.py`:
- Around line 240-243: _plot_fragment_layout currently calls plt.tight_layout()
and _save_plot(plt, save_directory, name, formats) but does not close the
figure, causing a matplotlib figure resource leak; mirror the behavior in
_plot_layout_map by calling plt.close() (or plt.close(fig) if a Figure object is
available) immediately after _save_plot to free resources. Ensure the close call
is placed inside the same scope as _save_plot so every invocation of
_plot_fragment_layout releases the figure.
🧹 Nitpick comments (4)
tilelang/tools/plot_layout.py (3)
10-10: UseOptional[str]instead of implicitNonedefault forcolormap.PEP 484 prohibits implicit
Optional. This should beOptional[str] = Noneorstr | None = None.Proposed fix
- colormap: str = None, + colormap: str | None = None,
359-359: Unused variablefigfromplt.subplots.Prefix with underscore to signal intent per convention and silence the linter warning.
Proposed fix
- fig, ax = plt.subplots(figsize=(fig_w, fig_h)) + _fig, ax = plt.subplots(figsize=(fig_w, fig_h))
61-85:_save_plotuses an if-chain instead of iterating overformats_list.Only
"pdf","png", and"svg"are handled. If_parse_formatsreturns an unrecognised format string (e.g."jpg"), it is silently ignored with no feedback.Consider either validating formats in
_parse_formatsor falling back to a genericplt.savefigcall for unknown formats, so users aren't confused by silent no-ops.examples/plot_layout/layout_swizzle.py (1)
17-18: Module-levelelement_sizeandvector_sizeare unused.Each factory function accepts its own
element_sizeparameter (defaulting to 16) and computesvsinternally, so these module-level variables are dead code.Proposed fix
-element_size = 16 # float16 = 16 bits -vector_size = 128 // element_size # = 8 elements per 128-bit vector
…eplace inline 2D swizzle functions with calls to `make_full_bank_swizzled_layout`, `make_half_bank_swizzled_layout`, and `make_quarter_bank_swizzled_layout` for improved clarity and maintainability in layout generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enhances plot_layout to visualize both Fragment and Layout mappings, adds new layout visualization examples (transformations + bank-swizzle), and consolidates multi-format saving with a PDF default.
Changes:
- Added type-based dispatch in
plot_layoutto support bothT.FragmentandT.Layoutvisualizations. - Refactored output saving into a consolidated multi-format flow and changed default output to PDF.
- Added new example scripts for layout transformations and shared-memory bank-swizzle patterns.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
tilelang/tools/plot_layout.py |
Adds Layout visualization path, centralizes save/format handling, defaults outputs to PDF. |
examples/plot_layout/layout_transform.py |
New examples demonstrating transpose/reshape/interleave layout mappings. |
examples/plot_layout/layout_swizzle.py |
New examples visualizing quarter/half/full bank swizzle layouts, including multi-tile. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| save_directory="./tmp", | ||
| name: str = "layout", | ||
| colormap: str = "RdPu", | ||
| colormap: str = None, |
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
colormap defaults to None but is annotated as str. To keep typing accurate (and avoid mypy/pyright complaints), change the annotation to str | None (or Optional[str]) consistently across the public plot_layout signature.
| colormap: str = None, | |
| colormap: str | None = None, |
| # determine output shape from actual output indices | ||
| output_shape = [0] * num_out_dims | ||
| for _, out_idx in mappings: | ||
| for k in range(num_out_dims): | ||
| output_shape[k] = max(output_shape[k], out_idx[k] + 1) | ||
|
|
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_shape is inferred from the max observed mapped indices. Layouts are allowed to have an explicit get_output_shape(), and the inferred max can be smaller than the true output shape (e.g., if the mapping is not surjective or intentionally leaves holes), which makes the visualization misleading. Prefer using layout.get_output_shape() (casting to ints) as the primary source of truth, and only fall back to max-observed inference if the output shape cannot be materialized.
| # determine output shape from actual output indices | |
| output_shape = [0] * num_out_dims | |
| for _, out_idx in mappings: | |
| for k in range(num_out_dims): | |
| output_shape[k] = max(output_shape[k], out_idx[k] + 1) | |
| # determine output shape, preferring an explicit layout.get_output_shape() | |
| output_shape = None | |
| # Try to use the layout's declared output shape, if available. | |
| try: | |
| if hasattr(layout, "get_output_shape"): | |
| declared_shape = layout.get_output_shape() | |
| if declared_shape is not None: | |
| output_shape = [int(d) for d in declared_shape] | |
| except Exception: | |
| # If anything goes wrong, fall back to inferring from observed indices. | |
| output_shape = None | |
| # Fallback: infer output shape from actual output indices | |
| if output_shape is None: | |
| output_shape = [0] * num_out_dims | |
| for _, out_idx in mappings: | |
| for k in range(num_out_dims): | |
| output_shape[k] = max(output_shape[k], out_idx[k] + 1) |
| if isinstance(formats, str): | ||
| formats_str = formats.strip().lower() | ||
| if formats_str == "all": | ||
| return ["pdf", "png", "svg"] | ||
| elif "," in formats_str: | ||
| return [f.strip() for f in formats_str.split(",")] | ||
| else: | ||
| return [formats_str] | ||
| else: | ||
| raise TypeError( | ||
| f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." | ||
| ) | ||
|
|
||
|
|
Copilot
AI
Feb 8, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
formats is annotated and documented as str | list[str], but _parse_formats only accepts str and raises TypeError for lists/tuples. This breaks callers that pass ['pdf','png'] (which the public API advertises). Either update _parse_formats to accept iterables of strings (normalizing/validating each entry, handling 'all' and comma-separated items), or change the public type hints/docs to str only for consistency.
| if isinstance(formats, str): | |
| formats_str = formats.strip().lower() | |
| if formats_str == "all": | |
| return ["pdf", "png", "svg"] | |
| elif "," in formats_str: | |
| return [f.strip() for f in formats_str.split(",")] | |
| else: | |
| return [formats_str] | |
| else: | |
| raise TypeError( | |
| f"Expected str, but got {type(formats).__name__}. Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." | |
| ) | |
| def _parse_single_format(fmt: str) -> list[str]: | |
| """Normalize a single string formats specification into a list.""" | |
| fmt_str = fmt.strip().lower() | |
| if fmt_str == "all": | |
| return ["pdf", "png", "svg"] | |
| if "," in fmt_str: | |
| return [f.strip() for f in fmt_str.split(",") if f.strip()] | |
| return [fmt_str] | |
| # Single string input (existing behavior) | |
| if isinstance(formats, str): | |
| return _parse_single_format(formats) | |
| # Iterable of strings, e.g. list/tuple/set | |
| if isinstance(formats, (list, tuple, set)): | |
| result: list[str] = [] | |
| for fmt in formats: | |
| if not isinstance(fmt, str): | |
| raise TypeError( | |
| f"Expected each format to be a str, but got {type(fmt).__name__}. " | |
| "Please pass strings like 'png', 'pdf', 'svg', 'all', or 'png,pdf'." | |
| ) | |
| result.extend(_parse_single_format(fmt)) | |
| return result | |
| # Unsupported type | |
| raise TypeError( | |
| f"Expected str or iterable of str, but got {type(formats).__name__}. " | |
| "Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf', " | |
| "or an iterable of such strings, e.g. ['pdf', 'png']." | |
| ) |
… are closed after saving in plot_layout.py for better resource management.
as title.
Summary by CodeRabbit
New Features
Chores