-
Notifications
You must be signed in to change notification settings - Fork 136
Add two new requirements to TensorArrayProtocol #165
Changes from all commits
99b8d03
2eeb847
96b0f5a
e0423fe
5a4d789
75c98e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol { | |||||||
|
|
||||||||
| var _tensorHandleCount: Int32 { get } | ||||||||
| var _typeList: [TensorDataType] { get } | ||||||||
| var _tensorHandles: [_AnyTensorHandle] { get } | ||||||||
|
|
||||||||
| init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) | ||||||||
| init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle | ||||||||
| } | ||||||||
|
|
||||||||
| /// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`. | ||||||||
|
|
@@ -88,13 +90,21 @@ extension TensorHandle: TensorGroup { | |||||||
| return [Scalar.tensorFlowDataType] | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } | ||||||||
|
|
||||||||
| public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
| address!.initialize(to: _cTensorHandle) | ||||||||
| } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
| self.init(_owning: tensorHandles!.pointee) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| precondition(_handles.count == 1) | ||||||||
| self.init(handle: _handles[_handles.startIndex]) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| extension ResourceHandle: TensorGroup { | ||||||||
|
|
@@ -108,13 +118,21 @@ extension ResourceHandle: TensorGroup { | |||||||
| return [TensorDataType(TF_RESOURCE)] | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } | ||||||||
|
|
||||||||
| public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
| address!.initialize(to: _cTensorHandle) | ||||||||
| } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
| self.init(owning: tensorHandles!.pointee) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| precondition(_handles.count == 1) | ||||||||
| self.init(handle: _handles[_handles.startIndex]) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| extension VariantHandle: TensorGroup { | ||||||||
|
|
@@ -128,13 +146,21 @@ extension VariantHandle: TensorGroup { | |||||||
| return [TensorDataType(TF_VARIANT)] | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [self.handle] } | ||||||||
|
|
||||||||
| public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
| address!.initialize(to: _cTensorHandle) | ||||||||
| } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
| self.init(owning: tensorHandles!.pointee) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
| precondition(_handles.count == 1) | ||||||||
| self.init(handle: _handles[_handles.startIndex]) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| extension Tensor: TensorGroup { | ||||||||
|
|
@@ -152,9 +178,17 @@ extension Tensor: TensorGroup { | |||||||
| address!.initialize(to: handle._cTensorHandle) | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
| self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| precondition(_handles.count == 1) | ||||||||
| self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| extension _TensorElementLiteral: TensorGroup { | ||||||||
|
|
@@ -168,13 +202,21 @@ extension _TensorElementLiteral: TensorGroup { | |||||||
| return [Scalar.tensorFlowDataType] | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } | ||||||||
|
|
||||||||
| public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) { | ||||||||
| address!.initialize(to: handle._cTensorHandle) | ||||||||
| } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
| self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| precondition(_handles.count == 1) | ||||||||
| self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| extension StringTensor: TensorGroup { | ||||||||
|
|
@@ -192,9 +234,17 @@ extension StringTensor: TensorGroup { | |||||||
| address!.initialize(to: handle._cTensorHandle) | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) { | ||||||||
| self.init(handle: TensorHandle(_owning: tensorHandles!.pointee)) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| precondition(_handles.count == 1) | ||||||||
| self.init(handle: TensorHandle(handle: _handles[_handles.startIndex])) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| extension Array: TensorArrayProtocol where Element: TensorGroup { | ||||||||
|
|
@@ -216,10 +266,31 @@ extension Array: TensorArrayProtocol where Element: TensorGroup { | |||||||
| count: Int(count)).joined()) | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: ([_AnyTensorHandle]) { | ||||||||
| var result: [_AnyTensorHandle] = [] | ||||||||
| result.reserveCapacity(Int(self._tensorHandleCount)) | ||||||||
| for elem in self { | ||||||||
| result += elem._tensorHandles | ||||||||
| } | ||||||||
| return result | ||||||||
| } | ||||||||
|
|
||||||||
| public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) { | ||||||||
| let size = count / Int(Element._tensorHandleCount) | ||||||||
| self = Array((0..<size).map { Element.init( | ||||||||
| _owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount))) | ||||||||
| }) | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| let size = _handles.count / Int(Element._tensorHandleCount) | ||||||||
| self = (0..<size).map { | ||||||||
| let start = _handles.index( | ||||||||
| _handles.startIndex, offsetBy: $0 * Int(Element._tensorHandleCount)) | ||||||||
| let end = _handles.index( | ||||||||
| start, offsetBy: Int(Element._tensorHandleCount)) | ||||||||
| return Element.init(_handles: _handles[start..<end]) | ||||||||
| } | ||||||||
| } | ||||||||
| } | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible { | |
| self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
| } | ||
|
|
||
| public init(handle: _AnyTensorHandle) { | ||
| self.handle = handle | ||
| } | ||
|
|
||
| @usableFromInline | ||
| init(copyingFromCTensor cTensor: CTensor) { | ||
| let status = TF_NewStatus() | ||
|
|
@@ -105,7 +109,7 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible { | |
| extension TensorHandle where Scalar: TensorFlowScalar { | ||
| /// Create a `TensorHandle` with a closure that initializes the underlying buffer. | ||
| /// | ||
| /// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a | ||
| /// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a | ||
| /// tensor with shape `shape`. `scalarsInitializer` must initialize the entire buffer, with | ||
| /// contiguous scalars in row-major order. | ||
| @inlinable | ||
|
|
@@ -145,6 +149,11 @@ public struct ResourceHandle { | |
| init(owning cTensorHandle: CTensorHandle) { | ||
| self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
| } | ||
|
|
||
| @usableFromInline | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| init(handle: _AnyTensorHandle) { | ||
| self.handle = handle | ||
| } | ||
| } | ||
|
|
||
| public struct VariantHandle { | ||
|
|
@@ -157,4 +166,9 @@ public struct VariantHandle { | |
| init(owning cTensorHandle: CTensorHandle) { | ||
| self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
| } | ||
|
|
||
| @usableFromInline | ||
| init(handle: _AnyTensorHandle) { | ||
| self.handle = handle | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -215,6 +215,19 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup { | |||||||
| self.first = first | ||||||||
| self.second = second | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| first._tensorHandles + second._tensorHandles | ||||||||
| } | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| let firstStart = _handles.startIndex | ||||||||
| let firstEnd = _handles.index( | ||||||||
| firstStart, offsetBy: Int(T._tensorHandleCount)) | ||||||||
| self.first = T.init(_handles: _handles[firstStart..<firstEnd]) | ||||||||
| self.second = U.init(_handles: _handles[firstEnd..<_handles.endIndex]) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // TODO(SR-9156): This does not work in graph mode. | ||||||||
|
|
||||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,6 +18,17 @@ import XCTest | |||||||
| struct SimpleOutput: TensorGroup { | ||||||||
| let a: TensorHandle<Int32> | ||||||||
| let b: TensorHandle<Int32> | ||||||||
|
|
||||||||
| public init<C: RandomAccessCollection>( | ||||||||
| _handles: C) where C.Element == _AnyTensorHandle { | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
| precondition(_handles.count == 2) | ||||||||
| let aIndex = _handles.startIndex | ||||||||
| let bIndex = _handles.index(aIndex, offsetBy: 1) | ||||||||
| a = TensorHandle<Int32>(handle: _handles[aIndex]) | ||||||||
| b = TensorHandle<Int32>(handle: _handles[bIndex]) | ||||||||
| } | ||||||||
|
|
||||||||
| public var _tensorHandles: [_AnyTensorHandle] { [a.handle, b.handle] } | ||||||||
| } | ||||||||
|
|
||||||||
| final class DatasetTests: XCTestCase { | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be changed to
Intnow?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.