Skip to content

Commit 45368fd

Browse files
committed
[Relax,Topi] Allow passing workspace to thrust to avoid allocations
1 parent 61249b4 commit 45368fd

File tree

7 files changed

+474
-133
lines changed

7 files changed

+474
-133
lines changed

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
1818
"""Dispatch sort and scan operators to platform dependent implementation."""
1919

20-
from tvm import topi, dlight, relax
20+
from functools import reduce
21+
from operator import mul
22+
23+
from tvm import DataType, dlight, relax, topi
24+
from tvm.contrib.thrust import can_use_thrust
2125
from tvm.ir import Op
2226
from tvm.ir.module import IRModule
2327
from tvm.ir.transform import PassContext, module_pass
24-
from tvm.target import Target
25-
from tvm.contrib.thrust import can_use_thrust
2628
from tvm.relax import PyExprMutator, expr_functor
29+
from tvm.target import Target
2730

2831

2932
@expr_functor.mutator
@@ -80,23 +83,24 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
8083
if call.op.name == "relax.sort":
8184
tgt = self._get_target(call.struct_info)
8285
te_func = topi.sort
86+
kwargs = {}
8387
with tgt:
8488
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
8589
te_func = topi.cuda.sort_thrust
90+
kwargs["workspace"] = self.allocate_workspace(call)
8691
elif tgt.kind.name == "cuda":
8792
te_func = topi.cuda.sort
8893
return self.builder_.call_te(
89-
te_func,
90-
call.args[0],
91-
call.attrs.axis,
92-
not call.attrs.descending,
94+
te_func, call.args[0], call.attrs.axis, not call.attrs.descending, **kwargs
9395
)
9496
if call.op.name == "relax.argsort":
9597
tgt = self._get_target(call.struct_info)
9698
te_func = topi.argsort
99+
kwargs = {}
97100
with tgt:
98101
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
99102
te_func = topi.cuda.argsort_thrust
103+
kwargs["workspace"] = self.allocate_workspace(call)
100104
elif tgt.kind.name == "cuda":
101105
te_func = topi.cuda.argsort
102106
return self.builder_.call_te(
@@ -105,12 +109,15 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
105109
axis=call.attrs.axis,
106110
is_ascend=not call.attrs.descending,
107111
dtype=call.attrs.dtype,
112+
**kwargs,
108113
)
109114
if call.op.name == "relax.topk":
110115
tgt = self._get_target(call.struct_info)
111116
te_func = topi.topk
117+
kwargs = {}
112118
if can_use_thrust(tgt, "tvm.contrib.thrust.sort"):
113119
te_func = topi.cuda.topk_thrust
120+
kwargs["workspace"] = self.allocate_workspace(call)
114121
elif tgt.kind.name == "cuda":
115122
te_func = topi.cuda.topk
116123
tir_call = self.builder_.call_te(
@@ -121,6 +128,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
121128
ret_type=call.attrs.ret_type,
122129
is_ascend=not call.attrs.largest,
123130
dtype=call.attrs.dtype,
131+
**kwargs,
124132
)
125133
if tgt.kind.name != "cuda":
126134
return tir_call
@@ -130,23 +138,51 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
130138
if call.op.name in ("relax.cumprod", "relax.cumsum"):
131139
tgt = self._get_target(call.struct_info)
132140
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
133-
te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum
134-
if call.op.name == "relax.cumprod":
135-
te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else topi.cumprod
136-
tir_call = self.builder_.call_te(
137-
te_func,
138-
call.args[0],
139-
axis,
140-
call.attrs.dtype,
141-
call.attrs.exclusive,
142-
)
141+
kwargs = {}
142+
with tgt:
143+
if call.op.name == "relax.cumsum":
144+
te_func = topi.cuda.cumsum if tgt.kind.name == "cuda" else topi.cumsum
145+
if can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan"):
146+
kwargs["workspace"] = self.allocate_workspace(call)
147+
elif call.op.name == "relax.cumprod":
148+
te_func = topi.cuda.cumprod if tgt.kind.name == "cuda" else topi.cumprod
149+
else:
150+
raise ValueError(f"Unsupported op: {call.op.name}")
151+
tir_call = self.builder_.call_te(
152+
te_func,
153+
call.args[0],
154+
axis,
155+
call.attrs.dtype,
156+
call.attrs.exclusive,
157+
**kwargs,
158+
)
143159
if tgt.kind.name != "cuda":
144160
return tir_call
145161
# apply dlight gpu fallback
146162
self._apply_dlight_gpu_fallback(tgt, tir_call)
147163
return tir_call
148164
return super().visit_call_(call)
149165

166+
def estimate_thrust_workspace_size(self, call: relax.Call) -> int:
167+
"""
168+
Estimate the workspace size for thrust sort/argsort/topk/cumsum
169+
"""
170+
input_shape = call.args[0].struct_info.shape
171+
input_byte_per_elem = DataType(call.args[0].struct_info.dtype).bits // 8
172+
input_size = reduce(mul, input_shape, 1) * input_byte_per_elem
173+
# Most GPU algorithms take O(n) space or less, we choose 2N + 4MB as a safe estimation
174+
return 2 * input_size + 4 * 1024 * 1024
175+
176+
def allocate_workspace(self, call: relax.Call) -> relax.Var:
177+
"""
178+
Allocate workspace for thrust sort/argsort/topk.
179+
"""
180+
workspace_size = self.estimate_thrust_workspace_size(call)
181+
alloc = relax.op.builtin.alloc_tensor(
182+
relax.ShapeExpr((workspace_size,)), "uint8", runtime_device_index=0
183+
)
184+
return self.builder_.emit(alloc)
185+
150186

151187
@module_pass(opt_level=0, name="DispatchSortScan")
152188
class DispatchSortScan:

python/tvm/relax/frontend/nn/op.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,6 +2241,112 @@ def cumsum(
22412241
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
22422242

22432243

2244+
def sort(x: Tensor, axis: int = -1, descending: bool = False, name="sort"):
2245+
"""Performs sorting along the given axis and returns an array
2246+
in sorted order.
2247+
2248+
Parameters
2249+
----------
2250+
x : Tensor
2251+
The input tensor.
2252+
2253+
axis : int
2254+
Axis along which to sort the input tensor.
2255+
By default the last axis of the input is used.
2256+
2257+
descending : bool
2258+
Whether to sort in descending order, the default is False
2259+
2260+
name : str
2261+
Name hint.
2262+
2263+
Returns
2264+
-------
2265+
out : Tensor
2266+
The sorted tensor.
2267+
"""
2268+
return wrap_nested(_op.sort(x, axis, descending), name=name)
2269+
2270+
2271+
def argsort(
2272+
data: Tensor, axis: int = -1, descending: bool = False, dtype: str = "int32", name="argsort"
2273+
):
2274+
"""Performs sorting along the given axis and returns an array of indices
2275+
having same shape as an input array that index data in sorted order.
2276+
2277+
Parameters
2278+
----------
2279+
data : Tensor
2280+
The input data tensor.
2281+
2282+
axis : int
2283+
Axis long which to sort the input tensor.
2284+
2285+
descending : bool
2286+
Whether to sort in descending order, the default is False
2287+
2288+
dtype : str
2289+
The data type of the output indices.
2290+
2291+
name : str
2292+
Name hint.
2293+
2294+
Returns
2295+
-------
2296+
out : Tensor
2297+
The indices of the sorted tensor.
2298+
"""
2299+
return wrap_nested(_op.argsort(data, axis, descending, dtype), name=name)
2300+
2301+
2302+
def topk(
2303+
data: Tensor,
2304+
k: int = 1,
2305+
axis: int = -1,
2306+
ret_type: str = "both",
2307+
largest: bool = True,
2308+
dtype: str = "int32",
2309+
name: str = "topk",
2310+
):
2311+
"""Get the top k elements in an input tensor along the given axis.
2312+
2313+
ret_type specifies the return type, can be one of ("both", "values", "indices").
2314+
2315+
Parameters
2316+
----------
2317+
data : Tensor
2318+
The input data tensor.
2319+
2320+
k : int
2321+
Number of top elements to select. Return all elements if k < 1.
2322+
2323+
axis : int
2324+
Axis long which to sort the input tensor.
2325+
2326+
ret_type: str
2327+
The return type [both, values, indices].
2328+
"both": return both top k data and indices.
2329+
"values": return top k data only.
2330+
"indices": return top k indices only.
2331+
2332+
largest : bool
2333+
Whether to return largest or smallest elements.
2334+
The k smallest elements are returned if largest is False.
2335+
2336+
dtype : str
2337+
The data type of the indices output.
2338+
2339+
name : str
2340+
Name hint.
2341+
2342+
Returns
2343+
-------
2344+
out : Tensor or Tuple[Tensor, Tensor]
2345+
The computed result.
2346+
"""
2347+
return wrap_nested(_op.topk(data, k, axis, ret_type, largest, dtype), name=name)
2348+
2349+
22442350
def multinomial_from_uniform(
22452351
prob: Tensor,
22462352
uniform_sample: Tensor,

python/tvm/te/operation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,15 +333,15 @@ def extern(
333333
)
334334
types.add(t.dtype)
335335

336-
if dtype is None:
337-
if len(types) != 1:
338-
raise ValueError("Cannot infer output type, please provide dtype argument")
339-
infered_type = types.pop()
340-
dtype = [infered_type for _ in shape]
341-
if isinstance(dtype, str):
342-
dtype = [dtype]
343-
344336
if out_buffers is None:
337+
if dtype is None:
338+
if len(types) != 1:
339+
raise ValueError("Cannot infer output type, please provide dtype argument")
340+
infered_type = types.pop()
341+
dtype = [infered_type for _ in shape]
342+
if isinstance(dtype, str):
343+
dtype = [dtype]
344+
345345
for shp, dt in zip(shape, dtype):
346346
output_placeholders.append(
347347
tvm.tir.decl_buffer(shp, dt, name, elem_offset=tvm.tir.Var("elem_offset", "int32"))

0 commit comments

Comments
 (0)