@@ -3974,8 +3974,6 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:
39743974
39753975
39763976# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918
3977-
3978-
39793977def aten_hstack (tensors : Sequence [TTensor ]) -> TTensor :
39803978 """hstack(Tensor[] tensors) -> Tensor"""
39813979
@@ -7887,14 +7885,14 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr
78877885 return op .ConcatFromSequence (tensors , axis = dim , new_axis = 1 )
78887886
78897887
7890- @ torch_op ( "aten:: std" , trace_only = True )
7888+ # std is decomposed by PyTroch
78917889def aten_std (self : TReal , unbiased : bool = True ) -> TReal :
78927890 """std(Tensor self, bool unbiased=True) -> Tensor"""
78937891 var = _aten_var_onnx (self , correction = float (unbiased ), keepdim = False )
78947892 return op .Sqrt (var )
78957893
78967894
7897- @ torch_op ( "aten::std.dim" , trace_only = True )
7895+ # std_dim is decomposed by PyTroch
78987896def aten_std_dim (
78997897 self : TReal ,
79007898 dim : Sequence [int ],
@@ -7907,7 +7905,7 @@ def aten_std_dim(
79077905 return op .Sqrt (var )
79087906
79097907
7910- @ torch_op ( "aten::var.correction" , trace_only = True )
7908+ # std is decomposed by PyTroch
79117909def aten_std_correction (
79127910 self : TReal ,
79137911 # FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -7927,7 +7925,7 @@ def aten_std_correction(
79277925 return op .Sqrt (var )
79287926
79297927
7930- @ torch_op ( "aten:: std_mean" , trace_only = True )
7928+ # std_mean is decomposed by PyTroch
79317929def aten_std_mean (self : TReal , unbiased : bool = True ) -> Tuple [TReal , TReal ]:
79327930 """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
79337931
@@ -7937,7 +7935,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
79377935 return op .Sqrt (var ), mean
79387936
79397937
7940- @ torch_op ( "aten:: std_mean.dim" , trace_only = True )
7938+ # std_mean is decomposed by PyTroch
79417939def aten_std_mean_dim (
79427940 self : TReal , dim : Sequence [int ], unbiased : bool = True , keepdim : bool = False
79437941) -> Tuple [TReal , TReal ]:
@@ -7951,7 +7949,7 @@ def aten_std_mean_dim(
79517949 return op .Sqrt (var ), mean
79527950
79537951
7954- @ torch_op ( "aten:: std_mean.correction" , trace_only = True )
7952+ # std_mean is decomposed by PyTroch
79557953def aten_std_mean_correction (
79567954 self : TReal ,
79577955 # FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -7973,139 +7971,6 @@ def aten_std_mean_correction(
79737971 return op .Sqrt (var ), mean
79747972
79757973
7976- @torch_op ("aten::stft" , private = True )
7977- def _add_batch_dimension (self : TFloatOrBFloat16 ) -> Tuple [TFloatOrBFloat16 , INT64 ]:
7978- signal_rank = Rank (self )
7979- if signal_rank == 1 :
7980- # Add a batch dimension
7981- self = op .Unsqueeze (self , op .Constant (value_ints = [0 ]))
7982- return op .Identity (self ), signal_rank
7983-
7984-
7985- @torch_op ("aten::stft" , private = True )
7986- def _center_window_around_zeros_if_needed (
7987- window : TFloatOrBFloat16 , n_fft : int
7988- ) -> TFloatOrBFloat16 :
7989- # first dimension
7990- n_win = op .Shape (window , start = 0 , end = 1 )
7991- # Center window around zeros if needed (required by ONNX's STFT)
7992- if n_win < n_fft :
7993- left = (n_fft - n_win ) / 2
7994-
7995- right = n_fft - left - n_win
7996- left = op .Reshape (left , op .Constant (value_ints = [1 ]))
7997- right = op .Reshape (right , op .Constant (value_ints = [1 ]))
7998-
7999- left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8000- right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8001- right_win = op .CastLike (right_win , window )
8002- left_win = op .CastLike (left_win , window )
8003- window = op .Concat (left_win , window , right_win , axis = 0 )
8004- return window
8005-
8006-
8007- @torch_op ("aten::stft" , private = True )
8008- def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloatOrBFloat16 :
8009- left = (n_fft - win_length ) / 2
8010-
8011- right = n_fft - left - win_length
8012- left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8013- right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8014- win_length = op .Reshape (win_length , op .Constant (value_ints = [1 ]))
8015-
8016- left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8017- right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8018- window_list = op .Expand (op .Constant (value_ints = [1 ]), win_length )
8019- return op .Concat (left_win , window_list , right_win , axis = 0 )
8020-
8021-
8022- @torch_op ("aten::stft" , private = True )
8023- def _create_window_from_n_fft (n_fft : int ) -> TFloatOrBFloat16 :
8024- n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8025- window = op .Expand (op .Constant (value_ints = [1 ]), n_fft_tensor )
8026- return window
8027-
8028-
8029- @torch_op ("aten::stft" , private = True )
8030- def _normalize_fft_result (
8031- signal : TFloatOrBFloat16 , result : TFloatOrBFloat16 , n_fft : int
8032- ) -> TFloatOrBFloat16 :
8033- n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8034- sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
8035- result = result / sqrt_nfft
8036- return result
8037-
8038-
8039- @torch_op ("aten::stft" , private = True )
8040- def _aten_stft_onnx (
8041- signal : TFloatOrBFloat16 ,
8042- frame_step_const : INT64 ,
8043- window : Union [TFloatOrBFloat16 , INT64 ],
8044- frame_length_const : INT64 ,
8045- signal_rank : INT64 ,
8046- onesided : int ,
8047- ) -> TFloatOrBFloat16 :
8048- window = op .CastLike (window , signal )
8049- result = op .STFT (signal , frame_step_const , window , frame_length_const , onesided = onesided )
8050- result = op .Transpose (result , perm = [0 , 2 , 1 , 3 ])
8051- # Remove batch dimension, if needed
8052- if signal_rank == 1 :
8053- result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
8054- return result
8055-
8056-
8057- @torch_op ("aten::stft" , trace_only = True )
8058- def aten_stft (
8059- self : TFloatOrBFloat16 ,
8060- n_fft : int ,
8061- hop_length : Optional [int ] = None ,
8062- win_length : Optional [int ] = None ,
8063- window : Optional [TFloatOrBFloat16 ] = None ,
8064- normalized : bool = False ,
8065- onesided : Optional [bool ] = None ,
8066- return_complex : Optional [bool ] = None ,
8067- ) -> TFloatOrBFloat16 :
8068- """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8069-
8070- # NOTE: regarless of the value of return_complex, we always return a real representation.
8071- del return_complex
8072-
8073- # Get STFT sizes
8074- if hop_length is None :
8075- # core dump
8076- # hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8077- hop_length = n_fft // 4
8078- frame_step_const = op .Reshape (hop_length , op .Constant (value_ints = [1 ]))
8079- frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
8080-
8081- # Pre-process input if needed
8082- self , signal_rank = _add_batch_dimension (self )
8083-
8084- # Get window and make sure it's the same size as `win_length` or `n_fft`
8085- if window is not None and window .shape [0 ] is not None :
8086- window = _center_window_around_zeros_if_needed (window , n_fft )
8087- elif window is None :
8088- if win_length is not None :
8089- window = _create_window_from_win_length (win_length , n_fft )
8090- else :
8091- window = _create_window_from_n_fft (n_fft )
8092-
8093- if onesided is None or onesided :
8094- onesided = 1
8095- else :
8096- onesided = 0
8097- # remove batch dimension included
8098- result = _aten_stft_onnx (
8099- self , frame_step_const , window , frame_length_const , signal_rank , onesided
8100- )
8101-
8102- # Normalize, if needed
8103- if normalized :
8104- result = _normalize_fft_result (self , result , n_fft )
8105-
8106- return result
8107-
8108-
81097974@torch_op (
81107975 (
81117976 "aten::sub.Tensor" ,
@@ -8738,7 +8603,7 @@ def aten_vander(
87388603 raise NotImplementedError ()
87398604
87408605
8741- @ torch_op ( "aten:: var" , trace_only = True )
8606+ # var is decomposed by PyTroch
87428607def aten_var (self : TReal , unbiased : Optional [bool ] = True ) -> TReal :
87438608 """var(Tensor self, bool unbiased=True) -> Tensor"""
87448609
@@ -8747,7 +8612,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
87478612 return _aten_var_onnx (self , correction = float (unbiased ), keepdim = False )
87488613
87498614
8750- @ torch_op ( "aten:: var.dim" , trace_only = True )
8615+ # var is decomposed by PyTroch
87518616def aten_var_dim (
87528617 self : TReal ,
87538618 dim : Sequence [int ],
@@ -8759,7 +8624,7 @@ def aten_var_dim(
87598624 return _aten_var_dim_onnx (self , dims = dim , correction = float (unbiased ), keepdim = keepdim )
87608625
87618626
8762- @ torch_op ( "aten:: var.correction" , trace_only = True )
8627+ # var is decomposed by PyTroch
87638628def aten_var_correction (
87648629 self : TReal ,
87658630 # FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -8779,7 +8644,7 @@ def aten_var_correction(
87798644 return var
87808645
87818646
8782- @ torch_op ( "aten:: var" , private = True , traceable = True )
8647+ # var is decomposed by PyTroch
87838648def _aten_var_onnx (self : TReal , correction : float , keepdim : bool = False ) -> TReal :
87848649 mean = op .ReduceMean (self , keepdims = keepdim )
87858650 sub_mean = op .Sub (self , mean )
@@ -8796,7 +8661,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe
87968661 return var
87978662
87988663
8799- @ torch_op ( "aten:: var.dim" , private = True , traceable = True )
8664+ # var is decomposed by PyTroch
88008665def _aten_var_dim_onnx (
88018666 self : TReal , dims : Sequence [int ], correction : float , keepdim : bool = False
88028667) -> TReal :
@@ -8817,7 +8682,7 @@ def _aten_var_dim_onnx(
88178682 return var
88188683
88198684
8820- @ torch_op ( "aten:: var_mean" , trace_only = True )
8685+ # var_mean is decomposed by PyTroch
88218686def aten_var_mean (self : TReal , unbiased : bool = True ) -> Tuple [TReal , TReal ]:
88228687 """var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
88238688
@@ -8826,7 +8691,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
88268691 return _aten_var_mean_onnx (self , correction = float (unbiased ), keepdim = False )
88278692
88288693
8829- @ torch_op ( "aten:: var_mean.dim" , trace_only = True )
8694+ # var_mean is decomposed by PyTroch
88308695def aten_var_mean_dim (
88318696 self : TReal , dim : Sequence [int ], unbiased : bool = True , keepdim : bool = False
88328697) -> Tuple [TReal , TReal ]:
@@ -8837,7 +8702,7 @@ def aten_var_mean_dim(
88378702 return _aten_var_mean_dim_onnx (self , dims = dim , correction = float (unbiased ), keepdim = keepdim )
88388703
88398704
8840- @ torch_op ( "aten:: var_mean.correction" , trace_only = True )
8705+ # var_mean is decomposed by PyTroch
88418706def aten_var_mean_correction (
88428707 self : TReal ,
88438708 # FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -8859,7 +8724,7 @@ def aten_var_mean_correction(
88598724 return var , mean
88608725
88618726
8862- @ torch_op ( "aten:: var_mean" , private = True )
8727+ # var_mean is decomposed by PyTroch
88638728def _aten_var_mean_onnx (
88648729 self : TReal , correction : float = 1.0 , keepdim : bool = False
88658730) -> Tuple [TReal , TReal ]:
@@ -8879,7 +8744,7 @@ def _aten_var_mean_onnx(
88798744 return var , mean
88808745
88818746
8882- @ torch_op ( "aten:: var_mean.dim" , private = True )
8747+ # var_mean is decomposed by PyTroch
88838748def _aten_var_mean_dim_onnx (
88848749 self : TReal , dims : Sequence [int ], correction : float , keepdim : bool = False
88858750) -> Tuple [TReal , TReal ]:
@@ -8977,8 +8842,6 @@ def aten_view_copy(self: TTensor, size: IntType) -> TTensor:
89778842
89788843
89798844# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918
8980-
8981-
89828845def aten_vstack (tensors : Sequence [TTensor ]) -> TTensor :
89838846 """vstack(Tensor[] tensors) -> Tensor"""
89848847
@@ -8998,6 +8861,7 @@ def reshape_to_2d(tensor):
89988861
89998862@torch_op (
90008863 (
8864+ "aten::where" ,
90018865 "aten::where.Scalar" ,
90028866 "aten::where.ScalarSelf" ,
90038867 "aten::where.ScalarOther" ,
0 commit comments