Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -62,7 +63,30 @@ using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();

if (args[0]->IsInstance<BufferLoadNode>()) {
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}

// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;

// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
Expand All @@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
Expand All @@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<< " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static
if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
int64_t min = min_imm->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
// Only perform the upper-bound check when the destination shape
// extent is also statically known. If the shape is symbolic (e.g., Var),
// skip this static check to avoid invalid downcasts.
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
ICHECK_LE(extent_imm->value, shape_imm->value)
<< "region[" << i << "] = " << extent_imm->value << " > "
<< node->dst->shape[i];
}
}
}
data_ = std::move(node);
Expand Down Expand Up @@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices.push_back(region[i]->min + var);
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
Expand Down Expand Up @@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
}
}

Expand Down Expand Up @@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill)
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }

} // namespace tl
} // namespace tvm
} // namespace tvm
51 changes: 51 additions & 0 deletions testing/python/issue/test_tilelang_issue_1008.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'int64']):
with T.Kernel(num_tokens, threads=128) as pid:
T.fill(x[0:128], 0)

return buggy_kernel

@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'int64']):
with T.Kernel(num_tokens, threads=128) as pid:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)

Comment on lines +34 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Initialize dynamic slice bounds before calling T.fill.

T.alloc_var only gives you unconstrained scalars; unless you bind them immediately, they carry garbage when the kernel runs. Feeding x[a:b] into tl.fill right after allocation leaves the start/end of the region undefined, so the generated kernel can write outside x. Please derive the bounds from known expressions (e.g., reuse num_tokens) instead of uninitialized vars.

-            a, b = T.alloc_var('int'), T.alloc_var('int')
-            T.fill(x[a:b], 0)
+            start = T.max(num_tokens - 128, 0)
+            T.fill(x[start:num_tokens], 0)

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In testing/python/issue/test_tilelang_issue_1008.py around lines 35 to 37, the
dynamic slice bounds a and b are allocated with T.alloc_var but never
initialized before calling T.fill(x[a:b], 0), leaving start/end undefined and
risking out-of-bounds writes; initialize those alloc_var scalars (e.g., set
a.value = 0 and b.value = num_tokens or assign them from an existing expression)
or replace x[a:b] with a slice derived directly from a known expression like
x[0:num_tokens] so the kernel receives concrete bounds before calling T.fill.

return buggy_kernel

def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256, ), dtype=torch.int64, device='cuda')
kernel(x)

def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256, ), dtype=torch.int64, device='cuda')
kernel(x)

if __name__ == '__main__':
tilelang.testing.main()
32 changes: 29 additions & 3 deletions tilelang/language/fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from tvm import tir
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
from tilelang.language.utils import (
buffer_to_tile_region,
buffer_region_to_tile_region,
buffer_load_to_tile_region,
)


def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value.

Args:
Expand All @@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
Returns:
A TVM intrinsic call that performs the fill operation
"""
# Normalize Var with let value to its underlying object
if isinstance(buffer, tir.Var) and has_let_value(buffer):
buffer = get_let_value(buffer)

# Convert to a tl.region descriptor (PrimExpr) with write access
region_call = None
if isinstance(buffer, tir.Buffer):
buffer = buffer.access_ptr("w") # Get write pointer if input is a Buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
region_call = buffer_to_tile_region(buffer, "w")
elif isinstance(buffer, tir.BufferRegion):
extents = [r.extent for r in buffer.region]
region_call = buffer_region_to_tile_region(buffer, "w", extents)
elif isinstance(buffer, tir.BufferLoad):
region = get_buffer_region_from_load(buffer)
if region is not None:
extents = [r.extent for r in region.region]
region_call = buffer_region_to_tile_region(region, "w", extents)
else:
# Fallback: treat element access as 1-extent per dim
region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices))
else:
# As-is fallback (rare): pass through for downstream handling
region_call = buffer

return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value)


def clear(buffer: tir.Buffer | tir.Var):
Expand Down
Loading