-
Notifications
You must be signed in to change notification settings - Fork 425
[Refactor] Phase out the primitives folder since its design has been merged into tileop #1429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughMoved and re-exported GemmWarpPolicy to Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Updated the `from_warp_partition` method in the `GemmWarpPolicy` class to return the type `GemmWarpPolicy` instead of a string, enhancing type safety and clarity in the codebase. Removed an unnecessary blank line for improved readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/tileop/base.py (1)
161-186: Fix misleading docstring examples (and optionally drop the redundant forward-ref quotes).
The examples callfrom_block_row_cols, but the method isfrom_warp_partition, so the docstring is currently wrong.@classmethod - def from_warp_partition(cls, m_warp: int, n_warp: int) -> "GemmWarpPolicy": + def from_warp_partition(cls, m_warp: int, n_warp: int) -> GemmWarpPolicy: @@ - >>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows + >>> GemmWarpPolicy.from_warp_partition(4, 1) # All warps in rows GemmWarpPolicy.FullRow - >>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns + >>> GemmWarpPolicy.from_warp_partition(1, 4) # All warps in columns GemmWarpPolicy.FullCol - >>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution + >>> GemmWarpPolicy.from_warp_partition(2, 2) # Balanced distribution GemmWarpPolicy.Squaretilelang/tileop/gemm_sp/__init__.py (1)
10-55: Fix incorrectGemmWarpPolicyimport for FFI field annotation.The
policyfield (line 54) is a TVM FFI Object that should be annotated with theGemmWarpPolicyfromtilelang.ir(a Node/Scriptable class), not theIntEnumfromtilelang.tileop.baseimported on line 11. The C++ implementation constructspolicyas aGemmWarpPolicyNodeObjectRef, which maps to the IR version. Replace the import or add a qualified reference to avoid the name collision and ensure the type annotation correctly reflects the actual FFI field type.
🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)
62-63: Import reorganization is correct.The imports are properly split:
GemmWarpPolicynow comes from the new centralized locationtilelang.tileop.base, while thegemmfunctions continue to be imported from the local.gemmmodule. This correctly exposes both through thetilelang.languagenamespace.Optional cleanup: Static analysis indicates the
# noqa: F401directives may be unnecessary if the F401 rule isn't enabled in your linter configuration. Consider removing them for cleaner code, or keep them as defensive annotations for re-exports.-from tilelang.tileop.base import GemmWarpPolicy # noqa: F401 -from .gemm import gemm, gemm_v1, gemm_v2 # noqa: F401 +from tilelang.tileop.base import GemmWarpPolicy +from .gemm import gemm, gemm_v1, gemm_v2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (15)
examples/amd/example_amd_flash_attn_bwd.py(1 hunks)examples/amd/example_amd_flash_attn_fwd.py(1 hunks)testing/python/primitives/test_tilelang_primitives_mma.py(0 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/experimental/gemm_sp.py(1 hunks)tilelang/language/gemm.py(1 hunks)tilelang/primitives/__init__.py(0 hunks)tilelang/primitives/gemm/__init__.py(0 hunks)tilelang/primitives/gemm/gemm_mma.py(0 hunks)tilelang/tileop/__init__.py(1 hunks)tilelang/tileop/base.py(1 hunks)tilelang/tileop/gemm/__init__.py(0 hunks)tilelang/tileop/gemm/gemm_base.py(1 hunks)tilelang/tileop/gemm_sp/__init__.py(1 hunks)tilelang/tileop/gemm_sp/gemm_sp_base.py(1 hunks)
💤 Files with no reviewable changes (5)
- tilelang/tileop/gemm/init.py
- tilelang/primitives/init.py
- tilelang/primitives/gemm/gemm_mma.py
- tilelang/primitives/gemm/init.py
- testing/python/primitives/test_tilelang_primitives_mma.py
🧰 Additional context used
🧬 Code graph analysis (7)
tilelang/tileop/gemm_sp/gemm_sp_base.py (2)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)tilelang/ir.py (1)
GemmWarpPolicy(26-33)
examples/amd/example_amd_flash_attn_bwd.py (1)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)
tilelang/tileop/__init__.py (2)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)tilelang/ir.py (1)
GemmWarpPolicy(26-33)
examples/amd/example_amd_flash_attn_fwd.py (1)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)
tilelang/language/gemm.py (2)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)tilelang/ir.py (1)
GemmWarpPolicy(26-33)
tilelang/language/experimental/gemm_sp.py (3)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)tilelang/ir.py (1)
GemmWarpPolicy(26-33)src/op/gemm.h (4)
GemmWarpPolicy(59-83)GemmWarpPolicy(64-68)GemmWarpPolicy(70-74)GemmWarpPolicy(76-82)
tilelang/tileop/gemm/gemm_base.py (2)
tilelang/tileop/base.py (1)
GemmWarpPolicy(6-186)tilelang/ir.py (1)
GemmWarpPolicy(26-33)
🪛 Ruff (0.14.8)
tilelang/tileop/__init__.py
1-1: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/__init__.py
62-62: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
63-63: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (5)
tilelang/language/experimental/gemm_sp.py (1)
3-6: Import path update looks correct and consistent with the refactor.The change from
tilelang.primitivestotilelang.tileop.baseforGemmWarpPolicyaligns with the module reorganization.tilelang/tileop/gemm/gemm_base.py (1)
6-6: Import path update looks correct.The import is correctly updated to use the new centralized location at
tilelang.tileop.base. This aligns with the PR's goal of consolidating GEMM-related types into the tileop module.examples/amd/example_amd_flash_attn_bwd.py (1)
5-5: Import path update is correct and consistent.The import is properly updated to use
tilelang.tileop.base. The file correctly uses both the direct import (GemmWarpPolicy.FullRowin the forward kernel) and the module-qualified access (T.GemmWarpPolicy.FullRowin the backward kernel), both of which resolve to the same relocated type.tilelang/language/gemm.py (1)
4-4: Import path update is correct.The import is properly updated to the new centralized location. The
GemmWarpPolicytype is consistently used throughout the file for type annotations and default parameter values, and this change maintains full backward compatibility for the public API.examples/amd/example_amd_flash_attn_fwd.py (1)
5-5: Import path update is correct.The import is properly updated to
tilelang.tileop.base, consistent with the codebase-wide refactoring.
| from .base import GemmWarpPolicy # noqa: F401 | ||
| from .gemm import GemmPy # noqa: F401 | ||
| from .gemm_sp import GemmSPPy # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the stale noqa to satisfy Ruff (RUF100) while keeping the re-export.
-from .base import GemmWarpPolicy # noqa: F401
+from .base import GemmWarpPolicy📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from .base import GemmWarpPolicy # noqa: F401 | |
| from .gemm import GemmPy # noqa: F401 | |
| from .gemm_sp import GemmSPPy # noqa: F401 | |
| from .base import GemmWarpPolicy | |
| from .gemm import GemmPy # noqa: F401 | |
| from .gemm_sp import GemmSPPy # noqa: F401 |
| from .base import GemmWarpPolicy # noqa: F401 | |
| from .gemm import GemmPy # noqa: F401 | |
| from .gemm_sp import GemmSPPy # noqa: F401 | |
| from .base import GemmWarpPolicy | |
| from .gemm import GemmPy | |
| from .gemm_sp import GemmSPPy |
🧰 Tools
🪛 Ruff (0.14.8)
1-1: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
2-2: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
3-3: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In tilelang/tileop/__init__.py lines 1-3, remove the stale "# noqa: F401"
comments and keep the imports as real exports, then add an explicit __all__ =
["GemmWarpPolicy", "GemmPy", "GemmSPPy"] so the names are considered used and
the re-exports remain while satisfying Ruff (RUF100).
| from tilelang.utils.language import is_shared, is_fragment | ||
| from tilelang.ir import GemmWarpPolicy | ||
| from tilelang.tileop.base import GemmWarpPolicy | ||
| from tvm.ir.base import Node | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# 1) Find any remaining imports of the IR-side GemmWarpPolicy
echo "=== Searching for imports from tilelang.ir GemmWarpPolicy ==="
rg -n --type=py 'from\s+tilelang\.ir\s+import\s+GemmWarpPolicy' || echo "No matches found"
# 2) Find call sites that might expect the IR method signature compute_warp_partition
echo -e "\n=== Searching for compute_warp_partition calls ==="
rg -n --type=py 'compute_warp_partition\s*\(' -B2 -A3 || echo "No matches found"
# 3) Find explicit references to tilelang.ir.GemmWarpPolicy
echo -e "\n=== Searching for explicit IR.GemmWarpPolicy references ==="
rg -n --type=py 'tilelang\.ir\.GemmWarpPolicy|from tilelang\.ir import.*GemmWarpPolicy' || echo "No matches found"
# 4) Check if tilelang.ir.GemmWarpPolicy still exists
echo -e "\n=== Checking if tilelang/ir module has GemmWarpPolicy ==="
find . -name "*.py" -path "*/tilelang/ir*" -type f | head -20Repository: tile-ai/tilelang
Length of output: 8294
🏁 Script executed:
#!/bin/bash
# Check the GemmWarpPolicy in tileop/base.py
echo "=== GemmWarpPolicy in tileop/base.py ==="
rg -n --type=py -A15 'class GemmWarpPolicy' tilelang/tileop/base.py || true
# Check how self.policy is defined and typed in gemm_base.py or similar
echo -e "\n=== Checking policy attribute definition in GemmBase ==="
rg -n --type=py 'self\.policy' tilelang/tileop/base.py | head -20 || true
# Check the actual GemmWarpPolicy definitions in tilelang/ir.py
echo -e "\n=== GemmWarpPolicy definitions in tilelang/ir.py ==="
rg -n --type=py -B3 -A8 'class GemmWarpPolicy|class GemmSPWarpPolicy' tilelang/ir.py || trueRepository: tile-ai/tilelang
Length of output: 1593
🏁 Script executed:
#!/bin/bash
# Check GemmSPBase class and its parent to understand self.policy
echo "=== GemmSPBase class definition ==="
rg -n --type=py -B2 -A20 'class GemmSPBase' tilelang/tileop/gemm_sp/gemm_sp_base.py || true
# Check GemmBase class definition
echo -e "\n=== GemmBase class definition ==="
rg -n --type=py -B2 -A30 'class GemmBase' tilelang/tileop/base.py | head -50 || true
# Look for self.policy assignment in GemmSPBase or GemmBase
echo -e "\n=== self.policy assignments in codebase ==="
rg -n --type=py 'self\.policy\s*=' tilelang/tileop/ -B2 -A2 || trueRepository: tile-ai/tilelang
Length of output: 936
🏁 Script executed:
#!/bin/bash
# Check full gemm_sp_base.py file
echo "=== Full gemm_sp_base.py content ==="
cat -n tilelang/tileop/gemm_sp/gemm_sp_base.py || true
# Search for @property in the codebase for policy
echo -e "\n=== @property definitions related to policy ==="
rg -n --type=py -B1 -A5 '@property' tilelang/tileop/gemm_sp/gemm_sp_base.py || true
# Check parent class definitions
echo -e "\n=== GemmBase class in tileop/base.py ==="
rg -n --type=py -B2 -A50 'class GemmBase' tilelang/tileop/base.py | head -80 || trueRepository: tile-ai/tilelang
Length of output: 6790
🏁 Script executed:
#!/bin/bash
# Check git history to see what the original import was
echo "=== Checking git log for the import change ==="
git log --oneline -20 tilelang/tileop/gemm_sp/gemm_sp_base.py 2>/dev/null || echo "Git history not available"
# Verify what gemm_sp_node.policy actually returns
echo -e "\n=== Checking gemm_sp_node type definition ==="
rg -n --type=py 'gemm_sp_node\.policy' tilelang/ -B2 -A2 || trueRepository: tile-ai/tilelang
Length of output: 402
Incorrect type annotation: GemmWarpPolicy import should be from tilelang.ir, not tilelang.tileop.base.
The @property policy on line 130 declares return type GemmWarpPolicy but imports it from tilelang.tileop.base (an IntEnum with values 0, 1, 2). However, the actual runtime value self.gemm_sp_node.policy returns the IR-side GemmWarpPolicy (a Node/Scriptable object with compute_warp_partition method). All call sites in gemm_sp_mma.py and elsewhere invoke self.policy.compute_warp_partition(...), which requires the IR-side type. The type annotation now misleads IDEs and type checkers about the actual type.
Fix: Change line 6 to from tilelang.ir import GemmWarpPolicy to match the actual return type and usage.
🤖 Prompt for AI Agents
In tilelang/tileop/gemm_sp/gemm_sp_base.py around lines 5 to 8, the type
annotation imports GemmWarpPolicy from tilelang.tileop.base but the property
policy actually returns the IR-side GemmWarpPolicy (a Node with methods like
compute_warp_partition); change the import to from tilelang.ir import
GemmWarpPolicy so the annotated type matches runtime usage and
IDE/type-checkers; update the single import line accordingly and ensure no other
references rely on the old IntEnum import.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/tl_templates/cuda/atomic.h (1)
609-648: Pre-existing: PTX instruction ordering inconsistency inAtomicAddx4Ret.Not introduced by this PR, but worth noting: the PTX instructions in
AtomicAddx4Retuseatom.global.gpu.release(lines 621, 629, 638) whileAtomicAddx4usesatom.release.gpu.global(lines 581, 589, 598).Per PTX ISA, the correct ordering is
.sem.scope.space(e.g.,atom.release.gpu.global). TheAtomicAddx4Retvariant appears to have the modifiers in the wrong order. Consider addressing this in a follow-up.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tl_templates/cuda/atomic.h(5 hunks)tilelang/language/atomic.py(4 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: lucifer1004
Repo: tile-ai/tilelang PR: 1256
File: testing/python/jit/test_tilelang_jit_gemm_nvrtc.py:55-115
Timestamp: 2025-11-14T07:56:11.098Z
Learning: In `testing/python/jit/test_tilelang_jit_gemm_nvrtc.py`, the global function `tilelang_callback_cuda_postproc` registered via `tvm.register_global_func(..., override=True)` is intentionally not restored after the test completes, as the persistent behavior is expected.
🧬 Code graph analysis (1)
tilelang/language/atomic.py (1)
tilelang/language/tir/op.py (2)
call_extern(173-195)address_of(464-480)
🔇 Additional comments (9)
tilelang/language/atomic.py (4)
59-68: Pointer-based extern calls look correct.The change from passing
dstdirectly toT.address_of(dst)aligns with the updated CUDAAtomicMaxsignature that now expectsT1 *refinstead ofT1 &ref. The implementation correctly handles both the no-memory-order and memory-order variants.
110-119: LGTM!Consistent with
atomic_max- correctly passes pointer viaT.address_of(dst)for both memory-order variants.
340-340: LGTM!Correctly updated to pass
T.address_of(src)matching the pointer-basedAtomicLoad(T *ref, ...)signature in the CUDA header.
393-393: LGTM!Correctly updated to pass
T.address_of(dst)for the destination while keepingsrcas a value, matching theAtomicStore(T1 *ref, T2 value, ...)signature.src/tl_templates/cuda/atomic.h (5)
48-77: Pointer-based signature change looks correct.The conversion from
T1 &reftoT1 *refis properly implemented:
- Parameter changed to pointer type
- Internal assignment simplified from
T1 *address = &ref;toT1 *address = ref;- The rest of the function logic (CAS loop for half/bf16, atomic_ref for other types) correctly uses the pointer
79-108: LGTM!Consistent pointer-based implementation matching
AtomicMax. The return semantics are preserved correctly.
110-170: LGTM!Both
AtomicMinandAtomicMinRetcorrectly updated to pointer-based signatures with the same pattern as the max variants.
693-700: LGTM!Clean conversion to pointer-based signature. The
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*ref)correctly dereferences the pointer.
702-711: LGTM!Correctly updated to pointer-based signature, consistent with
AtomicLoad.
This pull request refactors the codebase to move the
GemmWarpPolicyand related GEMM primitive logic from theprimitivesmodule to a newtileopmodule, and removes unused or redundant code. It updates all relevant imports and removes obsolete files and test cases. The changes help to clarify module boundaries and simplify the codebase.Module refactoring and import updates:
GemmWarpPolicyfromtilelang.primitives.gemm.basetotilelang.tileop.base, and updated all imports in the codebase to reference the new location. (examples/amd/example_amd_flash_attn_bwd.py,examples/amd/example_amd_flash_attn_fwd.py,tilelang/language/__init__.py,tilelang/language/experimental/gemm_sp.py,tilelang/language/gemm.py,tilelang/tileop/__init__.py,tilelang/tileop/base.py) [1] [2] [3] [4] [5] [6] [7]Code cleanup and removal:
tilelang/primitives/gemmdirectory, including thebase.pyandgemm_mma.pyfiles, and their associated logic, as this functionality has been moved or is no longer needed. (tilelang/primitives/gemm/__init__.py,tilelang/primitives/gemm/gemm_mma.py) [1] [2]tilelang/primitives/__init__.pyfile as it is no longer required after the refactor.Test and example updates:
testing/python/primitives/test_tilelang_primitives_mma.pyas it depended on the now-removed primitives GEMM implementation.testing/python/language/test_tilelang_language_atomic_add.pyto directly runrun_atomic_maxinstead of the removedtilelang.testing.main().Summary by CodeRabbit
Refactor
Tests
Breaking Changes
✏️ Tip: You can customize this high-level summary in your review settings.