diff --git a/stdlib/public/TensorFlow/ArrayOps.swift b/stdlib/public/TensorFlow/ArrayOps.swift index 82a345157af05..619d2ad539171 100644 --- a/stdlib/public/TensorFlow/ArrayOps.swift +++ b/stdlib/public/TensorFlow/ArrayOps.swift @@ -110,6 +110,47 @@ public extension Raw { return out } + /// Splits a tensor into `numSplit` tensors along one dimension. + /// + /// - Parameters: + /// - splitDim: 0-D. The dimension along which to split. Must be in the range + /// `[-rank(value), rank(value))`. + /// - value: The tensor to split. + /// - numSplit: The number of splits to create. + /// + /// - Returns: Tensors whose shape matches that of `value` + /// except along `axis`, where their sizes are + /// `value.shape[axis] / numSplit`. + @inlinable @inline(__always) + static func split( + splitDim: Tensor, + value: Tensor, + numSplit: Int64 + ) -> [Tensor] { + let s: CTFStatus = TF_NewStatus() + defer { TF_DeleteStatus(s) } + let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "Split", s) + defer { TFE_DeleteOp(op) } + let _ = _TFCOpAddInputFromTensorGroup(op, splitDim, s) + let _ = _TFCOpAddInputFromTensorGroup(op, value, s) + TFE_OpSetAttrInt(op, "num_split", numSplit) + TFE_OpSetAttrType(op, "T", T.tensorFlowDataType._cDataType) + var count: Int32 = Int32(numSplit) + let buffer: UnsafeMutablePointer = + UnsafeMutablePointer.allocate(capacity: Int(count)) + defer { buffer.deallocate() } + _TFCEagerExecute(op, UnsafeMutablePointer(buffer), &count, s) + checkOk(s) + + var out: [Tensor] = [] + var cursor = buffer + for _ in 0..(handle: TensorHandle(_owning: cursor.pointee))) + cursor = cursor.advanced(by: 1) + } + return out + } + /// Splits a tensor into `numSplit` tensors along one dimension. /// /// - Parameters: diff --git a/stdlib/public/TensorFlow/Tensor.swift b/stdlib/public/TensorFlow/Tensor.swift index 05b66a867ad6c..f0e3972e7771f 100644 --- a/stdlib/public/TensorFlow/Tensor.swift +++ b/stdlib/public/TensorFlow/Tensor.swift @@ -206,11 +206,23 @@ public extension Tensor { public extension Tensor where Scalar : Numeric { /// Perform an element-wise conversion from another `Tensor`. @inlinable @inline(__always) + @differentiable( + vjp: _vjpCast where Scalar : TensorFlowFloatingPoint, + OtherScalar: TensorFlowFloatingPoint) init(_ other: Tensor) { self = Raw.cast(other) } } +internal extension Tensor where Scalar : TensorFlowFloatingPoint { + @inlinable + static func _vjpCast( + _ other: Tensor + ) -> (Tensor, (Tensor) -> Tensor) { + return (Tensor(other), { v in Tensor(v) }) + } +} + public extension Tensor { /// Creates a tensor from a scalar value. @inlinable @inline(__always)