2424from tvm .relax .dpl import is_op , rewrite_call , wildcard , has_target
2525
2626from ..transform import function_pass
27+ from tvm .topi .utils import prod , swap
2728
2829
2930@function_pass (opt_level = 0 )
@@ -35,10 +36,21 @@ class DispatchOps:
3536 def __init__ (self ):
3637 self .input = wildcard ()
3738 # cumsum on cpu will be legalized
39+ self .cumsum_cpu = is_op ("relax.cumsum" )(self .input ) & has_target ("llvm" )
3840 self .cumsum_gpu = is_op ("relax.cumsum" )(self .input ) & has_target ("cuda" )
39- self .sort_gpu = is_op ("relax.sort" )(self .input ) & has_target ("cuda" )
4041 self .sort_cpu = is_op ("relax.sort" )(self .input ) & has_target ("llvm" )
41- self .pattern = self .cumsum_gpu | self .sort_gpu | self .sort_cpu
42+ self .sort_gpu = is_op ("relax.sort" )(self .input ) & has_target ("cuda" )
43+ # if no target is specified, default will be on GPU
44+ self .sort = is_op ("relax.sort" )(self .input )
45+ self .cumsum = is_op ("relax.cumsum" )(self .input )
46+ self .pattern = (
47+ self .cumsum_gpu
48+ | self .cumsum_cpu
49+ | self .sort_gpu
50+ | self .sort_cpu
51+ | self .sort
52+ | self .cumsum
53+ )
4254
4355 def transform_function (self , func : Expr , mod : IRModule , ctx : PassContext ) -> IRModule :
4456 """
@@ -64,55 +76,81 @@ def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRM
6476 continue
6577
6678 def rewriter (expr , matches ):
67- print ("got here 70, expr: " , expr )
6879 arg = matches [self .input ]
69- print ("got arg: " , arg )
7080
71- if self .cumsum_gpu in matches :
72- print ("86 matches[self.no_op_reshape]: " , matches [self .cumsum_gpu ])
73- return relax .call_dps_packed (
81+ if self .cumsum_gpu in matches or (
82+ self .cumsum in matches and self .cumsum_cpu not in matches
83+ ):
84+ axis = matches [self .cumsum_gpu ].attrs .axis
85+ output_dtype = matches [self .cumsum_gpu ].attrs .dtype
86+ if output_dtype is None :
87+ output_dtype = out_sinfo .dtype
88+ out_sinfo = arg .struct_info
89+ if axis is None :
90+ axis = 0
91+ new_shape = (prod (arg .struct_info .shape ),)
92+ arg = relax .op .reshape (arg , new_shape )
93+ out_sinfo = relax .TensorStructInfo (
94+ new_shape , output_dtype , out_sinfo .vdevice
95+ )
96+ return relax .op .call_dps_packed (
7497 "tvm.contrib.thrust.sum_scan" ,
75- [arg ],
76- out_sinfo = arg .struct_info ,
77- )
78- elif self .sort_gpu in matches :
79- print ("86 matches[self.no_op_reshape]: " , matches [self .sort_gpu ])
80- return relax .call_dps_packed (
81- "tvm.contrib.thrust.sort" ,
82- [arg ],
83- out_sinfo = arg .struct_info ,
98+ [arg , int (axis )],
99+ out_sinfo = out_sinfo ,
84100 )
101+
85102 elif self .sort_cpu in matches :
103+ axis = int (matches [self .sort_cpu ].attrs .axis )
104+ is_ascend = int (matches [self .sort_cpu ].attrs .is_ascend )
105+ out_sinfo = arg .struct_info
106+ if axis is None :
107+ axis = 0
108+ new_shape = (prod (arg .struct_info .shape ),)
109+ out_sinfo = relax .TensorStructInfo (
110+ new_shape , arg .struct_info .dtype , out_sinfo .vdevice
111+ )
112+
86113 return relax .call_dps_packed (
87114 "tvm.contrib.sort.sort" ,
88- [arg ],
89- out_sinfo = arg . struct_info ,
115+ [arg , axis , is_ascend ],
116+ out_sinfo = out_sinfo ,
90117 )
91118
119+ elif self .sort_gpu in matches or self .sort in matches :
120+ axis = matches [self .sort_gpu ].attrs .axis
121+ if axis is None :
122+ axis = - 1
123+ axis = int (axis )
124+
125+ is_ascend = matches [self .sort_gpu ].attrs .is_ascend
126+ if is_ascend is None :
127+ is_ascend = True
128+ out_sinfo = arg .struct_info
129+ ndim = arg .struct_info .ndim
130+
131+ axis = ndim + axis if axis < 0 else axis
132+ if axis != ndim - 1 :
133+ # Prepare for sorting along axis -1.
134+ axes = swap (list (range (ndim )), axis )
135+ arg = relax .op .permute_dims (arg , axes )
136+ new_shape = [out_sinfo .shape [i ] for i in axes ]
137+ out_sinfo = relax .TensorStructInfo (
138+ new_shape , out_sinfo .dtype , out_sinfo .vdevice
139+ )
140+
141+ out = relax .op .call_dps_packed (
142+ "tvm.contrib.thrust.sort" ,
143+ [arg , int (is_ascend )],
144+ out_sinfo = out_sinfo ,
145+ )
146+ if axis != ndim - 1 :
147+ # Prepare for sorting along axis -1.
148+ axes = swap (list (range (ndim )), axis )
149+ out = relax .op .permute_dims (out , axes )
150+ return out
151+
92152 return expr
93153
94154 updated_func = rewrite_call (self .pattern , rewriter , func )
95155
96156 return updated_func
97-
98-
99- # Option 0): add a global dict for it: {op, target, condition, dps_packed},
100- # condition is some specific setting like the value of k in topk
101- # Q: how to work it with pattern match?
102- #
103- # Option 1): normal python mod pass, straightforward, but not easy to hack like topk
104- #
105- # Option 2): c++ pass, not easy to be updated. Don't go
106- #
107- # How to handle with target? don't require RealizeVDevice, just specify the vdevice in inputs
108- # but vdevice is necessary, we could have default for it
109- #
110- # Sample map
111- # cumsum - cpu => ignore for legalization
112- # cumsum - gpu => relax.call_dps_packed(
113- # "tvm.contrib.thrust.sum_scan",
114- # [data],
115- # out_sinfo=data.struct_info,
116- # )
117- # f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32") # and pattern
118- # is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [4, 3]}).match(conv2d)
0 commit comments