-
Notifications
You must be signed in to change notification settings - Fork 384
[BugFix] Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values #1143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ions and inf/nan values
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdded CUDA codegen visitors for MinNode/MaxNode with type-dispatched emissions (cutlass::fast_min/fast_max for scalar bf16/fp16, std::min/max for scalar float32/float64, fallback to base CodeGenC). Updated Cast handling to early-return after FP8→float conversions and improved bf16/float16 constant serialization. Changes
Sequence Diagram(s)sequenceDiagram
participant Visitor as CodeGenTileLangCUDA Visitor
participant TypeCheck as Operand Type Check
participant Cutlass as cutlass::fast_min/max
participant Std as std::min/max
participant Base as Base CodeGenC
Visitor->>TypeCheck: inspect operand types
alt scalar bf16 or float16
TypeCheck->>Cutlass: emit cutlass::fast_min/fast_max
Cutlass-->>Visitor: generated expression
else scalar float32 or float64
TypeCheck->>Std: emit std::min/std::max
Std-->>Visitor: generated expression
else other types
TypeCheck->>Base: delegate to base CodeGenC visitor
Base-->>Visitor: generated expression
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🧰 Additional context used🧬 Code graph analysis (1)src/target/codegen_cuda.cc (1)
🔇 Additional comments (4)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
src/target/codegen_cuda.cc (2)
1038-1067: Prefer fminf/fmin for float32/64 to avoid macro/overload pitfalls and match IEEE semantics.Using bare min can bind unintended overloads/macros and diverge for NaNs/±0. Keep bf16/half branches as-is, switch f32/f64 to fminf/fmin.
Apply:
- // For float32 and float64 scalar, use standard fmin functions + // For float32 and float64 scalar, use IEEE fmin variants if (t.is_float() && t.is_scalar()) { - if (t.bits() == 32 || t.bits() == 64) { - os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; - } + if (t.bits() == 32) { + os << "fminf(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + } else if (t.bits() == 64) { + os << "fmin(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + } return; }
1069-1098: Similarly, use fmaxf/fmax for float32/64.Align with IEEE fmax semantics and avoid bare max.
Apply:
- // For float32 and float64 scalar, use standard max functions + // For float32 and float64 scalar, use IEEE fmax variants if (t.is_float() && t.is_scalar()) { - if (t.bits() == 32 || t.bits() == 64) { - os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; - } + if (t.bits() == 32) { + os << "fmaxf(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + } else if (t.bits() == 64) { + os << "fmax(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + } return; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/target/codegen_cuda.cc(4 hunks)src/target/codegen_cuda.h(1 hunks)src/tl_templates/cuda/cuda_bf16_wrapper.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/target/codegen_cuda.h (1)
src/target/codegen_cuda.cc (22)
VisitExpr_(887-1036)VisitExpr_(887-887)VisitExpr_(1038-1067)VisitExpr_(1038-1038)VisitExpr_(1069-1098)VisitExpr_(1069-1069)VisitExpr_(1303-2235)VisitExpr_(1303-1303)VisitExpr_(2368-2382)VisitExpr_(2368-2368)VisitExpr_(2384-2452)VisitExpr_(2384-2385)VisitExpr_(2454-2602)VisitExpr_(2454-2455)VisitExpr_(2672-2675)VisitExpr_(2672-2673)op(217-232)op(217-217)op(1858-1860)op(1858-1858)op(1861-1863)op(1861-1861)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
VisitExpr_(611-639)VisitExpr_(611-611)VisitExpr_(759-976)VisitExpr_(759-759)VisitExpr_(1039-1052)VisitExpr_(1039-1039)VisitExpr_(1054-1169)VisitExpr_(1054-1055)VisitExpr_(1217-1220)VisitExpr_(1217-1218)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (3)
src/target/codegen_cuda.h (1)
54-55: Good addition: explicit Min/Max visitorsDeclarations align with the .cc implementations and keep the override surface clear. No further nits.
src/tl_templates/cuda/cuda_bf16_wrapper.h (1)
21-23: Unconditional includes look fine; verify device math availability.The new min/max use fminf/fmaxf after the refactor above; ensure math functions are available via existing includes (cuda headers pulled by common.h). If not, include <math_functions.h>.
src/target/codegen_cuda.cc (1)
2608-2626: bf16 constants: enable bf16 path for all cases; INF/NAN macros.
- Set enable_bf16_ for finite constants too (see diff above) to guarantee required headers get emitted.
- Confirm CUDART_INF_BF16/CUDART_NAN_BF16 are available via included CUDA headers in your fallback chain; otherwise add need_math_constants_h_ or an explicit cuda_bf16 include in the Finish path.
| TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) { | ||
| return (a < b) ? a : b; | ||
| } | ||
|
|
||
| TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) { | ||
| return (a > b) ? a : b; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bf16 min/max should mirror IEEE fmin/fmax semantics (NaN and signed zero).
The ternary + operator< approach returns NaN when one operand is NaN and may return +0 instead of −0. Use fminf/fmaxf on promoted float and convert back to bf16.
Apply:
-TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
- return (a < b) ? a : b;
-}
-
-TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
- return (a > b) ? a : b;
-}
+TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) {
+ float af = __bfloat162float(a);
+ float bf = __bfloat162float(b);
+ return __float2bfloat16(fminf(af, bf));
+}
+
+TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) {
+ float af = __bfloat162float(a);
+ float bf = __bfloat162float(b);
+ return __float2bfloat16(fmaxf(af, bf));
+}Optional: avoid Windows macro collisions
+#if defined(_WIN32)
+# ifdef min
+# undef min
+# endif
+# ifdef max
+# undef max
+# endif
+#endif📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) { | |
| return (a < b) ? a : b; | |
| } | |
| TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) { | |
| return (a > b) ? a : b; | |
| } | |
| TL_DEVICE bfloat16_t min(const bfloat16_t a, const bfloat16_t b) { | |
| float af = __bfloat162float(a); | |
| float bf = __bfloat162float(b); | |
| return __float2bfloat16(fminf(af, bf)); | |
| } | |
| TL_DEVICE bfloat16_t max(const bfloat16_t a, const bfloat16_t b) { | |
| float af = __bfloat162float(a); | |
| float bf = __bfloat162float(b); | |
| return __float2bfloat16(fmaxf(af, bf)); | |
| } |
🤖 Prompt for AI Agents
In src/tl_templates/cuda/cuda_bf16_wrapper.h around lines 24 to 30, the current
ternary implementations of min/max use operator< which does not follow IEEE
fmin/fmax semantics for NaN and signed zero; replace them by promoting the
bfloat16 operands to float, call fminf/fmaxf to preserve NaN propagation and
signed-zero behavior, then convert the result back to bfloat16. Also guard
against Windows macro collisions by undefining min/max macros (or use
parentheses around fminf/fmaxf) before defining these functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
src/target/codegen_cuda.cc (1)
2621-2625: Critical: Finite bfloat16 constants don't setenable_bf16_flag.This is the same issue flagged in the previous review. Finite bfloat16 constants (lines 2621-2624) don't set
enable_bf16_ = true, while infinity and NaN cases do (lines 2616, 2619). This means kernels that only use finite bf16 constants won't include the necessary headers, causing compilation failures.Apply this diff to fix:
} else { temp << "bfloat16_t"; temp << '(' << std::hexfloat << op->value << 'f'; temp << "/*" << std::scientific << op->value << "*/"; temp << ')'; + p->enable_bf16_ = true; } p->MarkConst(temp.str()); os << temp.str();
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/target/codegen_cuda.cc(5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (2)
src/target/codegen_hip.cc (10)
VisitExpr_(611-639)VisitExpr_(611-611)VisitExpr_(759-976)VisitExpr_(759-759)VisitExpr_(1039-1052)VisitExpr_(1039-1039)VisitExpr_(1054-1169)VisitExpr_(1054-1055)VisitExpr_(1217-1220)VisitExpr_(1217-1218)src/tl_templates/cpp/half.hpp (4)
isinf(2676-2676)isinf(5290-5290)isnan(2677-2677)isnan(5298-5298)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
src/target/codegen_cuda.cc
Outdated
|
|
||
| if (enable_bf16_) { | ||
| decl_stream << "#include <tl_templates/cuda/cuda_bf16_fallbacks.cuh>\n"; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify that enable_bf16_ is set for all bfloat16 usage.
The conditional inclusion is good practice, but it depends on enable_bf16_ being correctly set whenever bfloat16 types or operations are used. A previous review identified that finite bfloat16 constants don't set this flag (see lines 2621-2624), which could cause missing headers.
Verify that all bfloat16 code paths set enable_bf16_ = true:
🏁 Script executed:
#!/bin/bash
# Check all locations where bfloat16 types are handled
rg -n -C3 'is_bfloat16\(\)|\.is_bfloat16|bfloat16_t' src/target/codegen_cuda.ccLength of output: 4166
Add enable_bf16_ = true; for finite bfloat16 constants at lines 2620-2624.
The conditional header inclusion at lines 283-286 depends on enable_bf16_ being set. However, the finite bfloat16 constant branch (lines 2620-2624) generates bfloat16_t casts without setting this flag. This causes kernels using only finite bfloat16 constants to skip the required header while still using the bfloat16_t type, leading to compilation failures.
Add p->enable_bf16_ = true; in the else branch at line 2620.
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 2620 to 2624 and 283 to 286, the
finite bfloat16 constant branch creates bfloat16_t casts but does not set
enable_bf16_, causing the conditional header inclusion to be skipped and leading
to compilation failures; fix this by adding p->enable_bf16_ = true; in the else
branch around line 2620 so that the cuda_bf16_fallbacks.cuh header is included
when bfloat16_t is emitted.
|
I've just refactored the code to directly use CUTLASS apis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
src/target/codegen_cuda.cc (3)
1072-1077: Fix control flow: float scalars with non-standard bit widths generate no code.Same issue as in
VisitExpr_(const MinNode*)- whent.is_float() && t.is_scalar()is true butbits()is neither 32 nor 64, the function returns without emitting any code.Apply this diff:
// For float32 and float64 scalar, use standard max functions if (t.is_float() && t.is_scalar()) { if (t.bits() == 32 || t.bits() == 64) { os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + return; } - return; } // For all other scalar types (int, uint), use default implementation CodeGenC::VisitExpr_(op, os);
1050-1055: Fix control flow: float scalars with non-standard bit widths generate no code.When
t.is_float() && t.is_scalar()is true butbits()is neither 32 nor 64, the function returns without emitting any code. This results in missing min operations for float8, float16 (if it somehow reaches here), or other float types.Apply this diff to fix the control flow:
// For float32 and float64 scalar, use standard min functions if (t.is_float() && t.is_scalar()) { if (t.bits() == 32 || t.bits() == 64) { os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")"; + return; } - return; } // For all other scalar types (int, uint), use default implementation CodeGenC::VisitExpr_(op, os);This ensures non-32/64-bit float scalars fall through to the base implementation.
2591-2608: Setenable_bf16_flag for bfloat16 constants to ensure header inclusion.When handling bfloat16 constants (including infinity and NaN), the code doesn't set
p->enable_bf16_ = true. The conditional header inclusion depends on this flag being set, so kernels that only use bfloat16 constants will fail to compile with missing type definitions.Apply this diff:
if (op->dtype.is_bfloat16() || op->dtype.is_float16()) { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { temp << "-"; } temp << "std::numeric_limits<" << op->dtype << ">::infinity()"; + if (op->dtype.is_bfloat16()) { + p->enable_bf16_ = true; + } } else if (std::isnan(op->value)) { temp << "std::numeric_limits<" << op->dtype << ">::quiet_NaN()"; + if (op->dtype.is_bfloat16()) { + p->enable_bf16_ = true; + } } else { temp << op->dtype; temp << '(' << std::hexfloat << op->value << 'f'; temp << "/*" << std::scientific << op->value << "*/"; temp << ')'; + if (op->dtype.is_bfloat16()) { + p->enable_bf16_ = true; + } } p->MarkConst(temp.str()); os << temp.str();
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)
1066-1066: Minor style: Missing space in condition.Line 1066 has
t.is_float16())&&which should bet.is_float16()) &&for consistency with coding style.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/target/codegen_cuda.cc(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
VisitExpr_(611-639)VisitExpr_(611-611)VisitExpr_(759-976)VisitExpr_(759-759)VisitExpr_(1039-1052)VisitExpr_(1039-1039)VisitExpr_(1054-1169)VisitExpr_(1054-1055)VisitExpr_(1217-1220)VisitExpr_(1217-1218)
🔇 Additional comments (1)
src/target/codegen_cuda.cc (1)
1020-1021: LGTM: Early return prevents fall-through.The early return after FP8 conversion correctly prevents the code from continuing to the element-wise cast fallback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/target/codegen_cuda.cc (1)
2595-2607: Fix invalid C++ type identifiers in std::numeric_limits calls.When
op->dtypeis streamed (lines 2599, 2601, 2603), it outputs TVM's dtype representation (e.g., "bfloat16", "float16") rather than valid CUDA type names (e.g., "bfloat16_t", "half_t"). This generates invalid C++ code likestd::numeric_limits<bfloat16>::infinity()that won't compile.Replace dtype streaming with explicit CUDA type names, following the pattern used at lines 347 and 404:
if (op->dtype.is_bfloat16() || op->dtype.is_float16()) { std::ostringstream temp; + std::string cuda_type = op->dtype.is_bfloat16() ? "bfloat16_t" : "half_t"; if (std::isinf(op->value)) { if (op->value < 0) { temp << "-"; } - temp << "std::numeric_limits<" << op->dtype << ">::infinity()"; + temp << "std::numeric_limits<" << cuda_type << ">::infinity()"; } else if (std::isnan(op->value)) { - temp << "std::numeric_limits<" << op->dtype << ">::quiet_NaN()"; + temp << "std::numeric_limits<" << cuda_type << ">::quiet_NaN()"; } else { - temp << op->dtype; + temp << cuda_type; temp << '(' << std::hexfloat << op->value << 'f'; temp << "/*" << std::scientific << op->value << "*/"; temp << ')'; }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/target/codegen_cuda.cc(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (10)
VisitExpr_(611-639)VisitExpr_(611-611)VisitExpr_(759-976)VisitExpr_(759-759)VisitExpr_(1039-1052)VisitExpr_(1039-1039)VisitExpr_(1054-1169)VisitExpr_(1054-1055)VisitExpr_(1217-1220)VisitExpr_(1217-1218)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (1)
src/target/codegen_cuda.cc (1)
1039-1083: LGTM! Well-structured min/max visitors for CUDA types.The type-dispatched implementations correctly route scalar bfloat16/float16 to CUTLASS optimized functions, scalar float32/64 to standard min/max, and other types to the base implementation. The control flow is clear and complete.
| p->MarkConst(temp.str()); | ||
| os << temp.str(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Set enable flags for consistency when emitting bf16/fp16 constants.
When emitting bfloat16_t and half_t type names in constant expressions, the corresponding feature flags should be set for consistency with PrintType (lines 345, 402) and to signal that CUDA-specific types are in use.
temp << ')';
}
+ if (op->dtype.is_bfloat16()) {
+ p->enable_bf16_ = true;
+ } else {
+ p->enable_fp16_ = true;
+ }
p->MarkConst(temp.str());
os << temp.str();📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| p->MarkConst(temp.str()); | |
| os << temp.str(); | |
| if (op->dtype.is_bfloat16()) { | |
| p->enable_bf16_ = true; | |
| } else { | |
| p->enable_fp16_ = true; | |
| } | |
| p->MarkConst(temp.str()); | |
| os << temp.str(); |
🤖 Prompt for AI Agents
In src/target/codegen_cuda.cc around lines 2608-2609, when emitting bf16/fp16
constant type names you must set the corresponding CUDA feature flags before
marking and writing the constant: if the emitted name is "bfloat16_t" set
p->enable_bfloat16 = true; if the emitted name is "half_t" set p->enable_fp16 =
true; then call p->MarkConst(temp.str()) and os << temp.str() so PrintType and
other code see the correct feature usage.
…max functions and inf/nan values (tile-ai#1143) * Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values * refactor * fix prev typo * bugfix * lint * bugfix
This pull request improves CUDA code generation for bfloat16 and other floating-point types, particularly for min/max operations and constant printing. The changes ensure correct handling of bfloat16, float16, and float32/64 types, and conditionally include necessary headers only when bfloat16 support is required.
Enhanced min/max operation support
VisitExpr_methods forMinNodeandMaxNodeinCodeGenTileLangCUDA, ensuring correct CUDA code generation for bfloat16 (uses custom functions), float16 (uses__hmin/__hmax), and float32/64 (uses standardmin/max). [1] [2]minandmaxfunctions forbfloat16_tincuda_bf16_wrapper.hto support comparisons, replacing conditional compilation with unconditional inclusion.Improved bfloat16 constant printing
PrintConstto correctly handle bfloat16 constants, including special cases for infinity and NaN, and to mark when bfloat16 support is needed.Conditional header inclusion
Finish()to includecuda_bf16_fallbacks.cuhonly if bfloat16 is used, removing the previous preprocessor guard.Minor clarifications
Summary by CodeRabbit
New Features
Bug Fixes / Improvements