-
Notifications
You must be signed in to change notification settings - Fork 136
Add two new requirements to TensorArrayProtocol #165
Changes from 3 commits
99b8d03
2eeb847
96b0f5a
e0423fe
5a4d789
75c98e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -32,8 +32,10 @@ public protocol TensorArrayProtocol { | |||||
|
|
||||||
| var _tensorHandleCount: Int32 { get } | ||||||
| var _typeList: [TensorDataType] { get } | ||||||
| var _tensorHandles: [_AnyTensorHandle] { get } | ||||||
|
|
||||||
| init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) | ||||||
| init(handles: [_AnyTensorHandle]) | ||||||
|
||||||
| init(handles: [_AnyTensorHandle]) | |
| init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is redundant because TensorGroup inherits from TensorArrayProtocol.
| init(handles: [_AnyTensorHandle]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The map(_:) method already produces an array. No need to call the array initializer.
| self = Array((0..<size).map { | |
| self = (0..<size).map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix indentation here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible { | |
| self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
| } | ||
|
|
||
| public init(handle: _AnyTensorHandle) { | ||
| self.handle = handle | ||
| } | ||
|
|
||
| @usableFromInline | ||
| init(copyingFromCTensor cTensor: CTensor) { | ||
| let status = TF_NewStatus() | ||
|
|
@@ -105,7 +109,7 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible { | |
| extension TensorHandle where Scalar: TensorFlowScalar { | ||
| /// Create a `TensorHandle` with a closure that initializes the underlying buffer. | ||
| /// | ||
| /// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a | ||
| /// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a | ||
| /// tensor with shape `shape`. `scalarsInitializer` must initialize the entire buffer, with | ||
| /// contiguous scalars in row-major order. | ||
| @inlinable | ||
|
|
@@ -145,6 +149,11 @@ public struct ResourceHandle { | |
| init(owning cTensorHandle: CTensorHandle) { | ||
| self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
| } | ||
|
|
||
| @usableFromInline | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Drop
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| init(handle: _AnyTensorHandle) { | ||
| self.handle = handle | ||
| } | ||
| } | ||
|
|
||
| public struct VariantHandle { | ||
|
|
@@ -157,4 +166,9 @@ public struct VariantHandle { | |
| init(owning cTensorHandle: CTensorHandle) { | ||
| self.handle = TFETensorHandle(_owning: cTensorHandle) | ||
| } | ||
|
|
||
| @usableFromInline | ||
| init(handle: _AnyTensorHandle) { | ||
| self.handle = handle | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -215,6 +215,16 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup { | |||||
| self.first = first | ||||||
| self.second = second | ||||||
| } | ||||||
|
|
||||||
| public var _tensorHandles: [_AnyTensorHandle] { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| first._tensorHandles + second._tensorHandles | ||||||
| } | ||||||
|
|
||||||
| public init(handles: [_AnyTensorHandle]) { | ||||||
| let firstEnd = Int(T._tensorHandleCount) | ||||||
| self.first = T.init(handles: Array(handles[0..<firstEnd])) | ||||||
| self.second = U.init(handles: Array(handles[firstEnd..<handles.count])) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| // TODO(SR-9156): This does not work in graph mode. | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be changed to
Intnow?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.