Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
0c8c904
suddenly copy.default is unsupported
hugolatendresse Mar 16, 2025
3ae2b79
Merge branch 'main' into split
hugolatendresse Mar 16, 2025
f7f0637
wip
hugolatendresse Mar 16, 2025
7e4cf05
Able to split uneven tensors!
hugolatendresse Mar 16, 2025
dcbee0c
split size test passes!
hugolatendresse Mar 16, 2025
75890ce
test sizes and lists
hugolatendresse Mar 16, 2025
5a8eab1
just one func
hugolatendresse Mar 16, 2025
c771389
cleanup
hugolatendresse Mar 16, 2025
2fbe4c1
no assert
hugolatendresse Mar 16, 2025
e5095b8
linting
hugolatendresse Mar 16, 2025
490a454
chunk
hugolatendresse Mar 16, 2025
1ef561e
Merge branch 'main' of https://github.com/apache/tvm into split_uneven
hugolatendresse Mar 20, 2025
ec7311c
remove unsused modulo
hugolatendresse Mar 20, 2025
0744701
fixed first test
hugolatendresse Mar 20, 2025
40f1711
fixed second test and lint
hugolatendresse Mar 20, 2025
ca2bf9a
Merge branch 'main' into chunk2
hugolatendresse Mar 24, 2025
e20a43f
Merge branch 'split_uneven' into chunk2
hugolatendresse Mar 24, 2025
7bdc0bf
merge main
hugolatendresse Mar 24, 2025
4582c3a
linting
hugolatendresse Mar 24, 2025
afa793a
fix one test
hugolatendresse Mar 24, 2025
d71b518
chunk not passing anymore
hugolatendresse Mar 24, 2025
e95aef6
get_item error
hugolatendresse Mar 24, 2025
bc50446
chunk unit tests
hugolatendresse Mar 24, 2025
9951924
fix conflicts
hugolatendresse Mar 26, 2025
00f04b1
Merge branch 'chunk2' into index_tensor
hugolatendresse Mar 27, 2025
db5ec01
index select test passes
hugolatendresse Mar 28, 2025
859ca17
merge main
hugolatendresse Mar 28, 2025
c39e6e1
fix test
hugolatendresse Mar 28, 2025
f8d50f2
cleanup
hugolatendresse Mar 28, 2025
086410c
index_select
hugolatendresse Mar 29, 2025
97ada56
arange.default ok
hugolatendresse Mar 30, 2025
3431e18
all arange tests pass
hugolatendresse Mar 30, 2025
5c0aa4f
arange test complete
hugolatendresse Mar 30, 2025
343fb58
merge arange
hugolatendresse Mar 30, 2025
fc9155e
merge index select
hugolatendresse Mar 30, 2025
353c399
dummy tensor.Index compiles - ready to test runnign
hugolatendresse Mar 31, 2025
3ed1f1b
type error in dummy test
hugolatendresse Mar 31, 2025
1d1606c
failing a check
hugolatendresse Mar 31, 2025
3ce339c
codegen error
hugolatendresse Mar 31, 2025
a8c7185
code gen error
hugolatendresse Mar 31, 2025
eeefce6
explode broadcast_shapes
hugolatendresse Mar 31, 2025
330f165
combine pending PRs
hugolatendresse Mar 31, 2025
c4f5e7e
Merge branch 'pending_PRs' into index.Tensor
hugolatendresse Mar 31, 2025
3b8a4a0
debugging
hugolatendresse Mar 31, 2025
c12968e
merge main
hugolatendresse Apr 6, 2025
4b990e4
first week of work, index.Tensor branch
hugolatendresse Apr 6, 2025
77dbc3b
testing
hugolatendresse Apr 7, 2025
557885b
able to get an output with dummy relax.op.collapse_sum_like
hugolatendresse Apr 13, 2025
6b664cc
whether I can use collapse sum like TWO dpeends on whether I legalize…
hugolatendresse Apr 13, 2025
0adab38
Merge remote-tracking branch 'origin/main' into index.Tensor
hugolatendresse Apr 13, 2025
a411662
able to call relax.collapse_sum_like_TWO
hugolatendresse Apr 13, 2025
c418fc5
_TWO works with regualar registration
hugolatendresse Apr 13, 2025
c497eaf
both topi options work
hugolatendresse Apr 13, 2025
258467d
merge tensor1 and tensor3 together
hugolatendresse Apr 13, 2025
7438ce5
can still output from _TWO after merge index.Tensor and index.Tensor3
hugolatendresse Apr 13, 2025
9d6270b
still able to output after building (hadn't built after merge)
hugolatendresse Apr 13, 2025
29fbc01
must do a topi op in transform.py to get an output
hugolatendresse Apr 13, 2025
ff88c5a
error re. IndexTensorAttrs
hugolatendresse Apr 13, 2025
4df949a
able to get an ouptut from index.tenosr !
hugolatendresse Apr 13, 2025
d9f2589
need to isolate what goes wrong in building
hugolatendresse Apr 13, 2025
e6ae241
ok
hugolatendresse Apr 13, 2025
8db2a09
ok
hugolatendresse Apr 13, 2025
301bc40
ok
hugolatendresse Apr 13, 2025
57d01ff
ok
hugolatendresse Apr 13, 2025
d04a814
ok
hugolatendresse Apr 13, 2025
098c541
calculate 3 results. result2 and result3 have same type but only resu…
hugolatendresse Apr 13, 2025
82775dc
gets correctness with topi.take
hugolatendresse Apr 13, 2025
82e9edd
passing index1D and index2D testsgit status!
hugolatendresse Apr 13, 2025
3df31fc
first 3 tests pass
hugolatendresse Apr 13, 2025
142a16d
other test
hugolatendresse Apr 14, 2025
bfb2ea9
all tests written. 0 to 5 pass, 6 to 8 fail
hugolatendresse Apr 14, 2025
03d1124
added full
hugolatendresse Apr 14, 2025
e13fa3d
unit test
hugolatendresse Apr 14, 2025
5b23c30
full.default
hugolatendresse Apr 14, 2025
35aee29
linting
hugolatendresse Apr 14, 2025
5c0e18b
ones ok
hugolatendresse Apr 14, 2025
40316a0
tests for ones, full, and full like work
hugolatendresse Apr 14, 2025
2bb3f15
merge full.default
hugolatendresse Apr 14, 2025
ac33a59
before switchign to list[Expr]
hugolatendresse Apr 14, 2025
eddcd39
able to get list of tensors in topi
hugolatendresse Apr 14, 2025
6f62e0d
unable to reproduce results for second case
hugolatendresse Apr 14, 2025
65c6ba2
not working
hugolatendresse Apr 14, 2025
85737ec
dummy concat works
hugolatendresse Apr 14, 2025
64f9297
concat2 doesn't work either
hugolatendresse Apr 14, 2025
aa4cfd8
original concat works
hugolatendresse Apr 14, 2025
1eba32b
concat2 works as a perfect copy of concat
hugolatendresse Apr 14, 2025
4e096d1
concat2 works! now need to strip as much as possible, and then conver…
hugolatendresse Apr 14, 2025
8adbc61
concat2 still passes
hugolatendresse Apr 14, 2025
1cfa024
concat2 still passes
hugolatendresse Apr 14, 2025
f0d533b
still works whne grabing first tensro
hugolatendresse Apr 14, 2025
d22b598
still works whne grabing first tensro
hugolatendresse Apr 14, 2025
636ab70
still works whne grabing first tensro
hugolatendresse Apr 14, 2025
cd0f013
adding back struct info makes test pass
hugolatendresse Apr 14, 2025
a06db3d
concat2 passes with very simple manipulate.ccgit status!
hugolatendresse Apr 14, 2025
aa08975
I pass old test2!
hugolatendresse Apr 14, 2025
f23b786
7 tests pass!
hugolatendresse Apr 14, 2025
583f1e5
all tests passgit statusgit status
hugolatendresse Apr 14, 2025
4ce35b9
all tests still psas
hugolatendresse Apr 14, 2025
6cf94b6
passes every single test
hugolatendresse Apr 14, 2025
e5e9688
won't compile anymore
hugolatendresse Apr 14, 2025
e834661
all tests pass
hugolatendresse Apr 14, 2025
2f87583
minimum manipulate.cc -> all tests pass
hugolatendresse Apr 14, 2025
c2de1e2
cleanup
hugolatendresse Apr 14, 2025
7dd9e8f
removed axis, all tests pass, doc TODOs remain
hugolatendresse Apr 14, 2025
a003829
merge main
hugolatendresse Apr 14, 2025
2f2505f
resolve conflict
hugolatendresse Apr 14, 2025
41bf141
linting
hugolatendresse Apr 14, 2025
4aac8e6
Merge branch 'main' of https://github.com/apache/tvm into indextensor…
hugolatendresse Apr 16, 2025
01b23f0
pass correctness with first indices shape
hugolatendresse Apr 16, 2025
982719f
corretness passes with indices shape
hugolatendresse Apr 16, 2025
b124846
correctness checks pass!
hugolatendresse Apr 17, 2025
f1737b2
cleanup - still passes correctness
hugolatendresse Apr 17, 2025
47adecf
comments
hugolatendresse Apr 17, 2025
8626f59
all pass
hugolatendresse Apr 17, 2025
93c0ac1
combine into one test. all pass
hugolatendresse Apr 17, 2025
76a3e21
blank line
hugolatendresse Apr 17, 2025
695edf1
docs indentation
hugolatendresse Apr 17, 2025
e505fec
lint
hugolatendresse Apr 17, 2025
7887502
dummy whitespace change to trigger tests
hugolatendresse Apr 17, 2025
9f260f2
whitespace
hugolatendresse Apr 17, 2025
d70fba1
whitespace
hugolatendresse Apr 17, 2025
d620c4f
dummy whitespace change to trigger tests
hugolatendresse Apr 17, 2025
3e28c66
Merge branch 'main' into indextensor_1_and_3
hugolatendresse Apr 17, 2025
46841a3
no backtracking
hugolatendresse Apr 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,11 @@ def _gather(self, node: fx.Node) -> relax.Var:
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.gather_elements(x, index, axis=dim))

def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
indices = args[1]
return self.block_builder.emit(relax.op.index_tensor(args[0], indices))

def _permute(self, node: fx.Node) -> relax.Var:
import torch # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def create_convert_map(
"flatten.using_ints": self._flatten,
"flip.default": self._flip,
"gather.default": self._gather,
"index.Tensor": self._index_tensor,
"narrow.default": self._narrow,
"permute.default": self._permute,
"repeat.default": self._repeat,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
flip,
gather_elements,
gather_nd,
index_tensor,
layout_transform,
one_hot,
permute_dims,
Expand Down
63 changes: 63 additions & 0 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,69 @@ def gather_nd(data: Expr, indices: Expr, batch_dims: int = 0) -> Expr:
return _ffi_api.gather_nd(data, indices, batch_dims) # type: ignore


def index_tensor(data: Expr, indices: Union[Expr, List[Expr]]) -> Expr:
"""Advanced‑tensor indexing (NumPy/PyTorch‐style).

Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
operator selects elements from ``data`` as if one had written
``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:

All index tensors must have an integer dtype.

Their shapes are broadcast together to a common shape ``B`` in
the usual NumPy way.

The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
shape followed by the remaining axes of ``data`` that are *not*
indexed).

At compile‑time Relax checks that the number of index tensors
``k`` does not exceed ``data.ndim``, that the dtypes are integer,
and that the shapes are consitent (broadcast‑compatible).

Parameters
----------
data : relax.Expr
The input tensor to be indexed.

indices : Union[relax.Expr, List[relax.Expr]]
A Tuple expression containing the index tensors,
or a Python ``list`` / ``tuple`` that will be promoted to a
tuple expression automatically. Each tensor must have an
integer dtype.

Returns
-------
result : relax.Expr
The tensor obtained after advanced indexing. Its dtype equals
``data.dtype``

Examples
--------
.. code-block:: python

import numpy as np
import tvm.relax as R

x = R.const(np.arange(9).reshape(3, 3).astype("float32"))
row = R.const(np.array([0, 2])) # shape (2,)
col = R.const(np.array([1, 0])) # shape (2,)

y = R.index_tensor(x, [row, col])
# y.shape == (2,) ; y == [1., 6.]

# Broadcasting: row : (2,1), col : (1,3) → B = (2,3)
row = R.const(np.array([[0],[1]]))
col = R.const(np.array([[0,1,2]]))
z = R.index_tensor(x, [row, col])
# z.shape == (2,3)

"""
if isinstance(indices, (list, tuple)):
indices = RxTuple(indices)
return _ffi_api.index_tensor(data, indices) # type: ignore


def scatter_elements(
data: Expr, indices: Expr, updates: Expr, axis: int = 0, reduction: str = "update"
):
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relax/transform/legalize_ops/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def reshape_call_te(bb: BlockBuilder, call: Call):
"relax.collapse_sum_like",
_reshape(topi.collapse_sum, "collapse_sum", is_collapse_sum_like=True),
)

register_legalize("relax.collapse_sum_to", _reshape(topi.collapse_sum, "collapse_sum"))


Expand Down Expand Up @@ -162,6 +163,14 @@ def te_gather_nd(data, indices, batch_dims):
return bb.call_te(te_gather_nd, call.args[0], call.args[1], int(call.attrs.batch_dims))


@register_legalize("relax.index_tensor")
def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
t = call.args[1]
n_field = len(t.struct_info.fields)
fields = [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
return bb.call_te(topi.index_tensor, call.args[0], fields)


@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
greater_equal,
hint_on_device,
image,
index_tensor,
invoke_closure,
invoke_pure_closure,
isfinite,
Expand Down Expand Up @@ -784,6 +785,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"hexagon",
"hint_on_device",
"image",
"index_tensor",
"invoke_closure",
"invoke_pure_closure",
"isfinite",
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,3 +1052,54 @@ def _apply_trilu(*indices):
return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype))

return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE)


def index_tensor(data, indices):
"""Advanced‑tensor indexing (NumPy/PyTorch‐style).

Given k index tensors ``indices = (I0, I1, …, Ik‑1)`` this
operator selects elements from ``data`` as if one had written
``data[I0, I1, …, Ik‑1]`` in NumPy/PyTorch:

* All index tensors must have an integer dtype.
* Their shapes are broadcast together to a common shape ``B`` in
the usual NumPy way.
* The result shape is ``B + data.shape[k:]`` (i.e. the broadcast
shape followed by the remaining axes of ``data`` that are *not*
indexed).
* ``k`` must not exceed ``data.ndim``; otherwise a compile‑time
error is raised.

Parameters
----------
data : tvm.te.Tensor
The tensor to be indexed.

indices : Sequence[tvm.te.Tensor]
A Python ``list`` / ``tuple`` of **k** index tensors,
or a `tvm.te.Tensor` tuple expression. Each tensor must have an
integer dtype.

Returns
-------
result : tvm.te.Tensor
The tensor obtained after advanced indexing. Its dtype equals
``data.dtype``

Examples
--------
.. code-block:: python

x = te.placeholder((3, 3), name="x") # shape (3,3)
row = te.placeholder((2,), name="row", dtype="int32")
col = te.placeholder((2,), name="col", dtype="int32")

# Equivalent to x[row, col] in NumPy / PyTorch
y = topi.index_tensor(x, [row, col]) # shape (2,)

# Broadcasting example:
row = te.placeholder((2, 1), name="row", dtype="int32")
col = te.placeholder((1, 3), name="col", dtype="int32")
z = topi.index_tensor(x, [row, col]) # shape (2, 3)
"""
return topi.adv_index(data, indices)
145 changes: 145 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,151 @@ TVM_REGISTER_OP("relax.flatten")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.index_tensor */

Expr index_tensor(Expr first, Expr tensors) {
static const Op& op = Op::Get("relax.index_tensor");
return Call(op, {std::move(first), std::move(tensors)}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relax.op.index_tensor").set_body_typed(index_tensor);

StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 2) {
ctx->ReportFatal(Diagnostic::Error(call) << "Index.Tensor op should have 2 arguments");
}

TensorStructInfo data_sinfo = GetInputTensorStructInfo(call, 0, ctx);
Array<TensorStructInfo> indices_sinfo = GetTensorStructInfoFromTuple(call, ctx, call->args[1]);

if (indices_sinfo.empty()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor expects a non‑empty tuple of index tensors");
}

DataType output_dtype = data_sinfo->dtype;
int n_indices = static_cast<int>(indices_sinfo.size());
Optional<VDevice> vdev = data_sinfo->vdevice;

// Indices must be integers
for (int i = 0; i < n_indices; ++i) {
const auto& s = indices_sinfo[i];
if (!s->IsUnknownDtype() && !s->dtype.is_int()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor requires every index tensor to have an integer dtype; "
<< "index " << i << " has dtype " << s->dtype);
}
}

// Count of indices must be less than or equal to data.ndim
if (!data_sinfo->IsUnknownNdim() && n_indices > data_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor received " << n_indices
<< " index tensors, but data has only " << data_sinfo->ndim << " dimensions");
}

arith::Analyzer* analyzer = ctx->GetAnalyzer();
bool all_index_have_shape_value = true;
std::vector<Array<PrimExpr>> index_shapes;
int max_index_ndim = 0;

for (const auto& s : indices_sinfo) {
const auto* shp = s->shape.as<ShapeExprNode>();
if (!shp) {
all_index_have_shape_value = false;
} else {
index_shapes.push_back(shp->values);
max_index_ndim = std::max(max_index_ndim, static_cast<int>(shp->values.size()));
}
if (!s->IsUnknownNdim()) {
max_index_ndim = std::max(max_index_ndim, s->ndim);
}
}

Optional<Array<PrimExpr>> broadcast_shape;
bool shape_unknown = !all_index_have_shape_value;

if (all_index_have_shape_value) {
// initialise broadcast result with 1’s
Array<PrimExpr> out_shape;
for (int i = 0; i < max_index_ndim; ++i) {
out_shape.push_back(IntImm(DataType::Int(64), 1));
}

for (const auto& ishape : index_shapes) {
int cur_ndim = ishape.size();
for (int axis = 0; axis < max_index_ndim; ++axis) {
int lhs_axis = max_index_ndim - 1 - axis; // aligned from right
int rhs_axis = cur_ndim - 1 - axis;
if (rhs_axis < 0) break; // shorter rank – done

PrimExpr lhs_dim = out_shape[lhs_axis];
PrimExpr rhs_dim = ishape[rhs_axis];

const auto* lhs_int = lhs_dim.as<IntImmNode>();
const auto* rhs_int = rhs_dim.as<IntImmNode>();

// Case 1: current broadcast slot is 1 -> always replace
if (lhs_int && lhs_int->value == 1) {
out_shape.Set(lhs_axis, rhs_dim);
continue;
}
// Case 2: rhs is 1 -> keep lhs_dim unchanged
if (rhs_int && rhs_int->value == 1) {
continue;
}
// Both are non‑one constants: must equal
if (lhs_int && rhs_int && lhs_int->value != rhs_int->value) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "index_tensor: cannot broadcast index shapes. Mismatch at axis "
<< lhs_axis << ": " << lhs_dim << " vs " << rhs_dim);
}
// Give up if not provablt equal
if (!analyzer->CanProveEqual(lhs_dim, rhs_dim)) {
shape_unknown = true;
break;
}
}
if (shape_unknown) break;
}

if (!shape_unknown) broadcast_shape = out_shape;
}

// Count of dimensions in output
int out_ndim = kUnknownNDim;
if (!data_sinfo->IsUnknownNdim()) {
int tail_ndim = data_sinfo->ndim - n_indices;
if (broadcast_shape.defined()) {
out_ndim = static_cast<int>(broadcast_shape.value().size()) + tail_ndim;
} else if (!shape_unknown) {
out_ndim = max_index_ndim + tail_ndim;
}
}

// Derive output shape
if (broadcast_shape.defined()) {
const auto* data_shape_expr = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape_expr) {
Array<PrimExpr> result_shape = broadcast_shape.value();
for (int i = n_indices; i < data_sinfo->ndim; ++i) {
result_shape.push_back(data_shape_expr->values[i]);
}
return TensorStructInfo(ShapeExpr(result_shape), output_dtype, vdev);
}
}

// Unknown output shape
return TensorStructInfo(output_dtype, out_ndim, vdev);
}

TVM_REGISTER_OP("relax.index_tensor")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input data.")
.add_argument("indices", "List of Tensors", "The indices used to index.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoIndexTensor)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expand Down
12 changes: 12 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,18 @@ Expr gather_elements(Expr data, Expr indices, int axis = 0);
*/
Expr gather_nd(Expr data, Expr indices, int batch_dims = 0);

/*!
* \brief NumPy/PyTorch‑style advanced indexing with tensors.
* \param data The input tensor.
* \param indices A Tuple expression (or list) containing the index tensors.
* \return The indexed tensor.
*
* \note When all shapes are static, Relax checks that the index shapes are
* broadcast-compatible. Bounds checking of the values in indices is
* deferred to runtime.
*/
Expr index_tensor(Expr data, Expr indices);

/*!
* \brief Scatter updates into an array according to indices.
* \param data The input tensor.
Expand Down
Loading