-
Notifications
You must be signed in to change notification settings - Fork 446
Fix fp4 pointer arithmetic in CUDA codegen #1524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughPre-simplified index division in GetBufferRef for 4-bit, 1-bit, and fp4 scalar types using Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/target/codegen_cuda.cc (1)
1329-1338: Fix looks correct, but consider refactoring to reduce duplication.The fp4 pointer arithmetic fix correctly addresses the issue where scalar fp4 pointer arithmetic was treating each 4-bit element as a full byte. The division by 2 accounts for packing 2 fp4 elements per byte.
However, there's code duplication with lines 1322-1323. Consider refactoring:
🔎 Suggested refactoring to reduce duplication
} else if (t == buffer_element_dtype) { os << buffer_str << "[" << index_str << "]"; } else { - // Fix fp4 pointer arithmetic: fp4 elements are 4-bit packed 2 per byte. - // fp4* + n incorrectly advances n bytes (skipping 2n elements). - int div_factor = 1; - if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { - div_factor = 2; - } - index_str = - PrintExpr(arith::Analyzer().Simplify(truncdiv(index, div_factor))); - + // Fix fp4 pointer arithmetic: scalar fp4 elements are 4-bit packed, 2 per byte. + // Since fp4* + n advances n bytes, divide index by 2 to get correct byte offset. + if (buffer_element_dtype.is_float4() && buffer_element_dtype.lanes() == 1) { + index_str = PrintExpr(arith::Analyzer().Simplify(truncdiv(index, 2))); + } os << "*" << ptr_cast(t) << "(" << buffer_str << " + " << index_str << ")"; }This eliminates the
div_factorvariable and makes the logic more explicit.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
src/target/codegen_cuda.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
tilelang/language/tir/op.py (1)
truncdiv(3119-3142)
🔇 Additional comments (2)
src/target/codegen_cuda.cc (2)
1322-1325: LGTM: Pre-simplification improves code generation.The pre-simplification of the index division using
arith::Analyzer().Simplify(truncdiv(index, div_factor))is a good approach. This allows the analyzer to fold constants and simplify expressions before code generation, potentially producing more efficient pointer arithmetic code.
2189-2189: Minor formatting fix.Removed extraneous whitespace - no functional change.
Summary by CodeRabbit
Bug Fixes
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.