-
Notifications
You must be signed in to change notification settings - Fork 332
Closed
Description
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_buggy_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, ), 'int64']):
with T.Kernel(num_tokens, threads=128) as pid:
a, b = T.alloc_var('int'), T.alloc_var('int')
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel()
print(kernel.get_kernel_source())
x = torch.zeros((1, 128, ), dtype=torch.int64, device='cuda')
kernel(x)
As titled, we require a 1D tensor, but launching with shape (1, 128) is valid.
Metadata
Metadata
Assignees
Labels
No labels