Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 28, 2025

This pull request improves CUDA code generation for bfloat16 and other floating-point types, particularly for min/max operations and constant printing. The changes ensure correct handling of bfloat16, float16, and float32/64 types, and conditionally include necessary headers only when bfloat16 support is required.

Enhanced min/max operation support

  • Added custom VisitExpr_ methods for MinNode and MaxNode in CodeGenTileLangCUDA, ensuring correct CUDA code generation for bfloat16 (uses custom functions), float16 (uses __hmin/__hmax), and float32/64 (uses standard min/max). [1] [2]
  • Implemented custom min and max functions for bfloat16_t in cuda_bf16_wrapper.h to support comparisons, replacing conditional compilation with unconditional inclusion.

Improved bfloat16 constant printing

  • Updated PrintConst to correctly handle bfloat16 constants, including special cases for infinity and NaN, and to mark when bfloat16 support is needed.

Conditional header inclusion

  • Changed header inclusion logic in Finish() to include cuda_bf16_fallbacks.cuh only if bfloat16 is used, removing the previous preprocessor guard.

Minor clarifications

  • Clarified comments and switch logic for float types in constant printing.

Summary by CodeRabbit

  • New Features

    • CUDA code generation now recognizes and emits min/max operations across bfloat16, float16, float32, and float64, using optimized scalar implementations where applicable.
    • Public visitor support added so min/max expressions are handled during CUDA codegen.
  • Bug Fixes / Improvements

    • FP8→float conversion flow short-circuited to avoid redundant processing.
    • Improved serialization and handling of bfloat16/float16 constants, preserving infinities/NaNs and tracking numeric constant usage.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 28, 2025

Walkthrough

Added CUDA codegen visitors for MinNode/MaxNode with type-dispatched emissions (cutlass::fast_min/fast_max for scalar bf16/fp16, std::min/max for scalar float32/float64, fallback to base CodeGenC). Updated Cast handling to early-return after FP8→float conversions and improved bf16/float16 constant serialization.

Changes

Cohort / File(s) Summary
Header declarations
src/target/codegen_cuda.h
Added visitor declarations: VisitExpr_(const MinNode *op, std::ostream &os) final; and VisitExpr_(const MaxNode *op, std::ostream &os) final; in CodeGenTileLangCUDA.
Min/Max implementation
src/target/codegen_cuda.cc
Implemented MinNode/MaxNode visitors with type dispatch: scalar bf16/float16 → cutlass::fast_min/cutlass::fast_max; scalar float32/float64 → std::min/std::max; other types → delegate to base CodeGenC.
Cast & constant handling
src/target/codegen_cuda.cc
After FP8→float32/float16 conversion paths, emit intermediate sret and return early. Consolidated bf16/float16 constant printing to handle inf/NaN via std::numeric_limits, mark constants, and update bf16 usage; other constant behavior unchanged.

Sequence Diagram(s)

sequenceDiagram
    participant Visitor as CodeGenTileLangCUDA Visitor
    participant TypeCheck as Operand Type Check
    participant Cutlass as cutlass::fast_min/max
    participant Std as std::min/max
    participant Base as Base CodeGenC

    Visitor->>TypeCheck: inspect operand types
    alt scalar bf16 or float16
        TypeCheck->>Cutlass: emit cutlass::fast_min/fast_max
        Cutlass-->>Visitor: generated expression
    else scalar float32 or float64
        TypeCheck->>Std: emit std::min/std::max
        Std-->>Visitor: generated expression
    else other types
        TypeCheck->>Base: delegate to base CodeGenC visitor
        Base-->>Visitor: generated expression
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Inspect Min/Max type-dispatch branches for correctness (scalar vs vector, bf16/float16 interactions).
  • Verify FP8→float early-return emits correct sret and preserves required downstream logic.
  • Validate bf16/float16 constant serialization, NaN/inf handling, and marking.

Possibly related PRs

Suggested reviewers

  • xwhzz
  • tzj-fxz

Poem

🐇 I hop through kernels, swift and spry,
Min and Max now leap and fly,
Cutlass trims halves, std keeps the rest,
Constants tidy, casts do their best,
A rabbit cheers this codegen sky! 🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[BugFix] Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values" is fully related to the main changes in the pull request. The changeset directly implements the key features mentioned in the title: it adds custom VisitExpr_ methods for MinNode and MaxNode to handle min/max operations in CUDA code generation, introduces proper handling of bfloat16 constants with special values (infinity and NaN), and implements support for bfloat16 type operations using CUTLASS APIs. The title is specific and clear enough that a teammate scanning the history would immediately understand the primary objective of this change. The language is concise without using vague terms or file lists.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b2bb08b and 065feed.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
  • VisitExpr_ (611-639)
  • VisitExpr_ (611-611)
  • VisitExpr_ (759-976)
  • VisitExpr_ (759-759)
  • VisitExpr_ (1039-1052)
  • VisitExpr_ (1039-1039)
  • VisitExpr_ (1054-1169)
  • VisitExpr_ (1054-1055)
  • VisitExpr_ (1217-1220)
  • VisitExpr_ (1217-1218)
🔇 Additional comments (4)
src/target/codegen_cuda.cc (4)

1020-1021: LGTM: Consistent early return pattern after FP8 conversion.

The early return after emitting the FP8 conversion result follows the established pattern used for other vectorized conversions (half, bfloat16) in this function. This prevents incorrect fallthrough to the generic element-wise cast logic.


1039-1060: Type-dispatched min() emission looks correct.

The implementation properly handles:

  • bfloat16/float16 scalar → cutlass::fast_min (since standard min doesn't support these)
  • float32/float64 scalar → standard min()
  • All other types → delegates to base CodeGenC

This addresses the past review concern about non-standard float bit widths by falling through to the base implementation for any unhandled cases.


1062-1083: LGTM: MaxNode implementation mirrors MinNode correctly.

The type-dispatched emission for MaxNode follows the same correct pattern as MinNode, using cutlass::fast_max for bf16/fp16 scalar types and standard max() for float32/64.


2589-2614: No critical issues found. The std::numeric_limits specializations are available.

CUTLASS provides std::numeric_limits specializations for cutlass::half_t and cutlass::bfloat16_t, and CUDA's libcu++ and CUB/CCCL also provide std::numeric_limits support for extended floating types including __half and __nv_bfloat16. The generated code at lines 2599–2605 emits half_t and bfloat16_t type names (via PrintType), which will resolve to these aliased types when the generated code is compiled with the appropriate CUDA/CUTLASS headers.

No changes required.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
src/target/codegen_cuda.cc (2)

1038-1067: Prefer fminf/fmin for float32/64 to avoid macro/overload pitfalls and match IEEE semantics.

Using bare min can bind unintended overloads/macros and diverge for NaNs/±0. Keep bf16/half branches as-is, switch f32/f64 to fminf/fmin.

Apply:

-  // For float32 and float64 scalar, use standard fmin functions
+  // For float32 and float64 scalar, use IEEE fmin variants
   if (t.is_float() && t.is_scalar()) {
-    if (t.bits() == 32 || t.bits() == 64) {
-      os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
-    }
+    if (t.bits() == 32) {
+      os << "fminf(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
+    } else if (t.bits() == 64) {
+      os << "fmin(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
+    }
     return;
   }

1069-1098: Similarly, use fmaxf/fmax for float32/64.

Align with IEEE fmax semantics and avoid bare max.

Apply:

-  // For float32 and float64 scalar, use standard max functions
+  // For float32 and float64 scalar, use IEEE fmax variants
   if (t.is_float() && t.is_scalar()) {
-    if (t.bits() == 32 || t.bits() == 64) {
-      os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
-    }
+    if (t.bits() == 32) {
+      os << "fmaxf(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
+    } else if (t.bits() == 64) {
+      os << "fmax(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
+    }
     return;
   }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 60567ba and 7b30347.

📒 Files selected for processing (3)
  • src/target/codegen_cuda.cc (4 hunks)
  • src/target/codegen_cuda.h (1 hunks)
  • src/tl_templates/cuda/cuda_bf16_wrapper.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/target/codegen_cuda.h (1)
src/target/codegen_cuda.cc (22)
  • VisitExpr_ (887-1036)
  • VisitExpr_ (887-887)
  • VisitExpr_ (1038-1067)
  • VisitExpr_ (1038-1038)
  • VisitExpr_ (1069-1098)
  • VisitExpr_ (1069-1069)
  • VisitExpr_ (1303-2235)
  • VisitExpr_ (1303-1303)
  • VisitExpr_ (2368-2382)
  • VisitExpr_ (2368-2368)
  • VisitExpr_ (2384-2452)
  • VisitExpr_ (2384-2385)
  • VisitExpr_ (2454-2602)
  • VisitExpr_ (2454-2455)
  • VisitExpr_ (2672-2675)
  • VisitExpr_ (2672-2673)
  • op (217-232)
  • op (217-217)
  • op (1858-1860)
  • op (1858-1858)
  • op (1861-1863)
  • op (1861-1861)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
  • VisitExpr_ (611-639)
  • VisitExpr_ (611-611)
  • VisitExpr_ (759-976)
  • VisitExpr_ (759-759)
  • VisitExpr_ (1039-1052)
  • VisitExpr_ (1039-1039)
  • VisitExpr_ (1054-1169)
  • VisitExpr_ (1054-1055)
  • VisitExpr_ (1217-1220)
  • VisitExpr_ (1217-1218)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (3)
src/target/codegen_cuda.h (1)

54-55: Good addition: explicit Min/Max visitors

Declarations align with the .cc implementations and keep the override surface clear. No further nits.

src/tl_templates/cuda/cuda_bf16_wrapper.h (1)

21-23: Unconditional includes look fine; verify device math availability.

The new min/max use fminf/fmaxf after the refactor above; ensure math functions are available via existing includes (cuda headers pulled by common.h). If not, include <math_functions.h>.

src/target/codegen_cuda.cc (1)

2608-2626: bf16 constants: enable bf16 path for all cases; INF/NAN macros.

  • Set enable_bf16_ for finite constants too (see diff above) to guarantee required headers get emitted.
  • Confirm CUDART_INF_BF16/CUDART_NAN_BF16 are available via included CUDA headers in your fallback chain; otherwise add need_math_constants_h_ or an explicit cuda_bf16 include in the Finish path.

Comment on lines 24 to 30
TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
return (a < b) ? a : b;
}

TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
return (a > b) ? a : b;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

bf16 min/max should mirror IEEE fmin/fmax semantics (NaN and signed zero).

The ternary + operator< approach returns NaN when one operand is NaN and may return +0 instead of −0. Use fminf/fmaxf on promoted float and convert back to bf16.

Apply:

-TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
-  return (a < b) ? a : b;
-}
-
-TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
-  return (a > b) ? a : b;
-}
+TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
+  float af = __bfloat162float(a);
+  float bf = __bfloat162float(b);
+  return __float2bfloat16(fminf(af, bf));
+}
+
+TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
+  float af = __bfloat162float(a);
+  float bf = __bfloat162float(b);
+  return __float2bfloat16(fmaxf(af, bf));
+}

Optional: avoid Windows macro collisions

+#if defined(_WIN32)
+#  ifdef min
+#    undef min
+#  endif
+#  ifdef max
+#    undef max
+#  endif
+#endif
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
return (a < b) ? a : b;
}
TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
return (a > b) ? a : b;
}
TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
float af = __bfloat162float(a);
float bf = __bfloat162float(b);
return __float2bfloat16(fminf(af, bf));
}
TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
float af = __bfloat162float(a);
float bf = __bfloat162float(b);
return __float2bfloat16(fmaxf(af, bf));
}
🤖 Prompt for AI Agents
In src/tl_templates/cuda/cuda_bf16_wrapper.h around lines 24 to 30, the current
ternary implementations of min/max use operator< which does not follow IEEE
fmin/fmax semantics for NaN and signed zero; replace them by promoting the
bfloat16 operands to float, call fminf/fmaxf to preserve NaN propagation and
signed-zero behavior, then convert the result back to bfloat16. Also guard
against Windows macro collisions by undefining min/max macros (or use
parentheses around fminf/fmaxf) before defining these functions.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
src/target/codegen_cuda.cc (1)

2621-2625: Critical: Finite bfloat16 constants don't set enable_bf16_ flag.

This is the same issue flagged in the previous review. Finite bfloat16 constants (lines 2621-2624) don't set enable_bf16_ = true, while infinity and NaN cases do (lines 2616, 2619). This means kernels that only use finite bf16 constants won't include the necessary headers, causing compilation failures.

Apply this diff to fix:

     } else {
       temp << "bfloat16_t";
       temp << '(' << std::hexfloat << op->value << 'f';
       temp << "/*" << std::scientific << op->value << "*/";
       temp << ')';
+      p->enable_bf16_ = true;
     }
     p->MarkConst(temp.str());
     os << temp.str();
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7b30347 and 5929302.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (2)
src/target/codegen_hip.cc (10)
  • VisitExpr_ (611-639)
  • VisitExpr_ (611-611)
  • VisitExpr_ (759-976)
  • VisitExpr_ (759-759)
  • VisitExpr_ (1039-1052)
  • VisitExpr_ (1039-1039)
  • VisitExpr_ (1054-1169)
  • VisitExpr_ (1054-1055)
  • VisitExpr_ (1217-1220)
  • VisitExpr_ (1217-1218)
src/tl_templates/cpp/half.hpp (4)
  • isinf (2676-2676)
  • isinf (5290-5290)
  • isnan (2677-2677)
  • isnan (5298-5298)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)

Comment on lines 283 to 286

if (enable_bf16_) {
decl_stream << "#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n";
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that enable_bf16_ is set for all bfloat16 usage.

The conditional inclusion is good practice, but it depends on enable_bf16_ being correctly set whenever bfloat16 types or operations are used. A previous review identified that finite bfloat16 constants don't set this flag (see lines 2621-2624), which could cause missing headers.

Verify that all bfloat16 code paths set enable_bf16_ = true:


🏁 Script executed:

#!/bin/bash
# Check all locations where bfloat16 types are handled
rg -n -C3 'is_bfloat16\(\)|\.is_bfloat16|bfloat16_t' src/target/codegen_cuda.cc

Length of output: 4166


Add enable_bf16_ = true; for finite bfloat16 constants at lines 2620-2624.

The conditional header inclusion at lines 283-286 depends on enable_bf16_ being set. However, the finite bfloat16 constant branch (lines 2620-2624) generates bfloat16_t casts without setting this flag. This causes kernels using only finite bfloat16 constants to skip the required header while still using the bfloat16_t type, leading to compilation failures.

Add p->enable_bf16_ = true; in the else branch at line 2620.

🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 2620 to 2624 and 283 to 286, the
finite bfloat16 constant branch creates bfloat16_t casts but does not set
enable_bf16_, causing the conditional header inclusion to be skipped and leading
to compilation failures; fix this by adding p->enable_bf16_ = true; in the else
branch around line 2620 so that the cuda_bf16_fallbacks.cuh header is included
when bfloat16_t is emitted.

@Rachmanino
Copy link
Collaborator Author

I've just refactored the code to directly use CUTLASS apis

LeiWang1999
LeiWang1999 previously approved these changes Oct 28, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
src/target/codegen_cuda.cc (3)

1072-1077: Fix control flow: float scalars with non-standard bit widths generate no code.

Same issue as in VisitExpr_(const MinNode*) - when t.is_float() && t.is_scalar() is true but bits() is neither 32 nor 64, the function returns without emitting any code.

Apply this diff:

   // For float32 and float64 scalar, use standard max functions
   if (t.is_float() && t.is_scalar()) {
     if (t.bits() == 32 || t.bits() == 64) {
       os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
+      return;
     }
-    return;
   }
   
   // For all other scalar types (int, uint), use default implementation
   CodeGenC::VisitExpr_(op, os);

1050-1055: Fix control flow: float scalars with non-standard bit widths generate no code.

When t.is_float() && t.is_scalar() is true but bits() is neither 32 nor 64, the function returns without emitting any code. This results in missing min operations for float8, float16 (if it somehow reaches here), or other float types.

Apply this diff to fix the control flow:

   // For float32 and float64 scalar, use standard min functions
   if (t.is_float() && t.is_scalar()) {
     if (t.bits() == 32 || t.bits() == 64) {
       os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
+      return;
     }
-    return;
   }
   
   // For all other scalar types (int, uint), use default implementation
   CodeGenC::VisitExpr_(op, os);

This ensures non-32/64-bit float scalars fall through to the base implementation.


2591-2608: Set enable_bf16_ flag for bfloat16 constants to ensure header inclusion.

When handling bfloat16 constants (including infinity and NaN), the code doesn't set p->enable_bf16_ = true. The conditional header inclusion depends on this flag being set, so kernels that only use bfloat16 constants will fail to compile with missing type definitions.

Apply this diff:

   if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
     std::ostringstream temp;
     if (std::isinf(op->value)) {
       if (op->value < 0) {
         temp << "-";
       }
       temp << "std::numeric_limits<" << op->dtype << ">::infinity()";
+      if (op->dtype.is_bfloat16()) {
+        p->enable_bf16_ = true;
+      }
     } else if (std::isnan(op->value)) {
       temp << "std::numeric_limits<" << op->dtype << ">::quiet_NaN()";
+      if (op->dtype.is_bfloat16()) {
+        p->enable_bf16_ = true;
+      }
     } else {
       temp << op->dtype;
       temp << '(' << std::hexfloat << op->value << 'f';
       temp << "/*" << std::scientific << op->value << "*/";
       temp << ')';
+      if (op->dtype.is_bfloat16()) {
+        p->enable_bf16_ = true;
+      }
     }
     p->MarkConst(temp.str());
     os << temp.str();
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)

1066-1066: Minor style: Missing space in condition.

Line 1066 has t.is_float16())&& which should be t.is_float16()) && for consistency with coding style.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5929302 and bc51577.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
  • VisitExpr_ (611-639)
  • VisitExpr_ (611-611)
  • VisitExpr_ (759-976)
  • VisitExpr_ (759-759)
  • VisitExpr_ (1039-1052)
  • VisitExpr_ (1039-1039)
  • VisitExpr_ (1054-1169)
  • VisitExpr_ (1054-1055)
  • VisitExpr_ (1217-1220)
  • VisitExpr_ (1217-1218)
🔇 Additional comments (1)
src/target/codegen_cuda.cc (1)

1020-1021: LGTM: Early return prevents fall-through.

The early return after FP8 conversion correctly prevents the code from continuing to the element-wise cast fallback.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/target/codegen_cuda.cc (1)

2595-2607: Fix invalid C++ type identifiers in std::numeric_limits calls.

When op->dtype is streamed (lines 2599, 2601, 2603), it outputs TVM's dtype representation (e.g., "bfloat16", "float16") rather than valid CUDA type names (e.g., "bfloat16_t", "half_t"). This generates invalid C++ code like std::numeric_limits<bfloat16>::infinity() that won't compile.

Replace dtype streaming with explicit CUDA type names, following the pattern used at lines 347 and 404:

   if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
     std::ostringstream temp;
+    std::string cuda_type = op->dtype.is_bfloat16() ? "bfloat16_t" : "half_t";
     if (std::isinf(op->value)) {
       if (op->value < 0) {
         temp << "-";
       }
-      temp << "std::numeric_limits<" << op->dtype << ">::infinity()";
+      temp << "std::numeric_limits<" << cuda_type << ">::infinity()";
     } else if (std::isnan(op->value)) {
-      temp << "std::numeric_limits<" << op->dtype << ">::quiet_NaN()";
+      temp << "std::numeric_limits<" << cuda_type << ">::quiet_NaN()";
     } else {
-      temp << op->dtype;
+      temp << cuda_type;
       temp << '(' << std::hexfloat << op->value << 'f';
       temp << "/*" << std::scientific << op->value << "*/";
       temp << ')';
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bc51577 and b2bb08b.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
  • VisitExpr_ (611-639)
  • VisitExpr_ (611-611)
  • VisitExpr_ (759-976)
  • VisitExpr_ (759-759)
  • VisitExpr_ (1039-1052)
  • VisitExpr_ (1039-1039)
  • VisitExpr_ (1054-1169)
  • VisitExpr_ (1054-1055)
  • VisitExpr_ (1217-1220)
  • VisitExpr_ (1217-1218)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (1)
src/target/codegen_cuda.cc (1)

1039-1083: LGTM! Well-structured min/max visitors for CUDA types.

The type-dispatched implementations correctly route scalar bfloat16/float16 to CUTLASS optimized functions, scalar float32/64 to standard min/max, and other types to the base implementation. The control flow is clear and complete.

Comment on lines +2608 to +2609
p->MarkConst(temp.str());
os << temp.str();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Set enable flags for consistency when emitting bf16/fp16 constants.

When emitting bfloat16_t and half_t type names in constant expressions, the corresponding feature flags should be set for consistency with PrintType (lines 345, 402) and to signal that CUDA-specific types are in use.

       temp << ')';
     }
+    if (op->dtype.is_bfloat16()) {
+      p->enable_bf16_ = true;
+    } else {
+      p->enable_fp16_ = true;
+    }
     p->MarkConst(temp.str());
     os << temp.str();
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
p->MarkConst(temp.str());
os << temp.str();
if (op->dtype.is_bfloat16()) {
p->enable_bf16_ = true;
} else {
p->enable_fp16_ = true;
}
p->MarkConst(temp.str());
os << temp.str();
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 2608-2609, when emitting bf16/fp16
constant type names you must set the corresponding CUDA feature flags before
marking and writing the constant: if the emitted name is "bfloat16_t" set
p->enable_bfloat16 = true; if the emitted name is "half_t" set p->enable_fp16 =
true; then call p->MarkConst(temp.str()) and os << temp.str() so PrintType and
other code see the correct feature usage.

@LeiWang1999 LeiWang1999 merged commit c70b269 into tile-ai:main Oct 28, 2025
5 of 6 checks passed
RubiaCx pushed a commit to RubiaCx/tilelang that referenced this pull request Nov 24, 2025
…max functions and inf/nan values (tile-ai#1143)

* Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values

* refactor

* fix prev typo

* bugfix

* lint

* bugfix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants