Skip to content

Commit 1e8cad2

Browse files
committed
fix comments
1 parent 001f0c9 commit 1e8cad2

File tree

3 files changed

+92
-249
lines changed

3 files changed

+92
-249
lines changed

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def _get_target(self, expr: Expr) -> Target:
4747
target = Target.current()
4848
if target is None:
4949
raise ValueError(
50-
"Target not found. Please ensure that the target is annotated within the module, \
51-
or alternatively, execute this within a specified target context."
50+
"Target not found. Please ensure that the target is annotated within the module, "
51+
"or alternatively, execute this within a specified target context."
5252
)
5353
return target
5454

@@ -65,12 +65,14 @@ def visit_call_(self, call: Call) -> Expr:
6565
call.args[0],
6666
call.attrs.axis,
6767
not call.attrs.descending,
68+
primfunc_attrs={"tir.is_scheduled": 1},
6869
)
6970
return self.builder_.call_te(
7071
topi.cuda.sort if tgt.kind.name == "cuda" else topi.sort,
7172
call.args[0],
7273
call.attrs.axis,
7374
not call.attrs.descending,
75+
primfunc_attrs={"tir.is_scheduled": 1},
7476
)
7577

7678
if call.op.name == "relax.cumsum":
@@ -82,6 +84,7 @@ def visit_call_(self, call: Call) -> Expr:
8284
call.args[0],
8385
axis,
8486
call.attrs.dtype,
87+
primfunc_attrs={"tir.is_scheduled": 1},
8588
)
8689

8790
return super().visit_call_(call)

python/tvm/relax/op/sort.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Sortings operators."""
18-
from typing import List, Optional
1918

2019
from . import _ffi_api
2120
from ..expr import Expr
@@ -30,11 +29,11 @@ def sort(x: Expr, axis: int = -1, descending: bool = False):
3029
x : relax.Expr
3130
The input tensor.
3231
33-
axis : Optional[int]
32+
axis : int
3433
Axis along which to sort the input tensor.
3534
By default the last axis of the input is used.
3635
37-
descending : Optional[bool]
36+
descending : bool
3837
Whether to sort in descending order, the default is False
3938
4039
Returns

0 commit comments

Comments
 (0)