Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/attention_sink/benchmark_gqa_sink_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional


@triton.jit
Expand Down Expand Up @@ -94,7 +95,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
_, n_heads_kv, seq_kv, _ = K.shape
BLOCK_M = 64
Expand Down Expand Up @@ -130,7 +131,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
Expand Down
5 changes: 3 additions & 2 deletions examples/attention_sink/benchmark_mha_sink_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
from typing import Optional


@triton.jit
Expand Down Expand Up @@ -93,7 +94,7 @@ def triton_kernel(
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)


def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
bs, n_heads, seq_q, head_dim = Q.shape
seq_kv = K.shape[2]
BLOCK_M = 64
Expand Down Expand Up @@ -125,7 +126,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
Expand Down
2 changes: 1 addition & 1 deletion examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def main(BATCH: int = 1,
N_CTX: int = 512,
D_HEAD: int = 64,
groups: int = 2,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def main(
seq_kv: int = 256,
dim: int = 128,
groups: int = 8,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False,
):
Expand Down
2 changes: 1 addition & 1 deletion examples/attention_sink/example_mha_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def main(BATCH: int = 1,
H: int = 1,
N_CTX: int = 512,
D_HEAD: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16"):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
if window_size is not None:
Expand Down
2 changes: 1 addition & 1 deletion examples/attention_sink/example_mha_sink_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def main(batch: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 128,
window_size: int | None = None,
window_size: Optional[int] = None,
dtype: str = "float16",
tune: bool = False):
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
Expand Down
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "tilelang"
description = "A tile level programming language to generate high performance code."
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
maintainers = [{ name = "Lei Wang", email = "[email protected]" }]
license = "MIT"
Expand All @@ -14,7 +14,6 @@ classifiers = [
"Operating System :: MacOS",
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -118,7 +117,7 @@ skip = [
]

[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 100
output-format = "full"

Expand Down
50 changes: 41 additions & 9 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "../transform/loop_partition.h"
#include "../transform/loop_vectorize.h"
#include "builtin.h"
#include "region.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -62,7 +63,30 @@ using namespace tir;
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();

if (args[0]->IsInstance<BufferLoadNode>()) {
// Case 1: Region descriptor call (tl.region)
if (const auto *call = args[0].as<CallNode>()) {
if (call->op.same_as(RegionOp::Get())) {
auto region = RegionOp(call->args, vmap);
node->dst = region->GetBuffer();
node->region = region->GetRanges();
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
node->region.push_back(Range(0, node->dst->shape[i]));
}
} else {
ICHECK(false) << "Unsupported call op in tl.fill: "
<< Downcast<Op>(call->op)->name;
}

// Case 2: Explicit BufferRegion (legacy path)
} else if (args[0]->IsInstance<BufferRegionNode>()) {
auto region = Downcast<BufferRegion>(args[0]);
node->dst = region->buffer;
node->region = region->region;

// Case 3: Vector/scalar region expressed via BufferLoad indices
} else if (args[0]->IsInstance<BufferLoadNode>()) {
auto buffer_load = Downcast<BufferLoad>(args[0]);
for (const auto &index : buffer_load->indices) {
if (const auto *ramp = index.as<RampNode>()) {
Expand All @@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
}
}
node->dst = buffer_load->buffer;
// Case 4: Access pointer, fill the full buffer
} else {
node->dst = vmap[GetVarFromAccessPtr(args[0])];
for (int i = 0; i < node->dst->shape.size(); i++) {
Expand All @@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
<< " != " << node->dst->shape.size();
for (int i = 0; i < node->region.size(); i++) {
// bound check if region is static
if (node->region[i]->min.as<IntImm>()) {
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
int64_t min = min_imm->value;
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
}
if (node->region[i]->extent.as<IntImm>()) {
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
// Only perform the upper-bound check when the destination shape
// extent is also statically known. If the shape is symbolic (e.g., Var),
// skip this static check to avoid invalid downcasts.
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
ICHECK_LE(extent_imm->value, shape_imm->value)
<< "region[" << i << "] = " << extent_imm->value << " > "
<< node->dst->shape[i];
}
}
}
data_ = std::move(node);
Expand Down Expand Up @@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var);
// Offset the loop induction variable by region min to honor sliced regions
dst_indices.push_back(region[i]->min + var);
}
Stmt body = BufferStore(dst, value, dst_indices);
for (int i = ndim - 1; i >= 0; i--) {
Expand Down Expand Up @@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
} else {
LOG(FATAL) << "Unsupported scope " << dst.scope();
return Stmt();
}
}

Expand Down Expand Up @@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill)
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }

} // namespace tl
} // namespace tvm
} // namespace tvm
40 changes: 29 additions & 11 deletions src/transform/layout_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "../layout/layout.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"

Expand Down Expand Up @@ -275,17 +276,34 @@ class ReducerLayoutAnnotator : public IRMutatorWithAnalyzer {
auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) {
ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>();
arg0_call &&
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var.value(),
reducer_info_map_.Get(var.value()).value());
if (auto arg0_call = op->args[0].as<Call>()) {
// Case 1: tl.region(...) — extract buffer var from its first arg
if (arg0_call.value()->op.same_as(RegionOp::Get())) {
ICHECK(!arg0_call.value()->args.empty());
if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
Var var = bl->buffer->data;
if (reducer_info_map_.count(var)) {
ICHECK(inside_reducer_range_.count(var) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var,
reducer_info_map_.Get(var).value());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(
var.value(), reducer_info_map_.Get(var.value()).value());
}
}
}
} else if (op->op.same_as(FinalizeReducerOp::Get())) {
Expand Down
43 changes: 20 additions & 23 deletions testing/python/amd/test_tilelang_test_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,29 +223,26 @@ def ref_program(A, B):
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


@tilelang.testing.requires_rocm
def test_gemm_rs_f16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)


@tilelang.testing.requires_rocm
def test_gemm_rs_bf16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)


@tilelang.testing.requires_rocm
def test_gemm_rs_bf16bf16f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)

# @tilelang.testing.requires_rocm
# def test_gemm_rs_f16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)

# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)

# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16bf16f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)

if __name__ == "__main__":
tilelang.testing.main()
53 changes: 53 additions & 0 deletions testing/python/issue/test_tilelang_issue_1008.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_static_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
T.fill(x[0:128], 0)

return buggy_kernel


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},)
def _fill_with_dynamic_region_kernel():
num_tokens = T.symbolic('num_tokens')

@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), 'int64']): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var('int'), T.alloc_var('int')
T.fill(x[a:b], 0)

return buggy_kernel


def test_fill_with_static_region_kernel():
kernel = _fill_with_static_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)


def test_fill_with_dynamic_region_kernel():
kernel = _fill_with_dynamic_region_kernel()
x = torch.zeros((256,), dtype=torch.int64, device='cuda')
kernel(x)


if __name__ == '__main__':
tilelang.testing.main()
7 changes: 6 additions & 1 deletion tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, Generic, Literal, Any, ParamSpec, TypeVar)
from typing import (Callable, Generic, Literal, Any, TypeVar)
# Python 3.9 compatibility for ParamSpec
try:
from typing import ParamSpec
except ImportError: # Python < 3.10
from typing_extensions import ParamSpec
from tqdm.auto import tqdm
import logging
import concurrent.futures
Expand Down
Loading
Loading