Skip to content

Commit 28d3eda

Browse files
modified conditions for parameters
1 parent 431722f commit 28d3eda

File tree

5 files changed

+15
-16
lines changed

5 files changed

+15
-16
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,21 +1148,21 @@ def _gather(self, node: fx.Node) -> relax.Var:
11481148
index = self.env[node.args[2]]
11491149
return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim))
11501150

1151-
def _index_put_(self, node: fx.Node) -> relax.Var:
1151+
def _index_put(self, node: fx.Node) -> relax.Var:
11521152
args = self.retrieve_args(node)
11531153
tensor = args[0]
1154-
indices = args[1] if len(args) > 1 else node.kwargs.get("indices", ())
1154+
indices = args[1] if len(args) > 1 else node.kwargs.get("indices")
11551155
values = args[2] if len(args) > 2 else node.kwargs.get("values")
11561156
accumulate = args[3] if len(args) > 3 else node.kwargs.get("accumulate", False)
11571157

1158-
# Ensure accumulate is a boolean
1159-
if isinstance(accumulate, str):
1160-
accumulate = accumulate.lower() == "true"
1161-
elif not isinstance(accumulate, bool):
1162-
accumulate = bool(accumulate)
1158+
if indices is None or values is None:
1159+
raise ValueError("'indices and values' arguments are required for index_put operation")
1160+
1161+
if not isinstance(accumulate, bool):
1162+
raise TypeError("'accumulate' must be a boolean value, got {}".format(type(accumulate)))
11631163

11641164
if isinstance(indices, (list, tuple)):
1165-
indices = relax.Tuple(indices) if indices else relax.Tuple([])
1165+
indices = relax.Tuple(indices)
11661166
return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate))
11671167

11681168
def _permute(self, node: fx.Node) -> relax.Var:

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def create_convert_map(
420420
"flatten.using_ints": self._flatten,
421421
"flip.default": self._flip,
422422
"gather.default": self._gather,
423-
"index_put_.default": self._index_put_,
423+
"index_put_.default": self._index_put,
424424
"narrow.default": self._narrow,
425425
"permute.default": self._permute,
426426
"repeat.default": self._repeat,

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ def create_convert_map(
746746
"flatten": self._flatten,
747747
"flip": self._flip,
748748
"gather": self._gather,
749-
"index_put_": self._index_put_,
749+
"index_put_": self._index_put,
750750
"narrow": self._narrow,
751751
"numel": self._numel,
752752
"permute": self._permute,

python/tvm/relax/op/manipulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ def index_put(
588588
]
589589
"""
590590
if not isinstance(indices, (list, tuple)):
591-
indices = RxTuple(indices) if indices else RxTuple([])
591+
indices = RxTuple(indices)
592592
return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore
593593

594594

python/tvm/topi/index_put.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tvm import te
1919
from tvm import tir
2020
from . import utils
21+
import math
2122

2223

2324
def index_put(data, indices, values, accumulate=False):
@@ -56,14 +57,12 @@ def index_put(data, indices, values, accumulate=False):
5657

5758
# Prepare ranges and strides
5859
shape = data.shape
59-
full_range = 1
60-
for dim in shape:
61-
full_range *= dim
60+
full_range = math.prod(data.shape)
6261

6362
# Check all indices have same length
64-
index_len = indices[0].shape[0]
63+
index_len = len(indices[0])
6564
for idx in indices[1:]:
66-
if not utils.equal_const_int(idx.shape[0], index_len):
65+
if not utils.equal_const_int(len(idx), index_len):
6766
raise ValueError("All index tensors must have same length")
6867

6968
def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):

0 commit comments

Comments
 (0)