File tree Expand file tree Collapse file tree 2 files changed +10
-12
lines changed
onnxscript/function_libs/torch_lib/ops
tests/function_libs/torch_lib Expand file tree Collapse file tree 2 files changed +10
-12
lines changed Original file line number Diff line number Diff line change 6161Rank = common_ops .Rank
6262
6363
64- @torch_op ("aten::_local_scalar_dense" )
65- def aten__local_scalar_dense (self : Union [ FLOAT16 , FLOAT , DOUBLE , BFLOAT16 ] ) -> FLOAT :
64+ @torch_op ("aten::_local_scalar_dense" , trace_only = True )
65+ def aten__local_scalar_dense (self : TensorType ) -> TensorType :
6666 """_local_scalar_dense(Tensor self) -> Scalar"""
6767
6868 # Return the first element in tensor as a scalar.
69- return op .Cast (op .Gather (op .Reshape (self , [- 1 ]), 0 ), to = FLOAT .dtype )
70-
71-
72- @torch_op ("aten::_local_scalar_dense" )
73- def aten__local_scalar_dense_int (self : IntType ) -> INT64 :
74- """_local_scalar_dense(Tensor self) -> Scalar"""
75-
76- # Return the first element in tensor as a scalar.
77- return op .Cast (op .Gather (op .Reshape (self , [- 1 ]), 0 ), to = INT64 .dtype )
69+ if self .dtype .is_floating_point ():
70+ dtype = ir .DataType .FLOAT
71+ elif self .dtype == ir .DataType .BOOL :
72+ dtype = ir .DataType .BOOL
73+ else :
74+ dtype = ir .DataType .INT64
75+ return op .Cast (op .Gather (op .Reshape (self , [- 1 ]), 0 ), to = dtype )
7876
7977
8078@torch_op ("aten::_log_softmax" , trace_only = True )
Original file line number Diff line number Diff line change @@ -2308,7 +2308,7 @@ def __init__(self):
23082308 opinfo_core .OpInfo (
23092309 "ops.aten._local_scalar_dense" ,
23102310 aten_name = "_local_scalar_dense" ,
2311- dtypes = common_dtype .all_types ( ),
2311+ dtypes = common_dtype .all_types_and ( torch . bool ),
23122312 sample_inputs_func = sample_inputs__local_scalar_dense ,
23132313 supports_out = False ,
23142314 ),
You can’t perform that action at this time.
0 commit comments