Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Jan 28, 2026

For issue #1697

Summary by CodeRabbit

  • Bug Fixes

    • Relaxed stride validation to accept zero-dimension edge cases and handle symbolic strides more flexibly, reducing false stride errors.
  • Tests

    • Added a new test validating JIT GEMM behavior when a matrix dimension (K) is zero on CUDA.
    • Removed an outdated/commented-out test invocation.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

Added ArgBinder::RelaxedStrideCheck and applied it across stride-binding code paths to accept zero or symbolic strides; added a new CUDA-backed JIT GEMM test exercising K=0; removed a commented-out test invocation in a language test file.

Changes

Cohort / File(s) Summary
ArgBinder Stride Validation Relaxation
src/transform/arg_binder.h, src/transform/arg_binder.cc
Added ArgBinder::RelaxedStrideCheck(...) and replaced stricter per-dimension stride checks with relaxed validation that accepts zero strides and handles symbolic/concrete stride cases across compact, packed/subtype, and normal binding flows.
JIT GEMM Kernel Test
testing/python/issue/test_tilelang_issue_1697.py
New test module adding matmu_jit_kernel, run_gemm_jit_kernel, and test_gemm_jit_kernel_zero_dim() to compile and run a TileLang JIT GEMM with K=0 on CUDA and validate results against torch.matmul.
Test Cleanup
testing/python/language/test_tilelang_language_rand.py
Removed a commented-out test_rand_1d invocation from the __main__ block.

Sequence Diagram(s)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

enhancement

Suggested reviewers

  • LeiWang1999

Poem

🐇 I hopped through strides both strict and slow,
Found zeros hiding where tensors grow,
I nudged the checks to let kernels start,
A CUDA GEMM with empty K—smart! 🥕

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: adding stride check and fixes for tensors with zero-stride arguments, which aligns with all file modifications.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@tilelang/jit/adapter/tvm_ffi.py`:
- Around line 171-174: The unconditional diagnostic print in the block where
self.executable is created (inside the code that calls
runtime.Executable(self.rt_mod) and prints self.rt_mod.inspect_source()) should
be guarded; modify the initialization in the method that sets self.executable so
the print of self.rt_mod.inspect_source() only runs when a verbosity flag or
logger level indicates debug/verbose (e.g., check a self.verbose attribute or
use an injected logger) or replace it with a logger.debug call, leaving the
runtime.Executable(self.rt_mod) and COMPILE_ARGS logic unchanged and ensuring
the print does not execute during normal runs.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/transform/arg_binder.cc`:
- Around line 988-1016: Remove the unused local variable shape_stride that is
computed but never used; locate the computation of shape_stride near the stride
handling logic around buffer->strides[k] (the block that calls
stride_element_name(k), BinderAddAssert, and BindNullable) and delete the
shape_stride declaration and its assignment so the code only keeps the existing
stride checks and BinderAddAssert/BindNullable calls.
- Around line 942-970: In the kAutoBroadcast loop the local variable
stride_from_shape is initialized but never updated, so compact stride
computation stays 1; at the end of each loop iteration multiply
stride_from_shape by the current dimension (i.e., update stride_from_shape =
stride_from_shape * buffer->shape[k]) so subsequent iterations accumulate the
product of shapes; modify the loop that contains stride_from_shape,
BindNullable, and the stride checks (references: stride_from_shape,
kAutoBroadcast, BindNullable, BinderAddAssert) to perform this update before the
next iteration.
🧹 Nitpick comments (4)
src/transform/arg_binder.cc (1)

830-859: Compile-time warning may not fire for symbolic zero strides.

The LOG(WARNING) on line 837 only triggers when analyzer_.Simplify(expected) evaluates to zero at compile time. For symbolic strides that depend on runtime values, the warning won't fire even if the stride is actually zero at runtime. Consider adding a runtime warning via a tvm_call_packed similar to other error paths, or document that the compile-time warning is best-effort.

Additionally, this pattern is duplicated across four locations (lines 830-859, 890-896, 942-970, 988-1016). Consider extracting a helper function to reduce duplication.

🔧 Example helper function to reduce duplication
// Helper to bind stride with zero-dimension relaxation
void BindStrideWithZeroRelaxation(
    arith::Analyzer* analyzer,
    const PrimExpr& expected,
    const PrimExpr& actual_stride,
    const std::string& element_name,
    std::vector<Stmt>* asserts,
    const PrimExpr& is_null) {
  if (is_zero(analyzer->Simplify(expected))) {
    LOG(WARNING) << "TileLang: Detected zero-dimension in "
                 << element_name << ". Relaxing stride check.";
  }
  PrimExpr cond = (expected == actual_stride) || (expected == 0);
  BinderAddAssert(analyzer, cond, element_name, asserts, is_null);
}
testing/python/transform/test_tilelang_transform_arg_binder.py (3)

8-51: Consider extracting shared kernel definitions to a common module.

The matmu_jit_kernel function is nearly identical to the one in testing/python/jit/test_tilelang_jit_tvm_ffi.py. To reduce duplication and ease maintenance, consider creating a shared test utilities module (e.g., testing/python/common/kernels.py) that both test files can import from.


54-108: Same duplication concern and minor variable shadowing.

This function is also duplicated from test_tilelang_jit_tvm_ffi.py. Additionally, the parameters in_dtype and out_dtype are shadowed by reassignment at lines 87-88. While this works correctly, consider using different variable names for clarity:

♻️ Minor clarity improvement
-    in_dtype = map_torch_type(in_dtype)
-    out_dtype = map_torch_type(out_dtype)
+    torch_in_dtype = map_torch_type(in_dtype)
+    torch_out_dtype = map_torch_type(out_dtype)

-    A = torch.randn(M, K, dtype=in_dtype).cuda()
-    B = torch.randn(K, N, dtype=in_dtype).cuda()
+    A = torch.randn(M, K, dtype=torch_in_dtype).cuda()
+    B = torch.randn(K, N, dtype=torch_in_dtype).cuda()
     ...
-        C = C.to(out_dtype)
+        C = C.to(torch_out_dtype)

111-125: Test correctly exercises zero K dimension; consider expanding coverage.

The test appropriately validates the K=0 edge case which exercises the relaxed stride checking in arg_binder.cc. When K=0, the pipelined loop executes zero iterations, and the output should be zeros (matching T.clear(C_local)).

Consider adding tests for other zero-dimension cases (M=0, N=0) to ensure comprehensive coverage of the stride relaxation feature. Also, the test name could be more specific, e.g., test_gemm_jit_kernel_zero_k_dim.

Comment on lines 942 to 970
// Relax stride check: if the expected stride is 0, allow any actual
// stride. This happens when one of the subsequent dimensions is 0.
if (const VarNode *v = buffer->strides[k].as<VarNode>()) {
auto it = def_map_->find(v);
if (it != def_map_->end()) {
PrimExpr expected = it->second;
if (is_zero(analyzer_.Simplify(expected))) {
LOG(WARNING)
<< "TileLang: Detected zero-dimension in "
<< stride_element_name(k) << ". Relaxing stride check.";
}
PrimExpr cond = (expected == stride_val) || (expected == 0);
BinderAddAssert(&analyzer_, cond, stride_element_name(k),
&asserts_, is_null);
} else {
BindNullable(buffer->strides[k], stride_val,
stride_element_name(k), true, is_null);
}
} else {
PrimExpr expected = buffer->strides[k];
if (is_zero(analyzer_.Simplify(expected))) {
LOG(WARNING) << "TileLang: Detected zero-dimension in "
<< stride_element_name(k)
<< ". Relaxing stride check.";
}
PrimExpr cond = (expected == stride_val) || (expected == 0);
BinderAddAssert(&analyzer_, cond, stride_element_name(k), &asserts_,
is_null);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, let's get the full context around lines 930-970 in arg_binder.cc
sed -n '920,975p' src/transform/arg_binder.cc

Repository: tile-ai/tilelang

Length of output: 2700


🏁 Script executed:

# Check lines just after the loop to see if stride_from_shape is updated
sed -n '960,980p' src/transform/arg_binder.cc

Repository: tile-ai/tilelang

Length of output: 956


Update stride_from_shape at the end of each loop iteration in the kAutoBroadcast path.

The stride_from_shape variable is initialized to 1 at line 931 but is never updated within the loop (lines 932-968). This causes incorrect stride computation for auto-broadcast tensors with null strides. When iterating backward through dimensions, stride_from_shape should accumulate the product of shape dimensions: stride_from_shape = stride_from_shape * buffer->shape[k]; at the end of each iteration. Without this, compact stride values will always be 1 regardless of tensor dimensions.

🤖 Prompt for AI Agents
In `@src/transform/arg_binder.cc` around lines 942 - 970, In the kAutoBroadcast
loop the local variable stride_from_shape is initialized but never updated, so
compact stride computation stays 1; at the end of each loop iteration multiply
stride_from_shape by the current dimension (i.e., update stride_from_shape =
stride_from_shape * buffer->shape[k]) so subsequent iterations accumulate the
product of shapes; modify the loop that contains stride_from_shape,
BindNullable, and the stride checks (references: stride_from_shape,
kAutoBroadcast, BindNullable, BinderAddAssert) to perform this update before the
next iteration.

@tzj-fxz tzj-fxz changed the title [BugFix] Stride check and fix for empty tensor in tvm_ffi [BugFix] Stride check and fix for tensors with zero-stride argument Jan 29, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/issue/test_tilelang_issue_1697.py`:
- Around line 90-109: The test unconditionally calls .cuda() via
run_gemm_jit_kernel and should be skipped on non-CUDA hosts; add the decorator
`@tilelang.testing.requires_cuda` directly above the test function definition
test_gemm_jit_kernel_zero_dim so the test is skipped when CUDA is unavailable,
keeping the rest of the function and the helper run_gemm_jit_kernel unchanged
and referencing the tilelang.testing.requires_cuda symbol.

Comment on lines 90 to 109
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()

if trans_A:
A = A.T
if trans_B:
B = B.T

def ref_program(A, B):
import torch

C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(out_dtype)
return C

ref_C = ref_program(A, B)
C = matmul_kernel(A, B)

tilelang.testing.torch_assert_close(C, ref_C, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, locate and read the test file to verify the code at lines 90-109
fd "test_tilelang_issue_1697.py" -x cat -n {}

Repository: tile-ai/tilelang

Length of output: 1941


🏁 Script executed:

# Check for existing CUDA-skip patterns in the test suite
rg -n "cuda.is_available|requires_cuda|pytest.skip.*cuda" --type py -A 2

Repository: tile-ai/tilelang

Length of output: 50372


🏁 Script executed:

# Look for other CUDA guard patterns in test files
rg -n "\.cuda\(\)" testing/ --type py -B 3 -A 1 | head -100

Repository: tile-ai/tilelang

Length of output: 7574


🏁 Script executed:

# Verify the test function definition to check for existing decorators
sed -n '40,45p' testing/python/issue/test_tilelang_issue_1697.py

Repository: tile-ai/tilelang

Length of output: 178


🏁 Script executed:

# Check how other tests in the issue directory handle CUDA requirements
ls -la testing/python/issue/

Repository: tile-ai/tilelang

Length of output: 1889


Add CUDA requirement decorator to the test function.

The test test_gemm_jit_kernel_zero_dim (line 41) unconditionally calls .cuda() in the helper function run_gemm_jit_kernel (lines 33–34) and will fail on systems without CUDA. Add the @tilelang.testing.requires_cuda decorator to the test function to skip it when CUDA is unavailable, matching the pattern used throughout the codebase.

Example fix
+@tilelang.testing.requires_cuda
 def test_gemm_jit_kernel_zero_dim():
     run_gemm_jit_kernel(512, 1024, 0, 128, 256, 32)
🤖 Prompt for AI Agents
In `@testing/python/issue/test_tilelang_issue_1697.py` around lines 90 - 109, The
test unconditionally calls .cuda() via run_gemm_jit_kernel and should be skipped
on non-CUDA hosts; add the decorator `@tilelang.testing.requires_cuda` directly
above the test function definition test_gemm_jit_kernel_zero_dim so the test is
skipped when CUDA is unavailable, keeping the rest of the function and the
helper run_gemm_jit_kernel unchanged and referencing the
tilelang.testing.requires_cuda symbol.

LeiWang1999
LeiWang1999 previously approved these changes Jan 29, 2026
@tzj-fxz tzj-fxz merged commit 9ddf577 into tile-ai:main Jan 29, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants