Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/attention_sink/benchmark_gqa_sink_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional


@triton.jit
Expand Down Expand Up @@ -94,7 +95,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
Expand Down Expand Up @@ -130,7 +131,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
Expand Down
5 changes: 3 additions & 2 deletions examples/attention_sink/benchmark_mha_sink_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional


@triton.jit
Expand Down Expand Up @@ -93,7 +94,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
Expand Down Expand Up @@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
Expand Down
2 changes: 1 addition & 1 deletion examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def main(BATCH: int = 1,
N_CTX: int = 512,
D_HEAD: int = 64,
groups: int = 2,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
Expand Down
2 changes: 1 addition & 1 deletion examples/attention_sink/example_mha_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def main(BATCH: int = 1,
H: int = 1,
N_CTX: int = 512,
D_HEAD: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion examples/attention_sink/example_mha_sink_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "tilelang"
description = "A tile level programming language to generate high performance code."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
maintainers = [{ name = "Lei Wang", email = "[email protected]" }]
license = "MIT"
Expand All @@ -14,7 +14,6 @@ classifiers = [
"Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -118,7 +117,7 @@ skip = [
]

[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 100
output-format = "full"

Expand Down
50 changes: 41 additions & 9 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -62,7 +63,30 @@ using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();

if (args[0]->IsInstance<BufferLoadNode>()) {
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}

// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;

// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
Expand All @@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
Expand All @@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<< " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static
if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
int64_t min = min_imm->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
// Only perform the upper-bound check when the destination shape
// extent is also statically known. If the shape is symbolic (e.g., Var),
// skip this static check to avoid invalid downcasts.
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
ICHECK_LE(extent_imm->value, shape_imm->value)
<< "region[" << i << "] = " << extent_imm->value << " > "
<< node->dst->shape[i];
}
}
}
data_ = std::move(node);
Expand Down Expand Up @@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices.push_back(region[i]->min + var);
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
Expand Down Expand Up @@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
}
}

Expand Down Expand Up @@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill)
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }

} // namespace tl
} // namespace tvm
} // namespace tvm
40 changes: 29 additions & 11 deletions src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../layout/layout.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"

Expand Down Expand Up @@ -275,17 +276,34 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) {
ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>();
arg0_call &&
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var.value(),
reducer_info_map_.Get(var.value()).value());
if (auto arg0_call = op->args[0].as<Call>()) {
// Case 1: tl.region(...) — extract buffer var from its first arg
if (arg0_call.value()->op.same_as(RegionOp::Get())) {
ICHECK(!arg0_call.value()->args.empty());
if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
Var var = bl->buffer->data;
if (reducer_info_map_.count(var)) {
ICHECK(inside_reducer_range_.count(var) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var,
reducer_info_map_.Get(var).value());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(
var.value(), reducer_info_map_.Get(var.value()).value());
}
}
}
} else if (op->op.same_as(FinalizeReducerOp::Get())) {
Expand Down
53 changes: 53 additions & 0 deletions testing/python/issue/test_tilelang_issue_1008.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)

return buggy_kernel


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)

return buggy_kernel


def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)


def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)


if __name__ == '__main__':
tilelang.testing.main()
7 changes: 6 additions & 1 deletion tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar)
from typing import (Callable, Generic, Literal, Any, TypeVar)
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tqdm.auto import tqdm
import logging
import concurrent.futures
Expand Down
8 changes: 6 additions & 2 deletions tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
Any,
Callable,
Generic,
Iterable,
ParamSpec,
TypeVar,
overload,
Literal,
)
from collections.abc import Iterable
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tilelang import tvm as tvm
from tilelang.language.v2 import PrimFunc
from tilelang.jit.adapter.utils import is_metal_target
Expand Down
7 changes: 6 additions & 1 deletion tilelang/jit/kernel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations
from typing import Any, Callable, Generic, Literal, ParamSpec, TypeVar
from typing import Any, Callable, Generic, Literal, TypeVar
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec

from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target
Expand Down
Loading
Loading