Skip to content

Commit 0d7971d

Browse files
authored
[Enhancement] support composable expression for shape with symbolic vars (#624)
* [Refactor] Enhance expression handling in utils.py and update wrapper to use pythonic_expr - Added support for additional TIR expressions (FloorDiv, Min, Max, Add, Sub, FloorMod) in the pythonic_expr function to improve string representation. - Replaced the deprecated legalize_c function calls in TLCUDASourceWrapper and TLCPUSourceWrapper with pythonic_expr for better expression handling in kernel launch code. * [Refactor] Simplify expression handling in pythonic_expr function - Consolidated binary and min/max operation handling in the pythonic_expr function to improve readability and maintainability. - Replaced individual checks for binary operations with a mapping approach, streamlining the code and enhancing performance in expression representation. * [Enhancement] Improve expression representation in pythonic_expr function - Added operator precedence handling to the pythonic_expr function, enhancing the conversion of TVM PrimExpr to Python-style strings. - Updated the visitor logic to intelligently add parentheses based on operator precedence, improving the accuracy of expression representation. - Included a docstring for better clarity on the function's purpose and usage. * test fix * minor update
1 parent f2203ae commit 0d7971d

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

tilelang/engine/param.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import List, Union, Optional
77
import torch
88
from tilelang import tvm as tvm
9-
from tvm.tir import Buffer, IntImm, Var
9+
from tvm.tir import Buffer, IntImm, Var, PrimExpr
1010
from tilelang.utils.tensor import map_torch_type
1111

1212

@@ -38,10 +38,10 @@ def from_buffer(cls, buffer: Buffer):
3838
for s in buffer.shape:
3939
if isinstance(s, IntImm):
4040
shape.append(s.value)
41-
elif isinstance(s, Var):
41+
elif isinstance(s, (Var, PrimExpr)):
4242
shape.append(s)
4343
else:
44-
raise ValueError(f"Unsupported dimension type: {type(s)}")
44+
raise ValueError(f"Unsupported dimension type: {type(s)} {s}")
4545
return cls(dtype, shape)
4646

4747
@classmethod

tilelang/jit/adapter/cython/cython_wrapper.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,12 @@ cdef class CythonKernelWrapper:
147147
else: # Already converted to Python int during initialization
148148
shape.append(s)
149149
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
150+
if len(shape) == 0:
151+
param_name = self.params[i].name if hasattr(self.params[i], 'name') else f'parameter_{i}'
152+
raise ValueError(
153+
f"Cannot create output tensor (name={param_name}) - 0-dimensional tensors are not supported. "
154+
f"Expected shape: {shape}"
155+
)
150156
tensor = torch.empty(*shape, dtype=dtype, device=device)
151157
else:
152158
tensor = inputs[ins_idx]

0 commit comments

Comments
 (0)