Skip to content

Commit 1eef633

Browse files
authored
[torchlib] Unregister stft, var, var_mean, std, std_mean (#1867)
Following pytorch/pytorch#136153, we remove stft, var, var_mean, std, std_mean ops. They were never used even before because the ops were always decomposed.
1 parent 377869a commit 1eef633

File tree

2 files changed

+17
-274
lines changed

2 files changed

+17
-274
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 17 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -3974,8 +3974,6 @@ def aten_hspmm(mat1: TensorType, mat2: TensorType) -> TensorType:
39743974

39753975

39763976
# Do not register hstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918
3977-
3978-
39793977
def aten_hstack(tensors: Sequence[TTensor]) -> TTensor:
39803978
"""hstack(Tensor[] tensors) -> Tensor"""
39813979

@@ -7887,14 +7885,14 @@ def aten_stack(tensors: Sequence[TTensorOrString], dim: int = 0) -> TTensorOrStr
78877885
return op.ConcatFromSequence(tensors, axis=dim, new_axis=1)
78887886

78897887

7890-
@torch_op("aten::std", trace_only=True)
7888+
# std is decomposed by PyTroch
78917889
def aten_std(self: TReal, unbiased: bool = True) -> TReal:
78927890
"""std(Tensor self, bool unbiased=True) -> Tensor"""
78937891
var = _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
78947892
return op.Sqrt(var)
78957893

78967894

7897-
@torch_op("aten::std.dim", trace_only=True)
7895+
# std_dim is decomposed by PyTroch
78987896
def aten_std_dim(
78997897
self: TReal,
79007898
dim: Sequence[int],
@@ -7907,7 +7905,7 @@ def aten_std_dim(
79077905
return op.Sqrt(var)
79087906

79097907

7910-
@torch_op("aten::var.correction", trace_only=True)
7908+
# std is decomposed by PyTroch
79117909
def aten_std_correction(
79127910
self: TReal,
79137911
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -7927,7 +7925,7 @@ def aten_std_correction(
79277925
return op.Sqrt(var)
79287926

79297927

7930-
@torch_op("aten::std_mean", trace_only=True)
7928+
# std_mean is decomposed by PyTroch
79317929
def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
79327930
"""std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
79337931

@@ -7937,7 +7935,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
79377935
return op.Sqrt(var), mean
79387936

79397937

7940-
@torch_op("aten::std_mean.dim", trace_only=True)
7938+
# std_mean is decomposed by PyTroch
79417939
def aten_std_mean_dim(
79427940
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
79437941
) -> Tuple[TReal, TReal]:
@@ -7951,7 +7949,7 @@ def aten_std_mean_dim(
79517949
return op.Sqrt(var), mean
79527950

79537951

7954-
@torch_op("aten::std_mean.correction", trace_only=True)
7952+
# std_mean is decomposed by PyTroch
79557953
def aten_std_mean_correction(
79567954
self: TReal,
79577955
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -7973,139 +7971,6 @@ def aten_std_mean_correction(
79737971
return op.Sqrt(var), mean
79747972

79757973

7976-
@torch_op("aten::stft", private=True)
7977-
def _add_batch_dimension(self: TFloatOrBFloat16) -> Tuple[TFloatOrBFloat16, INT64]:
7978-
signal_rank = Rank(self)
7979-
if signal_rank == 1:
7980-
# Add a batch dimension
7981-
self = op.Unsqueeze(self, op.Constant(value_ints=[0]))
7982-
return op.Identity(self), signal_rank
7983-
7984-
7985-
@torch_op("aten::stft", private=True)
7986-
def _center_window_around_zeros_if_needed(
7987-
window: TFloatOrBFloat16, n_fft: int
7988-
) -> TFloatOrBFloat16:
7989-
# first dimension
7990-
n_win = op.Shape(window, start=0, end=1)
7991-
# Center window around zeros if needed (required by ONNX's STFT)
7992-
if n_win < n_fft:
7993-
left = (n_fft - n_win) / 2
7994-
7995-
right = n_fft - left - n_win
7996-
left = op.Reshape(left, op.Constant(value_ints=[1]))
7997-
right = op.Reshape(right, op.Constant(value_ints=[1]))
7998-
7999-
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8000-
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8001-
right_win = op.CastLike(right_win, window)
8002-
left_win = op.CastLike(left_win, window)
8003-
window = op.Concat(left_win, window, right_win, axis=0)
8004-
return window
8005-
8006-
8007-
@torch_op("aten::stft", private=True)
8008-
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloatOrBFloat16:
8009-
left = (n_fft - win_length) / 2
8010-
8011-
right = n_fft - left - win_length
8012-
left = op.Reshape(left, op.Constant(value_ints=[1]))
8013-
right = op.Reshape(right, op.Constant(value_ints=[1]))
8014-
win_length = op.Reshape(win_length, op.Constant(value_ints=[1]))
8015-
8016-
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8017-
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8018-
window_list = op.Expand(op.Constant(value_ints=[1]), win_length)
8019-
return op.Concat(left_win, window_list, right_win, axis=0)
8020-
8021-
8022-
@torch_op("aten::stft", private=True)
8023-
def _create_window_from_n_fft(n_fft: int) -> TFloatOrBFloat16:
8024-
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8025-
window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor)
8026-
return window
8027-
8028-
8029-
@torch_op("aten::stft", private=True)
8030-
def _normalize_fft_result(
8031-
signal: TFloatOrBFloat16, result: TFloatOrBFloat16, n_fft: int
8032-
) -> TFloatOrBFloat16:
8033-
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8034-
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
8035-
result = result / sqrt_nfft
8036-
return result
8037-
8038-
8039-
@torch_op("aten::stft", private=True)
8040-
def _aten_stft_onnx(
8041-
signal: TFloatOrBFloat16,
8042-
frame_step_const: INT64,
8043-
window: Union[TFloatOrBFloat16, INT64],
8044-
frame_length_const: INT64,
8045-
signal_rank: INT64,
8046-
onesided: int,
8047-
) -> TFloatOrBFloat16:
8048-
window = op.CastLike(window, signal)
8049-
result = op.STFT(signal, frame_step_const, window, frame_length_const, onesided=onesided)
8050-
result = op.Transpose(result, perm=[0, 2, 1, 3])
8051-
# Remove batch dimension, if needed
8052-
if signal_rank == 1:
8053-
result = op.Squeeze(result, op.Constant(value_ints=[0]))
8054-
return result
8055-
8056-
8057-
@torch_op("aten::stft", trace_only=True)
8058-
def aten_stft(
8059-
self: TFloatOrBFloat16,
8060-
n_fft: int,
8061-
hop_length: Optional[int] = None,
8062-
win_length: Optional[int] = None,
8063-
window: Optional[TFloatOrBFloat16] = None,
8064-
normalized: bool = False,
8065-
onesided: Optional[bool] = None,
8066-
return_complex: Optional[bool] = None,
8067-
) -> TFloatOrBFloat16:
8068-
"""stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor"""
8069-
8070-
# NOTE: regarless of the value of return_complex, we always return a real representation.
8071-
del return_complex
8072-
8073-
# Get STFT sizes
8074-
if hop_length is None:
8075-
# core dump
8076-
# hop_leagth = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4]))
8077-
hop_length = n_fft // 4
8078-
frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1]))
8079-
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
8080-
8081-
# Pre-process input if needed
8082-
self, signal_rank = _add_batch_dimension(self)
8083-
8084-
# Get window and make sure it's the same size as `win_length` or `n_fft`
8085-
if window is not None and window.shape[0] is not None:
8086-
window = _center_window_around_zeros_if_needed(window, n_fft)
8087-
elif window is None:
8088-
if win_length is not None:
8089-
window = _create_window_from_win_length(win_length, n_fft)
8090-
else:
8091-
window = _create_window_from_n_fft(n_fft)
8092-
8093-
if onesided is None or onesided:
8094-
onesided = 1
8095-
else:
8096-
onesided = 0
8097-
# remove batch dimension included
8098-
result = _aten_stft_onnx(
8099-
self, frame_step_const, window, frame_length_const, signal_rank, onesided
8100-
)
8101-
8102-
# Normalize, if needed
8103-
if normalized:
8104-
result = _normalize_fft_result(self, result, n_fft)
8105-
8106-
return result
8107-
8108-
81097974
@torch_op(
81107975
(
81117976
"aten::sub.Tensor",
@@ -8738,7 +8603,7 @@ def aten_vander(
87388603
raise NotImplementedError()
87398604

87408605

8741-
@torch_op("aten::var", trace_only=True)
8606+
# var is decomposed by PyTroch
87428607
def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
87438608
"""var(Tensor self, bool unbiased=True) -> Tensor"""
87448609

@@ -8747,7 +8612,7 @@ def aten_var(self: TReal, unbiased: Optional[bool] = True) -> TReal:
87478612
return _aten_var_onnx(self, correction=float(unbiased), keepdim=False)
87488613

87498614

8750-
@torch_op("aten::var.dim", trace_only=True)
8615+
# var is decomposed by PyTroch
87518616
def aten_var_dim(
87528617
self: TReal,
87538618
dim: Sequence[int],
@@ -8759,7 +8624,7 @@ def aten_var_dim(
87598624
return _aten_var_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)
87608625

87618626

8762-
@torch_op("aten::var.correction", trace_only=True)
8627+
# var is decomposed by PyTroch
87638628
def aten_var_correction(
87648629
self: TReal,
87658630
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -8779,7 +8644,7 @@ def aten_var_correction(
87798644
return var
87808645

87818646

8782-
@torch_op("aten::var", private=True, traceable=True)
8647+
# var is decomposed by PyTroch
87838648
def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TReal:
87848649
mean = op.ReduceMean(self, keepdims=keepdim)
87858650
sub_mean = op.Sub(self, mean)
@@ -8796,7 +8661,7 @@ def _aten_var_onnx(self: TReal, correction: float, keepdim: bool = False) -> TRe
87968661
return var
87978662

87988663

8799-
@torch_op("aten::var.dim", private=True, traceable=True)
8664+
# var is decomposed by PyTroch
88008665
def _aten_var_dim_onnx(
88018666
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
88028667
) -> TReal:
@@ -8817,7 +8682,7 @@ def _aten_var_dim_onnx(
88178682
return var
88188683

88198684

8820-
@torch_op("aten::var_mean", trace_only=True)
8685+
# var_mean is decomposed by PyTroch
88218686
def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
88228687
"""var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
88238688

@@ -8826,7 +8691,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
88268691
return _aten_var_mean_onnx(self, correction=float(unbiased), keepdim=False)
88278692

88288693

8829-
@torch_op("aten::var_mean.dim", trace_only=True)
8694+
# var_mean is decomposed by PyTroch
88308695
def aten_var_mean_dim(
88318696
self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False
88328697
) -> Tuple[TReal, TReal]:
@@ -8837,7 +8702,7 @@ def aten_var_mean_dim(
88378702
return _aten_var_mean_dim_onnx(self, dims=dim, correction=float(unbiased), keepdim=keepdim)
88388703

88398704

8840-
@torch_op("aten::var_mean.correction", trace_only=True)
8705+
# var_mean is decomposed by PyTroch
88418706
def aten_var_mean_correction(
88428707
self: TReal,
88438708
# FIXME(justinchuby): Make dim Optional[Sequence[int]]
@@ -8859,7 +8724,7 @@ def aten_var_mean_correction(
88598724
return var, mean
88608725

88618726

8862-
@torch_op("aten::var_mean", private=True)
8727+
# var_mean is decomposed by PyTroch
88638728
def _aten_var_mean_onnx(
88648729
self: TReal, correction: float = 1.0, keepdim: bool = False
88658730
) -> Tuple[TReal, TReal]:
@@ -8879,7 +8744,7 @@ def _aten_var_mean_onnx(
88798744
return var, mean
88808745

88818746

8882-
@torch_op("aten::var_mean.dim", private=True)
8747+
# var_mean is decomposed by PyTroch
88838748
def _aten_var_mean_dim_onnx(
88848749
self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False
88858750
) -> Tuple[TReal, TReal]:
@@ -8977,8 +8842,6 @@ def aten_view_copy(self: TTensor, size: IntType) -> TTensor:
89778842

89788843

89798844
# Do not register vstack - decomposed by PyTorch: https://github.com/pytorch/pytorch/blob/bedf96d7ffe74b34bcfe52c7ae1ae05f40d6c8ee/torch/_refs/__init__.py#L3918
8980-
8981-
89828845
def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
89838846
"""vstack(Tensor[] tensors) -> Tensor"""
89848847

@@ -8998,6 +8861,7 @@ def reshape_to_2d(tensor):
89988861

89998862
@torch_op(
90008863
(
8864+
"aten::where",
90018865
"aten::where.Scalar",
90028866
"aten::where.ScalarSelf",
90038867
"aten::where.ScalarOther",

0 commit comments

Comments
 (0)