Skip to content

Commit 0c20c01

Browse files
committed
fix a typo mistake
1 parent 0b2358c commit 0c20c01

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2680,7 +2680,7 @@ def nonzero(self, inputs, input_types, is_numpy_style=False):
26802680
return ret
26812681

26822682
def nonzero_numpy(self, inputs, input_types):
2683-
return self.nonzero(inputs, input_types, is_numpy_style=False)
2683+
return self.nonzero(inputs, input_types, is_numpy_style=True)
26842684

26852685
def scatter(self, inputs, input_types):
26862686
assert len(inputs) == 4 or len(inputs) == 5, (
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
18+
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
19+
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
20+
# pylint: disable=missing-function-docstring, redefined-builtin, use-implicit-booleaness-not-comparison
21+
"""Tests to ensure nonzero_numpy are correctly"""
22+
from torch import nn
23+
import torch
24+
import tvm
25+
26+
27+
class NonZeroModule(nn.Module):
28+
"""Module that performs nonzero"""
29+
30+
def __init__(self):
31+
super().__init__()
32+
33+
def forward(self, x, mask):
34+
mask_index = torch.nonzero(mask, as_tuple=True)
35+
x[mask_index] = torch.ones_like(x[mask_index])
36+
return x
37+
38+
def test_pytorch_nonzero():
39+
model = NonZeroModule()
40+
x = torch.zeros((2, 10), dtype=torch.float32)
41+
mask = torch.randint(0, 2, (2, 10)).bool()
42+
with torch.no_grad():
43+
traced_torch_model = torch.jit.trace(model, (x, mask))
44+
import_input = [("input0", (2, 10)), ("input1", (2, 10))]
45+
relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch(
46+
traced_torch_model, import_input
47+
)
48+
49+

0 commit comments

Comments
 (0)