Skip to content

Conversation

@LeiWang1999
Copy link
Member

@LeiWang1999 LeiWang1999 commented Sep 25, 2025

This PR addresses issue #839. It makes certain math intrinsics use the fast‑math versions of operators through TVM’s intrinsic rules.

After some consideration, we decided not to expose __exp at the base operator level and keep exp definition in tvm. The idea is that __exp is a CUDA‑specific intrinsic, whereas exp is a standard mathematical operation available on almost all hardware. Therefore, exp should remain defined at the fundamental level, making it easier to adapt to different operators across backends.

Once this PR is merged, all calls like T.exp, T.sin, etc. will be code‑generated into expf / sinf or the corresponding precision math function. Fast‑math intrinsics such as __exp will still be accessible internally in TileLang.

Example usage:

# Enable fast math explicitly
T.__exp()

# Or enable globally via pass configs
pass_configs = {
    tilelang.PassConfigKey.ENABLE_FAST_MATH: True
}

T.exp  # Will use fast-math version if enabled

and we also provide a test for this component.

Summary by CodeRabbit

  • New Features

    • Fast-math ops added (exp, exp10, log, log2, log10, sin, cos, tan) with language-level fastmath wrappers and CUDA codegen emitting optimized fast-math intrinsics; new CUDA math dispatch for type-specific variants.
    • New CUDA operator backend offering precise, double-precision, and fast execution modes.
  • Tests

    • Added CUDA-targeted tests validating fast-math vs non-fast-math codegen and numeric correctness.
  • Documentation

    • Added comprehensive precision/accuracy comparison report.
  • Chores

    • Updated third-party submodule reference.

…ity in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds eight fast-math TL intrinsics and Python wrappers, lowers them in CUDA codegen via new math dispatchers, adds CUDA-gated tests and precision-comparison tooling (including a CUDA operator backend), re-exports fastmath in the language package, and updates the TVM submodule reference.

Changes

Cohort / File(s) Summary
Submodule update
3rdparty/tvm
Updated submodule commit from 0524f76... to 883e96b....
TL builtins (C++)
src/op/builtin.cc, src/op/builtin.h
Added TL builtin declarations and registrations for __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin (each 1 input, TCallEffectKind = kOpaque).
CUDA codegen mappings (C++)
src/target/codegen_cuda.cc
Added CUDAMath, CUDAFastMath, CUDAFastMathTan dispatchers and lowering for the new intrinsics to emit CUDA math/fast-math symbols depending on dtype/width.
Python fastmath API
tilelang/language/fastmath.py, tilelang/language/__init__.py
New fastmath wrappers that call TL intrinsics via tir.call_intrin; re-exported via tilelang.language.__init__.
CUDA fast-math tests (Python)
testing/python/fastmath/test_mathops_fastmath.py
New CUDA-gated tests that inspect generated CUDA source for fast-math vs non-fast-math variants and run numeric comparisons (single/two-arg ops, abs mapping, explicit fastmath intrinsics).
Precision comparison tooling (Python + docs)
maint/precision/compare_ops.py, maint/precision/README.md
New precision comparison script across backends with data generation, multi-backend execution, error summarization; README updated with per-op error tables.
CUDA operator backend (C++)/bindings
maint/precision/cuda_ops.cu
New CUDA device implementations (precise/double/fast), kernels, launch wrappers, and pybind11 bindings exposing launches for multiple math ops with an enum-based op selector.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Py as tilelang.fastmath
  participant TIR as TIR CallNode (tl.__*)
  participant CG as CUDA Codegen
  participant CUDA as CUDA Math Runtime

  User->>Py: call __exp(x)
  Py->>TIR: tir.call_intrin(tl.__exp, x)
  TIR->>CG: Lower CallNode(tl.__exp)
  rect #E6F5FF
    note right of CG: select dispatcher based on dtype\n(CUDAFastMath / CUDAMath / CUDAFastMathTan)
    CG->>CUDA: emit mapped symbol (e.g., expf, __expf, exp)
  end
  CUDA-->>User: compute result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested reviewers

  • tzj-fxz

Poem

I twitch my ears at functions new,
__exp and friends now hop on cue.
From TL burrow to CUDA light,
fast-math scampers, kernels bright.
A rabbit cheers — bytes take flight! 🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly describes the primary change of disabling the default fastmath intrinsic dispatch and introducing an explicit fastmath operator invocation, which aligns directly with the PR’s goal of modifying math intrinsic handling when fast-math is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 84.21% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @LeiWang1999, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refines the handling of floating-point mathematical operations within TVM's TileLang, particularly concerning fast-math intrinsics. It transitions from an implicit fast-math approach to an explicit one, providing greater control over numerical precision and performance. Users can now explicitly invoke fast-math versions of functions or enable them globally, while standard mathematical functions like 'exp' will default to their non-fast-math, higher-precision counterparts, ensuring broader compatibility and predictable behavior.

Highlights

  • Explicit Fast Math Control: Default fast-math intrinsic dispatch is disabled, requiring explicit opt-in for fast-math versions of operators.
  • Standard 'exp' Behavior: The 'exp' function remains a standard mathematical operation, generating 'expf'/'sinf' (standard precision) by default, rather than a CUDA-specific '__exp'.
  • New Fast Math Intrinsics: Explicit fast-math intrinsics like 'T.__exp()' are introduced and accessible within TileLang.
  • Global Fast Math Toggle: Fast math can be enabled globally via 'pass_configs' ('tilelang.PassConfigKey.ENABLE_FAST_MATH').
  • Comprehensive Testing: A new test suite ('test_mathops_fastmath.py') has been added to verify the correct generation and numerical behavior of both standard and fast-math operations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@LeiWang1999 LeiWang1999 changed the title [Fast math] Disable default tvm fastmath intrin dispatch and expose explicit fastmath [FastMath] Disable default TVM intrinsic dispatch and add explicit fastmath op to invoke Sep 25, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces explicit fast-math intrinsics and disables the default fast-math dispatch for standard math functions to give users more control. My review focuses on improving code maintainability by reducing duplication in the C++ implementation, fixing a critical documentation error in the new Python fast-math module, and significantly strengthening the new Python tests by adding missing numerical correctness checks and improving existing ones to ensure the new functionality is robust and well-tested.



def __exp(x):
"""Calculate 2**x with fast math
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The docstring for __exp states that it calculates 2**x. This is incorrect; exp(x) calculates e**x, while exp2(x) calculates 2**x. This should be corrected to avoid confusion for users of this API.

Suggested change
"""Calculate 2**x with fast math
"""Calculate exp(x) with fast math

Comment on lines 54 to 96
def run_single_arg_mathop_test(mathop_name,
mathop_func,
M=128,
N=128,
block_M=32,
block_N=32,
dtype="float32"):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
"""

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
bx * block_N + j])

# Test with FAST_MATH disabled
kernel_no_fastmath = tilelang.compile(
main,
out_idx=[1],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
})

source_no_fastmath = kernel_no_fastmath.get_kernel_source()

print(f"\n=== Testing {mathop_name} ===")
print("FAST_MATH=False:")

# Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)

print(f"✓ {mathop_name} compilation and execution test passed")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The run_single_arg_mathop_test function only checks the generated CUDA source code for the presence of fastmath functions, but it doesn't perform any numerical correctness checks. This is a significant gap in testing. Please add a numerical check against a reference implementation (e.g., torch) to ensure the generated code is functionally correct, similar to the checks in other test helpers in this file.

Comment on lines +259 to +264
if cuda_mathop_name == "exp":
expected = torch.exp(a)
elif cuda_mathop_name == "log":
expected = torch.log(a)
else:
expected = b_fastmath # Just check compilation works
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The numerical correctness check in run_fastmath_mathop_test is incomplete. It only compares the output of exp and log against a torch reference. For all other functions (__exp10, __log2, __log10, __tan, __cos, __sin), it only checks that the code compiles and runs by comparing the output to itself. The tests should be expanded to compare the results of all fastmath functions against their corresponding torch reference implementations to ensure correctness.

    if cuda_mathop_name == "exp":
        expected = torch.exp(a)
    elif cuda_mathop_name == "exp10":
        expected = torch.pow(10, a)
    elif cuda_mathop_name == "log":
        expected = torch.log(a)
    elif cuda_mathop_name == "log2":
        expected = torch.log2(a)
    elif cuda_mathop_name == "log10":
        expected = torch.log10(a)
    elif cuda_mathop_name == "tan":
        expected = torch.tan(a)
    elif cuda_mathop_name == "cos":
        expected = torch.cos(a)
    elif cuda_mathop_name == "sin":
        expected = torch.sin(a)
    else:
        expected = b_fastmath  # Just check compilation works

Comment on lines +67 to +74
std::string operator()(DataType t, std::string name) const {
if (t.is_float() && t.bits() == 32) {
return "__" + name + 'f';
} else {
return CUDAMath::operator()(t, name);
}
return "";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return ""; on line 73 is unreachable because all paths in the if/else statement above it return a value. This dead code should be removed.

  std::string operator()(DataType t, std::string name) const {
    if (t.is_float() && t.bits() == 32) {
      return "__" + name + 'f';
    } else {
      return CUDAMath::operator()(t, name);
    }
  }

Comment on lines +1704 to +1735
} else if (op->op.same_as(tl::__exp())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "exp");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__exp10())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "exp10");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__log())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "log");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__log2())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "log2");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__log10())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "log10");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__tan())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__cos())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "cos");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__sin())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "sin");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling the different fast math intrinsics is very repetitive. Each else if block contains almost identical code, creating a new CUDAFastMath object and calling it. This could be refactored to reduce duplication and improve maintainability, for example by using a map from the Op to the math function name string and handling them in a single block or a helper function.

c_fastmath = kernel_fastmath(a, b)

# Both should produce similar results
torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The numerical test in run_two_arg_mathop_test compares the output of the kernel compiled with FAST_MATH=True against the one with FAST_MATH=False. Since these two-argument functions are not expected to be affected by the fast math flag, this test only confirms that behavior. It doesn't verify correctness against a golden reference. It would be stronger to compare the output against the equivalent torch function (e.g., torch.pow, torch.fmod).

Suggested change
torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3)
if mathop_name == "pow":
expected = torch.pow(a, b)
elif mathop_name == "fmod":
expected = torch.fmod(a, b)
torch.testing.assert_close(c_no_fastmath, expected, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(c_fastmath, expected, rtol=1e-3, atol=1e-3)


@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for test_mathops_generate_no_fastmath is confusing. It says "Test that our tl.* mathops generate fastmath CUDA code", but the test name and implementation check for the opposite (i.e., that they generate non-fastmath code when the fast math flag is disabled). The docstring should be corrected to reflect the test's actual behavior.

Additionally, there's a misleading comment on line 92 within run_single_arg_mathop_test: # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf). This comment seems incorrect given the test's logic and should be removed or updated.

    """Test that our tl.* mathops generate non-fastmath CUDA code by default."""

@LeiWang1999 LeiWang1999 changed the title [FastMath] Disable default TVM intrinsic dispatch and add explicit fastmath op to invoke [FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit fastmath op to invoke Sep 25, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (5)
src/op/builtin.h (1)

78-86: Avoid double-underscore identifiers in C++ public API

Identifiers beginning with a double underscore are reserved in C/C++. Using __exp, __log, etc., in a public header risks UB/collisions. Consider alternative C++ accessor names (e.g., fast_exp, fast_log, …) and mapping to tl.__exp string ops internally, or at least document the intentional deviation.

If you intend to keep the __ names, please confirm this won’t be compiled on toolchains enforcing strict reserved identifier rules (e.g., some sanitizers/linters). Do you want help sketching a macro that allows the Op registry name to stay tl.__exp while exposing non-__ C++ accessors?

src/op/builtin.cc (1)

43-67: Mark fast‑math intrinsics as pure (not opaque)

These math ops are side‑effect free. Using kOpaque needlessly inhibits CSE/hoisting and other optimizations. Prefer kPure to align with TVM math semantics.

Apply this diff:

-TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__exp).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__exp10).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__log).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__log2).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__log10).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__tan).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__cos).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));

-TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>(
-    "TCallEffectKind", Integer(CallEffectKind::kOpaque));
+TIR_DEFINE_TL_BUILTIN(__sin).set_num_inputs(1).set_attr<TCallEffectKind>(
+    "TCallEffectKind", Integer(CallEffectKind::kPure));
tilelang/language/__init__.py (1)

29-29: Fix noqa code and consider avoiding star import

Change the noqa to suppress F403 (star import) instead of F401 (unused imports).

Apply this diff:

-from .fastmath import *  # noqa: F401
+from .fastmath import *  # noqa: F403

Optionally, re-export via all instead of star import to keep linters happy and limit API surface.

testing/python/fastmath/test_mathops_fastmath.py (2)

259-266: Broaden numerical checks for more ops

You can validate more fast‑math ops without much overhead.

Apply this diff:

-    if cuda_mathop_name == "exp":
+    if cuda_mathop_name == "exp":
         expected = torch.exp(a)
     elif cuda_mathop_name == "log":
         expected = torch.log(a)
+    elif cuda_mathop_name == "log2":
+        expected = torch.log2(a)
+    elif cuda_mathop_name == "log10":
+        expected = torch.log10(a)
     else:
         expected = b_fastmath  # Just check compilation works

22-46: Regex is fine; add a small guard for false positives (optional)

Consider anchoring to \b before the name as well to avoid matching suffixes inside identifiers (very rare here).

Example:

-    fastmath_pattern = rf"__({mathop_name}f?)\b"
-    non_fastmath_pattern = rf"(?<!__)({mathop_name}f?)\b"
+    fastmath_pattern = rf"\b__({mathop_name}f?)\b"
+    non_fastmath_pattern = rf"(?<!__)\b({mathop_name}f?)\b"
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 15a303d and 2463338.

📒 Files selected for processing (7)
  • 3rdparty/tvm (1 hunks)
  • src/op/builtin.cc (1 hunks)
  • src/op/builtin.h (1 hunks)
  • src/target/codegen_cuda.cc (2 hunks)
  • testing/python/fastmath/test_mathops_fastmath.py (1 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/fastmath.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/language/fastmath.py (2)
src/op/builtin.h (1)
  • tvm (13-344)
tilelang/language/tir/op.py (1)
  • call_intrin (119-144)
testing/python/fastmath/test_mathops_fastmath.py (4)
tilelang/jit/__init__.py (1)
  • compile (33-86)
tilelang/jit/kernel.py (1)
  • out_idx (446-447)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-89)
tilelang/language/tir/op.py (1)
  • exp10 (1870-1883)
🪛 Ruff (0.13.1)
tilelang/language/__init__.py

29-29: from .fastmath import * used; unable to detect undefined names

(F403)


29-29: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/fastmath.py

140-149: __all__ is not sorted

Apply an isort-style sorting to __all__

(RUF022)


141-141: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


142-142: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


143-143: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


144-144: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


145-145: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


146-146: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


147-147: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


148-148: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

🔇 Additional comments (5)
3rdparty/tvm (1)

1-1: Submodule update looks intentional.

Pointer advance to commit 883e96b42ae0df40c2f7194cc932bbcd9d0c5627 aligns with the fast-math integration—no concerns from my side.

src/op/builtin.cc (3)

33-33: LGTM: cuTensorMapType constant

Returning UInt(8, 128) is clear and consistent with expected 128-bit handle size.


19-31: PassConfig options registration matches header

All keys (including enable/disable fast‑math) are registered and consistent with the header and Python PassConfigKey enum.

Please confirm conflict resolution when both TL_ENABLE_FAST_MATH and TL_DISABLE_FAST_MATH are set. If undefined, consider documenting precedence (e.g., ENABLE wins) or asserting mutually exclusive usage.


43-67: Fast-math intrinsics mapping and wrappers verified
codegen_cuda.cc and intrin_rule_cuda.cc handle tl::__exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin for float16/32/64, and tilelang/language/fastmath.py exports all eight.

testing/python/fastmath/test_mathops_fastmath.py (1)

208-247: Confirm CUDA emits __-prefixed intrinsics for all tested ops

Some CUDA fast‑math intrinsics don’t exist for every function/precision (e.g., exp10). If codegen maps to non‑prefixed variants for any, these assertions will fail.

Please run the verification script in my other comment to confirm that all of: exp, exp10, log, log2, log10, tan, cos, sin are emitted with __ in the generated CUDA for float32. If any lack __ support, we can relax the check for that op.

Comment on lines +1725 to +1732
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__cos())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "cos");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__sin())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use the tan-specific fast-math mapper

We landed CUDAFastMathTan to avoid emitting __tanf (too inaccurate) and to route FP16 through the h* variants, but this branch still instantiates the generic CUDAFastMath, so we end up generating __tanf for float32. That defeats the accuracy fix and contradicts the new helper/comment.

Please switch this site to CUDAFastMathTan.

-  } else if (op->op.same_as(tl::__tan())) {
-    CUDAFastMath math_func;
+  } else if (op->op.same_as(tl::__tan())) {
+    CUDAFastMathTan math_func;
     std::string func_name = math_func(op->dtype, "tan");
     os << func_name << "(" << PrintExpr(op->args[0]) << ")";
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__cos())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "cos");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__sin())) {
} else if (op->op.same_as(tl::__tan())) {
CUDAFastMathTan math_func;
std::string func_name = math_func(op->dtype, "tan");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__cos())) {
CUDAFastMath math_func;
std::string func_name = math_func(op->dtype, "cos");
os << func_name << "(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::__sin())) {
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 1725 to 1732, the branch handling tan
currently instantiates CUDAFastMath which emits __tanf for float32; replace that
instantiation with CUDAFastMathTan so the mapper picks the tan-specific
fast-math variants (and routes FP16 to h* variants). Concretely, change the
local type from CUDAFastMath to CUDAFastMathTan, keep the rest of the call to
math_func(op->dtype, "tan") and the PrintExpr usage unchanged, so the emitted
function name uses the tan-specialized mapper.

Comment on lines +90 to +96
print("FAST_MATH=False:")

# Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)

print(f"✓ {mathop_name} compilation and execution test passed")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix misleading comments and message

This test compiles with FAST_MATH=False and expects non‑fast‑math calls. Update the comments and message to avoid confusion.

Apply this diff:

-    print(f"\n=== Testing {mathop_name} ===")
-    print("FAST_MATH=False:")
-
-    # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
-    check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
-
-    print(f"✓ {mathop_name} compilation and execution test passed")
+    print(f"\n=== Testing {mathop_name} ===")
+    print("FAST_MATH=False:")
+    check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
+    print(f"✓ {mathop_name} compilation test passed")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print("FAST_MATH=False:")
# Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
print(f"✓ {mathop_name} compilation and execution test passed")
print(f"\n=== Testing {mathop_name} ===")
print("FAST_MATH=False:")
check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
print(f"✓ {mathop_name} compilation test passed")
🤖 Prompt for AI Agents
In testing/python/fastmath/test_mathops_fastmath.py around lines 90 to 96, the
inline comment and print message are misleading because this block compiles with
FAST_MATH=False and therefore expects non-fast-math calls; update the comment to
state that we expect non-fast-math (regular) calls when FAST_MATH is disabled,
and change the print message to clearly reflect that the FAST_MATH=False
compilation and non-fast-math check passed (e.g., "✓ {mathop_name} compilation
and non-fast-math execution test passed"). Ensure the comment and message
wording unambiguously reference FAST_MATH=False and non-fast-math expectations.

Comment on lines 270 to 297
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
# Based on test results, our tl.* intrinsics actually generate fastmath versions
# This appears to be the intended behavior
single_arg_mathops = [
("exp", T.exp),
("exp2", T.exp2),
("exp10", T.exp10),
("log", T.log),
("log2", T.log2),
("log10", T.log10),
("sin", T.sin),
("cos", T.cos),
("tan", T.tan),
("sinh", T.sinh),
("cosh", T.cosh),
("tanh", T.tanh),
("atan", T.atan),
("sqrt", T.sqrt),
("rsqrt", T.rsqrt),
("erf", T.erf),
("floor", T.floor),
("ceil", T.ceil),
("trunc", T.trunc),
("round", T.round),
("nearbyint", T.nearbyint),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Align test docstring with behavior

The docstring says “generate fastmath CUDA code” but the test name and assertions check non‑fast‑math when FAST_MATH=False. Update the docstring accordingly.

Apply this diff:

-"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
-# Based on test results, our tl.* intrinsics actually generate fastmath versions
-# This appears to be the intended behavior
+"""Test that tl.* mathops generate non-fastmath CUDA calls when FAST_MATH=False."""
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
# Based on test results, our tl.* intrinsics actually generate fastmath versions
# This appears to be the intended behavior
single_arg_mathops = [
("exp", T.exp),
("exp2", T.exp2),
("exp10", T.exp10),
("log", T.log),
("log2", T.log2),
("log10", T.log10),
("sin", T.sin),
("cos", T.cos),
("tan", T.tan),
("sinh", T.sinh),
("cosh", T.cosh),
("tanh", T.tanh),
("atan", T.atan),
("sqrt", T.sqrt),
("rsqrt", T.rsqrt),
("erf", T.erf),
("floor", T.floor),
("ceil", T.ceil),
("trunc", T.trunc),
("round", T.round),
("nearbyint", T.nearbyint),
]
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that tl.* mathops generate non-fastmath CUDA calls when FAST_MATH=False."""
single_arg_mathops = [
("exp", T.exp),
("exp2", T.exp2),
("exp10", T.exp10),
("log", T.log),
("log2", T.log2),
("log10", T.log10),
("sin", T.sin),
("cos", T.cos),
("tan", T.tan),
("sinh", T.sinh),
("cosh", T.cosh),
("tanh", T.tanh),
("atan", T.atan),
("sqrt", T.sqrt),
("rsqrt", T.rsqrt),
("erf", T.erf),
("floor", T.floor),
("ceil", T.ceil),
("trunc", T.trunc),
("round", T.round),
("nearbyint", T.nearbyint),
]
🤖 Prompt for AI Agents
In testing/python/fastmath/test_mathops_fastmath.py around lines 270 to 297, the
test docstring incorrectly states it checks that tl.* mathops "generate fastmath
CUDA code" while the test name and assertions verify non-fast-math behavior when
FAST_MATH=False; update the docstring to reflect that the test ensures tl.*
mathops do NOT generate fastmath CUDA code (i.e., verify non-fast-math
intrinsics are produced) so the docstring matches the test intent and behavior.

Comment on lines +124 to +135
"""Calculate 2**x with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix the __exp docstring to reflect the natural exponent.

tl.__exp computes the natural exponential (ex), not 2x. The current docstring is misleading for anyone using the fast-math wrapper.

-def __exp(x):
-    """Calculate 2**x with fast math
+def __exp(x):
+    """Calculate e**x with fast math
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Calculate 2**x with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
def __exp(x):
"""Calculate e**x with fast math
Parameters
----------
x : PrimExpr
Input argument.
Returns
-------
y : PrimExpr
The result.
"""
🤖 Prompt for AI Agents
In tilelang/language/fastmath.py around lines 124 to 135, the docstring for
__exp incorrectly states it calculates 2**x; update the docstring to state it
computes the natural exponential e**x (or exp(x)), and adjust the
parameter/returns wording if needed to reflect natural exponent semantics (e.g.,
"Calculate e**x with fast math" and keep parameter x : PrimExpr and return y :
PrimExpr). Ensure the description and summary explicitly mention "natural
exponential" or "e**x" to avoid confusion.

…A lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
testing/python/fastmath/test_mathops_fastmath.py (2)

22-29: Tighten regex to avoid false positives and improve robustness

Current patterns can match substrings inside longer identifiers (e.g., "texp(") and rely on capturing groups unnecessarily. Anchor on non-word boundaries and use non-capturing groups.

-def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
+def check_fastmath_usage(source, mathop_name, expect_fastmath=False):
     """Check source for fastmath/non-fastmath versions"""
-    fastmath_pattern = rf"__({mathop_name}f?)\b"
-    non_fastmath_pattern = rf"(?<!__)({mathop_name}f?)\b"
+    fastmath_pattern = rf"(?<!\w)__(?:{mathop_name}f?)\b"
+    non_fastmath_pattern = rf"(?<![_\w])(?:{mathop_name}f?)\b"
 
-    fastmath_matches = re.findall(fastmath_pattern, source)
-    non_fastmath_matches = re.findall(non_fastmath_pattern, source)
+    fastmath_matches = re.findall(fastmath_pattern, source)
+    non_fastmath_matches = re.findall(non_fastmath_pattern, source)

244-247: Avoid tan singularities to reduce flakiness

Clamp inputs for tan to keep away from asymptotes.

-    # Ensure positive values for functions that need them
-    if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
+    # Ensure valid/stable domains
+    if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
         a = torch.abs(a) + 0.1
+    if cuda_mathop_name == "tan":
+        a = torch.clamp(a, -1.0, 1.0)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2463338 and 029ddf3.

📒 Files selected for processing (1)
  • testing/python/fastmath/test_mathops_fastmath.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/fastmath/test_mathops_fastmath.py (2)
tilelang/jit/__init__.py (1)
  • compile (33-86)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-89)
🔇 Additional comments (5)
testing/python/fastmath/test_mathops_fastmath.py (5)

76-94: Add numerical correctness checks for single‑arg ops (compare to torch)

This helper only checks codegen. Add a golden comparison for all ops under test.

     print(f"\n=== Testing {mathop_name} ===")
     print("FAST_MATH=False:")
 
-    # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
     check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
 
-    print(f"✓ {mathop_name} compilation and execution test passed")
+    # Numerical correctness against torch
+    torch_dtype = getattr(torch, dtype)
+    a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
+    # Domain restrictions
+    if mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
+        a = torch.abs(a) + 0.1
+    if mathop_name == "tan":
+        a = torch.clamp(a, -1.0, 1.0)
+    b = kernel_no_fastmath(a)
+    if mathop_name == "exp":
+        expected = torch.exp(a)
+    elif mathop_name == "exp2":
+        expected = torch.pow(2, a)
+    elif mathop_name == "exp10":
+        expected = torch.pow(10, a)
+    elif mathop_name == "log":
+        expected = torch.log(a)
+    elif mathop_name == "log2":
+        expected = torch.log2(a)
+    elif mathop_name == "log10":
+        expected = torch.log10(a)
+    elif mathop_name == "sin":
+        expected = torch.sin(a)
+    elif mathop_name == "cos":
+        expected = torch.cos(a)
+    elif mathop_name == "tan":
+        expected = torch.tan(a)
+    elif mathop_name == "sinh":
+        expected = torch.sinh(a)
+    elif mathop_name == "cosh":
+        expected = torch.cosh(a)
+    elif mathop_name == "tanh":
+        expected = torch.tanh(a)
+    elif mathop_name == "atan":
+        expected = torch.atan(a)
+    elif mathop_name == "sqrt":
+        expected = torch.sqrt(a)
+    elif mathop_name == "rsqrt":
+        expected = torch.reciprocal(torch.sqrt(a))
+    elif mathop_name == "erf":
+        expected = torch.erf(a)
+    elif mathop_name == "floor":
+        expected = torch.floor(a)
+    elif mathop_name == "ceil":
+        expected = torch.ceil(a)
+    elif mathop_name == "trunc":
+        expected = torch.trunc(a)
+    elif mathop_name == "round":
+        expected = torch.round(a)
+    elif mathop_name == "nearbyint":
+        expected = torch.round(a)
+    else:
+        expected = b
+    torch.testing.assert_close(b, expected, rtol=1e-3, atol=1e-3)
+    print(f"✓ {mathop_name} compilation and numerical test passed")

87-93: Fix misleading comment and message in FAST_MATH=False block

This block validates non‑fast‑math calls; update the messaging accordingly.

-    print(f"\n=== Testing {mathop_name} ===")
-    print("FAST_MATH=False:")
-
-    # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
-    check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
-
-    print(f"✓ {mathop_name} compilation and execution test passed")
+    print(f"\n=== Testing {mathop_name} ===")
+    print("FAST_MATH=False:")
+    check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
+    print(f"✓ {mathop_name} compilation test passed")

147-165: Compare two‑arg ops against torch reference, not just each other

Validate correctness vs a golden reference for both FAST_MATH settings.

-    # Both should produce similar results
-    torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3)
-    print(f"✓ {mathop_name} numerical test passed")
+    # Golden reference
+    if mathop_name == "pow":
+        expected = torch.pow(a, b)
+    elif mathop_name == "fmod":
+        expected = torch.fmod(a, b)
+    else:
+        expected = c_no_fastmath
+    torch.testing.assert_close(c_no_fastmath, expected, rtol=1e-3, atol=1e-3)
+    torch.testing.assert_close(c_fastmath, expected, rtol=1e-3, atol=1e-3)
+    print(f"✓ {mathop_name} numerical test passed")

250-256: Add golden references for all fast‑math ops

Cover exp10, log2, log10, tan, cos, sin in the comparison.

-    if cuda_mathop_name == "exp":
-        expected = torch.exp(a)
-    elif cuda_mathop_name == "log":
-        expected = torch.log(a)
-    else:
-        expected = b_fastmath  # Just check compilation works
+    if cuda_mathop_name == "exp":
+        expected = torch.exp(a)
+    elif cuda_mathop_name == "exp10":
+        expected = torch.pow(10, a)
+    elif cuda_mathop_name == "log":
+        expected = torch.log(a)
+    elif cuda_mathop_name == "log2":
+        expected = torch.log2(a)
+    elif cuda_mathop_name == "log10":
+        expected = torch.log10(a)
+    elif cuda_mathop_name == "tan":
+        expected = torch.tan(a)
+    elif cuda_mathop_name == "cos":
+        expected = torch.cos(a)
+    elif cuda_mathop_name == "sin":
+        expected = torch.sin(a)
+    else:
+        expected = b_fastmath  # Keep compilation check for unknowns

262-266: Align test docstring with behavior

Docstring claims fast‑math generation, but the test checks non‑fast‑math when FAST_MATH=False.

-@tilelang.testing.requires_cuda
-def test_mathops_generate_no_fastmath():
-    """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
-    # Based on test results, our tl.* intrinsics actually generate fastmath versions
-    # This appears to be the intended behavior
+@tilelang.testing.requires_cuda
+def test_mathops_generate_no_fastmath():
+    """Test that tl.* mathops generate non-fastmath CUDA calls when FAST_MATH=False."""

@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
# Based on test results, our tl.* intrinsics actually generate fastmath versions
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments should be fixed/

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
testing/python/fastmath/test_mathops_fastmath.py (2)

243-246: Add domain constraint for trigonometric functions.

Trigonometric functions should have their inputs bounded to reasonable ranges to avoid extreme values that could cause numerical instability.

     # Ensure positive values for functions that need them
     if cuda_mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
         a = torch.abs(a) + 0.1
+    elif cuda_mathop_name in ["tan", "sin", "cos"]:
+        # Keep trigonometric inputs in reasonable range to avoid extreme values
+        a = torch.clamp(a, -10, 10)

337-338: Parameterize cache control with a fixture or context manager
Calling tilelang.disable_cache() in the test entrypoint applies globally and may leak into other tests. Define a scoped solution (for example, a pytest fixture or context manager that sets TILELANG_DISABLE_CACHE via monkeypatch) to limit cache‐disable to the desired tests.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 029ddf3 and a44e234.

📒 Files selected for processing (1)
  • testing/python/fastmath/test_mathops_fastmath.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/fastmath/test_mathops_fastmath.py (3)
tilelang/jit/__init__.py (1)
  • compile (33-86)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-89)
tilelang/env.py (1)
  • disable_cache (232-233)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: format-check
  • GitHub Check: bot-task
  • GitHub Check: format-check
🔇 Additional comments (5)
testing/python/fastmath/test_mathops_fastmath.py (5)

163-163: Consider validating against PyTorch references for two-arg functions.

The test only compares fastmath vs non-fastmath outputs but doesn't verify correctness against a known good implementation.

Add validation against PyTorch:

     # Both should produce similar results
     torch.testing.assert_close(c_no_fastmath, c_fastmath, rtol=1e-3, atol=1e-3)
+    
+    # Verify against PyTorch reference
+    if mathop_name == "pow":
+        expected = torch.pow(a, b)
+    elif mathop_name == "fmod":
+        expected = torch.fmod(a, b)
+    else:
+        expected = c_no_fastmath
+    
+    torch.testing.assert_close(c_no_fastmath, expected, rtol=1e-5, atol=1e-5)
     print(f"✓ {mathop_name} numerical test passed")

54-94: Enhance test coverage with numerical validation against PyTorch.

The current test only validates the generated CUDA source for the presence/absence of fastmath functions but lacks numerical correctness checks. Add checks against PyTorch reference implementations to ensure functional correctness.

Add numerical validation to make the test more robust:

 def run_single_arg_mathop_test(mathop_name,
                                mathop_func,
                                M=128,
                                N=128,
                                block_M=32,
                                block_N=32,
                                dtype="float32"):
     """
     Test single-argument mathops.
     T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
     """
 
     @T.prim_func
     def main(
             A: T.Tensor((M, N), dtype),
             B: T.Tensor((M, N), dtype),
     ):
         with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
             for i, j in T.Parallel(block_M, block_N):
                 B[by * block_M + i, bx * block_N + j] = mathop_func(A[by * block_M + i,
                                                                       bx * block_N + j])
 
     # Test with FAST_MATH disabled
     kernel_no_fastmath = tilelang.compile(
         main,
         out_idx=[1],
         target="cuda",
         pass_configs={
             tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
         })
 
     source_no_fastmath = kernel_no_fastmath.get_kernel_source()
 
     print(f"\n=== Testing {mathop_name} ===")
     print("FAST_MATH=False:")
 
-    # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
     check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)
 
-    print(f"✓ {mathop_name} compilation and execution test passed")
+    # Test numerical correctness
+    torch_dtype = getattr(torch, dtype)
+    a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
+    
+    # Ensure positive values for functions that need them
+    if mathop_name in ["sqrt", "rsqrt", "log", "log2", "log10"]:
+        a = torch.abs(a) + 0.1
+    elif mathop_name in ["atan"]:
+        a = torch.clamp(a, -10, 10)  # Keep in reasonable range
+    
+    b = kernel_no_fastmath(a)
+    
+    # Map to PyTorch equivalent
+    torch_func_map = {
+        "exp": torch.exp,
+        "exp2": lambda x: torch.pow(2, x),
+        "exp10": lambda x: torch.pow(10, x),
+        "log": torch.log,
+        "log2": torch.log2,
+        "log10": torch.log10,
+        "sin": torch.sin,
+        "cos": torch.cos,
+        "tan": torch.tan,
+        "sinh": torch.sinh,
+        "cosh": torch.cosh,
+        "tanh": torch.tanh,
+        "atan": torch.atan,
+        "sqrt": torch.sqrt,
+        "rsqrt": torch.rsqrt,
+        "erf": torch.erf,
+        "floor": torch.floor,
+        "ceil": torch.ceil,
+        "trunc": torch.trunc,
+        "round": torch.round,
+        "nearbyint": torch.round,  # nearbyint is similar to round
+    }
+    
+    if mathop_name in torch_func_map:
+        expected = torch_func_map[mathop_name](a)
+        torch.testing.assert_close(b, expected, rtol=1e-5, atol=1e-5)
+        print(f"✓ {mathop_name} numerical test passed")
+    
+    print(f"✓ {mathop_name} compilation and execution test passed")

251-256: Expand numerical validation for all fastmath functions.

The test only validates exp and log against PyTorch references. Extend validation to cover all fastmath functions.

Apply this diff to validate all fastmath functions:

     # Compare with reference implementation
     if cuda_mathop_name == "exp":
         expected = torch.exp(a)
+    elif cuda_mathop_name == "exp10":
+        expected = torch.pow(10, a)
     elif cuda_mathop_name == "log":
         expected = torch.log(a)
+    elif cuda_mathop_name == "log2":
+        expected = torch.log2(a)
+    elif cuda_mathop_name == "log10":
+        expected = torch.log10(a)
+    elif cuda_mathop_name == "tan":
+        expected = torch.tan(a)
+    elif cuda_mathop_name == "cos":
+        expected = torch.cos(a)
+    elif cuda_mathop_name == "sin":
+        expected = torch.sin(a)
     else:
-        expected = b_fastmath  # Just check compilation works
+        print(f"Warning: No reference implementation for {cuda_mathop_name}, skipping numerical check")
+        return

90-91: Remove misleading comment.

The comment is inconsistent with the test's behavior - when TL_ENABLE_FAST_MATH is False, the test expects non-fastmath functions.

-    # Our tl.* intrinsics actually generate fastmath versions (e.g., __expf)
     check_fastmath_usage(source_no_fastmath, mathop_name, expect_fastmath=False)

264-267: Fix incorrect docstring.

The docstring states the test checks for fastmath code generation, but the test actually verifies non-fastmath behavior when TL_ENABLE_FAST_MATH is False.

-    """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
-    # Based on test results, our tl.* intrinsics actually generate
-    # no fastmath versions
-    # This appears to be the intended behavior
+    """Test that tl.* mathops generate non-fastmath CUDA code when TL_ENABLE_FAST_MATH is False."""

This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (6)
maint/precision/README.md (1)

1-11: Document dataset, ranges, and seed used to compute these errors

Readers need the sample size, value domains per op, and RNG seed to reproduce these tables. Add a short preamble summarizing n, ranges for each op (e.g., positive-only for log/sqrt/rsqrt), device/arch, and the reference (“double CUDA precise”) used.

Apply this minimal preface above the first section:

+Precision report methodology
+---------------------------
+- Device/arch: CUDA GPU (set via TORCH_CUDA_ARCH_LIST); kernels launched with 256 threads
+- Sample size: n = 1,000,000 unless stated otherwise
+- Value ranges:
+  - div: x ~ U[-4, 4], y ~ U[1e-3, 4]
+  - log/sqrt/rsqrt/inv_sqrt: x ~ U[1e-3, 4]
+  - others: x ~ U[-4, 4]
+- Seed: 0
+- Reference: double-precision CUDA (“Double Precise”) vs each implementation
+
maint/precision/cuda_ops.cu (2)

124-145: Add default in switch to guard invalid op_type

Defensive default case avoids leaving r uninitialized if an invalid op_type slips through.

Apply to all three kernels:

         switch (op_type) {
             case OP_DIV:        r = precise_div(a, b); break;
             case OP_RECIPROCAL: r = precise_reciprocal(a); break;
             case OP_EXP:        r = precise_exp(a); break;
             case OP_LOG:        r = precise_log(a); break;
             case OP_SIN:        r = precise_sin(a); break;
             case OP_COS:        r = precise_cos(a); break;
             case OP_SQRT:       r = precise_sqrt(a); break;
             case OP_TANH:       r = precise_tanh(a); break;
             case OP_RSQRT:      r = precise_rsqrt(a); break;
             case OP_INV_SQRT:   r = precise_inv_sqrt(a); break;
+            default:            r = a; break;
         }

(Repeat analogous default in double_precise_operator_kernel and fast_operator_kernel.)

Also applies to: 147-168, 170-191


83-122: Consider using CUDA intrinsics over inline PTX for portability

Inline PTX (rcp.approx.f32/sqrt.approx.f32/rsqrt.approx.f32) can be brittle across toolchains/arch; CUDA fast-math intrinsics (__frcp_rn not available, but __fdividef, __fsqrt_rn, etc.) or libdevice may be preferable unless you need explicit opcodes.

Would you like a variant using only CUDA intrinsics to reduce PTX coupling?

maint/precision/compare_ops.py (3)

58-69: Pin compile flags to avoid ambient fast‑math affecting “precise” path

You note “No fast_math flags”; good. Consider explicitly passing conservative flags (e.g., disable reassociation/FTZ/DAZ) if you need strict IEEE behavior for “precise” baselines.

I can propose vetted flags for nvcc/clang if needed.


49-56: Shebang without executable bit (EXE001)

Either remove the shebang or set the script executable in repo metadata.


175-210: Drop unused parameter use_fastmath in make_tilelang_unary_kernel

The parameter isn’t used inside the nested kernel; it’s only used in compile pass_configs outside. Removing it avoids ARG001.

-def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool = False):
+def make_tilelang_unary_kernel(M: int, N: int, op_id: int):
@@
-    def tilelang_unary_kernel(
+    def tilelang_unary_kernel(
         A: T.Tensor((M, N), "float32"),
         B: T.Tensor((M, N), "float32"),
     ):
@@
-    return tilelang_unary_kernel
+    return tilelang_unary_kernel

And update the call site:

-        kernel_func = make_tilelang_unary_kernel(M, N, op_id, use_fastmath)
+        kernel_func = make_tilelang_unary_kernel(M, N, op_id)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a44e234 and 0cf3d49.

📒 Files selected for processing (3)
  • maint/precision/README.md (1 hunks)
  • maint/precision/compare_ops.py (1 hunks)
  • maint/precision/cuda_ops.cu (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
maint/precision/compare_ops.py (5)
tilelang/env.py (2)
  • disable_cache (232-233)
  • get (136-139)
tilelang/utils/target.py (1)
  • determine_target (44-84)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (258-299)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-89)
maint/precision/cuda_ops.cu (6)
  • launch_double_precise_operator (208-219)
  • launch_double_precise_operator (208-208)
  • launch_precise_operator (194-205)
  • launch_precise_operator (194-194)
  • launch_fast_operator (222-233)
  • launch_fast_operator (222-222)
🪛 Ruff (0.13.1)
maint/precision/compare_ops.py

1-1: Shebang is present but file is not executable

(EXE001)


175-175: Unused function argument: use_fastmath

(ARG001)


272-272: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


286-286: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)


320-320: Avoid specifying long messages outside the exception class

(TRY003)


373-373: Do not catch blind exception: Exception

(BLE001)


423-423: Do not catch blind exception: Exception

(BLE001)

🪛 GitHub Actions: CI Test on AMD
maint/precision/compare_ops.py

[error] 297-297: E731 Do not assign a lambda expression, use a def


[error] 314-314: E731 Do not assign a lambda expression, use a def


[error] 438-438: SIM114 Combine if branches using logical or operator


[error] 441-441: SIM114 Combine if branches using logical or operator

🪛 GitHub Actions: CI
maint/precision/compare_ops.py

[error] 297-297: E731 Do not assign a lambda expression, use a def


[error] 314-314: E731 Do not assign a lambda expression, use a def


[error] 438-441: SIM114 Combine if branches using logical or operator


[error] 441-447: SIM114 Combine if branches using logical or operator

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (3)
maint/precision/compare_ops.py (3)

228-266: Possible OOB in TileLang kernels at edges; add bounds guard or confirm safe memory pass

row/col are computed from ceildiv grids without explicit guards. Unless TileLang inserts masking by default, edge tiles may access OOB.

Can you confirm T.Parallel + current pass configs guarantee safe bounds? If not, add a guard:

-                x = A[row, col]
+                if row < M and col < N:
+                    x = A[row, col]
+                    ...
+                # else: implicitly skip

As per pass_config keys, TL_DISABLE_SAFE_MEMORY_ACCESS defaults to False, but please verify generated TIR. Based on learnings


20-28: Be resilient to target parsing failures

determine_target/parse_compute_version may vary; guard and fall back to a sensible arch list (or skip setting TORCH_CUDA_ARCH_LIST).

If needed, I can add a try/except around these calls with a fallback to native arch detection via torch.cuda.get_device_capability().


151-169: Verify Triton libdevice symbol names

libdevice bindings can differ by Triton version (e.g., sqrt vs sqrt_rn). Please confirm these exact names exist in your pinned Triton. If not, map to available ones.

I can generate a quick probe script to introspect triton.language.extra.libdevice for available symbols.

Comment on lines +272 to +279
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)

if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
triton_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE)
else: # Unary operation
triton_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix E731: replace lambda with def for Triton grid

CI fails with E731. Replace assigned lambda with a def.

-    grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
+    def grid(meta):
+        return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
🧰 Tools
🪛 Ruff (0.13.1)

272-272: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)

Comment on lines +286 to +293
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)

if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE)
else: # Unary operation
triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Fix E731: replace lambda with def for LibDevice grid

Same as above.

-    grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
+    def grid(meta):
+        return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE)
else: # Unary operation
triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE)
def grid(meta):
return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
if op_id == 0: # Division - binary operation
assert y is not None, "Division operation requires second operand"
triton_libdevice_binary_kernel[grid](x, y, out, x.numel(), BLOCK_SIZE=TRITON_BLOCK_SIZE)
else: # Unary operation
triton_libdevice_unary_kernel[grid](x, out, x.numel(), op_id, BLOCK_SIZE=TRITON_BLOCK_SIZE)
🧰 Tools
🪛 Ruff (0.13.1)

286-286: Do not assign a lambda expression, use a def

Rewrite grid as a def

(E731)

🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 286 to 293, replace the assigned
lambda for grid with a proper function definition to satisfy E731: define def
grid(meta): return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
(keeping the same return shape), then leave the subsequent
triton_libdevice_binary_kernel[grid](...) and
triton_libdevice_unary_kernel[grid](...) calls unchanged so they continue to use
the callable grid.

Comment on lines +332 to +336
abs_err = (output_double - reference_double).abs()
rel_err = abs_err / (reference_double.abs().clamp_min(1e-30))
print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: formatting torch.Tensors with format spec causes TypeError on CUDA tensors

abs_err.max()/mean() et al. are 0-d tensors; formatting with “:.3e” raises on CUDA. Convert to Python floats.

-    print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
-          f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")
+    max_abs = abs_err.max().item()
+    mean_abs = abs_err.mean().item()
+    max_rel = rel_err.max().item()
+    mean_rel = rel_err.mean().item()
+    print(f"{tag:<32} max abs: {max_abs:.3e}, mean abs: {mean_abs:.3e}, "
+          f"max rel: {max_rel:.3e}, mean rel: {mean_rel:.3e}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
abs_err = (output_double - reference_double).abs()
rel_err = abs_err / (reference_double.abs().clamp_min(1e-30))
print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")
abs_err = (output_double - reference_double).abs()
rel_err = abs_err / (reference_double.abs().clamp_min(1e-30))
max_abs = abs_err.max().item()
mean_abs = abs_err.mean().item()
max_rel = rel_err.max().item()
mean_rel = rel_err.mean().item()
print(f"{tag:<32} max abs: {max_abs:.3e}, mean abs: {mean_abs:.3e}, "
f"max rel: {max_rel:.3e}, mean rel: {mean_rel:.3e}")
🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 332 to 336, the formatted
f-string is passing 0-d torch.Tensors (e.g. abs_err.max(), abs_err.mean(),
rel_err.max(), rel_err.mean()) into a float format spec which raises a TypeError
for CUDA tensors; fix by converting those 0-d tensors to Python floats (use
.item()) before formatting so the f-string receives plain floats (e.g. replace
abs_err.max() with abs_err.max().item(), and likewise for mean() and the rel_err
values).

Comment on lines +193 to +205
// 精确版本
void launch_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
int64_t n = x.numel();
int threads = 256;
int blocks = (n + threads - 1) / threads;
const float* y_ptr = nullptr;
if (y.has_value()) {
y_ptr = y.value().data_ptr<float>();
}
precise_operator_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
x.data_ptr<float>(), y_ptr, result.data_ptr<float>(), n, static_cast<OperatorType>(op_type)
);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate dtype, contiguity, device, and shape in launchers

Raw pointer indexing assumes contiguous memory and matching dtypes; current code will silently miscompute on strided inputs or wrong dtypes. Add runtime checks to fail fast and prevent OOB/UB.

Apply these diffs to each launcher:

 void launch_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
+    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
+    TORCH_CHECK(result.is_cuda(), "result must be a CUDA tensor");
+    TORCH_CHECK(x.scalar_type() == at::kFloat && result.scalar_type() == at::kFloat,
+                "precise path expects float32 tensors");
+    TORCH_CHECK(x.is_contiguous() && result.is_contiguous(),
+                "x and result must be contiguous");
+    if (y.has_value()) {
+        TORCH_CHECK(y.value().is_cuda(), "y must be a CUDA tensor");
+        TORCH_CHECK(y.value().scalar_type() == at::kFloat, "y must be float32");
+        TORCH_CHECK(y.value().is_contiguous(), "y must be contiguous");
+        TORCH_CHECK(y.value().numel() == x.numel(), "x and y must have same numel");
+    }
     int64_t n = x.numel();
 void launch_double_precise_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
+    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
+    TORCH_CHECK(result.is_cuda(), "result must be a CUDA tensor");
+    TORCH_CHECK(x.scalar_type() == at::kDouble && result.scalar_type() == at::kDouble,
+                "double-precise path expects float64 tensors");
+    TORCH_CHECK(x.is_contiguous() && result.is_contiguous(),
+                "x and result must be contiguous");
+    if (y.has_value()) {
+        TORCH_CHECK(y.value().is_cuda(), "y must be a CUDA tensor");
+        TORCH_CHECK(y.value().scalar_type() == at::kDouble, "y must be float64");
+        TORCH_CHECK(y.value().is_contiguous(), "y must be contiguous");
+        TORCH_CHECK(y.value().numel() == x.numel(), "x and y must have same numel");
+    }
     int64_t n = x.numel();
 void launch_fast_operator(const at::Tensor& x, const c10::optional<at::Tensor>& y, at::Tensor& result, int op_type) {
+    TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
+    TORCH_CHECK(result.is_cuda(), "result must be a CUDA tensor");
+    TORCH_CHECK(x.scalar_type() == at::kFloat && result.scalar_type() == at::kFloat,
+                "fast path expects float32 tensors");
+    TORCH_CHECK(x.is_contiguous() && result.is_contiguous(),
+                "x and result must be contiguous");
+    if (y.has_value()) {
+        TORCH_CHECK(y.value().is_cuda(), "y must be a CUDA tensor");
+        TORCH_CHECK(y.value().scalar_type() == at::kFloat, "y must be float32");
+        TORCH_CHECK(y.value().is_contiguous(), "y must be contiguous");
+        TORCH_CHECK(y.value().numel() == x.numel(), "x and y must have same numel");
+    }
     int64_t n = x.numel();

Also applies to: 207-219, 221-233

🤖 Prompt for AI Agents
In maint/precision/cuda_ops.cu around lines 193-205 (and similarly for 207-219
and 221-233): add runtime validation before launching the kernel to ensure x and
result are CUDA tensors with scalar_type float32, contiguous, and have the same
numel; if y is provided ensure it is defined, on the same CUDA device,
contiguous, float32, and has numel equal to x.numel(); use TORCH_CHECK (or
AT_ASSERTM) to fail fast with clear messages (e.g. check x.is_cuda(),
result.is_cuda(), x.scalar_type()==at::kFloat, x.is_contiguous(),
result.is_contiguous(), x.numel()==result.numel(), and if (y.has_value())
validate same device, scalar_type, contiguity, and numel) before calling
precise_operator_kernel.

This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
maint/precision/compare_ops.py (1)

76-83: Seed CUDA RNGs for reproducibility (optional).

torch.manual_seed may not cover all CUDA RNG cases. Also seed CUDA.

Apply this diff:

 torch.manual_seed(args.seed)
+if torch.cuda.is_available():
+    torch.cuda.manual_seed_all(args.seed)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0cf3d49 and ce4a1a5.

📒 Files selected for processing (1)
  • maint/precision/compare_ops.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
maint/precision/compare_ops.py (6)
tilelang/env.py (2)
  • disable_cache (232-233)
  • get (136-139)
tilelang/utils/target.py (1)
  • determine_target (44-84)
tilelang/contrib/nvcc.py (1)
  • get_target_compute_version (258-299)
tilelang/jit/__init__.py (2)
  • jit (237-310)
  • compile (33-86)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-89)
maint/precision/cuda_ops.cu (6)
  • launch_double_precise_operator (208-219)
  • launch_double_precise_operator (208-208)
  • launch_precise_operator (194-205)
  • launch_precise_operator (194-194)
  • launch_fast_operator (222-233)
  • launch_fast_operator (222-222)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: bot-task
🔇 Additional comments (4)
maint/precision/compare_ops.py (4)

298-305: Replace assigned lambda with def to satisfy Ruff E731 (duplicate of prior review).

CI will fail with E731. Use a named function.

Apply this diff:

-    grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
+    def grid(meta):
+        return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)

315-322: Replace assigned lambda with def to satisfy Ruff E731 (duplicate of prior review).

Same fix for LibDevice path.

Apply this diff:

-    grid = lambda meta: ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)
+    def grid(meta):
+        return ((x.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'],)

365-369: Convert 0‑d tensors to Python floats before formatting (duplicate of prior review).

Formatting CUDA 0‑d tensors with “:.3e” raises TypeError. Use .item().

Apply this diff:

-    print(f"{tag:<32} max abs: {abs_err.max():.3e}, mean abs: {abs_err.mean():.3e}, "
-          f"max rel: {rel_err.max():.3e}, mean rel: {rel_err.mean():.3e}")
+    max_abs = abs_err.max().item()
+    mean_abs = abs_err.mean().item()
+    max_rel = rel_err.max().item()
+    mean_rel = rel_err.mean().item()
+    print(f"{tag:<32} max abs: {max_abs:.3e}, mean abs: {mean_abs:.3e}, "
+          f"max rel: {max_rel:.3e}, mean rel: {mean_rel:.3e}")

20-20: Verify top-level export of disable_cache
Cannot locate disable_cache in the local tilelang package; ensure it’s exposed in tilelang/__init__.py or import it directly from tilelang.env.

Comment on lines +69 to +73
return load(
name="cuda_ops",
sources=["cuda_ops.cu"],
extra_cuda_cflags=[] # No fast_math flags
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix source path for CUDA extension (will fail when run from repo root).

load(..., sources=["cuda_ops.cu"]) assumes CWD == script dir. When run as python maint/precision/compare_ops.py, this file won’t be found. Use an absolute path derived from __file__.

Apply this diff:

-    return load(
-        name="cuda_ops",
-        sources=["cuda_ops.cu"],
-        extra_cuda_cflags=[]  # No fast_math flags
-    )
+    src_dir = os.path.dirname(os.path.abspath(__file__))
+    cu_path = os.path.join(src_dir, "cuda_ops.cu")
+    return load(
+        name="cuda_ops",
+        sources=[cu_path],
+        extra_cuda_cflags=[],  # No fast_math flags
+    )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 69 to 73, the CUDA extension load
uses a relative source path ("cuda_ops.cu") which will fail when the script is
run from the repo root; change to compute the absolute path using the script's
directory (e.g., derive dirname = os.path.dirname(__file__) and join with
"cuda_ops.cu") and pass that absolute path in the sources list to load, ensuring
it works regardless of current working directory.

Comment on lines +200 to +224
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
x = A[row, col]

if op_id == 1: # reciprocal
B[row, col] = 1.0 / x
elif op_id == 2: # exp
B[row, col] = T.exp(x)
elif op_id == 3: # log
B[row, col] = T.log(x)
elif op_id == 4: # sin
B[row, col] = T.sin(x)
elif op_id == 5: # cos
B[row, col] = T.cos(x)
elif op_id == 6: # sqrt
B[row, col] = T.sqrt(x)
elif op_id == 7: # tanh
B[row, col] = T.tanh(x)
elif op_id == 8: # rsqrt
B[row, col] = T.rsqrt(x)
elif op_id == 9: # inv_sqrt
B[row, col] = 1.0 / T.sqrt(x)
else:
B[row, col] = x # Default case
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Out-of-bounds accesses in TileLang unary kernel (missing bounds guards).

row and col can exceed M/N on the edge tiles (ceildiv used). This reads/writes OOB.

Apply this diff to guard loads/stores:

-            for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
+            for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
                 row = by * TILELANG_BLOCK_M + i
                 col = bx * TILELANG_BLOCK_N + j
-                x = A[row, col]
+                if (row < M) and (col < N):
+                    x = A[row, col]
 
-                if op_id == 1:  # reciprocal
-                    B[row, col] = 1.0 / x
-                elif op_id == 2:  # exp
-                    B[row, col] = T.exp(x)
-                elif op_id == 3:  # log
-                    B[row, col] = T.log(x)
-                elif op_id == 4:  # sin
-                    B[row, col] = T.sin(x)
-                elif op_id == 5:  # cos
-                    B[row, col] = T.cos(x)
-                elif op_id == 6:  # sqrt
-                    B[row, col] = T.sqrt(x)
-                elif op_id == 7:  # tanh
-                    B[row, col] = T.tanh(x)
-                elif op_id == 8:  # rsqrt
-                    B[row, col] = T.rsqrt(x)
-                elif op_id == 9:  # inv_sqrt
-                    B[row, col] = 1.0 / T.sqrt(x)
-                else:
-                    B[row, col] = x  # Default case
+                    if op_id == 1:  # reciprocal
+                        B[row, col] = 1.0 / x
+                    elif op_id == 2:  # exp
+                        B[row, col] = T.exp(x)
+                    elif op_id == 3:  # log
+                        B[row, col] = T.log(x)
+                    elif op_id == 4:  # sin
+                        B[row, col] = T.sin(x)
+                    elif op_id == 5:  # cos
+                        B[row, col] = T.cos(x)
+                    elif op_id == 6:  # sqrt
+                        B[row, col] = T.sqrt(x)
+                    elif op_id == 7:  # tanh
+                        B[row, col] = T.tanh(x)
+                    elif op_id == 8:  # rsqrt
+                        B[row, col] = T.rsqrt(x)
+                    elif op_id == 9:  # inv_sqrt
+                        B[row, col] = 1.0 / T.sqrt(x)
+                    else:
+                        B[row, col] = x  # Default case
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
x = A[row, col]
if op_id == 1: # reciprocal
B[row, col] = 1.0 / x
elif op_id == 2: # exp
B[row, col] = T.exp(x)
elif op_id == 3: # log
B[row, col] = T.log(x)
elif op_id == 4: # sin
B[row, col] = T.sin(x)
elif op_id == 5: # cos
B[row, col] = T.cos(x)
elif op_id == 6: # sqrt
B[row, col] = T.sqrt(x)
elif op_id == 7: # tanh
B[row, col] = T.tanh(x)
elif op_id == 8: # rsqrt
B[row, col] = T.rsqrt(x)
elif op_id == 9: # inv_sqrt
B[row, col] = 1.0 / T.sqrt(x)
else:
B[row, col] = x # Default case
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
if (row < M) and (col < N):
x = A[row, col]
if op_id == 1: # reciprocal
B[row, col] = 1.0 / x
elif op_id == 2: # exp
B[row, col] = T.exp(x)
elif op_id == 3: # log
B[row, col] = T.log(x)
elif op_id == 4: # sin
B[row, col] = T.sin(x)
elif op_id == 5: # cos
B[row, col] = T.cos(x)
elif op_id == 6: # sqrt
B[row, col] = T.sqrt(x)
elif op_id == 7: # tanh
B[row, col] = T.tanh(x)
elif op_id == 8: # rsqrt
B[row, col] = T.rsqrt(x)
elif op_id == 9: # inv_sqrt
B[row, col] = 1.0 / T.sqrt(x)
else:
B[row, col] = x # Default case
🤖 Prompt for AI Agents
In maint/precision/compare_ops.py around lines 200 to 224, the TileLang kernel
computes row/col using ceildiv tiles and may read/write past M/N on edge tiles;
add explicit bounds guards before any load or store: check if row < M and col <
N (or equivalent shape variables) then load x = A[row,col] and perform the op
and write B[row,col]; otherwise skip the operation (or write a safe/default
value) so no OOB access occurs. Ensure the guard wraps both the read and the
write for every op_id branch so the default branch is also protected.

Comment on lines +242 to +247
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
x = A[row, col]
y = B[row, col]
C[row, col] = x / y # Division operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Out-of-bounds accesses in TileLang binary kernel (missing bounds guards).

Same issue for the division kernel.

Apply this diff:

-            for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
+            for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
                 row = by * TILELANG_BLOCK_M + i
                 col = bx * TILELANG_BLOCK_N + j
-                x = A[row, col]
-                y = B[row, col]
-                C[row, col] = x / y  # Division operation
+                if (row < M) and (col < N):
+                    x = A[row, col]
+                    y = B[row, col]
+                    C[row, col] = x / y  # Division operation
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
x = A[row, col]
y = B[row, col]
C[row, col] = x / y # Division operation
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
row = by * TILELANG_BLOCK_M + i
col = bx * TILELANG_BLOCK_N + j
if (row < M) and (col < N):
x = A[row, col]
y = B[row, col]
C[row, col] = x / y # Division operation

RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…plicit fastmath op to invoke (tile-ai#875)

* Add fast math operations for CUDA: exp, exp10, log, log2, log10, tan, cos, and sin (tile-ai#865)

* Refactor fast math operation definitions for consistency and readability in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity.

* Remove unnecessary pass configurations for warp specialization and TMA lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality.

* Update fastmath tests to reflect that tl.* intrinsics generate no fastmath versions and disable cache in main execution.

* Fix formatting in fastmath test comments for clarity on tl.* intrinsics behavior.

* Add precision comparison tool for CUDA operations

This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations.

* Add precision comparison tool for CUDA operations

This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant