-
Notifications
You must be signed in to change notification settings - Fork 441
[BugFix] Stride check and fix for tensors with zero-stride argument #1749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdded ArgBinder::RelaxedStrideCheck and applied it across stride-binding code paths to accept zero or symbolic strides; added a new CUDA-backed JIT GEMM test exercising K=0; removed a commented-out test invocation in a language test file. Changes
Sequence Diagram(s)Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@tilelang/jit/adapter/tvm_ffi.py`:
- Around line 171-174: The unconditional diagnostic print in the block where
self.executable is created (inside the code that calls
runtime.Executable(self.rt_mod) and prints self.rt_mod.inspect_source()) should
be guarded; modify the initialization in the method that sets self.executable so
the print of self.rt_mod.inspect_source() only runs when a verbosity flag or
logger level indicates debug/verbose (e.g., check a self.verbose attribute or
use an injected logger) or replace it with a logger.debug call, leaving the
runtime.Executable(self.rt_mod) and COMPILE_ARGS logic unchanged and ensuring
the print does not execute during normal runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/transform/arg_binder.cc`:
- Around line 988-1016: Remove the unused local variable shape_stride that is
computed but never used; locate the computation of shape_stride near the stride
handling logic around buffer->strides[k] (the block that calls
stride_element_name(k), BinderAddAssert, and BindNullable) and delete the
shape_stride declaration and its assignment so the code only keeps the existing
stride checks and BinderAddAssert/BindNullable calls.
- Around line 942-970: In the kAutoBroadcast loop the local variable
stride_from_shape is initialized but never updated, so compact stride
computation stays 1; at the end of each loop iteration multiply
stride_from_shape by the current dimension (i.e., update stride_from_shape =
stride_from_shape * buffer->shape[k]) so subsequent iterations accumulate the
product of shapes; modify the loop that contains stride_from_shape,
BindNullable, and the stride checks (references: stride_from_shape,
kAutoBroadcast, BindNullable, BinderAddAssert) to perform this update before the
next iteration.
🧹 Nitpick comments (4)
src/transform/arg_binder.cc (1)
830-859: Compile-time warning may not fire for symbolic zero strides.The
LOG(WARNING)on line 837 only triggers whenanalyzer_.Simplify(expected)evaluates to zero at compile time. For symbolic strides that depend on runtime values, the warning won't fire even if the stride is actually zero at runtime. Consider adding a runtime warning via atvm_call_packedsimilar to other error paths, or document that the compile-time warning is best-effort.Additionally, this pattern is duplicated across four locations (lines 830-859, 890-896, 942-970, 988-1016). Consider extracting a helper function to reduce duplication.
🔧 Example helper function to reduce duplication
// Helper to bind stride with zero-dimension relaxation void BindStrideWithZeroRelaxation( arith::Analyzer* analyzer, const PrimExpr& expected, const PrimExpr& actual_stride, const std::string& element_name, std::vector<Stmt>* asserts, const PrimExpr& is_null) { if (is_zero(analyzer->Simplify(expected))) { LOG(WARNING) << "TileLang: Detected zero-dimension in " << element_name << ". Relaxing stride check."; } PrimExpr cond = (expected == actual_stride) || (expected == 0); BinderAddAssert(analyzer, cond, element_name, asserts, is_null); }testing/python/transform/test_tilelang_transform_arg_binder.py (3)
8-51: Consider extracting shared kernel definitions to a common module.The
matmu_jit_kernelfunction is nearly identical to the one intesting/python/jit/test_tilelang_jit_tvm_ffi.py. To reduce duplication and ease maintenance, consider creating a shared test utilities module (e.g.,testing/python/common/kernels.py) that both test files can import from.
54-108: Same duplication concern and minor variable shadowing.This function is also duplicated from
test_tilelang_jit_tvm_ffi.py. Additionally, the parametersin_dtypeandout_dtypeare shadowed by reassignment at lines 87-88. While this works correctly, consider using different variable names for clarity:♻️ Minor clarity improvement
- in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) + torch_in_dtype = map_torch_type(in_dtype) + torch_out_dtype = map_torch_type(out_dtype) - A = torch.randn(M, K, dtype=in_dtype).cuda() - B = torch.randn(K, N, dtype=in_dtype).cuda() + A = torch.randn(M, K, dtype=torch_in_dtype).cuda() + B = torch.randn(K, N, dtype=torch_in_dtype).cuda() ... - C = C.to(out_dtype) + C = C.to(torch_out_dtype)
111-125: Test correctly exercises zero K dimension; consider expanding coverage.The test appropriately validates the K=0 edge case which exercises the relaxed stride checking in
arg_binder.cc. When K=0, the pipelined loop executes zero iterations, and the output should be zeros (matchingT.clear(C_local)).Consider adding tests for other zero-dimension cases (M=0, N=0) to ensure comprehensive coverage of the stride relaxation feature. Also, the test name could be more specific, e.g.,
test_gemm_jit_kernel_zero_k_dim.
src/transform/arg_binder.cc
Outdated
| // Relax stride check: if the expected stride is 0, allow any actual | ||
| // stride. This happens when one of the subsequent dimensions is 0. | ||
| if (const VarNode *v = buffer->strides[k].as<VarNode>()) { | ||
| auto it = def_map_->find(v); | ||
| if (it != def_map_->end()) { | ||
| PrimExpr expected = it->second; | ||
| if (is_zero(analyzer_.Simplify(expected))) { | ||
| LOG(WARNING) | ||
| << "TileLang: Detected zero-dimension in " | ||
| << stride_element_name(k) << ". Relaxing stride check."; | ||
| } | ||
| PrimExpr cond = (expected == stride_val) || (expected == 0); | ||
| BinderAddAssert(&analyzer_, cond, stride_element_name(k), | ||
| &asserts_, is_null); | ||
| } else { | ||
| BindNullable(buffer->strides[k], stride_val, | ||
| stride_element_name(k), true, is_null); | ||
| } | ||
| } else { | ||
| PrimExpr expected = buffer->strides[k]; | ||
| if (is_zero(analyzer_.Simplify(expected))) { | ||
| LOG(WARNING) << "TileLang: Detected zero-dimension in " | ||
| << stride_element_name(k) | ||
| << ". Relaxing stride check."; | ||
| } | ||
| PrimExpr cond = (expected == stride_val) || (expected == 0); | ||
| BinderAddAssert(&analyzer_, cond, stride_element_name(k), &asserts_, | ||
| is_null); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's get the full context around lines 930-970 in arg_binder.cc
sed -n '920,975p' src/transform/arg_binder.ccRepository: tile-ai/tilelang
Length of output: 2700
🏁 Script executed:
# Check lines just after the loop to see if stride_from_shape is updated
sed -n '960,980p' src/transform/arg_binder.ccRepository: tile-ai/tilelang
Length of output: 956
Update stride_from_shape at the end of each loop iteration in the kAutoBroadcast path.
The stride_from_shape variable is initialized to 1 at line 931 but is never updated within the loop (lines 932-968). This causes incorrect stride computation for auto-broadcast tensors with null strides. When iterating backward through dimensions, stride_from_shape should accumulate the product of shape dimensions: stride_from_shape = stride_from_shape * buffer->shape[k]; at the end of each iteration. Without this, compact stride values will always be 1 regardless of tensor dimensions.
🤖 Prompt for AI Agents
In `@src/transform/arg_binder.cc` around lines 942 - 970, In the kAutoBroadcast
loop the local variable stride_from_shape is initialized but never updated, so
compact stride computation stays 1; at the end of each loop iteration multiply
stride_from_shape by the current dimension (i.e., update stride_from_shape =
stride_from_shape * buffer->shape[k]) so subsequent iterations accumulate the
product of shapes; modify the loop that contains stride_from_shape,
BindNullable, and the stride checks (references: stride_from_shape,
kAutoBroadcast, BindNullable, BinderAddAssert) to perform this update before the
next iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@testing/python/issue/test_tilelang_issue_1697.py`:
- Around line 90-109: The test unconditionally calls .cuda() via
run_gemm_jit_kernel and should be skipped on non-CUDA hosts; add the decorator
`@tilelang.testing.requires_cuda` directly above the test function definition
test_gemm_jit_kernel_zero_dim so the test is skipped when CUDA is unavailable,
keeping the rest of the function and the helper run_gemm_jit_kernel unchanged
and referencing the tilelang.testing.requires_cuda symbol.
| A = torch.randn(M, K, dtype=in_dtype).cuda() | ||
| B = torch.randn(K, N, dtype=in_dtype).cuda() | ||
|
|
||
| if trans_A: | ||
| A = A.T | ||
| if trans_B: | ||
| B = B.T | ||
|
|
||
| def ref_program(A, B): | ||
| import torch | ||
|
|
||
| C = torch.matmul(A.to(torch.float), B.to(torch.float)) | ||
| C = C.to(out_dtype) | ||
| return C | ||
|
|
||
| ref_C = ref_program(A, B) | ||
| C = matmul_kernel(A, B) | ||
|
|
||
| tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, locate and read the test file to verify the code at lines 90-109
fd "test_tilelang_issue_1697.py" -x cat -n {}Repository: tile-ai/tilelang
Length of output: 1941
🏁 Script executed:
# Check for existing CUDA-skip patterns in the test suite
rg -n "cuda.is_available|requires_cuda|pytest.skip.*cuda" --type py -A 2Repository: tile-ai/tilelang
Length of output: 50372
🏁 Script executed:
# Look for other CUDA guard patterns in test files
rg -n "\.cuda\(\)" testing/ --type py -B 3 -A 1 | head -100Repository: tile-ai/tilelang
Length of output: 7574
🏁 Script executed:
# Verify the test function definition to check for existing decorators
sed -n '40,45p' testing/python/issue/test_tilelang_issue_1697.pyRepository: tile-ai/tilelang
Length of output: 178
🏁 Script executed:
# Check how other tests in the issue directory handle CUDA requirements
ls -la testing/python/issue/Repository: tile-ai/tilelang
Length of output: 1889
Add CUDA requirement decorator to the test function.
The test test_gemm_jit_kernel_zero_dim (line 41) unconditionally calls .cuda() in the helper function run_gemm_jit_kernel (lines 33–34) and will fail on systems without CUDA. Add the @tilelang.testing.requires_cuda decorator to the test function to skip it when CUDA is unavailable, matching the pattern used throughout the codebase.
Example fix
+@tilelang.testing.requires_cuda
def test_gemm_jit_kernel_zero_dim():
run_gemm_jit_kernel(512, 1024, 0, 128, 256, 32)🤖 Prompt for AI Agents
In `@testing/python/issue/test_tilelang_issue_1697.py` around lines 90 - 109, The
test unconditionally calls .cuda() via run_gemm_jit_kernel and should be skipped
on non-CUDA hosts; add the decorator `@tilelang.testing.requires_cuda` directly
above the test function definition test_gemm_jit_kernel_zero_dim so the test is
skipped when CUDA is unavailable, keeping the rest of the function and the
helper run_gemm_jit_kernel unchanged and referencing the
tilelang.testing.requires_cuda symbol.
For issue #1697
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.