Skip to content

Commit 196b413

Browse files
authored
[Relay][Frontend][Torch] fix a typo mistake in nonzero_numpy (#16390)
fix a typo mistake in pytorch frontend nonzero_numpy
1 parent 4258c86 commit 196b413

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2680,7 +2680,7 @@ def nonzero(self, inputs, input_types, is_numpy_style=False):
26802680
return ret
26812681

26822682
def nonzero_numpy(self, inputs, input_types):
2683-
return self.nonzero(inputs, input_types, is_numpy_style=False)
2683+
return self.nonzero(inputs, input_types, is_numpy_style=True)
26842684

26852685
def scatter(self, inputs, input_types):
26862686
assert len(inputs) == 4 or len(inputs) == 5, (

tests/python/frontend/pytorch/test_forward.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4445,6 +4445,7 @@ def forward(self, data):
44454445

44464446
inp = torch.Tensor(np.array([[0, 1, 0], [2, 0, 9], [-1, -1, 0]]).astype("float32"))
44474447
verify_trace_model(Nonzero(), [inp], ["llvm"])
4448+
verify_trace_model(Nonzero(as_tuple=True), [inp], ["llvm"])
44484449

44494450

44504451
def test_forward_scatter():

0 commit comments

Comments
 (0)