Skip to content

Commit 7aecc1a

Browse files
[Torch] Fix advanced indexing with NoneType index arguments (#13826)
[Torch] Fix advanced indexing with NoneType index
1 parent 9008ec2 commit 7aecc1a

File tree

2 files changed

+70
-5
lines changed

2 files changed

+70
-5
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,19 +2330,49 @@ def one_hot(self, inputs, input_types):
23302330

23312331
def index(self, inputs, input_types):
23322332
data = inputs[0]
2333+
data_shape = self.infer_type(data).shape
2334+
2335+
axes_adv_idx = [i for i, v in enumerate(inputs[1]) if v is not None]
2336+
axes_rest = [i for i in range(len(data_shape)) if i not in axes_adv_idx]
2337+
2338+
# check if the adv_index axes are consecutive
2339+
# if consecutive, result must be transposed again at the end
2340+
consecutive = True
2341+
for curr, nxt in zip(axes_adv_idx[:-1], axes_adv_idx[1:]):
2342+
if nxt - curr != 1:
2343+
consecutive = False
2344+
break
2345+
23332346
indices_list = []
2347+
axes_order = axes_adv_idx + axes_rest
23342348

2335-
for indices in inputs[1]:
2336-
if self.infer_type(indices).dtype == "bool":
2349+
for i in axes_adv_idx:
2350+
inp = inputs[1][i]
2351+
if self.infer_type(inp).dtype == "bool":
23372352
# adv_index does not support a mask as the index tensor (it will treat 0/1 as
23382353
# an index rather than a flag).
23392354
# So we use argwhere to turn the mask into indices, which will also take care
23402355
# of the dynamism in the indexing by mask.
2341-
indices_list.append(_op.squeeze(_op.transform.argwhere(indices), axis=[1]))
2356+
indices_list.append(_op.squeeze(_op.transform.argwhere(inp), axis=[1]))
23422357
else:
2343-
indices_list.append(indices)
2358+
indices_list.append(inp)
2359+
2360+
data_after_adv_index = _op.adv_index([_op.transpose(data, axes=axes_order)] + indices_list)
23442361

2345-
return _op.adv_index([data] + indices_list)
2362+
if consecutive:
2363+
num_dims = len(self.infer_type(data_after_adv_index).shape)
2364+
num_new_dims = num_dims - len(axes_rest)
2365+
2366+
axes_final_order = list(range(num_dims))
2367+
axes_final_order = (
2368+
axes_final_order[num_new_dims : num_new_dims + axes_adv_idx[0]]
2369+
+ axes_final_order[:num_new_dims]
2370+
+ axes_final_order[num_new_dims + axes_adv_idx[0] :]
2371+
)
2372+
2373+
return _op.transpose(data_after_adv_index, axes=axes_final_order)
2374+
else:
2375+
return data_after_adv_index
23462376

23472377
def meshgrid(self, inputs, input_types):
23482378
data = inputs[0]

tests/python/frontend/pytorch/test_forward.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4034,6 +4034,41 @@ def forward(self, x):
40344034
input_data = torch.rand(input_shape).float()
40354035
verify_model(Index1().eval(), input_data=input_data)
40364036

4037+
class Index2(Module):
4038+
def forward(self, x):
4039+
return x[None, [2, 2]]
4040+
4041+
input_data = torch.rand(input_shape).float()
4042+
verify_model(Index2().eval(), input_data=input_data)
4043+
4044+
class Index3(Module):
4045+
def forward(self, x):
4046+
return x[None, [0, 1, 2], 1, [2, 3, 4]]
4047+
4048+
input_data = torch.rand(input_shape).float()
4049+
verify_model(Index3().eval(), input_data=input_data)
4050+
4051+
class Index4(Module):
4052+
def forward(self, x):
4053+
return x[None, [0, 0], None, np.array([[0], [1], [2]]), None]
4054+
4055+
input_data = torch.rand(input_shape).float()
4056+
verify_model(Index4().eval(), input_data=input_data)
4057+
4058+
class Index5(Module):
4059+
def forward(self, x):
4060+
return x[None, None, [0, 0], np.array([[0], [1], [2]]), None]
4061+
4062+
input_data = torch.rand(input_shape).float()
4063+
verify_model(Index5().eval(), input_data=input_data)
4064+
4065+
class Index6(Module):
4066+
def forward(self, x):
4067+
return x[None, 1, None, [1, 2, 3]]
4068+
4069+
input_data = torch.rand(input_shape).float()
4070+
verify_model(Index6().eval(), input_data=input_data)
4071+
40374072
def test_fn_bool_mask():
40384073
return lambda data, mask: data[0, mask]
40394074

0 commit comments

Comments
 (0)