diff --git a/stdlib/public/TensorFlow/ArrayOps.swift b/stdlib/public/TensorFlow/ArrayOps.swift new file mode 100644 index 0000000000000..7ff208c7528c0 --- /dev/null +++ b/stdlib/public/TensorFlow/ArrayOps.swift @@ -0,0 +1,112 @@ +//===-- ArrayOps.swift ----------------------------------------*- swift -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file contains some Array ops that cannot be properly handled by #tfop. +// +// TODO: These should be deleted once we can properly generate raw ops for these. +// +//===----------------------------------------------------------------------===// + +import CTensorFlow + +public extension Raw { + /// Saves tensors in V2 checkpoint format. + /// + /// By default, saves the named tensors in full. If the caller wishes to save + /// specific slices of full tensors, "shape_and_slices" should be non-empty strings + /// and correspondingly well-formed. + /// + /// - Parameters: + /// - prefix: Must have a single element. The prefix of the V2 checkpoint to which we + /// write the tensors. + /// - tensor_names: shape {N}. The names of the tensors to be saved. + /// - shape_and_slices: shape {N}. The slice specs of the tensors to be saved. + /// Empty strings indicate that they are non-partitioned tensors. + /// - tensors: `N` tensors to save. + @inlinable @inline(__always) + static func saveV2( + prefix: StringTensor, + tensorNames: StringTensor, + shapeAndSlices: StringTensor, + tensors: [AnyTensor] + ) { + let s: CTFStatus = TF_NewStatus() + defer { TF_DeleteStatus(s) } + let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "SaveV2", s) + defer { TFE_DeleteOp(op) } + let _ = _TFCOpAddInputFromTensorGroup(op, prefix, s) + let _ = _TFCOpAddInputFromTensorGroup(op, tensorNames, s) + let _ = _TFCOpAddInputFromTensorGroup(op, shapeAndSlices, s) + let _ = _TFCOpAddInputFromAnyTensors(op, tensors, s) + let _ = _TFCOpSetAttrTypeArray(op, "dtypes", tensors.map { $0._tensorFlowDataType }) + return _TFCExecuteOp(op, s) + } + + /// Restores tensors from a V2 checkpoint. + /// + /// For backward compatibility with the V1 format, this Op currently allows + /// restoring from a V1 checkpoint as well: + /// - This Op first attempts to find the V2 index file pointed to by "prefix", and + /// if found proceed to read it as a V2 checkpoint; + /// - Otherwise the V1 read path is invoked. + /// Relying on this behavior is not recommended, as the ability to fall back to read + /// V1 might be deprecated and eventually removed. + /// + /// By default, restores the named tensors in full. If the caller wishes to restore + /// specific slices of stored tensors, "shape_and_slices" should be non-empty + /// strings and correspondingly well-formed. + /// + /// Callers must ensure all the named tensors are indeed stored in the checkpoint. + /// + /// - Parameters: + /// - prefix: Must have a single element. The prefix of a V2 checkpoint. + /// - tensor_names: shape {N}. The names of the tensors to be restored. + /// - shape_and_slices: shape {N}. The slice specs of the tensors to be restored. + /// Empty strings indicate that they are non-partitioned tensors. + /// + /// - Attr dtypes: shape {N}. The list of expected dtype for the tensors. Must match + /// those stored in the checkpoint. + /// + /// - Output tensors: shape {N}. The restored tensors, whose shapes are read from the + /// checkpoint directly. + @inlinable @inline(__always) + static func restoreV2( + prefix: StringTensor, + tensorNames: StringTensor, + shapeAndSlices: StringTensor, + dtypes: [TensorDataType] + ) -> [AnyTensor] { + let s: CTFStatus = TF_NewStatus() + defer { TF_DeleteStatus(s) } + let op: CTFEOp = TFE_NewOp(_ExecutionContext.global.eagerContext, "RestoreV2", s) + defer { TFE_DeleteOp(op) } + let _ = _TFCOpAddInputFromTensorGroup(op, prefix, s) + let _ = _TFCOpAddInputFromTensorGroup(op, tensorNames, s) + let _ = _TFCOpAddInputFromTensorGroup(op, shapeAndSlices, s) + let _ = _TFCOpSetAttrTypeArray(op, "dtypes", dtypes) + + var count: Int32 = Int32(dtypes.count) + let buffer: UnsafeMutablePointer = + UnsafeMutablePointer.allocate(capacity: Int(count)) + defer { buffer.deallocate() } + _TFCEagerExecute(op, UnsafeMutablePointer(buffer), &count, s) + checkOk(s) + + var out: [AnyTensor] = [] + var cursor = buffer + for type in dtypes { + out.append(makeTensor(dataType: type, owning: cursor.pointee)) + cursor = cursor.advanced(by: 1) + } + return out + } +} diff --git a/stdlib/public/TensorFlow/CMakeLists.txt b/stdlib/public/TensorFlow/CMakeLists.txt index f099d82a44859..83c41fbe074e8 100644 --- a/stdlib/public/TensorFlow/CMakeLists.txt +++ b/stdlib/public/TensorFlow/CMakeLists.txt @@ -48,7 +48,9 @@ set(SOURCES TensorProtocol.swift TensorShape.swift Utilities.swift + ArrayOps.swift Threading.swift + ExecuteOp.swift.gyb # NumPy bridging for `ShapedArray` and `Tensor`. NumpyConversion.swift) diff --git a/stdlib/public/TensorFlow/CompilerRuntime.swift b/stdlib/public/TensorFlow/CompilerRuntime.swift index 9278f6534ec8f..3aca91dbd05d8 100644 --- a/stdlib/public/TensorFlow/CompilerRuntime.swift +++ b/stdlib/public/TensorFlow/CompilerRuntime.swift @@ -1703,6 +1703,29 @@ func _TFCOpAddInputFromTensorGroup( return count } +/// Special protocol for calling tensorflow operations that take heterogeneous +/// arrays as input. +public protocol AnyTensor { + var _rawTensorHandle: CTensorHandle { get } + var _tensorFlowDataType: TensorDataType { get } +} + +extension Tensor : AnyTensor { + public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle } + public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType } +} + +@usableFromInline +func _TFCOpAddInputFromAnyTensors( + _ op: CTFEOp, _ tensors: [AnyTensor], _ status: CTFStatus +) { + for tensor in tensors { + let handle = tensor._rawTensorHandle + TFE_OpAddInput(op, handle, status) + checkOk(status) + } +} + /// Initializes a TensorGroup value, taking ownership of all the tensor /// handles in `tensorHandles`. @usableFromInline diff --git a/stdlib/public/TensorFlow/DataTypes.swift b/stdlib/public/TensorFlow/DataTypes.swift index 565bb159ec0b7..4b466d09b9691 100644 --- a/stdlib/public/TensorFlow/DataTypes.swift +++ b/stdlib/public/TensorFlow/DataTypes.swift @@ -36,6 +36,43 @@ public struct TensorDataType { } } +@usableFromInline +internal func makeTensor( + dataType: TensorDataType, + owning pointer: CTensorHandle +) -> AnyTensor { + switch dataType._cDataType { + case TF_BOOL: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_INT8: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_UINT8: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_INT16: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_UINT16: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_INT32: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_UINT32: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_INT64: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_UINT64: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_BFLOAT16: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_FLOAT: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_DOUBLE: + return Tensor(handle: TensorHandle(_owning: pointer)) + case TF_STRING: + fatalError("StringTensor does not conform to AnyTensor") + default: + fatalError("Unhandled type: \(dataType)") + } +} + /// A data type compatible with TensorFlow. public protocol _TensorFlowDataTypeCompatible { /// The underlying TensorFlow data type. diff --git a/stdlib/public/TensorFlow/ExecuteOp.swift.gyb b/stdlib/public/TensorFlow/ExecuteOp.swift.gyb new file mode 100644 index 0000000000000..92e04ed5319f9 --- /dev/null +++ b/stdlib/public/TensorFlow/ExecuteOp.swift.gyb @@ -0,0 +1,48 @@ +//===-- ExecuteOp.swift.gyb -----------------------------------*- swift -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file contains _TFCExecuteOp which allows dispatching an op and +// returning an arbitrary set of tensor-groups. +// +// TODO: A nice wrapper for TFEOp could possibly make this simpler to use. This +// may need to be extended in order to work with multiple tfops. +// +//===----------------------------------------------------------------------===// + +@usableFromInline +func _TFCExecuteOp(_ op: CTFEOp, _ s: CTFStatus) { + var count: Int32 = 0 + var unused: CTensorHandle? + _TFCEagerExecute(op, &unused, &count, s) + checkOk(s) +} + +%for n in range(1, 11): +// Calls _TFCEagerExecute under the hood and unpacks into TensorGroup conforming +// types. +@usableFromInline +func _TFCExecuteOp<${", ".join(["T" + str(i) + " : TensorGroup" for i in range(n)])}> + (_ op: CTFEOp, _ s: CTFStatus) + -> (${", ".join(["T" + str(i) for i in range(n)])}) { + + var count: Int32 = ${" + ".join(["T" + str(i) + "._tensorHandleCount" for i in range(n)])} + let buffer: UnsafeMutablePointer = + UnsafeMutablePointer.allocate(capacity: Int(count)) + defer { buffer.deallocate() } + _TFCEagerExecute(op, UnsafeMutablePointer(buffer), &count, s) + checkOk(s) +%for i in range(n): +let off${i}: Int32 = ${"0" if i == 0 else "off" + str(i - 1) + " + T" + str(i - 1) + "._tensorHandleCount"} +%end + return (${", ".join(["T" + str(i) + ".init(_owning: buffer.advanced(by: Int(off" + str(i) + ")))" for i in range(n)])}) +} +%end