2626from tvm .ir import GlobalVar , Op
2727from tvm .ir .module import IRModule
2828from tvm .ir .transform import PassContext , module_pass
29- from tvm .relax import PyExprMutator , expr_functor
29+ from tvm .relax import expr_functor
3030from tvm .target import Target
3131
32-
33- def is_gpu_target (target : Target ) -> bool :
34- """Check if the target is a GPU target."""
35- return "gpu" in target .keys
32+ from .utils import BackendDispatcher
3633
3734
3835@expr_functor .mutator
39- class SortScanDispatcher (PyExprMutator ):
40- """
41- Dispatcher to dispatch sort and scan.
42-
43- """
36+ class SortScanDispatcher (BackendDispatcher ):
37+ """Dispatcher to dispatch sort and scan."""
4438
4539 calls_to_update : Dict [GlobalVar , Target ]
4640
4741 def __init__ (self , mod ):
4842 super ().__init__ (mod )
4943 self .calls_to_update = {}
5044
51- def _get_target (self , sinfo : relax .StructInfo ) -> Target :
52- # Get target information from TensorStructInfo
53- if isinstance (sinfo , relax .TensorStructInfo ):
54- vdevice = sinfo .vdevice
55- if vdevice is not None :
56- return vdevice .target
57- elif isinstance (sinfo , relax .TupleStructInfo ):
58- for f in sinfo .fields :
59- tgt = self ._get_target (f )
60- if tgt != Target .current ():
61- return tgt
62- # Return the target in current context
63- target = Target .current ()
64- if target is None :
65- raise ValueError (
66- "Target not found. Please ensure that the target is annotated within the module, "
67- "or alternatively, execute this within a specified target context."
68- )
69- return target
70-
7145 def apply_dlight_gpu_fallback (
7246 self ,
7347 ) -> None :
@@ -107,7 +81,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
10781 if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
10882 te_func = topi .cuda .sort_thrust
10983 kwargs ["workspace" ] = self .allocate_workspace (call )
110- elif is_gpu_target (tgt ):
84+ elif self . is_gpu_target (tgt ):
11185 te_func = topi .cuda .sort
11286 return self .builder_ .call_te (
11387 te_func , call .args [0 ], call .attrs .axis , not call .attrs .descending , ** kwargs
@@ -120,7 +94,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
12094 if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
12195 te_func = topi .cuda .argsort_thrust
12296 kwargs ["workspace" ] = self .allocate_workspace (call )
123- elif is_gpu_target (tgt ):
97+ elif self . is_gpu_target (tgt ):
12498 te_func = topi .cuda .argsort
12599 return self .builder_ .call_te (
126100 te_func ,
@@ -137,7 +111,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
137111 if can_use_thrust (tgt , "tvm.contrib.thrust.sort" ):
138112 te_func = topi .cuda .topk_thrust
139113 kwargs ["workspace" ] = self .allocate_workspace (call )
140- elif is_gpu_target (tgt ):
114+ elif self . is_gpu_target (tgt ):
141115 te_func = topi .cuda .topk
142116 tir_call = self .builder_ .call_te (
143117 te_func ,
@@ -162,7 +136,7 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
162136 if (
163137 shape is not None
164138 and (axis == - 1 or axis == len (shape ) - 1 )
165- and is_gpu_target (tgt )
139+ and self . is_gpu_target (tgt )
166140 and not can_use_thrust (tgt , "tvm.contrib.thrust.sum_scan" )
167141 and call .op .name == "relax.cumsum"
168142 and call .attrs .exclusive == 0
@@ -202,11 +176,11 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
202176
203177 with tgt :
204178 if call .op .name == "relax.cumsum" :
205- te_func = topi .cuda .cumsum if is_gpu_target (tgt ) else topi .cumsum
179+ te_func = topi .cuda .cumsum if self . is_gpu_target (tgt ) else topi .cumsum
206180 if can_use_thrust (tgt , "tvm.contrib.thrust.sum_scan" ):
207181 kwargs ["workspace" ] = self .allocate_workspace (call )
208182 elif call .op .name == "relax.cumprod" :
209- te_func = topi .cuda .cumprod if is_gpu_target (tgt ) else topi .cumprod
183+ te_func = topi .cuda .cumprod if self . is_gpu_target (tgt ) else topi .cumprod
210184 else :
211185 raise ValueError (f"Unsupported op: { call .op .name } " )
212186 tir_call = self .builder_ .call_te (
0 commit comments