diff --git a/Sources/TensorFlow/Core/TensorGroup.swift b/Sources/TensorFlow/Core/TensorGroup.swift index 4390e0bf4..8a84c4d0d 100644 --- a/Sources/TensorFlow/Core/TensorGroup.swift +++ b/Sources/TensorFlow/Core/TensorGroup.swift @@ -32,8 +32,10 @@ public protocol TensorArrayProtocol { var _tensorHandleCount: Int32 { get } var _typeList: [TensorDataType] { get } + var _tensorHandles: [_AnyTensorHandle] { get } init(_owning tensorHandles: UnsafePointer?, count: Int) + init(_handles: C) where C.Element == _AnyTensorHandle } /// A protocol representing types that can be mapped to and from `Array`. @@ -88,6 +90,8 @@ extension TensorHandle: TensorGroup { return [Scalar.tensorFlowDataType] } + public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { address!.initialize(to: _cTensorHandle) } @@ -95,6 +99,12 @@ extension TensorHandle: TensorGroup { public init(_owning tensorHandles: UnsafePointer?) { self.init(_owning: tensorHandles!.pointee) } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 1) + self.init(handle: _handles[_handles.startIndex]) + } } extension ResourceHandle: TensorGroup { @@ -108,6 +118,8 @@ extension ResourceHandle: TensorGroup { return [TensorDataType(TF_RESOURCE)] } + public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { address!.initialize(to: _cTensorHandle) } @@ -115,6 +127,12 @@ extension ResourceHandle: TensorGroup { public init(_owning tensorHandles: UnsafePointer?) { self.init(owning: tensorHandles!.pointee) } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 1) + self.init(handle: _handles[_handles.startIndex]) + } } extension VariantHandle: TensorGroup { @@ -128,6 +146,8 @@ extension VariantHandle: TensorGroup { return [TensorDataType(TF_VARIANT)] } + public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { address!.initialize(to: _cTensorHandle) } @@ -135,6 +155,12 @@ extension VariantHandle: TensorGroup { public init(_owning tensorHandles: UnsafePointer?) { self.init(owning: tensorHandles!.pointee) } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 1) + self.init(handle: _handles[_handles.startIndex]) + } } extension Tensor: TensorGroup { @@ -152,9 +178,17 @@ extension Tensor: TensorGroup { address!.initialize(to: handle._cTensorHandle) } + public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } + public init(_owning tensorHandles: UnsafePointer?) { self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 1) + self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) + } } extension _TensorElementLiteral: TensorGroup { @@ -168,6 +202,8 @@ extension _TensorElementLiteral: TensorGroup { return [Scalar.tensorFlowDataType] } + public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } + public func _unpackTensorHandles(into address: UnsafeMutablePointer?) { address!.initialize(to: handle._cTensorHandle) } @@ -175,6 +211,12 @@ extension _TensorElementLiteral: TensorGroup { public init(_owning tensorHandles: UnsafePointer?) { self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 1) + self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) + } } extension StringTensor: TensorGroup { @@ -192,9 +234,17 @@ extension StringTensor: TensorGroup { address!.initialize(to: handle._cTensorHandle) } + public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } + public init(_owning tensorHandles: UnsafePointer?) { self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 1) + self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) + } } extension Array: TensorArrayProtocol where Element: TensorGroup { @@ -216,10 +266,31 @@ extension Array: TensorArrayProtocol where Element: TensorGroup { count: Int(count)).joined()) } + public var _tensorHandles: ([_AnyTensorHandle]) { + var result: [_AnyTensorHandle] = [] + result.reserveCapacity(Int(self._tensorHandleCount)) + for elem in self { + result += elem._tensorHandles + } + return result + } + public init(_owning tensorHandles: UnsafePointer?, count: Int) { let size = count / Int(Element._tensorHandleCount) self = Array((0..( + _handles: C) where C.Element == _AnyTensorHandle { + let size = _handles.count / Int(Element._tensorHandleCount) + self = (0.. where Scalar: _TensorFlowDataTypeCompatible { self.handle = TFETensorHandle(_owning: cTensorHandle) } + public init(handle: _AnyTensorHandle) { + self.handle = handle + } + @usableFromInline init(copyingFromCTensor cTensor: CTensor) { let status = TF_NewStatus() @@ -105,7 +109,7 @@ public struct TensorHandle where Scalar: _TensorFlowDataTypeCompatible { extension TensorHandle where Scalar: TensorFlowScalar { /// Create a `TensorHandle` with a closure that initializes the underlying buffer. /// - /// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a + /// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a /// tensor with shape `shape`. `scalarsInitializer` must initialize the entire buffer, with /// contiguous scalars in row-major order. @inlinable @@ -145,6 +149,11 @@ public struct ResourceHandle { init(owning cTensorHandle: CTensorHandle) { self.handle = TFETensorHandle(_owning: cTensorHandle) } + + @usableFromInline + init(handle: _AnyTensorHandle) { + self.handle = handle + } } public struct VariantHandle { @@ -157,4 +166,9 @@ public struct VariantHandle { init(owning cTensorHandle: CTensorHandle) { self.handle = TFETensorHandle(_owning: cTensorHandle) } + + @usableFromInline + init(handle: _AnyTensorHandle) { + self.handle = handle + } } diff --git a/Sources/TensorFlow/Operators/Dataset.swift b/Sources/TensorFlow/Operators/Dataset.swift index 1338a384e..ac9cf5f8d 100644 --- a/Sources/TensorFlow/Operators/Dataset.swift +++ b/Sources/TensorFlow/Operators/Dataset.swift @@ -215,6 +215,19 @@ public struct Zip2TensorGroup: TensorGroup { self.first = first self.second = second } + + public var _tensorHandles: [_AnyTensorHandle] { + first._tensorHandles + second._tensorHandles + } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + let firstStart = _handles.startIndex + let firstEnd = _handles.index( + firstStart, offsetBy: Int(T._tensorHandleCount)) + self.first = T.init(_handles: _handles[firstStart.. let b: TensorHandle + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 2) + let aIndex = _handles.startIndex + let bIndex = _handles.index(aIndex, offsetBy: 1) + a = TensorHandle(handle: _handles[aIndex]) + b = TensorHandle(handle: _handles[bIndex]) + } + + public var _tensorHandles: [_AnyTensorHandle] { [a.handle, b.handle] } } final class DatasetTests: XCTestCase { diff --git a/Tests/TensorFlowTests/TensorGroupTests.swift b/Tests/TensorFlowTests/TensorGroupTests.swift index 2cfd168f2..99d246098 100644 --- a/Tests/TensorFlowTests/TensorGroupTests.swift +++ b/Tests/TensorFlowTests/TensorGroupTests.swift @@ -22,10 +22,31 @@ extension TensorDataType : Equatable { } } -struct Empty : TensorGroup {} +struct Empty : TensorGroup { + init() {} + public init( + _handles: C) where C.Element == _AnyTensorHandle {} + public var _tensorHandles: [_AnyTensorHandle] { [] } +} struct Simple : TensorGroup, Equatable { var w, b: Tensor + + init(w: Tensor, b: Tensor) { + self.w = w + self.b = b + } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 2) + let wIndex = _handles.startIndex + let bIndex = _handles.index(wIndex, offsetBy: 1) + w = Tensor(handle: TensorHandle(handle: _handles[wIndex])) + b = Tensor(handle: TensorHandle(handle: _handles[bIndex])) + } + + public var _tensorHandles: [_AnyTensorHandle] { [w.handle.handle, b.handle.handle] } } struct Mixed : TensorGroup, Equatable { @@ -33,6 +54,26 @@ struct Mixed : TensorGroup, Equatable { var float: Tensor // Immutable. let int: Tensor + + init(float: Tensor, int: Tensor) { + self.float = float + self.int = int + } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + precondition(_handles.count == 2) + let floatIndex = _handles.startIndex + let intIndex = _handles.index(floatIndex, offsetBy: 1) + float = Tensor( + handle: TensorHandle(handle: _handles[floatIndex])) + int = Tensor( + handle: TensorHandle(handle: _handles[intIndex])) + } + + public var _tensorHandles: [_AnyTensorHandle] { + [float.handle.handle, int.handle.handle] + } } struct Nested : TensorGroup, Equatable { @@ -40,16 +81,85 @@ struct Nested : TensorGroup, Equatable { let simple: Simple // Mutable. var mixed: Mixed + + init(simple: Simple, mixed: Mixed) { + self.simple = simple + self.mixed = mixed + } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + let simpleStart = _handles.startIndex + let simpleEnd = _handles.index( + simpleStart, offsetBy: Int(Simple._tensorHandleCount)) + simple = Simple(_handles: _handles[simpleStart.. : TensorGroup, Equatable { var t: T var u: U + + public init(t: T, u: U) { + self.t = t + self.u = u + } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + let tStart = _handles.startIndex + let tEnd = _handles.index(tStart, offsetBy: Int(T._tensorHandleCount)) + t = T.init(_handles: _handles[tStart.. + : TensorGroup, Equatable { + var a: Generic + var b: Generic + + init(a: Generic, b: Generic) { + self.a = a + self.b = b + } + + public init( + _handles: C) where C.Element == _AnyTensorHandle { + let firstStart = _handles.startIndex + let firstEnd = _handles.index( + firstStart, offsetBy: Int(Generic._tensorHandleCount)) + a = Generic.init(_handles: _handles[firstStart...init(_handles: _handles[firstEnd..<_handles.endIndex]) + } + + public var _tensorHandles: [_AnyTensorHandle] { + return a._tensorHandles + b._tensorHandles + } +} + +func copy(of handle: TensorHandle) -> _AnyTensorHandle { + let status = TF_NewStatus() + let result = TFETensorHandle(_owning: TFE_TensorHandleCopySharingTensor( + handle._cTensorHandle, status)!) + XCTAssertEqual(TF_GetCode(status), TF_OK) + TF_DeleteStatus(status) + return result } final class TensorGroupTests: XCTestCase { func testEmptyList() { XCTAssertEqual([], Empty._typeList) + XCTAssertEqual(Empty()._tensorHandles.count, 0) } func testSimpleTypeList() { @@ -62,17 +172,10 @@ final class TensorGroupTests: XCTestCase { let b = Tensor(0.1) let simple = Simple(w: w, b: b) - let status = TF_NewStatus() - let wHandle = TFE_TensorHandleCopySharingTensor( - w.handle._cTensorHandle, status)! - let bHandle = TFE_TensorHandleCopySharingTensor( - b.handle._cTensorHandle, status)! - TF_DeleteStatus(status) + let wHandle = copy(of: w.handle) + let bHandle = copy(of: b.handle) - let buffer = UnsafeMutableBufferPointer.allocate( - capacity: 2) - let _ = buffer.initialize(from: [wHandle, bHandle]) - let expectedSimple = Simple(_owning: UnsafePointer(buffer.baseAddress)) + let expectedSimple = Simple(_handles: [wHandle, bHandle]) XCTAssertEqual(expectedSimple, simple) } @@ -88,17 +191,10 @@ final class TensorGroupTests: XCTestCase { let int = Tensor(1) let mixed = Mixed(float: float, int: int) - let status = TF_NewStatus() - let floatHandle = TFE_TensorHandleCopySharingTensor( - float.handle._cTensorHandle, status)! - let intHandle = TFE_TensorHandleCopySharingTensor( - int.handle._cTensorHandle, status)! - TF_DeleteStatus(status) + let floatHandle = copy(of: float.handle) + let intHandle = copy(of: int.handle) - let buffer = UnsafeMutableBufferPointer.allocate( - capacity: 2) - let _ = buffer.initialize(from: [floatHandle, intHandle]) - let expectedMixed = Mixed(_owning: UnsafePointer(buffer.baseAddress)) + let expectedMixed = Mixed(_handles: [floatHandle, intHandle]) XCTAssertEqual(expectedMixed, mixed) } @@ -118,24 +214,14 @@ final class TensorGroupTests: XCTestCase { let mixed = Mixed(float: float, int: int) let nested = Nested(simple: simple, mixed: mixed) - let status = TF_NewStatus() - let wHandle = TFE_TensorHandleCopySharingTensor( - w.handle._cTensorHandle, status)! - let bHandle = TFE_TensorHandleCopySharingTensor( - b.handle._cTensorHandle, status)! - let floatHandle = TFE_TensorHandleCopySharingTensor( - float.handle._cTensorHandle, status)! - let intHandle = TFE_TensorHandleCopySharingTensor( - int.handle._cTensorHandle, status)! - TF_DeleteStatus(status) - - let buffer = UnsafeMutableBufferPointer.allocate( - capacity: 4) - let _ = buffer.initialize( - from: [wHandle, bHandle, floatHandle, intHandle]) - let expectedNested = Nested( - _owning: UnsafePointer(buffer.baseAddress)) + let wHandle = copy(of: w.handle) + let bHandle = copy(of: b.handle) + let floatHandle = copy(of: float.handle) + let intHandle = copy(of: int.handle) + let expectedNested = Nested( + _handles: [wHandle, bHandle, floatHandle, intHandle]) + XCTAssertEqual(expectedNested, nested) } @@ -155,23 +241,13 @@ final class TensorGroupTests: XCTestCase { let mixed = Mixed(float: float, int: int) let generic = Generic(t: simple, u: mixed) - let status = TF_NewStatus() - let wHandle = TFE_TensorHandleCopySharingTensor( - w.handle._cTensorHandle, status)! - let bHandle = TFE_TensorHandleCopySharingTensor( - b.handle._cTensorHandle, status)! - let floatHandle = TFE_TensorHandleCopySharingTensor( - float.handle._cTensorHandle, status)! - let intHandle = TFE_TensorHandleCopySharingTensor( - int.handle._cTensorHandle, status)! - TF_DeleteStatus(status) - - let buffer = UnsafeMutableBufferPointer.allocate( - capacity: 4) - let _ = buffer.initialize( - from: [wHandle, bHandle, floatHandle, intHandle]) + let wHandle = copy(of: w.handle) + let bHandle = copy(of: b.handle) + let floatHandle = copy(of: float.handle) + let intHandle = copy(of: int.handle) + let expectedGeneric = Generic( - _owning: UnsafePointer(buffer.baseAddress)) + _handles: [wHandle, bHandle, floatHandle, intHandle]) XCTAssertEqual(expectedGeneric, generic) } @@ -179,12 +255,6 @@ final class TensorGroupTests: XCTestCase { func testNestedGenericTypeList() { struct NestedGeneric { func function() { - struct UltraNested< - T: TensorGroup & Equatable, V: TensorGroup & Equatable> - : TensorGroup, Equatable { - var a: Generic - var b: Generic - } let float = Float.tensorFlowDataType let int = Int32.tensorFlowDataType XCTAssertEqual([float, float, float, int, float, int, float, float], @@ -198,13 +268,6 @@ final class TensorGroupTests: XCTestCase { func testNestedGenericInit() { struct NestedGeneric { func function() { - struct UltraNested< - T: TensorGroup & Equatable, V: TensorGroup & Equatable> - : TensorGroup, Equatable { - var a: Generic - var b: Generic - } - let w = Tensor(0.1) let b = Tensor(0.1) let simple = Simple(w: w, b: b) @@ -215,32 +278,18 @@ final class TensorGroupTests: XCTestCase { let genericMS = Generic(t: mixed, u: simple) let generic = UltraNested(a: genericSM, b: genericMS) - let status = TF_NewStatus() - let wHandle1 = TFE_TensorHandleCopySharingTensor( - w.handle._cTensorHandle, status)! - let wHandle2 = TFE_TensorHandleCopySharingTensor( - w.handle._cTensorHandle, status)! - let bHandle1 = TFE_TensorHandleCopySharingTensor( - b.handle._cTensorHandle, status)! - let bHandle2 = TFE_TensorHandleCopySharingTensor( - b.handle._cTensorHandle, status)! - let floatHandle1 = TFE_TensorHandleCopySharingTensor( - float.handle._cTensorHandle, status)! - let floatHandle2 = TFE_TensorHandleCopySharingTensor( - float.handle._cTensorHandle, status)! - let intHandle1 = TFE_TensorHandleCopySharingTensor( - int.handle._cTensorHandle, status)! - let intHandle2 = TFE_TensorHandleCopySharingTensor( - int.handle._cTensorHandle, status)! - TF_DeleteStatus(status) - - let buffer = UnsafeMutableBufferPointer.allocate( - capacity: 8) - let _ = buffer.initialize( - from: [wHandle1, bHandle1, floatHandle1, intHandle1, - floatHandle2, intHandle2, wHandle2, bHandle2]) + let wHandle1 = copy(of: w.handle) + let wHandle2 = copy(of: w.handle) + let bHandle1 = copy(of: b.handle) + let bHandle2 = copy(of: b.handle) + let floatHandle1 = copy(of: float.handle) + let floatHandle2 = copy(of: float.handle) + let intHandle1 = copy(of: int.handle) + let intHandle2 = copy(of: int.handle) + let expectedGeneric = UltraNested( - _owning: UnsafePointer(buffer.baseAddress)) + _handles: [wHandle1, bHandle1, floatHandle1, intHandle1, + floatHandle2, intHandle2, wHandle2, bHandle2]) XCTAssertEqual(expectedGeneric, generic) }