Skip to content

[Feature request] Support dynamic range indexing #1008

@LyricZhao

Description

@LyricZhao
import torch
import tilelang
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },
)
def get_buggy_kernel():
    num_tokens = T.symbolic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens, ), 'int64']):
        with T.Kernel(num_tokens, threads=128) as pid:
            a, b = T.alloc_var('int'), T.alloc_var('int')
            T.fill(x[a:b], 0)
            # x[a:b] = 0  # also GG

    return buggy_kernel


if __name__ == '__main__':
    kernel = get_buggy_kernel()
    print(kernel.get_kernel_source())

    x = torch.zeros((128, ), dtype=torch.int64, device='cuda')
    kernel(x)

As titled.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions