-
Notifications
You must be signed in to change notification settings - Fork 327
[FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit fastmath op to invoke #875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ity in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds eight fast-math TL intrinsics and Python wrappers, lowers them in CUDA codegen via new math dispatchers, adds CUDA-gated tests and precision-comparison tooling (including a CUDA operator backend), re-exports fastmath in the language package, and updates the TVM submodule reference. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Py as tilelang.fastmath
participant TIR as TIR CallNode (tl.__*)
participant CG as CUDA Codegen
participant CUDA as CUDA Math Runtime
User->>Py: call __exp(x)
Py->>TIR: tir.call_intrin(tl.__exp, x)
TIR->>CG: Lower CallNode(tl.__exp)
rect #E6F5FF
note right of CG: select dispatcher based on dtype\n(CUDAFastMath / CUDAMath / CUDAFastMathTan)
CG->>CUDA: emit mapped symbol (e.g., expf, __expf, exp)
end
CUDA-->>User: compute result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refines the handling of floating-point mathematical operations within TVM's TileLang, particularly concerning fast-math intrinsics. It transitions from an implicit fast-math approach to an explicit one, providing greater control over numerical precision and performance. Users can now explicitly invoke fast-math versions of functions or enable them globally, while standard mathematical functions like 'exp' will default to their non-fast-math, higher-precision counterparts, ensuring broader compatibility and predictable behavior. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces explicit fast-math intrinsics and disables the default fast-math dispatch for standard math functions to give users more control. My review focuses on improving code maintainability by reducing duplication in the C++ implementation, fixing a critical documentation error in the new Python fast-math module, and significantly strengthening the new Python tests by adding missing numerical correctness checks and improving existing ones to ensure the new functionality is robust and well-tested.
|
|
||
|
|
||
| def __exp(x): | ||
| """Calculate 2**x with fast math |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def run_single_arg_mathop_test(mathop_name, | ||
| mathop_func, | ||
| M=128, | ||
| N=128, | ||
| block_M=32, | ||
| block_N=32, | ||
| dtype="float32"): | ||
| """ | ||
| Test single-argument mathops. | ||
| T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) | ||
| """ | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((M, N), dtype), | ||
| B: T.Tensor((M, N), dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, | ||
| bx * block_N + j]) | ||
|
|
||
| # Test with FAST_MATH disabled | ||
| kernel_no_fastmath = tilelang.compile( | ||
| main, | ||
| out_idx=[1], | ||
| target="cuda", | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, | ||
| }) | ||
|
|
||
| source_no_fastmath = kernel_no_fastmath.get_kernel_source() | ||
|
|
||
| print(f"\n=== Testing {mathop_name} ===") | ||
| print("FAST_MATH=False:") | ||
|
|
||
| # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) | ||
| check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) | ||
|
|
||
| print(f"✓ {mathop_name} compilation and execution test passed") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The run_single_arg_mathop_test function only checks the generated CUDA source code for the presence of fastmath functions, but it doesn't perform any numerical correctness checks. This is a significant gap in testing. Please add a numerical check against a reference implementation (e.g., torch) to ensure the generated code is functionally correct, similar to the checks in other test helpers in this file.
| if cuda_mathop_name == "exp": | ||
| expected = torch.exp(a) | ||
| elif cuda_mathop_name == "log": | ||
| expected = torch.log(a) | ||
| else: | ||
| expected = b_fastmath # Just check compilation works |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numerical correctness check in run_fastmath_mathop_test is incomplete. It only compares the output of exp and log against a torch reference. For all other functions (__exp10, __log2, __log10, __tan, __cos, __sin), it only checks that the code compiles and runs by comparing the output to itself. The tests should be expanded to compare the results of all fastmath functions against their corresponding torch reference implementations to ensure correctness.
if cuda_mathop_name == "exp":
expected = torch.exp(a)
elif cuda_mathop_name == "exp10":
expected = torch.pow(10, a)
elif cuda_mathop_name == "log":
expected = torch.log(a)
elif cuda_mathop_name == "log2":
expected = torch.log2(a)
elif cuda_mathop_name == "log10":
expected = torch.log10(a)
elif cuda_mathop_name == "tan":
expected = torch.tan(a)
elif cuda_mathop_name == "cos":
expected = torch.cos(a)
elif cuda_mathop_name == "sin":
expected = torch.sin(a)
else:
expected = b_fastmath # Just check compilation works| std::string operator()(DataType t, std::string name) const { | ||
| if (t.is_float() && t.bits() == 32) { | ||
| return "__" + name + 'f'; | ||
| } else { | ||
| return CUDAMath::operator()(t, name); | ||
| } | ||
| return ""; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return ""; on line 73 is unreachable because all paths in the if/else statement above it return a value. This dead code should be removed.
std::string operator()(DataType t, std::string name) const {
if (t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return CUDAMath::operator()(t, name);
}
}| } else if (op->op.same_as(tl::__exp())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "exp"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__exp10())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "exp10"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__log())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "log"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__log2())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "log2"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__log10())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "log10"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__tan())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "tan"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__cos())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "cos"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__sin())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "sin"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling the different fast math intrinsics is very repetitive. Each else if block contains almost identical code, creating a new CUDAFastMath object and calling it. This could be refactored to reduce duplication and improve maintainability, for example by using a map from the Op to the math function name string and handling them in a single block or a helper function.
| c_fastmath = kernel_fastmath(a, b) | ||
|
|
||
| # Both should produce similar results | ||
| torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The numerical test in run_two_arg_mathop_test compares the output of the kernel compiled with FAST_MATH=True against the one with FAST_MATH=False. Since these two-argument functions are not expected to be affected by the fast math flag, this test only confirms that behavior. It doesn't verify correctness against a golden reference. It would be stronger to compare the output against the equivalent torch function (e.g., torch.pow, torch.fmod).
| torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3) | |
| if mathop_name == "pow": | |
| expected = torch.pow(a, b) | |
| elif mathop_name == "fmod": | |
| expected = torch.fmod(a, b) | |
| torch.testing.assert_close(c_no_fastmath, expected, rtol=1e-3, atol=1e-3) | |
| torch.testing.assert_close(c_fastmath, expected, rtol=1e-3, atol=1e-3) |
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_mathops_generate_no_fastmath(): | ||
| """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for test_mathops_generate_no_fastmath is confusing. It says "Test that our tl.* mathops generate fastmath CUDA code", but the test name and implementation check for the opposite (i.e., that they generate non-fastmath code when the fast math flag is disabled). The docstring should be corrected to reflect the test's actual behavior.
Additionally, there's a misleading comment on line 92 within run_single_arg_mathop_test: # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf). This comment seems incorrect given the test's logic and should be removed or updated.
"""Test that our tl.* mathops generate non-fastmath CUDA code by default."""There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (5)
src/op/builtin.h (1)
78-86: Avoid double-underscore identifiers in C++ public APIIdentifiers beginning with a double underscore are reserved in C/C++. Using
__exp,__log, etc., in a public header risks UB/collisions. Consider alternative C++ accessor names (e.g.,fast_exp,fast_log, …) and mapping totl.__expstring ops internally, or at least document the intentional deviation.If you intend to keep the
__names, please confirm this won’t be compiled on toolchains enforcing strict reserved identifier rules (e.g., some sanitizers/linters). Do you want help sketching a macro that allows the Op registry name to staytl.__expwhile exposing non-__C++ accessors?src/op/builtin.cc (1)
43-67: Mark fast‑math intrinsics as pure (not opaque)These math ops are side‑effect free. Using
kOpaqueneedlessly inhibits CSE/hoisting and other optimizations. PreferkPureto align with TVM math semantics.Apply this diff:
-TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure)); -TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>( - "TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>( + "TCallEffectKind", Integer(CallEffectKind::kPure));tilelang/language/__init__.py (1)
29-29: Fix noqa code and consider avoiding star importChange the noqa to suppress F403 (star import) instead of F401 (unused imports).
Apply this diff:
-from .fastmath import * # noqa: F401 +from .fastmath import * # noqa: F403Optionally, re-export via all instead of star import to keep linters happy and limit API surface.
testing/python/fastmath/test_mathops_fastmath.py (2)
259-266: Broaden numerical checks for more opsYou can validate more fast‑math ops without much overhead.
Apply this diff:
- if cuda_mathop_name == "exp": + if cuda_mathop_name == "exp": expected = torch.exp(a) elif cuda_mathop_name == "log": expected = torch.log(a) + elif cuda_mathop_name == "log2": + expected = torch.log2(a) + elif cuda_mathop_name == "log10": + expected = torch.log10(a) else: expected = b_fastmath # Just check compilation works
22-46: Regex is fine; add a small guard for false positives (optional)Consider anchoring to
\bbefore the name as well to avoid matching suffixes inside identifiers (very rare here).Example:
- fastmath_pattern = rf"__({mathop_name}f?)\b" - non_fastmath_pattern = rf"(?<!__)({mathop_name}f?)\b" + fastmath_pattern = rf"\b__({mathop_name}f?)\b" + non_fastmath_pattern = rf"(?<!__)\b({mathop_name}f?)\b"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
3rdparty/tvm(1 hunks)src/op/builtin.cc(1 hunks)src/op/builtin.h(1 hunks)src/target/codegen_cuda.cc(2 hunks)testing/python/fastmath/test_mathops_fastmath.py(1 hunks)tilelang/language/__init__.py(1 hunks)tilelang/language/fastmath.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/fastmath.py (2)
src/op/builtin.h (1)
tvm(13-344)tilelang/language/tir/op.py (1)
call_intrin(119-144)
testing/python/fastmath/test_mathops_fastmath.py (4)
tilelang/jit/__init__.py (1)
compile(33-86)tilelang/jit/kernel.py (1)
out_idx(446-447)tilelang/transform/pass_config.py (1)
PassConfigKey(6-89)tilelang/language/tir/op.py (1)
exp10(1870-1883)
🪛 Ruff (0.13.1)
tilelang/language/__init__.py
29-29: from .fastmath import * used; unable to detect undefined names
(F403)
29-29: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
tilelang/language/fastmath.py
140-149: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
141-141: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
142-142: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
143-143: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
144-144: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
145-145: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
146-146: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
147-147: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
148-148: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (5)
3rdparty/tvm (1)
1-1: Submodule update looks intentional.Pointer advance to commit
883e96b42ae0df40c2f7194cc932bbcd9d0c5627aligns with the fast-math integration—no concerns from my side.src/op/builtin.cc (3)
33-33: LGTM: cuTensorMapType constantReturning
UInt(8, 128)is clear and consistent with expected 128-bit handle size.
19-31: PassConfig options registration matches headerAll keys (including enable/disable fast‑math) are registered and consistent with the header and Python PassConfigKey enum.
Please confirm conflict resolution when both TL_ENABLE_FAST_MATH and TL_DISABLE_FAST_MATH are set. If undefined, consider documenting precedence (e.g., ENABLE wins) or asserting mutually exclusive usage.
43-67: Fast-math intrinsics mapping and wrappers verified
codegen_cuda.cc and intrin_rule_cuda.cc handle tl::__exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin for float16/32/64, and tilelang/language/fastmath.py exports all eight.testing/python/fastmath/test_mathops_fastmath.py (1)
208-247: Confirm CUDA emits __-prefixed intrinsics for all tested opsSome CUDA fast‑math intrinsics don’t exist for every function/precision (e.g., exp10). If codegen maps to non‑prefixed variants for any, these assertions will fail.
Please run the verification script in my other comment to confirm that all of: exp, exp10, log, log2, log10, tan, cos, sin are emitted with
__in the generated CUDA for float32. If any lack__support, we can relax the check for that op.
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "tan"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__cos())) { | ||
| CUDAFastMath math_func; | ||
| std::string func_name = math_func(op->dtype, "cos"); | ||
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | ||
| } else if (op->op.same_as(tl::__sin())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the tan-specific fast-math mapper
We landed CUDAFastMathTan to avoid emitting __tanf (too inaccurate) and to route FP16 through the h* variants, but this branch still instantiates the generic CUDAFastMath, so we end up generating __tanf for float32. That defeats the accuracy fix and contradicts the new helper/comment.
Please switch this site to CUDAFastMathTan.
- } else if (op->op.same_as(tl::__tan())) {
- CUDAFastMath math_func;
+ } else if (op->op.same_as(tl::__tan())) {
+ CUDAFastMathTan math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| CUDAFastMath math_func; | |
| std::string func_name = math_func(op->dtype, "tan"); | |
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | |
| } else if (op->op.same_as(tl::__cos())) { | |
| CUDAFastMath math_func; | |
| std::string func_name = math_func(op->dtype, "cos"); | |
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | |
| } else if (op->op.same_as(tl::__sin())) { | |
| } else if (op->op.same_as(tl::__tan())) { | |
| CUDAFastMathTan math_func; | |
| std::string func_name = math_func(op->dtype, "tan"); | |
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | |
| } else if (op->op.same_as(tl::__cos())) { | |
| CUDAFastMath math_func; | |
| std::string func_name = math_func(op->dtype, "cos"); | |
| os << func_name << "(" << PrintExpr(op->args[0]) << ")"; | |
| } else if (op->op.same_as(tl::__sin())) { |
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 1725 to 1732, the branch handling tan
currently instantiates CUDAFastMath which emits __tanf for float32; replace that
instantiation with CUDAFastMathTan so the mapper picks the tan-specific
fast-math variants (and routes FP16 to h* variants). Concretely, change the
local type from CUDAFastMath to CUDAFastMathTan, keep the rest of the call to
math_func(op->dtype, "tan") and the PrintExpr usage unchanged, so the emitted
function name uses the tan-specialized mapper.
| print("FAST_MATH=False:") | ||
|
|
||
| # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) | ||
| check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) | ||
|
|
||
| print(f"✓ {mathop_name} compilation and execution test passed") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix misleading comments and message
This test compiles with FAST_MATH=False and expects non‑fast‑math calls. Update the comments and message to avoid confusion.
Apply this diff:
- print(f"\n=== Testing {mathop_name} ===")
- print("FAST_MATH=False:")
-
- # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
- check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
-
- print(f"✓ {mathop_name} compilation and execution test passed")
+ print(f"\n=== Testing {mathop_name} ===")
+ print("FAST_MATH=False:")
+ check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
+ print(f"✓ {mathop_name} compilation test passed")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print("FAST_MATH=False:") | |
| # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) | |
| check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) | |
| print(f"✓ {mathop_name} compilation and execution test passed") | |
| print(f"\n=== Testing {mathop_name} ===") | |
| print("FAST_MATH=False:") | |
| check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) | |
| print(f"✓ {mathop_name} compilation test passed") |
🤖 Prompt for AI Agents
In testing/python/fastmath/test_mathops_fastmath.py around lines 90 to 96, the
inline comment and print message are misleading because this block compiles with
FAST_MATH=False and therefore expects non-fast-math calls; update the comment to
state that we expect non-fast-math (regular) calls when FAST_MATH is disabled,
and change the print message to clearly reflect that the FAST_MATH=False
compilation and non-fast-math check passed (e.g., "✓ {mathop_name} compilation
and non-fast-math execution test passed"). Ensure the comment and message
wording unambiguously reference FAST_MATH=False and non-fast-math expectations.
| @tilelang.testing.requires_cuda | ||
| def test_mathops_generate_no_fastmath(): | ||
| """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" | ||
| # Based on test results, our tl.* intrinsics actually generate fastmath versions | ||
| # This appears to be the intended behavior | ||
| single_arg_mathops = [ | ||
| ("exp", T.exp), | ||
| ("exp2", T.exp2), | ||
| ("exp10", T.exp10), | ||
| ("log", T.log), | ||
| ("log2", T.log2), | ||
| ("log10", T.log10), | ||
| ("sin", T.sin), | ||
| ("cos", T.cos), | ||
| ("tan", T.tan), | ||
| ("sinh", T.sinh), | ||
| ("cosh", T.cosh), | ||
| ("tanh", T.tanh), | ||
| ("atan", T.atan), | ||
| ("sqrt", T.sqrt), | ||
| ("rsqrt", T.rsqrt), | ||
| ("erf", T.erf), | ||
| ("floor", T.floor), | ||
| ("ceil", T.ceil), | ||
| ("trunc", T.trunc), | ||
| ("round", T.round), | ||
| ("nearbyint", T.nearbyint), | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align test docstring with behavior
The docstring says “generate fastmath CUDA code” but the test name and assertions check non‑fast‑math when FAST_MATH=False. Update the docstring accordingly.
Apply this diff:
-"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
-# Based on test results, our tl.* intrinsics actually generate fastmath versions
-# This appears to be the intended behavior
+"""Test that tl.* mathops generate non-fastmath CUDA calls when FAST_MATH=False."""📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @tilelang.testing.requires_cuda | |
| def test_mathops_generate_no_fastmath(): | |
| """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" | |
| # Based on test results, our tl.* intrinsics actually generate fastmath versions | |
| # This appears to be the intended behavior | |
| single_arg_mathops = [ | |
| ("exp", T.exp), | |
| ("exp2", T.exp2), | |
| ("exp10", T.exp10), | |
| ("log", T.log), | |
| ("log2", T.log2), | |
| ("log10", T.log10), | |
| ("sin", T.sin), | |
| ("cos", T.cos), | |
| ("tan", T.tan), | |
| ("sinh", T.sinh), | |
| ("cosh", T.cosh), | |
| ("tanh", T.tanh), | |
| ("atan", T.atan), | |
| ("sqrt", T.sqrt), | |
| ("rsqrt", T.rsqrt), | |
| ("erf", T.erf), | |
| ("floor", T.floor), | |
| ("ceil", T.ceil), | |
| ("trunc", T.trunc), | |
| ("round", T.round), | |
| ("nearbyint", T.nearbyint), | |
| ] | |
| @tilelang.testing.requires_cuda | |
| def test_mathops_generate_no_fastmath(): | |
| """Test that tl.* mathops generate non-fastmath CUDA calls when FAST_MATH=False.""" | |
| single_arg_mathops = [ | |
| ("exp", T.exp), | |
| ("exp2", T.exp2), | |
| ("exp10", T.exp10), | |
| ("log", T.log), | |
| ("log2", T.log2), | |
| ("log10", T.log10), | |
| ("sin", T.sin), | |
| ("cos", T.cos), | |
| ("tan", T.tan), | |
| ("sinh", T.sinh), | |
| ("cosh", T.cosh), | |
| ("tanh", T.tanh), | |
| ("atan", T.atan), | |
| ("sqrt", T.sqrt), | |
| ("rsqrt", T.rsqrt), | |
| ("erf", T.erf), | |
| ("floor", T.floor), | |
| ("ceil", T.ceil), | |
| ("trunc", T.trunc), | |
| ("round", T.round), | |
| ("nearbyint", T.nearbyint), | |
| ] |
🤖 Prompt for AI Agents
In testing/python/fastmath/test_mathops_fastmath.py around lines 270 to 297, the
test docstring incorrectly states it checks that tl.* mathops "generate fastmath
CUDA code" while the test name and assertions verify non-fast-math behavior when
FAST_MATH=False; update the docstring to reflect that the test ensures tl.*
mathops do NOT generate fastmath CUDA code (i.e., verify non-fast-math
intrinsics are produced) so the docstring matches the test intent and behavior.
| """Calculate 2**x with fast math | ||
| Parameters | ||
| ---------- | ||
| x : PrimExpr | ||
| Input argument. | ||
| Returns | ||
| ------- | ||
| y : PrimExpr | ||
| The result. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the __exp docstring to reflect the natural exponent.
tl.__exp computes the natural exponential (ex), not 2x. The current docstring is misleading for anyone using the fast-math wrapper.
-def __exp(x):
- """Calculate 2**x with fast math
+def __exp(x):
+ """Calculate e**x with fast math📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Calculate 2**x with fast math | |
| Parameters | |
| ---------- | |
| x : PrimExpr | |
| Input argument. | |
| Returns | |
| ------- | |
| y : PrimExpr | |
| The result. | |
| """ | |
| def __exp(x): | |
| """Calculate e**x with fast math | |
| Parameters | |
| ---------- | |
| x : PrimExpr | |
| Input argument. | |
| Returns | |
| ------- | |
| y : PrimExpr | |
| The result. | |
| """ |
🤖 Prompt for AI Agents
In tilelang/language/fastmath.py around lines 124 to 135, the docstring for
__exp incorrectly states it calculates 2**x; update the docstring to state it
computes the natural exponential e**x (or exp(x)), and adjust the
parameter/returns wording if needed to reflect natural exponent semantics (e.g.,
"Calculate e**x with fast math" and keep parameter x : PrimExpr and return y :
PrimExpr). Ensure the description and summary explicitly mention "natural
exponential" or "e**x" to avoid confusion.
…A lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
testing/python/fastmath/test_mathops_fastmath.py (2)
22-29: Tighten regex to avoid false positives and improve robustnessCurrent patterns can match substrings inside longer identifiers (e.g., "texp(") and rely on capturing groups unnecessarily. Anchor on non-word boundaries and use non-capturing groups.
-def check_fastmath_usage(source, mathop_name, expect_fastmath=False): +def check_fastmath_usage(source, mathop_name, expect_fastmath=False): """Check source for fastmath/non-fastmath versions""" - fastmath_pattern = rf"__({mathop_name}f?)\b" - non_fastmath_pattern = rf"(?<!__)({mathop_name}f?)\b" + fastmath_pattern = rf"(?<!\w)__(?:{mathop_name}f?)\b" + non_fastmath_pattern = rf"(?<![_\w])(?:{mathop_name}f?)\b" - fastmath_matches = re.findall(fastmath_pattern, source) - non_fastmath_matches = re.findall(non_fastmath_pattern, source) + fastmath_matches = re.findall(fastmath_pattern, source) + non_fastmath_matches = re.findall(non_fastmath_pattern, source)
244-247: Avoid tan singularities to reduce flakinessClamp inputs for tan to keep away from asymptotes.
- # Ensure positive values for functions that need them - if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: + # Ensure valid/stable domains + if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: a = torch.abs(a) + 0.1 + if cuda_mathop_name == "tan": + a = torch.clamp(a, -1.0, 1.0)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/fastmath/test_mathops_fastmath.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/fastmath/test_mathops_fastmath.py (2)
tilelang/jit/__init__.py (1)
compile(33-86)tilelang/transform/pass_config.py (1)
PassConfigKey(6-89)
🔇 Additional comments (5)
testing/python/fastmath/test_mathops_fastmath.py (5)
76-94: Add numerical correctness checks for single‑arg ops (compare to torch)This helper only checks codegen. Add a golden comparison for all ops under test.
print(f"\n=== Testing {mathop_name} ===") print("FAST_MATH=False:") - # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) - print(f"✓ {mathop_name} compilation and execution test passed") + # Numerical correctness against torch + torch_dtype = getattr(torch, dtype) + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + # Domain restrictions + if mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: + a = torch.abs(a) + 0.1 + if mathop_name == "tan": + a = torch.clamp(a, -1.0, 1.0) + b = kernel_no_fastmath(a) + if mathop_name == "exp": + expected = torch.exp(a) + elif mathop_name == "exp2": + expected = torch.pow(2, a) + elif mathop_name == "exp10": + expected = torch.pow(10, a) + elif mathop_name == "log": + expected = torch.log(a) + elif mathop_name == "log2": + expected = torch.log2(a) + elif mathop_name == "log10": + expected = torch.log10(a) + elif mathop_name == "sin": + expected = torch.sin(a) + elif mathop_name == "cos": + expected = torch.cos(a) + elif mathop_name == "tan": + expected = torch.tan(a) + elif mathop_name == "sinh": + expected = torch.sinh(a) + elif mathop_name == "cosh": + expected = torch.cosh(a) + elif mathop_name == "tanh": + expected = torch.tanh(a) + elif mathop_name == "atan": + expected = torch.atan(a) + elif mathop_name == "sqrt": + expected = torch.sqrt(a) + elif mathop_name == "rsqrt": + expected = torch.reciprocal(torch.sqrt(a)) + elif mathop_name == "erf": + expected = torch.erf(a) + elif mathop_name == "floor": + expected = torch.floor(a) + elif mathop_name == "ceil": + expected = torch.ceil(a) + elif mathop_name == "trunc": + expected = torch.trunc(a) + elif mathop_name == "round": + expected = torch.round(a) + elif mathop_name == "nearbyint": + expected = torch.round(a) + else: + expected = b + torch.testing.assert_close(b, expected, rtol=1e-3, atol=1e-3) + print(f"✓ {mathop_name} compilation and numerical test passed")
87-93: Fix misleading comment and message in FAST_MATH=False blockThis block validates non‑fast‑math calls; update the messaging accordingly.
- print(f"\n=== Testing {mathop_name} ===") - print("FAST_MATH=False:") - - # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) - check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) - - print(f"✓ {mathop_name} compilation and execution test passed") + print(f"\n=== Testing {mathop_name} ===") + print("FAST_MATH=False:") + check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) + print(f"✓ {mathop_name} compilation test passed")
147-165: Compare two‑arg ops against torch reference, not just each otherValidate correctness vs a golden reference for both FAST_MATH settings.
- # Both should produce similar results - torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3) - print(f"✓ {mathop_name} numerical test passed") + # Golden reference + if mathop_name == "pow": + expected = torch.pow(a, b) + elif mathop_name == "fmod": + expected = torch.fmod(a, b) + else: + expected = c_no_fastmath + torch.testing.assert_close(c_no_fastmath, expected, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(c_fastmath, expected, rtol=1e-3, atol=1e-3) + print(f"✓ {mathop_name} numerical test passed")
250-256: Add golden references for all fast‑math opsCover exp10, log2, log10, tan, cos, sin in the comparison.
- if cuda_mathop_name == "exp": - expected = torch.exp(a) - elif cuda_mathop_name == "log": - expected = torch.log(a) - else: - expected = b_fastmath # Just check compilation works + if cuda_mathop_name == "exp": + expected = torch.exp(a) + elif cuda_mathop_name == "exp10": + expected = torch.pow(10, a) + elif cuda_mathop_name == "log": + expected = torch.log(a) + elif cuda_mathop_name == "log2": + expected = torch.log2(a) + elif cuda_mathop_name == "log10": + expected = torch.log10(a) + elif cuda_mathop_name == "tan": + expected = torch.tan(a) + elif cuda_mathop_name == "cos": + expected = torch.cos(a) + elif cuda_mathop_name == "sin": + expected = torch.sin(a) + else: + expected = b_fastmath # Keep compilation check for unknowns
262-266: Align test docstring with behaviorDocstring claims fast‑math generation, but the test checks non‑fast‑math when FAST_MATH=False.
-@tilelang.testing.requires_cuda -def test_mathops_generate_no_fastmath(): - """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" - # Based on test results, our tl.* intrinsics actually generate fastmath versions - # This appears to be the intended behavior +@tilelang.testing.requires_cuda +def test_mathops_generate_no_fastmath(): + """Test that tl.* mathops generate non-fastmath CUDA calls when FAST_MATH=False."""
| @tilelang.testing.requires_cuda | ||
| def test_mathops_generate_no_fastmath(): | ||
| """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" | ||
| # Based on test results, our tl.* intrinsics actually generate fastmath versions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments should be fixed/
…tmath versions and disable cache in main execution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
testing/python/fastmath/test_mathops_fastmath.py (2)
243-246: Add domain constraint for trigonometric functions.Trigonometric functions should have their inputs bounded to reasonable ranges to avoid extreme values that could cause numerical instability.
# Ensure positive values for functions that need them if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: a = torch.abs(a) + 0.1 + elif cuda_mathop_name in ["tan", "sin", "cos"]: + # Keep trigonometric inputs in reasonable range to avoid extreme values + a = torch.clamp(a, -10, 10)
337-338: Parameterize cache control with a fixture or context manager
Callingtilelang.disable_cache()in the test entrypoint applies globally and may leak into other tests. Define a scoped solution (for example, a pytest fixture or context manager that setsTILELANG_DISABLE_CACHEviamonkeypatch) to limit cache‐disable to the desired tests.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
testing/python/fastmath/test_mathops_fastmath.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/fastmath/test_mathops_fastmath.py (3)
tilelang/jit/__init__.py (1)
compile(33-86)tilelang/transform/pass_config.py (1)
PassConfigKey(6-89)tilelang/env.py (1)
disable_cache(232-233)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: format-check
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (5)
testing/python/fastmath/test_mathops_fastmath.py (5)
163-163: Consider validating against PyTorch references for two-arg functions.The test only compares fastmath vs non-fastmath outputs but doesn't verify correctness against a known good implementation.
Add validation against PyTorch:
# Both should produce similar results torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3) + + # Verify against PyTorch reference + if mathop_name == "pow": + expected = torch.pow(a, b) + elif mathop_name == "fmod": + expected = torch.fmod(a, b) + else: + expected = c_no_fastmath + + torch.testing.assert_close(c_no_fastmath, expected, rtol=1e-5, atol=1e-5) print(f"✓ {mathop_name} numerical test passed")
54-94: Enhance test coverage with numerical validation against PyTorch.The current test only validates the generated CUDA source for the presence/absence of fastmath functions but lacks numerical correctness checks. Add checks against PyTorch reference implementations to ensure functional correctness.
Add numerical validation to make the test more robust:
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): """ Test single-argument mathops. T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) """ @T.prim_func def main( A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): for i, j in T.Parallel(block_M, block_N): B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i, bx * block_N + j]) # Test with FAST_MATH disabled kernel_no_fastmath = tilelang.compile( main, out_idx=[1], target="cuda", pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False, }) source_no_fastmath = kernel_no_fastmath.get_kernel_source() print(f"\n=== Testing {mathop_name} ===") print("FAST_MATH=False:") - # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False) - print(f"✓ {mathop_name} compilation and execution test passed") + # Test numerical correctness + torch_dtype = getattr(torch, dtype) + a = torch.randn(M, N, device="cuda", dtype=torch_dtype) + + # Ensure positive values for functions that need them + if mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]: + a = torch.abs(a) + 0.1 + elif mathop_name in ["atan"]: + a = torch.clamp(a, -10, 10) # Keep in reasonable range + + b = kernel_no_fastmath(a) + + # Map to PyTorch equivalent + torch_func_map = { + "exp": torch.exp, + "exp2": lambda x: torch.pow(2, x), + "exp10": lambda x: torch.pow(10, x), + "log": torch.log, + "log2": torch.log2, + "log10": torch.log10, + "sin": torch.sin, + "cos": torch.cos, + "tan": torch.tan, + "sinh": torch.sinh, + "cosh": torch.cosh, + "tanh": torch.tanh, + "atan": torch.atan, + "sqrt": torch.sqrt, + "rsqrt": torch.rsqrt, + "erf": torch.erf, + "floor": torch.floor, + "ceil": torch.ceil, + "trunc": torch.trunc, + "round": torch.round, + "nearbyint": torch.round, # nearbyint is similar to round + } + + if mathop_name in torch_func_map: + expected = torch_func_map[mathop_name](a) + torch.testing.assert_close(b, expected, rtol=1e-5, atol=1e-5) + print(f"✓ {mathop_name} numerical test passed") + + print(f"✓ {mathop_name} compilation and execution test passed")
251-256: Expand numerical validation for all fastmath functions.The test only validates
expandlogagainst PyTorch references. Extend validation to cover all fastmath functions.Apply this diff to validate all fastmath functions:
# Compare with reference implementation if cuda_mathop_name == "exp": expected = torch.exp(a) + elif cuda_mathop_name == "exp10": + expected = torch.pow(10, a) elif cuda_mathop_name == "log": expected = torch.log(a) + elif cuda_mathop_name == "log2": + expected = torch.log2(a) + elif cuda_mathop_name == "log10": + expected = torch.log10(a) + elif cuda_mathop_name == "tan": + expected = torch.tan(a) + elif cuda_mathop_name == "cos": + expected = torch.cos(a) + elif cuda_mathop_name == "sin": + expected = torch.sin(a) else: - expected = b_fastmath # Just check compilation works + print(f"Warning: No reference implementation for {cuda_mathop_name}, skipping numerical check") + return
90-91: Remove misleading comment.The comment is inconsistent with the test's behavior - when
TL_ENABLE_FAST_MATHis False, the test expects non-fastmath functions.- # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf) check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
264-267: Fix incorrect docstring.The docstring states the test checks for fastmath code generation, but the test actually verifies non-fastmath behavior when
TL_ENABLE_FAST_MATHis False.- """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" - # Based on test results, our tl.* intrinsics actually generate - # no fastmath versions - # This appears to be the intended behavior + """Test that tl.* mathops generate non-fastmath CUDA code when TL_ENABLE_FAST_MATH is False."""
This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (6)
maint/precision/README.md (1)
1-11: Document dataset, ranges, and seed used to compute these errorsReaders need the sample size, value domains per op, and RNG seed to reproduce these tables. Add a short preamble summarizing n, ranges for each op (e.g., positive-only for log/sqrt/rsqrt), device/arch, and the reference (“double CUDA precise”) used.
Apply this minimal preface above the first section:
+Precision report methodology +--------------------------- +- Device/arch: CUDA GPU (set via TORCH_CUDA_ARCH_LIST); kernels launched with 256 threads +- Sample size: n = 1,000,000 unless stated otherwise +- Value ranges: + - div: x ~ U[-4, 4], y ~ U[1e-3, 4] + - log/sqrt/rsqrt/inv_sqrt: x ~ U[1e-3, 4] + - others: x ~ U[-4, 4] +- Seed: 0 +- Reference: double-precision CUDA (“Double Precise”) vs each implementation +maint/precision/cuda_ops.cu (2)
124-145: Add default in switch to guard invalid op_typeDefensive default case avoids leaving r uninitialized if an invalid op_type slips through.
Apply to all three kernels:
switch (op_type) { case OP_DIV: r = precise_div(a, b); break; case OP_RECIPROCAL: r = precise_reciprocal(a); break; case OP_EXP: r = precise_exp(a); break; case OP_LOG: r = precise_log(a); break; case OP_SIN: r = precise_sin(a); break; case OP_COS: r = precise_cos(a); break; case OP_SQRT: r = precise_sqrt(a); break; case OP_TANH: r = precise_tanh(a); break; case OP_RSQRT: r = precise_rsqrt(a); break; case OP_INV_SQRT: r = precise_inv_sqrt(a); break; + default: r = a; break; }(Repeat analogous default in double_precise_operator_kernel and fast_operator_kernel.)
Also applies to: 147-168, 170-191
83-122: Consider using CUDA intrinsics over inline PTX for portabilityInline PTX (rcp.approx.f32/sqrt.approx.f32/rsqrt.approx.f32) can be brittle across toolchains/arch; CUDA fast-math intrinsics (__frcp_rn not available, but __fdividef, __fsqrt_rn, etc.) or libdevice may be preferable unless you need explicit opcodes.
Would you like a variant using only CUDA intrinsics to reduce PTX coupling?
maint/precision/compare_ops.py (3)
58-69: Pin compile flags to avoid ambient fast‑math affecting “precise” pathYou note “No fast_math flags”; good. Consider explicitly passing conservative flags (e.g., disable reassociation/FTZ/DAZ) if you need strict IEEE behavior for “precise” baselines.
I can propose vetted flags for nvcc/clang if needed.
49-56: Shebang without executable bit (EXE001)Either remove the shebang or set the script executable in repo metadata.
175-210: Drop unused parameteruse_fastmathin make_tilelang_unary_kernelThe parameter isn’t used inside the nested kernel; it’s only used in compile pass_configs outside. Removing it avoids ARG001.
-def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = False): +def make_tilelang_unary_kernel(M: int, N: int, op_id: int): @@ - def tilelang_unary_kernel( + def tilelang_unary_kernel( A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32"), ): @@ - return tilelang_unary_kernel + return tilelang_unary_kernelAnd update the call site:
- kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath) + kernel_func = make_tilelang_unary_kernel(M, N, op_id)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
maint/precision/README.md(1 hunks)maint/precision/compare_ops.py(1 hunks)maint/precision/cuda_ops.cu(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
maint/precision/compare_ops.py (5)
tilelang/env.py (2)
disable_cache(232-233)get(136-139)tilelang/utils/target.py (1)
determine_target(44-84)tilelang/contrib/nvcc.py (1)
get_target_compute_version(258-299)tilelang/transform/pass_config.py (1)
PassConfigKey(6-89)maint/precision/cuda_ops.cu (6)
launch_double_precise_operator(208-219)launch_double_precise_operator(208-208)launch_precise_operator(194-205)launch_precise_operator(194-194)launch_fast_operator(222-233)launch_fast_operator(222-222)
🪛 Ruff (0.13.1)
maint/precision/compare_ops.py
1-1: Shebang is present but file is not executable
(EXE001)
175-175: Unused function argument: use_fastmath
(ARG001)
272-272: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
286-286: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
320-320: Avoid specifying long messages outside the exception class
(TRY003)
373-373: Do not catch blind exception: Exception
(BLE001)
423-423: Do not catch blind exception: Exception
(BLE001)
🪛 GitHub Actions: CI Test on AMD
maint/precision/compare_ops.py
[error] 297-297: E731 Do not assign a lambda expression, use a def
[error] 314-314: E731 Do not assign a lambda expression, use a def
[error] 438-438: SIM114 Combine if branches using logical or operator
[error] 441-441: SIM114 Combine if branches using logical or operator
🪛 GitHub Actions: CI
maint/precision/compare_ops.py
[error] 297-297: E731 Do not assign a lambda expression, use a def
[error] 314-314: E731 Do not assign a lambda expression, use a def
[error] 438-441: SIM114 Combine if branches using logical or operator
[error] 441-447: SIM114 Combine if branches using logical or operator
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (3)
maint/precision/compare_ops.py (3)
228-266: Possible OOB in TileLang kernels at edges; add bounds guard or confirm safe memory passrow/col are computed from ceildiv grids without explicit guards. Unless TileLang inserts masking by default, edge tiles may access OOB.
Can you confirm T.Parallel + current pass configs guarantee safe bounds? If not, add a guard:
- x = A[row, col] + if row < M and col < N: + x = A[row, col] + ... + # else: implicitly skipAs per pass_config keys, TL_DISABLE_SAFE_MEMORY_ACCESS defaults to False, but please verify generated TIR. Based on learnings
20-28: Be resilient to target parsing failuresdetermine_target/parse_compute_version may vary; guard and fall back to a sensible arch list (or skip setting TORCH_CUDA_ARCH_LIST).
If needed, I can add a try/except around these calls with a fallback to native arch detection via torch.cuda.get_device_capability().
151-169: Verify Triton libdevice symbol nameslibdevice bindings can differ by Triton version (e.g., sqrt vs sqrt_rn). Please confirm these exact names exist in your pinned Triton. If not, map to available ones.
I can generate a quick probe script to introspect triton.language.extra.libdevice for available symbols.
| grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) | ||
|
|
||
| if op_id == 0: # Division - binary operation | ||
| assert y is not None, "Division operation requires second operand" | ||
| triton_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE) | ||
| else: # Unary operation | ||
| triton_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Fix E731: replace lambda with def for Triton grid
CI fails with E731. Replace assigned lambda with a def.
- grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
+ def grid(meta):
+ return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)🧰 Tools
🪛 Ruff (0.13.1)
272-272: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
| grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) | ||
|
|
||
| if op_id == 0: # Division - binary operation | ||
| assert y is not None, "Division operation requires second operand" | ||
| triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE) | ||
| else: # Unary operation | ||
| triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Fix E731: replace lambda with def for LibDevice grid
Same as above.
- grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
+ def grid(meta):
+ return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) | |
| if op_id == 0: # Division - binary operation | |
| assert y is not None, "Division operation requires second operand" | |
| triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE) | |
| else: # Unary operation | |
| triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE) | |
| def grid(meta): | |
| return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) | |
| if op_id == 0: # Division - binary operation | |
| assert y is not None, "Division operation requires second operand" | |
| triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE) | |
| else: # Unary operation | |
| triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE) |
🧰 Tools
🪛 Ruff (0.13.1)
286-286: Do not assign a lambda expression, use a def
Rewrite grid as a def
(E731)
🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 286 to 293, replace the assigned
lambda for grid with a proper function definition to satisfy E731: define def
grid(meta): return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
(keeping the same return shape), then leave the subsequent
triton_libdevice_binary_kernel[grid](...) and
triton_libdevice_unary_kernel[grid](...) calls unchanged so they continue to use
the callable grid.
| abs_err = (output_double - reference_double).abs() | ||
| rel_err = abs_err / (reference_double.abs().clamp_min(1e-30)) | ||
| print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " | ||
| f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: formatting torch.Tensors with format spec causes TypeError on CUDA tensors
abs_err.max()/mean() et al. are 0-d tensors; formatting with “:.3e” raises on CUDA. Convert to Python floats.
- print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
- f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")
+ max_abs = abs_err.max().item()
+ mean_abs = abs_err.mean().item()
+ max_rel = rel_err.max().item()
+ mean_rel = rel_err.mean().item()
+ print(f"{tag:<32} max abs: {max_abs:.3e}, mean abs: {mean_abs:.3e}, "
+ f"max rel: {max_rel:.3e}, mean rel: {mean_rel:.3e}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| abs_err = (output_double - reference_double).abs() | |
| rel_err = abs_err / (reference_double.abs().clamp_min(1e-30)) | |
| print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " | |
| f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}") | |
| abs_err = (output_double - reference_double).abs() | |
| rel_err = abs_err / (reference_double.abs().clamp_min(1e-30)) | |
| max_abs = abs_err.max().item() | |
| mean_abs = abs_err.mean().item() | |
| max_rel = rel_err.max().item() | |
| mean_rel = rel_err.mean().item() | |
| print(f"{tag:<32} max abs: {max_abs:.3e}, mean abs: {mean_abs:.3e}, " | |
| f"max rel: {max_rel:.3e}, mean rel: {mean_rel:.3e}") |
🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 332 to 336, the formatted
f-string is passing 0-d torch.Tensors (e.g. abs_err.max(), abs_err.mean(),
rel_err.max(), rel_err.mean()) into a float format spec which raises a TypeError
for CUDA tensors; fix by converting those 0-d tensors to Python floats (use
.item()) before formatting so the f-string receives plain floats (e.g. replace
abs_err.max() with abs_err.max().item(), and likewise for mean() and the rel_err
values).
| // 精确版本 | ||
| void launch_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) { | ||
| int64_t n = x.numel(); | ||
| int threads = 256; | ||
| int blocks = (n + threads - 1) / threads; | ||
| const float* y_ptr = nullptr; | ||
| if (y.has_value()) { | ||
| y_ptr = y.value().data_ptr<float>(); | ||
| } | ||
| precise_operator_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( | ||
| x.data_ptr<float>(), y_ptr, result.data_ptr<float>(), n, static_cast<OperatorType>(op_type) | ||
| ); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate dtype, contiguity, device, and shape in launchers
Raw pointer indexing assumes contiguous memory and matching dtypes; current code will silently miscompute on strided inputs or wrong dtypes. Add runtime checks to fail fast and prevent OOB/UB.
Apply these diffs to each launcher:
void launch_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
+ TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
+ TORCH_CHECK(result.is_cuda(), "result must be a CUDA tensor");
+ TORCH_CHECK(x.scalar_type() == at::kFloat && result.scalar_type() == at::kFloat,
+ "precise path expects float32 tensors");
+ TORCH_CHECK(x.is_contiguous() && result.is_contiguous(),
+ "x and result must be contiguous");
+ if (y.has_value()) {
+ TORCH_CHECK(y.value().is_cuda(), "y must be a CUDA tensor");
+ TORCH_CHECK(y.value().scalar_type() == at::kFloat, "y must be float32");
+ TORCH_CHECK(y.value().is_contiguous(), "y must be contiguous");
+ TORCH_CHECK(y.value().numel() == x.numel(), "x and y must have same numel");
+ }
int64_t n = x.numel(); void launch_double_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
+ TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
+ TORCH_CHECK(result.is_cuda(), "result must be a CUDA tensor");
+ TORCH_CHECK(x.scalar_type() == at::kDouble && result.scalar_type() == at::kDouble,
+ "double-precise path expects float64 tensors");
+ TORCH_CHECK(x.is_contiguous() && result.is_contiguous(),
+ "x and result must be contiguous");
+ if (y.has_value()) {
+ TORCH_CHECK(y.value().is_cuda(), "y must be a CUDA tensor");
+ TORCH_CHECK(y.value().scalar_type() == at::kDouble, "y must be float64");
+ TORCH_CHECK(y.value().is_contiguous(), "y must be contiguous");
+ TORCH_CHECK(y.value().numel() == x.numel(), "x and y must have same numel");
+ }
int64_t n = x.numel(); void launch_fast_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
+ TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
+ TORCH_CHECK(result.is_cuda(), "result must be a CUDA tensor");
+ TORCH_CHECK(x.scalar_type() == at::kFloat && result.scalar_type() == at::kFloat,
+ "fast path expects float32 tensors");
+ TORCH_CHECK(x.is_contiguous() && result.is_contiguous(),
+ "x and result must be contiguous");
+ if (y.has_value()) {
+ TORCH_CHECK(y.value().is_cuda(), "y must be a CUDA tensor");
+ TORCH_CHECK(y.value().scalar_type() == at::kFloat, "y must be float32");
+ TORCH_CHECK(y.value().is_contiguous(), "y must be contiguous");
+ TORCH_CHECK(y.value().numel() == x.numel(), "x and y must have same numel");
+ }
int64_t n = x.numel();Also applies to: 207-219, 221-233
🤖 Prompt for AI Agents
In maint/precision/cuda_ops.cu around lines 193-205 (and similarly for 207-219
and 221-233): add runtime validation before launching the kernel to ensure x and
result are CUDA tensors with scalar_type float32, contiguous, and have the same
numel; if y is provided ensure it is defined, on the same CUDA device,
contiguous, float32, and has numel equal to x.numel(); use TORCH_CHECK (or
AT_ASSERTM) to fail fast with clear messages (e.g. check x.is_cuda(),
result.is_cuda(), x.scalar_type()==at::kFloat, x.is_contiguous(),
result.is_contiguous(), x.numel()==result.numel(), and if (y.has_value())
validate same device, scalar_type, contiguity, and numel) before calling
precise_operator_kernel.
This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
maint/precision/compare_ops.py (1)
76-83: Seed CUDA RNGs for reproducibility (optional).
torch.manual_seedmay not cover all CUDA RNG cases. Also seed CUDA.Apply this diff:
torch.manual_seed(args.seed) +if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
maint/precision/compare_ops.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
maint/precision/compare_ops.py (6)
tilelang/env.py (2)
disable_cache(232-233)get(136-139)tilelang/utils/target.py (1)
determine_target(44-84)tilelang/contrib/nvcc.py (1)
get_target_compute_version(258-299)tilelang/jit/__init__.py (2)
jit(237-310)compile(33-86)tilelang/transform/pass_config.py (1)
PassConfigKey(6-89)maint/precision/cuda_ops.cu (6)
launch_double_precise_operator(208-219)launch_double_precise_operator(208-208)launch_precise_operator(194-205)launch_precise_operator(194-194)launch_fast_operator(222-233)launch_fast_operator(222-222)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: bot-task
🔇 Additional comments (4)
maint/precision/compare_ops.py (4)
298-305: Replace assigned lambda with def to satisfy Ruff E731 (duplicate of prior review).CI will fail with E731. Use a named function.
Apply this diff:
- grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) + def grid(meta): + return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
315-322: Replace assigned lambda with def to satisfy Ruff E731 (duplicate of prior review).Same fix for LibDevice path.
Apply this diff:
- grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],) + def grid(meta): + return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
365-369: Convert 0‑d tensors to Python floats before formatting (duplicate of prior review).Formatting CUDA 0‑d tensors with “:.3e” raises TypeError. Use .item().
Apply this diff:
- print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, " - f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}") + max_abs = abs_err.max().item() + mean_abs = abs_err.mean().item() + max_rel = rel_err.max().item() + mean_rel = rel_err.mean().item() + print(f"{tag:<32} max abs: {max_abs:.3e}, mean abs: {mean_abs:.3e}, " + f"max rel: {max_rel:.3e}, mean rel: {mean_rel:.3e}")
20-20: Verify top-level export ofdisable_cache
Cannot locatedisable_cachein the localtilelangpackage; ensure it’s exposed intilelang/__init__.pyor import it directly fromtilelang.env.
| return load( | ||
| name="cuda_ops", | ||
| sources=["cuda_ops.cu"], | ||
| extra_cuda_cflags=[] # No fast_math flags | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix source path for CUDA extension (will fail when run from repo root).
load(..., sources=["cuda_ops.cu"]) assumes CWD == script dir. When run as python maint/precision/compare_ops.py, this file won’t be found. Use an absolute path derived from __file__.
Apply this diff:
- return load(
- name="cuda_ops",
- sources=["cuda_ops.cu"],
- extra_cuda_cflags=[] # No fast_math flags
- )
+ src_dir = os.path.dirname(os.path.abspath(__file__))
+ cu_path = os.path.join(src_dir, "cuda_ops.cu")
+ return load(
+ name="cuda_ops",
+ sources=[cu_path],
+ extra_cuda_cflags=[], # No fast_math flags
+ )Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 69 to 73, the CUDA extension load
uses a relative source path ("cuda_ops.cu") which will fail when the script is
run from the repo root; change to compute the absolute path using the script's
directory (e.g., derive dirname = os.path.dirname(__file__) and join with
"cuda_ops.cu") and pass that absolute path in the sources list to load, ensuring
it works regardless of current working directory.
| for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): | ||
| row = by * TILELANG_BLOCK_M + i | ||
| col = bx * TILELANG_BLOCK_N + j | ||
| x = A[row, col] | ||
|
|
||
| if op_id == 1: # reciprocal | ||
| B[row, col] = 1.0 / x | ||
| elif op_id == 2: # exp | ||
| B[row, col] = T.exp(x) | ||
| elif op_id == 3: # log | ||
| B[row, col] = T.log(x) | ||
| elif op_id == 4: # sin | ||
| B[row, col] = T.sin(x) | ||
| elif op_id == 5: # cos | ||
| B[row, col] = T.cos(x) | ||
| elif op_id == 6: # sqrt | ||
| B[row, col] = T.sqrt(x) | ||
| elif op_id == 7: # tanh | ||
| B[row, col] = T.tanh(x) | ||
| elif op_id == 8: # rsqrt | ||
| B[row, col] = T.rsqrt(x) | ||
| elif op_id == 9: # inv_sqrt | ||
| B[row, col] = 1.0 / T.sqrt(x) | ||
| else: | ||
| B[row, col] = x # Default case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out-of-bounds accesses in TileLang unary kernel (missing bounds guards).
row and col can exceed M/N on the edge tiles (ceildiv used). This reads/writes OOB.
Apply this diff to guard loads/stores:
- for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
+ for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
- x = A[row, col]
+ if (row < M) and (col < N):
+ x = A[row, col]
- if op_id == 1: # reciprocal
- B[row, col] = 1.0 / x
- elif op_id == 2: # exp
- B[row, col] = T.exp(x)
- elif op_id == 3: # log
- B[row, col] = T.log(x)
- elif op_id == 4: # sin
- B[row, col] = T.sin(x)
- elif op_id == 5: # cos
- B[row, col] = T.cos(x)
- elif op_id == 6: # sqrt
- B[row, col] = T.sqrt(x)
- elif op_id == 7: # tanh
- B[row, col] = T.tanh(x)
- elif op_id == 8: # rsqrt
- B[row, col] = T.rsqrt(x)
- elif op_id == 9: # inv_sqrt
- B[row, col] = 1.0 / T.sqrt(x)
- else:
- B[row, col] = x # Default case
+ if op_id == 1: # reciprocal
+ B[row, col] = 1.0 / x
+ elif op_id == 2: # exp
+ B[row, col] = T.exp(x)
+ elif op_id == 3: # log
+ B[row, col] = T.log(x)
+ elif op_id == 4: # sin
+ B[row, col] = T.sin(x)
+ elif op_id == 5: # cos
+ B[row, col] = T.cos(x)
+ elif op_id == 6: # sqrt
+ B[row, col] = T.sqrt(x)
+ elif op_id == 7: # tanh
+ B[row, col] = T.tanh(x)
+ elif op_id == 8: # rsqrt
+ B[row, col] = T.rsqrt(x)
+ elif op_id == 9: # inv_sqrt
+ B[row, col] = 1.0 / T.sqrt(x)
+ else:
+ B[row, col] = x # Default case📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): | |
| row = by * TILELANG_BLOCK_M + i | |
| col = bx * TILELANG_BLOCK_N + j | |
| x = A[row, col] | |
| if op_id == 1: # reciprocal | |
| B[row, col] = 1.0 / x | |
| elif op_id == 2: # exp | |
| B[row, col] = T.exp(x) | |
| elif op_id == 3: # log | |
| B[row, col] = T.log(x) | |
| elif op_id == 4: # sin | |
| B[row, col] = T.sin(x) | |
| elif op_id == 5: # cos | |
| B[row, col] = T.cos(x) | |
| elif op_id == 6: # sqrt | |
| B[row, col] = T.sqrt(x) | |
| elif op_id == 7: # tanh | |
| B[row, col] = T.tanh(x) | |
| elif op_id == 8: # rsqrt | |
| B[row, col] = T.rsqrt(x) | |
| elif op_id == 9: # inv_sqrt | |
| B[row, col] = 1.0 / T.sqrt(x) | |
| else: | |
| B[row, col] = x # Default case | |
| for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): | |
| row = by * TILELANG_BLOCK_M + i | |
| col = bx * TILELANG_BLOCK_N + j | |
| if (row < M) and (col < N): | |
| x = A[row, col] | |
| if op_id == 1: # reciprocal | |
| B[row, col] = 1.0 / x | |
| elif op_id == 2: # exp | |
| B[row, col] = T.exp(x) | |
| elif op_id == 3: # log | |
| B[row, col] = T.log(x) | |
| elif op_id == 4: # sin | |
| B[row, col] = T.sin(x) | |
| elif op_id == 5: # cos | |
| B[row, col] = T.cos(x) | |
| elif op_id == 6: # sqrt | |
| B[row, col] = T.sqrt(x) | |
| elif op_id == 7: # tanh | |
| B[row, col] = T.tanh(x) | |
| elif op_id == 8: # rsqrt | |
| B[row, col] = T.rsqrt(x) | |
| elif op_id == 9: # inv_sqrt | |
| B[row, col] = 1.0 / T.sqrt(x) | |
| else: | |
| B[row, col] = x # Default case |
🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 200 to 224, the TileLang kernel
computes row/col using ceildiv tiles and may read/write past M/N on edge tiles;
add explicit bounds guards before any load or store: check if row < M and col <
N (or equivalent shape variables) then load x = A[row,col] and perform the op
and write B[row,col]; otherwise skip the operation (or write a safe/default
value) so no OOB access occurs. Ensure the guard wraps both the read and the
write for every op_id branch so the default branch is also protected.
| for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): | ||
| row = by * TILELANG_BLOCK_M + i | ||
| col = bx * TILELANG_BLOCK_N + j | ||
| x = A[row, col] | ||
| y = B[row, col] | ||
| C[row, col] = x / y # Division operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out-of-bounds accesses in TileLang binary kernel (missing bounds guards).
Same issue for the division kernel.
Apply this diff:
- for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
+ for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
- x = A[row, col]
- y = B[row, col]
- C[row, col] = x / y # Division operation
+ if (row < M) and (col < N):
+ x = A[row, col]
+ y = B[row, col]
+ C[row, col] = x / y # Division operation📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): | |
| row = by * TILELANG_BLOCK_M + i | |
| col = bx * TILELANG_BLOCK_N + j | |
| x = A[row, col] | |
| y = B[row, col] | |
| C[row, col] = x / y # Division operation | |
| for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N): | |
| row = by * TILELANG_BLOCK_M + i | |
| col = bx * TILELANG_BLOCK_N + j | |
| if (row < M) and (col < N): | |
| x = A[row, col] | |
| y = B[row, col] | |
| C[row, col] = x / y # Division operation |
…plicit fastmath op to invoke (tile-ai#875) * Add fast math operations for CUDA: exp, exp10, log, log2, log10, tan, cos, and sin (tile-ai#865) * Refactor fast math operation definitions for consistency and readability in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity. * Remove unnecessary pass configurations for warp specialization and TMA lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality. * Update fastmath tests to reflect that tl.* intrinsics generate no fastmath versions and disable cache in main execution. * Fix formatting in fastmath test comments for clarity on tl.* intrinsics behavior. * Add precision comparison tool for CUDA operations This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations. * Add precision comparison tool for CUDA operations This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested.
This PR addresses issue #839. It makes certain math intrinsics use the fast‑math versions of operators through TVM’s intrinsic rules.
After some consideration, we decided not to expose
__expat the base operator level and keepexpdefinition in tvm. The idea is that__expis a CUDA‑specific intrinsic, whereasexpis a standard mathematical operation available on almost all hardware. Therefore,expshould remain defined at the fundamental level, making it easier to adapt to different operators across backends.Once this PR is merged, all calls like
T.exp,T.sin, etc. will be code‑generated intoexpf/sinfor the corresponding precision math function. Fast‑math intrinsics such as__expwill still be accessible internally in TileLang.Example usage:
and we also provide a test for this component.
Summary by CodeRabbit
New Features
Tests
Documentation
Chores