Skip to content

Commit 0cf5f47

Browse files
authored
[Unity] Dispatch cumsum and sort (#16254)
* [Unity] Add dispatch for scan and sort * add test cases * Use pass instead of pattern rewriter * add test case * fix lint * fix comments * fix lint * Add target context for default pipeline * fix tests * remove 'is_scheduled'
1 parent 6f2fe45 commit 0cf5f47

File tree

17 files changed

+740
-8
lines changed

17 files changed

+740
-8
lines changed

include/tvm/relax/attrs/sort.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/attrs/sort.h
22+
* \brief Attributes for sorting operators.
23+
*/
24+
#ifndef TVM_RELAX_ATTRS_SORT_H_
25+
#define TVM_RELAX_ATTRS_SORT_H_
26+
27+
#include <tvm/relax/expr.h>
28+
#include <tvm/tir/index_map.h>
29+
30+
namespace tvm {
31+
namespace relax {
32+
33+
/*! \brief Attributes used in sort operator */
34+
struct SortAttrs : public tvm::AttrsNode<SortAttrs> {
35+
int axis;
36+
bool descending;
37+
38+
TVM_DECLARE_ATTRS(SortAttrs, "relax.attrs.SortAttrs") {
39+
TVM_ATTR_FIELD(axis).set_default(-1).describe(
40+
"Axis along which the sort is computed."
41+
"The default the last axis is used.");
42+
TVM_ATTR_FIELD(descending)
43+
.set_default(false)
44+
.describe(
45+
"Whether to sort in descending order."
46+
"If it is not specified, it defaults to the ascending order.");
47+
}
48+
}; // struct SortAttrs
49+
} // namespace relax
50+
} // namespace tvm
51+
52+
#endif // TVM_RELAX_ATTRS_SORT_H_

python/tvm/relax/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818

1919
from . import contrib
2020
from .pattern_registry import get_pattern, get_patterns_with_prefix
21+
from .dispatch_sort_scan import DispatchSortScan
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
18+
"""Dispatch sort and scan operators to platform dependent implementation."""
19+
20+
from tvm import topi
21+
from tvm.ir import Op
22+
from tvm.ir.module import IRModule
23+
from tvm.ir.transform import PassContext, module_pass
24+
from tvm.target import Target
25+
from tvm.contrib.thrust import can_use_thrust
26+
from tvm.relax import Expr, Function, Call, PyExprMutator, expr_functor, TensorStructInfo
27+
28+
29+
@expr_functor.mutator
30+
class SortScanDispatcher(PyExprMutator):
31+
"""
32+
Dispatcher to dispatch sort and scan.
33+
34+
"""
35+
36+
def __init__(self, mod):
37+
super().__init__(mod)
38+
39+
def _get_target(self, expr: Expr) -> Target:
40+
sinfo = expr.struct_info
41+
# Get target information from TensorStructInfo
42+
if isinstance(sinfo, TensorStructInfo):
43+
vdevice = sinfo.vdevice
44+
if vdevice is not None:
45+
return vdevice.target
46+
# Return the target in current context
47+
target = Target.current()
48+
if target is None:
49+
raise ValueError(
50+
"Target not found. Please ensure that the target is annotated within the module, "
51+
"or alternatively, execute this within a specified target context."
52+
)
53+
return target
54+
55+
def visit_call_(self, call: Call) -> Expr:
56+
if not isinstance(call.op, Op):
57+
return super().visit_call_(call)
58+
59+
if call.op.name == "relax.sort":
60+
tgt = self._get_target(call)
61+
with tgt:
62+
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
63+
return self.builder_.call_te(
64+
topi.cuda.sort_thrust,
65+
call.args[0],
66+
call.attrs.axis,
67+
not call.attrs.descending,
68+
)
69+
return self.builder_.call_te(
70+
topi.cuda.sort if tgt.kind.name == "cuda" else topi.sort,
71+
call.args[0],
72+
call.attrs.axis,
73+
not call.attrs.descending,
74+
)
75+
76+
if call.op.name == "relax.cumsum":
77+
tgt = self._get_target(call)
78+
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
79+
with tgt:
80+
return self.builder_.call_te(
81+
topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum,
82+
call.args[0],
83+
axis,
84+
call.attrs.dtype,
85+
)
86+
87+
return super().visit_call_(call)
88+
89+
90+
@module_pass(opt_level=0, name="DispatchSortScan")
91+
class DispatchSortScan:
92+
"""
93+
Pass to dispatch scan and sort operators to platform dependent implementation.
94+
"""
95+
96+
def transform_module(self, mod: IRModule, ctx: PassContext) -> IRModule:
97+
sort_scan_dispater = SortScanDispatcher(mod)
98+
for gv, func in mod.functions_items():
99+
if isinstance(func, Function):
100+
func = sort_scan_dispater.visit_expr(func)
101+
sort_scan_dispater.builder_.update_func(gv, func)
102+
return sort_scan_dispater.builder_.get()

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
from .qdq import quantize, dequantize
100100
from .search import argmax, argmin, where
101101
from .set import unique
102+
from .sort import sort
102103
from .statistical import cumsum, max, mean, min, prod, std, sum, variance
103104
from .ternary import ewise_fma
104105
from .unary import (

python/tvm/relax/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ class PermuteDimsAttrs(Attrs):
114114
"""Attributes for permute_dims operator"""
115115

116116

117+
@tvm._ffi.register_object("relax.attrs.SortAttrs")
118+
class SortAttrs(Attrs):
119+
"""Attributes for sort operator"""
120+
121+
117122
@tvm._ffi.register_object("relax.attrs.SplitAttrs")
118123
class SplitAttrs(Attrs):
119124
"""Attributes used in split operator"""

python/tvm/relax/op/sort.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Sortings operators."""
18+
19+
from . import _ffi_api
20+
from ..expr import Expr
21+
22+
23+
def sort(x: Expr, axis: int = -1, descending: bool = False):
24+
"""Performs sorting along the given axis and returns an array
25+
in sorted order.
26+
27+
Parameters
28+
----------
29+
x : relax.Expr
30+
The input tensor.
31+
32+
axis : int
33+
Axis along which to sort the input tensor.
34+
By default the last axis of the input is used.
35+
36+
descending : bool
37+
Whether to sort in descending order, the default is False
38+
39+
Returns
40+
-------
41+
out : relax.Expr
42+
Sorted tensor.
43+
44+
"""
45+
return _ffi_api.sort(x, axis, descending) # type: ignore

python/tvm/relax/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import tvm
2525
from tvm import meta_schedule as ms
2626

27-
from . import transform
27+
from . import transform, backend
2828

2929

3030
def zero_pipeline(*, enable_warning: bool = False):
@@ -81,6 +81,7 @@ def default_build_pipeline():
8181
def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
8282
seq = tvm.transform.Sequential(
8383
[
84+
backend.DispatchSortScan(),
8485
transform.LegalizeOps(),
8586
transform.RewriteDataflowReshape(),
8687
transform.ToNonDataflow(),

python/tvm/relax/vm_build.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,12 @@ def _extract_attrs(mod: tvm.IRModule):
328328
if pipeline is not None:
329329
if isinstance(pipeline, str):
330330
pipeline = relax.get_pipeline(pipeline)
331-
mod = pipeline(mod)
331+
if target is None:
332+
mod = pipeline(mod)
333+
else:
334+
with target:
335+
mod = pipeline(mod)
336+
332337
ext_libs, constants = _extract_attrs(mod)
333338
params.update(dict(constants))
334339
builder = relax.ExecBuilder()

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
sign,
136136
sin,
137137
sinh,
138+
sort,
138139
split,
139140
square,
140141
squeeze,
@@ -758,6 +759,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
758759
"sign",
759760
"sin",
760761
"sinh",
762+
"sort",
761763
"split",
762764
"square",
763765
"squeeze",

python/tvm/topi/cuda/sort.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ def _odd_even_sort(
120120
values=None,
121121
values_swap=None,
122122
):
123-
124123
nthread_tx = block_size // 2
125124
nthread_bx = ceil_div(size, block_size)
126125
nthread_by = axis_mul_before

0 commit comments

Comments
 (0)