-
Notifications
You must be signed in to change notification settings - Fork 446
[Tool] Provide layout visualization tool #1353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation. This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance. The visualization can be enabled via a pass config option.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughA new layout visualization feature is added to TileLang, enabling visual and textual output of inferred layouts during compilation. Pass config options control enablement and output formats, with new Python API functions querying and applying visualization. A core analysis pass scans IR for layout annotations and generates plots via an enhanced plot_layout utility supporting multiple formats (txt, png, pdf, svg). Changes
Sequence DiagramsequenceDiagram
participant Compiler as Compilation Pipeline
participant Phase as LowerAndLegalize
participant Detector as should_enable_layout_visual
participant Formatter as get_layout_visual_formats
participant Visitor as LayoutVisual Pass
participant IR as IR Module
participant Plotter as plot_layout
Compiler->>Phase: invoke LowerAndLegalize(mod)
Phase->>Phase: LayoutInference(mod)
Phase->>Detector: check enable?
Detector-->>Phase: True/False
alt Visualization Enabled
Phase->>Formatter: get formats
Formatter-->>Phase: ["png", "pdf", ...] or ["txt"]
Phase->>Visitor: LayoutVisual(mod)
Visitor->>IR: scan for layout_map annotations
loop For each Fragment layout
Visitor->>Visitor: print_fragment_format(layout)
Visitor->>Plotter: plot_layout(layout, formats)
Plotter-->>Plotter: save txt/png/pdf/svg
end
else Visualization Disabled
Phase->>Phase: continue (no-op)
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Areas requiring close attention:
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/tools/plot_layout.py (2)
79-90: Guard warp coloring whennum_threads < 32The loop overriding
colors[i]for a fixedwarp_size = 32will raise anIndexErroriflayout.get_thread_size()returns fewer than 32 threads, sincecolorshas lengthnum_threads.Consider clamping the warp size to the actual number of threads:
- # Show the distribution of registers in each thread of a warp. - warp_size = 32 - spectral_camp = plt.get_cmap("hsv", warp_size * 6) - for i in range(warp_size): - colors[i] = spectral_camp(i * 6) + # Show the distribution of registers in each thread of a warp. + warp_size = min(32, num_threads) + spectral_camp = plt.get_cmap("hsv", warp_size * 6) + for i in range(warp_size): + colors[i] = spectral_camp(i * 6)
4-8: Align docstring withT.Fragmentand fix no-op assertionTwo small cleanups here:
The function now takes
layout: T.Fragment, but the docstring still sayslayout : T.Layout. Updating the docstring keeps the public API description accurate.The assertion on
local_idscurrently always passes because it compares each element to itself.You could address both with:
-def plot_layout(layout: T.Fragment, +def plot_layout(layout: T.Fragment, save_directory="./tmp", name: str = "layout", colormap: str = "RdPu", verbose: bool = False) -> None: @@ - layout : T.Layout - The layout object that describes how indices are mapped. + layout : T.Fragment + The fragment layout object that describes how indices are mapped. @@ - local_id = local_ids[0] - # assert local id in local_ids is equal - assert all(local_id == local_id for local_id in local_ids) + local_id = local_ids[0] + # assert all local IDs in local_ids are equal + assert all(local_id == v for v in local_ids)Also applies to: 14-24, 118-125
🧹 Nitpick comments (4)
tilelang/engine/lower.py (1)
18-23: Layout visualization hook placement looks correctImporting
LayoutVisualand invoking it once afterLowerAndLegalizeand beforeOptimizeForTargetis a good spot for debugging layouts; the pass is gated internally so normal compilation cost is unaffected.If you want to make the behavior clearer to readers, you could tweak the comment to e.g. “Visualize layout when TL_ENABLE_LAYOUT_VISUAL is set”, but this is optional.
Also applies to: 253-255
tilelang/tools/__init__.py (1)
1-3: Remove redundantnoqaonLayoutVisualimportRuff reports the
# noqa: F401here as unused. Since this import is a valid public re-export from the tools package, you can safely drop the directive.-from .layout_visual import LayoutVisual # noqa: F401 +from .layout_visual import LayoutVisualtilelang/__init__.py (1)
135-141: Drop unusednoqaontoolspackage importThe
toolssubmodule import is a legitimate part of the public API. Ruff flags the attached# noqa: F401as unused here, so you can simplify:-from . import ( - analysis, # noqa: F401 - transform, # noqa: F401 - language, # noqa: F401 - engine, # noqa: F401 - tools, # noqa: F401 -) +from . import ( + analysis, # noqa: F401 + transform, # noqa: F401 + language, # noqa: F401 + engine, # noqa: F401 + tools, +)tilelang/tools/layout_visual.py (1)
23-26: Remove unused attribute.The
layout_foundlist is initialized but never used in the visitor. Consider removing it to clean up the code.Apply this diff:
def __init__(self): super().__init__() - self.layout_found = [] self.processed_layouts = set()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/gemm/example_gemm.py(1 hunks)src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)tilelang/__init__.py(1 hunks)tilelang/engine/lower.py(2 hunks)tilelang/engine/phase.py(1 hunks)tilelang/tools/__init__.py(1 hunks)tilelang/tools/layout_visual.py(1 hunks)tilelang/tools/plot_layout.py(2 hunks)tilelang/transform/pass_config.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
tilelang/engine/phase.py (3)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-147)tilelang/tools/layout_visual.py (1)
LayoutVisual(44-50)
tilelang/engine/lower.py (2)
tilelang/engine/phase.py (1)
LayoutVisual(77-80)tilelang/tools/layout_visual.py (1)
LayoutVisual(44-50)
tilelang/tools/__init__.py (2)
tilelang/engine/phase.py (1)
LayoutVisual(77-80)tilelang/tools/layout_visual.py (1)
LayoutVisual(44-50)
examples/gemm/example_gemm.py (2)
tilelang/jit/kernel.py (1)
out_idx(598-599)tilelang/transform/pass_config.py (1)
PassConfigKey(6-147)
tilelang/tools/plot_layout.py (2)
tilelang/layout/fragment.py (1)
Fragment(12-212)src/layout/layout.cc (2)
Fragment(496-518)Fragment(520-530)
🪛 Ruff (0.14.6)
tilelang/tools/layout_visual.py
46-46: Unused function argument: mod
(ARG001)
46-46: Unused function argument: ctx
(ARG001)
tilelang/tools/__init__.py
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/__init__.py
140-140: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (9)
tilelang/transform/pass_config.py (1)
72-74: New pass-config key is consistent and well-documented
TL_ENABLE_LAYOUT_VISUALaligns with existing TL_* config keys and the underlying"tl.enable_layout_visual"string; the docstring is clear.src/op/builtin.cc (1)
19-38: Pass-config option registration for layout visual is consistent
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableLayoutVisual, Bool);follows the existing pattern for TL pass options and matches the corresponding string key and enum entry.src/op/builtin.h (1)
33-55: New layout-visual config key matches the rest of the stack
kEnableLayoutVisual’s string value"tl.enable_layout_visual"is consistent with the PythonPassConfigKey.TL_ENABLE_LAYOUT_VISUALand the C++ registration inbuiltin.cc.examples/gemm/example_gemm.py (1)
5-7: Example wiring ofTL_ENABLE_LAYOUT_VISUALlooks goodUsing
pass_configs={tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUAL: True}in the@tilelang.jitdecorator is a clear way to showcase the new layout visualization feature without altering the GEMM kernel itself.tilelang/engine/phase.py (2)
70-75: LGTM! Consistent with existing patterns.The implementation follows the same pattern as other enablement checks in this file (e.g.,
should_force_let_inline), correctly retrieving the pass context and checking the configuration flag.
77-81: LGTM! Appropriate design for a visualization pass.The function correctly gates the visualization behind the configuration flag and follows the pattern of other side-effect-only passes (like
PreLowerSemanticCheckat line 83) that don't modify or return the module.tilelang/tools/layout_visual.py (3)
9-18: LGTM! Clean formatting helper.The function provides clear, informative formatting of layout information for visualization output.
28-42: LGTM! Sound visitor implementation.The
visit_block_method correctly:
- Checks for the layout_map annotation
- Filters for Fragment instances
- Deduplicates layouts using string representation
- Prints and visualizes each unique layout
- Continues tree traversal
44-50: LGTM! Correctly implements visualization pass.The function properly creates a
prim_func_passthat visualizes layouts without modifying the IR. The unusedmodandctxparameters flagged by static analysis are part of the required pass interface signature and can be safely ignored.
5809550 to
70519cd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/tools/plot_layout.py (2)
4-14: Docstring parameter type is inconsistent with actual signature.The function signature specifies
layout: T.Fragment, but the docstring on line 14 still referencesT.Layout. Update the docstring to match the signature.Parameters ---------- - layout : T.Layout + layout : T.Fragment The layout object that describes how indices are mapped.
122-124: Tautological assertion always passes.The assertion
all(local_id == local_id for local_id in local_ids)compares each element to itself, which is alwaysTrue. It should compare each element tolocal_ids[0]to verify they are all equal.local_id = local_ids[0] # assert local id in local_ids is equal - assert all(local_id == local_id for local_id in local_ids) + assert all(lid == local_id for lid in local_ids)
🧹 Nitpick comments (3)
tilelang/tools/plot_layout.py (1)
85-89: Hardcoded warp size may not apply to all targets.The warp size is hardcoded to 32, which is CUDA-specific. For AMD GPUs, the wavefront size is 64. Consider making this configurable or deriving it from the target architecture if this tool is intended to support multiple backends.
tilelang/tools/layout_visual.py (2)
46-48: Prefix unused parameters with underscore.The
modandctxparameters are required by theprim_func_passsignature but unused in this pass. Prefix them with_to indicate they are intentionally ignored.- def pass_fn(func: tir.PrimFunc, mod, ctx): + def pass_fn(func: tir.PrimFunc, _mod, _ctx): _LayoutVisualVisitor().visit_stmt(func.body) return func
32-39: Consider using logging instead of print for consistency.The direct
print()calls work for a visualization tool, but using the module's logger would be more consistent with the rest of the codebase and allow users to control verbosity via log levels.+import logging + +logger = logging.getLogger(__name__) + ... if layout_id not in self.processed_layouts: - print(f"{key} layout inference:") - print(print_layout_format(layout)) + logger.info(f"{key} layout inference:") + logger.info(print_layout_format(layout)) plot_layout(layout, name=f"{key}_layout")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
tilelang/__init__.py(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/tools/layout_visual.py(1 hunks)tilelang/tools/plot_layout.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/tools/plot_layout.py (1)
tilelang/layout/fragment.py (1)
Fragment(12-212)
tilelang/tools/layout_visual.py (3)
tilelang/tools/plot_layout.py (1)
plot_layout(4-213)tilelang/layout/layout.py (2)
get_input_shape(57-66)get_output_shape(68-77)tilelang/engine/phase.py (1)
LayoutVisual(77-80)
tilelang/engine/phase.py (2)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/tools/layout_visual.py (1)
LayoutVisual(44-50)
🪛 Ruff (0.14.6)
tilelang/tools/layout_visual.py
46-46: Unused function argument: mod
(ARG001)
46-46: Unused function argument: ctx
(ARG001)
tilelang/__init__.py
140-140: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (2)
tilelang/__init__.py (1)
140-140: LGTM!The
toolsmodule export follows the established pattern for other submodules in this file. The static analysis hint about unusednoqadirective appears to be a false positive since the import is intentionally re-exported for public API access.tilelang/engine/phase.py (1)
70-80: LGTM!The implementation follows the established pattern for configuration-gated passes in this file. The
should_enable_layout_visualfunction correctly retrieves the pass context when not provided, andLayoutVisualproperly gates the visualization pass behind the configuration flag.
tilelang/tools/layout_visual.py
Outdated
| def __init__(self): | ||
| super().__init__() | ||
| self.layout_found = [] | ||
| self.processed_layouts = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove unused layout_found attribute.
The layout_found list is initialized but never used anywhere in the class. Remove it to avoid confusion.
def __init__(self):
super().__init__()
- self.layout_found = []
self.processed_layouts = set()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def __init__(self): | |
| super().__init__() | |
| self.layout_found = [] | |
| self.processed_layouts = set() | |
| def __init__(self): | |
| super().__init__() | |
| self.processed_layouts = set() |
🤖 Prompt for AI Agents
In tilelang/tools/layout_visual.py around lines 23 to 26, the __init__ defines
self.layout_found which is never used; remove that attribute initialization
(delete the self.layout_found = [] line) so only self.processed_layouts = set()
remains, and run tests/lint to ensure no references to layout_found exist
elsewhere.
|
@Cunxiao2002 thanks, would you mind providing an example for us to reproduce the tool. |
|
@LeiWang1999 ok, I will. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
examples/visual_layout_inference/visual_layout_inference.py (1)
47-52: Consider using regular comments instead of a string literal.The triple-quoted string is not assigned to anything, so it's created and discarded at runtime. Using regular comments would be more idiomatic and have zero runtime overhead.
- # print the layout visualization result and save figures to ./tmp. - ''' - C_local layout inference: - Shape: [32, 32] -> [8] - Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 - Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] - ''' + # The layout visualization result is saved to ./tmp. + # Example output for C_local layout inference: + # Shape: [32, 32] -> [8] + # Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 + # Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/visual_layout_inference/visual_layout_inference.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/visual_layout_inference/visual_layout_inference.py (1)
tilelang/transform/pass_config.py (1)
PassConfigKey(6-147)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (3)
examples/visual_layout_inference/visual_layout_inference.py (3)
1-6: LGTM!The imports and decorator usage are correct. The
pass_configsparameter properly enables layout visualization viaTL_ENABLE_LAYOUT_VISUAL.
9-28: LGTM!The GEMM kernel implementation is correct with proper shared memory tiling, fragment allocation for the accumulator, and 3-stage pipelining for latency hiding.
31-45: LGTM!The main function correctly instantiates the kernel, runs validation against PyTorch's matmul, and uses appropriate tolerances for float16 precision.
oraluben
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When trying to run the example the first time, it failed with the missing matplotlib. I suggest add that in pyproject.toml as an extra dependency as below.
[project.optional-dependencies]
vis = ["matplotlib"]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (1)
src/op/builtin.cc (1)
37-37: Layout‑visualization pass config registration is consistent
kEnableLayoutVisualizationis registered as aBoolalongside the other flags using the same pattern; the placement and usage look correct from this file’s perspective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Some minor comments:
- currently it generates 3 different format (pdf, png, and svg). Maybe to use pass key to control that to reduce generation time? (e.g.
Truefor all or one,'png'and'svg'for specific format? - Maybe to have some doc in readme? (e.g. we can install dependencies with
pip install tilelang[vis])
Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option. This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs. It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
tilelang/tools/layout_visual.py (1)
24-29: Remove unusedlayout_foundfield.
self.layout_foundis initialized but never used;processed_layoutsnow handles deduplication. Drop the unused attribute to avoid confusion.def __init__(self, formats: str = "png"): super().__init__() - self.layout_found = [] self.processed_layouts = set() self.formats = formats
🧹 Nitpick comments (6)
tilelang/transform/pass_config.py (1)
72-78: Clean up stray""after the docstring.The new config key and its docstring look good, but the trailing
""" ""leaves an extra empty string literal in the class body and is confusing. Please drop the""so this remains a single docstring block.tilelang/tools/layout_visual.py (1)
48-61: Optionally silence Ruff ARG001 by marking unused pass args.
modandctxare part of the required TVM prim_func_pass callback signature, but they’re unused because you read the current PassContext directly. If you want to satisfy Ruff without changing behavior, just rename them:- def pass_fn(func: tir.PrimFunc, mod, ctx): + def pass_fn(func: tir.PrimFunc, _mod, _ctx):tilelang/tools/plot_layout.py (3)
26-31: Align docstring with the updated API (T.Fragment, returns section).Two minor inconsistencies:
- The parameter docs still say
layout : T.Layout, but the signature now takeslayout: T.Fragment.- The “Returns” header line currently reads
-------s.Consider updating to:
- layout : T.Layout + layout : T.Fragment @@ - Returns - -------s + Returns + -------
121-131: Fix ineffective assertion onlocal_ids.This assertion always passes:
local_id = local_ids[0] # assert local id in local_ids is equal assert all(local_id == local_id for local_id in local_ids)The lambda compares each element to itself instead of to the first element. If you want to ensure all entries are equal to
local_ids[0], do:- local_id = local_ids[0] - # assert local id in local_ids is equal - assert all(local_id == local_id for local_id in local_ids) + local_id = local_ids[0] + # assert every local id in local_ids matches the first one + assert all(li == local_id for li in local_ids)
203-213:formatsimplementation doesn’t match thestr | list[str]type hint.The signature and docstring advertise
formats: str | list[str], but the code only acceptsstrand raises aTypeErrorfor any non-str:if isinstance(formats, str): ... else: raise TypeError("Expected str, but got ...")Either:
- Implement
list[str]support (e.g., normalize a list/tuple intoformats_listwith per-entry.strip().lower()), or- Narrow the type hint and docstring to
stronly and adjust the error message accordingly.Right now, callers following the annotation and passing a list will be rejected unexpectedly.
docs/tutorials/debug_tools_for_tilelang.md (1)
174-197: Tighten docs: code fence language and mention multi-format values.Two small doc polish points:
- The example layout output block (lines 190‑195) is fenced without a language; consider annotating it as
textfor better rendering and to satisfy markdownlint.- The pass config enum doc allows comma-separated formats (e.g.,
"png,svg"). If you intend to support that, it would be helpful to mention it here alongside"all","png","pdf","svg"so the docs and implementation stay in sync.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
docs/tutorials/debug_tools_for_tilelang.md(1 hunks)examples/visual_layout_inference/visual_layout_inference.py(1 hunks)pyproject.toml(1 hunks)src/op/builtin.cc(1 hunks)tilelang/engine/phase.py(1 hunks)tilelang/tools/layout_visual.py(1 hunks)tilelang/tools/plot_layout.py(4 hunks)tilelang/transform/pass_config.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- pyproject.toml
- examples/visual_layout_inference/visual_layout_inference.py
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/engine/phase.py (3)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-152)tilelang/tools/layout_visual.py (1)
LayoutVisual(46-64)
tilelang/tools/plot_layout.py (1)
tilelang/layout/fragment.py (1)
Fragment(12-212)
tilelang/tools/layout_visual.py (5)
tilelang/tools/plot_layout.py (1)
plot_layout(5-229)tilelang/layout/layout.py (2)
get_input_shape(57-66)get_output_shape(68-77)tilelang/engine/phase.py (1)
LayoutVisual(80-83)tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-152)
🪛 markdownlint-cli2 (0.18.1)
docs/tutorials/debug_tools_for_tilelang.md
190-190: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🪛 Ruff (0.14.6)
tilelang/tools/plot_layout.py
212-213: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tools/layout_visual.py
48-48: Unused function argument: mod
(ARG001)
48-48: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (1)
src/op/builtin.cc (1)
37-37: Registration ofkEnableLayoutVisualizationlooks correct.The new pass config option is registered consistently with the existing ones, and using
ffi::Stringis appropriate for a string-valued flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
tilelang/tools/layout_visual.py (2)
23-27: Remove unusedlayout_foundattribute.The
self.layout_foundlist on line 25 is initialized but never used anywhere in the class. Remove it to avoid confusion.Apply this diff:
def __init__(self, formats: str = "png"): super().__init__() - self.layout_found = [] self.processed_layouts = set() self.formats = formats
48-61: Fix default behavior forTL_ENABLE_LAYOUT_VISUALIZATIONwhen config is unset.The current implementation has a critical bug in the config handling logic:
- Line 50-51:
config.get()without a default returnsNonewhen the key is missing- Line 53:
str(None)becomes"none"- Line 54: The check
if not config_str or config_str == "false"fails because"none"is neither empty nor"false"- Line 59: Falls through to
formats = "none"and proceeds to run the visitorThis enables visualization by default when the config is unset, contradicting the documented default behavior (disabled). It also forces matplotlib imports even when users never enabled the feature.
Apply this diff to provide an explicit default:
def pass_fn(func: tir.PrimFunc, mod, ctx): pass_ctx = tilelang.transform.get_pass_context() config_value = pass_ctx.config.get( - tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") config_str = str(config_value).strip().lower() if not config_str or config_str == "false": return func elif config_str == "true": formats = "all" else: formats = config_str _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func
🧹 Nitpick comments (1)
tilelang/tools/layout_visual.py (1)
48-48: Consider using underscore prefix for intentionally unused parameters.The
modandctxparameters are required by the TVMprim_func_passAPI signature but are not used in this implementation. Consider renaming them to_modand_ctxto indicate they are intentionally unused and suppress the static analysis warning.Apply this diff:
- def pass_fn(func: tir.PrimFunc, mod, ctx): + def pass_fn(func: tir.PrimFunc, _mod, _ctx): pass_ctx = tilelang.transform.get_pass_context()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/tools/layout_visual.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/tools/layout_visual.py (5)
tilelang/layout/layout.py (2)
get_input_shape(57-66)get_output_shape(68-77)tilelang/tools/plot_layout.py (1)
plot_layout(5-229)tilelang/engine/phase.py (1)
LayoutVisual(80-83)tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-152)
🪛 Ruff (0.14.6)
tilelang/tools/layout_visual.py
48-48: Unused function argument: mod
(ARG001)
48-48: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build SDist
- GitHub Check: Quick Lint
🔇 Additional comments (2)
tilelang/tools/layout_visual.py (2)
1-7: LGTM!The imports are appropriate for the layout visualization functionality.
9-17: LGTM!The formatting logic correctly extracts and displays layout information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tilelang/tools/layout_visual.py (2)
24-28: Remove unusedlayout_foundattribute.The
layout_foundlist is initialized but never used anywhere in the class.Apply this diff:
def __init__(self, formats: str = "png"): super().__init__() - self.layout_found = [] self.processed_layouts = set() self.formats = formats
48-62: Fix default behavior to disable visualization when config is unset.When the config key is not set,
config_valuebecomesNone, which is converted to the string"none"on line 53. Since"none"is truthy and not equal to"false", the code falls through to theelsebranch (line 59) and setsformats = "none", causing the visitor to run even though visualization should be disabled by default. This contradicts the documented behavior and forces matplotlib imports unnecessarily.Apply this diff to provide an explicit default:
def pass_fn(func: tir.PrimFunc, mod, ctx): pass_ctx = tilelang.transform.get_pass_context() - config_value = pass_ctx.config.get( - tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value) + config_value = pass_ctx.config.get( + tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") config_str = str(config_value).strip().lower() if not config_str or config_str == "false": return func elif config_str == "true": formats = "all" else: formats = config_str _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func
🧹 Nitpick comments (1)
tilelang/tools/layout_visual.py (1)
48-48: Prefix unused arguments with underscore.The arguments
modandctxare required by the TVMprim_func_passsignature but are not used in the function body. Prefix them with_to indicate they are intentionally unused and silence the linter warnings.Apply this diff:
- def pass_fn(func: tir.PrimFunc, mod, ctx): + def pass_fn(func: tir.PrimFunc, _mod, _ctx):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/engine/phase.py(1 hunks)tilelang/tools/layout_visual.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/engine/phase.py (4)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-152)tilelang/engine/lower.py (1)
lower(214-272)tilelang/tools/layout_visual.py (1)
LayoutVisual(46-64)
🪛 Ruff (0.14.6)
tilelang/tools/layout_visual.py
48-48: Unused function argument: mod
(ARG001)
48-48: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
🔇 Additional comments (2)
tilelang/engine/phase.py (2)
70-81: LGTM! Default config behavior is correct.The function correctly disables visualization by default when the config is unset (line 74-75 provides
""as default). The logic properly handles all cases: unset config, explicitNone,"false","true", and custom format strings.Optional cleanup: The
Nonecheck on lines 77-78 is defensive but technically redundant since the default value on line 74-75 is"", notNone. The check would only trigger if someone explicitly sets the config value toNone, which is an edge case. You may consider removing it for clarity, but keeping it doesn't cause any issues.
84-87: LGTM! Gating logic is correct.The function properly gates visualization behind the
should_enable_layout_visual()check, ensuring thattilelang.tools.LayoutVisual()is only invoked when explicitly enabled.
|
I didn't know it's related to tvm's type🤔it seems not to make much sense to have string True/False |
|
@oraluben It seems |
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Show the distribution of registers in each thread of a warp. | ||
| warp_size = 32 | ||
| spectral_camp = plt.get_cmap("hsv", warp_size * 6) | ||
| for i in range(warp_size): | ||
| colors[i] = spectral_camp(i * 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard warp coloring when thread count < warp size
The new warp-aware coloring loop assumes at least 32 threads, but colors is sized to num_threads; when a layout has fewer than 32 threads (common for small tiles), iterating for i in range(warp_size) writes past the end of the list and raises IndexError, so visualization fails before any plots are saved.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Cunxiao2002
| else: | ||
| formats_list = [formats_str] | ||
| else: | ||
| raise TypeError(f"Expected str, but got {type(formats).__name__}. " | ||
| f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accept documented list inputs for output formats
The formats parameter is annotated and documented as str | list[str], but the new parsing block immediately raises a TypeError for any non‑string input, so callers following the API and passing a list like ["png", "pdf"] cannot use the feature and the tool fails before plotting.
Useful? React with 👍 / 👎.
tilelang/tools/layout_visual.py
Outdated
| from tilelang.tools.plot_layout import plot_layout | ||
|
|
||
|
|
||
| def print_layout_format(layout: T.Fragment) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only works for fragment layout, I think we should rename it into print_fragment_format and do some type check there.
| _LayoutVisualVisitor(formats=config_str).visit_stmt(func.body) | ||
| return func | ||
|
|
||
| return prim_func_pass(pass_fn, opt_level=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to add some comments for this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
tilelang/tools/layout_visual.py (3)
64-64: Remove unused attribute.The
layout_foundlist is initialized but never referenced anywhere in the class.Apply this diff:
def __init__(self, formats: str = "png"): super().__init__() - self.layout_found = [] self.processed_layouts = set() self.formats = formats
1-8: Matplotlib import prevents graceful degradation when the feature is disabled.Importing
plot_layoutat module level (line 7) will fail immediately if matplotlib is not installed, even whenTL_ENABLE_LAYOUT_VISUALIZATIONis unset or false. Since matplotlib is documented as an optional dependency, this breaks users who haven't installed it.Move the import inside
visit_block_and add error handling:import tilelang import tilelang.language as T from tvm import tir from tvm.tir import PyStmtExprVisitor from tvm.tir.transform import prim_func_pass -from tilelang.tools.plot_layout import plot_layoutAnd in the visitor method:
if isinstance(layout, T.Fragment): layout_id = str(layout) if layout_id not in self.processed_layouts: print(f"{key} layout inference:") print_fragment_format(layout) - plot_layout(layout, name=f"{key}_layout", formats=self.formats) + try: + from tilelang.tools.plot_layout import plot_layout + plot_layout(layout, name=f"{key}_layout", formats=self.formats) + except ImportError: + print("Warning: matplotlib not installed, skipping plot generation") self.processed_layouts.add(layout_id)
84-96: Add early return when visualization is disabled.The pass always constructs the visitor and processes the function body, even when visualization is disabled (empty or "false" config). This wastes resources and imports matplotlib unconditionally.
Apply this diff to skip execution when disabled:
def LayoutVisual(): def pass_fn(func: tir.PrimFunc, mod, ctx): pass_ctx = tilelang.transform.get_pass_context() config_value = pass_ctx.config.get( tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "") config_str = str(config_value).strip().lower() + + # Early return if visualization is disabled + if not config_str or config_str == "false": + return func + + # Determine output formats + if config_str == "true": + formats = "all" + else: + formats = config_str - _LayoutVisualVisitor(formats=config_str).visit_stmt(func.body) + _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return func return prim_func_pass(pass_fn, opt_level=0)tilelang/tools/plot_layout.py (1)
212-222: Function signature promiseslist[str]support but implementation rejects it.The type hint on line 10 declares
formats: str | list[str], but lines 220-222 raiseTypeErrorfor any non-string value. Callers following the documented API and passing["png", "pdf"]will encounter a runtime error.Either accept lists as documented:
if isinstance(formats, str): formats_str = formats.strip().lower() if formats_str == 'all': formats_list = ['pdf', 'png', 'svg'] elif "," in formats_str: formats_list = [f.strip() for f in formats_str.split(',')] else: formats_list = [formats_str] + elif isinstance(formats, list): + formats_list = [f.strip().lower() for f in formats] else: raise TypeError(f"Expected str or list[str], but got {type(formats).__name__}. " - f"Please pass a string like 'png', 'pdf', 'svg', 'all', or 'png,pdf'.") + f"Please pass a string like 'png', 'pdf', 'svg', 'all', 'png,pdf', or a list like ['png', 'pdf'].")Or narrow the type hint to
stronly if list support isn't needed:-def plot_layout(layout: T.Fragment, - save_directory="./tmp", - name: str = "layout", - colormap: str = "RdPu", - verbose: bool = False, - formats: str | list[str] = "png") -> None: +def plot_layout(layout: T.Fragment, + save_directory="./tmp", + name: str = "layout", + colormap: str = "RdPu", + verbose: bool = False, + formats: str = "png") -> None:
🧹 Nitpick comments (1)
tilelang/tools/layout_visual.py (1)
86-86: Prefix unused callback arguments with underscore.The
modandctxparameters are required by theprim_func_passcallback signature but are unused. By convention, prefix them with_to signal they're intentionally unused.Apply this diff:
- def pass_fn(func: tir.PrimFunc, mod, ctx): + def pass_fn(func: tir.PrimFunc, _mod, _ctx): pass_ctx = tilelang.transform.get_pass_context()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tilelang/tools/layout_visual.py(1 hunks)tilelang/tools/plot_layout.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/tools/plot_layout.py (3)
tilelang/layout/fragment.py (1)
Fragment(12-212)src/layout/layout.cc (2)
Fragment(496-518)Fragment(520-530)tilelang/language/kernel.py (1)
num_threads(221-225)
tilelang/tools/layout_visual.py (3)
tilelang/tools/plot_layout.py (1)
plot_layout(5-238)tilelang/engine/phase.py (1)
LayoutVisual(93-96)tilelang/transform/__init__.py (1)
get_pass_context(12-14)
🪛 Ruff (0.14.6)
tilelang/tools/plot_layout.py
221-222: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tools/layout_visual.py
33-33: Prefer TypeError exception for invalid type
(TRY004)
33-33: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Unused function argument: mod
(ARG001)
86-86: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (1)
tilelang/tools/plot_layout.py (1)
88-101: Good addition: warp-size guard and user warning.The warning on lines 91-97 helpfully alerts users when thread count is below the warp size, and the
min(warp_size, num_threads)guard on line 100 prevents theIndexErrorthat would occur whennum_threads < 32.
| def print_fragment_format(layout: T.Fragment) -> str: | ||
| """ | ||
| Format fragment layout information into a human-readable string. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| layout : T.Fragment | ||
| The fragment layout to format | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| Formatted string showing shape, thread mapping, and index mapping | ||
| """ | ||
| if isinstance(layout, T.Fragment): | ||
| input_shape = layout.get_input_shape() | ||
| output_shape = layout.get_output_shape() | ||
| lines = [ | ||
| f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", | ||
| f" Index: {layout.forward_index}" | ||
| ] | ||
| print("\n".join(lines)) | ||
| else: | ||
| raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix docstring inconsistency: function prints but claims to return str.
The function docstring (lines 19-22) states it returns a formatted string, but the implementation calls print() on line 31 and has no return statement.
Apply this diff to align the implementation with the documented behavior:
def print_fragment_format(layout: T.Fragment) -> str:
"""
Format fragment layout information into a human-readable string.
Parameters
----------
layout : T.Fragment
The fragment layout to format
Returns
-------
str
Formatted string showing shape, thread mapping, and index mapping
"""
if isinstance(layout, T.Fragment):
input_shape = layout.get_input_shape()
output_shape = layout.get_output_shape()
lines = [
f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}",
f" Index: {layout.forward_index}"
]
- print("\n".join(lines))
+ return "\n".join(lines)
else:
raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}")Then update the caller on line 77 to print the returned string:
if layout_id not in self.processed_layouts:
print(f"{key} layout inference:")
- print_fragment_format(layout)
+ print(print_fragment_format(layout))
plot_layout(layout, name=f"{key}_layout", formats=self.formats)
self.processed_layouts.add(layout_id)Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.6)
33-33: Prefer TypeError exception for invalid type
(TRY004)
33-33: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/tools/layout_visual.py around lines 10 to 34, the function
print_fragment_format's docstring states it returns a formatted string but the
implementation prints the string and returns None; change the implementation to
build and return the formatted string instead of calling print (construct the
same 'lines' and return "\n".join(lines)), and remove the print(); then update
the caller at line 77 to print(print_fragment_format(...)) so the caller
displays the returned string.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/engine/phase.py (1)
93-96: Optional: Consider passing pass_ctx to avoid redundant lookups.Currently,
should_enable_layout_visual()retrieves the pass context on every call. SinceLowerAndLegalize(line 175) already retrievespass_ctx, you could pass it through to eliminate redundant lookups:-def LayoutVisual(mod: IRModule) -> None: +def LayoutVisual(mod: IRModule, pass_ctx: PassContext | None = None) -> None: """Apply layout visualization pass if enabled.""" - if should_enable_layout_visual(): + if should_enable_layout_visual(pass_ctx): tilelang.tools.LayoutVisual()(mod)Then update the call site at line 156:
- LayoutVisual(mod) + LayoutVisual(mod, pass_ctx)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/engine/phase.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/engine/phase.py (3)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-152)tilelang/tools/layout_visual.py (1)
LayoutVisual(84-96)
🪛 Ruff (0.14.6)
tilelang/engine/phase.py
86-89: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Quick Lint
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build SDist
🔇 Additional comments (1)
tilelang/engine/phase.py (1)
155-156: LGTM!The placement after
LayoutInference()is appropriate - visualizing layouts after they've been inferred is the correct sequence.
tilelang/engine/phase.py
Outdated
| def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: | ||
| if pass_ctx is None: | ||
| pass_ctx = tilelang.transform.get_pass_context() | ||
|
|
||
| config_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, | ||
| "") | ||
|
|
||
| config_str = str(config_value).strip().lower() | ||
| if not config_str: | ||
| return False | ||
|
|
||
| valid_formats = ["png", "pdf", "svg", "all"] | ||
| formats_list = [f.strip() for f in config_str.split(",")] | ||
|
|
||
| invalid_formats = [fmt for fmt in formats_list if fmt not in valid_formats] | ||
| if invalid_formats: | ||
| raise ValueError( | ||
| f"Invalid formats for TL_ENABLE_LAYOUT_VISUALIZATION: {invalid_formats}. " | ||
| f"Valid formats are: {valid_formats}. " | ||
| f"You can choose one of the valid formats or a comma-separated list of formats.") | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Config key mismatch between validation and visualization pass.
The validation here uses tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, but the actual visualization pass in tilelang/tools/layout_visual.py (line 87) reads tilelang.PassConfigKey.TT_ENABLE_LAYOUT_VISUALIZATION.value. This typo means:
should_enable_layout_visual()validates one config key- The underlying pass reads a different config key
- Layout visualization won't activate even when correctly configured
Fix the key in tilelang/tools/layout_visual.py to match:
# In tilelang/tools/layout_visual.py, line 87
config_value = pass_ctx.config.get(
tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value, "")🧰 Tools
🪛 Ruff (0.14.6)
86-89: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/tools/layout_visual.py around line 87, the pass reads the wrong
config key `TT_ENABLE_LAYOUT_VISUALIZATION`; change it to use
`TL_ENABLE_LAYOUT_VISUALIZATION` (i.e., replace the `TT_` key reference with
`TL_`) so it matches the validator in tilelang/engine/phase.py and the rest of
the code; ensure the get(...) call uses
tilelang.PassConfigKey.TL_ENABLE_LAYOUT_VISUALIZATION.value and keep the same
default value behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tilelang/engine/phase.py (1)
70-90: Config parsing/validation is solid; consider centralizing formats and tightening the error messageThe gating logic now correctly:
- Defaults to disabled when the config is unset/empty.
- Normalizes to lowercase, supports comma-separated lists, and validates against the documented set of formats.
Two small, non-blocking suggestions:
valid_formatsis hard-coded here and also described intilelang/transform/pass_config.pyand used in the plotting tools. Consider centralizing this list (e.g., a shared helper/constant) to avoid future drift.- Ruff’s TRY003 warns about long error messages; you could shorten and structure the message a bit to keep it concise, e.g. by joining the invalid/valid sets into short comma‑separated strings.
Example tweak:
- invalid_formats = [fmt for fmt in formats_list if fmt not in valid_formats] - if invalid_formats: - raise ValueError( - f"Invalid formats for TL_ENABLE_LAYOUT_VISUALIZATION: {invalid_formats}. " - f"Valid formats are: {valid_formats}. " - f"You can choose one of the valid formats or a comma-separated list of formats.") + invalid_formats = [fmt for fmt in formats_list if fmt not in valid_formats] + if invalid_formats: + invalid = ", ".join(invalid_formats) + valid = ", ".join(valid_formats) + raise ValueError( + f"Invalid TL_ENABLE_LAYOUT_VISUALIZATION format(s): {invalid}. " + f"Expected one or more of: {valid}." + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/engine/phase.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/engine/phase.py (3)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-152)tilelang/tools/layout_visual.py (1)
LayoutVisual(84-96)
🪛 Ruff (0.14.6)
tilelang/engine/phase.py
86-89: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: Quick Lint
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (2)
tilelang/engine/phase.py (2)
93-96: Wrapper gating aroundtilelang.tools.LayoutVisuallooks correct
LayoutVisual(mod)cleanly delegates totilelang.tools.LayoutVisualonly whenshould_enable_layout_visual()is true, so the default pipeline stays unchanged and visualization is opt‑in via pass config. No issues from my side here.
153-154: Hooking visualization afterLayoutInferenceis a good insertion pointCalling
LayoutVisual(mod)immediately afterLayoutInferencegives the tool access to fully inferred layouts while keeping the pass optional (no‑op when disabled). This ordering shouldn’t affect the rest of the lowering pipeline and keeps visualization clearly separated from semantic transforms.
Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tilelang/tools/layout_visual.py (3)
61-66: Remove unusedlayout_foundattribute.
self.layout_foundis never read;self.processed_layoutsalready tracks seen layouts. Dropping the unused list avoids confusion:def __init__(self, formats: str = ""): super().__init__() - self.layout_found = [] self.processed_layouts = set() self.formats = formats.strip().lower() if formats else ""
9-32: Alignprint_fragment_formatbehavior with its type hint and docs.Current behavior:
- Signature and docstring say it returns
str, but it callsprint()and returnsNone.- On wrong type it raises
ValueError, though this is really a type error.To make this easier to use and test, consider:
- Returning the formatted string instead of printing, and letting callers handle
print().- Raising
TypeErrorfor non‑T.Fragmentinputs.Example:
def print_fragment_format(layout: T.Fragment) -> str: @@ - if isinstance(layout, T.Fragment): - input_shape = layout.get_input_shape() - output_shape = layout.get_output_shape() - lines = [ - f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", - f" Index: {layout.forward_index}" - ] - print("\n".join(lines)) - else: - raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") + if isinstance(layout, T.Fragment): + input_shape = layout.get_input_shape() + output_shape = layout.get_output_shape() + lines = [ + f" Shape: {input_shape} -> {output_shape}", + f" Thread: {layout.forward_thread}", + f" Index: {layout.forward_index}", + ] + return "\n".join(lines) + raise TypeError(f"Expected T.Fragment, but got {type(layout).__name__}")Call site would then
print(print_fragment_format(layout)).
1-7: Makeplot_layoutimport resilient to missing matplotlib and avoid hard failure.Importing
plot_layoutat module import time tightly couples this file to matplotlib. If matplotlib is an optional dependency, importingtilelang.tools.layout_visual(or anything that re‑exports it) will raiseImportErroreven when visualization is disabled.A safer pattern is:
- Lazily import
plot_layoutonly whenself.formatsis set and we actually want plots.- Wrap the import (and optional call) in
try/except ImportErrorto degrade gracefully.For example:
-from tilelang.tools.plot_layout import plot_layout @@ - print_fragment_format(layout) - if self.formats: - plot_layout(layout, name=f"{key}_layout", formats=self.formats) + print_fragment_format(layout) + if self.formats: + try: + from tilelang.tools.plot_layout import plot_layout + plot_layout( + layout, + name=f"{key}_layout", + formats=self.formats, + ) + except ImportError: + print( + "Warning: matplotlib not installed; " + "skipping layout plots." + )This keeps the rest of TileLang usable without forcing matplotlib.
Also applies to: 71-79
🧹 Nitpick comments (2)
tilelang/tools/layout_visual.py (2)
35-59: Update docstring to match config keys and consider recursing into nested blocks.Two points here:
Config naming mismatch in docs
The docstring still refers toTL_ENABLE_LAYOUT_VISUALIZATIONand a single string‑valued option, but the engine wiring usesTL_LAYOUT_VISUALIZATION_ENABLE(bool) andTL_LAYOUT_VISUALIZATION_FORMATSfor formats. It would be less confusing to update this description to match the actual keys and their types.Visitor does not recurse into nested blocks
visit_block_is overriding the base implementation, butsuper().visit_block_(op)is commented out. That means only blocks with a directlayout_mapannotation at the entry point are processed; any nested blocks with their ownlayout_mapwill be skipped.If
LayoutInferencecan attachlayout_mapto inner blocks, you likely want:def visit_block_(self, op: tir.Block) -> None: if "layout_map" in op.annotations:@@
# super().visit_block_(op)
super().visit_block_(op)If the current behavior (only top‑level blocks) is intentional, a short comment explaining that assumption would help future readers.Also applies to: 67-82
84-90: Minor: clarifyformatsexpectation and unusedmod/ctxwarnings.
formatsis passed straight through to the visitor and ultimately toplot_layout, which acceptsstr | list[str]. If you keep it string‑only, you may want to document that here (or relax the type annotation on_LayoutVisualVisitor.__init__to match).pass_fn’smodandctxparameters are unused but required by theprim_func_passcallback signature. If Ruff’sARG001warnings are noisy, you can rename them to_mod,_ctxto make the intent explicit.Both are nits and don’t block functionality.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/visual_layout_inference/visual_layout_inference.py(1 hunks)src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)tilelang/engine/phase.py(2 hunks)tilelang/tools/layout_visual.py(1 hunks)tilelang/transform/pass_config.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- tilelang/transform/pass_config.py
- examples/visual_layout_inference/visual_layout_inference.py
- src/op/builtin.cc
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/tools/layout_visual.py (3)
tilelang/tools/plot_layout.py (1)
plot_layout(5-238)tilelang/layout/layout.py (2)
get_input_shape(57-66)get_output_shape(68-77)tilelang/engine/phase.py (1)
LayoutVisual(94-98)
🪛 Ruff (0.14.7)
tilelang/engine/phase.py
87-90: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tools/layout_visual.py
32-32: Prefer TypeError exception for invalid type
(TRY004)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Unused function argument: mod
(ARG001)
86-86: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
🔇 Additional comments (3)
src/op/builtin.h (1)
54-57: Layout visualization config keys look consistent and well‑scoped.The new
kLayoutVisualizationEnable/kLayoutVisualizationFormatsconstants follow the existing naming pattern and provide a clean hook for the Python pass‑config layer. No issues from the header side.tilelang/engine/phase.py (2)
70-75:should_enable_layout_visualcorrectly switches to a typed bool flag.Using
PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLEwith a boolean default aligns with TVM’s typed pass‑config options and avoids the previous string"true"/"false"issues. If you expect any legacy string values, you might considerbool(...)normalization, but as long as the key is registered asBoolthis is fine.
94-99: Layout visualization is correctly wired into the lowering pipeline.Invoking
LayoutVisual(mod)immediately afterLayoutInference(mod)is a sensible place: you see inferred layouts before further lowering, and the helper itself is a no‑op unlessTL_LAYOUT_VISUALIZATION_ENABLEis true. With the formats helper fixed as above, this integration looks good.Also applies to: 155-156
tilelang/engine/phase.py
Outdated
| def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: | ||
| if pass_ctx is None: | ||
| pass_ctx = tilelang.transform.get_pass_context() | ||
| formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") | ||
| if not formats_value: | ||
| return "" | ||
|
|
||
| formats_str = formats_value.strip().lower() | ||
| valid_formats = ["png", "pdf", "svg", "all"] | ||
| if formats_str not in valid_formats: | ||
| raise ValueError( | ||
| f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {formats_str}. " | ||
| f"Valid formats are: {valid_formats}. " | ||
| f"You can choose one of the valid formats or a comma-separated list of formats.") | ||
| return formats_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix get_layout_visual_formats type and support comma‑separated formats as documented.
Current issues:
- The function is annotated to return
list[str]but actually returns""or a singlestr; callers (LayoutVisual) also treat it as a string. - Validation only allows
"png","pdf","svg","all"exactly. A documented value like"png,svg"will fail withValueError, even thoughplot_layoutcan handle comma‑separated strings.
A minimal fix is:
- Change the return type to
str. - Normalize to string, allow empty string,
"all", or any comma‑separated subset of{"png","pdf","svg"}. - Keep returning the original normalized string so
plot_layoutcan interpret it.
Example:
-def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]:
+def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> str:
@@
- formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "")
- if not formats_value:
- return ""
-
- formats_str = formats_value.strip().lower()
- valid_formats = ["png", "pdf", "svg", "all"]
- if formats_str not in valid_formats:
- raise ValueError(
- f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {formats_str}. "
- f"Valid formats are: {valid_formats}. "
- f"You can choose one of the valid formats or a comma-separated list of formats.")
- return formats_str
+ formats_value = pass_ctx.config.get(
+ tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, ""
+ )
+ formats_str = str(formats_value).strip().lower()
+ if not formats_str:
+ return ""
+
+ # Allow: "all" or comma-separated subset of {"png", "pdf", "svg"}.
+ if formats_str == "all":
+ return "all"
+
+ valid_single = {"png", "pdf", "svg"}
+ tokens = [fmt.strip() for fmt in formats_str.split(",") if fmt.strip()]
+ invalid = [fmt for fmt in tokens if fmt not in valid_single]
+ if invalid:
+ raise ValueError(
+ f"Invalid formats in TL_LAYOUT_VISUALIZATION_FORMATS: {', '.join(invalid)}. "
+ f"Valid values are 'all' or a comma-separated subset of {sorted(valid_single)}."
+ )
+ return formats_strThis matches the documented “comma‑separated list” behavior and keeps the API consistent with plot_layout.
🧰 Tools
🪛 Ruff (0.14.7)
87-90: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In tilelang/engine/phase.py around lines 77-91, the function signature and
behavior are inconsistent: it is annotated to return list[str] but returns
strings and rejects comma-separated values; update the function to return a
normalized str (change annotation to -> str), accept empty string, the literal
"all", or a comma-separated list of formats drawn from {"png","pdf","svg"}
(normalize to lowercase and strip whitespace), validate each token and raise
ValueError only if any token is invalid, and finally return the normalized
string (empty, "all", or the comma-joined normalized tokens) so callers like
plot_layout can parse it as documented.
|
|
||
| # print the layout visualization result and save figures to ./tmp. | ||
| ''' | ||
| C_local layout inference: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to be C_local inferenced layout:
| if not formats_value: | ||
| return "" | ||
|
|
||
| formats_str = formats_value.strip().lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be better to be designed to support multiple format dumps, and text will be the default one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (5)
tilelang/tools/layout_visual.py (3)
63-63: Remove unused attribute.
self.layout_found = []is initialized but never referenced elsewhere in the class.Apply this diff:
def __init__(self, formats: list[str] = ""): super().__init__() - self.layout_found = [] self.processed_layouts = set()
6-6: Lazy-load matplotlib-dependent import.Importing
plot_layoutat module level prevents this module from loading when matplotlib is absent, even if visualization is disabled. Since matplotlib is documented as optional, defer the import until plotting is actually needed.Move the import into the loop at lines 77-78 and wrap in a try-except:
import tilelang.language as T from tvm import tir from tvm.tir import PyStmtExprVisitor from tvm.tir.transform import prim_func_pass -from tilelang.tools.plot_layout import plot_layoutThen update lines 77-78:
print_fragment_format(layout) - for fmt in self.formats_list: - plot_layout(layout, name=f"{key}_layout", formats=fmt) + try: + from tilelang.tools.plot_layout import plot_layout + for fmt in self.formats_list: + plot_layout(layout, name=f"{key}_layout", formats=fmt) + except ImportError: + print("Warning: matplotlib not installed, skipping plot generation") self.processed_layouts.add(layout_id)
9-33: Align implementation with docstring.The docstring (lines 18-22) states the function returns a formatted string, but line 30 prints the string and implicitly returns
None.Apply this diff to return the string:
- print("\n".join(lines)) + return "\n".join(lines)Then update the caller at line 76:
print(f"{key} layout inference:") - print_fragment_format(layout) + print(print_fragment_format(layout))docs/tutorials/debug_tools_for_tilelang.md (1)
192-197: Add language specifier to code block.The fenced code block lacks a language identifier, affecting syntax highlighting and accessibility tools.
Apply this diff:
The output messages of "txt" will include something like: -``` +```text C_local inferenced layout: Shape: [32, 32] -> [8] Thread: _j // 16 * 64 + _i // 16 * 32 + _i % 8 * 4 + _j % 8 // 2 Index: [_j % 16 // 8 * 4 + _i % 16 // 8 * 2 + _j % 2] ```Based on static analysis hints.
tilelang/engine/phase.py (1)
77-102: Critical type mismatch:get_layout_visual_formatsreturnslist[str]but caller expectsstr.Line 109 passes the returned
list[str]totilelang.tools.LayoutVisual(formats=formats), butlayout_visual.pyline 84 declaresdef LayoutVisual(formats: str = ""). This type mismatch will cause a runtime error.Additionally, the default value
["txt"](line 82) and the expansion of"all"to include"txt"(line 88) conflict withlayout_visual.pyline 65, which explicitly filters out"txt":[f for f in formats if f != "txt"]. This suggests"txt"is a sentinel for text-only output, not a format to pass to the visualization pass.Solution 1 (recommended): Change the return type to
strand return a comma-separated string that the visualization module can parse:-def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> list[str]: +def get_layout_visual_formats(pass_ctx: PassContext | None = None) -> str: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() formats_value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS, "") if not formats_value: - return ["txt"] + return "" formats_str = formats_value.strip().lower() valid_formats = ["txt", "png", "pdf", "svg", "all"] if formats_str == "all": - return ["txt", "png", "pdf", "svg"] + return "png,pdf,svg" if "," in formats_str: formats_list = [f.strip() for f in formats_str.split(',')] else: formats_list = [formats_str] invalid_formats = [f for f in formats_list if f not in valid_formats] if invalid_formats: raise ValueError( f"Invalid formats for TL_LAYOUT_VISUALIZATION_FORMATS: {invalid_formats}. " f"Valid formats are: {valid_formats}. " f"You can choose one of the valid formats or a comma-separated list of formats.(e.g., 'txt,png,pdf')" ) - return formats_list + # Filter out 'txt' as it's handled separately (text output is always enabled) + filtered = [f for f in formats_list if f != "txt"] + return ",".join(filtered)Solution 2: Update
layout_visual.pyto acceptlist[str](see comment on that file).
🧹 Nitpick comments (1)
tilelang/tools/layout_visual.py (1)
86-86: Remove unused function parameters.The
pass_fnclosure does not usemodorctxparameters.Apply this diff:
- def pass_fn(func: tir.PrimFunc, mod, ctx): + def pass_fn(func: tir.PrimFunc, _mod, _ctx): _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return funcBased on learnings from static analysis hints.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
docs/tutorials/debug_tools_for_tilelang.md(1 hunks)examples/visual_layout_inference/visual_layout_inference.py(1 hunks)tilelang/engine/phase.py(2 hunks)tilelang/tools/layout_visual.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/visual_layout_inference/visual_layout_inference.py
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/engine/phase.py (3)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/transform/pass_config.py (1)
PassConfigKey(6-153)tilelang/tools/layout_visual.py (1)
LayoutVisual(84-90)
🪛 markdownlint-cli2 (0.18.1)
docs/tutorials/debug_tools_for_tilelang.md
192-192: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🪛 Ruff (0.14.7)
tilelang/engine/phase.py
97-101: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/tools/layout_visual.py
32-32: Prefer TypeError exception for invalid type
(TRY004)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Unused function argument: mod
(ARG001)
86-86: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build SDist
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (3)
tilelang/engine/phase.py (2)
166-167: LGTM! Appropriate placement for visualization pass.Calling
LayoutVisual(mod)immediately afterLayoutInferenceensures the visualization captures the inferred layouts before subsequent transformations. This placement is logical and aligns with the tool's purpose.
105-109: Unable to verify type consistency due to repository access limitations.The original review comment references a type mismatch between
get_layout_visual_formats()andtilelang.tools.LayoutVisual, but the repository could not be accessed to confirm the claim. Without examining the function signatures and implementations, the validity of this concern cannot be determined.docs/tutorials/debug_tools_for_tilelang.md (1)
174-198: Well-structured documentation for the visualization feature.The documentation clearly explains:
- The purpose and output types (textual + visual diagrams)
- Configuration via
TL_LAYOUT_VISUALIZATION_ENABLEandTL_LAYOUT_VISUALIZATION_FORMATS- Supported format values and their meanings
- Example output
This provides users with sufficient information to enable and use the layout visualization tool effectively.
| - "png,svg": Generate multiple formats (comma-separated) | ||
| """ | ||
|
|
||
| def __init__(self, formats: list[str] = ""): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Fix type annotation inconsistency.
Line 61 declares formats: list[str] = "" which is invalid—a list[str] type cannot have a string default. Looking at line 65, the code iterates [f for f in formats if f != "txt"], treating formats as an iterable of strings. However, this creates ambiguity: if formats="png,svg" (a string), the list comprehension yields ['p', 'n', 'g', ',', 's', 'v', 'g'] (individual characters), not ['png', 'svg'].
Additionally, there's a critical type mismatch with the caller in tilelang/engine/phase.py line 109, which passes a list[str] (from get_layout_visual_formats).
To resolve this, decide whether formats should be:
- A
list[str](as the caller provides), or - A
str(as the outerLayoutVisualfunction signature at line 84 suggests)
If formats should be a list, apply this diff:
- def __init__(self, formats: list[str] = ""):
+ def __init__(self, formats: list[str] | None = None):
super().__init__()
self.layout_found = []
self.processed_layouts = set()
- self.formats_list = [f for f in formats if f != "txt"]
+ self.formats_list = [f for f in (formats or []) if f != "txt"]And update the factory at line 84:
-def LayoutVisual(formats: str = ""):
+def LayoutVisual(formats: list[str] | None = None):Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tilelang/tools/layout_visual.py around line 61, the constructor currently
declares formats: list[str] = "" which is invalid and mismatches callers; change
the parameter to accept an optional list (e.g., formats: list[str] | None =
None), guard inside the ctor to set formats = [] when None (avoiding a mutable
default), and keep the existing list comprehension that filters out "txt"; also
update the factory function signature at line 84 to accept and forward a
list[str] (not a str) so callers like tilelang/engine/phase.py that pass a
list[str] match the type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (4)
tilelang/analysis/layout_visual.py (3)
9-32: Fixprint_fragment_formatdocstring vs implementation and narrow the exception type.The function is annotated and documented as returning a
str, but it currently prints and returnsNone. Also, type validation is better expressed asTypeError.Suggested change:
-def print_fragment_format(layout: T.Fragment) -> str: +def print_fragment_format(layout: T.Fragment) -> str: @@ - if isinstance(layout, T.Fragment): - input_shape = layout.get_input_shape() - output_shape = layout.get_output_shape() - lines = [ - f" Shape: {input_shape} -> {output_shape}", f" Thread: {layout.forward_thread}", - f" Index: {layout.forward_index}" - ] - print("\n".join(lines)) - else: - raise ValueError(f"Expected T.Fragment, but got {type(layout).__name__}") + if isinstance(layout, T.Fragment): + input_shape = layout.get_input_shape() + output_shape = layout.get_output_shape() + lines = [ + f" Shape: {input_shape} -> {output_shape}", + f" Thread: {layout.forward_thread}", + f" Index: {layout.forward_index}", + ] + return "\n".join(lines) + raise TypeError(f"Expected T.Fragment, but got {type(layout).__name__}")And update the caller below to actually print the string:
- print(f"{key} inferenced layout:") - print_fragment_format(layout) + print(f"{key} inferenced layout:") + print(print_fragment_format(layout))
61-66: Fixformatstyping/default and align with engine’sget_layout_visual_formats.
__init__declaresformats: list[str] = "", which is an invalid default and conflicts with the actual usage:
tilelang.engine.phase.get_layout_visual_formats()returnslist[str], which is passed through unchanged.- The comprehension
[f for f in formats if f != "txt"]will iterate characters whenformatsis a string (e.g.,"png"→["p","n","g"]).- The factory
LayoutVisualis still annotated as takingstr, but it is now effectively receiving alist[str].Recommend standardizing on
list[str]throughout:- def __init__(self, formats: list[str] = ""): + def __init__(self, formats: list[str] | None = None): super().__init__() - self.layout_found = [] self.processed_layouts = set() - self.formats_list = [f for f in formats if f != "txt"] + formats = formats or [] + self.formats_list = [fmt for fmt in formats if fmt != "txt"]This:
- Removes the invalid string default.
- Matches
get_layout_visual_formats()’s return type.- Drops the unused
layout_foundfield.
If you still want to support the older string-based API, you could additionally normalize astrto a singleton list before the list comprehension.
84-90: AlignLayoutVisualsignature and mark unusedpass_fnargs.There’s still a mismatch between this factory and how it’s used:
- Here,
formatsis annotated asstr, defaulting to"".- In
tilelang.engine.phase.LayoutVisual,formatsis alist[str]fromget_layout_visual_formats().To keep the API consistent with the engine helper and the visitor:
-def LayoutVisual(formats: str = ""): +def LayoutVisual(formats: list[str] | None = None): @@ - def pass_fn(func: tir.PrimFunc, mod, ctx): - _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) + def pass_fn(func: tir.PrimFunc, _mod, _ctx): + _LayoutVisualVisitor(formats=formats).visit_stmt(func.body) return funcRenaming to
_mod,_ctxalso resolves the “unused argument” warnings without changing behavior.tilelang/engine/phase.py (1)
105-110: Type mismatch between engineLayoutVisualand analysisLayoutVisual.Here,
formatsis alist[str]fromget_layout_visual_formats(), buttilelang.analysis.LayoutVisualis currently typed asLayoutVisual(formats: str = "")and forwardsformatsdirectly to_LayoutVisualVisitor.Once you adopt the list-based signature in
tilelang/analysis/layout_visual.LayoutVisual, this call will be consistent. After that change, this block itself is fine and does not need further adjustment:def LayoutVisual(mod: IRModule) -> None: """Apply layout visualization pass if enabled.""" if should_enable_layout_visual(): formats = get_layout_visual_formats() tilelang.analysis.LayoutVisual(formats=formats)(mod)
🧹 Nitpick comments (5)
tilelang/analysis/layout_visual.py (2)
35-59: Update configuration docstring to match actual config keys and behavior.The docstring still references
TL_ENABLE_LAYOUT_VISUALIZATIONand string-valued config controlling both enablement and formats, but the current implementation usesTL_LAYOUT_VISUALIZATION_ENABLE(boolean) plusTL_LAYOUT_VISUALIZATION_FORMATShandled intilelang.engine.phase.get_layout_visual_formats.Please revise the “Configuration” section to:
- Name the correct keys (
TL_LAYOUT_VISUALIZATION_ENABLE,TL_LAYOUT_VISUALIZATION_FORMATS).- Clarify that enable/disable is controlled by the
*_ENABLEboolean.- Note that formats are resolved in
get_layout_visual_formatsand passed here as a list of strings (with"txt"implying “text only”).
67-81: Consider restoring recursive traversal and minor wording fix.Two small points here:
super().visit_block_(op)is commented out, so the visitor will only process the current block and not recurse into nested blocks that might also carrylayout_mapannotations. If layout inference can attach annotations deeper in the tree, you probably want the recursive behavior back:- # super().visit_block_(op) + super().visit_block_(op)
- The message
"{key} inferenced layout:"is a bit off grammatically;"inferred layout"reads better:- print(f"{key} inferenced layout:") + print(f"{key} inferred layout:")tilelang/analysis/__init__.py (1)
6-6: Export looks good;noqais optional.Re-exporting
LayoutVisualfromtilelang.analysisis consistent with the other analysis tools. Note that Ruff reports# noqa: F401as unused for this file; you can drop it here (and possibly above) if you want to satisfy that lint, but it’s purely stylistic.tilelang/engine/phase.py (2)
77-102: Formats helper behavior is reasonable; consider minor tightening of validation/message.The overall behavior of
get_layout_visual_formatslooks good:
- Defaults to
["txt"]when unset (text-only).- Accepts
"all"and expands to["txt", "png", "pdf", "svg"].- Accepts comma-separated values and normalizes them.
Two optional refinements:
- You can avoid treating
"all"as a valid token informats_listby keepingvalid_formatsto the concrete formats only:- valid_formats = ["txt", "png", "pdf", "svg", "all"] + valid_formats = ["txt", "png", "pdf", "svg"] @@ - invalid_formats = [f for f in formats_list if f not in valid_formats] + invalid_formats = [f for f in formats_list if f not in valid_formats]
"all"would already have returned early at Line 87, so it doesn’t need to be invalid_formats.
- The error message is quite long for an inline
ValueError. If you care about Ruff’sTRY003hint, consider shortening it a bit or moving more detail into a custom exception type.
166-167: Layout visualization placement in the pipeline looks appropriate.Running
LayoutVisual(mod)immediately afterLayoutInference(mod)is a sensible spot: all fragment/shared-memory layouts have been inferred, and no lowering passes have erased or reshaped them yet. Since the visualization pass is gated by config and leaves the IR unmodified, this shouldn’t affect correctness.If you expect very large modules, you might later consider an additional guard (e.g., max number of layouts) to avoid generating an excessive number of plots, but that’s purely an optimization concern.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
tilelang/analysis/__init__.py(1 hunks)tilelang/analysis/layout_visual.py(1 hunks)tilelang/engine/phase.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/analysis/__init__.py (2)
tilelang/analysis/layout_visual.py (1)
LayoutVisual(84-90)tilelang/engine/phase.py (1)
LayoutVisual(105-109)
tilelang/engine/phase.py (2)
tilelang/transform/__init__.py (1)
get_pass_context(12-14)tilelang/analysis/layout_visual.py (1)
LayoutVisual(84-90)
tilelang/analysis/layout_visual.py (3)
tilelang/tools/plot_layout.py (1)
plot_layout(5-238)tilelang/layout/layout.py (2)
get_input_shape(57-66)get_output_shape(68-77)tilelang/engine/phase.py (1)
LayoutVisual(105-109)
🪛 Ruff (0.14.7)
tilelang/analysis/__init__.py
6-6: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/engine/phase.py
97-101: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/analysis/layout_visual.py
32-32: Prefer TypeError exception for invalid type
(TRY004)
32-32: Avoid specifying long messages outside the exception class
(TRY003)
86-86: Unused function argument: mod
(ARG001)
86-86: Unused function argument: ctx
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Build wheels for Python 3.9 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.9 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.9 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build SDist
| def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool: | ||
| if pass_ctx is None: | ||
| pass_ctx = tilelang.transform.get_pass_context() | ||
| enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False) | ||
| return enabled | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Harden should_enable_layout_visual against string-valued configs.
Right now enabled is returned directly from PassContext.config. If a user mistakenly sets "false" (string) instead of False (bool), if should_enable_layout_visual(): will still evaluate truthy and enable visualization.
A small defensive tweak keeps bool configs working while handling strings safely:
def should_enable_layout_visual(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
- enabled = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
- return enabled
+ value = pass_ctx.config.get(tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_ENABLE, False)
+ if isinstance(value, str):
+ v = value.strip().lower()
+ return bool(v and v != "false")
+ return bool(value)This keeps the default “disabled unless explicitly enabled” behavior while avoiding surprises from accidental string usage.
🤖 Prompt for AI Agents
In tilelang/engine/phase.py around lines 70 to 75, the function returns the raw
value from pass_ctx.config which can be a string like "false" and still evaluate
truthy; change the logic to coerce and validate the config into a strict
boolean: fetch the raw value, if it is a bool return it, if it is a str
interpret only common true values (e.g. "true", "1", "yes") as True
(case-insensitive) and everything else as False, and default to False when the
key is missing or value is None.
* [Example] Add GQA decoding kernel with varlen page table (#1265) * [Example] Add page table for gqa decode * [Example] Page table for varlen decoding * [Lint] * [Refactor] Remove redundant code * [Lint] * [Lint] * [Lint] * [Refactor] add support for numpy dtype conversion (#1255) * add typing stub for tir.ir * remove idents * minor update * [Refactor] add numpy conversion for dtype * fix lint error * remove unused np.float_ in dtype conversion * fix type in np.int_ * fix typo * minor fix * remove debug files * [EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability (#1148) * Keep the max of all blocks seen in scores_max for stability * ruff formatting * [Docs] Improve Installation Guide (#1270) * [Docs] Improve installation guide * address comments * [Enhancement] Keep max score attention across blocks in FlashAttention for better numerical stablity (#1269) * Implement max score retention across blocks in FlashAttention for improved stability * fix manual pipeline parameters * Update examples/flash_attention/example_gqa_fwd_varlen.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * fix typo * more * fix a previous typo --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * [Bugfix] Fix multiple cg defination when using T.sync_grid (#1272) * [Minor] Remove from __future__ import annotations for python 3.8 (#1273) * [BugFix] Adding extra parameters into autotune hashkey (#1274) * [BugFix] Adding extra parameters into autotune hashkey * lint * None check * check serializable * Fix various issues under `int64_t` static and dynamic shape. (#1218) * Fix various issues under int64_t static and dynamic shape. * Resolve reviewed issues. * Add unit test. * fix --------- Co-authored-by: LeiWang1999 <[email protected]> * Bug fix for Gated Delta Net benchmark script (#1267) * fix argument order for fla chunk_gated_delta_rule_fwd_h * explicit import assert_similar from utils * rename utils module to avoid name clash * set store_final_state and save_new_value to True * fix --------- Co-authored-by: LeiWang1999 <[email protected]> * [Bugfix] Minor fix for some cases (#1278) * [Language] Add shape check in `T.view/reshape` (#1277) * [Language] Add shape check in T.view/reshape * address comments * [FFI] Use tvm ffi as the default execution backend (#1259) * [Refactor] Update FFI type handling and simplify argument management * Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity. * Updated function registration in `runtime.cc` to utilize canonical names for better consistency. * Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled. * Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection. * Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity. * [Update] Sync TVM submodule and enhance kernel source handling * Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes. * Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging. * Commented out the main execution call in test files to prevent unintended execution during testing. * Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues. * Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends. * [Refactor] Clean up imports and improve code formatting * Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code. * Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency. * Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality. * Update execution backend options and improve resolution logic - Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target. - Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions. - Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target. - Updated documentation to reflect changes in execution backend options and their defaults. * lint fix * fix * Enhance argument handling in CUDA and HIP runtime modules - Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime. - Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers. - Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks. * lint fix * lint fix * lint fix * lint fix * minor fix * fix * recover check * Refactor argument binding and validation in `arg_binder.cc` - Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers. - Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards. - Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling. - Minor adjustments in test files to streamline kernel execution and improve readability. * lint fix * stride fix * minor fix * fix * lint fix * lint fix * Add CUDA stream access policy window helpers and integrate with L2 persistent cache management - Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage. - Updated runtime files to include new FFI packed functions for managing stream attributes. - Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown. - Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source. * check with symbolic * support null ptr * Update CMakeLists and lower.py for code generation and subproject status - Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support. - Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility. - Marked the TVM subproject as dirty to indicate local modifications. * lint fix * Update comments for clarity in quickstart.py * [Bugfix] Supply missing `T.print` for bool type (#1279) * fix for bool dtype * lint fix * fix * ci fix * [Fix] Fix memory leak bug (#1281) * add typing stub for tir.ir * remove idents * minor update * [Refactor] add numpy conversion for dtype * fix lint error * remove unused np.float_ in dtype conversion * fix type in np.int_ * fix typo * minor fix * remove debug files * fix memory leak bug * fix lint error * add comments * fix lint error * remove duplicated, because tilelang doesn't dependent deprecated * [Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283) - Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options. - Introduced handling for fast math and PTXAS options based on the provided pass configuration. - Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings. - Refactored NVCC command construction to use a dedicated function for better clarity and maintainability. * Fix the bug in issue #1266 (#1284) Co-authored-by: cheeryBloosm <[email protected]> * [Language][UX] Nested loop checker in pre-lowering stage (#1288) * [Language][UX] Nested loop checker in pre-lowering stage * rename * comment * address comments * [Compatibility] Support CUDA 11.3 (#1290) * [Feat] Add support for using `T.Tensor(n * 2 + 1)` in function annotation (#1285) * [Feature] Add support for A: T.Tensor(n + 1) and A: T.Tensor(2*n) * issue fix * fix * fix * decreate nproc for debugging --------- Co-authored-by: Lei Wang <[email protected]> * [Feat] add support for passing reference in T.Var annotation (#1291) * [Enhancement] Shared Memory Size Can be Dynamic (#1294) * bugfix * lint fix * test * lint fix * increate procs * recover * [Fix] Remove unused let_bindings_ in CodeGenC to fix #1300 (#1305) * [Feat] add missing support of uint32x2 * [Feat] Add `T.Ref` annotation and tests * fix lint error * minor update for error message on twice decl * Remove unused let_bindings_ in CodeGenC to fix #1300 * [Bugfix] Fallback to the old AtomicAdd implementation for legacy architectures (#1306) * [Fix] Fix frame scope error in T.macro (#1308) * [Fix] Fix #1307 by adding macro inside function * fix lint error * add comments and fix lint error * Remove debug print from enter_frame method Removed debug print statement from enter_frame method. --------- Co-authored-by: Lei Wang <[email protected]> * [WIP] support more dtypes for tcgen05 (#1229) support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis * Improve memory access safety and `T.assume` handling (#1292) * Improve memory access safety and T.assume handling * Improve memory access safety and T.assume handling * bugfix * lint fix * bugfix * bugfix * refactor legalize safe memory access pass --------- Co-authored-by: Lei Wang <[email protected]> * [Bugfix] Fix autotune cache (#1315) * [Refactor] Backup Analyzer to get the appropriate arith informations (#1311) * [Refactor] Update Vectorization Functions to Accept Analyzer Parameter - Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization. - Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness. - Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities. * [Fix] Corrected PostOrderVisit call in loop_vectorize.cc - Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis. * fix * lint fix * fix * Revert "[WIP] support more dtypes for tcgen05 (#1229)" (#1323) This reverts commit 0d101c110f74ebf2ef8c11a5ece9dfb314b48baa. Co-authored-by: Zhiwen Mo <[email protected]> * [CI]: Bump actions/checkout from 5 to 6 (#1319) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [CI]: Bump pypa/cibuildwheel from 3.2 to 3.3 (#1318) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [Installation] Fix building using customized TVM path (#1326) * [Release] Allow developer with write permission to trigger wheel release (#1322) * [Feat] Support warp reduce (#1316) * [Feat] Support warp reduce * lint * add test * lint * [Enhancement] Support more dtype in `T.print` (#1329) * [Enhancement] Support more dtype in `T.print` * upd * upd * [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape (#1321) * [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape * remove debug lines * remove rubbish * Fix decorator syntax for atomic_different_memory_orders_program --------- Co-authored-by: Lei Wang <[email protected]> * [Fix] fix wrong uint narrowing bug in tvm in #1310 (#1320) * [Refactor] Disable strided buffer load inside tvm (#1301) (#1332) * [Refactor] Moving `NormalizeToBufferRegion` and `MakeAccessPtrFromRegion` to utils (#1333) * Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse. * lint fix * [Fix] Fix bug copying from or to local buffer (#1304) (#1324) * [Fix] fix copy from or to local buffer (#1304) * fix lint error * minor fix testing script * [Language][UX] Semantic check for parallel fragment access (#1338) * Add unit tests for T.assume (#1341) * Add test for T.assume * Add unit test for T.assume * Add unit test for T.assume * Add unit tests for T.assume * Remove debug print for kernel source Remove print statement for kernel source in tests. * Update test_tilelang_language_assume.py --------- Co-authored-by: Lei Wang <[email protected]> * [Feat] Extend LegalizeNegativeIndex to support buffer store stmts (#1339) This commit enhances the LegalizeNegativeIndex transformation pass to handle both buffer load and store operations with negative indices and adds some test cases. * [Refactor] Phaseout vmap for Tile Operators (#1334) * Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse. * lint fix * Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations. * fix * Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions. * fix * fix * test fix * lint fix * Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management. * fix * lint fix * fix * fix * test fix * lint fix * lint fix * minor fix * fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327) * feat: add fp8 variants; add placeholder for fp6/fp4 in meta support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis * fix: when atom-m!=128, enable_ws * fix: typo in tcgen05 meta; dispatch in gemm sm100 * [Refactor] Enhance CopyNode's IterVar Creation and Range Handling (#1346) * [Refactor] Enhance CopyNode's IterVar Creation and Range Handling This commit refines the `MakeIterVars` method in `CopyNode` to select base ranges based on memory scope levels, ensuring that the chosen ranges are not smaller than the original source ranges. Additionally, it updates the Python `copy` function to clarify range handling, including broadcasting logic and extent alignment. These changes improve the robustness and clarity of the copy operation's implementation. * test fix * [Fix] Fix missing `not` rewrite in frontend (#1348) * [Enhancement] Add support for k_pack in gemm_mfma (#1344) * add support for k_pack * support benchmark on ROCm * fix format * Add sparse fine-tuning kernel for deepseek sparse attention to example (#1296) * [EXAMPLE] add example for dsa sparse finetuning * [Refactor] * [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352) * [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase. * [Enhancement] Update matmul kernel and optimize argument binding This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code. * lint fix * [Enhancement] Add tensor checks documentation and improve argument binding assertions This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code. * [Enhancement] Update .gitignore and refine matmul kernel for improved performance This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users. * lint fix * lint fix * [Refactor] Simplify tensor_null_test function and remove ptr_null_test This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations. * lint fix * fix * [Refactor] Simplify index sign state handling in LegalizeNegativeIndex (#1354) This commit refines the logic for determining the sign state of indices in the LegalizeNegativeIndex transformation. It prioritizes vector patterns, specifically Ramp and Broadcast nodes, to avoid compile-time lane queries. The handling of scalar indices is also streamlined, ensuring clearer diagnostics when non-negativity cannot be proven. These changes enhance the robustness and clarity of index handling in the transformation pass. * [Enhancement] Improve error handling and assertion messages across runtime and argument binding (#1356) This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process. * [Bugfix] Disable floordiv optimization due to integer overflow risk (#1355) * disable overflow-prone floordiv optimization in lower_intrin.cc * disable overflow-prone floordiv optimization in lower_intrin.cc * [Bugfix] Fix the jit_kernel issue (#1357) * [Bugfix] Fix the jit_kernel issue * Update README.md --------- Co-authored-by: Lei Wang <[email protected]> * [Refactor] Update Fragment Indexing in ParallelOpNode's InferLayout Method (#1359) This commit refines the Fragment creation process in the InferLayout method of ParallelOpNode. It removes the unnecessary forward_index array and utilizes default fragment indexing for consistency with other operations. Additionally, it binds the thread range to enhance comparability across different operations. * [Analysis] Enhance NestedLoopChecker with tile op cases (#1358) * [Analysis] Enhance NestedLoopChecker with tile op cases * fix tileop issue * [Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056) * [misc] add a cpp side wrapper for gemm_sp_py * [misc] typing * [IR] bind GemmSPWarpPolicy * [chore] add wrapper code * [IR] fix GemmSPWarpPolicy * [codegen] apply ptxas instructions * [intrinsic] add typical (unused) mma layout * [template] add uint16 debug func * [intrinsic] add b matrix layout * [gemm_sp] enable fp16/bf16 on sm8x * [layout] refactor fp16/bf16 layout * [gemm_sp] enable int8 * [chore] update test case dtype * [gemm_sp] enable fp32 * [layout] refactor layouts * [intrinsic] enable ldmatrix for mat A * [layout] enable ldsm for matrix b * [layout] add ldmatrix for fp32 and fp8 * [chore] refine * [chore] refactor * [chore] add fp8 efactor * [chore] refactor * [chore] add remove negative zero util * [example] add a custom compress kernel * [chore] minor update * [test] refactor gemm_sp test * [refactor] make metadata layout func * [example] add option for using cutlass layout * [doc] add a gemm_sp doc * [doc] minor polish * [chore] remove unused * [bugfix] fix non replicate b case * [test] refactor * [chore] add a check * [bugfix] fix util bug * [wip] init a new test case for v2 * [chore] minor refactor * [chore] minor update * [bugfix] enable 16bit rs * [language] enable rs * [language] enable gemm_sp_sr * [language] enable gemm_sp_rr * [test] enable more tests * [tvm] update ffi binding * [chore] remove print * [chore] fix benchmark script * [lint] precommit lint * [chore] apply feedback * [test] use arch 8.0 * [chore] rollback ::ordered_metadata for backward compatibility * [bugfix] fix captialized * [example] keep gemm_sp on hopper * [test] fix no fp8 normal kernel * [test] reduce matmul size to satisfy accum error * [test] use cal_diff for assertion * [bugfix] expand float8 type * [lib] add make_int4 for short type * [language] add transpose E * [bugfix] fix wrong var * [format] format * [chore] refactor binding * [chore] fix wrong passing var * [Bugfix] Update TIR registration for GemmSPPy to use tile operation (#1361) * [Enhancement] Implement dynamic unroll factor in CUDA code generation (#1360) * [Enhancement] Implement dynamic unroll factor in CUDA code generation This commit introduces support for specifying a dynamic unroll factor in the CUDA code generation. The `unroll_factor` map is added to store unroll factors for loop variables, allowing for more flexible and optimized loop unrolling. Additionally, the `unroll` function is integrated into the loop language, enabling users to define unroll factors directly in their code. This enhancement improves performance by allowing tailored unrolling strategies based on specific loop characteristics. * lint fix * [Bugfix] Correct initialization of non-zero counters in custom compress kernel and update TIR registration for gemm_sp_py to use the correct tile operation * [CI] [pre-commit.ci] autoupdate (#1362) updates: - [github.com/pre-commit/mirrors-clang-format: v21.1.2 → v21.1.6](https://github.com/pre-commit/mirrors-clang-format/compare/v21.1.2...v21.1.6) - [github.com/astral-sh/ruff-pre-commit: v0.14.3 → v0.14.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.14.3...v0.14.7) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Bugfix] Remove debug print in PyStmtFunctionVisitor (#1363) * [Debug] Always include line info in NVCC command for improved profiling and mapping (#1364) * [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py (#1365) * [Enhancement] Add DISABLE_CACHE environment variables (#1368) * [Refactor]: Remove useless include in atomicadd_vectorize.h (#1371) * [Refactor] Generalize fp8 process (#1372) * [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py * [Enhancement] Extend support for float8 data types in GEMM operations - Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`. - Refactored condition checks in `checkWgmma` methods to simplify float8 type handling. - Adjusted test cases to ensure compatibility with the new float8 types in tile language examples. * lint fix * [Layout] Enhance Free Layout Inference (#1375) * [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py * [Enhancement] Extend support for float8 data types in GEMM operations - Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`. - Refactored condition checks in `checkWgmma` methods to simplify float8 type handling. - Adjusted test cases to ensure compatibility with the new float8 types in tile language examples. * lint fix * [Enhancement] Add injective layout detection and exception handling - Introduced `DetectInjective` method in `FragmentNode` to check for injective layouts. - Added `LoopLayoutInjectiveException` to handle errors related to non-injective layouts. - Updated `InferLayout` methods in `ParallelOpNode` to utilize injective checks and log relevant information. - Refactored layout inference queue management to use `std::deque` for improved performance and added prioritization logic for buffer layouts. * remove debug print * remove debug print * remove debug print * minor layout fix * fix for T.view * [Enhancement] Improve injective layout detection in FragmentNode - Updated the `DetectInjective` method to handle symbolic dimensions more effectively by introducing a mechanism to collect symbolic shapes and adjust the detection level accordingly. - Added logging for cases where the layout detection falls back to NoCheck due to symbolic dimensions. - Minor update to the test file to include the tilelang testing module. * [Refactor] Simplify layout inference for bulk copy operations - Removed unnecessary conditions for bulk load/store operations in the layout inference logic. - Streamlined the handling of layout application for bulk copy instances to enhance clarity and maintainability. * remove debug print * [Enhancement] Introduce layout-related exceptions and improve error handling - Added `LayoutConflictException` and `LoopLayoutInjectiveException` classes for better exception management in layout operations. - Updated `InferLayout` method in `ParallelOpNode` to throw `LoopLayoutInjectiveException` with detailed error information when injective layout checks fail. - Removed redundant exception class definitions from `parallel.h` to streamline code organization. * [Enhancement] Introduce buffer var lca analysis for pass plan buffer allocations (#1376) * Update submodule TVM to latest commit and add PlanAndUpdateBufferAllocationLocation function to transform module - Updated the TVM submodule to commit 3a32b763. - Added a new function `PlanAndUpdateBufferAllocationLocation` in the transform module to facilitate buffer allocation planning within PrimFuncs. * Refactor buffer allocation code for improved readability and consistency - Updated formatting and spacing in `plan_update_buffer_allocation_location.cc` for better code clarity. - Standardized the use of pointer and reference syntax across various class methods. - Enhanced comments for better understanding of buffer allocation logic. - Removed unnecessary lines and improved overall code structure. * Refactor buffer allocation checks for improved clarity - Replaced size checks with empty checks for `ffi::Array<Buffer>` in `plan_update_buffer_allocation_location.cc` to enhance code readability. - Updated conditions in multiple methods to use `empty()` instead of comparing size to zero, streamlining the logic. * [Tool] Provide layout visualization tool (#1353) * Provide layout visualization tool Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation. This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance. The visualization can be enabled via a pass config option. * format * add layout visual example * Adds vis extra with matplotlib dependency * rafactor pass config name * fix lint * Enables configurable layout visualization formats Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option. This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs. It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false". * Adds visual layout inference tool docs * fix lint * fix lint * Rafactor configurable layout visualization formats * fix lint * fix typo * add some comments * fix lints * add some warnings for user * Moves layout visualization * Refactors layout visualization pass configuration Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats. * Enables multiple layout visualization formats * Updates layout visualization docs * Moves layout visualization to analysis * [Release] Relax constraint of tvm-ffi to compatible version (#1373) Co-authored-by: LeiWang1999 <[email protected]> * [Language] Tilelang LazyJIT Experimental Version (#1337) * initial step * modify builder * scratch version of new frontend * write some tests * add many tests * add typing stub for tir.ir * remove idents * minor update * minor update * First version of jitv2 (renamed to LazyJIT) * fix pre-commit error * minor fix * fix lint error * fix lint error * Fix conditional check for PrimFunc instance --------- Co-authored-by: Lei Wang <[email protected]> * [Builder] Enhance variable name binding and scope management (#1378) - Improved handling of TVM Var/Buffer names to prevent out-of-scope errors when reusing Python names across different for-frames. - Added assertions to ensure variables are defined within the correct control flow frame, enhancing error checking and code reliability. * [Bugfix] make cuda driver api compat with cuda12/13, along with tests (#1379) * [Fix] typo in cuda attr (#1380) * [Bugfix] make cuda driver api compat with cuda12/13, along with tests * fix typo in cudaDevAttr * [Language V2] Minor fix for complex annotations (#1381) * [Release] Bump Version into 0.1.7 (#1377) * Update VERSION to 0.1.7 * Update Python version in distribution scripts to support CPython 3.9 and log output * [Typing] Enhance compatibility for advanced typing features in Python (#1382) - Updated `allocate.py` and `annot.py` to improve compatibility with Python 3.9 and later by conditionally importing advanced typing features such as `TypeVarTuple`, `Unpack`, and `ParamSpec`. - Added fallback imports from `typing_extensions` for environments using earlier Python versions. - Improved handling of generic alias detection to ensure consistent behavior across different Python versions. * [Bugfix][Build] Update CMake configuration to remove project root injection for sys.path (#1385) * [Build] Update CMake configuration for tilelang_cython_wrapper installation - Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib. - Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules. - Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects. * [Build] Standardize output directories for tilelang libraries - Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds. - This change enhances organization and ensures that all build artifacts are located in a unified directory structure. * [BugFix] Fix split kernel layout bug of GQA decode (#1386) * [BugFix] Fix split kernel layout bug of GQA decode * [BugFix] Avoid local with Parallel; use robust fragment instead * [Enhancement] Add debug output methods for Layout and Fragment classes (#1392) * [Doc] Update logging docs (#1395) * [Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399) * [Build] Update CMake configuration for tilelang_cython_wrapper installation - Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib. - Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules. - Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects. * [Build] Standardize output directories for tilelang libraries - Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds. - This change enhances organization and ensures that all build artifacts are located in a unified directory structure. * [Refactor] Update TVM subproject and enhance pipeline loop handling - Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0. - Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management. - Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts. - Simplified access index calculations and strengthened analyzer constraints for loop bounds. * [Cleanup] Remove license block and unused includes from inject_pipeline.cc - Eliminated the Apache license block from the top of the file to streamline the code. - Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies. * [Refactor] Enhance transformation pipeline and test execution - Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization. - Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation. * [AMD] Fix 3 bugs when build docker on amd mi3x gpu (#1401) * [Typo] Fix tilelang link in README.md (#1402) * [Dependency] Update apache-tvm-ffi version to >=0.1.2 (#1400) * [Dependency] Update apache-tvm-ffi version to >=0.1.2 in project files * [Dependency] Update subproject commit for TVM to latest version afc07935 * [Enhancement] Add support for optional step parameter in loop constructs - Updated loop creation functions to accept an optional step parameter, enhancing flexibility in loop definitions. - Modified ForFrame implementations to utilize the new step parameter across various loop types including serial, parallel, and pipelined loops. - Adjusted related vectorization transformations to accommodate the step parameter, ensuring consistent behavior in loop vectorization processes. * lint fix * [AMD] Enable FA2 fwd on AMD MI300X (#1406) * enable FA2 on AMD MI300X * make lint happy * [TypoFix] fix typo for SM120 (#1408) * [Doc] Minor documentation update (#1410) * [Dependency] Add torch-c-dlpack-ext to project requirements (#1403) * [Dependency] Add torch-c-dlpack-ext to project requirements * Added torch-c-dlpack-ext to both pyproject.toml and requirements.txt to provide prebuilt torch extensions, which may prevent JIT compilation on first import of TVM FFI. * [Build] Update manylinux images in project configuration * Changed the manylinux image for x86_64 from "manylinux2014" to "manylinux_2_28" in both pyproject.toml and the Dockerfile to align with updated standards for compatibility and performance. * [Build] Update CUDA repository configuration in pyproject.toml * Changed the package manager command from `yum-config-manager` to `dnf config-manager` for adding the CUDA repository, ensuring compatibility with newer systems. * fix * [Build] Update CUDA repository to RHEL 8 * Changed the CUDA repository configuration in both pyproject.toml and the manylinux Dockerfile from RHEL 7 to RHEL 8, ensuring compatibility with newer systems. * test: run out of space * use cu130 to reduce size * upd * upd comment * upd --------- Co-authored-by: Your Name <[email protected]> * [Dependency] Update TVM subproject to latest commit 2b1ead1a (#1412) * [Enhancement] Introduce `T.__ldg` (#1414) * [Enhancement] Add __ldg intrinsic for CUDA read-only cache loads * Introduced the __ldg intrinsic to enable explicit read-only cached loads from global memory in CUDA. * Updated the corresponding documentation and added support in both CUDA and HIP code generation. * Enhanced the Python interface for __ldg to accept BufferLoad and Buffer types, improving usability. * [Enhancement] Update formatting and linting rules in pyproject.toml; minor test adjustment * Added new formatting rules in pyproject.toml to enforce consistent code style, including hanging indents and argument splitting. * Updated test_tilelang_language_intrinsics_codegen.py to improve readability by adding a blank line before the main execution block. * Refactored error messages in builtin.py for better clarity and consistency, ensuring proper formatting in function definitions and raising ValueErrors. * lint fix * [Enhancement] Improve vectorization invariant check (#1398) * Improve loop vectorize * Improve loop vectorize * Improve loop vectorize * Improve loop vectorize * Improve loop vectorize * Add some vectorize tests and comments * [Lint] Phaseout Yapf format and embrace ruff format (#1417) * [Atomic] Use ptr for atomicAdd dst instead of reference (#1425) * [Enhancement] Update AtomicAdd function signature to accept pointer to destination * Modified AtomicAdd in CUDA to take a pointer instead of a reference for the destination argument. * Updated related code in atomicadd_vectorize.cc to ensure compatibility with the new signature. * Adjusted Python interface in atomic.py to pass the destination by pointer, aligning with device function requirements. * [Enhancement] Refactor AtomicAddRet function signature to accept pointer * Updated AtomicAddRet in both CUDA and HIP to take a pointer instead of a reference for the address argument, improving consistency with the AtomicAdd function. * Adjusted the implementation to ensure proper reinterpretation of the address type for atomic operations. * lint fix * [Enhancement] Refactor AtomicAddNode::MakeSIMTLoop to use destination pointer * Updated the MakeSIMTLoop function to build a pointer to the destination element using tvm_access_ptr instead of loading the destination value directly. * Simplified the handling of source and destination predicates, improving clarity and maintainability of the code. * Ensured compatibility with the new pointer-based approach for atomic operations. * lint fix * test fix * lint fix * [CUDA] Add read-only parameter annotation for CUDA codegen (#1416) * [Enhancement] Add read-only parameter annotation for CUDA codegen * Introduced the `AnnotateReadOnlyParams` transformation to annotate read-only handle parameters in PrimFuncs, enabling the generation of `const` qualifiers in CUDA codegen. * Updated `PrintFunctionSignature` and `AddFunction` methods to utilize the new attribute `tl.readonly_param_indices`, enhancing performance by allowing read-only cache loads. * Modified the optimization pipeline to include the new annotation step, improving the overall efficiency of the code generation process. * lint fix * [Dependency] Update apache-tvm-ffi version to >=0.1.3 * Updated the version of apache-tvm-ffi in pyproject.toml, requirements.txt, and requirements-dev.txt to ensure compatibility with the latest features and fixes. * Made adjustments in CUDA and HIP template files to use `const` qualifiers for global pointer parameters, enhancing code safety and clarity. * lint fix * [Enhancement] Refactor ReadWriteMarker for improved parameter handling * Updated the ReadWriteMarker class to accept a set of parameter or data variables, enhancing its ability to track written variables. * Introduced a new method, ResolveDataVarFromPtrArg, to resolve underlying buffer data from pointer-like arguments, improving accuracy in identifying written variables. * Modified the MarkReadOnlyParams function to gather handle parameters and their corresponding buffer data variables, streamlining the process of determining read-only parameters. * Enhanced the logic for identifying written variables to account for aliased data variables, ensuring comprehensive tracking of modifications. * lint fix * Update tma_load function to use const qualifier for global memory pointer * Changed the parameter type of gmem_ptr in the tma_load function from void* to void const* to enhance type safety and clarity in memory operations. * This modification ensures that the function correctly handles read-only global memory pointers, aligning with best practices in CUDA programming. * Remove commented-out code and reorder transformations in OptimizeForTarget function for clarity * Refactor buffer marking logic in annotate_read_only_params.cc to improve accuracy in identifying written variables. Update OptimizeForTarget function to reorder transformations for better clarity. * [Refactor] Phase out the primitives folder since its design has been merged into tileop (#1429) * Phase out primitives * revert changes * Refactor GemmWarpPolicy method signature for clarity Updated the `from_warp_partition` method in the `GemmWarpPolicy` class to return the type `GemmWarpPolicy` instead of a string, enhancing type safety and clarity in the codebase. Removed an unnecessary blank line for improved readability. * fix * [CI]: Bump actions/upload-artifact from 5 to 6 (#1431) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 5 to 6. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [CI]: Bump actions/download-artifact from 6 to 7 (#1432) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 6 to 7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [Bugfix] Convey `compile_flags` to ffi compilation path with pass_configs (#1434) * [Enhancement] Add device compile flags support in pass configuration * Introduced `kDeviceCompileFlags` option in the pass configuration to allow additional device compiler flags for CUDA compilation. * Updated the `tilelang_callback_cuda_compile` function to merge extra flags from the pass configuration, enhancing flexibility in compiler options. * Modified the `JITKernel` class to handle device compile flags appropriately, ensuring they are included during compilation. * Documented the new pass configuration key for clarity on usage and expected input formats. * lint fix * [Refactor] Simplify compile_flags handling in JIT functions * Removed redundant string check for compile_flags in the compile, jit, and lazy_jit functions, ensuring compile_flags is consistently treated as a list. * Updated the JITKernel class to handle compile_flags as a list when a string is provided, enhancing code clarity and maintainability. * lint fix * fix * [Enhancement] Improve buffer usage tracking in MakePackedAPI (#1435) * Added detailed logging for data and shape variable parameters during buffer usage detection in the MakePackedAPI function. * Refactored the UsedBufferDetector to differentiate between used parameters by data and shape variables, enhancing clarity in buffer management. * Updated logic to ensure minimal carrier buffers are selected for shape symbols, improving the efficiency of parameter handling. * [Enhancement] Improve InjectAssumes logic and make assumes work after SplitHostDevice (#1405) * [Refactor] Refactor InjectAssumes logic and make assumes work after SplitHostDevice * address comments * fix * fix submodule * fix * fix 3rdparty * [Enhancement] Include PrimFunc name in memory cache logs for better debugging (#1437) * Added the `get_prim_func_name` utility to extract human-readable function names from TVM PrimFuncs. * Updated memory cache logging in `AutoTuner` and `KernelCache` classes to include the kernel name, improving clarity during cache hits. * Enhanced debug logging to provide more informative messages when checking disk cache for kernels. * [CI] Update lint dependencies and fix lint on trunk (#1433) * [CI] Update pre-commit hooks * [Lint] Pass correct `exclude-header-filter` to `clang-tidy` * [Lint] Download latest `run-clang-tidy` script * [CI] Show compile commands * [CI] Add output grouping to GHA * [Lint] Re-order pre-commit hooks * [Enhancement] Refactor vectorization checks in loop_vectorize (#1440) * Introduced a new function, IsExprInvariantInVectorBoundary, to encapsulate the logic for checking if an expression is invariant within vector boundaries, improving code clarity and reusability. * Updated the existing vectorization logic to utilize this new function, streamlining the process of determining vectorization feasibility based on boundary conditions. * Enhanced comments for better understanding of the vectorization criteria and mathematical rationale behind the checks. * Enhance vectorized conversion support (#1438) * [Feature] Support region as input of T.cumsum (#1426) * [Feature] Support region as input of T.cumsum - Extend T.cumsum to accept BufferRegion and BufferLoad inputs in addition to Buffer - This enables operations on buffer slices/regions like: T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0) - Update cumsum_fragment to handle region inputs properly - Add comprehensive tests for 1D and 2D region inputs including normal and reverse modes Fixes #879 * Fix formatting and add docstring for cumsum_fragment - Add comprehensive docstring for cumsum_fragment function - Format code according to ruff style guidelines * Fix CodeRabbit review issues - Fix negative dimension bounds check (dim < -len(shape) instead of dim <= -len(shape)) - Add src/dst shape compatibility validation for out-of-place cumsum - Update copy() type annotation to accept BufferRegion as dst parameter - Fix test in-place mutation issues by using out-of-place cumsum operations - Add non-divisible size test cases for tail region coverage * Fix out-of-bounds access in region tests - Add bounds clamping using T.min() for chunk_end calculations - Prevents accessing beyond tensor bounds for non-divisible sizes - Matches reference implementation behavior - Fixes both 1D and 2D region test cases * Fix region test: use simple slice expressions instead of T.min() - Remove T.min() which cannot be used directly in slice indices - Use chunk_start + chunk_size form instead - Rely on system's automatic bounds checking for non-divisible sizes - Update comments to reflect this approach * Fix cumsum region: use region extents in lowering and update tests for shared memory * Simplify fragment scope check using is_fragment() --------- Co-authored-by: LeiWang1999 <[email protected]> * [Fix] Fix analyzer bind conflicting (#1446) * [Refactor] Reduce direct dependency on PyTorch due to its limited type support (#1444) * [Enhancement] Update KernelParam to use tvm.DataType directly and add torch_dtype conversion method - Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions. - Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation. - Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization. * [Enhancement] Refactor CUDA type handling and add support for FP4 and FP8 types - Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively. - Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64). - Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions. - Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing. - Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch. - Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively. * lint fix * Remove unnecessary logging statements from CUDA code generation and delete obsolete matrix multiplication test file. * [Enhancement] Add support for FP4 and FP8 types in CUDA code generation - Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types. - Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch. - Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations. - Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations. - Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows. --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Refactor] Use `pytest.mark.parameterize` to speedup parallel testing (#1447) * Refactor GEMM tests to use parameterized pytest fixtures - Converted multiple test cases for GEMM operations in `test_tilelang_tilelibrary_gemm_sp.py` to use `pytest.mark.parametrize` for better maintainability and readability. - Similar refactoring applied to `test_tilelang_tilelibrary_gemm_sp_v2.py`, consolidating test cases for `run_gemm_ss`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` into parameterized tests. - This change reduces code duplication and enhances the clarity of test configurations. * Update testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * [Docs] Improve installation instructions for developers (#1450) * [Feat] Integrate Z3 in TVM Arith Analyzer (#1367) * [Bugfix] Improve autotune from elementwise_add function in examples (#1445) * Remove JIT decorator from elementwise_add function in examples * fix kernel compilation without autotune * Refactor main function to accept parameters and update tests for autotune option * Refactor autotune test function for morden style * [Language] Introduce `T.annotate_restrict_buffers` (#1428) * [Enhancement] Introduce non-restrict parameter support in code generation - Added a new PrimFunc-level attribute `tl.non_restrict_params` to specify handle Vars that should not be marked with the restrict qualifier during code generation. - Updated `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to handle non-restrict parameters, ensuring proper treatment of overlapping buffer aliases. - Implemented a new annotation function `annotate_restrict_buffers` to facilitate the marking of buffer parameters as non-restrict. - Enhanced the `SplitHostDevice` transformation to propagate non-restrict parameters from host to device functions. - Added a new transform function `HoistNonRestrictParams` to manage non-restrict parameters effectively. * [Enhancement] Improve HoistNonRestrictParams transformation - Updated the HoistNonRestrictParams function to recursively collect all `tl.non_restrict_params` annotations from nested blocks, enhancing flexibility in annotation placement. - Introduced a new NonRestrictCollector class to manage the collection and deduplication of non-restrict parameters. - Modified the SplitHostDevice transformation to remove the non-restrict attribute from the host-side PrimFunc after propagation to device kernels. - Adjusted the LowerAndLegalize function to directly apply the HoistNonRestrictParams transformation without exception handling, streamlining the process. * [Refactor] Simplify non-restrict parameter handling in code generation - Removed unnecessary normalization logic and associated data structures from `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP`. - Streamlined the handling of non-restrict parameters by directly inserting them into the `non_restrict` set, improving code clarity and maintainability. - Updated conditional checks to eliminate redundant checks against normalized names, enhancing performance and readability. * [Dependency] Update TVM subproject to latest commit 68aa8461 - Updated the TVM subproject to the latest commit, ensuring compatibility with recent changes and improvements. - Refactored non-restrict parameter handling in `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to enhance code clarity and maintainability. - Adjusted the `SplitHostDevice` transformation to streamline the propagation of non-restrict parameters. * fix * [Analyzer] Require loop extent > 0 when entering loop (#1451) * Updat ROCm CI to Nightly-ROCm-7.1 (#1449) * [Enhancement] Update examples and tests for improved type handling functionality (#1448) * [Enhancement] Update examples and tests for improved type handling and functionality - Enhanced various example scripts to support new data types and improve compatibility with PyTorch. - Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling. - Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management. - Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling. * [Refactor] Update accumulation data type to float32 across examples - Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability. - This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board. * [Refactor] Standardize data type usage across benchmark scripts - Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling. - Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard. - Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules. * [Refactor] Standardize data type usage in templates and scripts - Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity. - Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates. - This change aims to streamline type handling and improve compatibility with existing workflows. * [Refactor] Standardize data type usage in examples and benchmarks - Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability. - Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard. - Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples. * [Refactor] Import dtypes from language.v2 module - Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase. - This change aims to streamline data type management and improve overall code clarity. * fix * [Refactor] Standardize data type usage across scripts - Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity. - Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability. - This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase. * [Refactor] Update data type handling for consistency and clarity - Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency. - Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase. - Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling. - This refactor aims to streamline data type management and improve overall code clarity and maintainability. * [Enhancement] Improve data type handling and error messaging - Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation. - Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs. - Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience. * [Fix] Correct boolean flag in GEMM SP test case - Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case. - This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation. * [Refactor] Standardize data type usage across scripts - Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability. - Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase. - This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability. * [Refactor] Standardize data type usage in various modules - Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability. - Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase. - This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations. * [Refactor] Update argument parsing for data types in benchmarks - Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float. - This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability. * [Refactor] Update data type handling in benchmark and example scripts - Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency. - Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase. - This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples. * [Refactor] Fix data type conversion in multiple scripts - Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts. - This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations. * [Refactor] Update float8 data type usage across multiple scripts - Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling. - This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations. * [Refactor] Enhance float8 data type handling in CUDA code generation - Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic. - Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase. - Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation. - This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy. * [Refactor] Streamline float8 data type handling in CUDA and related modules - Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks. - Updated layout inference for float8 types to improve clarity and maintainability across the implementation. - This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy. * [Refactor] Remove unnecessary cache disabling in float8 example script - Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code. - This change enhances clarity and maintainability of the example script without affecting its functionality. * [Refactor] Update data type usage in debug print tests - Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type. - This update enhances consistency in data type handling within the test suite, improving clarity and maintainability. * lint fix * Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples * Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines. * [Issue Template] Enable blank issues in GitHub issue template(#1453) * [CI] Moved the clang-tidy step to after pip install (#1456) * [Bug] Fix tvm build script when patchelf is not found #1459) * [Analyzer] Fix floordiv & floormod bug in z3 prover (#1458) * fix floordiv & floormod in z3 prover * fix lint error * [Cache] Rename sparse compress cache directory (#1460) * Enhance cache directory structure by including version information in sparse.py to ensure separate caches for different versions. * Fix formatting in sparse.py by adding a newline for improved readability and consistency. * [Language]Adds a random number generation capability through curand_kernel (#1461) * add curand.{curand_init, curand} * run format.sh * add default value for curand_init & add test for curand * Update testing/python/language/test_rand.py Remove unused thread binding Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * remove unused library * enable tilelang cache for testing * run format.sh * Revert "run format.sh" This reverts commit 5afaff782f31cdf653e2c45b469da8dead228b8a. * Revert "enable tilelang cache for testing" This reverts commit c277a43e77938bd88d47a108dd1bd65734d4a1ae. * Revert "remove unused library" This reverts commit 568ad20611f039380113937fd131151a2bffd801. * run format.sh * ensure FreshName for __philox_state * ensure FreshName for __philox_state …
* Enhance threadblock swizzle templates with default offset parameter and streamline parser.py for better readability * [Cache] Rename sparse compress cache directory * Temporarily exclude sink tests from non-distributed example tests in CI to address timeout issues * [DeepEP] Move deepep benchmark to example and allow compatible with new version DeepEP * [Feat] Enhance `T.st` to support intra-node store to peer's symm memory * use strided loop to simplify get_dispatch a bit * [Feat] Support warp reduce operators * draft notify dispatch * rename and refactor `T.barrier/sync_blocks` * fix prev typo * [Feat] Add `get_device_tensor` function and related test * support elect_one_sync() and add test * draft dispatch * suupport ld, st, warp_sync, continue and add test * support warp vote and add test * support device-side wait_ne * refactor T.wait_* and refine dispatch test logic * intra-node dispatch test passed * draft combine * support massage-only debug print * intra-node combine test passed * unify dispatch, migrate topk_idx to u64, support cached dispatch * Refactor to pre-alloc buffers and expose interface, add benchmark * remove redundant test * update doc * use int4 vectorization for dispatch * use comm_stream for comm kernels * optimze dispatch perf via skipping tensor validation * add dispatch benchmark result * make rank as an argument of the kernel * use cuda postproc for vectorization in combine * support int4 ld/st ptx in cuda template * [Feat] Support auto vectorization for ld/st to optimize combine to surpass deepep * lint * upd doc * make ci happy * fix review issues * fix import error * Add DeepEP submodule and installation script for CI * fix ci bug * [Sync] Merge mainstream TileLang TVM-FFI features into TileScale (#47) * [Example] Add GQA decoding kernel with varlen page table (#1265) * [Example] Add page table for gqa decode * [Example] Page table for varlen decoding * [Lint] * [Refactor] Remove redundant code * [Lint] * [Lint] * [Lint] * [Refactor] add support for numpy dtype conversion (#1255) * add typing stub for tir.ir * remove idents * minor update * [Refactor] add numpy conversion for dtype * fix lint error * remove unused np.float_ in dtype conversion * fix type in np.int_ * fix typo * minor fix * remove debug files * [EXAMPLE] In the flash attention example keep the max of all blocks seen in scores_max numerical stability (#1148) * Keep the max of all blocks seen in scores_max for stability * ruff formatting * [Docs] Improve Installation Guide (#1270) * [Docs] Improve installation guide * address comments * [Enhancement] Keep max score attention across blocks in FlashAttention for better numerical stablity (#1269) * Implement max score retention across blocks in FlashAttention for improved stability * fix manual pipeline parameters * Update examples/flash_attention/example_gqa_fwd_varlen.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * fix typo * more * fix a previous typo --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * [Bugfix] Fix multiple cg defination when using T.sync_grid (#1272) * [Minor] Remove from __future__ import annotations for python 3.8 (#1273) * [BugFix] Adding extra parameters into autotune hashkey (#1274) * [BugFix] Adding extra parameters into autotune hashkey * lint * None check * check serializable * Fix various issues under `int64_t` static and dynamic shape. (#1218) * Fix various issues under int64_t static and dynamic shape. * Resolve reviewed issues. * Add unit test. * fix --------- Co-authored-by: LeiWang1999 <[email protected]> * Bug fix for Gated Delta Net benchmark script (#1267) * fix argument order for fla chunk_gated_delta_rule_fwd_h * explicit import assert_similar from utils * rename utils module to avoid name clash * set store_final_state and save_new_value to True * fix --------- Co-authored-by: LeiWang1999 <[email protected]> * [Bugfix] Minor fix for some cases (#1278) * [Language] Add shape check in `T.view/reshape` (#1277) * [Language] Add shape check in T.view/reshape * address comments * [FFI] Use tvm ffi as the default execution backend (#1259) * [Refactor] Update FFI type handling and simplify argument management * Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity. * Updated function registration in `runtime.cc` to utilize canonical names for better consistency. * Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled. * Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection. * Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity. * [Update] Sync TVM submodule and enhance kernel source handling * Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes. * Added functionality to print kernel source in `example_blocksparse_gemm.py` for better debugging. * Commented out the main execution call in test files to prevent unintended execution during testing. * Introduced `tilelang.disable_cache()` in various test files to streamline testing and avoid cache-related issues. * Refactored kernel source retrieval methods to improve clarity and consistency across different execution backends. * [Refactor] Clean up imports and improve code formatting * Removed unused import of `tilelang.testing` in `test_example_blocksparse_gemm.py` to streamline the code. * Reformatted several lines in `arg_binder.cc`, `make_packed_api.cc`, `tvm_ffi.py`, and `adapter.py` for improved readability and consistency. * Updated comments and spacing in `tvm_ffi.py` to enhance clarity without altering functionality. * Update execution backend options and improve resolution logic - Changed default execution backend from "cython" to "auto" in multiple locations to allow automatic selection based on the target. - Expanded the list of supported execution backends to include "torch" and "nvrtc" across various classes and functions. - Enhanced backend resolution logic in `KernelCache` and `AutoTuner` to ensure appropriate backend selection based on the target. - Updated documentation to reflect changes in execution backend options and their defaults. * lint fix * fix * Enhance argument handling in CUDA and HIP runtime modules - Updated `ExtractFuncInfo` in `rt_mod_cuda.cc` and `rt_mod_hip.cc` to map boolean argument types to int32, ensuring compatibility with device runtime. - Refactored `BindDLTensor` in `arg_binder.cc` to improve null handling and validation checks for DLTensor parameters, utilizing expression-level guards to prevent dereferencing null pointers. - Enhanced error checking for buffer shape, strides, and data fields, ensuring robust handling of optional inputs and maintaining consistency across various checks. * lint fix * lint fix * lint fix * lint fix * minor fix * fix * recover check * Refactor argument binding and validation in `arg_binder.cc` - Improved null handling and validation checks in `BindDLTensor`, ensuring safe dereferencing of pointers. - Enhanced consistency checks for buffer shape, strides, and data fields, utilizing expression-level guards. - Updated `MakePackedAPI` to maintain code clarity and consistency in argument handling. - Minor adjustments in test files to streamline kernel execution and improve readability. * lint fix * stride fix * minor fix * fix * lint fix * lint fix * Add CUDA stream access policy window helpers and integrate with L2 persistent cache management - Introduced functions to set and reset the CUDA stream access policy window, allowing for better control over L2 cache usage. - Updated runtime files to include new FFI packed functions for managing stream attributes. - Modified lower_hopper_intrin to incorporate prologue and epilogue statements for L2 cache setup and teardown. - Enhanced tests to verify the inclusion of new FFI calls in the generated kernel source. * check with symbolic * support null ptr * Update CMakeLists and lower.py for code generation and subproject status - Added `codegen_c_host.cc` to the list of source files in CMakeLists.txt for improved code generation support. - Updated the function call in `lower.py` to use `target.build.tilelang_c` for C target host code generation, enhancing compatibility. - Marked the TVM subproject as dirty to indicate local modifications. * lint fix * Update comments for clarity in quickstart.py * [Bugfix] Supply missing `T.print` for bool type (#1279) * fix for bool dtype * lint fix * fix * ci fix * [Fix] Fix memory leak bug (#1281) * add typing stub for tir.ir * remove idents * minor update * [Refactor] add numpy conversion for dtype * fix lint error * remove unused np.float_ in dtype conversion * fix type in np.int_ * fix typo * minor fix * remove debug files * fix memory leak bug * fix lint error * add comments * fix lint error * remove duplicated, because tilelang doesn't dependent deprecated * [Enhancement] Enhance CUDA compilation by integrating pass context configuration (#1283) - Updated the `tilelang_callback_cuda_compile` function to accept a `pass_config` parameter, allowing for more flexible compilation options. - Introduced handling for fast math and PTXAS options based on the provided pass configuration. - Modified the CUDA build process in `rt_mod_cuda.cc` to utilize the current pass context, improving the integration of compilation settings. - Refactored NVCC command construction to use a dedicated function for better clarity and maintainability. * Fix the bug in issue #1266 (#1284) Co-authored-by: cheeryBloosm <[email protected]> * [Language][UX] Nested loop checker in pre-lowering stage (#1288) * [Language][UX] Nested loop checker in pre-lowering stage * rename * comment * address comments * [Compatibility] Support CUDA 11.3 (#1290) * [Feat] Add support for using `T.Tensor(n * 2 + 1)` in function annotation (#1285) * [Feature] Add support for A: T.Tensor(n + 1) and A: T.Tensor(2*n) * issue fix * fix * fix * decreate nproc for debugging --------- Co-authored-by: Lei Wang <[email protected]> * [Feat] add support for passing reference in T.Var annotation (#1291) * [Enhancement] Shared Memory Size Can be Dynamic (#1294) * bugfix * lint fix * test * lint fix * increate procs * recover * [Fix] Remove unused let_bindings_ in CodeGenC to fix #1300 (#1305) * [Feat] add missing support of uint32x2 * [Feat] Add `T.Ref` annotation and tests * fix lint error * minor update for error message on twice decl * Remove unused let_bindings_ in CodeGenC to fix #1300 * [Bugfix] Fallback to the old AtomicAdd implementation for legacy architectures (#1306) * [Fix] Fix frame scope error in T.macro (#1308) * [Fix] Fix #1307 by adding macro inside function * fix lint error * add comments and fix lint error * Remove debug print from enter_frame method Removed debug print statement from enter_frame method. --------- Co-authored-by: Lei Wang <[email protected]> * [WIP] support more dtypes for tcgen05 (#1229) support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis * Improve memory access safety and `T.assume` handling (#1292) * Improve memory access safety and T.assume handling * Improve memory access safety and T.assume handling * bugfix * lint fix * bugfix * bugfix * refactor legalize safe memory access pass --------- Co-authored-by: Lei Wang <[email protected]> * [Bugfix] Fix autotune cache (#1315) * [Refactor] Backup Analyzer to get the appropriate arith informations (#1311) * [Refactor] Update Vectorization Functions to Accept Analyzer Parameter - Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization. - Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness. - Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities. * [Fix] Corrected PostOrderVisit call in loop_vectorize.cc - Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis. * fix * lint fix * fix * Revert "[WIP] support more dtypes for tcgen05 (#1229)" (#1323) This reverts commit 0d101c110f74ebf2ef8c11a5ece9dfb314b48baa. Co-authored-by: Zhiwen Mo <[email protected]> * [CI]: Bump actions/checkout from 5 to 6 (#1319) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [CI]: Bump pypa/cibuildwheel from 3.2 to 3.3 (#1318) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [Installation] Fix building using customized TVM path (#1326) * [Release] Allow developer with write permission to trigger wheel release (#1322) * [Feat] Support warp reduce (#1316) * [Feat] Support warp reduce * lint * add test * lint * [Enhancement] Support more dtype in `T.print` (#1329) * [Enhancement] Support more dtype in `T.print` * upd * upd * [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape (#1321) * [BugFix] Use BufferRegion in tl.cumsum to infer buffer shape * remove debug lines * remove rubbish * Fix decorator syntax for atomic_different_memory_orders_program --------- Co-authored-by: Lei Wang <[email protected]> * [Fix] fix wrong uint narrowing bug in tvm in #1310 (#1320) * [Refactor] Disable strided buffer load inside tvm (#1301) (#1332) * [Refactor] Moving `NormalizeToBufferRegion` and `MakeAccessPtrFromRegion` to utils (#1333) * Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse. * lint fix * [Fix] Fix bug copying from or to local buffer (#1304) (#1324) * [Fix] fix copy from or to local buffer (#1304) * fix lint error * minor fix testing script * [Language][UX] Semantic check for parallel fragment access (#1338) * Add unit tests for T.assume (#1341) * Add test for T.assume * Add unit test for T.assume * Add unit test for T.assume * Add unit tests for T.assume * Remove debug print for kernel source Remove print statement for kernel source in tests. * Update test_tilelang_language_assume.py --------- Co-authored-by: Lei Wang <[email protected]> * [Feat] Extend LegalizeNegativeIndex to support buffer store stmts (#1339) This commit enhances the LegalizeNegativeIndex transformation pass to handle both buffer load and store operations with negative indices and adds some test cases. * [Refactor] Phaseout vmap for Tile Operators (#1334) * Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse. * lint fix * Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations. * fix * Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions. * fix * fix * test fix * lint fix * Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management. * fix * lint fix * fix * fix * test fix * lint fix * lint fix * minor fix * fix --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Enhancement] add more dtype and fix mma.ws for fp16 for tcgen05 (#1327) * feat: add fp8 variants; add placeholder for fp6/fp4 in meta support ld with pack for fp32 dtype add dump add tempalte expand remove unused dtype and change to rebased apis * fix: when atom-m!=128, enable_ws * fix: typo in tcgen05 meta; dispatch in gemm sm100 * [Refactor] Enhance CopyNode's IterVar Creation and Range Handling (#1346) * [Refactor] Enhance CopyNode's IterVar Creation and Range Handling This commit refines the `MakeIterVars` method in `CopyNode` to select base ranges based on memory scope levels, ensuring that the chosen ranges are not smaller than the original source ranges. Additionally, it updates the Python `copy` function to clarify range handling, including broadcasting logic and extent alignment. These changes improve the robustness and clarity of the copy operation's implementation. * test fix * [Fix] Fix missing `not` rewrite in frontend (#1348) * [Enhancement] Add support for k_pack in gemm_mfma (#1344) * add support for k_pack * support benchmark on ROCm * fix format * Add sparse fine-tuning kernel for deepseek sparse attention to example (#1296) * [EXAMPLE] add example for dsa sparse finetuning * [Refactor] * [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder (#1352) * [Refactor] Improve assertion handling in CodeGenCHost and ArgBinder This commit refines the assertion message generation in CodeGenCHost by optimizing the handling of equality checks and reducing buffer size for error messages. Additionally, it enhances the ArgBinder by introducing a nullable guard mechanism for assertions, allowing for more precise error handling when binding arguments. The changes improve the clarity and efficiency of assertion handling across the codebase. * [Enhancement] Update matmul kernel and optimize argument binding This commit enhances the matmul kernel by introducing additional tensor parameters and refining the pipeline stages for improved performance. It also updates the argument binding mechanism to include a flag indicating whether buffers are used, enhancing the efficiency of buffer management. Furthermore, the optimization phase in the engine is improved by adding a simplification step, ensuring better performance and clarity in the generated code. * lint fix * [Enhancement] Add tensor checks documentation and improve argument binding assertions This commit introduces a new documentation page for host-side tensor checks, detailing the automatic validations performed by TileLang on kernel arguments. It enhances the ArgBinder by adding assertions for non-null pointers when arguments are used, improving error handling. Additionally, the optimization phase in the engine is updated to include a simplification step, ensuring better performance and clarity in the generated code. * [Enhancement] Update .gitignore and refine matmul kernel for improved performance This commit adds host checks logs to the .gitignore file to prevent unnecessary log files from being tracked. Additionally, it refines the matmul kernel by adjusting pipeline stages, updating tensor parameters, and enhancing argument handling for better performance. The changes also include improved error messages in the argument binding process, ensuring clearer diagnostics for users. * lint fix * lint fix * [Refactor] Simplify tensor_null_test function and remove ptr_null_test This commit refactors the tensor_null_test function by adding a with_bias parameter and removing the ptr_null_test function, which was previously unused. The run_test function is updated to reflect these changes, streamlining the testing process for tensor operations. * lint fix * fix * [Refactor] Simplify index sign state handling in LegalizeNegativeIndex (#1354) This commit refines the logic for determining the sign state of indices in the LegalizeNegativeIndex transformation. It prioritizes vector patterns, specifically Ramp and Broadcast nodes, to avoid compile-time lane queries. The handling of scalar indices is also streamlined, ensuring clearer diagnostics when non-negativity cannot be proven. These changes enhance the robustness and clarity of index handling in the transformation pass. * [Enhancement] Improve error handling and assertion messages across runtime and argument binding (#1356) This commit enhances the error handling mechanisms in the runtime by introducing CPU-safe runtime helpers and refining assertion messages in the CodeGenCHost and ArgBinder. It includes structured packed error messages for various conditions, improving clarity in diagnostics. Additionally, the CMake configuration is updated to always include necessary runtime helpers, ensuring consistent error reporting. The changes aim to provide clearer feedback during runtime errors and improve the overall robustness of the argument binding process. * [Bugfix] Disable floordiv optimization due to integer overflow risk (#1355) * disable overflow-prone floordiv optimization in lower_intrin.cc * disable overflow-prone floordiv optimization in lower_intrin.cc * [Bugfix] Fix the jit_kernel issue (#1357) * [Bugfix] Fix the jit_kernel issue * Update README.md --------- Co-authored-by: Lei Wang <[email protected]> * [Refactor] Update Fragment Indexing in ParallelOpNode's InferLayout Method (#1359) This commit refines the Fragment creation process in the InferLayout method of ParallelOpNode. It removes the unnecessary forward_index array and utilizes default fragment indexing for consistency with other operations. Additionally, it binds the thread range to enhance comparability across different operations. * [Analysis] Enhance NestedLoopChecker with tile op cases (#1358) * [Analysis] Enhance NestedLoopChecker with tile op cases * fix tileop issue * [Language] support `T.gemm_sp_v2` on sm80 and sm89 (#1056) * [misc] add a cpp side wrapper for gemm_sp_py * [misc] typing * [IR] bind GemmSPWarpPolicy * [chore] add wrapper code * [IR] fix GemmSPWarpPolicy * [codegen] apply ptxas instructions * [intrinsic] add typical (unused) mma layout * [template] add uint16 debug func * [intrinsic] add b matrix layout * [gemm_sp] enable fp16/bf16 on sm8x * [layout] refactor fp16/bf16 layout * [gemm_sp] enable int8 * [chore] update test case dtype * [gemm_sp] enable fp32 * [layout] refactor layouts * [intrinsic] enable ldmatrix for mat A * [layout] enable ldsm for matrix b * [layout] add ldmatrix for fp32 and fp8 * [chore] refine * [chore] refactor * [chore] add fp8 efactor * [chore] refactor * [chore] add remove negative zero util * [example] add a custom compress kernel * [chore] minor update * [test] refactor gemm_sp test * [refactor] make metadata layout func * [example] add option for using cutlass layout * [doc] add a gemm_sp doc * [doc] minor polish * [chore] remove unused * [bugfix] fix non replicate b case * [test] refactor * [chore] add a check * [bugfix] fix util bug * [wip] init a new test case for v2 * [chore] minor refactor * [chore] minor update * [bugfix] enable 16bit rs * [language] enable rs * [language] enable gemm_sp_sr * [language] enable gemm_sp_rr * [test] enable more tests * [tvm] update ffi binding * [chore] remove print * [chore] fix benchmark script * [lint] precommit lint * [chore] apply feedback * [test] use arch 8.0 * [chore] rollback ::ordered_metadata for backward compatibility * [bugfix] fix captialized * [example] keep gemm_sp on hopper * [test] fix no fp8 normal kernel * [test] reduce matmul size to satisfy accum error * [test] use cal_diff for assertion * [bugfix] expand float8 type * [lib] add make_int4 for short type * [language] add transpose E * [bugfix] fix wrong var * [format] format * [chore] refactor binding * [chore] fix wrong passing var * [Bugfix] Update TIR registration for GemmSPPy to use tile operation (#1361) * [Enhancement] Implement dynamic unroll factor in CUDA code generation (#1360) * [Enhancement] Implement dynamic unroll factor in CUDA code generation This commit introduces support for specifying a dynamic unroll factor in the CUDA code generation. The `unroll_factor` map is added to store unroll factors for loop variables, allowing for more flexible and optimized loop unrolling. Additionally, the `unroll` function is integrated into the loop language, enabling users to define unroll factors directly in their code. This enhancement improves performance by allowing tailored unrolling strategies based on specific loop characteristics. * lint fix * [Bugfix] Correct initialization of non-zero counters in custom compress kernel and update TIR registration for gemm_sp_py to use the correct tile operation * [CI] [pre-commit.ci] autoupdate (#1362) updates: - [github.com/pre-commit/mirrors-clang-format: v21.1.2 → v21.1.6](https://github.com/pre-commit/mirrors-clang-format/compare/v21.1.2...v21.1.6) - [github.com/astral-sh/ruff-pre-commit: v0.14.3 → v0.14.7](https://github.com/astral-sh/ruff-pre-commit/compare/v0.14.3...v0.14.7) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Bugfix] Remove debug print in PyStmtFunctionVisitor (#1363) * [Debug] Always include line info in NVCC command for improved profiling and mapping (#1364) * [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py (#1365) * [Enhancement] Add DISABLE_CACHE environment variables (#1368) * [Refactor]: Remove useless include in atomicadd_vectorize.h (#1371) * [Refactor] Generalize fp8 process (#1372) * [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py * [Enhancement] Extend support for float8 data types in GEMM operations - Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`. - Refactored condition checks in `checkWgmma` methods to simplify float8 type handling. - Adjusted test cases to ensure compatibility with the new float8 types in tile language examples. * lint fix * [Layout] Enhance Free Layout Inference (#1375) * [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py * [Enhancement] Extend support for float8 data types in GEMM operations - Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`. - Refactored condition checks in `checkWgmma` methods to simplify float8 type handling. - Adjusted test cases to ensure compatibility with the new float8 types in tile language examples. * lint fix * [Enhancement] Add injective layout detection and exception handling - Introduced `DetectInjective` method in `FragmentNode` to check for injective layouts. - Added `LoopLayoutInjectiveException` to handle errors related to non-injective layouts. - Updated `InferLayout` methods in `ParallelOpNode` to utilize injective checks and log relevant information. - Refactored layout inference queue management to use `std::deque` for improved performance and added prioritization logic for buffer layouts. * remove debug print * remove debug print * remove debug print * minor layout fix * fix for T.view * [Enhancement] Improve injective layout detection in FragmentNode - Updated the `DetectInjective` method to handle symbolic dimensions more effectively by introducing a mechanism to collect symbolic shapes and adjust the detection level accordingly. - Added logging for cases where the layout detection falls back to NoCheck due to symbolic dimensions. - Minor update to the test file to include the tilelang testing module. * [Refactor] Simplify layout inference for bulk copy operations - Removed unnecessary conditions for bulk load/store operations in the layout inference logic. - Streamlined the handling of layout application for bulk copy instances to enhance clarity and maintainability. * remove debug print * [Enhancement] Introduce layout-related exceptions and improve error handling - Added `LayoutConflictException` and `LoopLayoutInjectiveException` classes for better exception management in layout operations. - Updated `InferLayout` method in `ParallelOpNode` to throw `LoopLayoutInjectiveException` with detailed error information when injective layout checks fail. - Removed redundant exception class definitions from `parallel.h` to streamline code organization. * [Enhancement] Introduce buffer var lca analysis for pass plan buffer allocations (#1376) * Update submodule TVM to latest commit and add PlanAndUpdateBufferAllocationLocation function to transform module - Updated the TVM submodule to commit 3a32b763. - Added a new function `PlanAndUpdateBufferAllocationLocation` in the transform module to facilitate buffer allocation planning within PrimFuncs. * Refactor buffer allocation code for improved readability and consistency - Updated formatting and spacing in `plan_update_buffer_allocation_location.cc` for better code clarity. - Standardized the use of pointer and reference syntax across various class methods. - Enhanced comments for better understanding of buffer allocation logic. - Removed unnecessary lines and improved overall code structure. * Refactor buffer allocation checks for improved clarity - Replaced size checks with empty checks for `ffi::Array<Buffer>` in `plan_update_buffer_allocation_location.cc` to enhance code readability. - Updated conditions in multiple methods to use `empty()` instead of comparing size to zero, streamlining the logic. * [Tool] Provide layout visualization tool (#1353) * Provide layout visualization tool Adds a layout visualization tool to TileLang, which helps users understand and debug the layout transformations applied during compilation. This tool visualizes the memory layout of tensors at different stages of the compilation process, allowing developers to identify potential inefficiencies and optimize their code for better performance. The visualization can be enabled via a pass config option. * format * add layout visual example * Adds vis extra with matplotlib dependency * rafactor pass config name * fix lint * Enables configurable layout visualization formats Allows users to specify the output formats (png, pdf, svg) for layout visualization through a pass config option. This change provides more flexibility in how layout visualizations are generated, allowing users to choose the formats that best suit their needs. It also fixes a bug where layout visualization was not correctly disabled when the config option was set to "false". * Adds visual layout inference tool docs * fix lint * fix lint * Rafactor configurable layout visualization formats * fix lint * fix typo * add some comments * fix lints * add some warnings for user * Moves layout visualization * Refactors layout visualization pass configuration Updates the layout visualization pass configuration to use boolean flag for enabling and a string for specifying formats. * Enables multiple layout visualization formats * Updates layout visualization docs * Moves layout visualization to analysis * [Release] Relax constraint of tvm-ffi to compatible version (#1373) Co-authored-by: LeiWang1999 <[email protected]> * [Language] Tilelang LazyJIT Experimental Version (#1337) * initial step * modify builder * scratch version of new frontend * write some tests * add many tests * add typing stub for tir.ir * remove idents * minor update * minor update * First version of jitv2 (renamed to LazyJIT) * fix pre-commit error * minor fix * fix lint error * fix lint error * Fix conditional check for PrimFunc instance --------- Co-authored-by: Lei Wang <[email protected]> * [Builder] Enhance variable name binding and scope management (#1378) - Improved handling of TVM Var/Buffer names to prevent out-of-scope errors when reusing Python names across different for-frames. - Added assertions to ensure variables are defined within the correct control flow frame, enhancing error checking and code reliability. * [Bugfix] make cuda driver api compat with cuda12/13, along with tests (#1379) * [Fix] typo in cuda attr (#1380) * [Bugfix] make cuda driver api compat with cuda12/13, along with tests * fix typo in cudaDevAttr * [Language V2] Minor fix for complex annotations (#1381) * [Release] Bump Version into 0.1.7 (#1377) * Update VERSION to 0.1.7 * Update Python version in distribution scripts to support CPython 3.9 and log output * [Typing] Enhance compatibility for advanced typing features in Python (#1382) - Updated `allocate.py` and `annot.py` to improve compatibility with Python 3.9 and later by conditionally importing advanced typing features such as `TypeVarTuple`, `Unpack`, and `ParamSpec`. - Added fallback imports from `typing_extensions` for environments using earlier Python versions. - Improved handling of generic alias detection to ensure consistent behavior across different Python versions. * [Bugfix][Build] Update CMake configuration to remove project root injection for sys.path (#1385) * [Build] Update CMake configuration for tilelang_cython_wrapper installation - Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib. - Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules. - Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects. * [Build] Standardize output directories for tilelang libraries - Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds. - This change enhances organization and ensures that all build artifacts are located in a unified directory structure. * [BugFix] Fix split kernel layout bug of GQA decode (#1386) * [BugFix] Fix split kernel layout bug of GQA decode * [BugFix] Avoid local with Parallel; use robust fragment instead * [Enhancement] Add debug output methods for Layout and Fragment classes (#1392) * [Doc] Update logging docs (#1395) * [Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399) * [Build] Update CMake configuration for tilelang_cython_wrapper installation - Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib. - Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules. - Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects. * [Build] Standardize output directories for tilelang libraries - Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds. - This change enhances organization and ensures that all build artifacts are located in a unified directory structure. * [Refactor] Update TVM subproject and enhance pipeline loop handling - Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0. - Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management. - Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts. - Simplified access index calculations and strengthened analyzer constraints for loop bounds. * [Cleanup] Remove license block and unused includes from inject_pipeline.cc - Eliminated the Apache license block from the top of the file to streamline the code. - Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies. * [Refactor] Enhance transformation pipeline and test execution - Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization. - Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation. * [AMD] Fix 3 bugs when build docker on amd mi3x gpu (#1401) * [Typo] Fix tilelang link in README.md (#1402) * [Dependency] Update apache-tvm-ffi version to >=0.1.2 (#1400) * [Dependency] Update apache-tvm-ffi version to >=0.1.2 in project files * [Dependency] Update subproject commit for TVM to latest version afc07935 * [Enhancement] Add support for optional step parameter in loop constructs - Updated loop creation functions to accept an optional step parameter, enhancing flexibility in loop definitions. - Modified ForFrame implementations to utilize the new step parameter across various loop types including serial, parallel, and pipelined loops. - Adjusted related vectorization transformations to accommodate the step parameter, ensuring consistent behavior in loop vectorization processes. * lint fix * [AMD] Enable FA2 fwd on AMD MI300X (#1406) * enable FA2 on AMD MI300X * make lint happy * [TypoFix] fix typo for SM120 (#1408) * [Doc] Minor documentation update (#1410) * [Dependency] Add torch-c-dlpack-ext to project requirements (#1403) * [Dependency] Add torch-c-dlpack-ext to project requirements * Added torch-c-dlpack-ext to both pyproject.toml and requirements.txt to provide prebuilt torch extensions, which may prevent JIT compilation on first import of TVM FFI. * [Build] Update manylinux images in project configuration * Changed the manylinux image for x86_64 from "manylinux2014" to "manylinux_2_28" in both pyproject.toml and the Dockerfile to align with updated standards for compatibility and performance. * [Build] Update CUDA repository configuration in pyproject.toml * Changed the package manager command from `yum-config-manager` to `dnf config-manager` for adding the CUDA repository, ensuring compatibility with newer systems. * fix * [Build] Update CUDA repository to RHEL 8 * Changed the CUDA repository configuration in both pyproject.toml and the manylinux Dockerfile from RHEL 7 to RHEL 8, ensuring compatibility with newer systems. * test: run out of space * use cu130 to reduce size * upd * upd comment * upd --------- Co-authored-by: Your Name <[email protected]> * [Dependency] Update TVM subproject to latest commit 2b1ead1a (#1412) * [Enhancement] Introduce `T.__ldg` (#1414) * [Enhancement] Add __ldg intrinsic for CUDA read-only cache loads * Introduced the __ldg intrinsic to enable explicit read-only cached loads from global memory in CUDA. * Updated the corresponding documentation and added support in both CUDA and HIP code generation. * Enhanced the Python interface for __ldg to accept BufferLoad and Buffer types, improving usability. * [Enhancement] Update formatting and linting rules in pyproject.toml; minor test adjustment * Added new formatting rules in pyproject.toml to enforce consistent code style, including hanging indents and argument splitting. * Updated test_tilelang_language_intrinsics_codegen.py to improve readability by adding a blank line before the main execution block. * Refactored error messages in builtin.py for better clarity and consistency, ensuring proper formatting in function definitions and raising ValueErrors. * lint fix * [Enhancement] Improve vectorization invariant check (#1398) * Improve loop vectorize * Improve loop vectorize * Improve loop vectorize * Improve loop vectorize * Improve loop vectorize * Add some vectorize tests and comments * [Lint] Phaseout Yapf format and embrace ruff format (#1417) * [Atomic] Use ptr for atomicAdd dst instead of reference (#1425) * [Enhancement] Update AtomicAdd function signature to accept pointer to destination * Modified AtomicAdd in CUDA to take a pointer instead of a reference for the destination argument. * Updated related code in atomicadd_vectorize.cc to ensure compatibility with the new signature. * Adjusted Python interface in atomic.py to pass the destination by pointer, aligning with device function requirements. * [Enhancement] Refactor AtomicAddRet function signature to accept pointer * Updated AtomicAddRet in both CUDA and HIP to take a pointer instead of a reference for the address argument, improving consistency with the AtomicAdd function. * Adjusted the implementation to ensure proper reinterpretation of the address type for atomic operations. * lint fix * [Enhancement] Refactor AtomicAddNode::MakeSIMTLoop to use destination pointer * Updated the MakeSIMTLoop function to build a pointer to the destination element using tvm_access_ptr instead of loading the destination value directly. * Simplified the handling of source and destination predicates, improving clarity and maintainability of the code. * Ensured compatibility with the new pointer-based approach for atomic operations. * lint fix * test fix * lint fix * [CUDA] Add read-only parameter annotation for CUDA codegen (#1416) * [Enhancement] Add read-only parameter annotation for CUDA codegen * Introduced the `AnnotateReadOnlyParams` transformation to annotate read-only handle parameters in PrimFuncs, enabling the generation of `const` qualifiers in CUDA codegen. * Updated `PrintFunctionSignature` and `AddFunction` methods to utilize the new attribute `tl.readonly_param_indices`, enhancing performance by allowing read-only cache loads. * Modified the optimization pipeline to include the new annotation step, improving the overall efficiency of the code generation process. * lint fix * [Dependency] Update apache-tvm-ffi version to >=0.1.3 * Updated the version of apache-tvm-ffi in pyproject.toml, requirements.txt, and requirements-dev.txt to ensure compatibility with the latest features and fixes. * Made adjustments in CUDA and HIP template files to use `const` qualifiers for global pointer parameters, enhancing code safety and clarity. * lint fix * [Enhancement] Refactor ReadWriteMarker for improved parameter handling * Updated the ReadWriteMarker class to accept a set of parameter or data variables, enhancing its ability to track written variables. * Introduced a new method, ResolveDataVarFromPtrArg, to resolve underlying buffer data from pointer-like arguments, improving accuracy in identifying written variables. * Modified the MarkReadOnlyParams function to gather handle parameters and their corresponding buffer data variables, streamlining the process of determining read-only parameters. * Enhanced the logic for identifying written variables to account for aliased data variables, ensuring comprehensive tracking of modifications. * lint fix * Update tma_load function to use const qualifier for global memory pointer * Changed the parameter type of gmem_ptr in the tma_load function from void* to void const* to enhance type safety and clarity in memory operations. * This modification ensures that the function correctly handles read-only global memory pointers, aligning with best practices in CUDA programming. * Remove commented-out code and reorder transformations in OptimizeForTarget function for clarity * Refactor buffer marking logic in annotate_read_only_params.cc to improve accuracy in identifying written variables. Update OptimizeForTarget function to reorder transformations for better clarity. * [Refactor] Phase out the primitives folder since its design has been merged into tileop (#1429) * Phase out primitives * revert changes * Refactor GemmWarpPolicy method signature for clarity Updated the `from_warp_partition` method in the `GemmWarpPolicy` class to return the type `GemmWarpPolicy` instead of a string, enhancing type safety and clarity in the codebase. Removed an unnecessary blank line for improved readability. * fix * [CI]: Bump actions/upload-artifact from 5 to 6 (#1431) Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 5 to 6. - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [CI]: Bump actions/download-artifact from 6 to 7 (#1432) Bumps [actions/download-artifact](https://github.com/actions/download-artifact) from 6 to 7. - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * [Bugfix] Convey `compile_flags` to ffi compilation path with pass_configs (#1434) * [Enhancement] Add device compile flags support in pass configuration * Introduced `kDeviceCompileFlags` option in the pass configuration to allow additional device compiler flags for CUDA compilation. * Updated the `tilelang_callback_cuda_compile` function to merge extra flags from the pass configuration, enhancing flexibility in compiler options. * Modified the `JITKernel` class to handle device compile flags appropriately, ensuring they are included during compilation. * Documented the new pass configuration key for clarity on usage and expected input formats. * lint fix * [Refactor] Simplify compile_flags handling in JIT functions * Removed redundant string check for compile_flags in the compile, jit, and lazy_jit functions, ensuring compile_flags is consistently treated as a list. * Updated the JITKernel class to handle compile_flags as a list when a string is provided, enhancing code clarity and maintainability. * lint fix * fix * [Enhancement] Improve buffer usage tracking in MakePackedAPI (#1435) * Added detailed logging for data and shape variable parameters during buffer usage detection in the MakePackedAPI function. * Refactored the UsedBufferDetector to differentiate between used parameters by data and shape variables, enhancing clarity in buffer management. * Updated logic to ensure minimal carrier buffers are selected for shape symbols, improving the efficiency of parameter handling. * [Enhancement] Improve InjectAssumes logic and make assumes work after SplitHostDevice (#1405) * [Refactor] Refactor InjectAssumes logic and make assumes work after SplitHostDevice * address comments * fix * fix submodule * fix * fix 3rdparty * [Enhancement] Include PrimFunc name in memory cache logs for better debugging (#1437) * Added the `get_prim_func_name` utility to extract human-readable function names from TVM PrimFuncs. * Updated memory cache logging in `AutoTuner` and `KernelCache` classes to include the kernel name, improving clarity during cache hits. * Enhanced debug logging to provide more informative messages when checking disk cache for kernels. * [CI] Update lint dependencies and fix lint on trunk (#1433) * [CI] Update pre-commit hooks * [Lint] Pass correct `exclude-header-filter` to `clang-tidy` * [Lint] Download latest `run-clang-tidy` script * [CI] Show compile commands * [CI] Add output grouping to GHA * [Lint] Re-order pre-commit hooks * [Enhancement] Refactor vectorization checks in loop_vectorize (#1440) * Introduced a new function, IsExprInvariantInVectorBoundary, to encapsulate the logic for checking if an expression is invariant within vector boundaries, improving code clarity and reusability. * Updated the existing vectorization logic to utilize this new function, streamlining the process of determining vectorization feasibility based on boundary conditions. * Enhanced comments for better understanding of the vectorization criteria and mathematical rationale behind the checks. * Enhance vectorized conversion support (#1438) * [Feature] Support region as input of T.cumsum (#1426) * [Feature] Support region as input of T.cumsum - Extend T.cumsum to accept BufferRegion and BufferLoad inputs in addition to Buffer - This enables operations on buffer slices/regions like: T.cumsum(InputG_fragment[i * chunk_size:(i + 1) * chunk_size], dim=0) - Update cumsum_fragment to handle region inputs properly - Add comprehensive tests for 1D and 2D region inputs including normal and reverse modes Fixes #879 * Fix formatting and add docstring for cumsum_fragment - Add comprehensive docstring for cumsum_fragment function - Format code according to ruff style guidelines * Fix CodeRabbit review issues - Fix negative dimension bounds check (dim < -len(shape) instead of dim <= -len(shape)) - Add src/dst shape compatibility validation for out-of-place cumsum - Update copy() type annotation to accept BufferRegion as dst parameter - Fix test in-place mutation issues by using out-of-place cumsum operations - Add non-divisible size test cases for tail region coverage * Fix out-of-bounds access in region tests - Add bounds clamping using T.min() for chunk_end calculations - Prevents accessing beyond tensor bounds for non-divisible sizes - Matches reference implementation behavior - Fixes both 1D and 2D region test cases * Fix region test: use simple slice expressions instead of T.min() - Remove T.min() which cannot be used directly in slice indices - Use chunk_start + chunk_size form instead - Rely on system's automatic bounds checking for non-divisible sizes - Update comments to reflect this approach * Fix cumsum region: use region extents in lowering and update tests for shared memory * Simplify fragment scope check using is_fragment() --------- Co-authored-by: LeiWang1999 <[email protected]> * [Fix] Fix analyzer bind conflicting (#1446) * [Refactor] Reduce direct dependency on PyTorch due to its limited type support (#1444) * [Enhancement] Update KernelParam to use tvm.DataType directly and add torch_dtype conversion method - Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions. - Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation. - Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization. * [Enhancement] Refactor CUDA type handling and add support for FP4 and FP8 types - Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively. - Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64). - Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions. - Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing. - Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch. - Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively. * lint fix * Remove unnecessary logging statements from CUDA code generation and delete obsolete matrix multiplication test file. * [Enhancement] Add support for FP4 and FP8 types in CUDA code generation - Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types. - Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch. - Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations. - Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations. - Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows. --------- Co-authored-by: Zhiwen Mo <[email protected]> * [Refactor] Use `pytest.mark.parameterize` to speedup parallel testing (#1447) * Refactor GEMM tests to use parameterized pytest fixtures - Converted multiple test cases for GEMM operations in `test_tilelang_tilelibrary_gemm_sp.py` to use `pytest.mark.parametrize` for better maintainability and readability. - Similar refactoring applied to `test_tilelang_tilelibrary_gemm_sp_v2.py`, consolidating test cases for `run_gemm_ss`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` into parameterized tests. - This change reduces code duplication and enhances the clarity of test configurations. * Update testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * [Docs] Improve installation instructions for developers (#1450) * [Feat] Integrate Z3 in TVM Arith Analyzer (#1367) * [Bugfix] Improve autotune from elementwise_add function in examples (#1445) * Remove JIT decorator from elementwise_add function in examples * fix kernel compilation without autotune * Refactor main function to accept parameters and update tests for autotune option * Refactor autotune test function for morden style * [Language] Introduce `T.annotate_restrict_buffers` (#1428) * [Enhancement] Introduce non-restrict parameter support in code generation - Added a new PrimFunc-level attribute `tl.non_restrict_params` to specify handle Vars that should not be marked with the restrict qualifier during code generation. - Updated `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to handle non-restrict parameters, ensuring proper treatment of overlapping buffer aliases. - Implemented a new annotation function `annotate_restrict_buffers` to facilitate the marking of buffer parameters as non-restrict. - Enhanced the `SplitHostDevice` transformation to propagate non-restrict parameters from host to device functions. - Added a new transform function `HoistNonRestrictParams` to manage non-restrict parameters effectively. * [Enhancement] Improve HoistNonRestrictParams transformation - Updated the HoistNonRestrictParams function to recursively collect all `tl.non_restrict_params` annotations from nested blocks, enhancing flexibility in annotation placement. - Introduced a new NonRestrictCollector class to manage the collection and deduplication of non-restrict parameters. - Modified the SplitHostDevice transformation to remove the non-restrict attribute from the host-side PrimFunc after propagation to device kernels. - Adjusted the LowerAndLegalize function to directly apply the HoistNonRestrictParams transformation without exception handling, streamlining the process. * [Refactor] Simplify non-restrict parameter handling in code generation - Removed unnecessary normalization logic and associated data structures from `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP`. - Streamlined the handling of non-restrict parameters by directly inserting them into the `non_restrict` set, improving code clarity and maintainability. - Updated conditional checks to eliminate redundant checks against normalized names, enhancing performance and readability. * [Dependency] Update TVM subproject to latest commit 68aa8461 - Updated the TVM subproject to the latest commit, ensuring compatibility with recent changes and improvements. - Refactored non-restrict parameter handling in `CodeGenTileLangCPP`, `CodeGenTileLangCUDA`, and `CodeGenTileLangHIP` to enhance code clarity and maintainability. - Adjusted the `SplitHostDevice` transformation to streamline the propagation of non-restrict parameters. * fix * [Analyzer] Require loop extent > 0 when entering loop (#1451) * Updat ROCm CI to Nightly-ROCm-7.1 (#1449) * [Enhancement] Update examples and tests for improved type handling functionality (#1448) * [Enhancement] Update examples and tests for improved type handling and functionality - Enhanced various example scripts to support new data types and improve compatibility with PyTorch. - Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling. - Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management. - Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling. * [Refactor] Update accumulation data type to float32 across examples - Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability. - This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board. * [Refactor] Standardize data type usage across benchmark scripts - Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling. - Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard. - Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules. * [Refactor] Standardize data type usage in templates and scripts - Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity. - Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates. - This change aims to streamline type handling and improve compatibility with existing workflows. * [Refactor] Standardize data type usage in examples and benchmarks - Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability. - Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard. - Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples. * [Refactor] Import dtypes from language.v2 module - Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase. - This change aims to streamline data type management and improve overall code clarity. * fix * [Refactor] Standardize data type usage across scripts - Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity. - Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability. - This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase. * [Refactor] Update data type handling for consistency and clarity - Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency. - Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase. - Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling. - This refactor aims to streamline data type management and improve overall code clarity and maintainability. * [Enhancement] Improve data type handling and error messaging - Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation. - Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs. - Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience. * [Fix] Correct boolean flag in GEMM SP test case - Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case. - This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation. * [Refactor] Standardize data type usage across scripts - Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability. - Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase. - This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability. * [Refactor] Standardize data type usage in various modules - Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability. - Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase. - This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations. * [Refactor] Update argument parsing for data types in benchmarks - Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float. - This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability. * [Refactor] Update data type handling in benchmark and example scripts - Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency. - Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase. - This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples. * [Refactor] Fix data type conversion in multiple scripts - Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts. - This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations. * [Refactor] Update float8 data type usage across multiple scripts - Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling. - This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations. * [Refactor] Enhance float8 data type handling in CUDA code generation - Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic. - Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase. - Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation. - This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy. * [Refactor] Streamline float8 data type handling in CUDA and related modules - Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks. - Updated layout inference for float8 types to improve clarity and maintainability across the implementation. - This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy. * [Refactor] Remove unnecessary cache disabling in float8 example script - Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code. - This change enhances clarity and maintainability of the example script without affecting its functionality. * [Refactor] Update data type usage in debug print tests - Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type. - This update…
Summary by CodeRabbit
Release Notes
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.