Skip to content

Commit 056c6a3

Browse files
committed
test fix
1 parent 6ed0611 commit 056c6a3

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

examples/fusedmoe/example_fusedmoe_tilelang.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from tilelang.autotuner import *
88
from example_fusedmoe_torch import *
99

10-
# tilelang.disable_cache()
11-
1210

1311
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
1412
def moe_forward_tilelang_shared(d_hidden,

tilelang/language/proxy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The language interface for tl programs."""
22

33
from __future__ import annotations
4-
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple
4+
from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union
55
from typing_extensions import Self
66

77
from tvm import tir
@@ -150,7 +150,9 @@ def _construct_strides(shape: Tuple[Any]):
150150
strides.append(s)
151151
return tuple(reversed(strides))
152152

153-
def __call__(self, shape: Tuple[Any], dtype: str = "float32", data=None) -> tir.Buffer:
153+
def __call__(self, shape: Union[Tuple[Any], PrimExpr, int], dtype: str = "float32", data=None) -> tir.Buffer:
154+
if isinstance(shape, (int, PrimExpr)):
155+
shape = (shape,)
154156
return super().__call__(
155157
shape, dtype=dtype, strides=TensorProxy._construct_strides(shape), data=data)
156158

0 commit comments

Comments
 (0)