1717# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
1818"""Dispatch sort and scan operators to platform dependent implementation."""
1919
20- from tvm import topi , dlight , relax
20+ from functools import reduce
21+ from operator import mul
22+
23+ from tvm import DataType , dlight , relax , topi
24+ from tvm .contrib .thrust import can_use_thrust
2125from tvm .ir import Op
2226from tvm .ir .module import IRModule
2327from tvm .ir .transform import PassContext , module_pass
24- from tvm .target import Target
25- from tvm .contrib .thrust import can_use_thrust
2628from tvm .relax import PyExprMutator , expr_functor
29+ from tvm .target import Target
2730
2831
2932@expr_functor .mutator
@@ -80,23 +83,24 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
8083 if call .op .name == "relax.sort" :
8184 tgt = self ._get_target (call .struct_info )
8285 te_func = topi .sort
86+ kwargs = {}
8387 with tgt :
8488 if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
8589 te_func = topi .cuda .sort_thrust
90+ kwargs ["workspace" ] = self .allocate_workspace (call )
8691 elif tgt .kind .name == "cuda" :
8792 te_func = topi .cuda .sort
8893 return self .builder_ .call_te (
89- te_func ,
90- call .args [0 ],
91- call .attrs .axis ,
92- not call .attrs .descending ,
94+ te_func , call .args [0 ], call .attrs .axis , not call .attrs .descending , ** kwargs
9395 )
9496 if call .op .name == "relax.argsort" :
9597 tgt = self ._get_target (call .struct_info )
9698 te_func = topi .argsort
99+ kwargs = {}
97100 with tgt :
98101 if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
99102 te_func = topi .cuda .argsort_thrust
103+ kwargs ["workspace" ] = self .allocate_workspace (call )
100104 elif tgt .kind .name == "cuda" :
101105 te_func = topi .cuda .argsort
102106 return self .builder_ .call_te (
@@ -105,12 +109,15 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
105109 axis = call .attrs .axis ,
106110 is_ascend = not call .attrs .descending ,
107111 dtype = call .attrs .dtype ,
112+ ** kwargs ,
108113 )
109114 if call .op .name == "relax.topk" :
110115 tgt = self ._get_target (call .struct_info )
111116 te_func = topi .topk
117+ kwargs = {}
112118 if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
113119 te_func = topi .cuda .topk_thrust
120+ kwargs ["workspace" ] = self .allocate_workspace (call )
114121 elif tgt .kind .name == "cuda" :
115122 te_func = topi .cuda .topk
116123 tir_call = self .builder_ .call_te (
@@ -121,6 +128,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
121128 ret_type = call .attrs .ret_type ,
122129 is_ascend = not call .attrs .largest ,
123130 dtype = call .attrs .dtype ,
131+ ** kwargs ,
124132 )
125133 if tgt .kind .name != "cuda" :
126134 return tir_call
@@ -130,23 +138,51 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
130138 if call .op .name in ("relax.cumprod" , "relax.cumsum" ):
131139 tgt = self ._get_target (call .struct_info )
132140 axis = int (call .attrs .axis ) if call .attrs .axis is not None else call .attrs .axis
133- te_func = topi .cuda .cumsum if tgt .kind .name == "cuda" else topi .cumsum
134- if call .op .name == "relax.cumprod" :
135- te_func = topi .cuda .cumprod if tgt .kind .name == "cuda" else topi .cumprod
136- tir_call = self .builder_ .call_te (
137- te_func ,
138- call .args [0 ],
139- axis ,
140- call .attrs .dtype ,
141- call .attrs .exclusive ,
142- )
141+ kwargs = {}
142+ with tgt :
143+ if call .op .name == "relax.cumsum" :
144+ te_func = topi .cuda .cumsum if tgt .kind .name == "cuda" else topi .cumsum
145+ if can_use_thrust (tgt , "tvm.contrib.thrust.sum_scan" ):
146+ kwargs ["workspace" ] = self .allocate_workspace (call )
147+ elif call .op .name == "relax.cumprod" :
148+ te_func = topi .cuda .cumprod if tgt .kind .name == "cuda" else topi .cumprod
149+ else :
150+ raise ValueError (f"Unsupported op: { call .op .name } " )
151+ tir_call = self .builder_ .call_te (
152+ te_func ,
153+ call .args [0 ],
154+ axis ,
155+ call .attrs .dtype ,
156+ call .attrs .exclusive ,
157+ ** kwargs ,
158+ )
143159 if tgt .kind .name != "cuda" :
144160 return tir_call
145161 # apply dlight gpu fallback
146162 self ._apply_dlight_gpu_fallback (tgt , tir_call )
147163 return tir_call
148164 return super ().visit_call_ (call )
149165
166+ def estimate_thrust_workspace_size (self , call : relax .Call ) -> int :
167+ """
168+ Estimate the workspace size for thrust sort/argsort/topk/cumsum
169+ """
170+ input_shape = call .args [0 ].struct_info .shape
171+ input_byte_per_elem = DataType (call .args [0 ].struct_info .dtype ).bits // 8
172+ input_size = reduce (mul , input_shape , 1 ) * input_byte_per_elem
173+ # Most GPU algorithms take O(n) space or less, we choose 2N + 4MB as a safe estimation
174+ return 2 * input_size + 4 * 1024 * 1024
175+
176+ def allocate_workspace (self , call : relax .Call ) -> relax .Var :
177+ """
178+ Allocate workspace for thrust sort/argsort/topk.
179+ """
180+ workspace_size = self .estimate_thrust_workspace_size (call )
181+ alloc = relax .op .builtin .alloc_tensor (
182+ relax .ShapeExpr ((workspace_size ,)), "uint8" , runtime_device_index = 0
183+ )
184+ return self .builder_ .emit (alloc )
185+
150186
151187@module_pass (opt_level = 0 , name = "DispatchSortScan" )
152188class DispatchSortScan :
0 commit comments