Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
}
}; // struct GatherNDAttrs

/*! \brief Attributes used in index_put operator */
struct IndexPutAttrs : public tvm::AttrsNode<IndexPutAttrs> {
bool accumulate;

TVM_DECLARE_ATTRS(IndexPutAttrs, "relax.attrs.IndexPutAttrs") {
TVM_ATTR_FIELD(accumulate)
.set_default(false)
.describe(
"Whether to accumulate (add) values rather than replace. "
"If true, performs tensor[indices] += values, "
"otherwise performs tensor[indices] = values.");
}
}; // struct IndexPutAttrs

/*! \brief Attributes used in scatter_elements operators */
struct ScatterElementsAttrs : public tvm::AttrsNode<ScatterElementsAttrs> {
Integer axis;
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,23 @@ def _gather(self, node: fx.Node) -> relax.Var:
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim))

def _index_put(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
tensor = args[0]
indices = args[1] if len(args) > 1 else node.kwargs.get("indices")
values = args[2] if len(args) > 2 else node.kwargs.get("values")
accumulate = args[3] if len(args) > 3 else node.kwargs.get("accumulate", False)

if indices is None or values is None:
raise ValueError("'indices and values' arguments are required for index_put operation")

if not isinstance(accumulate, bool):
raise TypeError("'accumulate' must be a boolean value, got {}".format(type(accumulate)))

if isinstance(indices, (list, tuple)):
indices = relax.Tuple(indices)
return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate))

def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
indices = args[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ def create_convert_map(
"flip.default": self._flip,
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
"index_put_.default": self._index_put,
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,7 @@ def create_convert_map(
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
"index_put_": self._index_put,
"narrow": self._narrow,
"numel": self._numel,
"permute": self._permute,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
flip,
gather_elements,
gather_nd,
index_put,
index_tensor,
layout_transform,
one_hot,
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,57 @@ def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr:
return _ffi_api.index_tensor(data, indices) # type: ignore


def index_put(
data: Expr,
indices: Union[Expr, Tuple[Expr]],
values: Expr,
accumulate: bool = False,
) -> Expr:
"""This operation updates values in `data` at positions
specified by `indices` with corresponding values from `values`. The `indices` is a tuple
of tensors where each tensor corresponds to a dimension in `data`.
When `accumulate` is True, the operation performs accumulation (addition) rather than
replacement. The `reduction` parameter allows specifying different reduction operations.
Parameters
----------
data : relax.Expr
The input tensor to be modified
indices : Union[Expr, Tuple[Expr]]
Tuple of index tensors (one for each dimension) specifying positions to update
values : relax.Expr
Values to place at the specified indices
accumulate : bool
Whether to accumulate (add) values rather than replace (default: False)

Returns
-------
result : relax.Expr
A new tensor with the same shape as data but with specified positions updated
Examples
--------
.. code-block:: python
# inputs
data = torch.zeros(3, 3)
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
values = torch.tensor([1.0, 2.0])
# output
output = [
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 2.0, 0.0],
]
# with accumulate=True
output = [
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0],
[0.0, 3.0, 0.0],
]
"""
if not isinstance(indices, (list, tuple)):
indices = RxTuple(indices)
return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore


def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class StackAttrs(Attrs):
"""Attributes for concat operator"""


@tvm._ffi.register_object("relax.attrs.IndexPutAttrs")
class IndexPutAttrs(Attrs):
"""Attributes for index_put operator"""


@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,28 @@ def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.index_tensor, call.args[0], fields)


@register_legalize("relax.index_put")
def _index_put(bb: BlockBuilder, call: Call) -> Expr:
data = call.args[0]
indices = call.args[1]
values = call.args[2]
accumulate = call.attrs.accumulate

# If indices is a Tuple, unpack it into individual tensors
if isinstance(indices, relax.Tuple):
indices_list = [indices.fields[i] for i in range(len(indices.fields))]
else:
indices_list = [indices]

return bb.call_te(
topi.index_put,
data,
indices_list,
values,
accumulate=accumulate,
)


@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
greater,
greater_equal,
hint_on_device,
index_put,
image,
index_tensor,
invoke_closure,
Expand Down Expand Up @@ -785,6 +786,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"greater_equal",
"hexagon",
"hint_on_device",
"index_put",
"image",
"index_tensor",
"invoke_closure",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .math import *
from .tensor import *
from .generic_op_impl import *
from .index_put import *
from .reduction import *
from .transform import *
from .broadcast import *
Expand Down
117 changes: 117 additions & 0 deletions python/tvm/topi/index_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contrir_builderutor license agreements. See the NOTICE file
# distrir_builderuted with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distrir_builderuted under the License is distrir_builderuted on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""IndexPut operator"""
from tvm import te
from tvm import tir
from . import utils


def index_put(data, indices, values, accumulate=False):
"""Put values into an array according to indices.

Parameters
----------
data : tvm.te.Tensor
The source array to be modified.

indices : Tuple[tvm.te.Tensor]
Tuple of 1D index tensors (one for each dimension) specifying positions.

values : tvm.te.Tensor
The values to place at the specified indices.

accumulate : bool, optional
Whether to accumulate (add) values rather than replace.
If True, performs tensor[indices] += values
If False, performs tensor[indices] = values
Default is False.

Returns
-------
ret : tvm.te.Tensor
"""
if not isinstance(indices, (list, tuple)):
indices = [indices]

# Check indices match data dimensions
if len(indices) != len(data.shape):
raise ValueError(
f"Number of index tensors ({len(indices)}) must match "
f"data dimensions ({len(data.shape)})"
)

# Prepare ranges and strides
shape = data.shape
full_range = 1
for dim in shape:
full_range *= dim

# Check all indices have same length
index_len = len(indices[0])
for idx in indices[1:]:
if not utils.equal_const_int(len(idx), index_len):
raise ValueError("All index tensors must have same length")

def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
ir_builder = tir.ir_builder.create()

data = ir_builder.buffer_ptr(data_ptr)
indices = [ir_builder.buffer_ptr(idx) for idx in index_ptrs]
values = ir_builder.buffer_ptr(values_ptr)
out = ir_builder.buffer_ptr(out_ptr)

with ir_builder.for_range(0, full_range, "i", kind="parallel") as i:
out[i] = data[i]

with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
# Calculate multi-dimensional index
flat_index = 0
stride = 1
for dim in range(len(shape) - 1, -1, -1):
# Get index and shift to positive if needed
idx_val = indices[dim][k]
shifted_idx = idx_val + (idx_val < 0) * shape[dim]
flat_index += shifted_idx * stride
stride *= shape[dim]

reduce_func(out, flat_index, values[k])

return ir_builder.get()

def update_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] = update

def add_func(dst_ptr, dst_index, update):
dst_ptr[dst_index] += update

reduce_func = add_func if accumulate else update_func

# Prepare input buffers
in_buffers = [data]
in_buffers.extend(indices)
in_buffers.append(values)

out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf")
return te.extern(
[data.shape],
in_buffers,
lambda ins, outs: gen_ir(ins[0], ins[1:-1], ins[-1], outs[0], reduce_func),
dtype=data.dtype,
out_buffers=[out_buf],
name="index_put.generic",
tag="index_put.generic",
)
Loading
Loading