Skip to content

Commit 57800b3

Browse files
MasterJH5574tqchen
authored andcommitted
[Unity] Relax op: index (#13987)
This PR is about the high-level tensor computation operators in Relax. This PR includes the tensor indexing operators.
1 parent 3e8560e commit 57800b3

File tree

9 files changed

+1122
-0
lines changed

9 files changed

+1122
-0
lines changed

include/tvm/relax/attrs/index.h

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/index.h
22+
* \brief Attributes for indexing operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_INDEX_H_
25+
#define TVM_RELAX_ATTRS_INDEX_H_
26+
27+
#include <tvm/relax/expr.h>
28+
29+
namespace tvm {
30+
namespace relax {
31+
32+
/*! \brief Attributes used in take operator */
33+
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
34+
Optional<Integer> axis;
35+
36+
TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") {
37+
TVM_ATTR_FIELD(axis).describe("The axis over which to select values.");
38+
}
39+
}; // struct TakeAttrs
40+
41+
/*! \brief Attributes used in strided_slice operator */
42+
struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
43+
Array<Integer> axes;
44+
Array<PrimExpr> begin;
45+
Array<PrimExpr> end;
46+
Optional<Array<PrimExpr>> strides;
47+
48+
TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") {
49+
TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied.");
50+
TVM_ATTR_FIELD(begin).describe("The indices to begin with in the slicing, inclusive.");
51+
TVM_ATTR_FIELD(end).describe("The indices indicating end of the slice, exclusive.");
52+
TVM_ATTR_FIELD(strides).describe(
53+
"Specifies the stride values, it can be negative in that case, the input tensor will be "
54+
"reversed in that particular axis. If not specified, it by default is an list of ones of "
55+
"the same length as `axes`.");
56+
}
57+
}; // struct StridedSliceAttrs
58+
59+
} // namespace relax
60+
} // namespace tvm
61+
62+
#endif // TVM_RELAX_ATTRS_INDEX_H_

python/tvm/relax/op/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
# Operators
2121
from .base import *
2222
from .binary import *
23+
from .index import *
2324
from .manipulate import *
25+
from .op_attrs import *
2426
from . import builtin
2527
from . import memory

python/tvm/relax/op/index.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Indexing operators."""
18+
from typing import List, Optional, Union
19+
20+
from tvm.ir.expr import PrimExpr
21+
22+
from . import _ffi_api
23+
from ..expr import Expr
24+
25+
PrimExprLike = Union[int, PrimExpr]
26+
27+
28+
def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
29+
"""Take elements from a tensor along an axis.
30+
31+
Parameters
32+
----------
33+
x : relax.Expr
34+
The source tensor.
35+
36+
indices : relax.Expr
37+
The indices of the values to extract.
38+
It is required to be a one-dimensional tensor which has integer dtype.
39+
40+
axis : Optional[int]
41+
The axis over which to select values.
42+
If it is none, the input tensor is required to be one-dimensional.
43+
44+
Returns
45+
-------
46+
ret : relax.Expr
47+
The taken result.
48+
"""
49+
return _ffi_api.take(x, indices, axis) # type: ignore
50+
51+
52+
def strided_slice(
53+
x: Expr,
54+
axes: List[int],
55+
begin: List[PrimExprLike],
56+
end: List[PrimExprLike],
57+
strides: Optional[List[PrimExprLike]] = None,
58+
) -> Expr:
59+
"""Strided slice of a tensor.
60+
61+
Parameters
62+
----------
63+
x : relax.Expr
64+
The source tensor to be sliced.
65+
66+
axes : List[int]
67+
Axes along which slicing is applied.
68+
69+
begin : List[PrimExprLike]
70+
The indices to begin with in the slicing, inclusive.
71+
72+
end : List[PrimExprLike]
73+
The indices indicating end of the slice, exclusive.
74+
75+
strides : Optional[List[PrimExprLike]]
76+
Specifies the stride values, it can be negative in that case,
77+
the input tensor will be reversed in that particular axis.
78+
If not specified, it by default is an list of ones of the same length as `axes`.
79+
80+
Returns
81+
-------
82+
ret : relax.Expr
83+
The sliced result.
84+
85+
Note
86+
----
87+
strided_slice require the input `begin`, `end` and `strides` to have the
88+
same length as `axes`.
89+
"""
90+
return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore

python/tvm/relax/op/op_attrs.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""The attributes node used for Relax operators"""
18+
from tvm.ir import Attrs
19+
import tvm._ffi
20+
21+
22+
@tvm._ffi.register_object("relax.attrs.TakeAttrs")
23+
class TakeAttrs(Attrs):
24+
"""Attributes used in take operator"""
25+
26+
27+
@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs")
28+
class StridedSliceAttrs(Attrs):
29+
"""Attributes used in strided_slice operator"""

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
print,
4343
reshape,
4444
shape_of,
45+
strided_slice,
46+
take,
4547
)
4648
from tvm.relax.struct_info import StructInfo
4749
from tvm.relax.utils import args_converter
@@ -427,5 +429,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
427429
"shape",
428430
"shape_of",
429431
"str",
432+
"strided_slice",
433+
"take",
430434
"tuple",
431435
]

0 commit comments

Comments
 (0)