@@ -400,6 +400,30 @@ def convert_conv2d_transpose(g, op, block):
400400 g .add_node (op .output ("Output" )[0 ], out )
401401
402402
403+ def convert_dist (g , op , block ):
404+ """Operator converter for dist."""
405+
406+ x = g .get_node (op .input ("X" )[0 ])
407+ y = g .get_node (op .input ("Y" )[0 ])
408+ z = _op .abs (_op .subtract (x , y ))
409+ dtype = infer_type (x ).checked_type .dtype
410+ p = op .attr ("p" )
411+ if p == np .inf :
412+ out = _op .reduce .max (_op .abs (z ))
413+ elif p == np .NINF :
414+ out = _op .reduce .min (_op .abs (z ))
415+ elif p == 0.0 :
416+ out = _op .reduce .sum (_op .sign (_op .abs (z )))
417+ else :
418+ inv_p = _expr .const (1.0 / p , dtype = dtype )
419+ p = _expr .const (p , dtype = dtype )
420+ power_z = _op .power (z , p )
421+ sum_pow = _op .reduce .sum (power_z )
422+ out = _op .power (sum_pow , inv_p )
423+ out = _op .full (out , shape = (1 ))
424+ g .add_node (op .output ("Out" )[0 ], out )
425+
426+
403427def convert_cumsum (g , op , block ):
404428 """Operator converter for cumsum."""
405429
@@ -475,6 +499,39 @@ def convert_elementwise_op(g, op, block):
475499 g .add_node (op .output ("Out" )[0 ], out )
476500
477501
502+ def convert_linspace (g , op , block ):
503+ """Operator converter for linspace."""
504+
505+ start = g .get_node (op .input ("Start" )[0 ])
506+ stop = g .get_node (op .input ("Stop" )[0 ])
507+ num = g .get_node (op .input ("Num" )[0 ])
508+ dtype = _convert_dtype_value (op .attr ("dtype" ))
509+
510+ start = _op .cast (start , dtype )
511+ stop = _op .cast (stop , dtype )
512+ num = _op .cast (num , dtype )
513+
514+ if dtype in ["int32" , "float32" ]:
515+ tmp_dtype = "float32"
516+ else :
517+ tmp_dtype = "float64"
518+ start = _op .cast (start , tmp_dtype )
519+ stop = _op .cast (stop , tmp_dtype )
520+ num = _op .cast (num , tmp_dtype )
521+ const_one = _expr .const (1 , tmp_dtype )
522+ const_zero = _expr .const (0 , tmp_dtype )
523+ seg_num = _op .where (num > const_one , num - const_one , num - const_zero )
524+ seg_len = _op .subtract (stop , start )
525+ step_len = _op .divide (seg_len , seg_num )
526+ step_cnt = _op .argwhere (_op .ones (num , dtype = tmp_dtype ))
527+ step_cnt = _op .cast (step_cnt , dtype = tmp_dtype )
528+ out = _op .multiply (step_len , step_cnt )
529+ out = _op .add (start , out )
530+ out = _op .squeeze (out , axis = [1 ])
531+ out = _op .cast (out , dtype )
532+ g .add_node (op .output ("Out" )[0 ], out )
533+
534+
478535def convert_elu (g , op , block ):
479536 """Operator converter for elu."""
480537
@@ -514,6 +571,27 @@ def convert_expand_as(g, op, block):
514571 g .add_node (op .output ("Out" )[0 ], out )
515572
516573
574+ def convert_eye (g , op , block ):
575+ """Operator converter for eye."""
576+
577+ num_rows = op .attr ("num_rows" )
578+ num_columns = op .attr ("num_columns" )
579+ one_nums = min (num_rows , num_columns )
580+ dtype = op .attr ("dtype" )
581+ dtype = _convert_dtype_value (dtype )
582+
583+ zeros = _op .zeros ((num_rows , num_columns ), dtype )
584+ if one_nums == 0 :
585+ out = zeros
586+ else :
587+ ones = _op .ones (one_nums , dtype )
588+ indices = _op .arange (
589+ _expr .const (0 , dtype = "int32" ), _expr .const (one_nums , dtype = "int32" ), dtype = "int32"
590+ )
591+ out = _op .scatter_nd (zeros , _op .stack ([indices , indices ], axis = 0 ), ones , "update" )
592+ g .add_node (op .output ("Out" )[0 ], out )
593+
594+
517595def convert_feed (g , op , block ):
518596 """Converter for model input node."""
519597
@@ -830,6 +908,16 @@ def get_interpolate_mode(op):
830908 g .add_node (op .output ("Out" )[0 ], out )
831909
832910
911+ def convert_index_select (g , op , block ):
912+ """Operator converter for index_select."""
913+
914+ x = g .get_node (op .input ("X" )[0 ])
915+ index = g .get_node (op .input ("Index" )[0 ])
916+ axis = op .attr ("dim" )
917+ out = _op .transform .take (x , index , axis , mode = "wrap" )
918+ g .add_node (op .output ("Out" )[0 ], out )
919+
920+
833921def convert_instance_norm (g , op , block ):
834922 """Operator converter for instance_norm."""
835923
@@ -2072,13 +2160,27 @@ def convert_swish(g, op, block):
20722160
20732161
20742162def convert_take_along_axis (g , op , block ):
2163+ """Operator converter for take_along_axis."""
2164+
20752165 x = g .get_node (op .input ("Input" )[0 ])
20762166 idx = g .get_node (op .input ("Index" )[0 ])
20772167 axis = op .attr ("Axis" )
20782168 out = _op .gather (x , axis , idx )
20792169 g .add_node (op .output ("Result" )[0 ], out )
20802170
20812171
2172+ def convert_thresholded_relu (g , op , block ):
2173+ """Operator converter for thresholded_relu."""
2174+
2175+ x = g .get_node (op .input ("X" )[0 ])
2176+ dtype = infer_type (x ).checked_type .dtype
2177+ threshold = op .attr ("threshold" )
2178+ threshold = _expr .const (threshold , dtype )
2179+ zero = _expr .const (0 , dtype = dtype )
2180+ out = tvm .relay .where (x > threshold , x , zero )
2181+ g .add_node (op .output ("Out" )[0 ], out )
2182+
2183+
20822184def convert_tile (g , op , block ):
20832185 """Operator converter for tile."""
20842186
@@ -2220,6 +2322,7 @@ def convert_where_index(g, op, block):
22202322 "cumsum" : convert_cumsum ,
22212323 "depthwise_conv2d" : convert_conv2d ,
22222324 "depthwise_conv2d_transpose" : convert_conv2d_transpose ,
2325+ "dist" : convert_dist ,
22232326 "dot" : convert_dot ,
22242327 "dropout" : convert_dropout ,
22252328 "elementwise_add" : convert_elementwise_op ,
@@ -2238,6 +2341,7 @@ def convert_where_index(g, op, block):
22382341 "exp" : convert_unary_op ,
22392342 "expand_v2" : convert_expand ,
22402343 "expand_as_v2" : convert_expand_as ,
2344+ "eye" : convert_eye ,
22412345 "feed" : convert_feed ,
22422346 "fill_any_like" : convert_fill_any_like ,
22432347 "fill_constant" : convert_fill_constant ,
@@ -2254,6 +2358,7 @@ def convert_where_index(g, op, block):
22542358 "hard_shrink" : convert_hard_shrink ,
22552359 "hard_sigmoid" : convert_hard_sigmoid ,
22562360 "hard_swish" : convert_hard_swish ,
2361+ "index_select" : convert_index_select ,
22572362 "instance_norm" : convert_instance_norm ,
22582363 "isfinite_v2" : convert_unary_op ,
22592364 "isinf_v2" : convert_unary_op ,
@@ -2262,6 +2367,7 @@ def convert_where_index(g, op, block):
22622367 "leaky_relu" : convert_leaky_relu ,
22632368 "less_equal" : convert_elementwise_op ,
22642369 "less_than" : convert_elementwise_op ,
2370+ "linspace" : convert_linspace ,
22652371 "log" : convert_unary_op ,
22662372 "log2" : convert_unary_op ,
22672373 "log10" : convert_unary_op ,
@@ -2333,6 +2439,7 @@ def convert_where_index(g, op, block):
23332439 "tan" : convert_unary_op ,
23342440 "tanh" : convert_unary_op ,
23352441 "top_k" : convert_topk ,
2442+ "thresholded_relu" : convert_thresholded_relu ,
23362443 "tile" : convert_tile ,
23372444 "top_k_v2" : convert_topk ,
23382445 "transpose2" : convert_transpose ,
0 commit comments