-
Notifications
You must be signed in to change notification settings - Fork 442
[Feature] Add more curand operations & support vectorization #1582
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a generator parameter to rng_init, introduces rng_rand_float, generalizes CUDA codegen to track multiple CURAND state types and emit appropriate float/bitwidth/vectorized intrinsics, relaxes vectorization for rng_rand, and updates Python APIs and tests to exercise multiple generators and outputs. (37 words) Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User as User (Python)
participant TL as TileLang IR
participant CG as CUDA Codegen
participant CURAND as CURAND (GPU)
rect rgb(220,235,255)
User->>TL: call rng_init(seed, seq, off, generator)
TL->>CG: lower tl.rng_init(generator)
CG->>CURAND: curandInit(..., state pointer of type curand_random_generator_state_type)
CURAND-->>CG: state initialized
end
rect rgb(225,255,230)
User->>TL: call rng_rand_float(bit, dist) / rng_rand()
TL->>CG: emit tl.rng_rand_* intrinsic
CG->>CG: select intrinsic based on curand_random_generator_state_type, bit, dist, lanes
CG->>CURAND: curand_uniform/normal / double / curandN / lane variants using curand_random_generator_state
CURAND-->>CG: random values
CG-->>TL: lowered value returned to User kernel
end
Note over CG: Codegen chooses per-call variant (vectorized lanes & 32/64-bit) using stored state type
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)
108-111: Remove unnecessarynoqadirectives.Static analysis (Ruff) reports that the
# noqa: F401directives on lines 108-111 are unused because F401 is not enabled. These can be safely removed.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
src/op/builtin.ccsrc/op/builtin.hsrc/target/codegen_cuda.ccsrc/target/codegen_cuda.hsrc/transform/loop_vectorize.cctesting/python/language/test_tilelang_language_rand.pytilelang/language/__init__.pytilelang/language/random.py
🧰 Additional context used
🧠 Learnings (2)
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
src/op/builtin.htilelang/language/__init__.pytilelang/language/random.pysrc/op/builtin.cc
📚 Learning: 2025-12-18T04:49:52.473Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:49:52.473Z
Learning: Document and maintain test alignment for RNG behavior: In testing/python/language/test_tilelang_language_rand.py, TileLang kernel uses blk_M = 1 (single block) and calls rng_rand() four times per element to align with Triton’s behavior, which uses blk_M = 128 (multiple blocks) and calls the RNG once per element. Keep this difference consistent in tests and comment why the four RNG calls compensate for internal RNG differences, and avoid altering this test without revalidating RNG semantics across implementations.
Applied to files:
testing/python/language/test_tilelang_language_rand.py
🧬 Code graph analysis (6)
tilelang/language/__init__.py (1)
tilelang/language/random.py (4)
rng_rand_uniform(51-59)rng_rand_uniform_double(62-70)rng_rand_normal(73-81)rng_rand_normal_double(84-92)
src/transform/loop_vectorize.cc (1)
tilelang/language/random.py (1)
rng_init(6-37)
tilelang/language/random.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
src/target/codegen_cuda.cc (1)
tilelang/language/random.py (5)
rng_rand(40-48)rng_rand_uniform(51-59)rng_rand_uniform_double(62-70)rng_rand_normal(73-81)rng_rand_normal_double(84-92)
testing/python/language/test_tilelang_language_rand.py (2)
examples/rand/rand_uint.py (2)
tilelang_rand_1d(9-25)rand_kernel(15-23)tilelang/language/random.py (6)
rng_init(6-37)rng_rand(40-48)rng_rand_uniform(51-59)rng_rand_uniform_double(62-70)rng_rand_normal(73-81)rng_rand_normal_double(84-92)
src/op/builtin.cc (1)
tilelang/language/random.py (6)
rng_init(6-37)rng_rand(40-48)rng_rand_uniform(51-59)rng_rand_uniform_double(62-70)rng_rand_normal(73-81)rng_rand_normal_double(84-92)
🪛 Ruff (0.14.10)
tilelang/language/__init__.py
108-108: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
109-109: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
110-110: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
111-111: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (13)
src/transform/loop_vectorize.cc (1)
155-158: LGTM! Vectorization enabled for RNG generation.The change correctly allows
rng_rand()to be vectorized while keepingrng_init()scalar. This aligns with the broader PR changes that emit vectorized curand intrinsics (e.g.,curand4,curand_uniform4) during codegen.src/target/codegen_cuda.h (1)
92-93: LGTM! Generalized RNG state tracking.The refactoring from a Philox-specific state to generalized
curand_random_generator_stateandcurand_random_generator_state_typemembers properly supports multiple RNG generator types (Philox, MRG32k3a, XORWOW).src/op/builtin.h (1)
134-137: LGTM! New RNG operation declarations.The four new RNG operation declarations follow the established pattern and clearly indicate their purpose through naming.
src/op/builtin.cc (2)
104-104: LGTM! Updated arity for generator parameter.The change from 3 to 4 inputs correctly reflects the new
generatorparameter added torng_init.
110-128: LGTM! New RNG operation registrations.All four new RNG operations are correctly configured with:
- 0 inputs (stateful operations that read from the generator state)
kOpaqueeffect kind (appropriate for non-deterministic operations)testing/python/language/test_tilelang_language_rand.py (2)
15-49: LGTM! Comprehensive test coverage for new RNG operations.The test properly validates:
- All four new RNG distribution functions
- Multiple generator types (MRG32k3a, Philox, XORWOW)
- Each operation writes to a separate output tensor
55-65: LGTM! Parameterized test for multiple generators.The test correctly parametrizes over different generator types and varying sizes/seeds, providing good coverage of the new generator parameter functionality.
src/target/codegen_cuda.cc (3)
2652-2664: LGTM! Generic RNG state initialization.The updated
rng_initimplementation correctly:
- Extracts the generator type from the new fourth argument
- Stores both the state variable name and type for later use
- Emits the appropriate
curand_initcall
2666-2682: LGTM! Scalar RNG operation emission.All new RNG operations correctly emit the corresponding curand function calls using the stored generator state.
3132-3176: LGTM! Vectorized RNG emission with generator-specific intrinsics.The implementation correctly:
- Checks the generator type before emitting vectorized intrinsics
- Uses Philox-specific vectorized functions (curand4, curand_uniform4, curand_normal4, etc.)
- Falls back to double-precision intrinsics for MRG32k3a and XORWOW
- Maintains the existing scalar broadcast fallback
tilelang/language/random.py (3)
6-37: LGTM! Generator parameter added correctly.The implementation properly:
- Adds a
generatorparameter with a sensible default- Validates the generator type against supported CURAND states
- Passes the generator through to the underlying intrinsic
51-59: LGTM! Uniform float RNG helper added.The new
rng_rand_uniformfunction correctly generates uniformly distributed float32 values.
73-81: LGTM! Normal float RNG helper added.The new
rng_rand_normalfunction correctly generates normally distributed float32 values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tilelang/language/random.py (1)
26-26: Consider usingValueErrorinstead ofassertfor input validation.
assertstatements can be disabled when Python runs with the-O(optimize) flag, allowing invalid generator values to pass through silently. Using an explicitif/raisepattern ensures validation is always enforced.🔎 Proposed fix
- assert generator in ["curandStateMRG32k3a_t", "curandStatePhilox4_32_10_t", "curandStateXORWOW_t"] + allowed_generators = {"curandStateMRG32k3a_t", "curandStatePhilox4_32_10_t", "curandStateXORWOW_t"} + if generator not in allowed_generators: + raise ValueError(f"generator must be one of {allowed_generators}, got {generator!r}")
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/random.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
tilelang/language/random.py
🧬 Code graph analysis (1)
tilelang/language/random.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
🔇 Additional comments (5)
tilelang/language/random.py (5)
40-48: LGTM!The function correctly calls the intrinsic and the documentation is accurate.
51-59: LGTM!Clean implementation with accurate documentation.
62-70: LGTM!The docstring correctly states "64-bit uniformly distributed double", matching the
float64return type. The previous review feedback has been addressed.
73-81: LGTM!Clean implementation with accurate documentation.
84-92: LGTM!The docstring correctly states "64-bit normally distributed double", matching the
float64return type. The previous review feedback has been addressed.
| generator : StringImm | ||
| Set random generator. | ||
| See https://docs.nvidia.com/cuda/curand/group__DEVICE.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring type annotation is inaccurate.
The parameter type is listed as StringImm, but the function accepts a plain Python str. Consider updating to str for clarity.
🔎 Proposed fix
- generator : StringImm
- Set random generator.
- See https://docs.nvidia.com/cuda/curand/group__DEVICE.html
+ generator : str
+ The CURAND generator type to use.
+ See https://docs.nvidia.com/cuda/curand/group__DEVICE.html📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| generator : StringImm | |
| Set random generator. | |
| See https://docs.nvidia.com/cuda/curand/group__DEVICE.html | |
| generator : str | |
| The CURAND generator type to use. | |
| See https://docs.nvidia.com/cuda/curand/group__DEVICE.html |
🤖 Prompt for AI Agents
In tilelang/language/random.py around lines 17 to 19, the docstring lists the
parameter type as `StringImm` but the function actually accepts a plain Python
`str`; update the docstring type annotation to `str` and adjust any surrounding
wording to reflect it accepts a Python string (not an IR/StringImm object),
keeping the link and descriptive text intact.
LeiWang1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we merge those rng_rand* functions into a same rng_rand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tilelang/language/__init__.py (1)
108-108: Static analysis note: thenoqadirective may be unnecessary.The static analysis tool suggests the
noqadirective is unused. Since this is a valid re-export, the# noqa: F401is likely correct but might be redundant depending on your linter configuration.If your linter doesn't complain without the
noqa, consider removing it for cleaner code.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
src/op/builtin.ccsrc/op/builtin.hsrc/target/codegen_cuda.cctesting/python/language/test_tilelang_language_rand.pytilelang/language/__init__.pytilelang/language/random.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/op/builtin.cc
🧰 Additional context used
🧠 Learnings (3)
📓 Common learnings
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
📚 Learning: 2025-12-18T04:49:52.473Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:49:52.473Z
Learning: Document and maintain test alignment for RNG behavior: In testing/python/language/test_tilelang_language_rand.py, TileLang kernel uses blk_M = 1 (single block) and calls rng_rand() four times per element to align with Triton’s behavior, which uses blk_M = 128 (multiple blocks) and calls the RNG once per element. Keep this difference consistent in tests and comment why the four RNG calls compensate for internal RNG differences, and avoid altering this test without revalidating RNG semantics across implementations.
Applied to files:
testing/python/language/test_tilelang_language_rand.py
📚 Learning: 2025-12-18T04:50:00.512Z
Learnt from: silentCoder-dev
Repo: tile-ai/tilelang PR: 1464
File: testing/python/language/test_tilelang_language_rand.py:14-14
Timestamp: 2025-12-18T04:50:00.512Z
Learning: In `testing/python/language/test_tilelang_language_rand.py`, the TileLang kernel uses `blk_M = M` (single block) and calls `rng_rand()` four times per element to align results with the Triton implementation, which uses `blk_M = 128` (multiple blocks) and calls the RNG once per element. These differences compensate for internal RNG behavior differences between TileLang and Triton.
Applied to files:
src/op/builtin.htilelang/language/random.pytilelang/language/__init__.py
🧬 Code graph analysis (5)
src/target/codegen_cuda.cc (1)
tilelang/language/random.py (2)
rng_rand(40-48)rng_rand_float(51-68)
testing/python/language/test_tilelang_language_rand.py (2)
examples/rand/rand_uint.py (2)
tilelang_rand_1d(9-25)rand_kernel(15-23)tilelang/language/random.py (3)
rng_init(6-37)rng_rand(40-48)rng_rand_float(51-68)
src/op/builtin.h (1)
tilelang/language/random.py (1)
rng_rand_float(51-68)
tilelang/language/random.py (1)
tilelang/language/tir/op.py (1)
call_intrin(120-145)
tilelang/language/__init__.py (1)
tilelang/language/random.py (1)
rng_rand_float(51-68)
🪛 Ruff (0.14.10)
tilelang/language/__init__.py
108-108: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with Nightly-ROCm-7.1 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (10)
tilelang/language/__init__.py (1)
105-109: LGTM: New RNG float function properly exported.The new
rng_rand_floatexport is correctly placed alongside the existing RNG functions and follows the established pattern.src/op/builtin.h (1)
134-134: LGTM: New RNG operation declared correctly.The
rng_rand_float()declaration follows the established pattern for builtin operations and is appropriately placed with the other RNG functions.tilelang/language/random.py (2)
6-6: LGTM: Generator parameter properly integrated.The new
generatorparameter is well-designed:
- Has a sensible default value (
curandStatePhilox4_32_10_t)- Includes proper validation against allowed generator types
- Is correctly propagated to the underlying intrinsic call
Note: A past review comment flagged the docstring type annotation as
StringImmwhen it should bestr, but this appears to have been addressed in commit e548fea.Also applies to: 17-19, 26-26, 37-37
51-68: LGTM: New rng_rand_float function is well-implemented.The function properly:
- Validates
bitparameter (32 or 64)- Validates
distparameter (uniform or normal)- Constructs the appropriate intrinsic call with correct return type
src/target/codegen_cuda.cc (3)
2654-2664: LGTM: RNG state tracking generalized to support multiple generators.The refactoring from
curand_philox_statetocurand_random_generator_statewith accompanyingcurand_random_generator_state_typeproperly enables support for multiple CURAND generator types. The initialization correctly extracts the generator type from the intrinsic arguments.
2669-2676: LGTM: rng_rand_float properly implemented.The handler correctly:
- Constructs the appropriate
curand_<dist>orcurand_<dist>_doublefunction name based on bit width- Uses the generic RNG state variable
- Handles both 32-bit and 64-bit float generation
3125-3176: Code correctly implements supported CURAND API combinations with appropriate fallback.The vectorized RNG logic accurately reflects NVIDIA CURAND's actual API capabilities. Philox4_32_10 provides more vectorized variants (curand4, curand_uniform4, curand_normal4, curand_uniform2_double, curand_normal2, curand_normal2_double) than MRG32k3a/XORWOW (curand_normal2, curand_normal2_double only), which matches NVIDIA's documented API. For unsupported combinations, the code gracefully falls back to scalar generation with
make_<type>()to replicate values across lanes. All three supported generators (Philox4_32_10, MRG32k3a, XORWOW) are tested, and the implementation is sound.Likely an incorrect or invalid review comment.
testing/python/language/test_tilelang_language_rand.py (3)
9-9: LGTM: Generator parameter properly integrated.The
generatorparameter is correctly added with a sensible default and properly passed through toT.rng_init().Also applies to: 24-24
15-21: Test expansion properly exercises new RNG functionality.The kernel now generates five different outputs:
- A:
uint32viaT.rng_rand()- B:
float32viaT.rng_rand_float()(uniform, default)- C:
float64viaT.rng_rand_float(bit=64)(uniform)- D:
float32viaT.rng_rand_float(dist="normal")- E:
float64viaT.rng_rand_float(bit=64, dist="normal")This provides good coverage of the new API.
💡 Optional: Consider consolidating loops
The five separate parallel loops could potentially be combined into a single loop that populates all five tensors, reducing code duplication:
for i, j in T.Parallel(threads, num_per_thread): offsets = (bx * threads + i) * num_per_thread idx = offsets + j if idx < M: A[idx] = T.rng_rand() B[idx] = T.rng_rand_float() C[idx] = T.rng_rand_float(bit=64) D[idx] = T.rng_rand_float(dist="normal") E[idx] = T.rng_rand_float(bit=64, dist="normal")However, the current structure may be intentional for testing purposes or to maintain consistency with existing patterns.
Also applies to: 30-49
55-59: LGTM: Test parameterization covers multiple generator types.The test now validates three different CURAND generators:
curandStateMRG32k3a_tcurandStatePhilox4_32_10_tcurandStateXORWOW_tThis ensures the generator parameter works correctly across different CURAND backends.
Based on learnings: The test structure aligns with the documented pattern where TileLang tests compensate for internal RNG behavior differences.
Also applies to: 60-65
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.