@@ -2236,23 +2236,13 @@ def aten_cov(
2236
2236
raise NotImplementedError ()
2237
2237
2238
2238
2239
- @torch_op ("aten::cross" )
2239
+ @torch_op (( "aten::cross" , "aten::linalg_cross" ) )
2240
2240
def aten_cross (self : TTensor , other : TTensor , dim : int = - 1 ) -> TTensor :
2241
2241
"""cross(Tensor self, Tensor other, int? dim=None) -> Tensor"""
2242
2242
2243
- zero = op .Constant (value_ints = [0 ])
2244
- one = op .Constant (value_ints = [1 ])
2245
- two = op .Constant (value_ints = [2 ])
2246
- three = op .Constant (value_ints = [3 ])
2247
- axes = op .Expand (dim , op .Constant (value_ints = [1 ]))
2248
-
2249
2243
# Reference https://en.wikipedia.org/w/index.php?title=Cross_product&oldid=1143125073
2250
- a1 = op .Slice (self , zero , one , axes )
2251
- a2 = op .Slice (self , one , two , axes )
2252
- a3 = op .Slice (self , two , three , axes )
2253
- b1 = op .Slice (other , zero , one , axes )
2254
- b2 = op .Slice (other , one , two , axes )
2255
- b3 = op .Slice (other , two , three , axes )
2244
+ a1 , a2 , a3 = op .Split (self , axis = dim , num_outputs = 3 )
2245
+ b1 , b2 , b3 = op .Split (other , axis = dim , num_outputs = 3 )
2256
2246
# Broadcasting is implicitly supported by Mul
2257
2247
c1 = op .Sub (op .Mul (a2 , b3 ), op .Mul (a3 , b2 ))
2258
2248
c2 = op .Sub (op .Mul (a3 , b1 ), op .Mul (a1 , b3 ))
@@ -3571,7 +3561,7 @@ def aten_fmin(self: TensorType, other: TensorType) -> TensorType:
3571
3561
raise NotImplementedError ()
3572
3562
3573
3563
3574
- @torch_op ("aten::fmod" )
3564
+ @torch_op (( "aten::fmod.Tensor" , "aten::fmod.Scalar" ) )
3575
3565
def aten_fmod (self : TRealOrUInt8 , other : TRealOrUInt8 ) -> TRealOrUInt8 :
3576
3566
"""fmod.Tensor(Tensor self, Tensor other) -> Tensor"""
3577
3567
@@ -4659,7 +4649,7 @@ def aten_le(self: TReal, other: TReal) -> BOOL:
4659
4649
return op .LessOrEqual (self , other )
4660
4650
4661
4651
4662
- @torch_op (("aten::le.Tensor" , "aten::less_equal.Tensor" , "_operator::le" ))
4652
+ @torch_op (("aten::le.Tensor" , "aten::le.Scalar" , "aten:: less_equal.Tensor" , "_operator::le" ))
4663
4653
def aten_le_bool (self : BOOL , other : BOOL ) -> BOOL :
4664
4654
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""
4665
4655
@@ -4672,10 +4662,17 @@ def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
4672
4662
return op .Or (other , op .Not (self ))
4673
4663
4674
4664
4675
- def aten_lerp (self : TensorType , end : TensorType , weight : TensorType ) -> TensorType :
4665
+ @torch_op (("aten::lerp.Tensor" , "aten::lerp.Scalar" ))
4666
+ def aten_lerp (self : TTensor , end : TTensor , weight : TTensor ) -> TTensor :
4676
4667
"""lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"""
4677
4668
4678
- raise NotImplementedError ()
4669
+ weight = op .CastLike (weight , self )
4670
+ diff = op .Sub (end , self )
4671
+ return op .Where (
4672
+ op .Less (weight , 0.5 ),
4673
+ op .Add (self , op .Mul (weight , diff )),
4674
+ op .Sub (end , op .Mul (diff , op .Sub (1.0 , weight ))),
4675
+ )
4679
4676
4680
4677
4681
4678
def aten_lgamma (self : TensorType ) -> TensorType :
@@ -5619,10 +5616,11 @@ def aten_multiply(self: TensorType, other: TensorType) -> TensorType:
5619
5616
raise NotImplementedError ()
5620
5617
5621
5618
5619
+ @torch_op ("aten::mv" )
5622
5620
def aten_mv (self : TensorType , vec : TensorType ) -> TensorType :
5623
5621
"""mv(Tensor self, Tensor vec) -> Tensor"""
5624
5622
5625
- raise NotImplementedError ( )
5623
+ return op . MatMul ( self , vec )
5626
5624
5627
5625
5628
5626
def aten_mvlgamma (self : TensorType , p : int ) -> TensorType :
@@ -7011,7 +7009,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
7011
7009
raise NotImplementedError ()
7012
7010
7013
7011
7014
- @torch_op ("aten::remainder" )
7012
+ @torch_op (( "aten::remainder.Tensor" , "aten::remainder.Scalar" ) )
7015
7013
def aten_remainder (self : TFloatOrBFloat16 , other : TFloatOrBFloat16 ) -> TFloatOrBFloat16 :
7016
7014
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7017
7015
@@ -7024,7 +7022,7 @@ def aten_remainder(self: TFloatOrBFloat16, other: TFloatOrBFloat16) -> TFloatOrB
7024
7022
return op .Sub (self , op .Mul (rounded_quotient , other ))
7025
7023
7026
7024
7027
- @torch_op ("aten::remainder" )
7025
+ @torch_op (( "aten::remainder.Tensor" , "aten::remainder.Scalar" ) )
7028
7026
def aten_remainder_int (self : TInt , other : TInt ) -> TInt :
7029
7027
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7030
7028
@@ -8533,10 +8531,11 @@ def aten_unsafe_chunk(self: TensorType, chunks: int, dim: int = 0) -> TensorType
8533
8531
raise NotImplementedError ()
8534
8532
8535
8533
8536
- def aten_unsafe_split (self : TensorType , split_size : INT64 , dim : int = 0 ) -> TensorType :
8534
+ @torch_op (("aten::unsafe_split" , "aten::unsafe_split.Tensor" ))
8535
+ def aten_unsafe_split (self : TTensor , split_size : INT64 , dim : int = 0 ) -> Sequence [TTensor ]:
8537
8536
"""unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"""
8538
8537
8539
- raise NotImplementedError ( )
8538
+ return op . SplitToSequence ( self , split_size , axis = dim )
8540
8539
8541
8540
8542
8541
def aten_unsafe_split_with_sizes (
0 commit comments