-
Notifications
You must be signed in to change notification settings - Fork 331
[Language] Introduce StridedTensor to support non contigious torch inputs
#722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
fd51304
c6cf3de
9ae1f7f
c1a456d
b68882b
1f811d5
de8c4a5
1599ff1
45f3be6
f3a92a0
c8e1a1b
0cfe1f2
6ed0611
056c6a3
a12dfad
b91279e
48d9a8f
ebda917
c5b1a10
0a9d50f
9f26dbf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1689,6 +1689,76 @@ void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { | |||||||||||||||||||||||||||||||||||||||||||||
| os << "))"; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, | ||||||||||||||||||||||||||||||||||||||||||||||
| std::ostream &os) { // NOLINT(*) | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK_EQ(op->indices.size(), 1) | ||||||||||||||||||||||||||||||||||||||||||||||
| << "Load from non-flat memory not supported."; | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(!op->predicate.defined()) | ||||||||||||||||||||||||||||||||||||||||||||||
| << "Predicated buffer load is not supported."; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| DataType value_dtype = op->dtype; | ||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr index = op->indices[0]; | ||||||||||||||||||||||||||||||||||||||||||||||
| Var buffer_var = op->buffer->data; | ||||||||||||||||||||||||||||||||||||||||||||||
| DataType element_dtype = op->buffer->dtype; | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| int lanes = op->dtype.lanes(); | ||||||||||||||||||||||||||||||||||||||||||||||
| // delcare type. | ||||||||||||||||||||||||||||||||||||||||||||||
| if (value_dtype.lanes() == element_dtype.lanes()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); | ||||||||||||||||||||||||||||||||||||||||||||||
| HandleVolatileLoads(ref, op, os); | ||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||
| bool can_vector_load = false; | ||||||||||||||||||||||||||||||||||||||||||||||
| arith::PVar<PrimExpr> base; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { | ||||||||||||||||||||||||||||||||||||||||||||||
| const RampNode *ramp = index.as<RampNode>(); | ||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(ramp); | ||||||||||||||||||||||||||||||||||||||||||||||
| can_vector_load = true; | ||||||||||||||||||||||||||||||||||||||||||||||
| // arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); | ||||||||||||||||||||||||||||||||||||||||||||||
| // The condition: {k * coeff + base} divisible by the alignment for any k | ||||||||||||||||||||||||||||||||||||||||||||||
| // if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() | ||||||||||||||||||||||||||||||||||||||||||||||
| // == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||
| // can_vector_load = true; | ||||||||||||||||||||||||||||||||||||||||||||||
| // } | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1729
to
+1734
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. skip for now for dynamic strides. |
||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if (value_dtype.is_float4_e2m1fn() && lanes != 1) { | ||||||||||||||||||||||||||||||||||||||||||||||
| // A float4_e2m1fn element has 4 bits, which is an incomplete byte. | ||||||||||||||||||||||||||||||||||||||||||||||
| // So we cannot vector load it. | ||||||||||||||||||||||||||||||||||||||||||||||
| can_vector_load = false; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| if (can_vector_load) { | ||||||||||||||||||||||||||||||||||||||||||||||
| std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); | ||||||||||||||||||||||||||||||||||||||||||||||
| HandleVolatileLoads(ref, op, os); | ||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1743
to
+1745
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainConfirm GetVecLoad availability The vector path relies on GetVecLoad(dtype, buffer, base). Please confirm this overload is declared/defined for CUDA and matches the expected semantics (alignment, address space handling). Run the following to verify declaration in the header and definitions in the codebase: 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "Searching for VisitExpr_ BufferLoad declaration in codegen_cuda.h"
rg -n -C2 -P 'class\s+CodeGenTileLangCUDA\b(?s).*?\bVisitExpr_\s*\(\s*const\s+BufferLoadNode\s*\*\s*op\s*,\s*std::ostream\s*&\s*os\)' -- src
echo
echo "Searching for GetVecLoad declarations/definitions"
rg -n -C3 -P '\bGetVecLoad\s*\(' -- srcLength of output: 303 To be sure the 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "Searching for all GetVecLoad occurrences in codegen sources"
rg -n 'GetVecLoad' -g 'src/**/codegen_*.cc' -g 'src/**/codegen_*.h'Length of output: 326 Let’s search for any 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "Searching for explicit GetVecLoad definitions (e.g., std::string GetVecLoad(...))"
rg -nP 'std::string\s+GetVecLoad\s*\(' -- src
echo
echo "Searching for all GetVecLoad occurrences"
rg -n 'GetVecLoad' -- srcLength of output: 240 Missing • Add a declaration in This must be implemented before merging. 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||
| std::ostringstream svalue_expr; | ||||||||||||||||||||||||||||||||||||||||||||||
| std::string sindex = SSAGetID(PrintExpr(index), index.dtype()); | ||||||||||||||||||||||||||||||||||||||||||||||
| std::string vid = GetVarID(buffer_var.get()); | ||||||||||||||||||||||||||||||||||||||||||||||
| DataType elem_type = op->dtype.element_of(); | ||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < lanes; ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||
| std::ostringstream value_temp; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (!HandleTypeMatch(buffer_var.get(), elem_type)) { | ||||||||||||||||||||||||||||||||||||||||||||||
| value_temp << "(("; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (buffer_var.get()->dtype.is_handle()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto it = alloc_storage_scope_.find(buffer_var.get()); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (it != alloc_storage_scope_.end()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| PrintStorageScope(it->second, value_temp); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| PrintType(elem_type, value_temp); | ||||||||||||||||||||||||||||||||||||||||||||||
| value_temp << "*)" << vid << ')'; | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1753
to
+1761
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Avoid printing storage scope in C-style casts unless it’s part of the type In the per-lane fallback, the cast includes the storage scope unconditionally. Elsewhere (e.g., GetBufferRef) this is guarded by IsScopePartOfType() to avoid generating invalid types like (shared float*). Mirror that behavior here to prevent malformed code on targets where storage scope is not encoded in the type. - if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
- value_temp << "((";
- if (buffer_var.get()->dtype.is_handle()) {
- auto it = alloc_storage_scope_.find(buffer_var.get());
- if (it != alloc_storage_scope_.end()) {
- PrintStorageScope(it->second, value_temp);
- }
- }
- PrintType(elem_type, value_temp);
- value_temp << "*)" << vid << ')';
- } else {
- value_temp << vid;
- }
+ if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
+ value_temp << "((";
+ if (buffer_var.get()->dtype.is_handle()) {
+ auto it = alloc_storage_scope_.find(buffer_var.get());
+ if (it != alloc_storage_scope_.end() && IsScopePartOfType()) {
+ PrintStorageScope(it->second, value_temp);
+ }
+ }
+ PrintType(elem_type, value_temp);
+ value_temp << "*)" << vid << ')';
+ } else {
+ value_temp << vid;
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||
| value_temp << vid; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| value_temp << '['; | ||||||||||||||||||||||||||||||||||||||||||||||
| PrintVecElemLoad(sindex, index.dtype(), i, value_temp); | ||||||||||||||||||||||||||||||||||||||||||||||
| value_temp << ']'; | ||||||||||||||||||||||||||||||||||||||||||||||
| PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| os << svalue_expr.str(); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, | ||||||||||||||||||||||||||||||||||||||||||||||
| std::ostream &os) { // NOLINT(*) | ||||||||||||||||||||||||||||||||||||||||||||||
| int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -136,11 +136,23 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_vector_size = gcd_base; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Generate strides if not existed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto strides = buffer->strides; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (buffer->strides.size() == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr stride = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = indices.size() - 1; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| strides.push_back(stride); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stride = stride * buffer->shape[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| strides = Array<PrimExpr>{strides.rbegin(), strides.rend()}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Generate and check element offset expression | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(indices.size() == strides.size()) << "Invalid indices and strides"; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr elem_offset = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr stride = 1; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = indices.size() - 1; i >= 0; --i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elem_offset = elem_offset + indices[i] * stride; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stride = stride * buffer->shape[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int i = 0; i < indices.size(); ++i) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elem_offset += indices[i] * strides[i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+140
to
156
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Fix stride reversal and avoid constructing tvm::Array from iterators (compile-time bug).
Refactor to compute Apply this diff: - // Generate strides if not existed
- auto strides = buffer->strides;
- if (buffer->strides.size() == 0) {
- PrimExpr stride = 1;
- for (int i = indices.size() - 1; i >= 0; --i) {
- strides.push_back(stride);
- stride = stride * buffer->shape[i];
- }
- strides = Array<PrimExpr>{strides.rbegin(), strides.rend()};
- }
-
- // Generate and check element offset expression
- ICHECK(indices.size() == strides.size()) << "Invalid indices and strides";
- PrimExpr elem_offset = 0;
- for (int i = 0; i < indices.size(); ++i) {
- elem_offset += indices[i] * strides[i];
- }
+ // Compute element offset (supports both explicit and implicit strides)
+ PrimExpr elem_offset = 0;
+ if (buffer->strides.size() == 0) {
+ PrimExpr stride = 1;
+ for (int i = indices.size() - 1; i >= 0; --i) {
+ elem_offset += indices[i] * stride;
+ stride = stride * buffer->shape[i];
+ }
+ } else {
+ ICHECK(indices.size() == buffer->strides.size()) << "Invalid indices and strides";
+ for (int i = 0; i < indices.size(); ++i) {
+ elem_offset += indices[i] * buffer->strides[i];
+ }
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inner_for_->extent, vector_size_, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -229,10 +241,19 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ICHECK(target_vectorized_size >= 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (target_vectorized_size == 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // bind thread range | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Extent must be divisible | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| 0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // The base offset must be divisible | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (!analyzer->CanProveEqual( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FloorMod(Substitute(expr, {{var, 0}}), target_vectorized_size), 0)) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Bind thread range | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Var v0("v0"), v1("v1"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| analyzer->Bind(v0, Range(0, target_vectorized_size)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| analyzer->Bind(v1, Range(0, analyzer->Simplify(FloorDiv( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -241,7 +262,8 @@ bool IndiceCanVectorize(PrimExpr expr, Var var, PrimExpr iter_var_size, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Substitute(expr, {{var, v0 + v1 * target_vectorized_size}})); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // This simplify is necessary for thread region specifiled | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // This simplify is necessary for thread region specified | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // optimizations. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| expr_vectorized = analyzer->Simplify(expr_vectorized); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| auto ramp_node = expr_vectorized.as<RampNode>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,8 +28,8 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_idx=[1], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| target="cuda", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pass_configs={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "tl.disable_warp_specialized": True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "tl.disable_tma_lower": True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| b = kernel(a) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -42,5 +42,49 @@ def test_tilelang_copy(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def main( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| A: T.StridedTensor((M, N), (NN, 1), dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B: T.Tensor((M, N), dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Initialize Kernel Context | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i, j in T.Parallel(block_M, block_N): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return main | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_tilelang_copy_with_stride(M=1024, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N=1024, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NN=2048, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_M=128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_N=128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| dtype="float16"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(NN, int): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert NN > N, "NN must be greater than N" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kernel = tilelang.compile( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| program, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| out_idx=[1], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| target="cuda", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pass_configs={ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(NN, T.Var): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| NN = N * 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| a = torch.randn(M, NN, device="cuda", dtype=getattr(torch, dtype)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| b = kernel(a[:, :N]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.testing.assert_close(b, a[:, :N], rtol=1e-2, atol=1e-2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+66
to
+82
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainFix TypeError in isinstance check and allow NN == N
Apply this diff: - if isinstance(NN, int):
- assert NN > N, "NN must be greater than N"
+ if isinstance(NN, int):
+ assert NN >= N, "NN must be greater than or equal to N"
@@
- if isinstance(NN, T.Var):
- NN = N * 2
+ # If NN is symbolic (e.g., tvm.tir.Var), substitute a concrete allocation size.
+ if not isinstance(NN, int):
+ NN = N * 2Optional: If you prefer an explicit type check, import Run this quick scan to ensure similar mistaken 🏁 Script executed: #!/bin/bash
rg -nP -C2 'isinstance\s*\(\s*[^,]+,\s*T\.Var\s*\)'Length of output: 2709 🏁 Script executed: #!/bin/bash
# Display import aliases for T in test_tilelang_language_copy.py
sed -n '1,50p' testing/python/language/test_tilelang_language_copy.py | sed -n '1,50p'Length of output: 1713 🏁 Script executed: #!/bin/bash
# Display import aliases for T in the relevant test files
sed -n '1,20p' testing/python/jit/test_tilelang_jit_gemm_ctypes.py
sed -n '1,20p' testing/python/jit/test_tilelang_jit_gemm_cython.pyLength of output: 753 Fix symbolic-dimension checks and allow NN == N The
Affected files:
Apply these diffs: --- a/testing/python/language/test_tilelang_language_copy.py
+++ b/testing/python/language/test_tilelang_language_copy.py
@@ -66,7 +66,7 @@ def test_tilelang_copy_with_stride():
- if isinstance(NN, int):
- assert NN > N, "NN must be greater than N"
+ if isinstance(NN, int):
+ assert NN >= N, "NN must be greater than or equal to N"
@@ -75,7 +75,8 @@ def test_tilelang_copy_with_stride():
- if isinstance(NN, T.Var):
- NN = N * 2
+ # For any non-int (e.g., tvm.tir.Var), pick a concrete size
+ if not isinstance(NN, int):
+ NN = N * 2--- a/testing/python/jit/test_tilelang_jit_gemm_ctypes.py
+++ b/testing/python/jit/test_tilelang_jit_gemm_ctypes.py
@@ -368,12 +368,12 @@ def test_matmul_ctypes():
- if isinstance(M, T.Var):
- M = 1024
+ if not isinstance(M, int):
+ M = 1024
- if isinstance(N, T.Var):
- N = 1024
+ if not isinstance(N, int):
+ N = 1024
- if isinstance(K, T.Var):
- K = 768
+ if not isinstance(K, int):
+ K = 768--- a/testing/python/jit/test_tilelang_jit_gemm_cython.py
+++ b/testing/python/jit/test_tilelang_jit_gemm_cython.py
@@ -377,12 +377,12 @@ def test_matmul_cython():
- if isinstance(M, T.Var):
- M = 1024
+ if not isinstance(M, int):
+ M = 1024
- if isinstance(N, T.Var):
- N = 1024
+ if not isinstance(N, int):
+ N = 1024
- if isinstance(K, T.Var):
- K = 768
+ if not isinstance(K, int):
+ K = 768
@@ -446,12 +446,12 @@ def test_matmul_cython_outidx_minus1():
- if isinstance(M, T.Var):
- M = 1024
+ if not isinstance(M, int):
+ M = 1024
- if isinstance(N, T.Var):
- N = 1024
+ if not isinstance(N, int):
+ N = 1024
- if isinstance(K, T.Var):
- K = 768
+ if not isinstance(K, int):
+ K = 768📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_tilelang_copy_with_stride(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.symbolic("NN"), block_M=128, block_N=128) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tilelang.testing.main() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Inverted arch selection is consistent; harden HIP detection and verify repo-wide consistency
The inversion matches other files in this PR. To avoid AttributeError on some torch builds, use
getattr:Additionally, to ensure there are no missed spots still using the old mapping, you can scan the repo:
If any legacy patterns appear in contexts that should be updated, align them with this PR’s policy. Centralizing selection via a helper will also simplify future changes.
🏁 Script executed:
Length of output: 2545
Harden HIP detection by using
getattracross all occurrencesTo avoid
AttributeErroron torch builds that lacktorch.version.hip, please replace raw attribute access with a safegetattrcheck in every spot where we currently do:Please update the following locations:
• examples/gemm/example_gemm_autotune.py:19
• examples/analyze/example_conv_analyze.py:99
• examples/analyze/example_gemm_analyze.py:52
• benchmark/matmul/benchmark_matmul.py:56
• benchmark/matmul/benchmark_matmul_intrinsic.py:190
Example diff:
For future maintainability, consider centralizing this logic into a small helper (e.g.
def select_device(): ...) so any further changes only touch one place.🤖 Prompt for AI Agents