1- # pylint: disable=invalid-name
1+ # pylint: disable=invalid-name, unused-argument
22"""Tensor ops"""
33from __future__ import absolute_import
44
88from ..compiler import registry as reg
99from ..compiler import OpPattern
1010
11- def schedule_elemwise (_ , outs , target ):
12- """Generic schedule for elemwise operation"""
13- if target == "cuda" :
14- return topi .cuda .schedule_elemwise (outs )
15- assert target .startswith ("llvm" )
16- s = tvm .create_schedule ([x .op for x in outs ])
17- tvm .schedule .AutoInlineInjective (s )
18- return s
19-
2011def _schedule_broadcast (_ , outs , target ):
2112 """Generic schedule for binary bcast"""
2213 if target == "cuda" :
@@ -29,66 +20,140 @@ def _schedule_broadcast(_, outs, target):
2920def _compute_binary_scalar (f ):
3021 """auxiliary function"""
3122 @tvm .tag_scope ("ewise" )
32- def _compute (attrs , x ):
23+ def _compute (attrs , x , _ ):
3324 x = x [0 ]
3425 scalar = attrs .get_float ("scalar" )
3526 scalar = tvm .const (scalar , x .dtype )
3627 return tvm .compute (x .shape , lambda * i : f (x (* i ), scalar ))
3728 return _compute
3829
3930
31+ def _compute_unary (f ):
32+ """auxiliary function"""
33+ def _compute (attrs , x , _ ):
34+ return f (x [0 ])
35+ return _compute
36+
37+
38+ def _compute_binary (f ):
39+ """auxiliary function"""
40+ def _compute (attrs , x , _ ):
41+ return f (x [0 ], x [1 ])
42+ return _compute
43+
44+
4045_fschedule_broadcast = tvm .convert (_schedule_broadcast )
4146
4247# exp
43- reg .register_compute ("exp" ,
44- lambda _ , x : topi .exp (x [0 ]))
48+ reg .register_compute ("exp" , _compute_unary (topi .exp ))
4549reg .register_pattern ("exp" , OpPattern .ELEM_WISE )
4650reg .register_schedule ("exp" , _fschedule_broadcast )
4751
52+ # sqrt
53+ reg .register_compute ("sqrt" , _compute_unary (topi .sqrt ))
54+ reg .register_pattern ("sqrt" , OpPattern .ELEM_WISE )
55+ reg .register_schedule ("sqrt" , _fschedule_broadcast )
56+
4857# log
49- reg .register_compute ("log" ,
50- lambda _ , x : topi .log (x [0 ]))
58+ reg .register_compute ("log" , _compute_unary (topi .log ))
5159reg .register_pattern ("log" , OpPattern .ELEM_WISE )
5260reg .register_schedule ("log" , _fschedule_broadcast )
5361
5462# tanh
55- reg .register_compute ("tanh" ,
56- lambda _ , x : topi .tanh (x [0 ]))
63+ reg .register_compute ("tanh" , _compute_unary (topi .tanh ))
5764reg .register_pattern ("tanh" , OpPattern .ELEM_WISE )
5865reg .register_schedule ("tanh" , _fschedule_broadcast )
5966
67+ # negative
68+ reg .register_compute ("negative" , _compute_unary (topi .negative ))
69+ reg .register_pattern ("negative" , OpPattern .ELEM_WISE )
70+ reg .register_schedule ("negative" , _fschedule_broadcast )
71+
6072# sigmoid
61- reg .register_compute ("sigmoid" ,
62- lambda _ , x : topi .sigmoid (x [0 ]))
73+ reg .register_compute ("sigmoid" , _compute_unary (topi .sigmoid ))
6374reg .register_pattern ("sigmoid" , OpPattern .ELEM_WISE )
6475reg .register_schedule ("sigmoid" , _fschedule_broadcast )
6576
66- # add scalar
77+ # add_scalar
6778reg .register_compute ("__add_scalar__" ,
6879 _compute_binary_scalar (lambda x , y : x + y ))
6980reg .register_pattern ("__add_scalar__" , OpPattern .ELEM_WISE )
7081reg .register_schedule ("__add_scalar__" , _fschedule_broadcast )
7182
83+ # sub_calar
84+ reg .register_compute ("__sub_scalar__" ,
85+ _compute_binary_scalar (lambda x , y : x - y ))
86+ reg .register_pattern ("__sub_scalar__" , OpPattern .ELEM_WISE )
87+ reg .register_schedule ("__sub_scalar__" , _fschedule_broadcast )
88+
89+ # rsub_scalar
90+ reg .register_compute ("__rsub_scalar__" ,
91+ _compute_binary_scalar (lambda x , y : y - x ))
92+ reg .register_pattern ("__rsub_scalar__" , OpPattern .ELEM_WISE )
93+ reg .register_schedule ("__rsub_scalar__" , _fschedule_broadcast )
94+
95+ # mul_scalar
96+ reg .register_compute ("__mul_scalar__" ,
97+ _compute_binary_scalar (lambda x , y : x * y ))
98+ reg .register_pattern ("__mul_scalar__" , OpPattern .ELEM_WISE )
99+ reg .register_schedule ("__mul_scalar__" , _fschedule_broadcast )
100+
101+ # div_scalar
102+ reg .register_compute ("__div_scalar__" ,
103+ _compute_binary_scalar (lambda x , y : x / y ))
104+ reg .register_pattern ("__div_scalar__" , OpPattern .ELEM_WISE )
105+ reg .register_schedule ("__div_scalar__" , _fschedule_broadcast )
106+
107+ # rdiv_scalar
108+ reg .register_compute ("__rdiv_scalar__" ,
109+ _compute_binary_scalar (lambda x , y : y / x ))
110+ reg .register_pattern ("__rdiv_scalar__" , OpPattern .ELEM_WISE )
111+ reg .register_schedule ("__rdiv_scalar__" , _fschedule_broadcast )
112+
113+ # elemwise_add
114+ reg .register_compute ("elemwise_add" , _compute_binary (topi .broadcast_add ))
115+ reg .register_pattern ("elemwise_add" , OpPattern .BROADCAST )
116+ reg .register_schedule ("elemwise_add" , _fschedule_broadcast )
117+
118+ # elemwise_sub
119+ reg .register_compute ("elemwise_sub" , _compute_binary (topi .broadcast_sub ))
120+ reg .register_pattern ("elemwise_sub" , OpPattern .BROADCAST )
121+ reg .register_schedule ("elemwise_sub" , _fschedule_broadcast )
122+
123+ # elemwise_mul
124+ reg .register_compute ("elemwise_mul" , _compute_binary (topi .broadcast_mul ))
125+ reg .register_pattern ("elemwise_mul" , OpPattern .BROADCAST )
126+ reg .register_schedule ("elemwise_mul" , _fschedule_broadcast )
127+
128+ # elemwise_div
129+ reg .register_compute ("elemwise_div" , _compute_binary (topi .broadcast_div ))
130+ reg .register_pattern ("elemwise_div" , OpPattern .BROADCAST )
131+ reg .register_schedule ("elemwise_div" , _fschedule_broadcast )
132+
72133# broadcast_add
73- reg .register_compute ("broadcast_add" ,
74- lambda _ , x : topi .broadcast_add (x [0 ], x [1 ]))
134+ reg .register_compute ("broadcast_add" , _compute_binary (topi .broadcast_add ))
75135reg .register_pattern ("broadcast_add" , OpPattern .BROADCAST )
76136reg .register_schedule ("broadcast_add" , _fschedule_broadcast )
77137
78138# broadcast_sub
79- reg .register_compute ("broadcast_sub" ,
80- lambda _ , x : topi .broadcast_sub (x [0 ], x [1 ]))
139+ reg .register_compute ("broadcast_sub" , _compute_binary (topi .broadcast_sub ))
81140reg .register_pattern ("broadcast_sub" , OpPattern .BROADCAST )
82141reg .register_schedule ("broadcast_sub" , _fschedule_broadcast )
83142
84143# broadcast_mul
85- reg .register_compute ("broadcast_mul" ,
86- lambda _ , x : topi .broadcast_mul (x [0 ], x [1 ]))
144+ reg .register_compute ("broadcast_mul" , _compute_binary (topi .broadcast_mul ))
87145reg .register_pattern ("broadcast_mul" , OpPattern .BROADCAST )
88146reg .register_schedule ("broadcast_mul" , _fschedule_broadcast )
89147
90148# broadcast_div
91- reg .register_compute ("broadcast_div" ,
92- lambda _ , x : topi .broadcast_div (x [0 ], x [1 ]))
149+ reg .register_compute ("broadcast_div" , _compute_binary (topi .broadcast_div ))
93150reg .register_pattern ("broadcast_div" , OpPattern .BROADCAST )
94151reg .register_schedule ("broadcast_div" , _fschedule_broadcast )
152+
153+ # broadcast_to
154+ @reg .register_compute ("broadcast_to" )
155+ def compute_softmax (attrs , inputs , out_info ):
156+ """Compute definition of softmax"""
157+ return topi .broadcast_to (inputs [0 ], shape = out_info [0 ].shape )
158+ reg .register_pattern ("broadcast_to" , OpPattern .BROADCAST )
159+ reg .register_schedule ("broadcast_to" , _fschedule_broadcast )
0 commit comments