Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions stdlib/public/TensorFlow/ArrayOps.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,47 @@ public extension Raw {
return out
}

/// Splits a tensor into `numSplit` tensors along one dimension.
///
/// - Parameters:
/// - splitDim: 0-D. The dimension along which to split. Must be in the range
/// `[-rank(value), rank(value))`.
/// - value: The tensor to split.
/// - numSplit: The number of splits to create.
///
/// - Returns: Tensors whose shape matches that of `value`
/// except along `axis`, where their sizes are
/// `value.shape[axis] / numSplit`.
@inlinable @inline(__always)
static func split<T: TensorFlowScalar>(
splitDim: Tensor<Int32>,
value: Tensor<T>,
numSplit: Int64
) -> [Tensor<T>] {
let s: CTFStatus = TF_NewStatus()
defer { TF_DeleteStatus(s) }
let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "Split", s)
defer { TFE_DeleteOp(op) }
let _ = _TFCOpAddInputFromTensorGroup(op, splitDim, s)
let _ = _TFCOpAddInputFromTensorGroup(op, value, s)
TFE_OpSetAttrInt(op, "num_split", numSplit)
TFE_OpSetAttrType(op, "T", T.tensorFlowDataType._cDataType)
var count: Int32 = Int32(numSplit)
let buffer: UnsafeMutablePointer<CTensorHandle> =
UnsafeMutablePointer.allocate(capacity: Int(count))
defer { buffer.deallocate() }
_TFCEagerExecute(op, UnsafeMutablePointer<CTensorHandle?>(buffer), &count, s)
checkOk(s)

var out: [Tensor<T>] = []
var cursor = buffer
for _ in 0..<numSplit {
out.append(Tensor<T>(handle: TensorHandle(_owning: cursor.pointee)))
cursor = cursor.advanced(by: 1)
}
return out
}

/// Splits a tensor into `numSplit` tensors along one dimension.
///
/// - Parameters:
Expand Down
12 changes: 12 additions & 0 deletions stdlib/public/TensorFlow/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,23 @@ public extension Tensor {
public extension Tensor where Scalar : Numeric {
/// Perform an element-wise conversion from another `Tensor`.
@inlinable @inline(__always)
@differentiable(
vjp: _vjpCast where Scalar : TensorFlowFloatingPoint,
OtherScalar: TensorFlowFloatingPoint)
init<OtherScalar : Numeric>(_ other: Tensor<OtherScalar>) {
self = Raw.cast(other)
}
}

internal extension Tensor where Scalar : TensorFlowFloatingPoint {
@inlinable
static func _vjpCast<OtherScalar : TensorFlowFloatingPoint>(
_ other: Tensor<OtherScalar>
) -> (Tensor, (Tensor) -> Tensor<OtherScalar>) {
return (Tensor(other), { v in Tensor<OtherScalar>(v) })
}
}

public extension Tensor {
/// Creates a tensor from a scalar value.
@inlinable @inline(__always)
Expand Down