Skip to content

Commit 703e63c

Browse files
committed
Merge remote-tracking branch 'upstream/main' into update-maint-release
2 parents 25a9814 + 778b97d commit 703e63c

25 files changed

+2683
-340
lines changed

.github/workflows/dist.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ jobs:
8989
- name: Upload SDist
9090
# Not PR to save artifact storage, as SDist is only needed for releases.
9191
if: github.event_name != 'pull_request' || contains(github.event.pull_request.title, '[Release]')
92-
uses: actions/upload-artifact@v4
92+
uses: actions/upload-artifact@v5
9393
with:
9494
name: sdist
9595
path: dist/*.tar.gz
@@ -172,7 +172,7 @@ jobs:
172172
timeout-minutes: 15
173173
steps:
174174
- name: Download built SDist
175-
uses: actions/download-artifact@v5
175+
uses: actions/download-artifact@v6
176176
with:
177177
# unpacks default artifact into dist/
178178
# if `name: artifact` is omitted, the action will create extra parent dir

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ repos:
4141
^.+\.json$
4242
)
4343
- repo: https://github.com/astral-sh/ruff-pre-commit
44-
rev: v0.14.1 # sync with requirements-lint.txt
44+
rev: v0.14.3 # sync with requirements-lint.txt
4545
hooks:
4646
- id: ruff-check
4747
args: [--fix, --exit-non-zero-on-fix]

examples/gdn/example_chunk_o_bwd.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import tilelang.language as T
88
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
99

10-
print(tilelang.__file__)
11-
1210
# Add your fla repository path to sys.path
1311
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
1412
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
@@ -256,8 +254,9 @@ def kernel(
256254
# for i_kv in T.Parallel(block_DK * block_DV):
257255
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
258256
for i_kv in T.Parallel(block_DK * block_DV):
259-
i_k, i_v = i_kv // block_DV, i_kv % block_DV
260-
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
257+
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv %
258+
block_DV] * dh_shared[i_kv // block_DV,
259+
i_kv % block_DV]
261260
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
262261
dg_last_local[0] += dg_last_fragment_scalar[0]
263262

requirements-lint.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ pre-commit
33
clang-format==21.1.2
44
clang-tidy==21.1.1
55
codespell[toml]==2.4.1
6-
ruff==0.14.1
6+
ruff==0.14.3
77
yapf==0.43.0

src/transform/make_packed_api.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
433433
auto shape_vectorize_expr = [&]() -> PrimExpr {
434434
PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1);
435435
result = result * vectorize_dim;
436-
result = FloorMod(result, dynamic_alignment);
436+
result = FloorMod(result, IntImm(result->dtype, dynamic_alignment));
437437
return result;
438438
}();
439439
shape_checks.emplace_back(AssertStmt(
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import tilelang
3+
import tilelang.language as T
4+
5+
6+
def test_int64_address():
7+
8+
@tilelang.jit
9+
def set_cache_kernel(
10+
S,
11+
D,
12+
pos_ty='int64',
13+
dtype="float32",
14+
):
15+
16+
@T.prim_func
17+
def main(
18+
pos: T
19+
.Tensor(
20+
[
21+
S,
22+
], pos_ty
23+
), # type: ignore `TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32`
24+
value: T.Tensor([S, D], dtype), # type: ignore
25+
cache: T.Tensor([S, D], dtype), # type: ignore
26+
):
27+
with T.Kernel(S, threads=128) as bx:
28+
slot = pos[bx]
29+
for i in T.Parallel(D):
30+
cache[slot, i] = value[bx, i]
31+
32+
return main
33+
34+
D = 2
35+
S = 10
36+
cache = torch.rand((S, D), device="cuda", dtype=torch.float32)
37+
value = torch.rand((S, D), device='cuda', dtype=torch.float32)
38+
pos_int64 = torch.arange(S, device='cuda', dtype=torch.int64)
39+
pos_int32 = torch.arange(S, device='cuda', dtype=torch.int32)
40+
kernel_int64 = set_cache_kernel(S, D, 'int64')
41+
kernel_int32 = set_cache_kernel(S, D, 'int32')
42+
kernel_int64(pos_int64, value, cache)
43+
torch.testing.assert_close(cache, value)
44+
kernel_int32(pos_int32, value, cache)
45+
torch.testing.assert_close(cache, value)
46+
47+
48+
if __name__ == "__main__":
49+
tilelang.testing.main()
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import tilelang.testing
2+
import tilelang
3+
import torch
4+
5+
6+
@tilelang.jit(
7+
out_idx=-1, # create the output tensor during runtime
8+
verbose=True,
9+
)
10+
def matmul_kernel_jit(
11+
M,
12+
N,
13+
K,
14+
block_M,
15+
block_N,
16+
block_K,
17+
trans_A=False,
18+
trans_B=True,
19+
in_dtype='float16',
20+
out_dtype='float32',
21+
accum_dtype='float32',
22+
num_stages=2,
23+
threads=128,
24+
):
25+
A_shape = (K, M) if trans_A else (M, K)
26+
B_shape = (N, K) if trans_B else (K, N)
27+
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
28+
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
29+
30+
import tilelang.language as T
31+
32+
@T.prim_func
33+
def main(
34+
A: T.Tensor(A_shape, in_dtype),
35+
B: T.Tensor(B_shape, in_dtype),
36+
C: T.Tensor((M, N), out_dtype),
37+
):
38+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
39+
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
40+
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
41+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
42+
T.clear(C_local)
43+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
44+
if trans_A:
45+
T.copy(A[k * block_K, by * block_M], A_shared)
46+
else:
47+
T.copy(A[by * block_M, k * block_K], A_shared)
48+
if trans_B:
49+
T.copy(B[bx * block_N, k * block_K], B_shared)
50+
else:
51+
T.copy(B[k * block_K, bx * block_N], B_shared)
52+
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
53+
T.copy(C_local, C[by * block_M, bx * block_N])
54+
55+
return main
56+
57+
58+
def test_par_compile():
59+
configs = [
60+
(1024, 1024, 1024, 128, 128, 32),
61+
(2048, 2048, 2048, 256, 256, 64),
62+
(4096, 4096, 4096, 64, 64, 128),
63+
]
64+
kernels = matmul_kernel_jit.par_compile(configs)
65+
for (M, N, K, _, _, _), kernel in zip(configs, kernels):
66+
A = torch.randn(M, K, dtype=torch.float16).cuda()
67+
B = torch.randn(N, K, dtype=torch.float16).cuda()
68+
ref = (A @ B.T).float()
69+
C = kernel(A, B)
70+
tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)
71+
72+
73+
if __name__ == "__main__":
74+
tilelang.testing.main()

0 commit comments

Comments
 (0)