Skip to content

[BUG] Compilation Error with Single TMA Copy Instruction: 'key is not in Map' #1166

@LJC00118

Description

@LJC00118

Required prerequisites

What version of TileLang are you using?

0.1.6.post1+cuda.git4efd2d2d

System information

3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux
0.1.6.post1+cuda.git4efd2d2d
2.7.0+cu128

Problem description

When compiling a kernel containing only one TMA copy instruction, the compilation fails with:

File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
KeyError: 'key is not in Map'

Reproducible example code

The Python snippets:

import torch
import tilelang
from tilelang import language as T


tilelang.disable_cache()
@tilelang.jit()
def get_kernel(m: int):
    @T.prim_func
    def test_kernel(
        a: T.Tensor[(m,), "bfloat16"],
    ):
        with T.Kernel(1, threads=32) as (bx):
            a_shared = T.alloc_shared((m,), "bfloat16")
            T.copy(a, a_shared)

    return test_kernel


m = 4096
kernel = get_kernel(m)

print(kernel.get_kernel_source())

a = torch.randn((m,), device="cuda", dtype=torch.bfloat16)

kernel(a)

Traceback

Traceback (most recent call last):
  File "~/workspace/qwq/issue_tma_copy.py", line 20, in <module>
    kernel = get_kernel(m)
             ^^^^^^^^^^^^^
  File "~/workspace/tilelang/tilelang/jit/__init__.py", line 205, in wrapper
    kernel_result = compile(
                    ^^^^^^^^
  File "~/workspace/tilelang/tilelang/jit/__init__.py", line 70, in compile
    return cached(
           ^^^^^^^
  File "~/workspace/tilelang/tilelang/cache/__init__.py", line 29, in cached
    return _kernel_cache_instance.cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/workspace/tilelang/tilelang/cache/kernel_cache.py", line 185, in cached
    kernel = JITKernel(
             ^^^^^^^^^^
  File "~/workspace/tilelang/tilelang/jit/kernel.py", line 121, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/workspace/tilelang/tilelang/jit/kernel.py", line 219, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File "~/workspace/tilelang/tilelang/engine/lower.py", line 230, in lower
    mod = OptimizeForTarget(mod, target)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/workspace/tilelang/tilelang/engine/phase.py", line 137, in OptimizeForTarget
    mod = tilelang.transform.InjectTmaBarrier()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/workspace/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "tvm/ffi/cython/function.pxi", line 228, in tvm.ffi.core.Function.__call__
KeyError: 'key is not in Map'

Expected behavior

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions