diff --git a/.gitignore b/.gitignore index 567f66548..661f4ab2e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ xcuserdata DerivedData/ *.xcodeproj *~ +*.vscode +*.idea ### MacOS ### .DS_Store diff --git a/Sources/TensorFlow/Core/DataTypes.swift b/Sources/TensorFlow/Core/DataTypes.swift index 73ec15191..7e9618f3f 100644 --- a/Sources/TensorFlow/Core/DataTypes.swift +++ b/Sources/TensorFlow/Core/DataTypes.swift @@ -14,14 +14,18 @@ import CTensorFlow -public extension TensorDataType { - var _cDataType: TF_DataType { - return TF_DataType(rawValue: _internalStorageType) +/// A TensorFlow dynamic type value that can be created from types that conform to +/// `TensorFlowScalar`. +// This simply wraps a `TF_DataType` and allows user code to handle +// `TF_DataType` without importing CTensorFlow, which pollutes the namespace +// with TensorFlow C API declarations. +public struct TensorDataType { + public var _cDataType: TF_DataType + + @usableFromInline + internal init(_ cDataType: TF_DataType) { + self._cDataType = cDataType } - - init(_ cDataType: TF_DataType) { - self.init(rawValue: cDataType.rawValue) - } } @usableFromInline diff --git a/Sources/TensorFlow/Core/TensorGroup.swift b/Sources/TensorFlow/Core/TensorGroup.swift index f06809bda..4390e0bf4 100644 --- a/Sources/TensorFlow/Core/TensorGroup.swift +++ b/Sources/TensorFlow/Core/TensorGroup.swift @@ -14,6 +14,45 @@ import CTensorFlow +/// A protocol representing types that can be mapped to `Array`. +/// +/// This protocol is defined separately from `TensorGroup` in order for the number of tensors to be +/// determined at runtime. For example, `[Tensor]` may have an unknown number of elements at +/// compile time. +/// +/// This protocol can be derived automatically for structs whose stored properties all conform to +/// the `TensorGroup` protocol. It cannot be derived automatically for structs whose properties all +/// conform to `TensorArrayProtocol` due to the constructor requirement (i.e., in such cases it +/// would be impossible to know how to break down `count` among the stored properties). +public protocol TensorArrayProtocol { + /// Writes the tensor handles to `address`, which must be allocated with enough capacity to hold + /// `_tensorHandleCount` handles. The tensor handles written to `address` are borrowed: this + /// container still owns them. + func _unpackTensorHandles(into address: UnsafeMutablePointer?) + + var _tensorHandleCount: Int32 { get } + var _typeList: [TensorDataType] { get } + + init(_owning tensorHandles: UnsafePointer?, count: Int) +} + +/// A protocol representing types that can be mapped to and from `Array`. +/// +/// When a `TensorGroup` is used as an argument to a tensor operation, it is passed as an argument +/// list whose elements are the tensor fields of the type. +/// +/// When a `TensorGroup` is returned as a result of a tensor operation, it is initialized with its +/// tensor fields set to the tensor operation's tensor results. +public protocol TensorGroup: TensorArrayProtocol { + + /// The types of the tensor stored properties in this type. + static var _typeList: [TensorDataType] { get } + + /// Initializes a value of this type, taking ownership of the `_tensorHandleCount` tensors + /// starting at address `tensorHandles`. + init(_owning tensorHandles: UnsafePointer?) +} + public extension TensorGroup { /// The number of tensor fields in this type. static var _tensorHandleCount: Int32 { return Int32(Self._typeList.count) }