diff --git a/src/op/fill.cc b/src/op/fill.cc index 055e64053..83b0842dc 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -17,6 +17,7 @@ #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" #include "builtin.h" +#include "region.h" namespace tvm { namespace tl { @@ -62,7 +63,30 @@ using namespace tir; Fill::Fill(Array args, BufferMap vmap) { ObjectPtr node = tvm::ffi::make_object(); - if (args[0]->IsInstance()) { + // Case 1: Region descriptor call (tl.region) + if (const auto *call = args[0].as()) { + if (call->op.same_as(RegionOp::Get())) { + auto region = RegionOp(call->args, vmap); + node->dst = region->GetBuffer(); + node->region = region->GetRanges(); + } else if (call->op.same_as(builtin::tvm_access_ptr())) { + node->dst = vmap[GetVarFromAccessPtr(args[0])]; + for (int i = 0; i < node->dst->shape.size(); i++) { + node->region.push_back(Range(0, node->dst->shape[i])); + } + } else { + ICHECK(false) << "Unsupported call op in tl.fill: " + << Downcast(call->op)->name; + } + + // Case 2: Explicit BufferRegion (legacy path) + } else if (args[0]->IsInstance()) { + auto region = Downcast(args[0]); + node->dst = region->buffer; + node->region = region->region; + + // Case 3: Vector/scalar region expressed via BufferLoad indices + } else if (args[0]->IsInstance()) { auto buffer_load = Downcast(args[0]); for (const auto &index : buffer_load->indices) { if (const auto *ramp = index.as()) { @@ -77,6 +101,7 @@ Fill::Fill(Array args, BufferMap vmap) { } } node->dst = buffer_load->buffer; + // Case 4: Access pointer, fill the full buffer } else { node->dst = vmap[GetVarFromAccessPtr(args[0])]; for (int i = 0; i < node->dst->shape.size(); i++) { @@ -95,14 +120,19 @@ Fill::Fill(Array args, BufferMap vmap) { << " != " << node->dst->shape.size(); for (int i = 0; i < node->region.size(); i++) { // bound check if region is static - if (node->region[i]->min.as()) { - int64_t min = Downcast(node->region[i]->min)->value; + if (const auto *min_imm = node->region[i]->min.as()) { + int64_t min = min_imm->value; ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; } - if (node->region[i]->extent.as()) { - int64_t extent = Downcast(node->region[i]->extent)->value; - ICHECK_LE(extent, Downcast(node->dst->shape[i])->value) - << "region[" << i << "] = " << extent << " > " << node->dst->shape[i]; + if (const auto *extent_imm = node->region[i]->extent.as()) { + // Only perform the upper-bound check when the destination shape + // extent is also statically known. If the shape is symbolic (e.g., Var), + // skip this static check to avoid invalid downcasts. + if (const auto *shape_imm = node->dst->shape[i].as()) { + ICHECK_LE(extent_imm->value, shape_imm->value) + << "region[" << i << "] = " << extent_imm->value << " > " + << node->dst->shape[i]; + } } } data_ = std::move(node); @@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { for (int i = 0; i < ndim; i++) { Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype); loop_vars.push_back({region[i], var, IterVarType::kDataPar}); - dst_indices.push_back(var); + // Offset the loop induction variable by region min to honor sliced regions + dst_indices.push_back(region[i]->min + var); } Stmt body = BufferStore(dst, value, dst_indices); for (int i = ndim - 1; i >= 0; i--) { @@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } else { LOG(FATAL) << "Unsupported scope " << dst.scope(); + return Stmt(); } } @@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill) TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); } } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/testing/python/issue/test_tilelang_issue_1008.py b/testing/python/issue/test_tilelang_issue_1008.py new file mode 100644 index 000000000..395593d8c --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1008.py @@ -0,0 +1,53 @@ +import torch +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + },) +def _fill_with_static_region_kernel(): + num_tokens = T.symbolic('num_tokens') + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 + with T.Kernel(num_tokens, threads=128) as _: + T.fill(x[0:128], 0) + + return buggy_kernel + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + },) +def _fill_with_dynamic_region_kernel(): + num_tokens = T.symbolic('num_tokens') + + @T.prim_func + def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821 + with T.Kernel(num_tokens, threads=128) as _: + a, b = T.alloc_var('int'), T.alloc_var('int') + T.fill(x[a:b], 0) + + return buggy_kernel + + +def test_fill_with_static_region_kernel(): + kernel = _fill_with_static_region_kernel() + x = torch.zeros((256,), dtype=torch.int64, device='cuda') + kernel(x) + + +def test_fill_with_dynamic_region_kernel(): + kernel = _fill_with_dynamic_region_kernel() + x = torch.zeros((256,), dtype=torch.int64, device='cuda') + kernel(x) + + +if __name__ == '__main__': + tilelang.testing.main() diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index 95ef26746..74aeb2648 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -4,9 +4,14 @@ from tvm import tir from tilelang.language import has_let_value, get_let_value from tilelang.utils.language import get_buffer_region_from_load +from tilelang.language.utils import ( + buffer_to_tile_region, + buffer_region_to_tile_region, + buffer_load_to_tile_region, +) -def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr): +def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr): """Fill a buffer or buffer region with a specified value. Args: @@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr): Returns: A TVM intrinsic call that performs the fill operation """ + # Normalize Var with let value to its underlying object + if isinstance(buffer, tir.Var) and has_let_value(buffer): + buffer = get_let_value(buffer) + + # Convert to a tl.region descriptor (PrimExpr) with write access + region_call = None if isinstance(buffer, tir.Buffer): - buffer = buffer.access_ptr("w") # Get write pointer if input is a Buffer - return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value) + region_call = buffer_to_tile_region(buffer, "w") + elif isinstance(buffer, tir.BufferRegion): + extents = [r.extent for r in buffer.region] + region_call = buffer_region_to_tile_region(buffer, "w", extents) + elif isinstance(buffer, tir.BufferLoad): + region = get_buffer_region_from_load(buffer) + if region is not None: + extents = [r.extent for r in region.region] + region_call = buffer_region_to_tile_region(region, "w", extents) + else: + # Fallback: treat element access as 1-extent per dim + region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices)) + else: + # As-is fallback (rare): pass through for downstream handling + region_call = buffer + + return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value) def clear(buffer: tir.Buffer | tir.Var):