Skip to content

Commit 9f6ce7c

Browse files
[relay][frontend][pytorch]Fix a bug in the _get_pytorch_value_type function (#14421)
* Fix a bug in the _get_pytorch_value_type function * Fix lint
1 parent 0e28541 commit 9f6ce7c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4493,7 +4493,7 @@ def _get_pytorch_value_type(typ, default_dtype="float32"):
44934493
return "ListType"
44944494
elif kind in ["IntType", "FloatType", "BoolType", "StringType", "OptionalType"]:
44954495
pt_dtype = str(typ).lower()
4496-
dtype = pt_dtype if pt_dtype == "OptionalType" else _convert_data_type(pt_dtype)
4496+
dtype = pt_dtype if kind == "OptionalType" else _convert_data_type(pt_dtype)
44974497
return dtype
44984498
else:
44994499
return "UnsupportedType"

0 commit comments

Comments
 (0)