Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -62,7 +63,30 @@ using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();

if (args[0]->IsInstance<BufferLoadNode>()) {
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}

// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;

// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
Expand All @@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
Expand All @@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<< " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static
if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
int64_t min = min_imm->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
// Only perform the upper-bound check when the destination shape
// extent is also statically known. If the shape is symbolic (e.g., Var),
// skip this static check to avoid invalid downcasts.
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
ICHECK_LE(extent_imm->value, shape_imm->value)
<< "region[" << i << "] = " << extent_imm->value << " > "
<< node->dst->shape[i];
}
}
}
data_ = std::move(node);
Expand Down Expand Up @@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices.push_back(region[i]->min + var);
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
Expand Down Expand Up @@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
}
}

Expand Down Expand Up @@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill)
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }

} // namespace tl
} // namespace tvm
} // namespace tvm
53 changes: 53 additions & 0 deletions testing/python/issue/test_tilelang_issue_1008.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)

return buggy_kernel


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)

return buggy_kernel


def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)


def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)


if __name__ == '__main__':
tilelang.testing.main()
7 changes: 6 additions & 1 deletion tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar)
from typing import (Callable, Generic, Literal, Any, TypeVar)
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tqdm.auto import tqdm
import logging
import concurrent.futures
Expand Down
6 changes: 5 additions & 1 deletion tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
Callable,
Generic,
Iterable,
ParamSpec,
TypeVar,
overload,
Literal,
)
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target
Expand Down
7 changes: 6 additions & 1 deletion tilelang/jit/kernel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations
from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar
from typing import Any, Callable, Generic, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec

from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target
Expand Down
32 changes: 29 additions & 3 deletions tilelang/language/fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
from tvm import tir
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
from tilelang.language.utils import (
buffer_to_tile_region,
buffer_region_to_tile_region,
buffer_load_to_tile_region,
)


def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value.

Args:
Expand All @@ -16,9 +21,30 @@ def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
Returns:
A TVM intrinsic call that performs the fill operation
"""
# Normalize Var with let value to its underlying object
if isinstance(buffer, tir.Var) and has_let_value(buffer):
buffer = get_let_value(buffer)

# Convert to a tl.region descriptor (PrimExpr) with write access
region_call = None
if isinstance(buffer, tir.Buffer):
buffer = buffer.access_ptr("w") # Get write pointer if input is a Buffer
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
region_call = buffer_to_tile_region(buffer, "w")
elif isinstance(buffer, tir.BufferRegion):
extents = [r.extent for r in buffer.region]
region_call = buffer_region_to_tile_region(buffer, "w", extents)
elif isinstance(buffer, tir.BufferLoad):
region = get_buffer_region_from_load(buffer)
if region is not None:
extents = [r.extent for r in region.region]
region_call = buffer_region_to_tile_region(region, "w", extents)
else:
# Fallback: treat element access as 1-extent per dim
region_call = buffer_load_to_tile_region(buffer, "w", [1] * len(buffer.indices))
else:
# As-is fallback (rare): pass through for downstream handling
region_call = buffer

return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), region_call, value)


def clear(buffer: tir.Buffer | tir.Var):
Expand Down
7 changes: 6 additions & 1 deletion tilelang/language/v2/ast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import Callable, ContextManager, Generic, Iterable, Any, Literal, ParamSpec, TypeVar
from typing import Callable, ContextManager, Generic, Iterable, Any, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
import inspect
# from .utils import get_ast, get_compiled_object
from . import utils
Expand Down
13 changes: 10 additions & 3 deletions tilelang/language/v2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from tvm.tir import Buffer
from tvm.script.ir_builder import tir, IRBuilder
from tvm.tir.expr import EqualOp, FloatImm, IntImm, NotEqualOp, PrimExpr, StringImm, Var
from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, ParamSpec, Self, TypeVar, ForwardRef
from typing import TYPE_CHECKING, Callable, ContextManager, Any, Generic, TypeVar, ForwardRef, Union
# Python 3.9 compatibility for ParamSpec and Self
try:
from typing import ParamSpec, Self
except ImportError: # Python < 3.11 for Self, < 3.10 for ParamSpec
from typing_extensions import ParamSpec, Self
from . import dtypes as dt
import threading
import logging
Expand Down Expand Up @@ -95,8 +100,10 @@ class BreakFrame(Frame):
...


ContinueOrBreak = ContinueFrame | BreakFrame
AnyFrame = tir.frame.IRBuilderFrame | Frame
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
# Use tuple for isinstance checks and typing.Union for annotations/aliases
ContinueOrBreak = (ContinueFrame, BreakFrame)
AnyFrame = Union[tir.frame.IRBuilderFrame, Frame]

TIR_CONTROL_FRAME = (
tir.frame.WhileFrame,
Expand Down
5 changes: 3 additions & 2 deletions tilelang/language/v2/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from tvm import ir
import torch
import ctypes
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
from tvm import tir
import tvm.script.ir_builder.tir._ffi_api as tb_ffi

dtype = tvm.DataType
AnyDType = ir.Type | str | type | torch.dtype | dtype
# Python 3.9 compatibility: avoid PEP 604 unions at runtime
AnyDType = Union[ir.Type, str, type, torch.dtype, dtype]

_dtype_cvt = [
(None, 'handle', ctypes.c_long, 'long', None), # use long to repr void*
Expand Down
Loading