diff --git a/include/swift/AST/ASTContext.h b/include/swift/AST/ASTContext.h index 2a6c812bebc05..dd4f1c6dd8a51 100644 --- a/include/swift/AST/ASTContext.h +++ b/include/swift/AST/ASTContext.h @@ -107,7 +107,6 @@ namespace swift { class VarDecl; class UnifiedStatsReporter; // SWIFT_ENABLE_TENSORFLOW - enum class AutoDiffAssociatedVectorSpaceKind : unsigned; class VectorSpace; class AutoDiffParameterIndices; class DifferentiableAttr; @@ -276,8 +275,7 @@ class ASTContext final { llvm::StringMap RemappedTypes; /// Cache of autodiff-associated vector spaces. - llvm::DenseMap, - Optional> AutoDiffVectorSpaces; + llvm::DenseMap> AutoDiffVectorSpaces; /// Cache of `@differentiable` attributes keyed by parameter indices. This /// helps us diagnose multiple `@differentiable`s that are with respect to the diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 287c335408aaa..37c1cfcd697f9 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -551,11 +551,6 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode { } }; -/// The kind of an associated type. -enum class AutoDiffAssociatedVectorSpaceKind : unsigned { - Tangent = 0, Cotangent = 1 -}; - /// Automatic differentiation utility namespace. namespace autodiff { diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 995a7e4314e89..c8befd9b422cf 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2715,7 +2715,7 @@ NOTE(protocol_witness_missing_specific_differentiable_attr,none, // @differentiating ERROR(differentiating_attr_expected_result_tuple,none, "'@differentiating' attribute requires function to return a two-element tuple of type " - "'(value: T..., pullback: (U.CotangentVector) -> T.CotangentVector...)' or " + "'(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or " "'(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ()) ERROR(differentiating_attr_invalid_result_tuple_value_label,none, "'@differentiating' attribute requires function to return a two-element " diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index ceab2728714b8..af915296809f3 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -135,11 +135,9 @@ IDENTIFIER(zero) IDENTIFIER(Scalar) // Differentiable IDENTIFIER(AllDifferentiableVariables) -IDENTIFIER(CotangentVector) IDENTIFIER(TangentVector) IDENTIFIER(allDifferentiableVariables) IDENTIFIER(moved) -IDENTIFIER(tangentVector) // Kinds of layout constraints IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout") diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 5f5a1197e1a57..9506f1f4c23a2 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -86,9 +86,6 @@ PROTOCOL(TensorGroup) PROTOCOL_(TensorFlowDataTypeCompatible) PROTOCOL(TensorProtocol) PROTOCOL(VectorNumeric) -// TODO(TF-213): Remove underscore `Differentiable` protocols. -PROTOCOL(__Differentiable) -PROTOCOL(_Differentiable) PROTOCOL(Differentiable) PROTOCOL_(ObjectiveCBridgeable) diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 24b42df37d74a..b1c47933ed04b 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -1096,20 +1096,17 @@ class alignas(1 << TypeAlignInBits) TypeBase { TypeTraitResult canBeClass(); // SWIFT_ENABLE_TENSORFLOW - /// Return the associated tangent or cotangent type. Return the null type if - /// there is no associated tangent/cotangent type. - /// - /// `kind` specifies whether to return the tangent or cotangent type. + /// Return the associated tangent type. Return the null type if there is no + /// associated tangent type. /// /// If the type conforms to `Differentiable`, then the associated - /// tangent/cotangent type is the associated `TangentVector`/`CotangentVector` - /// from the `Differentiable` requirement. If the type is a tuple, then the - /// associated tangent/cotangent type is the elementwise tangent/cotangent - /// type of its elements. If the type is a builtin float, then the associated - /// tangent/cotangent type is itself. Otherwise, there is no associated type. + /// tangent type is the associated `TangentVector` from the `Differentiable` + /// requirement. If the type is a tuple, then the associated tangent type is + /// the elementwise tangent type of its elements. If the type is a builtin + /// float, then the associated tangent type is itself. Otherwise, there is no + /// associated type. Optional - getAutoDiffAssociatedVectorSpace(AutoDiffAssociatedVectorSpaceKind kind, - LookupConformanceFn lookupConformance); + getAutoDiffAssociatedTangentSpace(LookupConformanceFn lookupConformance); private: // Make vanilla new/delete illegal for Types. @@ -3074,12 +3071,12 @@ class AnyFunctionType : public TypeBase { /// /// By default, if the original type has a self parameter list and parameter /// indices include self, the computed associated function type will return a - /// linear map taking/returning self's tangent/cotangent *last* instead of - /// first, for consistency with SIL. + /// linear map taking/returning self's tangent *last* instead of first, for + /// consistency with SIL. /// - /// If `makeSelfParamFirst` is true, self's tangent/cotangent is reordered to - /// appear first. This should be used during type-checking, e.g. - /// type-checking `@differentiable` and `@differentiating` attributes. + /// If `makeSelfParamFirst` is true, self's tangent is reordered to appear + /// first. This should be used during type-checking, e.g. type-checking + /// `@differentiable` and `@differentiating` attributes. /// /// \note The original function type (`self`) need not be `@differentiable`. /// The resulting function will preserve all `ExtInfo` of the original diff --git a/lib/AST/Builtins.cpp b/lib/AST/Builtins.cpp index 6a7f8d928628f..1d28a95cd3f5f 100644 --- a/lib/AST/Builtins.cpp +++ b/lib/AST/Builtins.cpp @@ -985,7 +985,7 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction( // rethrows -> (R, (...T.TangentVector) -> R.TangentVector) // VJP: // <...T...(arity), R> (@differentiable (...T) throws -> R, ...T) - // rethrows -> (R, (R.CotangentVector) -> ...T.CotangentVector) + // rethrows -> (R, (R.TangentVector) -> ...T.TangentVector) unsigned numGenericParams = 1 + arity; BuiltinGenericSignatureBuilder builder(Context, numGenericParams); // Look up the Differentiable protocol. diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 87855dfd92cc6..55b4144f98437 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4364,13 +4364,12 @@ makeFunctionType(AnyFunctionType *copy, ArrayRef params, return FunctionType::get(params, retTy, copy->getExtInfo()); } -Optional TypeBase::getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind kind, +Optional TypeBase::getAutoDiffAssociatedTangentSpace( LookupConformanceFn lookupConformance) { assert(lookupConformance); auto &ctx = getASTContext(); - std::pair cacheKey {this, (unsigned)kind}; + Type cacheKey = this; auto lookup = ctx.AutoDiffVectorSpaces.find(cacheKey); if (lookup != ctx.AutoDiffVectorSpaces.end()) return lookup->getSecond(); @@ -4379,11 +4378,11 @@ Optional TypeBase::getAutoDiffAssociatedVectorSpace( return vs; }; - // Functions' tangent/cotangent is the same function except the innermost - // return type being replaced by its tangent/cotangent. + // Functions' tangent is the same function except the innermost return type + // being replaced by its tangent. if (auto *fnTy = getAs()) { - auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedVectorSpace( - kind, lookupConformance); + auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedTangentSpace( + lookupConformance); if (!resultSpace) return cache(None); return cache(VectorSpace::getFunction( @@ -4391,12 +4390,12 @@ Optional TypeBase::getAutoDiffAssociatedVectorSpace( fnTy->getOptGenericSignature()))); } - // Tuples' tangent/cotangent is a tuple of each element's Tangent/Cotangent. + // Tuples' tangent is a tuple of each element's Tangent. if (auto *tupleTy = getAs()) { SmallVector newElts; for (auto elt : tupleTy->getElements()) { auto eltSpace = elt.getType() - ->getAutoDiffAssociatedVectorSpace(kind, lookupConformance); + ->getAutoDiffAssociatedTangentSpace(lookupConformance); if (!eltSpace) continue; newElts.push_back(elt.getWithType(eltSpace->getType())); @@ -4410,22 +4409,12 @@ Optional TypeBase::getAutoDiffAssociatedVectorSpace( return cache(VectorSpace::getTuple(tupleType)); } - // Find the TangentVector/CotangentVector associated type on the - // Differentiable protocol. + // Find the TangentVector associated type on the Differentiable protocol. auto *differentiableProtocol = - ctx.getProtocol(KnownProtocolKind::__Differentiable); - assert(differentiableProtocol && "Could not find __Differentiable protocol"); - Identifier associatedTypeIdentifier; - switch (kind) { - case AutoDiffAssociatedVectorSpaceKind::Tangent: - associatedTypeIdentifier = ctx.Id_TangentVector; - break; - case AutoDiffAssociatedVectorSpaceKind::Cotangent: - associatedTypeIdentifier = ctx.Id_CotangentVector; - break; - } + ctx.getProtocol(KnownProtocolKind::Differentiable); + assert(differentiableProtocol && "Could not find Differentiable protocol"); auto associatedTypeLookup = - differentiableProtocol->lookupDirect(associatedTypeIdentifier); + differentiableProtocol->lookupDirect(ctx.Id_TangentVector); assert(associatedTypeLookup.size() == 1); auto *dependentType = DependentMemberType::get( differentiableProtocol->getDeclaredInterfaceType(), @@ -4448,7 +4437,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) // VJP: (T...) -> ((R...), - // (R.CotangentVector...) -> (T.CotangentVector...)) + // (R.TangentVector...) -> (T.TangentVector...)) // // Note that both can be written as "(T...) -> ((R...), Closure)", so we build // "Closure" and then use common code to wrap "Closure" in the outer function @@ -4487,23 +4476,20 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( SmallVector differentialParams; for (auto wrtParamType : wrtParamTypes) differentialParams.push_back( - AnyFunctionType::Param(wrtParamType->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance) + AnyFunctionType::Param( + wrtParamType->getAutoDiffAssociatedTangentSpace(lookupConformance) ->getType())); SmallVector differentialResults; if (auto *resultTuple = originalResult->getAs()) { auto resultTupleEltType = resultTuple->getElementType(resultIndex); - differentialResults.push_back( - resultTupleEltType->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance) - ->getType()); + differentialResults.push_back(resultTupleEltType + ->getAutoDiffAssociatedTangentSpace(lookupConformance)->getType()); } else { assert(resultIndex == 0 && "resultIndex out of bounds"); differentialResults.push_back( - originalResult->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance) - ->getType()); + originalResult->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getType()); } Type differentialResult = differentialResults.size() > 1 @@ -4515,28 +4501,26 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType( } case AutoDiffAssociatedFunctionKind::VJP: { // closure is the VJP "pullback": - // (R.CotangentVector...) -> (T.CotangentVector...) + // (R.TangentVector...) -> (T.TangentVector...) SmallVector pullbackParams; if (auto *resultTuple = originalResult->getAs()) { auto resultTupleEltType = resultTuple->getElementType(resultIndex); pullbackParams.push_back( - AnyFunctionType::Param( - resultTupleEltType->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, - lookupConformance)->getType())); + AnyFunctionType::Param(resultTupleEltType + ->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getType())); } else { assert(resultIndex == 0 && "resultIndex out of bounds"); pullbackParams.push_back( - AnyFunctionType::Param( - originalResult->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, - lookupConformance)->getType())); + AnyFunctionType::Param(originalResult + ->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getType())); } SmallVector pullbackResults; for (auto wrtParamType : wrtParamTypes) - pullbackResults.push_back(wrtParamType->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance) + pullbackResults.push_back(wrtParamType + ->getAutoDiffAssociatedTangentSpace(lookupConformance) ->getType()); Type pullbackResult = pullbackResults.size() > 1 ? TupleType::get(pullbackResults, ctx) diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 82c398e7460ae..a4fce9fa045f8 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -4192,9 +4192,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::TensorFlowDataTypeCompatible: case KnownProtocolKind::TensorProtocol: case KnownProtocolKind::VectorNumeric: - // TODO(TF-213): Remove underscore `Differentiable` protocols. - case KnownProtocolKind::__Differentiable: - case KnownProtocolKind::_Differentiable: case KnownProtocolKind::Differentiable: return SpecialProtocol::None; } diff --git a/lib/SIL/SILFunctionType.cpp b/lib/SIL/SILFunctionType.cpp index ada1cefe6a865..2ce00297ef676 100644 --- a/lib/SIL/SILFunctionType.cpp +++ b/lib/SIL/SILFunctionType.cpp @@ -154,7 +154,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( // JVP: (T...) -> ((R...), // (T.TangentVector...) -> (R.TangentVector...)) // VJP: (T...) -> ((R...), - // (R.CotangentVector...) -> (T.CotangentVector...)) + // (R.TangentVector...) -> (T.TangentVector...)) auto &ctx = getASTContext(); auto &typeConverter = module.Types; @@ -164,9 +164,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( whereClauseGenSig = getGenericSignature(); // Given a type, returns its formal SIL parameter info. - auto getCotangentParameterInfoForOriginalResult = [&]( - CanType cotanType, ResultConvention origResConv) -> SILParameterInfo { - auto &tl = typeConverter.getTypeLowering(cotanType, ResilienceExpansion::Minimal); + auto getTangentParameterInfoForOriginalResult = [&]( + CanType tanType, ResultConvention origResConv) -> SILParameterInfo { + auto &tl = typeConverter.getTypeLowering(tanType, + ResilienceExpansion::Minimal); ParameterConvention conv; switch (origResConv) { case ResultConvention::Owned: @@ -183,13 +184,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( conv = ParameterConvention::Indirect_In_Guaranteed; break; } - return {cotanType, conv}; + return {tanType, conv}; }; // Given a type, returns its formal SIL result info. - auto getCotangentResultInfoForOriginalParameter = [&]( - CanType cotanType, ParameterConvention origParamConv) -> SILResultInfo { - auto &tl = typeConverter.getTypeLowering(cotanType, ResilienceExpansion::Minimal); + auto getTangentResultInfoForOriginalParameter = [&]( + CanType tanType, ParameterConvention origParamConv) -> SILResultInfo { + auto &tl = typeConverter.getTypeLowering(tanType, + ResilienceExpansion::Minimal); ResultConvention conv; switch (origParamConv) { case ParameterConvention::Direct_Owned: @@ -207,7 +209,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( conv = ResultConvention::Indirect; break; } - return {cotanType, conv}; + return {tanType, conv}; }; // Helper function testing if we are differentiating wrt this index. @@ -228,17 +230,15 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( SmallVector differentialParams; for (auto ¶m : wrtParams) { differentialParams.push_back( - {param.getType()->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance) - ->getCanonicalType(), + {param.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getCanonicalType(), param.getConvention()}); } SmallVector differentialResults; auto &result = getResults()[resultIndex]; differentialResults.push_back( - {result.getType()->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance) - ->getCanonicalType(), + {result.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getCanonicalType(), result.getConvention()}); closureType = SILFunctionType::get( /*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None, @@ -249,22 +249,20 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType( case AutoDiffAssociatedFunctionKind::VJP: { SmallVector pullbackParams; auto &origRes = getResults()[resultIndex]; - auto cotangentAssocTy = - origRes.getType()->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance) - ->getCanonicalType(); + auto tangentAssocTy = + origRes.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getCanonicalType(); pullbackParams.push_back( - getCotangentParameterInfoForOriginalResult(cotangentAssocTy, - origRes.getConvention())); + getTangentParameterInfoForOriginalResult(tangentAssocTy, + origRes.getConvention())); SmallVector pullbackResults; for (auto ¶m : wrtParams) { - auto paramCotangentTy = - param.getType()->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance) - ->getCanonicalType(); + auto paramTangentTy = + param.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getCanonicalType(); pullbackResults.push_back( - getCotangentResultInfoForOriginalParameter(paramCotangentTy, - param.getConvention())); + getTangentResultInfoForOriginalParameter(paramTangentTy, + param.getConvention())); } closureType = SILFunctionType::get( /*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None, diff --git a/lib/SIL/SILType.cpp b/lib/SIL/SILType.cpp index cf4902dd51fb7..529cdce53c004 100644 --- a/lib/SIL/SILType.cpp +++ b/lib/SIL/SILType.cpp @@ -595,7 +595,6 @@ bool SILType::isLoweringOf(SILModule &Mod, CanType formalType) { // SWIFT_ENABLE_TENSORFLOW /// Returns true if this SILType is a differentiable type. bool SILType::isDifferentiable(SILModule &M) const { - return getASTType()->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, + return getASTType()->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(M.getSwiftModule())).hasValue(); } diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index e47e2fe32c934..42ea0c0c603ff 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -3444,8 +3444,8 @@ SILGenFunction::getOrCreateAutoDiffLinearMapReorderingThunk( break; } case AutoDiffAssociatedFunctionKind::VJP: { - auto selfCotanInfo = thunkConv.getResults().back(); - if (selfCotanInfo.isFormalDirect()) { + auto selfTanInfo = thunkConv.getResults().back(); + if (selfTanInfo.isFormalDirect()) { for (auto *indRes : indirectResults) argValues.push_back(indRes); } else { @@ -3468,8 +3468,8 @@ SILGenFunction::getOrCreateAutoDiffLinearMapReorderingThunk( break; } case AutoDiffAssociatedFunctionKind::VJP: { - auto selfCotanInfo = thunkConv.getResults().back(); - if (selfCotanInfo.isFormalIndirect()) { + auto selfTanInfo = thunkConv.getResults().back(); + if (selfTanInfo.isFormalIndirect()) { thunkSGF.B.createReturn(loc, apply); break; } @@ -3572,7 +3572,7 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionReorderingThunk( // Otherwise, generate a thunk for reordering: // - The differential self tangent parameter: move from first to last. - // - The pullback self cotangent result: move from first to last. + // - The pullback self tangent result: move from first to last. SmallVector directResults; extractAllElements(apply, thunkSGF.B, directResults); auto linearMap = directResults.back(); diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index f460023bc1edb..8da5b8df0d92e 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -489,8 +489,8 @@ enum class StructExtractDifferentiationStrategy { Inactive, // The `struct_extract` is extracting a field from a Differentiable struct - // with @_fieldwiseProductSpace cotangent space. Therefore, differentiate the - // `struct_extract` by setting the adjoint to a vector in the cotangent space + // with @_fieldwiseProductSpace tangent space. Therefore, differentiate the + // `struct_extract` by setting the adjoint to a vector in the tangent space // that is zero except along the direction of the corresponding field. // // Fields correspond by matching name. @@ -1369,8 +1369,7 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, if (assocGenSig && projType->hasArchetype()) projType = assocGenSig->getCanonicalTypeInContext( projType->mapTypeOutOfContext()); - if (projType->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + if (projType->getAutoDiffAssociatedTangentSpace( LookUpConformanceInSignature(*assocGenSig))) setVaried(teai, i); } @@ -3617,17 +3616,16 @@ class AdjointEmitter final : public SILInstructionVisitor { adjointGenEnv->getForwardingSubstitutionMap()); } - Optional getCotangentSpace(CanType type) { - return type->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + Optional getTangentSpace(CanType type) { + return type->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(getModule().getSwiftModule())); } /// Assuming the given type conforms to `Differentiable` after remapping, - /// returns the associated cotangent space type. - SILType getRemappedCotangentType(SILType type) { + /// returns the associated tangent space type. + SILType getRemappedTangentType(SILType type) { return SILType::getPrimitiveObjectType( - getCotangentSpace(remapType(type).getASTType())->getCanonicalType()); + getTangentSpace(remapType(type).getASTType())->getCanonicalType()); } //--------------------------------------------------------------------------// @@ -3659,7 +3657,7 @@ class AdjointEmitter final : public SILInstructionVisitor { assert(originalValue->getFunction() == &getOriginal()); auto insertion = valueMap.try_emplace( originalValue, makeZeroAdjointValue( - getRemappedCotangentType(originalValue->getType()))); + getRemappedTangentType(originalValue->getType()))); auto it = insertion.first; SWIFT_DEFER { valueMap.erase(it); }; return std::move(it->getSecond()); @@ -3673,12 +3671,11 @@ class AdjointEmitter final : public SILInstructionVisitor { LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue); #ifndef NDEBUG auto origTy = remapType(originalValue->getType()).getASTType(); - auto cotanSpace = origTy->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tanSpace = origTy->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(getModule().getSwiftModule())); - // The adjoint value must be in the cotangent space. - assert(cotanSpace && newAdjointValue.getType().getASTType()->isEqual( - cotanSpace->getCanonicalType())); + // The adjoint value must be in the tangent space. + assert(tanSpace && newAdjointValue.getType().getASTType()->isEqual( + tanSpace->getCanonicalType())); #endif auto insertion = valueMap.try_emplace(originalValue, std::move(newAdjointValue)); @@ -3710,14 +3707,14 @@ class AdjointEmitter final : public SILInstructionVisitor { // Handle `struct_element_addr`. if (auto *seai = dyn_cast(originalProjection)) { auto adjSource = getAdjointBuffer(seai->getOperand()); - auto *cotangentVectorDecl = + auto *tangentVectorDecl = adjSource.getType().getStructOrBoundGenericStruct(); - auto cotanFieldLookup = - cotangentVectorDecl->lookupDirect(seai->getField()->getName()); - assert(cotanFieldLookup.size() == 1); - auto *cotanField = cast(cotanFieldLookup.front()); + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(seai->getField()->getName()); + assert(tanFieldLookup.size() == 1); + auto *tanField = cast(tanFieldLookup.front()); return builder.createStructElementAddr( - seai->getLoc(), adjSource.getValue(), cotanField); + seai->getLoc(), adjSource.getValue(), tanField); } // Handle `tuple_element_addr`. if (auto *teai = dyn_cast(originalProjection)) { @@ -3728,7 +3725,7 @@ class AdjointEmitter final : public SILInstructionVisitor { auto origTupleTy = source->getType().castTo(); unsigned adjIndex = 0; for (unsigned i : range(teai->getFieldNo())) { - if (getCotangentSpace( + if (getTangentSpace( origTupleTy->getElement(i).getType()->getCanonicalType())) ++adjIndex; } @@ -3786,7 +3783,7 @@ class AdjointEmitter final : public SILInstructionVisitor { // Allocate local buffer and initialize to zero. auto *newBuf = localAllocBuilder.createAllocStack( originalBuffer.getLoc(), - getRemappedCotangentType(originalBuffer->getType())); + getRemappedTangentType(originalBuffer->getType())); auto *access = localAllocBuilder.createBeginAccess( newBuf->getLoc(), newBuf, SILAccessKind::Init, SILAccessEnforcement::Static, /*noNestedConflict*/ true, @@ -4153,23 +4150,23 @@ class AdjointEmitter final : public SILInstructionVisitor { auto allResultsIt = allResults.begin(); for (unsigned i : applyInfo.actualIndices.parameters->getIndices()) { auto origArg = ai->getArgument(origNumIndRes + i); - auto cotan = *allResultsIt++; - // If a cotangent value corresponds to a non-desired parameter, it won't + auto tan = *allResultsIt++; + // If a tangent value corresponds to a non-desired parameter, it won't // be used, so release it. if (!applyInfo.desiredIndices.parameters->contains(i)) { - emitCleanup(builder, loc, cotan); + emitCleanup(builder, loc, tan); continue; } - if (cotan->getType().isAddress()) { - addToAdjointBuffer(origArg, cotan); - emitCleanup(builder, loc, cotan); + if (tan->getType().isAddress()) { + addToAdjointBuffer(origArg, tan); + emitCleanup(builder, loc, tan); } else { if (origArg->getType().isAddress()) { auto adjBuf = getAdjointBuffer(origArg); if (errorOccurred) return; - auto *tmpBuf = builder.createAllocStack(loc, cotan->getType()); - builder.createStore(loc, cotan, tmpBuf, + auto *tmpBuf = builder.createAllocStack(loc, tan->getType()); + builder.createStore(loc, tan, tmpBuf, getBufferSOQ(tmpBuf->getType().getASTType(), getAdjoint())); auto *readAccess = builder.createBeginAccess( loc, tmpBuf, SILAccessKind::Read, SILAccessEnforcement::Static, @@ -4181,7 +4178,7 @@ class AdjointEmitter final : public SILInstructionVisitor { } else addAdjointValue(origArg, makeConcreteAdjointValue(ValueWithCleanup( - cotan, makeCleanup(cotan, emitCleanup, {seed.getCleanup()})))); + tan, makeCleanup(tan, emitCleanup, {seed.getCleanup()})))); } } // Deallocate pullback indirect results. @@ -4204,51 +4201,50 @@ class AdjointEmitter final : public SILInstructionVisitor { for (auto *field : structDecl->getStoredProperties()) { auto fv = si->getFieldValue(field); addAdjointValue(fv, makeZeroAdjointValue( - getRemappedCotangentType(fv->getType()))); + getRemappedTangentType(fv->getType()))); } break; case AdjointValueKind::Concrete: { auto adjStruct = materializeAdjointDirect(std::move(av), loc); if (structDecl->getAttrs().hasAttribute()) { - // Find the struct `CotangentVector` type. + // Find the struct `TangentVector` type. auto structTy = remapType(si->getType()).getASTType(); - auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(getModule().getSwiftModule())) ->getType()->getCanonicalType(); assert(!getModule().Types.getTypeLowering( - cotangentVectorTy, ResilienceExpansion::Minimal) + tangentVectorTy, ResilienceExpansion::Minimal) .isAddressOnly()); - auto *cotangentVectorDecl = - cotangentVectorTy->getStructOrBoundGenericStruct(); - assert(cotangentVectorDecl); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); // Accumulate adjoints for the fields of the `struct` operand. for (auto *field : structDecl->getStoredProperties()) { - // There does not exist a corresponding cotangent field for original + // There does not exist a corresponding tangent field for original // fields with `@noDerivative` attribute. Emit an error. if (field->getAttrs().hasAttribute()) continue; - // Find the corresponding field in the cotangent space. - VarDecl *cotanField = nullptr; - if (cotangentVectorDecl == structDecl) - cotanField = field; + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + if (tangentVectorDecl == structDecl) + tanField = field; // Otherwise, look up the field by name. else { - auto cotanFieldLookup = - cotangentVectorDecl->lookupDirect(field->getName()); - assert(cotanFieldLookup.size() == 1); - cotanField = cast(cotanFieldLookup.front()); + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(field->getName()); + assert(tanFieldLookup.size() == 1); + tanField = cast(tanFieldLookup.front()); } auto *adjStructElt = - builder.createStructExtract(loc, adjStruct, cotanField); + builder.createStructExtract(loc, adjStruct, tanField); addAdjointValue( si->getFieldValue(field), makeConcreteAdjointValue(ValueWithCleanup( adjStructElt, makeCleanup(adjStructElt, emitCleanup)))); } } else { - // FIXME(TF-21): If `CotangentVector` is not marked + // FIXME(TF-21): If `TangentVector` is not marked // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. llvm_unreachable("Unhandled. Are you trying to differentiate a " "memberwise initializer?"); @@ -4256,7 +4252,7 @@ class AdjointEmitter final : public SILInstructionVisitor { break; } case AdjointValueKind::Aggregate: { - // FIXME(TF-21): If `CotangentVector` is not marked + // FIXME(TF-21): If `TangentVector` is not marked // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. // for (auto pair : llvm::zip(si->getElements(), av.getAggregateElements())) // addAdjointValue(std::get<0>(pair), std::get<1>(pair)); @@ -4282,48 +4278,47 @@ class AdjointEmitter final : public SILInstructionVisitor { // Compute adjoint as follows: // y = struct_extract x, #key // adj[x] += struct (0, ..., #key': adj[y], ..., 0) - // where `#key'` is the field in the cotangent space corresponding to + // where `#key'` is the field in the tangent space corresponding to // `#key`. auto structTy = remapType(sei->getOperand()->getType()).getASTType(); - auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tangentVectorTy = structTy->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(getModule().getSwiftModule())) ->getType()->getCanonicalType(); assert(!getModule().Types.getTypeLowering( - cotangentVectorTy, ResilienceExpansion::Minimal) + tangentVectorTy, ResilienceExpansion::Minimal) .isAddressOnly()); - auto cotangentVectorSILTy = - SILType::getPrimitiveObjectType(cotangentVectorTy); - auto *cotangentVectorDecl = - cotangentVectorTy->getStructOrBoundGenericStruct(); - assert(cotangentVectorDecl); - // Find the corresponding field in the cotangent space. - VarDecl *cotanField = nullptr; - // If the cotangent space is the original struct, then field is the same. - if (cotangentVectorDecl == sei->getStructDecl()) - cotanField = sei->getField(); + auto tangentVectorSILTy = + SILType::getPrimitiveObjectType(tangentVectorTy); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If the tangent space is the original struct, then field is the same. + if (tangentVectorDecl == sei->getStructDecl()) + tanField = sei->getField(); // Otherwise, look up the field by name. else { - auto cotanFieldLookup = - cotangentVectorDecl->lookupDirect(sei->getField()->getName()); - assert(cotanFieldLookup.size() == 1); - cotanField = cast(cotanFieldLookup.front()); + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(sei->getField()->getName()); + assert(tanFieldLookup.size() == 1); + tanField = cast(tanFieldLookup.front()); } // Accumulate adjoint for the `struct_extract` operand. auto av = takeAdjointValue(sei); switch (av.getKind()) { case AdjointValueKind::Zero: addAdjointValue(sei->getOperand(), - makeZeroAdjointValue(cotangentVectorSILTy)); + makeZeroAdjointValue(tangentVectorSILTy)); break; case AdjointValueKind::Concrete: case AdjointValueKind::Aggregate: { SmallVector eltVals; - for (auto *field : cotangentVectorDecl->getStoredProperties()) { - if (field == cotanField) { + for (auto *field : tangentVectorDecl->getStoredProperties()) { + if (field == tanField) { eltVals.push_back(av); } else { - auto substMap = cotangentVectorTy->getMemberSubstitutionMap( + auto substMap = tangentVectorTy->getMemberSubstitutionMap( field->getModuleContext(), field); auto fieldTy = field->getType().subst(substMap); auto fieldSILTy = @@ -4334,7 +4329,7 @@ class AdjointEmitter final : public SILInstructionVisitor { } } addAdjointValue(sei->getOperand(), - makeAggregateAdjointValue(cotangentVectorSILTy, eltVals)); + makeAggregateAdjointValue(tangentVectorSILTy, eltVals)); } } return; @@ -4373,17 +4368,17 @@ class AdjointEmitter final : public SILInstructionVisitor { switch (av.getKind()) { case AdjointValueKind::Zero: for (auto eltVal : ti->getElements()) { - if (!getCotangentSpace(eltVal->getType().getASTType())) + if (!getTangentSpace(eltVal->getType().getASTType())) continue; addAdjointValue(eltVal, makeZeroAdjointValue( - getRemappedCotangentType(eltVal->getType()))); + getRemappedTangentType(eltVal->getType()))); } break; case AdjointValueKind::Concrete: { auto val = av.getConcreteValue(); unsigned adjIdx = 0; for (auto i : range(ti->getNumOperands())) { - if (!getCotangentSpace(ti->getOperand(i)->getType().getASTType())) + if (!getTangentSpace(ti->getOperand(i)->getType().getASTType())) continue; auto adjElt = val; if (val.getType().is()) @@ -4396,7 +4391,7 @@ class AdjointEmitter final : public SILInstructionVisitor { case AdjointValueKind::Aggregate: unsigned adjIdx = 0; for (auto i : range(ti->getElements().size())) { - if (!getCotangentSpace(ti->getElement(i)->getType().getASTType())) + if (!getTangentSpace(ti->getElement(i)->getType().getASTType())) continue; addAdjointValue(ti->getElement(i), av.takeAggregateElement(adjIdx++)); } @@ -4409,32 +4404,32 @@ class AdjointEmitter final : public SILInstructionVisitor { /// |--- n-th element /// adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0) void visitTupleExtractInst(TupleExtractInst *tei) { - auto tupleCotanTy = getRemappedCotangentType(tei->getOperand()->getType()); + auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType()); auto av = takeAdjointValue(tei); switch (av.getKind()) { case AdjointValueKind::Zero: - addAdjointValue(tei->getOperand(), makeZeroAdjointValue(tupleCotanTy)); + addAdjointValue(tei->getOperand(), makeZeroAdjointValue(tupleTanTy)); break; case AdjointValueKind::Aggregate: case AdjointValueKind::Concrete: { auto tupleTy = tei->getTupleType(); - auto tupleCotanTupleTy = tupleCotanTy.getAs(); - if (!tupleCotanTupleTy) { + auto tupleTanTupleTy = tupleTanTy.getAs(); + if (!tupleTanTupleTy) { addAdjointValue(tei->getOperand(), std::move(av)); break; } SmallVector elements; unsigned adjIdx = 0; for (unsigned i : range(tupleTy->getNumElements())) { - if (!getCotangentSpace( + if (!getTangentSpace( tupleTy->getElement(i).getType()->getCanonicalType())) continue; if (tei->getFieldNo() == i) elements.push_back(av); else elements.push_back(makeZeroAdjointValue( - getRemappedCotangentType(SILType::getPrimitiveObjectType( - tupleCotanTupleTy->getElementType(adjIdx++) + getRemappedTangentType(SILType::getPrimitiveObjectType( + tupleTanTupleTy->getElementType(adjIdx++) ->getCanonicalType())))); } if (elements.size() == 1) { @@ -4442,7 +4437,7 @@ class AdjointEmitter final : public SILInstructionVisitor { break; } addAdjointValue(tei->getOperand(), - makeAggregateAdjointValue(tupleCotanTy, elements)); + makeAggregateAdjointValue(tupleTanTy, elements)); break; } } @@ -4462,9 +4457,9 @@ class AdjointEmitter final : public SILInstructionVisitor { // Handle `dealloc_stack` instruction. // Original: dealloc_stack y - // Adjoint: adj[y] = alloc_stack $T.CotangentVector + // Adjoint: adj[y] = alloc_stack $T.TangentVector void visitDeallocStackInst(DeallocStackInst *dsi) { - auto bufType = getRemappedCotangentType(dsi->getOperand()->getType()); + auto bufType = getRemappedTangentType(dsi->getOperand()->getType()); auto *adjBuf = builder.createAllocStack(dsi->getLoc(), bufType); auto *access = builder.createBeginAccess(dsi->getLoc(), adjBuf, SILAccessKind::Init, @@ -4832,11 +4827,10 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, auto *swiftMod = getModule().getSwiftModule(); // TODO(TF-202): Diagnose no `AdditiveArithmetic` due to generic signature // minimization bug. - auto cotangentSpace = type->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tangentSpace = type->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(swiftMod)); - assert(cotangentSpace && "No tangent space for this type"); - switch (cotangentSpace->getKind()) { + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { case VectorSpace::Kind::Vector: { // Look up conformance to `AdditiveArithmetic`. auto *additiveArithmeticProto = @@ -4867,7 +4861,7 @@ void AdjointEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess, return; } case VectorSpace::Kind::Tuple: { - auto tupleType = cotangentSpace->getTuple(); + auto tupleType = tangentSpace->getTuple(); SmallVector zeroElements; for (unsigned i : range(tupleType->getNumElements())) { auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); @@ -4996,11 +4990,10 @@ SILValue AdjointEmitter::accumulateDirect(SILValue lhs, SILValue rhs) { auto adjointASTTy = adjointTy.getASTType(); auto loc = lhs.getLoc(); auto *swiftMod = getModule().getSwiftModule(); - auto cotangentSpace = adjointASTTy->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tangentSpace = adjointASTTy->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(swiftMod)); - assert(cotangentSpace && "No tangent space for this type"); - switch (cotangentSpace->getKind()) { + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { case VectorSpace::Kind::Vector: { // Allocate buffers for inputs and output. auto *resultBuf = builder.createAllocStack(loc, adjointTy); @@ -5048,7 +5041,7 @@ SILValue AdjointEmitter::accumulateDirect(SILValue lhs, SILValue rhs) { return val; } case VectorSpace::Kind::Tuple: { - auto tupleType = cotangentSpace->getTuple(); + auto tupleType = tangentSpace->getTuple(); SmallVector adjElements; for (unsigned i : range(tupleType->getNumElements())) { auto *eltLHS = builder.createTupleExtract(loc, lhs, i); @@ -5076,17 +5069,16 @@ void AdjointEmitter::accumulateIndirect( auto adjointTy = lhsBufAccess->getType(); auto adjointASTTy = adjointTy.getASTType(); auto *swiftMod = getModule().getSwiftModule(); - auto cotangentSpace = adjointASTTy->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tangentSpace = adjointASTTy->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(swiftMod)); - assert(cotangentSpace && "No tangent space for this type"); - switch (cotangentSpace->getKind()) { + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { case VectorSpace::Kind::Vector: { auto *proto = getContext().getAdditiveArithmeticProtocol(); auto *combinerFuncDecl = getContext().getPlusDecl(); // Call the combiner function and return. - auto adjointParentModule = cotangentSpace->getNominal() - ? cotangentSpace->getNominal()->getModuleContext() + auto adjointParentModule = tangentSpace->getNominal() + ? tangentSpace->getNominal()->getModuleContext() : getModule().getSwiftModule(); auto confRef = adjointParentModule->lookupConformance(adjointASTTy, proto); // TODO(TF-202): Diagnose no `AdditiveArithmetic` due to generic signature @@ -5112,7 +5104,7 @@ void AdjointEmitter::accumulateIndirect( return; } case VectorSpace::Kind::Tuple: { - auto tupleType = cotangentSpace->getTuple(); + auto tupleType = tangentSpace->getTuple(); for (unsigned i : range(tupleType->getNumElements())) { auto *destAddr = builder.createTupleElementAddr(loc, resultBufAccess, i); auto *eltAddrLHS = builder.createTupleElementAddr(loc, lhsBufAccess, i); @@ -5138,11 +5130,10 @@ void AdjointEmitter::accumulateIndirect(SILValue lhsDestAccess, auto type = lhsDestAccess->getType(); auto astType = type.getASTType(); auto *swiftMod = getModule().getSwiftModule(); - auto cotangentSpace = astType->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, + auto tangentSpace = astType->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(swiftMod)); - assert(cotangentSpace && "No tangent space for this type"); - switch (cotangentSpace->getKind()) { + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { case VectorSpace::Kind::Vector: { auto *proto = getContext().getAdditiveArithmeticProtocol(); auto *accumulatorFuncDecl = getContext().getPlusEqualDecl(); @@ -5168,7 +5159,7 @@ void AdjointEmitter::accumulateIndirect(SILValue lhsDestAccess, return; } case VectorSpace::Kind::Tuple: { - auto tupleType = cotangentSpace->getTuple(); + auto tupleType = tangentSpace->getTuple(); for (unsigned i : range(tupleType->getNumElements())) { auto *destAddr = builder.createTupleElementAddr(loc, lhsDestAccess, i); auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAccess, i); @@ -5361,10 +5352,10 @@ void DifferentiationTask::createEmptyAdjoint() { module.Types, origTy->getGenericSignature()); // Given a type, returns its formal SIL parameter info. - auto getCotangentParameterInfoForOriginalResult = [&]( - CanType cotanType, ResultConvention origResConv) -> SILParameterInfo { + auto getTangentParameterInfoForOriginalResult = [&]( + CanType tanType, ResultConvention origResConv) -> SILParameterInfo { auto &tl = context.getTypeConverter().getTypeLowering( - cotanType, ResilienceExpansion::Minimal); + tanType, ResilienceExpansion::Minimal); ParameterConvention conv; switch (origResConv) { case ResultConvention::Owned: @@ -5381,14 +5372,14 @@ void DifferentiationTask::createEmptyAdjoint() { conv = ParameterConvention::Indirect_In_Guaranteed; break; } - return {cotanType, conv}; + return {tanType, conv}; }; // Given a type, returns its formal SIL result info. - auto getCotangentResultInfoForOriginalParameter = [&]( - CanType cotanType, ParameterConvention origParamConv) -> SILResultInfo { + auto getTangentResultInfoForOriginalParameter = [&]( + CanType tanType, ParameterConvention origParamConv) -> SILResultInfo { auto &tl = context.getTypeConverter().getTypeLowering( - cotanType, ResilienceExpansion::Minimal); + tanType, ResilienceExpansion::Minimal); ResultConvention conv; switch (origParamConv) { case ParameterConvention::Direct_Owned: @@ -5406,7 +5397,7 @@ void DifferentiationTask::createEmptyAdjoint() { conv = ResultConvention::Indirect; break; } - return {cotanType, conv}; + return {tanType, conv}; }; // Parameters of the adjoint are: @@ -5414,7 +5405,7 @@ void DifferentiationTask::createEmptyAdjoint() { // - a primal value struct, // - original results, and // - the original parameters. - // Results of the adjoint are in the cotangent space of the original + // Results of the adjoint are in the tangent space of the original // parameters. SmallVector adjParams; SmallVector adjResults; @@ -5422,10 +5413,9 @@ void DifferentiationTask::createEmptyAdjoint() { // Add adjoint parameter for the seed. auto origResInfo = origTy->getResults()[getIndices().source]; - adjParams.push_back(getCotangentParameterInfoForOriginalResult( + adjParams.push_back(getTangentParameterInfoForOriginalResult( origResInfo.getType() - ->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance) + ->getAutoDiffAssociatedTangentSpace(lookupConformance) ->getCanonicalType(), origResInfo.getConvention())); // Accept a primal value struct in the adjoint parameter list. This is the @@ -5437,10 +5427,9 @@ void DifferentiationTask::createEmptyAdjoint() { // Add adjoint results for the original differentiation parameters. for (auto i : getIndices().parameters->getIndices()) { auto origParam = origParams[i]; - adjResults.push_back(getCotangentResultInfoForOriginalParameter( + adjResults.push_back(getTangentResultInfoForOriginalParameter( origParam.getType() - ->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance) + ->getAutoDiffAssociatedTangentSpace(lookupConformance) ->getCanonicalType(), origParam.getConvention())); } diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 69c14612acfc4..074df8a6bfec8 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -88,7 +88,7 @@ static StructDecl *convertToStructDecl(ValueDecl *v) { // conformances. static Type getAssociatedType(VarDecl *decl, DeclContext *DC, Identifier id) { auto &C = decl->getASTContext(); - auto *diffableProto = C.getProtocol(KnownProtocolKind::__Differentiable); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); if (!decl->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(decl); auto varType = DC->mapTypeIntoContext(decl->getValueInterfaceType()); @@ -106,7 +106,7 @@ static Type getAssociatedType(VarDecl *decl, DeclContext *DC, Identifier id) { static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) { assert(DC->getSelfNominalTypeDecl() && "Must be a nominal `DeclContext`"); auto &C = DC->getASTContext(); - auto *diffableProto = C.getProtocol(KnownProtocolKind::__Differentiable); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); assert(diffableProto && "`Differentiable` protocol not found"); auto conf = TypeChecker::conformsToProtocol(DC->getSelfTypeInContext(), diffableProto, @@ -131,11 +131,10 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); - // Nominal type must not customize `TangentVector`, `CotangentVector`, or + // Nominal type must not customize `TangentVector` or // `AllDifferentiableVariables` to anything other than `Self`. // Otherwise, synthesis is semantically unsupported. auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector); - auto cotangentDecls = nominal->lookupDirect(C.Id_CotangentVector); auto allDiffableVarsDecls = nominal->lookupDirect(C.Id_AllDifferentiableVariables); auto nominalTypeInContext = @@ -153,7 +152,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, return structDecl; // 2. Equal nominal's implicit parent. // This can occur during mutually recursive constraints. Example: - // `X == X.TangentVector, X.CotangentVector.CotangentVector == X`. + // `X == X.TangentVector`. if (nominal->isImplicit() && structDecl == nominal->getDeclContext() && TypeChecker::conformsToProtocol(structDecl->getDeclaredInterfaceType(), diffableProto, DC, @@ -175,10 +174,6 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, auto invalidTangentDecls = llvm::partition(tangentDecls, [&](ValueDecl *v) { return isValidAssocTypeCandidate(v, /*checkAdditiveArithmetic*/ true); }); - auto invalidCotangentDecls = - llvm::partition(cotangentDecls, [&](ValueDecl *v) { - return isValidAssocTypeCandidate(v, /*checkAdditiveArithmetic*/ true); - }); auto invalidAllDiffableVarsDecls = llvm::partition(allDiffableVarsDecls, isValidAssocTypeCandidate); @@ -186,10 +181,6 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, std::distance(tangentDecls.begin(), invalidTangentDecls); auto invalidTangentDeclCount = std::distance(invalidTangentDecls, tangentDecls.end()); - auto validCotangentDeclCount = - std::distance(cotangentDecls.begin(), invalidCotangentDecls); - auto invalidCotangentDeclCount = - std::distance(invalidCotangentDecls, cotangentDecls.end()); auto validAllDiffableVarsDeclCount = std::distance(allDiffableVarsDecls.begin(), invalidAllDiffableVarsDecls); auto invalidAllDiffableVarsDeclCount = @@ -198,10 +189,8 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, // There cannot be any invalid associated types. There can be at most one // valid associated type. if (invalidTangentDeclCount != 0 || - invalidCotangentDeclCount != 0 || invalidAllDiffableVarsDeclCount != 0 || validTangentDeclCount > 1 || - validCotangentDeclCount > 1 || validAllDiffableVarsDeclCount > 1) return false; @@ -247,7 +236,7 @@ static void deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, retNominalTypeExpr); // Get method protocol requirement. - auto *diffProto = C.getProtocol(KnownProtocolKind::__Differentiable); + auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); auto *methodReq = getProtocolRequirement(diffProto, methodName); // Get references to `self` and parameter declarations. @@ -258,18 +247,6 @@ static void deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto *paramDRE = new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true); - // If this is the `tangentVector(from:)` method and the `TangentVector` and - // `CotangentVector` types are identical, simply return the parameter - // `cotangent` expression. This is more efficient than constructing a new - // `TangentVector` instance, which is unnecessary. - if (methodName == C.Id_tangentVector && - retNominalInterfaceType->isEqual(paramDecl->getInterfaceType())) { - ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), paramDRE, true); - funcDecl->setBody( - BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true)); - return; - } - // Hash properties for differentiation into a set for fast lookup. SmallVector diffProps; getStoredPropertiesForDifferentiation(nominal, parentDC, diffProps); @@ -367,14 +344,6 @@ static void deriveBodyDifferentiable_moved(AbstractFunctionDecl *funcDecl, C.getIdentifier("along")); } -// Synthesize body for `tangentVector(from:)`. -static void -deriveBodyDifferentiable_tangentVector(AbstractFunctionDecl *funcDecl, void *) { - auto &C = funcDecl->getASTContext(); - deriveBodyDifferentiable_method(funcDecl, C.Id_tangentVector, - C.getIdentifier("from")); -} - // Synthesize function declaration for a `Differentiable` method requirement. static ValueDecl *deriveDifferentiable_method( DerivedConformance &derived, Identifier methodName, Identifier argumentName, @@ -445,24 +414,6 @@ static ValueDecl *deriveDifferentiable_moved(DerivedConformance &derived) { {deriveBodyDifferentiable_moved, nullptr}); } -// Synthesize the `tangentVector(from:)` function declaration. -static ValueDecl * -deriveDifferentiable_tangentVector(DerivedConformance &derived) { - auto parentDC = derived.getConformanceContext(); - auto &C = derived.TC.Context; - - auto *tangentDecl = getAssociatedStructDecl(parentDC, C.Id_TangentVector); - auto tangentType = tangentDecl->getDeclaredInterfaceType(); - - auto *cotangentDecl = getAssociatedStructDecl(parentDC, C.Id_CotangentVector); - auto cotangentType = cotangentDecl->getDeclaredInterfaceType(); - - return deriveDifferentiable_method( - derived, C.Id_tangentVector, C.getIdentifier("from"), - C.getIdentifier("cotangent"), cotangentType, tangentType, - {deriveBodyDifferentiable_tangentVector, nullptr}); -} - // Return the underlying `allDifferentiableVariables` of a VarDecl `x`. // If `x` conforms to `Differentiable`, return `allDifferentiableVariables`. // Otherwise, return `x`. @@ -470,7 +421,7 @@ static ValueDecl *getUnderlyingAllDiffableVariables(DeclContext *DC, VarDecl *varDecl) { auto *module = DC->getParentModule(); auto &C = module->getASTContext(); - auto *diffableProto = C.getProtocol(KnownProtocolKind::__Differentiable); + auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); auto allDiffableVarsReq = getProtocolRequirement(diffableProto, C.Id_allDifferentiableVariables); if (!varDecl->hasInterfaceType()) @@ -646,8 +597,8 @@ deriveDifferentiable_allDifferentiableVariables(DerivedConformance &derived) { return allDiffableVarsDecl; } -// Return associated `TangentVector`, `CotangentVector`, or -// `AllDifferentiableVariables` struct for a nominal type, if it exists. +// Return associated `TangentVector` or `AllDifferentiableVariables` struct for +// a nominal type, if it exists. // If not, synthesize the struct. Also return a Boolean value that indicates // whether synthesis occurred. static std::pair @@ -658,8 +609,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, auto nominal = derived.Nominal; auto &C = nominal->getASTContext(); - assert(id == C.Id_TangentVector || id == C.Id_CotangentVector || - id == C.Id_AllDifferentiableVariables); + assert(id == C.Id_TangentVector || id == C.Id_AllDifferentiableVariables); // If the associated struct already exists, return it. auto lookup = nominal->lookupDirect(id); @@ -688,9 +638,9 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); - // If the associated type is `TangentVector` or `CotangentVector`, make it - // also conform to `AdditiveArithmetic`. - if (id == C.Id_TangentVector || id == C.Id_CotangentVector) + // If the associated type is `TangentVector`, make it also conform to + // `AdditiveArithmetic`. + if (id == C.Id_TangentVector) inherited.push_back(addArithType); // Associated struct can derive `AdditiveArithmetic` if the associated types @@ -923,7 +873,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, } // Get or synthesize all associated struct types: `TangentVector`, -// `CotangentVector`, and `AllDifferentiableVariables`. +// and `AllDifferentiableVariables`. // Return the type corresponding to the given identifier. static Type getOrSynthesizeAssociatedStructType(DerivedConformance &derived, @@ -933,10 +883,9 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived, auto *nominal = derived.Nominal; auto &C = nominal->getASTContext(); - // Get or synthesize `AllDifferentiableVariables`, `TangentVector`, and - // `CotangentVector` structs at once. Synthesizing all three structs at once - // is necessary in order to correctly set their mutually recursive associated - // types. + // Get or synthesize `AllDifferentiableVariables` and `TangentVector` structs + // at once. Synthesizing all three structs at once is necessary in order to + // correctly set their mutually recursive associated types. auto allDiffableVarsStructSynthesis = getOrSynthesizeSingleAssociatedStruct(derived, C.Id_AllDifferentiableVariables); @@ -952,13 +901,6 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived, return nullptr; freshlySynthesized |= tangentStructSynthesis.second; - auto cotangentStructSynthesis = - getOrSynthesizeSingleAssociatedStruct(derived, C.Id_CotangentVector); - auto *cotangentStruct = cotangentStructSynthesis.first; - if (!cotangentStruct) - return nullptr; - freshlySynthesized |= cotangentStructSynthesis.second; - // When all structs are freshly synthesized, we check emit warnings for // implicit `@noDerivative` members. Checking for fresh synthesis is necessary // because `getOrSynthesizeAssociatedStructType` will be called multiple times @@ -969,41 +911,23 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived, // Add associated typealiases for structs. addAssociatedTypeAliasDecl(C.Id_TangentVector, tangentStruct, tangentStruct, TC); - addAssociatedTypeAliasDecl(C.Id_TangentVector, - cotangentStruct, cotangentStruct, TC); addAssociatedTypeAliasDecl(C.Id_TangentVector, allDiffableVarsStruct, tangentStruct, TC); - addAssociatedTypeAliasDecl(C.Id_CotangentVector, - tangentStruct, cotangentStruct, TC); - addAssociatedTypeAliasDecl(C.Id_CotangentVector, - cotangentStruct, tangentStruct, TC); - addAssociatedTypeAliasDecl(C.Id_CotangentVector, - allDiffableVarsStruct, cotangentStruct, TC); - addAssociatedTypeAliasDecl(C.Id_AllDifferentiableVariables, allDiffableVarsStruct, allDiffableVarsStruct, TC); addAssociatedTypeAliasDecl(C.Id_AllDifferentiableVariables, tangentStruct, tangentStruct, TC); - addAssociatedTypeAliasDecl(C.Id_AllDifferentiableVariables, - cotangentStruct, cotangentStruct, TC); TC.validateDecl(allDiffableVarsStruct); TC.validateDecl(tangentStruct); - TC.validateDecl(cotangentStruct); // Sanity checks for synthesized structs. assert(DerivedConformance::canDeriveAdditiveArithmetic(tangentStruct, parentDC) && "Should be able to derive `AdditiveArithmetic`"); - assert(DerivedConformance::canDeriveAdditiveArithmetic(cotangentStruct, - parentDC) && - "Should be able to derive `AdditiveArithmetic`"); assert(DerivedConformance::canDeriveDifferentiable( tangentStruct, parentDC) && "Should be able to derive `Differentiable`"); - assert(DerivedConformance::canDeriveDifferentiable( - cotangentStruct, parentDC) && - "Should be able to derive `Differentiable`"); assert(DerivedConformance::canDeriveDifferentiable( allDiffableVarsStruct, parentDC) && "Should be able to derive `Differentiable`"); @@ -1012,8 +936,6 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived, StructDecl *requestedStructDecl = nullptr; if (id == C.Id_TangentVector) requestedStructDecl = tangentStruct; - else if (id == C.Id_CotangentVector) - requestedStructDecl = cotangentStruct; else if (id == C.Id_AllDifferentiableVariables) requestedStructDecl = allDiffableVarsStruct; else @@ -1022,7 +944,7 @@ getOrSynthesizeAssociatedStructType(DerivedConformance &derived, requestedStructDecl->getDeclaredInterfaceType()); } -// Synthesize an associated struct type (`TangentVector`, `CotangentVector`, or +// Synthesize an associated struct type (`TangentVector` or // `AllDifferentiableVariables`). static Type deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, @@ -1065,8 +987,8 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, // - No `@noDerivative` stored properties exist. // - All stored properties must have specified associated type equal to // `Self`. - // - If associated type is `TangentVector` or `CotangentVector`, parent type - // must also conform to `AdditiveArithmetic`. + // - If associated type is `TangentVector`, parent type must also conform to + // `AdditiveArithmetic`. bool allMembersAssocTypeEqualsSelf = llvm::all_of(diffProperties, [&](VarDecl *member) { auto memberAssocType = getAssociatedType(member, parentDC, id); @@ -1097,23 +1019,20 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, // Otherwise, check if all stored properties have all `Differentiable` // protocol associated types equal to each other: - // `TangentVector == CotangentVector == AllDifferentiableVariables`. + // `TangentVector == AllDifferentiableVariables`. bool allMembersAssocTypesEqualsSelf = llvm::all_of(diffProperties, [&](VarDecl *member) { auto tangentType = getAssociatedType(member, parentDC, C.Id_TangentVector); - auto cotangentType = - getAssociatedType(member, parentDC, C.Id_CotangentVector); auto allDiffableVarsType = getAssociatedType(member, parentDC, C.Id_AllDifferentiableVariables); - return tangentType->isEqual(cotangentType) && - tangentType->isEqual(allDiffableVarsType); + return tangentType->isEqual(allDiffableVarsType); }); // If all stored properties (excluding ones with `@noDerivative`) have all // `Differentiable` protocol associated types equal to `Self`, then get or - // synthesize `AllDifferentiableVariables` struct and let `TangentVector` and - // `CotangentVector` alias to it. + // synthesize `AllDifferentiableVariables` struct and let `TangentVector` + // alias to it. if (allMembersAssocTypesEqualsSelf) { auto allDiffableVarsStructSynthesis = getOrSynthesizeSingleAssociatedStruct( derived, C.Id_AllDifferentiableVariables); @@ -1129,12 +1048,8 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, allDiffableVarsStruct, allDiffableVarsStruct, TC); addAssociatedTypeAliasDecl(C.Id_TangentVector, allDiffableVarsStruct, allDiffableVarsStruct, TC); - addAssociatedTypeAliasDecl(C.Id_CotangentVector, - allDiffableVarsStruct, allDiffableVarsStruct, TC); addAssociatedTypeAliasDecl(C.Id_TangentVector, parentDC, allDiffableVarsStruct, TC); - addAssociatedTypeAliasDecl(C.Id_CotangentVector, - parentDC, allDiffableVarsStruct, TC); TC.validateDecl(allDiffableVarsStruct); return parentDC->mapTypeIntoContext( allDiffableVarsStruct->getDeclaredInterfaceType()); @@ -1150,8 +1065,6 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { return nullptr; if (requirement->getBaseName() == TC.Context.Id_moved) return deriveDifferentiable_moved(*this); - if (requirement->getBaseName() == TC.Context.Id_tangentVector) - return deriveDifferentiable_tangentVector(*this); if (requirement->getBaseName() == TC.Context.Id_allDifferentiableVariables) return deriveDifferentiable_allDifferentiableVariables(*this); TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement); @@ -1165,9 +1078,6 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) { if (requirement->getBaseName() == TC.Context.Id_TangentVector) return deriveDifferentiable_AssociatedStruct( *this, TC.Context.Id_TangentVector); - if (requirement->getBaseName() == TC.Context.Id_CotangentVector) - return deriveDifferentiable_AssociatedStruct( - *this, TC.Context.Id_CotangentVector); if (requirement->getBaseName() == TC.Context.Id_AllDifferentiableVariables) return deriveDifferentiable_AssociatedStruct( *this, TC.Context.Id_AllDifferentiableVariables); diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 4881f79056dcf..9fc65100e61e3 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -83,7 +83,7 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC, return canDeriveVectorNumeric(Nominal, DC); // SWIFT_ENABLE_TENSORFLOW - if (*knownProtocol == KnownProtocolKind::__Differentiable) + if (*knownProtocol == KnownProtocolKind::Differentiable) return canDeriveDifferentiable(Nominal, DC); if (auto *enumDecl = dyn_cast(Nominal)) { @@ -244,7 +244,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // SWIFT_ENABLE_TENSORFLOW // Differentiable.allDifferentiableVariables if (name.isSimpleName(ctx.Id_allDifferentiableVariables)) - return getRequirement(KnownProtocolKind::__Differentiable); + return getRequirement(KnownProtocolKind::Differentiable); return nullptr; } @@ -305,18 +305,7 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, auto argumentNames = name.getArgumentNames(); if (argumentNames.size() == 1 && argumentNames[0] == ctx.getIdentifier("along")) { - return getRequirement(KnownProtocolKind::__Differentiable); - } - } - - // SWIFT_ENABLE_TENSORFLOW - // Differentiable.tangentVector(from:) - if (name.isCompoundName() && - name.getBaseName() == ctx.Id_tangentVector) { - auto argumentNames = name.getArgumentNames(); - if (argumentNames.size() == 1 && - argumentNames[0] == ctx.getIdentifier("from")) { - return getRequirement(KnownProtocolKind::__Differentiable); + return getRequirement(KnownProtocolKind::Differentiable); } } @@ -374,12 +363,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, // SWIFT_ENABLE_TENSORFLOW // Differentiable.TangentVector - // Differentiable.CotangentVector // Differentiable.AllDifferentiableVariables if (name.isSimpleName(ctx.Id_TangentVector) || - name.isSimpleName(ctx.Id_CotangentVector) || name.isSimpleName(ctx.Id_AllDifferentiableVariables)) - return getRequirement(KnownProtocolKind::__Differentiable); + return getRequirement(KnownProtocolKind::Differentiable); // SWIFT_ENABLE_TENSORFLOW // VectorNumeric.Scalar diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 50bcb54703d37..04be6e308ceb8 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3191,7 +3191,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // The result type should be a two-element tuple. // Either a value and pullback: - // (value: R, pullback: (R.CotangentVector) -> (T.CotangentVector...) + // (value: R, pullback: (R.TangentVector) -> (T.TangentVector...) // Or a value and differential: // (value: R, differential: (T.TangentVector...) -> (R.TangentVector) auto derivativeResultType = derivative->getResultInterfaceType(); @@ -3219,7 +3219,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { autoDiffAssocTyId = ctx.Id_TangentVector; } else if (funcResultElt.getName().str() == "pullback") { kind = AutoDiffAssociatedFunctionKind::VJP; - autoDiffAssocTyId = ctx.Id_CotangentVector; + autoDiffAssocTyId = ctx.Id_TangentVector; } else { TC.diagnose(attr->getLocation(), diag::differentiating_attr_invalid_result_tuple_func_label); @@ -3227,7 +3227,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { return; } // `value: R` result tuple element must conform to `Differentiable`. - auto diffableProto = ctx.getProtocol(KnownProtocolKind::__Differentiable); + auto diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); auto valueResultType = valueResultElt.getType(); if (valueResultType->hasTypeParameter()) valueResultType = derivative->mapTypeIntoContext(valueResultType); diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 173c9cc1acda0..2594e80920036 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -1827,20 +1827,10 @@ checkIndividualConformance(NormalProtocolConformance *conformance, // should go into the new extension we (might) suggest here. impliedDisablesMissingWitnessFixits = true; - // SWIFT_ENABLE_TENSORFLOW - // Before diagnosing implied conditional conformances, check if the - // implied protocol is an underscored protocol for internal purposes - // (e.g. `_Differentiable` or `__Differentiable`, which are workarounds - // for TF-213). If so, allow the implied conformance. - auto *proto = conformance->getProtocol(); - auto &ctx = DC->getASTContext(); - // TODO(TF-213): Remove underscore `Differentiable` protocols. - if (proto != ctx.getProtocol(KnownProtocolKind::__Differentiable) && - proto != ctx.getProtocol(KnownProtocolKind::_Differentiable) ) { - diagnoseConformanceImpliedByConditionalConformance( - TC.Diags, conformance, implyingConf, issueFixit); - conformance->setInvalid(); - } + diagnoseConformanceImpliedByConditionalConformance( + TC.Diags, conformance, implyingConf, issueFixit); + + conformance->setInvalid(); } } @@ -5530,8 +5520,7 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC, return derived.deriveVectorNumeric(Requirement); // SWIFT_ENABLE_TENSORFLOW - // TODO(TF-213): Replace with `KnownProtocolKind::Differentiable`. - case KnownProtocolKind::__Differentiable: + case KnownProtocolKind::Differentiable: return derived.deriveDifferentiable(Requirement); default: @@ -5562,8 +5551,7 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC, return derived.deriveKeyPathIterable(AssocType); case KnownProtocolKind::VectorNumeric: return derived.deriveVectorNumeric(AssocType); - // TODO(TF-213): Replace with `KnownProtocolKind::Differentiable`. - case KnownProtocolKind::__Differentiable: + case KnownProtocolKind::Differentiable: return derived.deriveDifferentiable(AssocType); default: return nullptr; diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 1a8ea18762047..648d129aeca84 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -2629,8 +2629,7 @@ bool TypeResolver::isDifferentiableType(Type ty) { ty = DC->mapTypeIntoContext(ty); } return ty - ->getAutoDiffAssociatedVectorSpace( - AutoDiffAssociatedVectorSpaceKind::Tangent, + ->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(DC->getParentModule())) .hasValue(); } diff --git a/stdlib/private/DifferentiationUnittest/GenericLifetimeTracked.swift b/stdlib/private/DifferentiationUnittest/GenericLifetimeTracked.swift index f61c682852975..8beca492b8a75 100644 --- a/stdlib/private/DifferentiationUnittest/GenericLifetimeTracked.swift +++ b/stdlib/private/DifferentiationUnittest/GenericLifetimeTracked.swift @@ -117,15 +117,10 @@ extension Tracked : Strideable where T : Strideable, T.Stride == T.Stride.Magnit // For now, `T` must be restricted to trivial types (like `Float` or `Tensor`). extension Tracked : Differentiable where T : Differentiable, T == T.AllDifferentiableVariables, - T == T.TangentVector, T == T.CotangentVector + T == T.TangentVector { public typealias AllDifferentiableVariables = Tracked public typealias TangentVector = Tracked - public typealias CotangentVector = Tracked - @inlinable @inline(__always) - public func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return Tracked(value.tangentVector(from: cotangent.value)) - } } @differentiable(vjp: _vjpAdd) @@ -146,14 +141,14 @@ public extension Differentiable { @inlinable func gradient( in f: @differentiable (Self) -> Tracked - ) -> CotangentVector { + ) -> TangentVector { return self.pullback(in: f)(1) } @inlinable func gradient( at x: T, in f: @differentiable (Self, T) -> Tracked - ) -> (CotangentVector, T.CotangentVector) { + ) -> (TangentVector, T.TangentVector) { return self.pullback(at: x, in: f)(1) } } diff --git a/stdlib/public/TensorFlow/DataTypes.swift b/stdlib/public/TensorFlow/DataTypes.swift index 147cce6c94a29..baa3ee14957a4 100644 --- a/stdlib/public/TensorFlow/DataTypes.swift +++ b/stdlib/public/TensorFlow/DataTypes.swift @@ -101,7 +101,6 @@ public protocol TensorFlowFloatingPoint : TensorFlowScalar & BinaryFloatingPoint & Differentiable where Self.RawSignificand: FixedWidthInteger, Self == Self.TangentVector, - Self == Self.CotangentVector, Self == Self.AllDifferentiableVariables {} extension Float : TensorFlowFloatingPoint {} diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index cc8d4c67ad998..7a268c638ecab 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -48,14 +48,14 @@ public extension Differentiable { @inlinable func gradient( in f: @differentiable (Self) -> Tensor - ) -> CotangentVector { + ) -> TangentVector { return self.pullback(in: f)(Tensor(1)) } @inlinable func valueWithGradient( in f: @differentiable (Self) -> Tensor - ) -> (value: Tensor, gradient: CotangentVector) { + ) -> (value: Tensor, gradient: TangentVector) { let (y, pb) = self.valueWithPullback(in: f) return (y, pb(Tensor(1))) } @@ -63,14 +63,14 @@ public extension Differentiable { @inlinable func gradient( at x: T, in f: @differentiable (Self, T) -> Tensor - ) -> (CotangentVector, T.CotangentVector) { + ) -> (TangentVector, T.TangentVector) { return self.pullback(at: x, in: f)(Tensor(1)) } @inlinable func valueWithGradient( at x: T, in f: @differentiable (Self, T) -> Tensor - ) -> (value: Tensor, gradient: (CotangentVector, T.CotangentVector)) { + ) -> (value: Tensor, gradient: (TangentVector, T.TangentVector)) { let (y, pb) = self.valueWithPullback(at: x, in: f) return (y, pb(Tensor(1))) } @@ -85,7 +85,7 @@ public extension Differentiable { @inlinable public func valueWithGradient( at x: T, in f: @differentiable (T) -> Tensor -) -> (value: Tensor, gradient: T.CotangentVector) +) -> (value: Tensor, gradient: T.TangentVector) where T : Differentiable, R : TensorFlowFloatingPoint { let (y, pullback) = valueWithPullback(at: x, in: f) return (y, pullback(Tensor(1))) @@ -94,7 +94,7 @@ where T : Differentiable, R : TensorFlowFloatingPoint { @inlinable public func valueWithGradient( at x: T, _ y: U, in f: @differentiable (T, U) -> Tensor -) -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector)) +) -> (value: Tensor, gradient: (T.TangentVector, U.TangentVector)) where T : Differentiable, U : Differentiable, R : TensorFlowFloatingPoint { let (y, pullback) = valueWithPullback(at: x, y, in: f) @@ -105,7 +105,7 @@ public func valueWithGradient( public func valueWithGradient( at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> Tensor ) -> (value: Tensor, - gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + gradient: (T.TangentVector, U.TangentVector, V.TangentVector)) where T : Differentiable, U : Differentiable, V : Differentiable, R : TensorFlowFloatingPoint { let (y, pullback) = valueWithPullback(at: x, y, z, in: f) @@ -117,7 +117,7 @@ public func valueWithGradient( @inlinable public func valueWithGradient( of f: @escaping @differentiable (T) -> Tensor -) -> (T) -> (value: Tensor, gradient: T.CotangentVector) +) -> (T) -> (value: Tensor, gradient: T.TangentVector) where T : Differentiable, R : TensorFlowFloatingPoint { return { x in valueWithGradient(at: x, in: f) } } @@ -126,7 +126,7 @@ public func valueWithGradient( public func valueWithGradient( of f: @escaping @differentiable (T, U) -> Tensor ) -> (T, U) - -> (value: Tensor, gradient: (T.CotangentVector, U.CotangentVector)) + -> (value: Tensor, gradient: (T.TangentVector, U.TangentVector)) where T : Differentiable, U : Differentiable, R : TensorFlowFloatingPoint { return { x, y in valueWithGradient(at: x, y, in: f) } @@ -137,7 +137,7 @@ public func valueWithGradient( of f: @escaping @differentiable (T, U, V) -> Tensor ) -> (T, U, V) -> (value: Tensor, - gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + gradient: (T.TangentVector, U.TangentVector, V.TangentVector)) where T : Differentiable, U : Differentiable, V : Differentiable, R : TensorFlowFloatingPoint { return { x, y, z in valueWithGradient(at: x, y, z, in: f) } @@ -148,7 +148,7 @@ public func valueWithGradient( @inlinable public func gradient( at x: T, in f: @differentiable (T) -> Tensor -) -> T.CotangentVector +) -> T.TangentVector where T : Differentiable, R : TensorFlowFloatingPoint { return pullback(at: x, in: f)(Tensor(1)) } @@ -156,7 +156,7 @@ public func gradient( @inlinable public func gradient( at x: T, _ y: U, in f: @differentiable (T, U) -> Tensor -) -> (T.CotangentVector, U.CotangentVector) +) -> (T.TangentVector, U.TangentVector) where T : Differentiable, U : Differentiable, R : TensorFlowFloatingPoint { return pullback(at: x, y, in: f)(Tensor(1)) @@ -165,7 +165,7 @@ public func gradient( @inlinable public func gradient( at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> Tensor -) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) +) -> (T.TangentVector, U.TangentVector, V.TangentVector) where T : Differentiable, U : Differentiable, V : Differentiable, R : TensorFlowFloatingPoint { return pullback(at: x, y, z, in: f)(Tensor(1)) @@ -176,7 +176,7 @@ public func gradient( @inlinable public func gradient( of f: @escaping @differentiable (T) -> Tensor -) -> (T) -> T.CotangentVector +) -> (T) -> T.TangentVector where T : Differentiable, R : TensorFlowFloatingPoint { return { x in gradient(at: x, in: f) } } @@ -184,7 +184,7 @@ public func gradient( @inlinable public func gradient( of f: @escaping @differentiable (T, U) -> Tensor -) -> (T, U) -> (T.CotangentVector, U.CotangentVector) +) -> (T, U) -> (T.TangentVector, U.TangentVector) where T : Differentiable, U : Differentiable, R : TensorFlowFloatingPoint { return { x, y in gradient(at: x, y, in: f) } @@ -193,7 +193,7 @@ public func gradient( @inlinable public func gradient( of f: @escaping @differentiable (T, U, V) -> Tensor -) -> (T, U, V) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) +) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector) where T : Differentiable, U : Differentiable, V : Differentiable, R : TensorFlowFloatingPoint { return { x, y, z in gradient(at: x, y, z, in: f) } diff --git a/stdlib/public/TensorFlow/Ops.swift b/stdlib/public/TensorFlow/Ops.swift index abc303fa40494..78b0b3b6d1846 100644 --- a/stdlib/public/TensorFlow/Ops.swift +++ b/stdlib/public/TensorFlow/Ops.swift @@ -107,12 +107,7 @@ extension Tensor : ShapedVectorNumeric where Scalar : Numeric {} extension Tensor : Differentiable where Scalar : TensorFlowFloatingPoint { public typealias TangentVector = Tensor - public typealias CotangentVector = Tensor public typealias AllDifferentiableVariables = Tensor - @inlinable - public func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return cotangent - } } //===----------------------------------------------------------------------===// diff --git a/stdlib/public/core/Array.swift b/stdlib/public/core/Array.swift index 0398d91e76f52..d15172a818e7f 100644 --- a/stdlib/public/core/Array.swift +++ b/stdlib/public/core/Array.swift @@ -1934,7 +1934,7 @@ extension Array where Element : Differentiable { @usableFromInline func _vjpBase() -> - ([Element], (Array.CotangentVector) -> CotangentVector) { + ([Element], (Array.TangentVector) -> TangentVector) { return (base, { $0 }) } @@ -1944,7 +1944,7 @@ extension Array where Element : Differentiable { @usableFromInline static func _vjpInit(_ base: [Element]) -> - (Array.DifferentiableView, (CotangentVector) -> CotangentVector) { + (Array.DifferentiableView, (TangentVector) -> TangentVector) { return (Array.DifferentiableView(base), { $0 }) } @@ -1952,8 +1952,6 @@ extension Array where Element : Differentiable { public typealias TangentVector = Array.DifferentiableView - public typealias CotangentVector = - Array.DifferentiableView public typealias AllDifferentiableVariables = Array.DifferentiableView @@ -1983,19 +1981,6 @@ extension Array where Element : Differentiable { return DifferentiableView( zip(base, direction.base).map { $0.moved(along: $1) }) } - - public func tangentVector(from cotangentVector: CotangentVector) -> - TangentVector { - precondition( - base.count == cotangentVector.base.count, - "cannot use Array.DifferentiableView with count \(base.count) to " + - "get tangentVector from cotangentVector with different count " + - "\(cotangentVector.base.count)") - return TangentVector(zip(base, cotangentVector.base).map { - (selfElement, cotangentVectorElement) in - selfElement.tangentVector(from: cotangentVectorElement) - }) - } } } @@ -2065,16 +2050,14 @@ extension Array.DifferentiableView : AdditiveArithmetic /// Makes `Array` differentiable as the product manifold of `Element` /// multiplied with itself `count` times. extension Array : Differentiable where Element : Differentiable { - // In an ideal world, `TangentVector`, `CotangentVector`, and + // In an ideal world, `TangentVector`, `TangentVector`, and // `AllDifferentiableVariables` would all be `Array`s. Unfortunately, we // can't conform `Array` to `AdditiveArithmetic` for `TangentVector` and - // `CotangentVector`, because `Array` already has a static `+` method with + // `TangentVector`, because `Array` already has a static `+` method with // different semantics from `AdditiveArithmetic` `+`. So we use // `Array.DifferentiableView` for all these associated types. public typealias TangentVector = Array.DifferentiableView - public typealias CotangentVector = - Array.DifferentiableView public typealias AllDifferentiableVariables = Array.DifferentiableView @@ -2092,40 +2075,35 @@ extension Array : Differentiable where Element : Differentiable { public func moved(along direction: TangentVector) -> Array { return DifferentiableView(self).moved(along: direction).base } - - public func tangentVector(from cotangentVector: CotangentVector) -> - TangentVector { - return DifferentiableView(self).tangentVector(from: cotangentVector) - } } extension Array where Element : Differentiable { public func _vjpSubscript(index: Int) -> - (Element, (Element.CotangentVector) -> CotangentVector) + (Element, (Element.TangentVector) -> TangentVector) { - func pullback(_ gradientIn: Element.CotangentVector) -> CotangentVector { - var gradientOut = Array( + func pullback(_ gradientIn: Element.TangentVector) -> TangentVector { + var gradientOut = Array( repeating: .zero, count: count) gradientOut[index] = gradientIn - return CotangentVector(gradientOut) + return TangentVector(gradientOut) } return (self[index], pullback) } public static func _vjpPlus(_ lhs: [Element], _ rhs: [Element]) -> - ([Element], (CotangentVector) -> (CotangentVector, CotangentVector)) { - func pullback(_ gradientIn: CotangentVector) -> - (CotangentVector, CotangentVector) { + ([Element], (TangentVector) -> (TangentVector, TangentVector)) { + func pullback(_ gradientIn: TangentVector) -> + (TangentVector, TangentVector) { precondition( gradientIn.base.count == lhs.count + rhs.count, "+ should receive gradient with count equal to sum of operand " + "counts, but counts are: gradient \(gradientIn.base.count), " + "lhs \(lhs.count), rhs \(rhs.count)") return ( - CotangentVector(Array( + TangentVector(Array( gradientIn.base[0..( + TangentVector(Array( gradientIn.base[lhs.count...]))) } return (lhs + rhs, pullback) diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index 62b35c30f1026..2c5c14baf64ca 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -58,16 +58,12 @@ public protocol ShapedVectorNumeric : VectorNumeric { /// A type that mathematically represents a differentiable manifold whose /// tangent spaces are finite-dimensional. -/// -/// - Note: Do not use this protocol directly. Use `Differentiable` instead. -/// -// TODO(TF-213): Merge this into `Differentiable` when the generic signature -// minimization bug (SR-9595) is fixed. -public protocol __Differentiable { - /// The tangent bundle of this differentiable manifold. - associatedtype TangentVector : AdditiveArithmetic - /// The cotangent bundle of this differentiable manifold. - associatedtype CotangentVector : AdditiveArithmetic +public protocol Differentiable { + associatedtype TangentVector: Differentiable & AdditiveArithmetic + where TangentVector.TangentVector == TangentVector, + AllDifferentiableVariables.AllDifferentiableVariables == + AllDifferentiableVariables, + AllDifferentiableVariables.TangentVector == TangentVector /// The type of all differentiable variables in this type. associatedtype AllDifferentiableVariables : Differentiable @@ -79,38 +75,11 @@ public protocol __Differentiable { /// exponential map. func moved(along direction: TangentVector) -> Self - /// Convert a cotangent vector to its corresponding tangent vector. - func tangentVector(from cotangent: CotangentVector) -> TangentVector -} - -/// A type that mathematically represents a differentiable manifold whose -/// tangent spaces are finite-dimensional. -/// -/// - Note: Do not use this protocol directly. Use `Differentiable` instead. -/// -// TODO(TF-213): Merge this into `Differentiable` when the generic signature -// minimization bug (SR-9595) is fixed. -public protocol _Differentiable : __Differentiable - where TangentVector : Differentiable, CotangentVector : Differentiable { + @available(*, deprecated, + message: "'CotangentVector' is now equal to 'TangentVector' and will be removed") + typealias CotangentVector = TangentVector } -/// A type that mathematically represents a differentiable manifold whose -/// tangent spaces are finite-dimensional. -// BEGIN DIFFERENTIABLE -// - Note: these marks are identified during API doc generation and the -// contents are replaced with the ideal `Differentiable` protocol design. -public protocol Differentiable : _Differentiable - where TangentVector.TangentVector == TangentVector, - TangentVector.CotangentVector == CotangentVector, - CotangentVector.TangentVector == CotangentVector, - CotangentVector.CotangentVector == TangentVector, - AllDifferentiableVariables.AllDifferentiableVariables == - AllDifferentiableVariables, - AllDifferentiableVariables.TangentVector == TangentVector, - AllDifferentiableVariables.CotangentVector == CotangentVector { -} -// END DIFFERENTIABLE - public extension Differentiable where AllDifferentiableVariables == Self { var allDifferentiableVariables: AllDifferentiableVariables { get { return self } @@ -143,14 +112,14 @@ public extension Differentiable { @inlinable public func differentiableFunction( from vjp: @escaping (T) - -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) + -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) ) -> @differentiable (T) -> R { func original(_ x: T) -> R { return vjp(x).value } @differentiating(original) func derivative(_ x: T) - -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) { + -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) { return vjp(x) } return original @@ -160,8 +129,8 @@ public func differentiableFunction( @inlinable public func differentiableFunction( from vjp: @escaping (T, U) - -> (value: R, pullback: (R.CotangentVector) - -> (T.CotangentVector, U.CotangentVector)) + -> (value: R, pullback: (R.TangentVector) + -> (T.TangentVector, U.TangentVector)) ) -> @differentiable (T, U) -> R where T : Differentiable, U : Differentiable, R : Differentiable { func original(_ x: T, _ y: U) -> R { @@ -170,8 +139,8 @@ public func differentiableFunction( @differentiating(original) func derivative(_ x: T, _ y: U) -> (value: R, - pullback: (R.CotangentVector) - -> (T.CotangentVector, U.CotangentVector)) { + pullback: (R.TangentVector) + -> (T.TangentVector, U.TangentVector)) { return vjp(x, y) } return original @@ -179,14 +148,14 @@ public func differentiableFunction( public extension Differentiable { @differentiable(wrt: self, vjp: _vjpWithGrad) - func withGradient(_ body: @escaping (inout CotangentVector) -> Void) -> Self { + func withGradient(_ body: @escaping (inout TangentVector) -> Void) -> Self { return self } @inlinable internal func _vjpWithGrad( - _ body: @escaping (inout CotangentVector) -> Void - ) -> (Self, (CotangentVector) -> CotangentVector) { + _ body: @escaping (inout TangentVector) -> Void + ) -> (Self, (TangentVector) -> TangentVector) { return (self, { grad in var grad = grad body(&grad) @@ -195,14 +164,14 @@ public extension Differentiable { } @differentiable(wrt: self, vjp: _vjpWithGrad) - func withGradient(_ body: @escaping (CotangentVector) -> Void) -> Self { + func withGradient(_ body: @escaping (TangentVector) -> Void) -> Self { return self } @inlinable internal func _vjpWithGrad( - _ body: @escaping (CotangentVector) -> Void - ) -> (Self, (CotangentVector) -> CotangentVector) { + _ body: @escaping (TangentVector) -> Void + ) -> (Self, (TangentVector) -> TangentVector) { return (self, { grad in body(grad) return grad @@ -233,7 +202,7 @@ public extension Differentiable { @inlinable internal func _vjp_withRecomputationInPullbacks( _ body: @escaping @differentiable (Self) -> Result - ) -> (Result, (Result.CotangentVector) -> CotangentVector) { + ) -> (Result, (Result.TangentVector) -> TangentVector) { return valueWithPullback(in: Swift.withRecomputationInPullbacks(body)) } } @@ -246,30 +215,30 @@ public extension Differentiable { @inlinable func valueWithPullback( in f: @differentiable (Self) -> R - ) -> (value: R, pullback: (R.CotangentVector) -> CotangentVector) { + ) -> (value: R, pullback: (R.TangentVector) -> TangentVector) { return Builtin.autodiffApply_vjp_arity1(f, self) } @inlinable func pullback( in f: @differentiable (Self) -> R - ) -> (R.CotangentVector) -> CotangentVector { + ) -> (R.TangentVector) -> TangentVector { return Builtin.autodiffApply_vjp_arity1(f, self).1 } @inlinable func gradient( in f: @differentiable (Self) -> R - ) -> CotangentVector - where R : FloatingPoint, R.CotangentVector == R { + ) -> TangentVector + where R : FloatingPoint, R.TangentVector == R { return self.pullback(in: f)(R(1)) } @inlinable func valueWithGradient( in f: @differentiable (Self) -> R - ) -> (value: R, gradient: CotangentVector) - where R : FloatingPoint, R.CotangentVector == R { + ) -> (value: R, gradient: TangentVector) + where R : FloatingPoint, R.TangentVector == R { let (y, pb) = self.valueWithPullback(in: f) return (y, pb(R(1))) } @@ -278,30 +247,30 @@ public extension Differentiable { func valueWithPullback( at x: T, in f: @differentiable (Self, T) -> R ) -> (value: R, - pullback: (R.CotangentVector) -> (CotangentVector, T.CotangentVector)) { + pullback: (R.TangentVector) -> (TangentVector, T.TangentVector)) { return Builtin.autodiffApply_vjp_arity2(f, self, x) } @inlinable func pullback( at x: T, in f: @differentiable (Self, T) -> R - ) -> (R.CotangentVector) -> (CotangentVector, T.CotangentVector) { + ) -> (R.TangentVector) -> (TangentVector, T.TangentVector) { return Builtin.autodiffApply_vjp_arity2(f, self, x).1 } @inlinable func gradient( at x: T, in f: @differentiable (Self, T) -> R - ) -> (CotangentVector, T.CotangentVector) - where R : FloatingPoint, R.CotangentVector == R { + ) -> (TangentVector, T.TangentVector) + where R : FloatingPoint, R.TangentVector == R { return self.pullback(at: x, in: f)(R(1)) } @inlinable func valueWithGradient( at x: T, in f: @differentiable (Self, T) -> R - ) -> (value: R, gradient: (CotangentVector, T.CotangentVector)) - where R : FloatingPoint, R.CotangentVector == R { + ) -> (value: R, gradient: (TangentVector, T.TangentVector)) + where R : FloatingPoint, R.TangentVector == R { let (y, pb) = self.valueWithPullback(at: x, in: f) return (y, pb(R(1))) } @@ -316,7 +285,7 @@ public extension Differentiable { @inlinable public func valueWithPullback( at x: T, in f: @differentiable (T) -> R -) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector) +) -> (value: R, pullback: (R.TangentVector) -> T.TangentVector) where T : Differentiable, R : Differentiable { return Builtin.autodiffApply_vjp(f, x) } @@ -325,7 +294,7 @@ public func valueWithPullback( public func valueWithPullback( at x: T, _ y: U, in f: @differentiable (T, U) -> R ) -> (value: R, - pullback: (R.CotangentVector) -> (T.CotangentVector, U.CotangentVector)) + pullback: (R.TangentVector) -> (T.TangentVector, U.TangentVector)) where T : Differentiable, U : Differentiable, R : Differentiable { return Builtin.autodiffApply_vjp_arity2(f, x, y) } @@ -334,8 +303,8 @@ public func valueWithPullback( public func valueWithPullback( at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R ) -> (value: R, - pullback: (R.CotangentVector) - -> (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + pullback: (R.TangentVector) + -> (T.TangentVector, U.TangentVector, V.TangentVector)) where T : Differentiable, U : Differentiable, V : Differentiable, R : Differentiable { return Builtin.autodiffApply_vjp_arity3(f, x, y, z) @@ -346,7 +315,7 @@ public func valueWithPullback( @inlinable public func pullback( at x: T, in f: @differentiable (T) -> R -) -> (R.CotangentVector) -> T.CotangentVector +) -> (R.TangentVector) -> T.TangentVector where T : Differentiable, R : Differentiable { return Builtin.autodiffApply_vjp(f, x).1 } @@ -354,7 +323,7 @@ public func pullback( @inlinable public func pullback( at x: T, _ y: U, in f: @differentiable (T, U) -> R -) -> (R.CotangentVector) -> (T.CotangentVector, U.CotangentVector) +) -> (R.TangentVector) -> (T.TangentVector, U.TangentVector) where T : Differentiable, U : Differentiable, R : Differentiable { return Builtin.autodiffApply_vjp_arity2(f, x, y).1 } @@ -362,8 +331,8 @@ public func pullback( @inlinable public func pullback( at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R -) -> (R.CotangentVector) - -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) +) -> (R.TangentVector) + -> (T.TangentVector, U.TangentVector, V.TangentVector) where T : Differentiable, U : Differentiable, V : Differentiable, R : Differentiable { return Builtin.autodiffApply_vjp_arity3(f, x, y, z).1 @@ -374,9 +343,9 @@ public func pullback( @inlinable public func valueWithGradient( at x: T, in f: @differentiable (T) -> R -) -> (value: R, gradient: T.CotangentVector) +) -> (value: R, gradient: T.TangentVector) where T : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { let (y, pullback) = valueWithPullback(at: x, in: f) return (y, pullback(R(1))) } @@ -384,9 +353,9 @@ public func valueWithGradient( @inlinable public func valueWithGradient( at x: T, _ y: U, in f: @differentiable (T, U) -> R -) -> (value: R, gradient: (T.CotangentVector, U.CotangentVector)) +) -> (value: R, gradient: (T.TangentVector, U.TangentVector)) where T : Differentiable, U : Differentiable, - R : FloatingPoint & Differentiable, R.CotangentVector == R { + R : FloatingPoint & Differentiable, R.TangentVector == R { let (y, pullback) = valueWithPullback(at: x, y, in: f) return (y, pullback(R(1))) } @@ -395,9 +364,9 @@ public func valueWithGradient( public func valueWithGradient( at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R ) -> (value: R, - gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + gradient: (T.TangentVector, U.TangentVector, V.TangentVector)) where T : Differentiable, U : Differentiable, V : Differentiable, - R : FloatingPoint & Differentiable, R.CotangentVector == R { + R : FloatingPoint & Differentiable, R.TangentVector == R { let (y, pullback) = valueWithPullback(at: x, y, z, in: f) return (y, pullback(R(1))) } @@ -407,19 +376,19 @@ public func valueWithGradient( @inlinable public func valueWithGradient( of f: @escaping @differentiable (T) -> R -) -> (T) -> (value: R, gradient: T.CotangentVector) +) -> (T) -> (value: R, gradient: T.TangentVector) where T : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return { x in valueWithGradient(at: x, in: f) } } @inlinable public func valueWithGradient( of f: @escaping @differentiable (T, U) -> R -) -> (T, U) -> (value: R, gradient: (T.CotangentVector, U.CotangentVector)) +) -> (T, U) -> (value: R, gradient: (T.TangentVector, U.TangentVector)) where T : Differentiable, U : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return { x, y in valueWithGradient(at: x, y, in: f) } } @@ -428,10 +397,10 @@ public func valueWithGradient( of f: @escaping @differentiable (T, U, V) -> R ) -> (T, U, V) -> (value: R, - gradient: (T.CotangentVector, U.CotangentVector, V.CotangentVector)) + gradient: (T.TangentVector, U.TangentVector, V.TangentVector)) where T : Differentiable, U : Differentiable, V : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return { x, y, z in valueWithGradient(at: x, y, z, in: f) } } @@ -440,27 +409,27 @@ public func valueWithGradient( @inlinable public func gradient( at x: T, in f: @differentiable (T) -> R -) -> T.CotangentVector +) -> T.TangentVector where T : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return pullback(at: x, in: f)(R(1)) } @inlinable public func gradient( at x: T, _ y: U, in f: @differentiable (T, U) -> R -) -> (T.CotangentVector, U.CotangentVector) +) -> (T.TangentVector, U.TangentVector) where T : Differentiable, U : Differentiable, - R : FloatingPoint & Differentiable, R.CotangentVector == R { + R : FloatingPoint & Differentiable, R.TangentVector == R { return pullback(at: x, y, in: f)(R(1)) } @inlinable public func gradient( at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R -) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) +) -> (T.TangentVector, U.TangentVector, V.TangentVector) where T : Differentiable, U : Differentiable, V : Differentiable, - R : FloatingPoint & Differentiable, R.CotangentVector == R { + R : FloatingPoint & Differentiable, R.TangentVector == R { return pullback(at: x, y, z, in: f)(R(1)) } @@ -469,29 +438,29 @@ public func gradient( @inlinable public func gradient( of f: @escaping @differentiable (T) -> R -) -> (T) -> T.CotangentVector +) -> (T) -> T.TangentVector where T : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return { x in gradient(at: x, in: f) } } @inlinable public func gradient( of f: @escaping @differentiable (T, U) -> R -) -> (T, U) -> (T.CotangentVector, U.CotangentVector) +) -> (T, U) -> (T.TangentVector, U.TangentVector) where T : Differentiable, U : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return { x, y in gradient(at: x, y, in: f) } } @inlinable public func gradient( of f: @escaping @differentiable (T, U, V) -> R -) -> (T, U, V) -> (T.CotangentVector, U.CotangentVector, V.CotangentVector) +) -> (T, U, V) -> (T.TangentVector, U.TangentVector, V.TangentVector) where T : Differentiable, U : Differentiable, V : Differentiable, R : FloatingPoint & Differentiable, - R.CotangentVector == R { + R.TangentVector == R { return { x, y, z in gradient(at: x, y, z, in: f) } } @@ -506,14 +475,12 @@ internal protocol _AnyDerivativeBox { // `AdditiveArithmetic` requirements. static var _zero: _AnyDerivativeBox { get } - static var _dualSpaceZero: _AnyDerivativeBox { get } func _adding(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox func _subtracting(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox // `Differentiable` requirements. var _allDifferentiableVariables: _AnyDerivativeBox { get } func _moved(along direction: _AnyDerivativeBox) -> _AnyDerivativeBox - func _tangentVector(from cotangent: _AnyDerivativeBox) -> _AnyDerivativeBox /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } @@ -521,10 +488,7 @@ internal protocol _AnyDerivativeBox { /// Returns the underlying value unboxed to the given type, if possible. func _unboxed(to type: U.Type) -> U? where U : Differentiable, U.TangentVector == U, - U.AllDifferentiableVariables == U, - // NOTE: The requirement below should be defined on `Differentiable`. - // But it causes a crash due to generic signature minimization bug. - U.CotangentVector == U.CotangentVector.AllDifferentiableVariables + U.AllDifferentiableVariables == U } extension _AnyDerivativeBox { @@ -547,10 +511,7 @@ internal func _derivativeTypeMismatch( internal struct _ConcreteDerivativeBox : _AnyDerivativeBox where T : Differentiable, T.TangentVector == T, - T.AllDifferentiableVariables == T, - // NOTE: The requirement below should be defined on `Differentiable`. - // But it causes a crash due to generic signature minimization bug. - T.CotangentVector == T.CotangentVector.AllDifferentiableVariables + T.AllDifferentiableVariables == T { /// The underlying base value. var _base: T @@ -566,10 +527,7 @@ internal struct _ConcreteDerivativeBox : _AnyDerivativeBox func _unboxed(to type: U.Type) -> U? where U : Differentiable, U.TangentVector == U, - U.AllDifferentiableVariables == U, - // NOTE: The requirement below should be defined on `Differentiable`. - // But it causes a crash due to generic signature minimization bug. - U.CotangentVector == U.CotangentVector.AllDifferentiableVariables + U.AllDifferentiableVariables == U { return (self as? _ConcreteDerivativeBox)?._base } @@ -590,10 +548,6 @@ internal struct _ConcreteDerivativeBox : _AnyDerivativeBox return _ConcreteDerivativeBox(T.zero) } - static var _dualSpaceZero: _AnyDerivativeBox { - return _ConcreteDerivativeBox(T.CotangentVector.zero) - } - func _adding(_ x: _AnyDerivativeBox) -> _AnyDerivativeBox { // 0 + x = x if _isOpaqueZero() { @@ -643,21 +597,6 @@ internal struct _ConcreteDerivativeBox : _AnyDerivativeBox } return _ConcreteDerivativeBox(_base.moved(along: directionBase)) } - - func _tangentVector(from cotangent: _AnyDerivativeBox) -> _AnyDerivativeBox { - if _isOpaqueZero() { - return type(of: cotangent)._dualSpaceZero._tangentVector(from: cotangent) - } - if cotangent._isOpaqueZero() { - return cotangent - } - guard let cotangentBase = - cotangent._unboxed(to: T.CotangentVector.self) else { - _derivativeTypeMismatch(T.self, type(of: cotangent._typeErasedBase)) - } - return _ConcreteDerivativeBox( - _base.tangentVector(from: cotangentBase)) - } } /// A type-erased derivative value. @@ -681,28 +620,21 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic { @differentiable(vjp: _vjpInit(_:)) public init(_ base: T) where T : Differentiable, T.TangentVector == T, - T.AllDifferentiableVariables == T, - // NOTE: The requirement below should be defined on `Differentiable`. - // But it causes a crash due to generic signature minimization bug. - T.CotangentVector == T.CotangentVector.AllDifferentiableVariables + T.AllDifferentiableVariables == T { self._box = _ConcreteDerivativeBox(base) } @usableFromInline internal static func _vjpInit( _ base: T - ) -> (AnyDerivative, (AnyDerivative) -> T.CotangentVector) + ) -> (AnyDerivative, (AnyDerivative) -> T.TangentVector) where T : Differentiable, T.TangentVector == T, - T.AllDifferentiableVariables == T, - // NOTE: The requirement below should be defined on `Differentiable`. - // But it causes a crash due to generic signature minimization bug. - T.CotangentVector == T.CotangentVector.AllDifferentiableVariables + T.AllDifferentiableVariables == T { - return (AnyDerivative(base), { v in v.base as! T.CotangentVector }) + return (AnyDerivative(base), { v in v.base as! T.TangentVector }) } public typealias TangentVector = AnyDerivative - public typealias CotangentVector = AnyDerivative public typealias AllDifferentiableVariables = AnyDerivative // `Equatable` requirements (implied by `AdditiveArithmetic`). @@ -761,7 +693,4 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic { public func moved(along direction: TangentVector) -> AnyDerivative { return AnyDerivative(_box: _box._moved(along: direction._box)) } - public func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return AnyDerivative(_box: _box._tangentVector(from: cotangent._box)) - } } diff --git a/stdlib/public/core/FloatingPoint.swift b/stdlib/public/core/FloatingPoint.swift index 29283df0d1eb7..0f7e00a0dd120 100644 --- a/stdlib/public/core/FloatingPoint.swift +++ b/stdlib/public/core/FloatingPoint.swift @@ -1845,7 +1845,7 @@ extension FloatingPoint { @_transparent // SWIFT_ENABLE_TENSORFLOW @differentiable(wrt: self, vjp: _vjpSquareRoot - where Self : Differentiable, Self == Self.CotangentVector) + where Self : Differentiable, Self == Self.TangentVector) public func squareRoot( ) -> Self { var lhs = self lhs.formSquareRoot( ) @@ -1868,7 +1868,7 @@ extension FloatingPoint { @_transparent /// SWIFT_ENABLE_TENSORFLOW @differentiable(wrt: (self, lhs, rhs), vjp: _vjpAddingProduct - where Self : Differentiable, Self == Self.CotangentVector) + where Self : Differentiable, Self == Self.TangentVector) public func addingProduct(_ lhs: Self, _ rhs: Self) -> Self { var addend = self addend.addProduct(lhs, rhs) @@ -2030,7 +2030,7 @@ extension FloatingPoint { /// SWIFT_ENABLE_TENSORFLOW extension FloatingPoint where Self : Differentiable, - Self == Self.CotangentVector { + Self == Self.TangentVector { /// The vector-Jacobian product function of `addingProduct`. Returns the /// original result and pullback of `addingProduct` with respect to `self`, /// `lhs` and `rhs`. diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index 88eeb35af4796..2b883a906747c 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1882,11 +1882,7 @@ extension ${Self} : VectorNumeric { extension ${Self} : Differentiable { public typealias TangentVector = ${Self} - public typealias CotangentVector = ${Self} public typealias AllDifferentiableVariables = ${Self} - public func tangentVector(from cotangent: CotangentVector) -> TangentVector { - return cotangent - } } //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/anyderivative.swift b/test/AutoDiff/anyderivative.swift index 720f5cbd22689..d0fde4db62b69 100644 --- a/test/AutoDiff/anyderivative.swift +++ b/test/AutoDiff/anyderivative.swift @@ -20,19 +20,19 @@ AnyDerivativeTests.test("Vector") { expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan + tan) expectEqual(AnyDerivative(Vector.TangentVector(x: 0, y: 0)), tan - tan) expectEqual(AnyDerivative(Vector.TangentVector(x: 4, y: 4)), tan.moved(along: tan)) - expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan.tangentVector(from: tan)) + expectEqual(AnyDerivative(Vector.TangentVector(x: 2, y: 2)), tan) } AnyDerivativeTests.test("Generic") { var tan = AnyDerivative(Generic.TangentVector(x: 1)) - let cotan = AnyDerivative(Generic.CotangentVector(x: 1)) + let cotan = AnyDerivative(Generic.TangentVector(x: 1)) tan += tan expectEqual(AnyDerivative(Generic.TangentVector(x: 2)), tan) expectEqual(tan, tan.allDifferentiableVariables) expectEqual(AnyDerivative(Generic.TangentVector(x: 4)), tan + tan) expectEqual(AnyDerivative(Generic.TangentVector(x: 0)), tan - tan) expectEqual(AnyDerivative(Generic.TangentVector(x: 4)), tan.moved(along: tan)) - expectEqual(AnyDerivative(Generic.TangentVector(x: 1)), tan.tangentVector(from: cotan)) + expectEqual(AnyDerivative(Generic.TangentVector(x: 1)), cotan) } AnyDerivativeTests.test("Zero") { @@ -45,7 +45,7 @@ AnyDerivativeTests.test("Zero") { expectEqual(zero, zero.allDifferentiableVariables) var tan = AnyDerivative(Vector.TangentVector(x: 1, y: 1)) - expectEqual(zero, tan.tangentVector(from: zero)) + expectEqual(zero, zero) expectEqual(AnyDerivative(Vector.TangentVector.zero), tan - tan) expectNotEqual(AnyDerivative(Vector.TangentVector.zero), zero) expectNotEqual(AnyDerivative.zero, tan - tan) @@ -55,8 +55,8 @@ AnyDerivativeTests.test("Zero") { expectEqual(tan, tan - zero) expectEqual(tan, tan.moved(along: zero)) expectEqual(tan, zero.moved(along: tan)) - expectEqual(zero, tan.tangentVector(from: zero)) - expectEqual(tan, zero.tangentVector(from: tan)) + expectEqual(zero, zero) + expectEqual(tan, tan) } AnyDerivativeTests.test("Casting") { @@ -95,23 +95,23 @@ AnyDerivativeTests.test("Derivatives") { do { let x = AnyDerivative(Vector.TangentVector(x: 4, y: 5)) let y = AnyDerivative(Vector.TangentVector(x: -2, y: -1)) - let v = AnyDerivative(Vector.CotangentVector(x: 1, y: 1)) - let expectedVJP = Vector.CotangentVector(x: 3, y: 3) + let v = AnyDerivative(Vector.TangentVector(x: 1, y: 1)) + let expectedVJP = Vector.TangentVector(x: 3, y: 3) let (𝛁x, 𝛁y) = pullback(at: x, y, in: tripleSum)(v) - expectEqual(expectedVJP, 𝛁x.base as? Vector.CotangentVector) - expectEqual(expectedVJP, 𝛁y.base as? Vector.CotangentVector) + expectEqual(expectedVJP, 𝛁x.base as? Vector.TangentVector) + expectEqual(expectedVJP, 𝛁y.base as? Vector.TangentVector) } do { let x = AnyDerivative(Generic.TangentVector(x: 4)) let y = AnyDerivative(Generic.TangentVector(x: -2)) - let v = AnyDerivative(Generic.CotangentVector(x: 1)) - let expectedVJP = Generic.CotangentVector(x: 3) + let v = AnyDerivative(Generic.TangentVector(x: 1)) + let expectedVJP = Generic.TangentVector(x: 3) let (𝛁x, 𝛁y) = pullback(at: x, y, in: tripleSum)(v) - expectEqual(expectedVJP, 𝛁x.base as? Generic.CotangentVector) - expectEqual(expectedVJP, 𝛁y.base as? Generic.CotangentVector) + expectEqual(expectedVJP, 𝛁x.base as? Generic.TangentVector) + expectEqual(expectedVJP, 𝛁y.base as? Generic.TangentVector) } // Test `AnyDerivative` initializer. @@ -120,7 +120,7 @@ AnyDerivativeTests.test("Derivatives") { T.AllDifferentiableVariables == T, // NOTE: The requirement below should be defined on `Differentiable`. // But it causes a crash due to generic signature minimization bug. - T.CotangentVector == T.CotangentVector.AllDifferentiableVariables + T.TangentVector == T.TangentVector.AllDifferentiableVariables { let any = AnyDerivative(x) return any + any @@ -136,17 +136,17 @@ AnyDerivativeTests.test("Derivatives") { do { let x = Vector.TangentVector(x: 4, y: 5) - let v = AnyDerivative(Vector.CotangentVector(x: 1, y: 1)) + let v = AnyDerivative(Vector.TangentVector(x: 1, y: 1)) let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v) - let expectedVJP = Vector.CotangentVector(x: 2, y: 2) + let expectedVJP = Vector.TangentVector(x: 2, y: 2) expectEqual(expectedVJP, 𝛁x) } do { let x = Generic.TangentVector(x: 4) - let v = AnyDerivative(Generic.CotangentVector(x: 1)) + let v = AnyDerivative(Generic.TangentVector(x: 1)) let 𝛁x = pullback(at: x, in: { x in typeErased(x) })(v) - let expectedVJP = Generic.CotangentVector(x: 2) + let expectedVJP = Generic.TangentVector(x: 2) expectEqual(expectedVJP, 𝛁x) } } diff --git a/test/AutoDiff/autodiff_indirect_diagnostics.swift b/test/AutoDiff/autodiff_indirect_diagnostics.swift index ff931b1caa65d..6695660e8c9e4 100644 --- a/test/AutoDiff/autodiff_indirect_diagnostics.swift +++ b/test/AutoDiff/autodiff_indirect_diagnostics.swift @@ -24,7 +24,7 @@ func weird(_ x: T) -> T { } func vjpWeirdExtraRequirements< T : Differentiable & CaseIterable ->(_ x: T) -> (T, (T.CotangentVector) -> T.CotangentVector) +>(_ x: T) -> (T, (T.TangentVector) -> T.TangentVector) where T.AllCases : ExpressibleByStringLiteral { return (x, { $0 }) diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index 3f092dd28d4cb..35126fa38c58f 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -13,9 +13,7 @@ public struct Foo : Differentiable { // CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables // CHECK-AST: public typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: public typealias CotangentVector = Foo.AllDifferentiableVariables // CHECK-AST: public typealias TangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: public typealias CotangentVector = Foo.AllDifferentiableVariables // CHECK-SILGEN-LABEL: // Foo.a.getter // CHECK-SILGEN-NEXT: sil [transparent] [serialized] [differentiable source 0 wrt 0] [ossa] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float @@ -31,7 +29,6 @@ let _: @differentiable (AdditiveTangentIsSelf) -> Float = { x in // CHECK-AST: internal var a: Float // CHECK-AST: internal init(a: Float) // CHECK-AST: internal typealias TangentVector = AdditiveTangentIsSelf -// CHECK-AST: internal typealias CotangentVector = AdditiveTangentIsSelf // CHECK-AST: internal typealias AllDifferentiableVariables = AdditiveTangentIsSelf struct TestNoDerivative : Differentiable { @@ -46,9 +43,7 @@ struct TestNoDerivative : Differentiable { // CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric // CHECK-AST: internal typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables -// CHECK-AST: internal typealias CotangentVector = TestNoDerivative.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestNoDerivative.AllDifferentiableVariables -// CHECK-AST: internal typealias CotangentVector = TestNoDerivative.AllDifferentiableVariables struct TestKeyPathIterable : Differentiable, KeyPathIterable { var w: Float @@ -62,29 +57,24 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable { // CHECK-AST: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric // CHECK-AST: internal typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables -// CHECK-AST: internal typealias CotangentVector = TestKeyPathIterable.AllDifferentiableVariables // CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables -// CHECK-AST: internal typealias CotangentVector = TestKeyPathIterable.AllDifferentiableVariables -struct GenericCotanMember : Differentiable, AdditiveArithmetic { - var x: T.CotangentVector +struct GenericTanMember : Differentiable, AdditiveArithmetic { + var x: T.TangentVector } // TODO(TF-316): Revisit after `Differentiable` derived conformances behavior is standardized. -// `AllDifferentiableVariables` and `CotangentVector` structs need not both be synthesized. +// `AllDifferentiableVariables` and `TangentVector` structs need not both be synthesized. -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericCotanMember : Differentiable, AdditiveArithmetic where T : Differentiable { -// CHECK-AST: var x: T.CotangentVector -// CHECK-AST: internal init(x: T.CotangentVector) -// CHECK-AST: internal typealias TangentVector = GenericCotanMember -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct AllDifferentiableVariables : Differentiable -// CHECK-AST: internal typealias TangentVector = GenericCotanMember -// CHECK-AST: internal typealias CotangentVector = GenericCotanMember.CotangentVector -// CHECK-AST: internal typealias AllDifferentiableVariables = GenericCotanMember.AllDifferentiableVariables -// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct CotangentVector : Differentiable, AdditiveArithmetic -// CHECK-AST: internal typealias TangentVector = GenericCotanMember.CotangentVector -// CHECK-AST: internal typealias CotangentVector = GenericCotanMember -// CHECK-AST: internal typealias AllDifferentiableVariables = GenericCotanMember.CotangentVector +// CHECK-AST-LABEL: @_fieldwiseDifferentiable internal struct GenericTanMember : Differentiable, AdditiveArithmetic where T : Differentiable +// CHECK-AST: internal var x: T.TangentVector +// CHECK-AST: internal init(x: T.TangentVector) +// CHECK-AST: internal typealias TangentVector = GenericTanMember +// CHECK-AST: internal typealias AllDifferentiableVariables = GenericTanMember +// CHECK-AST: internal static var zero: GenericTanMember { get } +// CHECK-AST: internal static func + (lhs: GenericTanMember, rhs: GenericTanMember) -> GenericTanMember +// CHECK-AST: internal static func - (lhs: GenericTanMember, rhs: GenericTanMember) -> GenericTanMember +// CHECK-AST: @_implements(Equatable, ==(_:_:)) internal static func __derived_struct_equals(_ a: GenericTanMember, _ b: GenericTanMember) -> Bool public struct ConditionallyDifferentiable { public let x: T diff --git a/test/AutoDiff/differentiable_attr_silgen.swift b/test/AutoDiff/differentiable_attr_silgen.swift index 107b88a21625a..a63a8f84ea114 100644 --- a/test/AutoDiff/differentiable_attr_silgen.swift +++ b/test/AutoDiff/differentiable_attr_silgen.swift @@ -33,7 +33,7 @@ public func foo_indir_ret(_ x: Float, _ y: T) -> T { // CHECK: bb0(%0 : $*T, %1 : $Float, %2 : $*T): @_silgen_name("dfoo_indir_ret") -public func dfoo_indir_ret(_ x: Float, _ y: T) -> (T, (T.CotangentVector) -> (Float, T.CotangentVector)) { +public func dfoo_indir_ret(_ x: Float, _ y: T) -> (T, (T.TangentVector) -> (Float, T.TangentVector)) { return (y, { v in (x, v) }) } @@ -111,7 +111,6 @@ extension DiffStoredProp : VectorNumeric { extension DiffStoredProp : Differentiable { typealias TangentVector = DiffStoredProp - typealias CotangentVector = DiffStoredProp } //===----------------------------------------------------------------------===// @@ -151,7 +150,6 @@ extension DiffComputedProp : VectorNumeric { extension DiffComputedProp : Differentiable { typealias TangentVector = DiffComputedProp - typealias CotangentVector = DiffComputedProp } // CHECK-LABEL: DiffComputedProp.computedProp.getter diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index a65e2efe29ff6..45fe8e0f35bda 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -264,7 +264,6 @@ extension JVPStruct : VectorNumeric { extension JVPStruct : Differentiable { typealias TangentVector = JVPStruct - typealias CotangentVector = JVPStruct } extension JVPStruct { @@ -353,7 +352,7 @@ func vjp2ParamsVJP(x: Float, y: Float) -> (Float, (Float) -> (Float, Float)) { return (x + y, { v in (v, v) }) } -// expected-error @+1 {{'vjpWrongTypeVJP' does not have expected type '(Float) -> (Float, (Float.CotangentVector) -> Float.CotangentVector)' (aka '(Float) -> (Float, (Float) -> Float)'}} +// expected-error @+1 {{'vjpWrongTypeVJP' does not have expected type '(Float) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float) -> (Float, (Float) -> Float)'}} @differentiable(vjp: vjpWrongTypeVJP) func vjpWrongType(x: Float) -> Float { return x @@ -387,14 +386,14 @@ struct VJPStruct { @differentiable(vjp: storedPropVJP) let storedImmutableOk: Float - // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.CotangentVector) -> VJPStruct.CotangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} + // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} @differentiable(vjp: storedPropVJP) let storedImmutableWrongType: Double @differentiable(vjp: storedPropVJP) var storedMutableOk: Float - // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.CotangentVector) -> VJPStruct.CotangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} + // expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} @differentiable(vjp: storedPropVJP) var storedMutableWrongType: Double } @@ -421,7 +420,6 @@ extension VJPStruct : VectorNumeric { extension VJPStruct : Differentiable { typealias TangentVector = VJPStruct - typealias CotangentVector = VJPStruct } extension VJPStruct { @@ -459,7 +457,7 @@ extension VJPStruct { } } - // expected-error @+1 {{'computedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.CotangentVector) -> VJPStruct.CotangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} + // expected-error @+1 {{'computedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}} @differentiable(vjp: computedPropVJP) var computedPropWrongType: Double { return 0 @@ -501,7 +499,7 @@ func where1(x: T) -> T { func jvpWhere1(x: T) -> (T, (T.TangentVector) -> T.TangentVector) { return (x, { v in v }) } -func vjpWhere1(x: T) -> (T, (T.CotangentVector) -> T.CotangentVector) { +func vjpWhere1(x: T) -> (T, (T.TangentVector) -> T.TangentVector) { return (x, { v in v }) } @@ -543,7 +541,6 @@ struct ResultLabelTest { struct Tensor : AdditiveArithmetic {} extension Tensor : Differentiable where Scalar : Differentiable { typealias TangentVector = Tensor - typealias CotangentVector = Tensor typealias AllDifferentiableVariables = Tensor func moved(along direction: Tensor) -> Tensor { return self } func tangentVector(from cotangent: Tensor) -> Tensor { return cotangent } @@ -584,12 +581,12 @@ protocol MethodDiffReq { } extension MethodDiffReq where Self : Differentiable { - func vjpFoo(x: Self) -> (Self, (Self.CotangentVector) -> Self.CotangentVector) { + func vjpFoo(x: Self) -> (Self, (Self.TangentVector) -> Self.TangentVector) { return (self, { $0 }) } } -// expected-error @+1 {{'vjpNonvariadic' does not have expected type '(Float, Int32...) -> (Float, (Float.CotangentVector) -> Float.CotangentVector)' (aka '(Float, Int32...) -> (Float, (Float) -> Float)')}} +// expected-error @+1 {{'vjpNonvariadic' does not have expected type '(Float, Int32...) -> (Float, (Float.TangentVector) -> Float.TangentVector)' (aka '(Float, Int32...) -> (Float, (Float) -> Float)')}} @differentiable(wrt: x, vjp: vjpNonvariadic) func variadic(_ x: Float, indices: Int32...) -> Float { return x diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 6d8665459fb78..6e6b7c525fd0f 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -20,7 +20,7 @@ func vjpSinExplicitWrt(x: Float) -> (value: Float, pullback: (Float) -> Float) { func vjpDuplicate(x: Float) -> (value: Float, pullback: (Float) -> Float) { return (x, { $0 }) } -// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.CotangentVector) -> T.CotangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}} +// expected-error @+1 {{'@differentiating' attribute requires function to return a two-element tuple of type '(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or '(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'}} @differentiating(sin) func jvpSinResultInvalid(x: @nondiff Float) -> Float { return x @@ -36,7 +36,7 @@ func vjpSinResultNotDifferentiable(x: Int) -> (value: Int, pullback: (Int) -> In return (x, { $0 }) } // expected-error @+2 {{function result's 'pullback' type does not match 'sin'}} -// expected-note @+2 {{'pullback' does not have expected type '(Float.CotangentVector) -> (Float.CotangentVector)' (aka '(Float) -> Float')}} +// expected-note @+2 {{'pullback' does not have expected type '(Float.TangentVector) -> (Float.TangentVector)' (aka '(Float) -> Float')}} @differentiating(sin) func vjpSinResultInvalidSeedType(x: Float) -> (value: Float, pullback: (Double) -> Double) { return (x, { $0 }) @@ -54,13 +54,13 @@ func jvpGeneric(x: T, y: T) -> (value: T, differential: (T.T func vjpGenericWrongLabel(x: T, y: T) -> (value: T, (T) -> (T, T)) { return (x, { ($0, $0) }) } -// expected-error @+1 {{could not find function 'generic' with expected type ' (x: T) -> T'}} +// expected-error @+1 {{could not find function 'generic' with expected type ' (x: T) -> T'}} @differentiating(generic) -func vjpGenericDiffParamMismatch(x: T) -> (value: T, pullback: (T) -> (T, T)) where T == T.CotangentVector { +func vjpGenericDiffParamMismatch(x: T) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector { return (x, { ($0, $0) }) } @differentiating(generic) // ok -func vjpGenericExtraGenericRequirements(x: T, y: T) -> (value: T, pullback: (T) -> (T, T)) where T == T.CotangentVector { +func vjpGenericExtraGenericRequirements(x: T, y: T) -> (value: T, pullback: (T) -> (T, T)) where T == T.TangentVector { return (x, { ($0, $0) }) } @@ -122,11 +122,11 @@ func invalidDiffWrtFunction(_ fn: @differentiable(Float) -> Float) -> Float { // expected-error @+2 {{type 'T' does not conform to protocol 'FloatingPoint'}} // expected-error @+1 {{could not find function 'foo' with expected type ' (T) -> T'}} @differentiating(foo) -func vjpFoo(_ x: T) -> (value: T, pullback: (T.CotangentVector) -> (T.CotangentVector)) { +func vjpFoo(_ x: T) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector)) { return (x, { $0 }) } @differentiating(foo) -func vjpFooExtraGenericRequirements(_ x: T) -> (value: T, pullback: (T) -> (T)) where T == T.CotangentVector { +func vjpFooExtraGenericRequirements(_ x: T) -> (value: T, pullback: (T) -> (T)) where T == T.TangentVector { return (x, { $0 }) } @@ -135,12 +135,12 @@ func vjpFooExtraGenericRequirements (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) { + static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) { return (x + y, { v in (v, v) }) } } -extension FloatingPoint where Self : Differentiable, Self == Self.CotangentVector { +extension FloatingPoint where Self : Differentiable, Self == Self.TangentVector { // expected-error @+1 {{derivative not in the same file as the original function}} @differentiating(+) static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) { @@ -151,13 +151,13 @@ extension FloatingPoint where Self : Differentiable, Self == Self.CotangentVecto extension Differentiable where Self : AdditiveArithmetic { // expected-error @+1 {{'+' is not defined in the current type context}} @differentiating(+) - static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) { + static func vjpPlus(x: Self, y: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) { return (x + y, { v in (v, v) }) } } -extension AdditiveArithmetic where Self : Differentiable, Self == Self.CotangentVector { - // expected-error @+1 {{could not find function '+' with expected type ' (Self) -> (Self, Self) -> Self'}} +extension AdditiveArithmetic where Self : Differentiable, Self == Self.TangentVector { + // expected-error @+1 {{could not find function '+' with expected type ' (Self) -> (Self, Self) -> Self'}} @differentiating(+) func vjpPlusInstanceMethod(x: Self, y: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) { return (x + y, { v in (v, v) }) @@ -178,9 +178,9 @@ protocol InstanceMethod : Differentiable { extension InstanceMethod { // If `Self` conforms to `Differentiable`, then `Self` is currently always inferred to be a differentiation parameter. // expected-error @+2 {{function result's 'pullback' type does not match 'foo'}} - // expected-note @+2 {{'pullback' does not have expected type '(Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)'}} + // expected-note @+2 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)'}} @differentiating(foo) - func vjpFoo(x: Self) -> (value: Self, pullback: (CotangentVector) -> CotangentVector) { + func vjpFoo(x: Self) -> (value: Self, pullback: (TangentVector) -> TangentVector) { return (x, { $0 }) } @@ -190,21 +190,21 @@ extension InstanceMethod { } @differentiating(foo, wrt: (self, x)) - func vjpFooWrt(x: Self) -> (value: Self, pullback: (CotangentVector) -> (CotangentVector, CotangentVector)) { + func vjpFooWrt(x: Self) -> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) { return (x, { ($0, $0) }) } } extension InstanceMethod { // expected-error @+2 {{function result's 'pullback' type does not match 'bar'}} - // expected-note @+2 {{'pullback' does not have expected type '(Self.CotangentVector) -> (Self.CotangentVector, T.CotangentVector)'}} + // expected-note @+2 {{'pullback' does not have expected type '(Self.TangentVector) -> (Self.TangentVector, T.TangentVector)'}} @differentiating(bar) - func vjpBar(_ x: T) -> (value: Self, pullback: (CotangentVector) -> T.CotangentVector) { + func vjpBar(_ x: T) -> (value: Self, pullback: (TangentVector) -> T.TangentVector) { return (self, { _ in .zero }) } @differentiating(bar) - func vjpBar(_ x: T) -> (value: Self, pullback: (CotangentVector) -> (CotangentVector, T.CotangentVector)) { + func vjpBar(_ x: T) -> (value: Self, pullback: (TangentVector) -> (TangentVector, T.TangentVector)) { return (self, { ($0, .zero) }) } @@ -214,7 +214,7 @@ extension InstanceMethod { } } -extension InstanceMethod where Self == Self.TangentVector, Self == Self.CotangentVector { +extension InstanceMethod where Self == Self.TangentVector { @differentiating(foo2) func vjpFooExtraRequirements(x: Self) -> (value: Self, pullback: (Self) -> (Self, Self)) { return (x, { ($0, $0) }) @@ -226,7 +226,7 @@ extension InstanceMethod where Self == Self.TangentVector, Self == Self.Cotangen } @differentiating(bar2) - func vjpBarExtraRequirements(x: T) -> (value: Self, pullback: (Self) -> (Self, T.CotangentVector)) { + func vjpBarExtraRequirements(x: T) -> (value: Self, pullback: (Self) -> (Self, T.TangentVector)) { return (self, { ($0, .zero) }) } @@ -236,7 +236,7 @@ extension InstanceMethod where Self == Self.TangentVector, Self == Self.Cotangen } } -protocol GenericInstanceMethod : Differentiable where Self == Self.TangentVector, Self == Self.CotangentVector { +protocol GenericInstanceMethod : Differentiable where Self == Self.TangentVector { func instanceMethod(_ x: T) -> T } @@ -245,7 +245,7 @@ extension GenericInstanceMethod { return (x, { v in (self, v) }) } - func vjpInstanceMethod(_ x: T) -> (T, (T.CotangentVector) -> (CotangentVector, T.CotangentVector)) { + func vjpInstanceMethod(_ x: T) -> (T, (T.TangentVector) -> (TangentVector, T.TangentVector)) { return (x, { v in (self, v) }) } } @@ -256,7 +256,7 @@ func bar(_ x: T) -> T { return x } @differentiating(bar) -func vjpBar(_ x: T) -> (value: T, pullback: (T.CotangentVector) -> T.CotangentVector) { +func vjpBar(_ x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) { return (x, { $0 }) } @@ -266,7 +266,7 @@ func baz(_ x: T, _ y: U) -> T { @differentiating(baz) func vjpBaz(_ x: T, _ y: U) -> (value: T, pullback: (T) -> (T, U)) - where T == T.CotangentVector, U == U.CotangentVector + where T == T.TangentVector, U == U.TangentVector { return (x, { ($0, .zero) }) } @@ -276,7 +276,7 @@ protocol InstanceMethodProto { } extension InstanceMethodProto where Self : Differentiable { @differentiating(bar) - func vjpBar() -> (value: Float, pullback: (Float) -> CotangentVector) { + func vjpBar() -> (value: Float, pullback: (Float) -> TangentVector) { return (bar(), { _ in .zero }) } } diff --git a/test/AutoDiff/generics.swift b/test/AutoDiff/generics.swift index bfb922d1bcb07..209cddcb787c3 100644 --- a/test/AutoDiff/generics.swift +++ b/test/AutoDiff/generics.swift @@ -10,11 +10,11 @@ _ = gradient(at: Float(1), in: { x in identity(x) }) // Verify that local buffers are immediately set to zero. // CHECK-SIL-LABEL: sil hidden @AD__identity__adjoint_src_0_wrt_0 -// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.CotangentVector +// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.TangentVector // CHECK-SIL-NEXT: [[ORIG_COTAN_BEGIN:%.*]] = begin_access [init] [static] [no_nested_conflict] [[ORIG_COTAN]] -// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.CotangentVector, #AdditiveArithmetic.zero!getter.1 -// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.CotangentVector.Type -// CHECK-SIL-NEXT: [[EMIT_ZERO_INDIRECT:%.*]] = apply [[ZERO_WITNESS]]<τ_0_0.CotangentVector>([[ORIG_COTAN_BEGIN]], [[ORIG_COTAN_METATYPE]]) +// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter.1 +// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.TangentVector.Type +// CHECK-SIL-NEXT: [[EMIT_ZERO_INDIRECT:%.*]] = apply [[ZERO_WITNESS]]<τ_0_0.TangentVector>([[ORIG_COTAN_BEGIN]], [[ORIG_COTAN_METATYPE]]) // CHECK-SIL-NEXT: end_access [[ORIG_COTAN_BEGIN]] // CHECK-SIL: } diff --git a/test/AutoDiff/separate_cotangent_type.swift b/test/AutoDiff/separate_cotangent_type.swift index eea55bc6396ed..0ebad06398bd8 100644 --- a/test/AutoDiff/separate_cotangent_type.swift +++ b/test/AutoDiff/separate_cotangent_type.swift @@ -8,7 +8,7 @@ import Darwin.C import Glibc #endif -var SeparateCotangentTypeTests = TestSuite("SeparateCotangentType") +var SeparateTangentTypeTests = TestSuite("SeparateTangentType") @_fieldwiseDifferentiable struct DifferentiableSubset : Differentiable { @@ -21,49 +21,35 @@ struct DifferentiableSubset : Differentiable { @_fieldwiseDifferentiable struct TangentVector : Differentiable, VectorNumeric { typealias TangentVector = DifferentiableSubset.TangentVector - typealias CotangentVector = DifferentiableSubset.CotangentVector var w: Float var b: Float - func tangentVector(from cotan: CotangentVector) -> TangentVector { + func tangentVector(from cotan: TangentVector) -> TangentVector { return TangentVector(w: cotan.w, b: cotan.b) } } - @_fieldwiseDifferentiable - struct CotangentVector : Differentiable, VectorNumeric { - typealias TangentVector = DifferentiableSubset.CotangentVector - typealias CotangentVector = DifferentiableSubset.TangentVector - var w: Float - var b: Float - func tangentVector(from cotan: CotangentVector) -> TangentVector { - return TangentVector(w: cotan.w, b: cotan.b) - } - } - func tangentVector(from cotan: CotangentVector) -> TangentVector { - return TangentVector(w: cotan.w, b: cotan.b) - } func moved(along v: TangentVector) -> DifferentiableSubset { return DifferentiableSubset(w: w.moved(along: v.w), b: b.moved(along: v.b), flag: flag) } } -SeparateCotangentTypeTests.test("Trivial") { +SeparateTangentTypeTests.test("Trivial") { let x = DifferentiableSubset(w: 0, b: 1, flag: false) let pb = pullback(at: x) { x in x } - expectEqual(pb(DifferentiableSubset.CotangentVector.zero), DifferentiableSubset.CotangentVector.zero) + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) } -SeparateCotangentTypeTests.test("Initialization") { +SeparateTangentTypeTests.test("Initialization") { let x = DifferentiableSubset(w: 0, b: 1, flag: false) let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) } - expectEqual(pb(DifferentiableSubset.CotangentVector.zero), DifferentiableSubset.CotangentVector.zero) + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) } -// FIXME(SR-9602): If `CotangentVector` is not marked +// FIXME(SR-9602): If `TangentVector` is not marked // `@_fieldwiseProductSpace`, call the VJP of the memberwise initializer. -// SeparateCotangentTypeTests.test("SomeArithmetics") { +// SeparateTangentTypeTests.test("SomeArithmetics") { // let x = DifferentiableSubset(w: 0, b: 1, flag: false) // let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } -// expectEqual(pb(DifferentiableSubset.CotangentVector.zero), DifferentiableSubset.CotangentVector.zero) +// expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) // } runAllTests() diff --git a/test/Sema/struct_differentiable.swift b/test/Sema/struct_differentiable.swift index d6fef7602fec0..79c30329220a8 100644 --- a/test/Sema/struct_differentiable.swift +++ b/test/Sema/struct_differentiable.swift @@ -1,9 +1,9 @@ // SWIFT_ENABLE_TENSORFLOW // RUN: %target-swift-frontend -typecheck -verify -primary-file %s %S/Inputs/struct_differentiable_other_module.swift -// Verify that a `Differentiable` type upholds `AllDifferentiableVariables == CotangentVector`. -func assertAllDifferentiableVariablesEqualsCotangentVector(_: T.Type) - where T : Differentiable, T.AllDifferentiableVariables == T.CotangentVector {} +// Verify that a `Differentiable` type upholds `AllDifferentiableVariables == TangentVector`. +func assertAllDifferentiableVariablesEqualsTangentVector(_: T.Type) + where T : Differentiable, T.AllDifferentiableVariables == T.TangentVector {} // Verify that a type `T` conforms to `AdditiveArithmetic`. func assertConformsToAdditiveArithmetic(_: T.Type) where T : AdditiveArithmetic {} @@ -55,7 +55,6 @@ func testSimple() { var simple = Simple(w: 1, b: 1) simple.allDifferentiableVariables = simple + simple assert(simple.moved(along: simple) == simple + simple) - assert(simple.tangentVector(from: simple) == simple) } // Test type with mixed members. @@ -67,7 +66,6 @@ func testMixed(_ simple: Simple) { var mixed = Mixed(simple: simple, float: 1) mixed.allDifferentiableVariables = Mixed(simple: simple, float: 2) assert(mixed.moved(along: mixed) == mixed + mixed) - assert(mixed.tangentVector(from: mixed) == mixed) } // Test type with manual definition of vector space types to `Self`. @@ -75,13 +73,12 @@ struct VectorSpacesEqualSelf : AdditiveArithmetic, Differentiable { var w: Float var b: Float typealias TangentVector = VectorSpacesEqualSelf - typealias CotangentVector = VectorSpacesEqualSelf typealias AllDifferentiableVariables = VectorSpacesEqualSelf } // Test generic type with vector space types to `Self`. struct GenericVectorSpacesEqualSelf : AdditiveArithmetic, Differentiable - where T : Differentiable, T == T.TangentVector, T == T.CotangentVector, + where T : Differentiable, T == T.TangentVector, T == T.AllDifferentiableVariables { var w: T @@ -91,7 +88,6 @@ func testGenericVectorSpacesEqualSelf() { var genericSame = GenericVectorSpacesEqualSelf(w: 1, b: 1) genericSame.allDifferentiableVariables = genericSame + genericSame assert(genericSame.moved(along: genericSame) == genericSame + genericSame) - assert(genericSame.tangentVector(from: genericSame) == genericSame) } // Test nested type. @@ -106,7 +102,6 @@ func testNested( ) { let nested = Nested(simple: simple, mixed: mixed, generic: genericSame) assert(nested.moved(along: nested) == nested + nested) - assert(nested.tangentVector(from: nested) == nested) _ = pullback(at: nested) { model in model.simple + model.simple @@ -114,7 +109,7 @@ func testNested( } // Test type that does not conform to `AdditiveArithmetic` but whose members do. -// Thus, `Self` cannot be used as `TangentVector` or `CotangentVector`. +// Thus, `Self` cannot be used as `TangentVector` or `TangentVector`. // Vector space structs types must be synthesized. // Note: it would be nice to emit a warning if conforming `Self` to // `AdditiveArithmetic` is possible. @@ -123,11 +118,11 @@ struct AllMembersAdditiveArithmetic : Differentiable { var b: Float } func testAllMembersAdditiveArithmetic() { - assertAllDifferentiableVariablesEqualsCotangentVector(AllMembersAdditiveArithmetic.self) + assertAllDifferentiableVariablesEqualsTangentVector(AllMembersAdditiveArithmetic.self) } // Test type `AllMembersVectorNumeric` whose members conforms to `VectorNumeric`, -// in which case we should make `TangentVector` and `CotangentVector` conform to +// in which case we should make `TangentVector` and `TangentVector` conform to // `VectorNumeric`. struct MyVector : VectorNumeric, Differentiable { var w: Float @@ -139,7 +134,7 @@ struct AllMembersVectorNumeric : Differentiable { } func testAllMembersVectorNumeric() { assertConformsToVectorNumeric(AllMembersVectorNumeric.TangentVector.self) - assertConformsToVectorNumeric(AllMembersVectorNumeric.CotangentVector.self) + assertConformsToVectorNumeric(AllMembersVectorNumeric.TangentVector.self) } // Test type with immutable, differentiable stored property. @@ -161,9 +156,9 @@ struct DifferentiableSubset : Differentiable { func testDifferentiableSubset() { assertConformsToAdditiveArithmetic(DifferentiableSubset.AllDifferentiableVariables.self) assertConformsToVectorNumeric(DifferentiableSubset.AllDifferentiableVariables.self) - assertAllDifferentiableVariablesEqualsCotangentVector(DifferentiableSubset.self) + assertAllDifferentiableVariablesEqualsTangentVector(DifferentiableSubset.self) + _ = DifferentiableSubset.TangentVector(w: 1, b: 1) _ = DifferentiableSubset.TangentVector(w: 1, b: 1) - _ = DifferentiableSubset.CotangentVector(w: 1, b: 1) _ = DifferentiableSubset.AllDifferentiableVariables(w: 1, b: 1) _ = pullback(at: DifferentiableSubset(w: 1, b: 2, flag: false)) { model in @@ -178,7 +173,7 @@ struct NestedDifferentiableSubset : Differentiable { @noDerivative var technicallyDifferentiable: Float } func testNestedDifferentiableSubset() { - assertAllDifferentiableVariablesEqualsCotangentVector(NestedDifferentiableSubset.self) + assertAllDifferentiableVariablesEqualsTangentVector(NestedDifferentiableSubset.self) } // Test type that uses synthesized vector space types but provides custom @@ -267,7 +262,7 @@ extension GenericConstrained : Differentiable where T : Differentiable {} struct TF_260 : Differentiable & AdditiveArithmetic { - var x: T.CotangentVector + var x: T.TangentVector } // TF-269: Test crash when differentiation properties have no getter. @@ -295,25 +290,20 @@ public struct TF_269 : TF_269_Layer { // Test manually customizing vector space types. // Thees should fail. Synthesis is semantically unsupported if vector space // types are customized. -// expected-error @+3 {{type 'VectorSpaceTypeAlias' does not conform to protocol '__Differentiable'}} -// expected-error @+2 {{type 'VectorSpaceTypeAlias' does not conform to protocol '_Differentiable'}} // expected-error @+1 {{type 'VectorSpaceTypeAlias' does not conform to protocol 'Differentiable'}} struct VectorSpaceTypeAlias : AdditiveArithmetic, Differentiable { var w: Float var b: Float typealias TangentVector = Simple } -// expected-error @+3 {{type 'VectorSpaceCustomStruct' does not conform to protocol '__Differentiable'}} -// expected-error @+2 {{type 'VectorSpaceCustomStruct' does not conform to protocol '_Differentiable'}} // expected-error @+1 {{type 'VectorSpaceCustomStruct' does not conform to protocol 'Differentiable'}} struct VectorSpaceCustomStruct : AdditiveArithmetic, Differentiable { var w: Float var b: Float - struct CotangentVector : AdditiveArithmetic, Differentiable { - var w: Float.CotangentVector - var b: Float.CotangentVector - typealias TangentVector = VectorSpaceCustomStruct.CotangentVector - typealias CotangentVector = VectorSpaceCustomStruct.CotangentVector + struct TangentVector : AdditiveArithmetic, Differentiable { + var w: Float.TangentVector + var b: Float.TangentVector + typealias TangentVector = VectorSpaceCustomStruct.TangentVector } } @@ -343,14 +333,10 @@ struct InvalidInitializer : Differentiable { // Test derived conformances in disallowed contexts. -// expected-error @+4 {{type 'OtherFileNonconforming' does not conform to protocol '__Differentiable'}} -// expected-error @+3 {{type 'OtherFileNonconforming' does not conform to protocol '_Differentiable'}} // expected-error @+2 {{type 'OtherFileNonconforming' does not conform to protocol 'Differentiable'}} -// expected-error @+1 {{implementation of '__Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} +// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} extension OtherFileNonconforming : Differentiable {} -// expected-error @+4 {{type 'GenericOtherFileNonconforming' does not conform to protocol '__Differentiable'}} -// expected-error @+3 {{type 'GenericOtherFileNonconforming' does not conform to protocol '_Differentiable'}} // expected-error @+2 {{type 'GenericOtherFileNonconforming' does not conform to protocol 'Differentiable'}} -// expected-error @+1 {{implementation of '__Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} +// expected-error @+1 {{implementation of 'Differentiable' cannot be automatically synthesized in an extension in a different file to the type}} extension GenericOtherFileNonconforming : Differentiable {} diff --git a/test/Sema/struct_differentiable_member_types.swift b/test/Sema/struct_differentiable_member_types.swift index 70b4c601468a8..e213871797c43 100644 --- a/test/Sema/struct_differentiable_member_types.swift +++ b/test/Sema/struct_differentiable_member_types.swift @@ -16,4 +16,4 @@ struct Foo : Differentiable { // synthesized member types require extra non-trivial work, due to the // current type-checker design. let randomGlobal = 1 -extension Foo.CotangentVector : Proto {} +extension Foo.TangentVector : Proto {} diff --git a/test/Serialization/differentiable_attr.swift b/test/Serialization/differentiable_attr.swift index c43f54493ce24..09bb7ec96ebdd 100644 --- a/test/Serialization/differentiable_attr.swift +++ b/test/Serialization/differentiable_attr.swift @@ -51,7 +51,7 @@ func testOnlyWhereClause(x: T) -> T { func testWhereClause(x: T) -> T { return x } -func vjpTestWhereClause(x: T) -> (T, (T.CotangentVector) -> T.CotangentVector) +func vjpTestWhereClause(x: T) -> (T, (T.TangentVector) -> T.TangentVector) where T : Numeric, T : Differentiable { return (x, { v in v }) @@ -67,33 +67,33 @@ extension P { } } extension P where Self : Differentiable { - func vjpTestWhereClauseMethod() -> (Self, (Self.CotangentVector) -> Self.CotangentVector) { + func vjpTestWhereClauseMethod() -> (Self, (Self.TangentVector) -> Self.TangentVector) { return (self, { v in v }) } } -// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.CotangentVector) +// CHECK: @differentiable(wrt: x, vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector) // CHECK-NEXT: func testWhereClauseMethodTypeConstraint(x: T) -> T where T : Numeric -@differentiable(vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.CotangentVector) +@differentiable(vjp: vjpTestWhereClauseMethodTypeConstraint where T : Differentiable, T == T.TangentVector) func testWhereClauseMethodTypeConstraint(x: T) -> T { return x } func vjpTestWhereClauseMethodTypeConstraint(x: T) -> (T, (T) -> T) - where T : Numeric, T : Differentiable, T == T.CotangentVector + where T : Numeric, T : Differentiable, T == T.TangentVector { return (x, { v in v }) } extension P { - // CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self : Differentiable, Self == Self.CotangentVector) + // CHECK: @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self : Differentiable, Self == Self.TangentVector) // CHECK-NEXT: func testWhereClauseMethodTypeConstraint() -> Self - @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self.CotangentVector == Self, Self : Differentiable) + @differentiable(wrt: self, vjp: vjpTestWhereClauseMethodTypeConstraint where Self.TangentVector == Self, Self : Differentiable) func testWhereClauseMethodTypeConstraint() -> Self { return self } } -extension P where Self : Differentiable, Self == Self.CotangentVector { - func vjpTestWhereClauseMethodTypeConstraint() -> (Self, (Self.CotangentVector) -> Self.CotangentVector) { +extension P where Self : Differentiable, Self == Self.TangentVector { + func vjpTestWhereClauseMethodTypeConstraint() -> (Self, (Self.TangentVector) -> Self.TangentVector) { return (self, { v in v }) } } diff --git a/test/TensorFlowRuntime/model_autodiff_runtime.swift b/test/TensorFlowRuntime/model_autodiff_runtime.swift index 7a44f19f76e5f..1245ba8df7acf 100644 --- a/test/TensorFlowRuntime/model_autodiff_runtime.swift +++ b/test/TensorFlowRuntime/model_autodiff_runtime.swift @@ -90,7 +90,7 @@ public protocol Optimizer { associatedtype Scalar: FloatingPoint var learningRate: Scalar { get } mutating func update(_ variables: inout Model.AllDifferentiableVariables, - along vector: Model.CotangentVector) + along vector: Model.TangentVector) } public class RiemannSGD: Optimizer @@ -107,9 +107,9 @@ public class RiemannSGD: Optimizer } public func update(_ model: inout Model.AllDifferentiableVariables, - along vector: Model.CotangentVector) { + along vector: Model.TangentVector) { model = model.moved( - along: learningRate * (.zero - model.tangentVector(from: vector))) + along: learningRate * (.zero - vector)) } } @@ -161,7 +161,7 @@ ModelADTests.testAllBackends("WithRespectToModel") { } } let x = Tensor(0) - var model = Foo(bar: x, baz: x) + let model = Foo(bar: x, baz: x) let d = gradient(at: model) { model in model.applied(to: x) } @@ -191,9 +191,9 @@ ModelADTests.testAllBackends("TF437") { func tf437Step(_ model: inout Model, inputs: Model.Input) -> () - where Model.AllDifferentiableVariables == Model.CotangentVector, + where Model.AllDifferentiableVariables == Model.TangentVector, Model.Output == Tensor { - gradient(at: model) { model -> Model.Output in + _ = gradient(at: model) { model -> Model.Output in let logits = model.applied(to: inputs) return logits.mean() } diff --git a/test/TensorFlowRuntime/tensor_autodiff_indirect.swift b/test/TensorFlowRuntime/tensor_autodiff_indirect.swift index 81f48afdce46b..424a26beebc51 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_indirect.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_indirect.swift @@ -35,7 +35,7 @@ extension Tensor where Scalar : Differentiable & FloatingPoint { func foo(_ x: Scalar) -> Scalar { return x } - func vjpFoo(_ x: Scalar) -> (Scalar, (Scalar.CotangentVector) -> Scalar.CotangentVector) { + func vjpFoo(_ x: Scalar) -> (Scalar, (Scalar.TangentVector) -> Scalar.TangentVector) { return (x, { v in v }) } } @@ -163,7 +163,7 @@ TensorADTests.testAllBackends("GenericLayerMembers") { rhs.applied(to: lhs.applied(to: input)) }(seed) let 𝛁combined = pullback(at: combined) { $0.applied(to: input) }(seed) - expectEqual(Sequential.CotangentVector(lhs: 𝛁lhs, rhs: 𝛁rhs), 𝛁combined) + expectEqual(Sequential.TangentVector(lhs: 𝛁lhs, rhs: 𝛁rhs), 𝛁combined) } testFixedInput() @@ -178,7 +178,7 @@ TensorADTests.testAllBackends("GenericLayerMembers") { rhs.applied(to: lhs.applied(to: input)) }(seed) let 𝛁combined = pullback(at: combined) { $0.applied(to: input) }(seed) - expectEqual(Sequential.CotangentVector(lhs: 𝛁lhs, rhs: 𝛁rhs), 𝛁combined) + expectEqual(Sequential.TangentVector(lhs: 𝛁lhs, rhs: 𝛁rhs), 𝛁combined) } testWrtInput(Tensor(randomUniform: [2, 3])) } @@ -203,7 +203,7 @@ TensorADTests.testAllBackends("GenericWrapperLayer") { let 𝛁wrapper = pullback(at: wrapper) { $0.applied(to: input) }(seed) let 𝛁dense = pullback(at: dense) { $0.applied(to: input) }(seed) - expectEqual(Wrapper.CotangentVector(layer: 𝛁dense), 𝛁wrapper) + expectEqual(Wrapper.TangentVector(layer: 𝛁dense), 𝛁wrapper) } runAllTests()