Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions Sources/TensorFlow/Core/TensorGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ public protocol TensorArrayProtocol {

var _tensorHandleCount: Int32 { get }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be changed to Int now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but will do so in a subsequent CL. I have updated https://bugs.swift.org/browse/TF-542 to reflect this.

var _typeList: [TensorDataType] { get }
var _tensorHandles: [_AnyTensorHandle] { get }

init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int)
init(handles: [_AnyTensorHandle])
Copy link
Contributor

@rxwei rxwei Jun 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requiring the argument to be an Array is not ideal, because in a lot of cases we are creating a tensor group from a slice of an array. Converting a slice to an Array creates an unnecessary copy. Instead, I think this should take a generic RandomAccessCollection.

Also, for consistency with other requirements, it's better for the first argument label of this initializer to start with an underscore.

Suggested change
init(handles: [_AnyTensorHandle])
init<C: RandomAccessCollection>(_handles: C) where C.Element == _AnyTensorHandle

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Done.

}

/// A protocol representing types that can be mapped to and from `Array<CTensorHandle>`.
Expand All @@ -51,6 +53,8 @@ public protocol TensorGroup: TensorArrayProtocol {
/// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors
/// starting at address `tensorHandles`.
init(_owning tensorHandles: UnsafePointer<CTensorHandle>?)

init(handles: [_AnyTensorHandle])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is redundant because TensorGroup inherits from TensorArrayProtocol.

Suggested change
init(handles: [_AnyTensorHandle])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

public extension TensorGroup {
Expand Down Expand Up @@ -88,13 +92,20 @@ extension TensorHandle: TensorGroup {
return [Scalar.tensorFlowDataType]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(_owning: tensorHandles!.pointee)
}

public init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 1)
self.init(handle: handles[0])
}
}

extension ResourceHandle: TensorGroup {
Expand All @@ -108,13 +119,20 @@ extension ResourceHandle: TensorGroup {
return [TensorDataType(TF_RESOURCE)]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(owning: tensorHandles!.pointee)
}

public init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 1)
self.init(handle: handles[0])
}
}

extension VariantHandle: TensorGroup {
Expand All @@ -128,13 +146,20 @@ extension VariantHandle: TensorGroup {
return [TensorDataType(TF_VARIANT)]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: _cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(owning: tensorHandles!.pointee)
}

public init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 1)
self.init(handle: handles[0])
}
}

extension Tensor: TensorGroup {
Expand All @@ -152,9 +177,16 @@ extension Tensor: TensorGroup {
address!.initialize(to: handle._cTensorHandle)
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}

public init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 1)
self.init(handle: TensorHandle(handle: handles[0]))
}
}

extension _TensorElementLiteral: TensorGroup {
Expand All @@ -168,13 +200,20 @@ extension _TensorElementLiteral: TensorGroup {
return [Scalar.tensorFlowDataType]
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }

public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
address!.initialize(to: handle._cTensorHandle)
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}

public init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 1)
self.init(handle: TensorHandle(handle: handles[0]))
}
}

extension StringTensor: TensorGroup {
Expand All @@ -192,9 +231,16 @@ extension StringTensor: TensorGroup {
address!.initialize(to: handle._cTensorHandle)
}

public var _tensorHandles: [_AnyTensorHandle] { [self.handle.handle] }

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
self.init(handle: TensorHandle(_owning: tensorHandles!.pointee))
}

public init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 1)
self.init(handle: TensorHandle(handle: handles[0]))
}
}

extension Array: TensorArrayProtocol where Element: TensorGroup {
Expand All @@ -216,10 +262,29 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
count: Int(count)).joined())
}

public var _tensorHandles: ([_AnyTensorHandle]) {
var result: [_AnyTensorHandle] = []
result.reserveCapacity(Int(self._tensorHandleCount))
for elem in self {
result += elem._tensorHandles
}
return result
}

public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?, count: Int) {
let size = count / Int(Element._tensorHandleCount)
self = Array((0..<size).map { Element.init(
_owning: tensorHandles?.advanced(by: $0 * Int(Element._tensorHandleCount)))
})
}

public init(handles: [_AnyTensorHandle]) {
let size = handles.count / Int(Element._tensorHandleCount)
self = Array((0..<size).map {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The map(_:) method already produces an array. No need to call the array initializer.

Suggested change
self = Array((0..<size).map {
self = (0..<size).map {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

let start = $0 * Int(Element._tensorHandleCount)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix indentation here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

let end = start + Int(Element._tensorHandleCount)
let elemHandles = Array<_AnyTensorHandle>(handles[start..<end])
return Element.init(handles: elemHandles)
})
}
}
16 changes: 15 additions & 1 deletion Sources/TensorFlow/Core/TensorHandle.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
self.handle = TFETensorHandle(_owning: cTensorHandle)
}

public init(handle: _AnyTensorHandle) {
self.handle = handle
}

@usableFromInline
init(copyingFromCTensor cTensor: CTensor) {
let status = TF_NewStatus()
Expand Down Expand Up @@ -105,7 +109,7 @@ public struct TensorHandle<Scalar> where Scalar: _TensorFlowDataTypeCompatible {
extension TensorHandle where Scalar: TensorFlowScalar {
/// Create a `TensorHandle` with a closure that initializes the underlying buffer.
///
/// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a
/// `scalarsInitializer` receives a buffer with exactly enough capacity to hold the scalars in a
/// tensor with shape `shape`. `scalarsInitializer` must initialize the entire buffer, with
/// contiguous scalars in row-major order.
@inlinable
Expand Down Expand Up @@ -145,6 +149,11 @@ public struct ResourceHandle {
init(owning cTensorHandle: CTensorHandle) {
self.handle = TFETensorHandle(_owning: cTensorHandle)
}

@usableFromInline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop @usableFromInline?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

init(handle: _AnyTensorHandle) {
self.handle = handle
}
}

public struct VariantHandle {
Expand All @@ -157,4 +166,9 @@ public struct VariantHandle {
init(owning cTensorHandle: CTensorHandle) {
self.handle = TFETensorHandle(_owning: cTensorHandle)
}

@usableFromInline
init(handle: _AnyTensorHandle) {
self.handle = handle
}
}
10 changes: 10 additions & 0 deletions Sources/TensorFlow/Operators/Dataset.swift
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,16 @@ public struct Zip2TensorGroup<T: TensorGroup, U: TensorGroup>: TensorGroup {
self.first = first
self.second = second
}

public var _tensorHandles: [_AnyTensorHandle] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public var _tensorHandles: [_AnyTensorHandle] {
public var _tensorHandles: [_AnyTensorHandle] {

first._tensorHandles + second._tensorHandles
}

public init(handles: [_AnyTensorHandle]) {
let firstEnd = Int(T._tensorHandleCount)
self.first = T.init(handles: Array(handles[0..<firstEnd]))
self.second = U.init(handles: Array(handles[firstEnd..<handles.count]))
}
}

// TODO(SR-9156): This does not work in graph mode.
Expand Down
8 changes: 8 additions & 0 deletions Tests/TensorFlowTests/OperatorTests/DatasetTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ import XCTest
struct SimpleOutput: TensorGroup {
let a: TensorHandle<Int32>
let b: TensorHandle<Int32>

init(handles: [_AnyTensorHandle]) {
precondition(handles.count == 2)
a = TensorHandle<Int32>(handle: handles[0])
b = TensorHandle<Int32>(handle: handles[1])
}

public var _tensorHandles: [_AnyTensorHandle] { [a.handle, b.handle] }
}

final class DatasetTests: XCTestCase {
Expand Down
Loading