Skip to content

Commit f8ee736

Browse files
[torchlib] Implement missing operators (set1) (#1706)
Implement missing operators uncovered by torch.onnx tests as per #1644 - [x] Implement <OpOverload(op='aten.fmod', overload='Scalar')> - [x] Implement <OpOverload(op='aten.fmod', overload='Tensor')> - [x] Implement <OpOverload(op='aten.glu', overload='default')> @shubhambhokare1 - [x] Implement <OpOverload(op='aten.le', overload='Scalar')> - [x] Implement <OpOverload(op='aten.lerp', overload='Scalar')> - [x] Implement <OpOverload(op='aten.linalg_cross', overload='default')> - [x] Implement <OpOverload(op='aten.mv', overload='default')> - [x] Implement <OpOverload(op='aten.pow', overload='Scalar')> - [x] Implement <OpOverload(op='aten.remainder', overload='Scalar')> - [x] Implement <OpOverload(op='aten.remainder', overload='Tensor')> - [x] Implement <OpOverload(op='aten.silu', overload='default')> - [x] Implement <OpOverload(op='aten.unsafe_split', overload='Tensor')> [**NOT PART OF THIS PR**] Requires adding implementation functions in torchlib eventually (not currently high in priority) - [ ] Implement `<OpOverload(op='aten.__rshift__', overload='Scalar')>` - [ ] Implement <OpOverload(op='aten._linalg_det', overload='default')> - [ ] Implement <OpOverload(op='aten._linalg_slogdet', overload='default')> - [ ] Implement <OpOverload(op='aten._prelu_kernel', overload='default')> - [ ] Implement <OpOverload(op='aten.add', overload='Scalar')> - [ ] Implement <OpOverload(op='aten.add', overload='Tensor')> - [ ] Implement <OpOverload(op='aten.affine_grid_generator', overload='default')> - [ ] Implement <OpOverload(op='aten.aminmax', overload='default')> - [ ] Implement <OpOverload(op='aten.binary_cross_entropy_with_logits', overload='default')> - [ ] Implement <OpOverload(op='aten.bitwise_and', overload='Tensor')> - [ ] Implement <OpOverload(op='aten.bucketize', overload='Tensor')> - [ ] Implement <OpOverload(op='aten.conv_tbc', overload='default')> - [ ] Implement <OpOverload(op='aten.fake_quantize_per_tensor_affine_cachemask', overload='default')> - [ ] Implement <OpOverload(op='aten.fill', overload='Scalar')> - [ ] Implement <OpOverload(op='aten.index_add', overload='default')> - [ ] Implement <OpOverload(op='aten.index_copy', overload='default')> - [ ] Implement <OpOverload(op='aten.index_fill', overload='int_Scalar')> - [ ] Implement <OpOverload(op='aten.index_put', overload='default')> - [ ] Implement <OpOverload(op='aten.masked_scatter', overload='default')> - [ ] Implement <OpOverload(op='aten.masked_select', overload='default')> - [ ] Implement <OpOverload(op='aten.prod', overload='dim_int')> - [ ] Implement <OpOverload(op='aten.rsub', overload='Tensor')> - [ ] Implement <OpOverload(op='aten.scatter', overload='src')> - [ ] Implement <OpOverload(op='aten.scatter', overload='value')> - [ ] Implement <OpOverload(op='aten.sort', overload='default')> - [ ] Implement <OpOverload(op='aten.std', overload='correction')> - [ ] Implement <OpOverload(op='aten.std_mean', overload='correction')> - [ ] Implement <OpOverload(op='aten.sym_size', overload='int')> - [ ] Implement <OpOverload(op='aten.take', overload='default')> - Implement <OpOverload(op='aten._adaptive_avg_pool2d', overload='default')> - Implement <OpOverload(op='aten._cdist_forward', overload='default')> - Implement <OpOverload(op='aten._convolution', overload='default')> - Implement <OpOverload(op='aten._fake_quantize_per_tensor_affine_cachemask_tensor_qparams', overload='default')> - Implement <OpOverload(op='aten.grid_sampler_3d', overload='default')> - Implement <OpOverload(op='aten.hann_window', overload='default')> - Implement <OpOverload(op='aten.im2col', overload='default')> - Implement <OpOverload(op='aten.repeat_interleave', overload='Tensor')> - Implement <OpOverload(op='torchvision.nms', overload='default')> - Implement <OpOverload(op='torchvision.roi_align', overload='default')> - Implement <OpOverload(op='torchvision.roi_pool', overload='default')> - [ ] Implement <OpOverload(op='aten.nan_to_num', overload='default')> - [ ] Implement <OpOverload(op='aten.nll_loss2d_forward', overload='default')> - [ ] Implement <OpOverload(op='aten.nll_loss_forward', overload='default')> - [ ] Implement <OpOverload(op='aten.norm', overload='ScalarOpt_dim_dtype')> - [ ] Implement <OpOverload(op='aten.pixel_unshuffle', overload='default')> Add operator registration - [ ] aten::empty - [ ] aten::fill - [ ] aten::getitem - [ ] aten::normal - [ ] aten::rsub - [ ] aten::scatter_reduce - [ ] aten::select - [ ] aten::slice - [ ] aten::softmax - [ ] aten::subtract - [ ] aten::transpose - [ ] aten::unbind
1 parent fb7dea4 commit f8ee736

File tree

3 files changed

+34
-24
lines changed

3 files changed

+34
-24
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -2236,23 +2236,13 @@ def aten_cov(
22362236
raise NotImplementedError()
22372237

22382238

2239-
@torch_op("aten::cross")
2239+
@torch_op(("aten::cross", "aten::linalg_cross"))
22402240
def aten_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
22412241
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""
22422242

2243-
zero = op.Constant(value_ints=[0])
2244-
one = op.Constant(value_ints=[1])
2245-
two = op.Constant(value_ints=[2])
2246-
three = op.Constant(value_ints=[3])
2247-
axes = op.Expand(dim, op.Constant(value_ints=[1]))
2248-
22492243
# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
2250-
a1 = op.Slice(self, zero, one, axes)
2251-
a2 = op.Slice(self, one, two, axes)
2252-
a3 = op.Slice(self, two, three, axes)
2253-
b1 = op.Slice(other, zero, one, axes)
2254-
b2 = op.Slice(other, one, two, axes)
2255-
b3 = op.Slice(other, two, three, axes)
2244+
a1, a2, a3 = op.Split(self, axis=dim, num_outputs=3)
2245+
b1, b2, b3 = op.Split(other, axis=dim, num_outputs=3)
22562246
# Broadcasting is implicitly supported by Mul
22572247
c1 = op.Sub(op.Mul(a2, b3), op.Mul(a3, b2))
22582248
c2 = op.Sub(op.Mul(a3, b1), op.Mul(a1, b3))
@@ -3571,7 +3561,7 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType:
35713561
raise NotImplementedError()
35723562

35733563

3574-
@torch_op("aten::fmod")
3564+
@torch_op(("aten::fmod.Tensor", "aten::fmod.Scalar"))
35753565
def aten_fmod(self: TRealOrUInt8, other: TRealOrUInt8) -> TRealOrUInt8:
35763566
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""
35773567

@@ -4659,7 +4649,7 @@ def aten_le(self: TReal, other: TReal) -> BOOL:
46594649
return op.LessOrEqual(self, other)
46604650

46614651

4662-
@torch_op(("aten::le.Tensor", "aten::less_equal.Tensor", "_operator::le"))
4652+
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
46634653
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
46644654
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""
46654655

@@ -4672,10 +4662,17 @@ def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
46724662
return op.Or(other, op.Not(self))
46734663

46744664

4675-
def aten_lerp(self: TensorType, end: TensorType, weight: TensorType) -> TensorType:
4665+
@torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar"))
4666+
def aten_lerp(self: TTensor, end: TTensor, weight: TTensor) -> TTensor:
46764667
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""
46774668

4678-
raise NotImplementedError()
4669+
weight = op.CastLike(weight, self)
4670+
diff = op.Sub(end, self)
4671+
return op.Where(
4672+
op.Less(weight, 0.5),
4673+
op.Add(self, op.Mul(weight, diff)),
4674+
op.Sub(end, op.Mul(diff, op.Sub(1.0, weight))),
4675+
)
46794676

46804677

46814678
def aten_lgamma(self: TensorType) -> TensorType:
@@ -5619,10 +5616,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType:
56195616
raise NotImplementedError()
56205617

56215618

5619+
@torch_op("aten::mv")
56225620
def aten_mv(self: TensorType, vec: TensorType) -> TensorType:
56235621
"""mv(Tensor self, Tensor vec) -> Tensor"""
56245622

5625-
raise NotImplementedError()
5623+
return op.MatMul(self, vec)
56265624

56275625

56285626
def aten_mvlgamma(self: TensorType, p: int) -> TensorType:
@@ -7011,7 +7009,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
70117009
raise NotImplementedError()
70127010

70137011

7014-
@torch_op("aten::remainder")
7012+
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
70157013
def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrBFloat16:
70167014
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
70177015

@@ -7024,7 +7022,7 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB
70247022
return op.Sub(self, op.Mul(rounded_quotient, other))
70257023

70267024

7027-
@torch_op("aten::remainder")
7025+
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
70287026
def aten_remainder_int(self: TInt, other: TInt) -> TInt:
70297027
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
70307028

@@ -8533,10 +8531,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType
85338531
raise NotImplementedError()
85348532

85358533

8536-
def aten_unsafe_split(self: TensorType, split_size: INT64, dim: int = 0) -> TensorType:
8534+
@torch_op(("aten::unsafe_split", "aten::unsafe_split.Tensor"))
8535+
def aten_unsafe_split(self: TTensor, split_size: INT64, dim: int = 0) -> Sequence[TTensor]:
85378536
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""
85388537

8539-
raise NotImplementedError()
8538+
return op.SplitToSequence(self, split_size, axis=dim)
85408539

85418540

85428541
def aten_unsafe_split_with_sizes(

onnxscript/function_libs/torch_lib/ops/linalg.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from onnxscript import BOOL, FLOAT, INT64
1818
from onnxscript.function_libs.torch_lib.ops import common as common_ops
1919
from onnxscript.function_libs.torch_lib.registration import torch_op
20-
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat
20+
from onnxscript.function_libs.torch_lib.tensor_typing import TFloat, TTensor
2121
from onnxscript.onnx_opset import opset18 as op
2222
from onnxscript.onnx_types import TensorType
2323

@@ -44,9 +44,10 @@ def aten_linalg_cond(self: TensorType, p: Optional[float] = None) -> TensorType:
4444
raise NotImplementedError()
4545

4646

47-
def aten_linalg_cross(self: TensorType, other: TensorType, dim: int = -1) -> TensorType:
47+
def aten_linalg_cross(self: TTensor, other: TTensor, dim: int = -1) -> TTensor:
4848
"""linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor"""
4949

50+
# Same implementation as aten_cross
5051
raise NotImplementedError()
5152

5253

tests/function_libs/torch_lib/ops_test_data.py

+10
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,11 @@ def _where_input_wrangler(
900900
TorchLibOpInfo("log", core_ops.aten_log),
901901
TorchLibOpInfo("le", core_ops.aten_le),
902902
TorchLibOpInfo("le_bool", core_ops.aten_le_bool),
903+
TorchLibOpInfo(
904+
"lerp",
905+
core_ops.aten_lerp,
906+
tolerance={torch.float16: (2e-3, 2e-1)},
907+
),
903908
TorchLibOpInfo("log10", core_ops.aten_log10),
904909
TorchLibOpInfo("log1p", core_ops.aten_log1p),
905910
TorchLibOpInfo(
@@ -1020,6 +1025,11 @@ def _where_input_wrangler(
10201025
TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True),
10211026
TorchLibOpInfo("mul", core_ops.aten_mul),
10221027
TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True),
1028+
TorchLibOpInfo(
1029+
"mv",
1030+
core_ops.aten_mv,
1031+
tolerance={torch.float16: (3e-2, 1e-2)},
1032+
),
10231033
TorchLibOpInfo("narrow", core_ops.aten_narrow),
10241034
TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout),
10251035
TorchLibOpInfo("ne", core_ops.aten_ne),

0 commit comments

Comments
 (0)