diff --git a/Sources/TensorFlow/Core/Tensor.swift b/Sources/TensorFlow/Core/Tensor.swift index 06fb869fb..55b88ca20 100644 --- a/Sources/TensorFlow/Core/Tensor.swift +++ b/Sources/TensorFlow/Core/Tensor.swift @@ -500,11 +500,9 @@ extension Tensor: Codable where Scalar: Codable { //===------------------------------------------------------------------------------------------===// extension Tensor: AdditiveArithmetic where Scalar: Numeric { - /// A scalar zero tensor. + /// The scalar zero tensor. @inlinable - public static var zero: Tensor { - return Tensor(0) - } + public static var zero: Tensor { Tensor(0) } /// Adds two tensors and produces their sum. /// - Note: `+` supports broadcasting. @@ -553,6 +551,26 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint { } } +//===------------------------------------------------------------------------------------------===// +// Multiplicative Group +//===------------------------------------------------------------------------------------------===// + +extension Tensor: PointwiseMultiplicative where Scalar: Numeric { + /// The scalar one tensor. + @inlinable + public static var one: Tensor { Tensor(1) } + + /// Returns the pointwise reciprocal of `self`. + @inlinable + public var reciprocal: Tensor { 1 / self } + + /// Multiplies two tensors element-wise and produces their product. + /// - Note: `.*` supports broadcasting. + public static func .* (lhs: Tensor, rhs: Tensor) -> Tensor { + return lhs * rhs + } +} + //===------------------------------------------------------------------------------------------===// // Differentiable //===------------------------------------------------------------------------------------------===//