Skip to content

Commit f5d3fc2

Browse files
authored
[Relax][Frontend][Onnx] Cast Op special handling for ShapeExpr input (#17061)
Co-authored-by: tsu-bin <[email protected]>
1 parent 1c05902 commit f5d3fc2

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,11 @@ class Cast(OnnxOpConverter):
442442
@classmethod
443443
def _impl_v13(cls, bb, inputs, attr, params):
444444
to_type = get_type(attr["to"])
445+
if isinstance(inputs[0], relax.ShapeExpr):
446+
shape = inputs[0]
447+
if all([isinstance(x, tir.IntImm) for x in shape]):
448+
shape = [int(x) for x in shape]
449+
return relax.const(shape, to_type)
445450
if isinstance(inputs[0], relax.Constant):
446451
output = inputs[0].data.numpy().astype(to_type)
447452
return relax.const(output, to_type)
@@ -2210,6 +2215,7 @@ def _construct_nodes(self, graph: onnx.onnx_ml_pb2.GraphProto):
22102215
"Concat",
22112216
"Equal",
22122217
"Where",
2218+
"Cast",
22132219
]
22142220
for i, inp in enumerate(inputs):
22152221
if (

0 commit comments

Comments
 (0)