diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index f460ca9bf76b0..43dbd2aacb059 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -118,6 +118,7 @@ IDENTIFIER(withArguments) IDENTIFIER(withKeywordArguments) // SWIFT_ENABLE_TENSORFLOW +IDENTIFIER(TensorFlow) // KeyPathIterable IDENTIFIER(AllKeyPaths) IDENTIFIER(allKeyPaths) diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 9f5a5a5656429..9506f1f4c23a2 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -83,6 +83,8 @@ PROTOCOL(FloatingPoint) PROTOCOL(KeyPathIterable) PROTOCOL(TensorArrayProtocol) PROTOCOL(TensorGroup) +PROTOCOL_(TensorFlowDataTypeCompatible) +PROTOCOL(TensorProtocol) PROTOCOL(VectorNumeric) PROTOCOL(Differentiable) diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 8249b987f1033..ad1ee8fd99c98 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -809,19 +809,62 @@ CanType ASTContext::getAnyObjectType() const { // SWIFT_ENABLE_TENSORFLOW -/// Retrieve the decl for TensorDataType. +/// Retrieve the decl for TensorFlow.TensorHandle iff the TensorFlow module has +/// been imported. Otherwise, this returns null. +ClassDecl *ASTContext::getTensorHandleDecl() const { + if (getImpl().TensorHandleDecl) + return getImpl().TensorHandleDecl; + + // See if the TensorFlow module was imported. If not, return null. + auto tfModule = getLoadedModule(Id_TensorFlow); + if (!tfModule) + return nullptr; + + SmallVector results; + tfModule->lookupValue({ }, getIdentifier("TensorHandle"), + NLKind::UnqualifiedLookup, results); + + for (auto result : results) + if (auto CD = dyn_cast(result)) + return getImpl().TensorHandleDecl = CD; + return nullptr; +} + +/// Retrieve the decl for TensorFlow.TensorShape iff the TensorFlow module has +/// been imported. Otherwise, this returns null. +StructDecl *ASTContext::getTensorShapeDecl() const { + if (getImpl().TensorShapeDecl) + return getImpl().TensorShapeDecl; + + // See if the TensorFlow module was imported. If not, return null. + auto tfModule = getLoadedModule(Id_TensorFlow); + if (!tfModule) + return nullptr; + + SmallVector results; + tfModule->lookupValue({}, getIdentifier("TensorShape"), + NLKind::UnqualifiedLookup, results); + + for (auto result : results) + if (auto CD = dyn_cast(result)) + return getImpl().TensorShapeDecl = CD; + return nullptr; +} + +/// Retrieve the decl for TensorFlow.TensorDataType iff the TensorFlow module has +/// been imported. Otherwise, this returns null. StructDecl *ASTContext::getTensorDataTypeDecl() const { if (getImpl().TensorDataTypeDecl) return getImpl().TensorDataTypeDecl; - // See if the Stdlib module was imported. If not, return null. - auto stdlibModule = getStdlibModule(); - if (!stdlibModule) + // See if the TensorFlow module was imported. If not, return null. + auto tfModule = getLoadedModule(Id_TensorFlow); + if (!tfModule) return nullptr; SmallVector results; - stdlibModule->lookupValue({}, getIdentifier("TensorDataType"), - NLKind::UnqualifiedLookup, results); + tfModule->lookupValue({}, getIdentifier("TensorDataType"), + NLKind::UnqualifiedLookup, results); for (auto result : results) if (auto CD = dyn_cast(result)) @@ -905,6 +948,10 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const { // SWIFT_ENABLE_TENSORFLOW case KnownProtocolKind::TensorArrayProtocol: case KnownProtocolKind::TensorGroup: + case KnownProtocolKind::TensorFlowDataTypeCompatible: + case KnownProtocolKind::TensorProtocol: + M = getLoadedModule(Id_TensorFlow); + break; default: M = getStdlibModule(); break; diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 3a40e939f78ad..722ea38369446 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -4211,6 +4211,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::KeyPathIterable: case KnownProtocolKind::TensorArrayProtocol: case KnownProtocolKind::TensorGroup: + case KnownProtocolKind::TensorFlowDataTypeCompatible: + case KnownProtocolKind::TensorProtocol: case KnownProtocolKind::VectorNumeric: case KnownProtocolKind::Differentiable: return SpecialProtocol::None; diff --git a/stdlib/public/core/CMakeLists.txt b/stdlib/public/core/CMakeLists.txt index f07e1ebd8c14a..fe5bf7a66894f 100644 --- a/stdlib/public/core/CMakeLists.txt +++ b/stdlib/public/core/CMakeLists.txt @@ -184,7 +184,6 @@ set(SWIFTLIB_ESSENTIAL StringGraphemeBreaking.swift # ORDER DEPENDENCY: Must follow UTF16.swift ValidUTF8Buffer.swift WriteBackMutableSlice.swift - HackyTensorflowMigrationSupport.swift MigrationSupport.swift) set(SWIFTLIB_ESSENTIAL_GYB_SOURCES diff --git a/stdlib/public/core/GroupInfo.json b/stdlib/public/core/GroupInfo.json index 6940f9afc1078..c5ef1e9f7e630 100644 --- a/stdlib/public/core/GroupInfo.json +++ b/stdlib/public/core/GroupInfo.json @@ -176,9 +176,6 @@ "AutoDiff": [ "AutoDiff.swift", ], - "HackyTensorflowMigrationSupport": [ - "HackyTensorflowMigrationSupport.swift", - ], "Optional": [ "Optional.swift" ], diff --git a/stdlib/public/core/HackyTensorflowMigrationSupport.swift b/stdlib/public/core/HackyTensorflowMigrationSupport.swift deleted file mode 100644 index e63f58f419e87..0000000000000 --- a/stdlib/public/core/HackyTensorflowMigrationSupport.swift +++ /dev/null @@ -1,80 +0,0 @@ -//===-- HackyTensorflowMigrationSupport.swift -----------------*- swift -*-===// -// -// This source file is part of the Swift.org open source project -// -// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors -// Licensed under Apache License v2.0 with Runtime Library Exception -// -// See https://swift.org/LICENSE.txt for license information -// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors -// -//===----------------------------------------------------------------------===// -// -// This file defines the TensorGroup,TensorDataType and TensorArrayProtocol -// types. -// -//===----------------------------------------------------------------------===// - -// This whole file is a hack in order to allow moving TensorFlow out of the -// swift compiler and to build it independently. There is an obviously more -// general version of this using associated types or something that can be -// given a general name. - -public typealias CTensorHandle = OpaquePointer - -/// A TensorFlow dynamic type value that can be created from types that conform -/// to `TensorFlowScalar`. -// This simply wraps a `TF_DataType` and allows user code to handle -// `TF_DataType` without importing CTensorFlow, which pollutes the namespace -// with TensorFlow C API declarations. -public struct TensorDataType { - public var _internalStorageType: UInt32 - - public init(rawValue: UInt32) { - self._internalStorageType = rawValue - } -} - -/// A protocol representing types that can be mapped to `Array`. -/// -/// This protocol is defined separately from `TensorGroup` in order for the -/// number of tensors to be determined at runtime. For example, -/// `[Tensor]` may have an unknown number of elements at compile time. -/// -/// This protocol can be derived automatically for structs whose stored -/// properties all conform to the `TensorGroup` protocol. It cannot be derived -/// automatically for structs whose properties all conform to -/// `TensorArrayProtocol` due to the constructor requirement (i.e., in such -/// cases it would be impossible to know how to break down `count` among the -/// stored properties). -public protocol TensorArrayProtocol { - /// Writes the tensor handles to `address`, which must be allocated - /// with enough capacity to hold `_tensorHandleCount` handles. The tensor - /// handles written to `address` are borrowed: this container still - /// owns them. - func _unpackTensorHandles(into address: UnsafeMutablePointer?) - - var _tensorHandleCount: Int32 { get } - var _typeList: [TensorDataType] { get } - - init(_owning tensorHandles: UnsafePointer?, count: Int) -} - -/// A protocol representing types that can be mapped to and from -/// `Array`. -/// -/// When a `TensorGroup` is used as an argument to a tensor operation, it is -/// passed as an argument list whose elements are the tensor fields of the type. -/// -/// When a `TensorGroup` is returned as a result of a tensor operation, it is -/// initialized with its tensor fields set to the tensor operation's tensor -/// results. -public protocol TensorGroup : TensorArrayProtocol { - - /// The types of the tensor stored properties in this type. - static var _typeList: [TensorDataType] { get } - - /// Initializes a value of this type, taking ownership of the - /// `_tensorHandleCount` tensors starting at address `tensorHandles`. - init(_owning tensorHandles: UnsafePointer?) -} diff --git a/test/Index/Store/unit-one-file-multi-file-invocation.swift b/test/Index/Store/unit-one-file-multi-file-invocation.swift index 15fec596fae31..057ddbbe68c50 100644 --- a/test/Index/Store/unit-one-file-multi-file-invocation.swift +++ b/test/Index/Store/unit-one-file-multi-file-invocation.swift @@ -13,8 +13,8 @@ // CHECK: [[SWIFT]] // CHECK: DEPEND START -// CHECK: Record | system | Swift.String | [[MODULE]] | {{.+}}.swiftmodule_String-{{.*}} // CHECK: Record | system | Swift.Math.Floating | [[MODULE]] | {{.+}}.swiftmodule_Math_Floating-{{.*}} +// CHECK: Record | system | Swift.String | [[MODULE]] | {{.+}}.swiftmodule_String-{{.*}} // CHECK: DEPEND END func test1() { diff --git a/utils/update_checkout/update-checkout-config.json b/utils/update_checkout/update-checkout-config.json index d7349324aee56..06391da553cd5 100644 --- a/utils/update_checkout/update-checkout-config.json +++ b/utils/update_checkout/update-checkout-config.json @@ -349,7 +349,7 @@ "clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-05-26-a", "libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-05-26-a", "tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86", - "tensorflow-swift-apis": "1d484e1826c7d4efff6d9eb66d6eb6722ce84f12", + "tensorflow-swift-apis": "835d1436a01d9261f0467bc2803cc7f6ac56ed80", "indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-05-26-a", "sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-05-26-a" }