Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions testing/python/jit/test_tilelang_jit_nullptr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
from tilelang.utils import map_torch_type


@tl.jit
def ptr_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def main(
a_ptr: T.ptr,
b_ptr: T.ptr,
c_ptr: T.ptr,
bias_ptr: T.ptr,
m: T.int32,
n: T.int32,
k: T.int32,
with_bias: T.bool,
):
A = T.make_tensor(a_ptr, (m, k), dtype)
B = T.make_tensor(b_ptr, (k, n), dtype)
C = T.make_tensor(c_ptr, (m, n), accum_dtype)
Bias = T.make_tensor(bias_ptr, (n), accum_dtype)

# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(k, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)

if with_bias:
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] += Bias[bx * block_N + j]

T.copy(C_local, C[by * block_M, bx * block_N])

return main


@tl.jit
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), accum_dtype),
Bias: T.Tensor((N), accum_dtype),
with_bias: T.bool,
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)

if with_bias:
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] += Bias[bx * block_N + j]

T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
func = ptr_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)

a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
d = torch.randn(N, device="cuda", dtype=map_torch_type(accum_dtype))

func(a, b, c, None, M, N, K, False)

ref_no_bias = (a @ b.T).to(map_torch_type(accum_dtype))
ref_with_bias = ref_no_bias + d

torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)

func(a, b, c, d, M, N, K, True)

torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)

func = tensor_null_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
func(a, b, c, None, False)
torch.testing.assert_close(c, ref_no_bias, atol=1e-2, rtol=1e-2)
func(a, b, c, d, True)
torch.testing.assert_close(c, ref_with_bias, atol=1e-2, rtol=1e-2)


def test_nullptr():
run_test(1024, 1024, 1024, 128, 128, 32)


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 2 additions & 0 deletions tilelang/jit/adapter/cython/cython_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ cdef class CythonKernelWrapper:
if dtype not in dtype_to_ctype:
raise ValueError(f"Unsupported tensor dtype: {dtype}")
call_args.append(dtype_to_ctype[dtype](tensor))
elif tensor is None:
call_args.append(ctypes.c_void_p(0))
else:
raise ValueError(f"Unsupported tensor type: {type(tensor)}")

Comment on lines 251 to 258

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle optional tensors consistently across adapters

Adding elif tensor is None: call_args.append(ctypes.c_void_p(0)) makes the Cython path accept null inputs, but the other adapters (e.g. the ctypes and nvrtc wrappers) still always access .shape/.stride on every argument when building dynamic symbolics. Passing None for a T.ptr/T.Tensor with those backends will therefore still raise an AttributeError instead of forwarding a null pointer, despite the feature claiming generic support for None. To avoid backend‑dependent failures, the same None handling or a guard in their dynamic dimension logic is needed.

Useful? React with 👍 / 👎.

Expand Down
4 changes: 2 additions & 2 deletions tilelang/language/allocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
with the appropriate memory scope.
"""

from __future__ import annotations
from tilelang import tvm as tvm
from tvm.script import tir as T
from tvm.tir import PrimExpr
from tvm.script.parser.tir import block_attr
from typing import Union


def alloc_shared(shape, dtype, scope="shared.dyn"):
Expand Down Expand Up @@ -67,7 +67,7 @@ def alloc_fragment(shape, dtype, scope="local.fragment"):
return T.alloc_buffer(shape, dtype, scope=scope)


def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None):
def alloc_var(dtype, *args, scope="local.var", init: PrimExpr | None = None):
"""Allocate a single-element variable buffer.

Args:
Expand Down
Loading