-
Notifications
You must be signed in to change notification settings - Fork 334
Closed
Description
import torch
import tilelang
from tilelang import language as T
@tilelang.jit()
def get_buggy_kernel():
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(indices: T.Tensor[(num_tokens, ), 'int'],
x: T.Tensor[(num_tokens, ), 'float']):
with T.Kernel(num_tokens, threads=32) as pid:
idx = T.alloc_local([1], 'int')
T.copy(indices[pid], idx[0])
x[idx[0]] = x[idx[0]] + 1
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel()
print(kernel.get_kernel_source())
i = torch.arange(128, dtype=torch.int, device='cuda')
x = torch.randn((128, ), dtype=torch.float, device='cuda')
kernel(i, x)error: Can't deduce copy extents from args, the error is raised by T.copy(indices[pid], idx[0]).
Metadata
Metadata
Assignees
Labels
No labels