@@ -126,6 +126,49 @@ def main(
126126def test_extended_unary_ops ():
127127 example_args = (torch .randn (1 , 3 , 10 , 10 , dtype = torch .float32 ),)
128128
129+ # celu
130+ class Celu1 (Module ):
131+ def __init__ (self ):
132+ super ().__init__ ()
133+ self .celu = torch .nn .CELU ()
134+
135+ def forward (self , input ):
136+ return self .celu (input )
137+
138+ class Celu2 (Module ):
139+ def forward (self , input ):
140+ return torch .nn .functional .celu (input )
141+
142+ # alpha * min(0, exp(x / alpha) - 1) + max(0, x)
143+ @tvm .script .ir_module
144+ class expected_celu :
145+ @R .function
146+ def main (
147+ input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
148+ ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
149+ with R .dataflow ():
150+ lv : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .exp (input_1 )
151+ lv_div : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .divide (
152+ lv , R .const (1.0 , "float32" )
153+ )
154+ lv_sub : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .subtract (
155+ lv_div , R .const (1.0 , "float32" )
156+ )
157+ lv_min : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .minimum (
158+ R .const (0.0 , "float32" ), lv_sub
159+ )
160+ lv_scaled : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .multiply (
161+ R .const (1.0 , "float32" ), lv_min
162+ )
163+ lv_relu_x : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .relu (input_1 )
164+ lv_celu : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .add (lv_scaled , lv_relu_x )
165+ gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv_celu ,)
166+ R .output (gv )
167+ return gv
168+
169+ verify_model (Celu1 (), example_args , {}, expected_celu )
170+ verify_model (Celu2 (), example_args , {}, expected_celu )
171+
129172 # clamp
130173 class Clamp (Module ):
131174 def forward (self , input ):
@@ -226,6 +269,46 @@ def main(
226269 verify_model (Dropout1 (), example_args , {}, expected_dropout )
227270 verify_model (Dropout2 (), example_args , {}, expected_dropout )
228271
272+ # elu
273+ class Elu (Module ):
274+ def __init__ (self ):
275+ super ().__init__ ()
276+ self .elu = torch .nn .ELU ()
277+
278+ def forward (self , input ):
279+ return self .elu (input )
280+
281+ class Elu2 (Module ):
282+ def forward (self , input ):
283+ return torch .nn .functional .elu (input )
284+
285+ @tvm .script .ir_module
286+ class expected_elu :
287+ @R .function
288+ def main (
289+ input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
290+ ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
291+ # block 0
292+ with R .dataflow ():
293+ lv_exp : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .exp (input_1 )
294+ lv_one_minus_exp : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .subtract (
295+ R .const (1.0 , dtype = "float32" ), lv_exp
296+ )
297+ lv_relu_one_minus_exp : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .relu (
298+ lv_one_minus_exp
299+ )
300+ lv_scaled : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .multiply (
301+ R .const (- 1.0 , dtype = "float32" ), lv_relu_one_minus_exp
302+ )
303+ lv_relu_x : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .relu (input_1 )
304+ lv_elu : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .add (lv_scaled , lv_relu_x )
305+ gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv_elu ,)
306+ R .output (gv )
307+ return gv
308+
309+ verify_model (Elu (), example_args , {}, expected_elu )
310+ verify_model (Elu2 (), example_args , {}, expected_elu )
311+
229312 # gelu
230313 class Gelu (Module ):
231314 def __init__ (self ):
@@ -358,6 +441,46 @@ def main(
358441 verify_model (ReLU0 (), example_args , {}, expected_relu )
359442 verify_model (ReLU1 (), example_args , {}, expected_relu )
360443
444+ # selu
445+ class Selu1 (Module ):
446+ def __init__ (self ):
447+ super ().__init__ ()
448+ self .selu = torch .nn .SELU ()
449+
450+ def forward (self , input ):
451+ return self .selu (input )
452+
453+ class Selu2 (Module ):
454+ def forward (self , input ):
455+ return torch .nn .functional .selu (input )
456+
457+ @tvm .script .ir_module
458+ class expected_selu :
459+ @R .function
460+ def main (
461+ input_1 : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )
462+ ) -> R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )):
463+ with R .dataflow ():
464+ lv_relu : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .nn .relu (input_1 )
465+ lv_exp : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .exp (input_1 )
466+ lv_sub : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .subtract (
467+ lv_exp , R .const (1.0 , "float32" )
468+ )
469+ lv_scaled : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .multiply (
470+ R .const (1.6732631921768188 , "float32" ), lv_sub
471+ )
472+ lv_add : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .add (lv_relu , lv_scaled )
473+ lv_selu : R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" ) = R .multiply (
474+ R .const (1.0507010221481323 , "float32" ), lv_add
475+ )
476+ gv : R .Tuple (R .Tensor ((1 , 3 , 10 , 10 ), dtype = "float32" )) = (lv_selu ,)
477+ R .output (gv )
478+
479+ return gv
480+
481+ verify_model (Selu1 (), example_args , {}, expected_selu )
482+ verify_model (Selu2 (), example_args , {}, expected_selu )
483+
361484 # sigmoid
362485 class Sigmoid (Module ):
363486 def __init__ (self ):
0 commit comments