Skip to content

Commit 06dd4f7

Browse files
committed
[ONNX][FRONTEND] Update Resize to accept ShapeExpr
1 parent 60f5568 commit 06dd4f7

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,18 +2143,24 @@ def _impl_v18(cls, bb, inputs, attr, params):
21432143

21442144
# Convert scales to sizes if needed.
21452145
if scales is not None:
2146-
assert isinstance(scales, relax.Constant), "Only constant scales currently supported."
2147-
scales = scales.data.numpy()
2146+
if isinstance(scales, relax.Constant):
2147+
scales = scales.data.numpy()
2148+
elif isinstance(scales, relax.expr.ShapeExpr):
2149+
scales = [int(val.value) for val in scales.values]
2150+
else:
2151+
assert f"Type {type(size)} for scale is currently unsupported."
21482152
sizes = []
21492153

21502154
for i, dim in enumerate(x.struct_info.shape):
21512155
sizes.append(cast(scales[i] * dim, "int64"))
21522156
sizes = sizes[2:]
21532157
else:
2154-
assert isinstance(
2155-
sizes, relax.Constant
2156-
), "Only constant output size currently supported."
2157-
sizes = sizes.data.numpy().astype("int64").tolist()[2:]
2158+
if isinstance(sizes, relax.Constant):
2159+
sizes = sizes.data.numpy().astype("int64").tolist()[2:]
2160+
elif isinstance(sizes, relax.expr.ShapeExpr):
2161+
sizes = [int(val.value) for val in sizes.values][2:]
2162+
else:
2163+
assert f"Type {type(size)} for size is currently unsupported."
21582164

21592165
return relax.op.image.resize2d(
21602166
x,
@@ -3751,6 +3757,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
37513757
# convert it to a tensor.
37523758
shape_compatible_ops = [
37533759
"Reshape",
3760+
"Resize",
37543761
"ConstantOfShape",
37553762
"Gather",
37563763
"Slice",

0 commit comments

Comments
 (0)