diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 3c254850f479b..70b9ea742328f 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -556,6 +556,8 @@ struct AutoDiffAssociatedFunctionKind { AutoDiffAssociatedFunctionKind() = default; AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {} + AutoDiffAssociatedFunctionKind(AutoDiffLinearMapKind linMapKind) + : rawValue(static_cast(linMapKind.rawValue)) {} explicit AutoDiffAssociatedFunctionKind(StringRef string); operator innerty() const { return rawValue; } AutoDiffLinearMapKind getLinearMapKind() { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 1002285d47782..cd2580ec842b5 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -91,13 +91,35 @@ static bool isWithoutDerivative(SILValue v) { return false; } +static bool isArrayLiteralIntrinsic(ApplyInst *ai) { + return ai->hasSemantics("array.uninitialized_intrinsic"); +} + static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) { if (auto *applyInst = dyn_cast(v)) - if (applyInst->hasSemantics("array.uninitialized_intrinsic")) + if (isArrayLiteralIntrinsic(applyInst)) return applyInst; return nullptr; } +/// Given a value, find its single `destructure_tuple` user if the value is +/// tuple-typed and such a user exists. +static DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) { + bool foundDestructureTupleUser = false; + if (!value->getType().is()) + return nullptr; + DestructureTupleInst *result = nullptr; + for (auto *use : value->getUses()) { + if (auto *dti = dyn_cast(use->getUser())) { + assert(!foundDestructureTupleUser && + "There should only be one `destructure_tuple` user of a tuple"); + foundDestructureTupleUser = true; + result = dti; + } + } + return result; +} + /// Given a function, gather all of its formal results (both direct and /// indirect) in an order defined by its result type. Note that "formal results" /// refer to result values in the body of the function, not at call sites. @@ -122,6 +144,23 @@ collectAllFormalResultsInTypeOrder(SILFunction &function, : indResults[indResIdx++]); } +/// Given a function, gather all of its direct results in an order defined by +/// its result type. Note that "formal results" refer to result values in the +/// body of the function, not at call sites. +static void +collectAllDirectResultsInTypeOrder(SILFunction &function, + SmallVectorImpl &results) { + SILFunctionConventions convs(function.getLoweredFunctionType(), + function.getModule()); + auto *retInst = cast(function.findReturnBB()->getTerminator()); + auto retVal = retInst->getOperand(); + if (auto *tupleInst = dyn_cast(retVal)) + results.append(tupleInst->getElements().begin(), + tupleInst->getElements().end()); + else + results.push_back(retVal); +} + /// Given a function call site, gather all of its actual results (both direct /// and indirect) in an order defined by its result type. template @@ -256,10 +295,6 @@ static Inst *peerThroughFunctionConversions(SILValue value) { return nullptr; } -static bool isArrayLiteralIntrinsic(ApplyInst *ai) { - return ai->hasSemantics("array.uninitialized_intrinsic"); -} - //===----------------------------------------------------------------------===// // Auxiliary data structures //===----------------------------------------------------------------------===// @@ -369,7 +404,7 @@ class DifferentiableActivityInfo; class LinearMapInfo { private: /// The linear map kind. - AutoDiffAssociatedFunctionKind kind; + AutoDiffLinearMapKind kind; /// The original function. SILFunction *const original; @@ -481,13 +516,13 @@ class LinearMapInfo { // Create a branching trace enum. std::string enumName; switch (kind) { - case swift::AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffLinearMapKind::Differential: enumName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + "__Succ__" + indices.mangle(); break; - case swift::AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffLinearMapKind::Pullback: enumName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + @@ -546,10 +581,10 @@ class LinearMapInfo { auto &s = getADDebugStream(); std::string enumName; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffLinearMapKind::Differential: enumName = "Predecessor"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffLinearMapKind::Pullback: enumName = "Successor"; break; } @@ -573,13 +608,13 @@ class LinearMapInfo { std::string structName; switch (kind) { - case swift::AutoDiffAssociatedFunctionKind::JVP: + case swift::AutoDiffLinearMapKind::Differential: structName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + "__DF__" + indices.mangle(); break; - case swift::AutoDiffAssociatedFunctionKind::VJP: + case swift::AutoDiffLinearMapKind::Pullback: structName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + @@ -609,10 +644,10 @@ class LinearMapInfo { auto &s = getADDebugStream(); std::string structName; switch (kind) { - case AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffLinearMapKind::Differential: structName = "Differential"; break; - case AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffLinearMapKind::Pullback: structName = "Pullback"; break; } @@ -645,10 +680,10 @@ class LinearMapInfo { auto *linMapStruct = getLinearMapStruct(origBB); std::string linearMapName; switch (kind) { - case swift::AutoDiffAssociatedFunctionKind::JVP: + case AutoDiffLinearMapKind::Differential: linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size()); break; - case swift::AutoDiffAssociatedFunctionKind::VJP: + case AutoDiffLinearMapKind::Pullback: linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size()); break; } @@ -675,7 +710,7 @@ class LinearMapInfo { LinearMapInfo &operator=(const LinearMapInfo &) = delete; explicit LinearMapInfo(ADContext &context, - AutoDiffAssociatedFunctionKind kind, + AutoDiffLinearMapKind kind, SILFunction *original, SILFunction *assocFn, const SILAutoDiffIndices &indices, const DifferentiableActivityInfo &activityInfo, @@ -1343,6 +1378,12 @@ using Activity = OptionSet; /// indices. class DifferentiableActivityInfo { private: + // TODO(TF-800): Temporarily store `AutoDiffAssociatedFunctionKind` because + // special logic for `apply` result does not work for reverse-mode. + + // with us handling `apply` instructions differently. + AutoDiffAssociatedFunctionKind kind; + DifferentiableActivityCollection &parent; GenericSignature *assocGenSig = nullptr; @@ -1377,7 +1418,8 @@ class DifferentiableActivityInfo { public: explicit DifferentiableActivityInfo( - DifferentiableActivityCollection &parent, GenericSignature *assocGenSig); + DifferentiableActivityCollection &parent, GenericSignature *assocGenSig, + AutoDiffAssociatedFunctionKind kind); bool isVaried(SILValue value, unsigned independentVariableIndex) const; bool isUseful(SILValue value, unsigned dependentVariableIndex) const; @@ -1479,7 +1521,7 @@ static void collectMinimalIndicesForFunctionCall( } LinearMapInfo::LinearMapInfo(ADContext &context, - AutoDiffAssociatedFunctionKind kind, + AutoDiffLinearMapKind kind, SILFunction *original, SILFunction *assocFn, const SILAutoDiffIndices &indices, const DifferentiableActivityInfo &activityInfo, @@ -1509,7 +1551,19 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) { activityInfo.isActive(paramArgs[i], indices)) return true; - // TODO(bartchr): Check `destructure_tuple` user's results' acvitity. + // TODO(TF-800): Investigate why `apply` result special logic does not work + // for reverse-mode. + if (kind == AutoDiffLinearMapKind::Differential) { + for (auto use : ai->getUses()) { + if (auto *dti = dyn_cast(use->getUser())) { + for (auto result : dti->getResults()) { + if (activityInfo.isActive(result, indices)) + return true; + } + } + } + } + bool hasActiveDirectResults = activityInfo.isActive(ai, indices); bool hasActiveIndirectResults = llvm::any_of(ai->getIndirectSILResults(), [&](SILValue result) { return activityInfo.isActive(result, indices); }); @@ -1525,10 +1579,12 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) { return hasActiveResults && hasActiveParamArguments; } -/// Returns a flag that indicates whether the instruction should be -/// differentiated, given the differentiation indices of the instruction's -/// parent function. Whether the instruction should be differentiated is -/// determined sequentially from the following conditions: +// TODO(TF-800): Investigate why reverse-mode requires special "should +// differentiate logic" and update comment. +/// Returns a flag indicating whether the instruction should be differentiated, +/// given the differentiation indices of the instruction's parent function. +/// Whether the instruction should be differentiated is determined sequentially +/// from the following conditions: /// 1. The instruction is an `apply` and `shouldDifferentiateApplyInst` returns /// true. /// 2. The instruction has an active operand and an active result. @@ -1546,17 +1602,70 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { [&](SILValue val) { return activityInfo.isActive(val, indices); }); if (hasActiveOperands && hasActiveResults) return true; - if (inst->mayHaveSideEffects() && hasActiveOperands) - return true; + + // TODO(TF-800): Investigate why reverse-mode requires special "should + // differentiate logic" and update comment. + switch (kind) { + case AutoDiffLinearMapKind::Differential: { + +#define CHECK_INST_TYPE_ACTIVE_OPERANDS(TYPE) \ +if (isa(inst) && hasActiveOperands) \ + return true; + +#define CHECK_INST_TYPE_ACTIVE_DEST(TYPE) \ +if (auto *castInst = dyn_cast(inst)) { \ + return activityInfo.isActive(castInst->getDest(), indices); \ +} + + CHECK_INST_TYPE_ACTIVE_DEST(StoreInst) + CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrowInst) + CHECK_INST_TYPE_ACTIVE_DEST(CopyAddrInst) + if ((isa(inst) && hasActiveResults)) + return true; + CHECK_INST_TYPE_ACTIVE_OPERANDS(RefCountingInst) + CHECK_INST_TYPE_ACTIVE_OPERANDS(EndAccessInst) + CHECK_INST_TYPE_ACTIVE_OPERANDS(EndBorrowInst) + CHECK_INST_TYPE_ACTIVE_OPERANDS(DeallocationInst) + CHECK_INST_TYPE_ACTIVE_OPERANDS(CopyValueInst) + CHECK_INST_TYPE_ACTIVE_OPERANDS(DestroyValueInst) + CHECK_INST_TYPE_ACTIVE_OPERANDS(DestroyAddrInst) + break; + +#undef CHECK_INST_TYPE_ACTIVE_OPERANDS +#undef CHECK_INST_TYPE_ACTIVE_DEST + } + case AutoDiffLinearMapKind::Pullback: { + if (inst->mayHaveSideEffects() && hasActiveOperands) + return true; + break; + } + } + return false; } -/// Takes an `apply` instruction and adds its linear map function to the +/// Given an `apply` instruction, conditionally adds its linear map function to the /// linear map struct if it's active. void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, const SILAutoDiffIndices &indices) { SmallVector allResults; - allResults.push_back(ai); + // TODO(TF-800): Investigate why `apply` result special logic does not work + // for reverse-mode. + // If differential, handle `apply` result specially. + // If `apply` result is tuple-typed with a `destructure_tuple` user, add the + // results of the `destructure_tuple` user to `allResults` instead of adding + // the `apply` result itself. + bool isDifferentialAndFoundDestructureTupleUser = false; + if (kind == AutoDiffLinearMapKind::Differential) { + if (auto *dti = getSingleDestructureTupleUser(ai)) { + isDifferentialAndFoundDestructureTupleUser = true; + for (auto result : dti->getResults()) + allResults.push_back(result); + } + } + // Otherwise, add `apply` result to `allResults`. + if (!isDifferentialAndFoundDestructureTupleUser) + allResults.push_back(ai); allResults.append(ai->getIndirectSILResults().begin(), ai->getIndirectSILResults().end()); @@ -1573,18 +1682,20 @@ void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, if (!hasActiveResults || !hasActiveArguments) return; - unsigned source; - AutoDiffIndexSubset *parameters; SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall( ai, allResults, indices, activityInfo, activeParamIndices, activeResultIndices); - source = activeResultIndices.front(); - // If function is already marked differentiable, differentiate W.R.T. - // all parameters. + // Compute differentiation result index. + auto source = activeResultIndices.front(); + // Compute differentiation parameters. + // - If the callee has `@differentiable` function type, use differentiation + // parameters from the function type. + // - Otherwise, use the active parameters. + AutoDiffIndexSubset *parameters; auto originalFnSubstTy = ai->getSubstCalleeType(); if (originalFnSubstTy->isDifferentiable()) { parameters = originalFnSubstTy->getDifferentiationParameterIndices(); @@ -1594,25 +1705,22 @@ void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); } - SILAutoDiffIndices curIndices(activeResultIndices.front(), - AutoDiffIndexSubset::get( - builder.getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), - activeParamIndices)); + // Create autodiff indices for the `apply` instruction. + SILAutoDiffIndices applyIndices(source, parameters); // Check for non-differentiable original function type. auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType origFnTy) { // Check and diagnose non-differentiable arguments. for (unsigned paramIndex : range(origFnTy->getNumParameters())) { - if (curIndices.isWrtParameter(paramIndex) && + if (applyIndices.isWrtParameter(paramIndex) && !origFnTy->getParameters()[paramIndex] .getSILStorageType() .isDifferentiable(builder.getModule())) return true; } // Check non-differentiable results. - if (!origFnTy->getResults()[curIndices.source] + if (!origFnTy->getResults()[applyIndices.source] .getSILStorageType() .isDifferentiable(builder.getModule())) return true; @@ -1621,8 +1729,10 @@ void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy)) return; + AutoDiffAssociatedFunctionKind assocFnKind(kind); auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType( - parameters, source, /*differentiationOrder*/ 1, kind, builder.getModule(), + parameters, source, /*differentiationOrder*/ 1, assocFnKind, + builder.getModule(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); auto assocFnResultTypes = @@ -1716,12 +1826,13 @@ class DifferentiableActivityCollection { DominanceInfo *domInfo; PostDominanceInfo *postDomInfo; - DifferentiableActivityInfo &getActivityInfo(GenericSignature *assocGenSig) { + DifferentiableActivityInfo &getActivityInfo( + GenericSignature *assocGenSig, AutoDiffAssociatedFunctionKind kind) { auto activityInfoLookup = activityInfoMap.find(assocGenSig); if (activityInfoLookup != activityInfoMap.end()) return activityInfoLookup->getSecond(); auto insertion = activityInfoMap.insert( - {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)}); + {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig, kind)}); return insertion.first->getSecond(); } @@ -1754,8 +1865,9 @@ DifferentiableActivityCollection::DifferentiableActivityCollection( : function(f), domInfo(di), postDomInfo(pdi) {} DifferentiableActivityInfo::DifferentiableActivityInfo( - DifferentiableActivityCollection &parent, GenericSignature *assocGenSig) - : parent(parent), assocGenSig(assocGenSig) { + DifferentiableActivityCollection &parent, GenericSignature *assocGenSig, + AutoDiffAssociatedFunctionKind kind) + : kind(kind), parent(parent), assocGenSig(assocGenSig) { analyze(parent.domInfo, parent.postDomInfo); } @@ -1800,8 +1912,24 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, if (isVaried(arg, i)) { for (auto indRes : ai->getIndirectSILResults()) setVaried(indRes, i); - for (auto dirRes : ai->getResults()) - setVaried(dirRes, i); + // TODO(TF-800): Investigate why `apply` result special logic + // does not work for reverse-mode. + // If differential, handle `apply` result specially. + // If JVP, handle `apply` result specially. + // If `apply` result is tuple-typed with a `destructure_tuple` + // user, mark the results of the `destructure_tuple` user as + // varied instead of marking the `apply` result itself. + bool isJVPAndFoundDestructureTupleUser = false; + if (kind == swift::AutoDiffAssociatedFunctionKind::JVP) { + if (auto *dti = getSingleDestructureTupleUser(ai)) { + for (auto result : dti->getResults()) + setVaried(result, i); + isJVPAndFoundDestructureTupleUser = true; + } + } + // Otherwise, mark the `apply` result as varied. + if (!isJVPAndFoundDestructureTupleUser) + setVaried(ai, i); } } } @@ -3200,7 +3328,8 @@ class VJPEmitter final passManager.getAnalysis(); auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( - vjp->getLoweredFunctionType()->getGenericSignature()); + vjp->getLoweredFunctionType()->getGenericSignature(), + AutoDiffAssociatedFunctionKind::VJP); LLVM_DEBUG( dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; @@ -3214,7 +3343,7 @@ class VJPEmitter final context(context), original(original), attr(attr), vjp(vjp), invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), vjp)), - pullbackInfo(context, AutoDiffAssociatedFunctionKind::VJP, original, + pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp, attr->getIndices(), activityInfo, getBuilder()) { // Create empty pullback function. pullback = createEmptyPullback(); @@ -3618,7 +3747,15 @@ class VJPEmitter final // Get the parameter indices required for differentiating this function. SmallVector allResults; - allResults.push_back(ai); + // Only append the results from the `destruct_tuple` instruction which are + // active, we don't consider the result of the original apply if it's a + // tuple. + if (auto *dti = getSingleDestructureTupleUser(ai)) { + for (auto result : dti->getResults()) + allResults.push_back(result); + } else { + allResults.push_back(ai); + } allResults.append(ai->getIndirectSILResults().begin(), ai->getIndirectSILResults().end()); SmallVector activeParamIndices; @@ -3642,11 +3779,12 @@ class VJPEmitter final errorOccurred = true; return; } - // Form expected indices by assuming there's only one result. - SILAutoDiffIndices indices(activeResultIndices.front(), + + // Form expected indices, assuming there's only one result. + SILAutoDiffIndices indices( + activeResultIndices.front(), AutoDiffIndexSubset::get( - getASTContext(), - ai->getArgumentsWithoutIndirectResults().size(), + getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices)); // Emit the VJP. @@ -3705,7 +3843,7 @@ class VJPEmitter final // If VJP has not yet been found, emit an `autodiff_function` instruction // on the remapped original function operand and `autodiff_function_extract` - // the VJP. The actual JVP/VJP functions will be populated in the + // the VJP. The actual VJP functions will be populated in the // `autodiff_function` during the transform main loop. if (!vjpValue) { // FIXME: Handle indirect differentiation invokers. This may require some @@ -4052,8 +4190,8 @@ class JVPEmitter final /// Mapping from differential struct field declarations to differential struct /// elements destructured from the linear map basic block argument. In the - /// beginning of each differential basic block, the block's differential struct is - /// destructured into individual elements stored here. + /// beginning of each differential basic block, the block's differential + /// struct is destructured into the individual elements stored here. DenseMap differentialStructElements; /// Mapping from original basic blocks and original values to corresponding @@ -4101,8 +4239,12 @@ class JVPEmitter final static SubstitutionMap getSubstitutionMap(SILFunction *original, SILFunction *jvp) { auto substMap = original->getForwardingSubstitutionMap(); - if (auto *jvpGenEnv = jvp->getGenericEnvironment()) - substMap = substMap.subst(jvpGenEnv->getForwardingSubstitutionMap()); + if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { + auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); + substMap = SubstitutionMap::get( + jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, + LookUpConformanceInSubstitutionMap(jvpSubstMap)); + } return substMap; } @@ -4116,14 +4258,16 @@ class JVPEmitter final passManager.getAnalysis(); auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( - jvp->getLoweredFunctionType()->getGenericSignature()); + jvp->getLoweredFunctionType()->getGenericSignature(), + AutoDiffAssociatedFunctionKind::JVP); LLVM_DEBUG( dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; } static SILBuilder - initializeDifferentialAndBuilder(ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, + initializeDifferentialAndBuilder(ADContext &context, SILFunction *original, + SILDifferentiableAttr *attr, LinearMapInfo *linearMapInfo) { auto *differential = createEmptyDifferential(context, original, attr, linearMapInfo); @@ -4147,7 +4291,8 @@ class JVPEmitter final auto insertion = differentialStructElements.insert({std::get<0>(pair), std::get<1>(pair)}); (void)insertion; - assert(insertion.second && "A differential struct element already exists!"); + assert(insertion.second && + "A differential struct element mapping already exists!"); } } @@ -4210,7 +4355,8 @@ class JVPEmitter final //--------------------------------------------------------------------------// AdjointValue makeZeroTangentValue(SILType type) { - return AdjointValue::createZero(allocator, remapType(type)); + return AdjointValue::createZero( + allocator, remapSILTypeInDifferential(type)); } AdjointValue makeConcreteTangentValue(SILValue value) { @@ -4296,7 +4442,8 @@ class JVPEmitter final assert(originalBuffer->getType().isAddress()); auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); - assert(insertion.second); (void)insertion; + assert(insertion.second && "tangent buffer already exists."); + (void)insertion; } SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { @@ -4309,30 +4456,52 @@ class JVPEmitter final } //--------------------------------------------------------------------------// - // Type transformer + // Differential type calculations //--------------------------------------------------------------------------// + /// Substitutes all replacement types of the given substitution map using the + /// tangent function's substitution map. + SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) { + return substMap.subst(getDifferential().getForwardingSubstitutionMap()); + } + + /// Remap any archetypes into the differential function's context. + Type remapTypeInDifferential(Type ty) { + if (ty->hasArchetype()) + return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); + return getDifferential().mapTypeIntoContext(ty); + } + + /// Remap any archetypes into the differential function's context. + SILType remapSILTypeInDifferential(SILType ty) { + if (ty.hasArchetype()) + return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); + return getDifferential().mapTypeIntoContext(ty); + } + + /// Find the tangent space of a given canonical type. Optional getTangentSpace(CanType type) { return type->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(getModule().getSwiftModule())); } /// Assuming the given type conforms to `Differentiable` after remapping, - /// returns the associated tangent space type. + /// returns the associated tangent space SIL type. SILType getRemappedTangentType(SILType type) { return SILType::getPrimitiveType( - getTangentSpace(remapType(type).getASTType())->getCanonicalType(), + getTangentSpace(remapSILTypeInDifferential(type).getASTType()) + ->getCanonicalType(), type.getCategory()); } //--------------------------------------------------------------------------// - // Tngent value mapping + // Tangent value mapping //--------------------------------------------------------------------------// /// Get the tangent for an original value. The given value must be in the /// original function. /// - /// This method first tries to find an entry in `tangentValueMap`. If a tangent + /// This method first tries to find an entry in `tangentValueMap`. If an entry /// doesn't exist, create a zero tangent. AdjointValue getTangentValue(SILValue originalValue) { assert(originalValue->getType().isObject()); @@ -4346,6 +4515,14 @@ class JVPEmitter final /// Map the tangent value to the given original value. void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, AdjointValue newTangentValue) { + if (auto *defInst = originalValue->getDefiningInstruction()) { + bool isTupleTypedApplyResult = + isa(defInst) && originalValue->getType().is(); + assert(!isTupleTypedApplyResult && + "Should not set tangent value for tuple-typed result from `apply` " + "instruction; use `destructure_tuple` on `apply` result and set " + "tangent value for `destructure_tuple` results instead."); + } assert(originalValue->getType().isObject()); assert(newTangentValue.getType().isObject()); assert(originalValue->getFunction() == original); @@ -4362,15 +4539,16 @@ class JVPEmitter final //--------------------------------------------------------------------------// // Tangent emission helpers //--------------------------------------------------------------------------// +public: +#define CLONE_AND_EMIT_TANGENT(INST, ID) \ + void visit##INST##Inst(INST##Inst *inst) { \ + TypeSubstCloner::visit##INST##Inst(inst); \ + if (differentialInfo.shouldDifferentiateInstruction(inst)) \ + emitTangentFor##INST##Inst(inst); \ + } \ + void emitTangentFor##INST##Inst(INST##Inst *(ID)) - void emitTangentForDestroyValueInst(DestroyValueInst *dvi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = dvi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); - diffBuilder.emitDestroyValue(loc, tanVal); - } - - void emitTangentForBeginBorrow(BeginBorrowInst *bbi) { + CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { auto &diffBuilder = getDifferentialBuilder(); auto loc = bbi->getLoc(); auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); @@ -4379,14 +4557,21 @@ class JVPEmitter final makeConcreteTangentValue(tanValBorrow)); } - void emitTangentForEndBorrow(EndBorrowInst *ebi) { + CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { auto &diffBuilder = getDifferentialBuilder(); auto loc = ebi->getLoc(); auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); diffBuilder.emitEndBorrowOperation(loc, tanVal); } - void emitTangentForCopyValueInst(CopyValueInst *cvi) { + CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = dvi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); + diffBuilder.emitDestroyValue(loc, tanVal); + } + + CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { auto &diffBuilder = getDifferentialBuilder(); auto tan = getTangentValue(cvi->getOperand()); auto tanVal = materializeTangent(tan, cvi->getLoc()); @@ -4395,44 +4580,437 @@ class JVPEmitter final makeConcreteTangentValue(tanValCopy)); } - void emitTangentForReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); + /// Handle `struct_extract` instruction. + /// Original: y = struct_extract x, #field + /// Tangent: tan[y] = struct_extract tan[x], tan[#field]] + CLONE_AND_EMIT_TANGENT(StructExtract, sei) { + assert(!sei->getField()->getAttrs().hasAttribute() && + "`struct_extract` with `@noDerivative` field should not be " + "differentiated; activity analysis should not marked as varied."); + + auto diffBuilder = getDifferentialBuilder();; + auto tangentVectorTy = + getRemappedTangentType(sei->getOperand()->getType()); + auto *tangentVectorDecl = + tangentVectorTy.getStructOrBoundGenericStruct(); + + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If the tangent space is the original struct, then field is the same. + if (tangentVectorDecl == sei->getStructDecl()) + tanField = sei->getField(); + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(sei->getField()->getName()); + if (tanFieldLookup.empty()) { + context.emitNondifferentiabilityError( + sei, invoker, + diag::autodiff_stored_property_no_corresponding_tangent, + sei->getStructDecl()->getNameStr(), + sei->getField()->getNameStr()); + errorOccurred = true; + return; + } + tanField = cast(tanFieldLookup.front()); + } + // Emit tangent `struct_extract`. + auto tanStruct = + materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc()); + auto tangentInst = + diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField); + // Update tangent value mapping for `struct_extract` result. + auto tangentResult = makeConcreteTangentValue(tangentInst); + setTangentValue(sei->getParent(), sei, tangentResult); + } + + /// Handle `struct_element_addr` instruction. + /// Original: y = struct_element_addr x, #field + /// Tangent: tan[y] = struct_element_addr tan[x], tan[#field] + CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { + assert(!seai->getField()->getAttrs().hasAttribute() && + "`struct_element_addr` with `@noDerivative` field should not be " + "differentiated; activity analysis should not marked as varied."); + auto diffBuilder = getDifferentialBuilder(); - // This vector will contain all the materialized return elements. - SmallVector retElts; - // This vector will contain all indirect parameter tangent buffers. - // TODO: Handle indirect results. - auto tanParam = - materializeTangent(getTangentValue(ri->getOperand()), loc); - diffBuilder.createReturn(ri->getLoc(), tanParam); + auto *bb = seai->getParent(); + auto tangentVectorTy = + getRemappedTangentType(seai->getOperand()->getType()); + auto *tangentVectorDecl = + tangentVectorTy.getStructOrBoundGenericStruct(); + + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If the tangent space is the original struct, then field is the same. + if (tangentVectorDecl == seai->getStructDecl()) + tanField = seai->getField(); + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(seai->getField()->getName()); + if (tanFieldLookup.empty()) { + context.emitNondifferentiabilityError( + seai, invoker, + diag::autodiff_stored_property_no_corresponding_tangent, + seai->getStructDecl()->getNameStr(), + seai->getField()->getNameStr()); + errorOccurred = true; + return; + } + tanField = cast(tanFieldLookup.front()); + } + + // Emit tangent `struct_element_addr`. + auto tanOperand = getTangentBuffer(bb, seai->getOperand()); + auto tangentInst = diffBuilder.createStructElementAddr( + seai->getLoc(), tanOperand, tanField); + // Update tangent buffer map for `struct_element_addr`. + setTangentBuffer(bb, seai, tangentInst); + } + + /// Handle `load` instruction. + /// Original: y = load x + /// Tangent: tan[y] = load tan[x] + CLONE_AND_EMIT_TANGENT(Load, li) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = li->getParent(); + auto loc = li->getLoc(); + auto tanBuf = getTangentBuffer(bb, li->getOperand()); + auto tanVal = diffBuilder.emitLoadValueOperation( + loc, tanBuf, li->getOwnershipQualifier()); + setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); + } + + /// Handle `load_borrow` instruction. + /// Original: y = load_borrow x + /// Tangent: tan[y] = load_borrow tan[x] + CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = lbi->getParent(); + auto loc = lbi->getLoc(); + auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); + auto tanVal = diffBuilder.emitLoadBorrowOperation( + loc, tanBuf); + setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); + } + + /// Handle `store` instruction in the differential. + /// Original: store x to y + /// Tangent: store tan[x] to tan[y] + CLONE_AND_EMIT_TANGENT(Store, si) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = si->getLoc(); + auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); + auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); + if (errorOccurred) + return; + diffBuilder.emitStoreValueOperation( + loc, tanValSrc, tanValDest, si->getOwnershipQualifier()); + } + + /// Handle `store_borrow` instruction in the differential. + /// Original: store_borrow x to y + /// Tangent: store_borrow tan[x] to tan[y] + CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = sbi->getLoc(); + auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); + auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); + if (errorOccurred) + return; + diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); } + /// Handle `copy_addr` instruction. + /// Original: copy_addr x to y + /// Tangent: copy_addr tan[x] to tan[y] + CLONE_AND_EMIT_TANGENT(CopyAddr, cai) { + auto *diffGenEnv = getDifferential().getGenericEnvironment(); + auto diffGenSig = diffGenEnv + ? diffGenEnv->getGenericSignature()->getCanonicalSignature() + : nullptr; + Lowering::GenericContextScope genericContextScope( + context.getTypeConverter(), diffGenSig); + + auto diffBuilder = getDifferentialBuilder(); + auto loc = cai->getLoc(); + auto *bb = cai->getParent(); + auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); + auto tanDest = getTangentBuffer(bb, cai->getDest()); + if (errorOccurred) + return; + + diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), + cai->isInitializationOfDest()); + } + + /// Handle `begin_access` instruction (and do differentiability checks). + /// Original: y = begin_access x + /// Tangent: tan[y] = begin_access tan[x] + CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { + // Check for non-differentiable writes. + if (bai->getAccessKind() == SILAccessKind::Modify) { + if (auto *gai = dyn_cast(bai->getSource())) { + context.emitNondifferentiabilityError(bai, invoker, + diag::autodiff_cannot_differentiate_writes_to_global_variables); + errorOccurred = true; + return; + } + if (auto *pbi = dyn_cast(bai->getSource())) { + context.emitNondifferentiabilityError(bai, invoker, + diag::autodiff_cannot_differentiate_writes_to_mutable_captures); + errorOccurred = true; + return; + } + } + + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = bai->getParent(); + + auto tanSrc = getTangentBuffer(bb, bai->getSource()); + auto *tanDest = diffBuilder.createBeginAccess( + bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), + bai->hasNoNestedConflict(), bai->isFromBuiltin()); + setTangentBuffer(bb, bai, tanDest); + } + + /// Handle `end_access` instruction. + /// Original: begin_access x + /// Tangent: end_access tan[x] + CLONE_AND_EMIT_TANGENT(EndAccess, eai) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = eai->getParent(); + auto loc = eai->getLoc(); + auto tanSrc = getTangentBuffer(bb, eai->getOperand()); + diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting()); + } + + /// Handle `alloc_stack` instruction. + /// Original: y = alloc_stack $T + /// Tangent: tan[y] = alloc_stack $T.Tangent + CLONE_AND_EMIT_TANGENT(AllocStack, asi) { + auto &diffBuilder = getDifferentialBuilder(); + auto *mappedAllocStackInst = diffBuilder.createAllocStack( + asi->getLoc(), getRemappedTangentType(asi->getElementType())); + bufferMap.try_emplace({asi->getParent(), asi}, + mappedAllocStackInst); + } + + /// Handle `dealloc_stack` instruction. + /// Original: dealloc_stack x + /// Tangent: dealloc_stack tan[x] + CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { + auto &diffBuilder = getDifferentialBuilder(); + auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); + diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); + } + + /// Handle `destroy_addr` instruction. + /// Original: destroy_addr x + /// Tangent: destroy_addr tan[x] + CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { + auto &diffBuilder = getDifferentialBuilder(); + auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); + diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); + } + + /// Handle `struct` instruction. + /// Original: y = struct $T (x0, x1, x2, ...) + /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) + CLONE_AND_EMIT_TANGENT(Struct, si) { + auto &diffBuilder = getDifferentialBuilder(); + SmallVector tangentElements; + for (auto elem : si->getElements()) + tangentElements.push_back(getTangentValue(elem).getConcreteValue()); + auto tanExtract = diffBuilder.createStruct( + si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); + setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); + } + + /// Handle `tuple` instruction. + /// Original: y = tuple (x0, x1, x2, ...) + /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) + CLONE_AND_EMIT_TANGENT(Tuple, ti) { + auto diffBuilder = getDifferentialBuilder(); + + // Get the tangents of all the tuple elements. + SmallVector tangentTupleElements; + for (auto elem : ti->getElements()) { + tangentTupleElements.push_back( + materializeTangent(getTangentValue(elem), ti->getLoc())); + } + + // Emit the instruction and add the tangent mapping. + auto tanTuple = diffBuilder.createTuple(ti->getLoc(), tangentTupleElements); + setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); + } + + /// Handle `tuple_extract` instruction. + /// Original: y = tuple_element_addr x, + /// Tangent: tan[y] = tuple_element_addr tan[x], + CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { + auto &diffBuilder = getDifferentialBuilder(); + auto origTupleTy = teai->getOperand()->getType().castTo(); + unsigned tanIndex = 0; + for (unsigned i : range(teai->getFieldNo())) { + if (getTangentSpace( + origTupleTy->getElement(i).getType()->getCanonicalType())) + ++tanIndex; + } + auto tanType = getRemappedTangentType(teai->getType()); + auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); + SILValue tanBuf; + // If the tangent buffer of the source does not have a tuple type, then + // it must represent a "single element tuple type". Use it directly. + if (!tanSource->getType().is()) { + tanBuf = tanSource; + } else { + tanBuf = diffBuilder.createTupleElementAddr( + teai->getLoc(), tanSource, tanIndex, tanType); + } + bufferMap.try_emplace({teai->getParent(), teai}, tanBuf); + } + + /// Handle `tuple_extract` instruction. + /// Original: y = tuple_extract x, + /// Tangent: tan[y] = tuple_extract tan[x], + CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = tei->getLoc(); + auto origTupleTy = tei->getOperand()->getType().castTo(); + unsigned tanIndex = 0; + for (unsigned i : range(tei->getFieldNo())) { + if (getTangentSpace( + origTupleTy->getElement(i).getType()->getCanonicalType())) + ++tanIndex; + } + auto tanType = getRemappedTangentType(tei->getType()); + auto tanSource = materializeTangent( + getTangentValue(tei->getOperand()), loc); + SILValue tanBuf; + // If the tangent buffer of the source does not have a tuple type, then + // it must represent a "single element tuple type". Use it directly. + if (!tanSource->getType().is()) { + setTangentValue(tei->getParent(), tei, + makeConcreteTangentValue(tanSource)); + } else { + tanBuf = diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); + bufferMap.try_emplace({tei->getParent(), tei}, tanBuf); + } + } + + /// Handle `destructure_tuple` instruction. + /// Original: (y0, y1, y2, ...) = destructure_tuple x, + /// Tangent: (tan[y0], tan[y1], tan[y2], ...) = destructure_tuple tan[x], + CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { + auto &diffBuilder = getDifferentialBuilder(); + auto *bb = dti->getParent(); + auto loc = dti->getLoc(); + + SmallVector activeOrigResults; + bool hasActiveResult = false; + for (auto result : dti->getResults()) { + if (activityInfo.isActive(result, getIndices())) { + activeOrigResults.push_back(result); + hasActiveResult = true; + break; + } + } + assert(!activeOrigResults.empty() && + "original 'destructure_tuple' should have at least one active " + "result"); + + auto tanTuple = + materializeTangent(getTangentValue(dti->getOperand()), loc); + auto *tupleElements = diffBuilder.createDestructureTuple(loc, tanTuple); + for (auto i : range(tupleElements->getNumResults())) { + auto origElem = dti->getResult(i); + auto tanElem = tupleElements->getResult(i); + setTangentValue(bb, origElem, makeConcreteTangentValue(tanElem)); + } + } + +#undef CLONE_AND_EMIT_TANGENT + + /// Handle `apply` instruction. + /// Original: y = apply f(x) + /// Tangent: tan[y] = apply diff_f(tan[x]) void emitTangentForApplyInst(ApplyInst *ai, - SILAutoDiffIndices &actualIndices) { + const SILAutoDiffIndices &actualIndices, + CanSILFunctionType originalDifferentialType) { assert(differentialInfo.shouldDifferentiateApplyInst(ai)); auto *bb = ai->getParent(); auto loc = ai->getLoc(); - auto diffBuilder = getDifferentialBuilder(); + auto &diffBuilder = getDifferentialBuilder(); - // Get the differential. + // Get the differential value. auto *field = differentialInfo.lookUpLinearMapDecl(ai); assert(field); SILValue differential = getDifferentialStructElement(bb, field); + auto differentialType = remapSILTypeInDifferential(differential->getType()) + .castTo(); + // Get the differential arguments. SmallVector diffArgs; - for (auto origArg : ai->getArguments()) { - // Get the tangent value of the original parameter. - if (!activityInfo.isActive(origArg, getIndices())) - continue; - SILValue tanParam; - if (origArg->getType().isObject()) { - tanParam = materializeTangent(getTangentValue(origArg), loc); - } else { - tanParam = getTangentBuffer(ai->getParent(), origArg); + + for (auto indRes : ai->getIndirectSILResults()) + diffArgs.push_back(getTangentBuffer(bb, indRes)); + + auto paramArgs = ai->getArgumentsWithoutIndirectResults(); + // Get the tangent value of the original arguments. + for (auto i : indices(paramArgs)) { + auto origArg = paramArgs[i]; + // If the argument is not active: + // - Skip the element, if it is not differentiable. + // - Otherwise, add a zero value to that location. + if (!activityInfo.isActive(origArg, getIndices())) { + auto origCalleeType = ai->getSubstCalleeType(); + if (!origCalleeType->isDifferentiable()) + continue; + auto actualOrigCalleeIndices = + origCalleeType->getDifferentiationParameterIndices(); + if (actualOrigCalleeIndices->contains(i)) { + SILValue tanParam; + if (origArg->getType().isObject()) { + tanParam = emitZeroDirect( + getRemappedTangentType(origArg->getType()).getASTType(), loc); + diffArgs.push_back(tanParam); + } else { + tanParam = diffBuilder.createAllocStack( + loc, getRemappedTangentType(origArg->getType())); + emitZeroIndirect( + getRemappedTangentType(origArg->getType()).getASTType(), tanParam, + loc); + } + } + } + // Otherwise, if the argument is active, handle the argument normally by + // getting its tangent value. + else { + SILValue tanParam; + if (origArg->getType().isObject()) { + tanParam = materializeTangent(getTangentValue(origArg), loc); + } else { + tanParam = getTangentBuffer(ai->getParent(), origArg); + } + diffArgs.push_back(tanParam); if (errorOccurred) return; } - diffArgs.push_back(tanParam); + } + + // If callee differential was reabstracted in JVP, reabstract the callee + // differential. + if (!differentialType->isEqual(originalDifferentialType)) { + SILOptFunctionBuilder fb(context.getTransform()); + auto *thunk = getOrCreateReabstractionThunk( + fb, context.getModule(), loc, &getDifferential(), + differentialType, originalDifferentialType); + auto *thunkRef = diffBuilder.createFunctionRef(loc, thunk); + differential = diffBuilder.createPartialApply( + loc, thunkRef, + remapSubstitutionMapInDifferential(thunk->getForwardingSubstitutionMap()), + {differential}, differentialType->getCalleeConvention()); } // Call the differential. @@ -4443,23 +5021,198 @@ class JVPEmitter final assert(differentialCall->getNumResults() == 1 && "Expected differential to return one result"); - // TODO: Generalize for indirect results, multiple results, etc. - auto origResult = ai->getResult(actualIndices.source); + // Get the original results of the `apply` instructions. + SmallVector origDirResults; + collectAllExtractedElements(ai, origDirResults); + SmallVector origAllResults; + collectAllActualResultsInTypeOrder( + ai, origDirResults, ai->getIndirectSILResults(), origAllResults); + auto origResult = origAllResults[actualIndices.source]; - // Extract all direct results from the differential. + // Get the differential results of the `apply` instructions. SmallVector differentialDirResults; - extractAllElements(differentialCall, diffBuilder, differentialDirResults); - // Get all differential results in type-defined order. + collectAllExtractedElements(differentialCall, differentialDirResults); SmallVector differentialAllResults; collectAllActualResultsInTypeOrder( differentialCall, differentialDirResults, differentialCall->getIndirectSILResults(), differentialAllResults); - auto differentialResult = differentialAllResults[actualIndices.source]; + auto differentialResult = differentialAllResults.front(); // Add tangent for original result. - assert(actualIndices.source == 0 && "Expected result index to be first."); - setTangentValue(bb, origResult, - makeConcreteTangentValue(differentialResult)); + if (origResult->getType().isObject()) { + if (!origResult->getType().is()) { + setTangentValue(bb, origResult, + makeConcreteTangentValue(differentialResult)); + } else if (auto *dti = getSingleDestructureTupleUser(ai)) { + bool notSetValue = true; + for (auto result : dti->getResults()) { + if (activityInfo.isActive(result, getIndices())) { + assert(notSetValue && + "This was incorrectly set, should only have one active " + "result from the tuple."); + notSetValue = false; + setTangentValue(bb, result, + makeConcreteTangentValue(differentialResult)); + } + } + } + } + } + + /// Generate a `return` instruction in the current differential basic block. + void emitReturnInstForDifferential() { + auto &differential = getDifferential(); + auto diffLoc = differential.getLocation(); + auto &diffBuilder = getDifferentialBuilder(); + + SmallVector activeResults; + + // This vector will contain all the materialized return elements. + SmallVector retElts; + SmallVector originalResults; + collectAllDirectResultsInTypeOrder(*original, originalResults); + + // Materializes the return element corresponding to the result + // `resultIndex` into the `retElts` vector. + auto addActiveResult = [&](unsigned resultIndex) -> void { + auto origResult = originalResults[resultIndex]; + assert(origResult->getType().isObject() && + "Should only be handling direct results for 'return' " + "instruction."); + if (activityInfo.isActive(origResult, getIndices())) { + activeResults.push_back(origResult); + } + }; + // Create an array of the direct tangent values of the original results. + for (auto i : range(originalResults.size())) + addActiveResult(i); + assert(activeResults.size() <= 1); + + if (activeResults.empty() && !originalResults.empty()) { + // Create zero tangent value for direct result. + auto origResult = originalResults[getIndices().source]; + assert(origResult->getType().isObject() && + "Should only be handling direct results for 'return' " + "instruction."); + auto zeroType = origResult->getType().getASTType(); + auto zero = + emitZeroDirect(getTangentSpace(zeroType)->getCanonicalType(), + diffLoc); + retElts.push_back(zero); + } else if (!activeResults.empty()) { + auto diffVal = getTangentValue(activeResults.front()); + auto val = materializeTangent(diffVal, diffLoc); + retElts.push_back(val); + } + + diffBuilder.createReturn( + diffLoc, joinElements(retElts, diffBuilder, diffLoc)); + } + +private: + + /// Set up the differential function. This includes: + /// - Creating all differential blocks. + /// - Creating differential entry block arguments based on the function type. + /// - Creating tangent value mapping for original/differential parameters. + /// - Checking for unvaried result and emitting related warnings. + void prepareForDifferentialGeneration() { + // Create differential blocks and arguments. + auto *diffGenEnv = getDifferential().getGenericEnvironment(); + auto diffGenSig = diffGenEnv + ? diffGenEnv->getGenericSignature()->getCanonicalSignature() + : nullptr; + auto &differential = getDifferential(); + auto *origEntry = original->getEntryBlock(); + for (auto &origBB : *original) { + auto *diffBB = differential.createBasicBlock(); + diffBBMap.insert({&origBB, diffBB}); + { + Lowering::GenericContextScope genericContextScope( + context.getTypeConverter(), diffGenSig); + auto diffStructLoweredType = remapSILTypeInDifferential( + differentialInfo.getLinearMapStructLoweredType(&origBB)); + + // If the BB is the original entry, then the differential block that we + // just created must be the differential function's entry. Create + // differential entry arguments and continue. + if (&origBB == origEntry) { + assert(diffBB->isEntry()); + createEntryArguments(&differential); + auto *lastArg = diffBB->getArguments().back(); + assert(lastArg->getType() == diffStructLoweredType); + differentialStructArguments[&origBB] = lastArg; + } + } + + LLVM_DEBUG({ + auto &s = getADDebugStream() + << "Original bb" + std::to_string(origBB.getDebugID()) + << ": To differentiate or not to differentiate?\n"; + for (auto &inst : origBB) { + s << (differentialInfo.shouldDifferentiateInstruction(&inst) + ? "[∂] " : "[ ] ") + << inst; + } + }); + } + + assert(diffBBMap.size() == 1 && + "Can only currently handle single basic block functions"); + + // The differential function has type: + // (arg0', ..., argn', entry_df_struct) -> result'. + auto diffParamArgs = + differential.getArgumentsWithoutIndirectResults().drop_back(); + assert(diffParamArgs.size() == + attr->getIndices().parameters->getNumIndices()); + auto origParamArgs = original->getArgumentsWithoutIndirectResults(); + + // TODO(TF-788): Re-enable non-varied result warning. + /* + // Check if result is not varied. + SmallVector origFormalResults; + collectAllFormalResultsInTypeOrder(*original, origFormalResults); + auto origResult = origFormalResults[getIndices().source]; + // Emit warning if original result is not varied, because it will always + // have a zero derivative. + if (!activityInfo.isVaried(origResult, getIndices().parameters)) { + // Emit fixit if original result has a valid source location. + auto startLoc = origResult.getLoc().getStartSourceLoc(); + auto endLoc = origResult.getLoc().getEndSourceLoc(); + if (startLoc.isValid() && endLoc.isValid()) { + context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) + .fixItInsert(startLoc, "withoutDerivative(at:") + .fixItInsertAfter(endLoc, ")"); + } + } + */ + + // Create a mapping of the parameters. + auto autoDiffIndex = getIndices().parameters->begin(); + for (auto index : range(diffParamArgs.size())) { + auto *diffParam = diffParamArgs[index]; + auto *origParam = origParamArgs[*autoDiffIndex]; + autoDiffIndex++; + if (diffParam->getType().isAddress()) { + setTangentBuffer(origEntry, origParam, diffParam); + } else { + setTangentValue( + origEntry, origParam, makeConcreteTangentValue(diffParam)); + } + LLVM_DEBUG(getADDebugStream() + << "Assigned parameter " << *diffParam + << " as the tangent of original result " << *origParam); + } + + // If there are indirect results, create a mapping. + auto origIndResults = original->getIndirectResults(); + auto diffIndResults = differential.getIndirectResults(); + assert(origIndResults.size() == diffIndResults.size()); + + for (auto &origBB : *original) + for (auto i : indices(diffIndResults)) + setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]); } public: @@ -4470,7 +5223,7 @@ class JVPEmitter final context(context), original(original), attr(attr), jvp(jvp), invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), jvp)), - differentialInfo(context, AutoDiffAssociatedFunctionKind::JVP, original, + differentialInfo(context, AutoDiffLinearMapKind::Differential, original, jvp, attr->getIndices(), activityInfo, getBuilder()), differentialAndBuilder(initializeDifferentialAndBuilder( context, original, attr, &differentialInfo)), @@ -4560,91 +5313,6 @@ class JVPEmitter final return differential; } - /// Set up the differential function. This includes: - /// - Creating all the differential blocks. - /// - Create arguments for the entry block according to the function type. - /// - Adding the tangent values of the parameters to the tangent value map. - /// - Checking for unvaried result and emitting related warnings. - void prepareForDifferentialGeneration() { - auto &diffBuilder = getDifferentialBuilder(); - - // Create differential blocks and arguments. - // TODO: Consider visiting original blocks in pre-order (dominance) order. - auto &differential = getDifferential(); - auto *origEntry = original->getEntryBlock(); - for (auto &origBB : *original) { - auto *diffBB = differential.createBasicBlock(); - diffBBMap.insert({&origBB, diffBB}); - auto diffStructLoweredType = - remapType(differentialInfo.getLinearMapStructLoweredType(&origBB)); - // If the BB is the original entry, then the differential block that we - // just created must be the differential function's entry. Create - // differential entry arguments and continue. - if (&origBB == origEntry) { - assert(diffBB->isEntry()); - createEntryArguments(&differential); - auto *mainDifferentialStruct = diffBB->getArguments().back(); - assert(mainDifferentialStruct->getType() == diffStructLoweredType); - differentialStructArguments[&origBB] = mainDifferentialStruct; - } - - LLVM_DEBUG({ - auto &s = getADDebugStream() - << "Original bb" + std::to_string(origBB.getDebugID()) - << ": To differentiate or not to differentiate?\n"; - for (auto &inst : origBB) { - s << (differentialInfo.shouldDifferentiateInstruction(&inst) - ? "[∂] " : "[ ] ") - << inst; - } - }); - } - - assert(diffBBMap.size() == 1 && - "Can only currently handle single basic block functions"); - - // The differential function has type: - // (arg0', ..., argn', entry_df_struct) -> result'. - auto diffParamArgs = - differential.getArgumentsWithoutIndirectResults().drop_back(); - assert(diffParamArgs.size() == - attr->getIndices().parameters->getNumIndices()); - auto origParamArgs = original->getArgumentsWithoutIndirectResults(); - - // TODO(TF-788): Re-enable non-varied result warning. - /* - // Emit a warning and fixit if original result is not varied, because it - // will always have a zero derivative. - SmallVector origFormalResults; - collectAllFormalResultsInTypeOrder(*original, origFormalResults); - auto origResult = origFormalResults[getIndices().source]; - if (!activityInfo.isVaried(origResult, getIndices().parameters)) { - // Emit fixit if original result has a valid source location. - auto startLoc = origResult.getLoc().getStartSourceLoc(); - auto endLoc = origResult.getLoc().getEndSourceLoc(); - if (startLoc.isValid() && endLoc.isValid()) { - context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) - .fixItInsert(startLoc, "withoutDerivative(at:") - .fixItInsertAfter(endLoc, ")"); - } - } - */ - - auto *diffEntry = getDifferential().getEntryBlock(); - diffBuilder.setInsertionPoint( - diffEntry, getNextDifferentialLocalAllocationInsertionPoint()); - - for (auto index : *getIndices().parameters) { - auto diffParam = diffParamArgs[index]; - auto origParam = origParamArgs[index]; - setTangentValue(origEntry, origParam, - makeConcreteTangentValue(diffParam)); - LLVM_DEBUG(getADDebugStream() - << "Assigned parameter " << *diffParam - << " as the tangent of original result " << origParam); - } - } - /// Run JVP generation. Returns true on error. bool run() { LLVM_DEBUG(getADDebugStream() @@ -4658,6 +5326,7 @@ class JVPEmitter final SmallVector entryArgs(entry->getArguments().begin(), entry->getArguments().end()); cloneFunctionBody(original, entry, entryArgs); + emitReturnInstForDifferential(); // If errors occurred, back out. if (errorOccurred) return true; @@ -4683,9 +5352,25 @@ class JVPEmitter final /// General visitor for all instructions. If any error is emitted by previous /// visits, bail out. void visit(SILInstruction *inst) { + auto diffBuilder = getDifferentialBuilder(); if (errorOccurred) return; - TypeSubstCloner::visit(inst); + if (differentialInfo.shouldDifferentiateInstruction(inst)) { + LLVM_DEBUG(getADDebugStream() << "JVPEmitter visited:\n[ORIG]" + << *inst); +#ifndef NDEBUG + auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); +#endif + TypeSubstCloner::visit(inst); + LLVM_DEBUG({ + auto &s = llvm::dbgs() << "[DF] Emitted in Differential:\n"; + auto afterInsertion = diffBuilder.getInsertionPoint(); + for (auto it = ++beforeInsertion; it != afterInsertion; ++it) + s << *it; + }); + } else { + TypeSubstCloner::visit(inst); + } } void visitSILInstruction(SILInstruction *inst) { @@ -4694,47 +5379,6 @@ class JVPEmitter final errorOccurred = true; } - /// Handle `copy_value` instruction. - /// Original: y = copy_value x - /// Adjoint: tan[x] = copy_value tan[y] - void visitCopyValueInst(CopyValueInst *cvi) { - TypeSubstCloner::visitCopyValueInst(cvi); - emitTangentForCopyValueInst(cvi); - } - - void visitReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); - auto &builder = getBuilder(); - auto *diffStructVal = buildDifferentialValueStructValue(ri); - - // Get the value in the JVP corresponding to the original result. - auto *origRetInst = cast(origExit->getTerminator()); - auto origResult = getOpValue(origRetInst->getOperand()); - SmallVector origResults; - extractAllElements(origResult, builder, origResults); - - // Get and partially apply the differential. - auto jvpGenericEnv = jvp->getGenericEnvironment(); - auto jvpSubstMap = jvpGenericEnv - ? jvpGenericEnv->getForwardingSubstitutionMap() - : jvp->getForwardingSubstitutionMap(); - auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); - auto *differentialPartialApply = builder.createPartialApply( - loc, differentialRef, jvpSubstMap, {diffStructVal}, - ParameterConvention::Direct_Guaranteed); - - // Return a tuple of the original result and differential. - SmallVector directResults; - directResults.append(origResults.begin(), origResults.end()); - directResults.push_back(differentialPartialApply); - builder.createReturn( - ri->getLoc(), joinElements(directResults, builder, loc)); - - // Differential emission. - emitTangentForReturnInst(ri); - } - void visitInstructionsInBlock(SILBasicBlock *bb) { // Destructure the differential struct to get the elements. auto &diffBuilder = getDifferentialBuilder(); @@ -4748,24 +5392,6 @@ class JVPEmitter final TypeSubstCloner::visitInstructionsInBlock(bb); } - void visitDestroyValueInst(DestroyValueInst *dvi) { - TypeSubstCloner::visitDestroyValueInst(dvi); - if (differentialInfo.shouldDifferentiateInstruction(dvi)) - emitTangentForDestroyValueInst(dvi); - } - - void visitBeginBorrowInst(BeginBorrowInst *bbi) { - TypeSubstCloner::visitBeginBorrowInst(bbi); - if (differentialInfo.shouldDifferentiateInstruction(bbi)) - emitTangentForBeginBorrow(bbi); - } - - void visitEndBorrowInst(EndBorrowInst *ebi) { - TypeSubstCloner::visitEndBorrowInst(ebi); - if (differentialInfo.shouldDifferentiateInstruction(ebi)) - emitTangentForEndBorrow(ebi); - } - // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its JVP. void visitApplyInst(ApplyInst *ai) { @@ -4792,11 +5418,21 @@ class JVPEmitter final } } - LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); + LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); // Get the parameter indices required for differentiating this function. SmallVector allResults; - allResults.push_back(ai); + // If `apply` result is tuple-typed with a `destructure_tuple` user, add the + // results of the `destructure_tuple` user to `allResults` instead of adding + // the `apply` result itself. + // Otherwise, add `apply` result to `allResults`. + if (auto *dti = getSingleDestructureTupleUser(ai)) { + for (auto result : dti->getResults()) + allResults.push_back(result); + } else { + allResults.push_back(ai); + } + allResults.append(ai->getIndirectSILResults().begin(), ai->getIndirectSILResults().end()); SmallVector activeParamIndices; @@ -4820,7 +5456,7 @@ class JVPEmitter final errorOccurred = true; return; } - // Form expected indices by assuming there's only one result. + // Form expected indices, assuming there's only one result. SILAutoDiffIndices indices( activeResultIndices.front(), AutoDiffIndexSubset::get( @@ -4969,13 +5605,85 @@ class JVPEmitter final recursivelyDeleteTriviallyDeadInstructions( getOpValue(origCallee)->getDefiningInstruction()); - // Record the callee differential function value. - // This is used later to construct a differential struct. - auto diffFunc = jvpDirectResults.back(); - differentialValues[ai->getParent()].push_back(diffFunc); + // Add the differential function for when we create the struct we partially + // apply to the differential we are generating. + auto differential = jvpDirectResults.back(); + auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai); + auto originalDifferentialType = + getOpType(differential->getType()).getAs(); + auto differentialType = + remapType(differential->getType()) + .castTo(); + auto jvpGenSig = SubsMap.getGenericSignature() + ? SubsMap.getGenericSignature()->getCanonicalSignature() + : nullptr; + Lowering::GenericContextScope genericContextScope( + context.getTypeConverter(), jvpGenSig); + auto loweredDifferentialType = + getOpType(context.getTypeConverter().getLoweredType( + differentialDecl->getInterfaceType()->getCanonicalType(), + ResilienceExpansion::Minimal)) + .castTo(); + // If actual differential type does not match lowered differential type, + // reabstract the differential using a thunk. + if (!loweredDifferentialType->isEqual(originalDifferentialType)) { + SILOptFunctionBuilder fb(context.getTransform()); + auto *thunk = getOrCreateReabstractionThunk( + fb, context.getModule(), loc, &getDifferential(), + differentialType, loweredDifferentialType); + auto *thunkRef = builder.createFunctionRef(loc, thunk); + differential = builder.createPartialApply( + loc, thunkRef, + getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()), + {differential}, differentialType->getCalleeConvention()); + } + differentialValues[ai->getParent()].push_back(differential); // Differential emission. - emitTangentForApplyInst(ai, indices); + emitTangentForApplyInst(ai, indices, originalDifferentialType); + } + + void visitReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto *origExit = ri->getParent(); + auto &builder = getBuilder(); + auto *diffStructVal = buildDifferentialValueStructValue(ri); + + // Get the JVP value corresponding to the original functions's return value. + auto *origRetInst = cast(origExit->getTerminator()); + auto origResult = getOpValue(origRetInst->getOperand()); + SmallVector origResults; + extractAllElements(origResult, builder, origResults); + + // Get and partially apply the differential. + auto jvpGenericEnv = jvp->getGenericEnvironment(); + auto jvpSubstMap = jvpGenericEnv + ? jvpGenericEnv->getForwardingSubstitutionMap() + : jvp->getForwardingSubstitutionMap(); + auto *differentialRef = + builder.createFunctionRef(loc, &getDifferential()); + auto *differentialPartialApply = builder.createPartialApply( + loc, differentialRef, jvpSubstMap, {diffStructVal}, + ParameterConvention::Direct_Guaranteed); + + // Return a tuple of the original result and pullback. + SmallVector directResults; + directResults.append(origResults.begin(), origResults.end()); + directResults.push_back(differentialPartialApply); + builder.createReturn( + ri->getLoc(), joinElements(directResults, builder, loc)); + } + + void visitBranchInst(BranchInst *bi) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitCondBranchInst(CondBranchInst *cbi) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitSwitchEnumInst(SwitchEnumInst *sei) { + llvm_unreachable("Unsupported SIL instruction."); } void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { @@ -6137,7 +6845,7 @@ class PullbackEmitter final : public SILInstructionVisitor { void visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplyInst(ai)); // Handle array uninitialized allocation intrinsic specially. - if (ai->hasSemantics("array.uninitialized_intrinsic")) + if (isArrayLiteralIntrinsic(ai)) return visitArrayInitialization(ai); // Replace a call to a function with a call to its pullback. auto &nestedApplyInfo = getContext().getNestedApplyInfo(); @@ -6297,8 +7005,7 @@ class PullbackEmitter final : public SILInstructionVisitor { tangentVectorTy->getStructOrBoundGenericStruct(); assert(tangentVectorDecl); - auto *destructure = - builder.createDestructureStruct(si->getLoc(), adjStruct); + auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); // Accumulate adjoints for the fields of the `struct` operand. unsigned fieldIndex = 0; for (auto it = structDecl->getStoredProperties().begin(); @@ -6325,7 +7032,7 @@ class PullbackEmitter final : public SILInstructionVisitor { tanField = cast(tanFieldLookup.front()); } assert(tanField); - auto tanElt = destructure->getResult(fieldIndex); + auto tanElt = dti->getResult(fieldIndex); addAdjointValue( bb, si->getFieldValue(field), makeConcreteAdjointValue(tanElt), si->getLoc()); diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift index ba4c3ff02327f..ecd52ac0a8364 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift @@ -202,6 +202,13 @@ extension Tracked where T : Differentiable, T == T.TangentVector { return (lhs + rhs, { v in (v, v) }) } + @usableFromInline + @differentiating(+) + internal static func _jvpAdd(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> Self) { + return (lhs + rhs, { $0 + $1 }) + } + @usableFromInline @differentiating(-) internal static func _vjpSubtract(lhs: Self, rhs: Self) @@ -209,18 +216,11 @@ extension Tracked where T : Differentiable, T == T.TangentVector { return (lhs - rhs, { v in (v, .zero - v) }) } - @usableFromInline - @differentiating(+) - internal static func _vjpAdd(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> (Self)) { - return (lhs + rhs, { (dx, dy) in dx + dy }) - } - @usableFromInline @differentiating(-) - internal static func _vjpSubtract(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> (Self)) { - return (lhs - rhs, { (dx, dy) in dx - dy }) + internal static func _jvpSubtract(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> Self) { + return (lhs - rhs, { $0 - $1 }) } } @@ -235,7 +235,7 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, @usableFromInline @differentiating(*) - internal static func _vjpMultiply(lhs: Self, rhs: Self) + internal static func _jvpMultiply(lhs: Self, rhs: Self) -> (value: Self, differential: (Self, Self) -> (Self)) { return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs }) } @@ -251,7 +251,7 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector @usableFromInline @differentiating(/) - internal static func _vjpDivide(lhs: Self, rhs: Self) + internal static func _jvpDivide(lhs: Self, rhs: Self) -> (value: Self, differential: (Self, Self) -> (Self)) { return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy }) } diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index f1956cc5cdb17..f70ff914680d8 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -914,7 +914,7 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { } /// Creates a type-erased derivative from the given derivative. - @differentiable(vjp: _vjpInit(_:)) + @differentiable(jvp: _jvpInit(_:), vjp: _vjpInit(_:)) public init(_ base: T) where T : Differentiable, T.TangentVector == T { self._box = _ConcreteDerivativeBox(base) } @@ -927,6 +927,14 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { return (AnyDerivative(base), { v in v.base as! T.TangentVector }) } + @usableFromInline internal static func _jvpInit( + _ base: T + ) -> (AnyDerivative, (T.TangentVector) -> AnyDerivative) + where T : Differentiable, T.TangentVector == T + { + return (AnyDerivative(base), { dbase in AnyDerivative(dbase) }) + } + public typealias TangentVector = AnyDerivative // `Equatable` requirements (implied by `AdditiveArithmetic`). @@ -963,6 +971,14 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { return (lhs + rhs, { v in (v, v) }) } + @differentiating(+) + @usableFromInline internal static func _jvpAdd( + lhs: AnyDerivative, rhs: AnyDerivative + ) -> (value: AnyDerivative, + differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative)) { + return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs }) + } + public static func - ( lhs: AnyDerivative, rhs: AnyDerivative ) -> AnyDerivative { @@ -977,6 +993,14 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { return (lhs - rhs, { v in (v, .zero - v) }) } + @differentiating(-) + @usableFromInline internal static func _jvpSubtract( + lhs: AnyDerivative, rhs: AnyDerivative + ) -> (value: AnyDerivative, + differential: (AnyDerivative, AnyDerivative) -> AnyDerivative) { + return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs }) + } + // `Differentiable` requirements. public mutating func move(along direction: TangentVector) { if _box._isOpaqueZero() { diff --git a/test/AutoDiff/forward_mode_runtime.swift b/test/AutoDiff/forward_mode_runtime.swift index 81a215837bb3f..a1e2ad83ff8c5 100644 --- a/test/AutoDiff/forward_mode_runtime.swift +++ b/test/AutoDiff/forward_mode_runtime.swift @@ -3,6 +3,11 @@ import StdlibUnittest import DifferentiationUnittest +#if os(macOS) +import Darwin.C +#else +import Glibc +#endif var ForwardModeTests = TestSuite("ForwardMode") @@ -10,6 +15,15 @@ var ForwardModeTests = TestSuite("ForwardMode") // Basic tests. //===----------------------------------------------------------------------===// +ForwardModeTests.test("Identity") { + func func_to_diff(x: Float) -> Float { + return x + } + let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff) + expectEqual(4, y) + expectEqual(1, differential(1)) +} + ForwardModeTests.test("Unary") { func func_to_diff(x: Float) -> Float { return x * x @@ -39,8 +53,156 @@ ForwardModeTests.test("BinaryWithLets") { expectEqual(-19, differential(1, 1)) } +ForwardModeTests.test("SubsetParametersDiff") { + func func_to_diff1(x: Int, y: Float, z: Int) -> Float { + return y + } + let (y1, differential1) = valueWithDifferential(at: 5) { y in + func_to_diff1(x: 0, y: y, z: 0) + } + expectEqual(5, y1) + expectEqual(1, differential1(1)) + + func func_to_diff2(x: Float, y: Int, z: Int) -> Float { + return 2 * x + } + let (y2, differential2) = valueWithDifferential(at: 6) { x in + func_to_diff2(x: x, y: 0, z: 0) + } + expectEqual(12, y2) + expectEqual(2, differential2(1)) + + func func_to_diff3(x: Int, y: Int, z: Float) -> Float { + return 3 * z + } + let (y3, differential3) = valueWithDifferential(at: 7) { z in + func_to_diff3(x: 0, y: 0, z: z) + } + expectEqual(21, y3) + expectEqual(3, differential3(1)) +} + +//===----------------------------------------------------------------------===// +// Functions with variables +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("UnaryWithVars") { + func unary(x: Float) -> Float { + var a = x + a = x + var b = a + 2 + b = b - 1 + let c: Float = 3 + var d = a + b + c - 1 + d = d + d + return d + } + + let (y, differential) = valueWithDifferential(at: 4, in: unary) + expectEqual(22, y) + expectEqual(4, differential(1)) +} + +//===----------------------------------------------------------------------===// +// Functions with basic struct +//===----------------------------------------------------------------------===// + +struct A: Differentiable & AdditiveArithmetic { + var x: Float +} + +ForwardModeTests.test("StructInit") { + func structInit(x: Float) -> A { + return A(x: 2 * x) + } + + let (y, differential) = valueWithDifferential(at: 4, in: structInit) + expectEqual(A(x: 8), y) + expectEqual(A(x: 2), differential(1)) +} + +ForwardModeTests.test("StructExtract") { + func structExtract(x: A) -> Float { + return 2 * x.x + } + + let (y, differential) = valueWithDifferential( + at: A(x: 4), + in: structExtract) + expectEqual(8, y) + expectEqual(2, differential(A(x: 1))) +} + +ForwardModeTests.test("LocalStructVariable") { + func structExtract(x: A) -> A { + let a = A(x: 2 * x.x) // 2x + var b = A(x: a.x + 2) // 2x + 2 + b = A(x: b.x + a.x) // 2x + 2 + 2x = 4x + 2 + return b + } + + let (y, differential) = valueWithDifferential( + at: A(x: 4), + in: structExtract) + expectEqual(A(x: 18), y) + expectEqual(A(x: 4), differential(A(x: 1))) +} + +//===----------------------------------------------------------------------===// +// Functions with methods +//===----------------------------------------------------------------------===// + +extension A { + func noParamMethodA() -> A { + return A(x: 2 * x) + } + + func noParamMethodx() -> Float { + return 2 * x + } + + static func *(lhs: A, rhs: A) -> A { + return A(x: lhs.x * rhs.x) + } + + func complexBinaryMethod(u: A, v: Float) -> A { + var b: A = u * A(x: 2) // A(x: u * 2) + b.x = b.x * v // A(x: u * 2 * v) + let c = b.x + 1 // u * 2 * v + 1 + + // A(x: u * 2 * v + 1 + u * 2 * v) = A(x: x * (4uv + 1)) + return A(x: x * (c + b.x)) + } +} + +ForwardModeTests.test("noParamMethodA") { + let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in + x.noParamMethodA() + } + expectEqual(A(x: 8), y) + expectEqual(A(x: 2), differential(A(x: 1))) +} + +ForwardModeTests.test("noParamMethodx") { + let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in + x.noParamMethodx() + } + expectEqual(8, y) + expectEqual(2, differential(A(x: 1))) +} + +ForwardModeTests.test("complexBinaryMethod") { + let (y, differential) = valueWithDifferential(at: A(x: 4), A(x: 5), 3) { + (x, y, z) in + // derivative = A(x: 4uv + 4xv + 4ux + 1) = 4*5*3 + 4*4*3 + 4*5*4 + 1 = 189 + x.complexBinaryMethod(u: y, v: z) + } + expectEqual(A(x: 244), y) + expectEqual(A(x: 189), differential(A(x: 1), A(x: 1), 1)) +} + //===----------------------------------------------------------------------===// -// `Tracked` struct +// Tracked struct //===----------------------------------------------------------------------===// ForwardModeTests.test("TrackedIdentity") { @@ -93,6 +255,378 @@ ForwardModeTests.test("TrackedWithLets") { expectEqual(4.9375, differential(1, 1)) } +//===----------------------------------------------------------------------===// +// Tuples +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("SimpleTupleExtractLet") { + func foo(_ x: Float) -> Float { + let tuple = (2*x, x) + return tuple.0 + } + let (y, differential) = valueWithDifferential(at: 4, in: foo) + expectEqual(8, y) + expectEqual(2, differential(1)) +} + +ForwardModeTests.test("SimpleTupleExtractVar") { + func foo(_ x: Float) -> Float { + let tuple = (2*x, x) + return tuple.0 + } + let (y, differential) = valueWithDifferential(at: 4, in: foo) + expectEqual(8, y) + expectEqual(2, differential(1)) +} + +ForwardModeTests.test("TupleSideEffects") { + func foo(_ x: Float) -> Float { + var tuple = (x, x) + tuple.0 = tuple.0 * x + return x * tuple.0 + } + expectEqual(27, derivative(at: 3, in: foo)) + + func fifthPower(_ x: Float) -> Float { + var tuple = (x, x) + tuple.0 = tuple.0 * x + tuple.1 = tuple.0 * x + return tuple.0 * tuple.1 + } + expectEqual(405, derivative(at: 3, in: fifthPower)) + + func nested(_ x: Float) -> Float { + var tuple = ((x, x), x) + tuple.0.0 = tuple.0.0 * x + tuple.0.1 = tuple.0.0 * x + return tuple.0.0 * tuple.0.1 + } + expectEqual(405, derivative(at: 3, in: nested)) + + // FIXME(TF-201): Update after reabstraction thunks can be directly differentiated. + /* + func generic(_ x: T) -> T { + var tuple = (x, x) + tuple.0 += x + tuple.1 += x + return tuple.0 + tuple.0 + } + expectEqual(1, derivative(at: 3.0, in: generic)) + */ +} + +// Tests TF-321. +ForwardModeTests.test("TupleNonDifferentiableElements") { + // @differentiable + func foo(_ x: Float) -> Float { + var tuple = (x, 1) + tuple.0 = x + tuple.1 = 1 + return tuple.0 + } + expectEqual(1, derivative(at: 1, in: foo)) + + func bar(_ x: Float) -> Float { + var tuple: (Int, Int, Float, Float) = (1, 1, x, x) + tuple.0 = 1 + tuple.1 = 1 + tuple.3 = x + return tuple.3 + } + expectEqual(1, derivative(at: 1, in: bar)) + + struct Wrapper { + @differentiable(where T : Differentiable) + func baz(_ x: T) -> T { + var tuple = (1, 1, x, 1) + tuple.0 = 1 + tuple.2 = x + tuple.3 = 1 + return tuple.2 + } + } + expectEqual(1, derivative(at: Float(1), in: { x -> Float in + let wrapper = Wrapper() + return wrapper.baz(x) + })) +} + +//===----------------------------------------------------------------------===// +// Generics +//===----------------------------------------------------------------------===// + +struct Tensor + : VectorProtocol, Differentiable { + // NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`) + // until differentiation has indirect passing support. + var value: Float + init(_ value: Float) { self.value = value } +} + +ForwardModeTests.test("GenericIdentity") { + func identity(_ x: T) -> T { + return x + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + identity(x) + } + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.test("GenericTensorIdentity") { + func identity( + _ x: Tensor) -> Tensor { + return x + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + identity(Tensor(x)) + } + expectEqual(Tensor(4), y) + expectEqual(Tensor(1), differential(1)) +} + +ForwardModeTests.test("GenericTensorPlus") { + func plus(_ x: Tensor) -> Float { + return x.value + x.value + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + plus(Tensor(x)) + } + expectEqual(8, y) + expectEqual(2, differential(1)) +} + +ForwardModeTests.test("GenericTensorBinaryInput") { + func binary( + _ x: Tensor, _ y: Tensor) -> Float { + return x.value * y.value + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + binary(Tensor(x), Tensor(y)) + } + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +ForwardModeTests.test("GenericTensorWithLets") { + func binary( + _ x: Tensor, _ y: Tensor) -> Float { + let a = Tensor(x.value) + let b = Tensor(y.value) + return a.value * b.value + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + binary(Tensor(x), Tensor(y)) + } + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +ForwardModeTests.test("GenericTensorWithVars") { + func binary( + _ x: Tensor, _ y: Tensor) -> Float { + var a = Tensor(x.value) + var b = Tensor(y.value) + b = a + a = Tensor(y.value) + return a.value * b.value + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + binary(Tensor(x), Tensor(y)) + } + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +// Test case where associated derivative function's requirements are met. +extension Tensor where Scalar : Numeric { + @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint) + func mean() -> Tensor { + return self + } + + @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint) + func variance() -> Tensor { + return mean() // ok + } +} +_ = differential(at: Tensor(1), in: { $0.variance() }) + +// Tests TF-508: differentiation requirements with dependent member types. +protocol TF_508_Proto { + associatedtype Scalar +} +extension TF_508_Proto where Scalar : FloatingPoint { + @differentiable( + jvp: jvpAdd + where Self : Differentiable, Scalar : Differentiable, + // Conformance requirement with dependent member type. + Self.TangentVector : TF_508_Proto + ) + static func +(lhs: Self, rhs: Self) -> Self { + return lhs + } + + @differentiable( + jvp: jvpSubtract + where Self : Differentiable, Scalar : Differentiable, + // Same-type requirement with dependent member type. + Self.TangentVector == Float + ) + static func -(lhs: Self, rhs: Self) -> Self { + return lhs + } +} +extension TF_508_Proto where Self : Differentiable, + Scalar : FloatingPoint & Differentiable, + Self.TangentVector : TF_508_Proto { + static func jvpAdd(lhs: Self, rhs: Self) + -> (Self, (TangentVector, TangentVector) -> TangentVector) { + return (lhs, { (dlhs, drhs) in dlhs }) + } +} +extension TF_508_Proto where Self : Differentiable, + Scalar : FloatingPoint & Differentiable, + Self.TangentVector == Float { + static func jvpSubtract(lhs: Self, rhs: Self) + -> (Self, (TangentVector, TangentVector) -> TangentVector) { + return (lhs, { (dlhs, drhs) in dlhs }) + } +} + +struct TF_508_Struct + : TF_508_Proto, AdditiveArithmetic {} +extension TF_508_Struct : Differentiable where Scalar : Differentiable { + typealias TangentVector = TF_508_Struct +} + +// func TF_508() { +// let x = TF_508_Struct() +// // Test conformance requirement with dependent member type. +// _ = differential(at: x, in: { +// (x: TF_508_Struct) -> TF_508_Struct in +// return x + x +// }) +// // Test same-type requirement with dependent member type. +// _ = differential(at: x, in: { +// (x: TF_508_Struct) -> TF_508_Struct in +// return x - x +// }) +// } + +// TF-523 +struct TF_523_Struct : Differentiable & AdditiveArithmetic { + var a: Float = 1 + typealias TangentVector = TF_523_Struct + typealias AllDifferentiableVariables = TF_523_Struct +} + +@differentiable +func TF_523_f(_ x: TF_523_Struct) -> Float { + return x.a * 2 +} + +// TF-534: Thunk substitution map remapping. +protocol TF_534_Layer : Differentiable { + associatedtype Input : Differentiable + associatedtype Output : Differentiable + + @differentiable + func callAsFunction(_ input: Input) -> Output +} +struct TF_534_Tensor : Differentiable {} + +func TF_534( + _ model: inout Model, inputs: Model.Input +) -> TF_534_Tensor where Model.Output == TF_534_Tensor { + return valueWithDifferential(at: model) { model -> Model.Output in + return model(inputs) + }.0 +} + +// TODO: uncomment once control flow is supported in forward mode. +// TF-652: Test VJPEmitter substitution map generic signature. +// The substitution map should have the VJP's generic signature, not the +// original function's. +// struct TF_652 {} +// extension TF_652 : Differentiable where Scalar : FloatingPoint {} + +// @differentiable(wrt: x where Scalar: FloatingPoint) +// func test(x: TF_652) -> TF_652 { +// for _ in 0..<10 { +// let _ = x +// } +// return x +// } + +//===----------------------------------------------------------------------===// +// Tracked Generic. +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("GenericTrackedIdentity") { + func identity(_ x: Tracked) -> Tracked { + return x + } + let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in + identity(Tracked(x)) + } + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.test("GenericTrackedBinaryAdd") { + func add(_ x: Tracked, _ y: Tracked) -> Tracked + where T: Differentiable, T == T.TangentVector { + return x + y + } + let (y, differential) = valueWithDifferential(at: 4, 5) { + (x: Float, y: Float) in + add(Tracked(x), Tracked(y)) + } + expectEqual(9, y) + expectEqual(2, differential(1, 1)) +} + +ForwardModeTests.test("GenericTrackedBinaryLets") { + func add(_ x: Tracked, _ y: Tracked) -> Tracked + where T: Differentiable & SignedNumeric, + T == T.TangentVector, + T == T.Magnitude { + let a = x * y // xy + let b = a + a // 2xy + return b + b // 4xy + } + // 4y + 4x + let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in + add(Tracked(x), Tracked(y)) + } + expectEqual(80, y) + expectEqual(36, differential(1, 1)) +} + +ForwardModeTests.test("GenericTrackedBinaryVars") { + func add(_ x: Tracked, _ y: Tracked) -> Tracked + where T: Differentiable & SignedNumeric, + T == T.TangentVector, + T == T.Magnitude { + var a = x * y // xy + a = a + a // 2xy + var b = x + b = a + return b + b // 4xy + } + // 4y + 4x + let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in + add(Tracked(x), Tracked(y)) + } + expectEqual(80, y) + expectEqual(36, differential(1, 1)) +} + ForwardModeTests.test("TrackedDifferentiableFuncType") { func valAndDeriv( f: @escaping @differentiable (Tracked) -> Tracked @@ -110,11 +644,10 @@ ForwardModeTests.test("TrackedDifferentiableFuncType") { expectEqual(400, val1) expectEqual(160, dv1) } + //===----------------------------------------------------------------------===// // Classes //===----------------------------------------------------------------------===// -// NOTE: once forward mode is done, can copy and replace this in -// `class_method.swift` as it already calls reverse mode functions. ForwardModeTests.test("Final") { final class Final : Differentiable { @@ -124,7 +657,6 @@ ForwardModeTests.test("Final") { } for i in -5...5 { - expectEqual(Float(i) * 2, gradient(at: Float(i)) { x in Final().method(x) }) expectEqual( Float(i) * 2, derivative(at: Float(i)) { x in Final().method(x) }) @@ -168,16 +700,10 @@ ForwardModeTests.test("Simple") { func classValueWithDerivative(_ c: Super) -> (Float, Float) { return valueWithDerivative(at: 1) { c.f($0) } } - func classValueWithGradient(_ c: Super) -> (Float, Float) { - return valueWithGradient(at: 1) { c.f($0) } - } expectEqual((2, 2), classValueWithDerivative(Super())) expectEqual((3, 3), classValueWithDerivative(SubOverride())) expectEqual((3, 3), classValueWithDerivative(SubOverrideCustomDerivatives())) - expectEqual((2, 2), classValueWithGradient(Super())) - expectEqual((3, 3), classValueWithGradient(SubOverride())) - expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives())) } ForwardModeTests.test("SimpleWrtSelf") { @@ -241,13 +767,6 @@ ForwardModeTests.test("SimpleWrtSelf") { // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) // expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v)) - - - // `valueWithGradient` is not used because nested tuples cannot be compared - // with `expectEqual`. - func classGradient(_ c: Super) -> (Super.TangentVector, Float) { - return gradient(at: c, 10) { c, x in c.f(x) } - } // `valueWithDerivative` is not used because the derivative requires `Super` // to conform to `FloatingPoint`. @@ -269,43 +788,500 @@ ForwardModeTests.test("SimpleWrtSelf") { expectEqual(30, y3) let c3 = SubOverrideCustomDerivatives.TangentVector(base: 1, _nontrivial: []) expectEqual(3, diff3(c3, 1)) - expectEqual((Super.TangentVector(base: 10, _nontrivial: []), 2), - classGradient(Super(base: 2))) - expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3), - classGradient(SubOverride(base: 2))) - expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3), - classGradient(SubOverrideCustomDerivatives(base: 2))) } //===----------------------------------------------------------------------===// // Protocols //===----------------------------------------------------------------------===// -// TODO: add more protocol tests. -// protocol DiffReq : Differentiable { -// @differentiable(wrt: x) -// func foo(x: Float) -> Float -// } -// struct Linear: DiffReq, VectorProtocol { -// typealias TangentVector = Linear +protocol Prot : Differentiable { + @differentiable(wrt: x) + func foo(x: Float) -> Float +} +ForwardModeTests.test("Simple Protocol") { + struct Linear: Prot, VectorProtocol { + typealias TangentVector = Linear -// let m: Float -// let b: Float + let m: Float + let b: Float -// @differentiable(wrt: x) -// func foo(x: Float) -> Float { -// return m * x + b -// } -// } + @differentiable(wrt: x) + func foo(x: Float) -> Float { + return m * x + b + } + } -// ForwardModeTests.test("Protocols") { -// func genericFoo(_ t: T, _ x: Float) -> Float { -// t.foo(x: x) -// } -// let inst = Linear(m: 5, b: -2) -// let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) } -// expectEqual(23, y1) -// expectEqual(5, diff1(1)) -// } + func genericFoo(_ t: T, _ x: Float) -> Float { + t.foo(x: x) + } + let inst = Linear(m: 5, b: -2) + let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) } + expectEqual(23, y1) + expectEqual(5, diff1(1)) +} + +protocol DiffReq : Differentiable { + @differentiable(wrt: (self, x)) + func f(_ x: Float) -> Float +} + +extension DiffReq where TangentVector : AdditiveArithmetic { + @inline(never) // Prevent specialization, to test all witness code. + func derivF(at x: Float) -> Float { + return (valueWithDifferential(at: x) { x in self.f(x) }).1(1) + } +} + +struct Quadratic : DiffReq, VectorProtocol { + typealias TangentVector = Quadratic + + @differentiable + let a: Float + + @differentiable + let b: Float + + @differentiable + let c: Float + + init(_ a: Float, _ b: Float, _ c: Float) { + self.a = a + self.b = b + self.c = c + } + + @differentiable(wrt: (self, x)) + func f(_ x: Float) -> Float { + return a * x * x + b * x + c + } +} + +ForwardModeTests.test("ProtocolFunc") { + expectEqual(12, Quadratic(11, 12, 13).derivF(at: 0)) + expectEqual(2 * 11 + 12, Quadratic(11, 12, 13).derivF(at: 1)) + expectEqual(2 * 11 * 2 + 12, Quadratic(11, 12, 13).derivF(at: 2)) +} + +// MARK: Constructor, accessor, and subscript requirements. + +protocol FunctionsOfX: Differentiable { + @differentiable + init(x: Float) + + @differentiable + var x: Float { get } + + @differentiable + var y: Float { get } + + @differentiable + var z: Float { get } + + @differentiable + subscript() -> Float { get } +} + +struct TestFunctionsOfX: FunctionsOfX { + @differentiable + init(x: Float) { + self.x = x + self.y = x * x + } + + /// x = x + var x: Float + + /// y = x * x + var y: Float + + /// z = x * x + x + var z: Float { + return y + x + } + + @differentiable + subscript() -> Float { + return z + } +} + +@inline(never) // Prevent specialization, to test all witness code. +func derivatives(at x: Float, in: F.Type) + -> (Float, Float, Float, Float) +{ + let dxdx = derivative(at: x) { x in F(x: x).x } + let dydx = derivative(at: x) { x in F(x: x).y } + let dzdx = derivative(at: x) { x in F(x: x).z } + let dsubscriptdx = derivative(at: x) { x in F(x: x)[] } + return (dxdx, dydx, dzdx, dsubscriptdx) +} + +ForwardModeTests.test("constructor, accessor, subscript") { + expectEqual( + (1.0, 4.0, 5.0, 5.0), + derivatives(at: 2.0, in: TestFunctionsOfX.self)) +} + +// MARK: - Test witness method SIL type computation. + +protocol P : Differentiable { + @differentiable(wrt: (x, y)) + func foo(_ x: Float, _ y: Double) -> Float +} +struct S : P { + @differentiable(wrt: (x, y)) + func foo(_ x: Float, _ y: Double) -> Float { + return x + } +} + +// MARK: - Overridden protocol method adding differentiable attribute. + +public protocol Distribution { + associatedtype Value + func logProbability(of value: Value) -> Float +} + +public protocol DifferentiableDistribution: Differentiable, Distribution { + @differentiable(wrt: self) + func logProbability(of value: Value) -> Float +} + +struct Foo: DifferentiableDistribution { + @differentiable(wrt: self) + func logProbability(of value: Float) -> Float { + .zero + } +} + +@differentiable +func blah(_ x: T) -> Float where T.Value: AdditiveArithmetic { + x.logProbability(of: .zero) +} + +// Adding a more general `@differentiable` attribute. +public protocol DoubleDifferentiableDistribution: DifferentiableDistribution + where Value: Differentiable { + @differentiable(wrt: self) + @differentiable(wrt: (self, value)) + func logProbability(of value: Value) -> Float +} + +@differentiable +func blah2(_ x: T, _ value: T.Value) -> Float + where T.Value: AdditiveArithmetic { + x.logProbability(of: value) +} + +protocol DifferentiableFoo { + associatedtype T: Differentiable + @differentiable(wrt: x) + func foo(_ x: T) -> Float +} + +protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo { + @differentiable(wrt: (self, x)) + func foo(_ x: T) -> Float +} + +struct MoreDifferentiableFooStruct: MoreDifferentiableFoo { + @differentiable(wrt: (self, x)) + func foo(_ x: Float) -> Float { + x + } +} + +//===----------------------------------------------------------------------===// +// Simple Math +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("Arithmetics") { + func foo1(x: Float, y: Float) -> Float { + return x * y + } + expectEqual(7, derivative(at: 3, 4, in: foo1)) + func foo2(x: Float, y: Float) -> Float { + return -x * y + } + expectEqual(-7, derivative(at: 3, 4, in: foo2)) + func foo3(x: Float, y: Float) -> Float { + return -x + y + } + expectEqual(0, derivative(at: 3, 4, in: foo3)) +} + +ForwardModeTests.test("Fanout") { + func foo1(x: Float) -> Float { + x - x + } + expectEqual(0, derivative(at: 100, in: foo1)) + func foo2(x: Float) -> Float { + x + x + } + expectEqual(2, derivative(at: 100, in: foo2)) + func foo3(x: Float, y: Float) -> Float { + x + x + x * y + } + expectEqual(7, derivative(at: 3, 2, in: foo3)) +} + +ForwardModeTests.test("FunctionCall") { + func foo(_ x: Float, _ y: Float) -> Float { + return 3 * x + { $0 * 3 }(3) * y + } + expectEqual(12, derivative(at: 3, 4, in: foo)) + expectEqual(3, derivative(at: 3) { x in foo(x, 4) }) +} + +ForwardModeTests.test("ResultSelection") { + func foo(_ x: Float, _ y: Float) -> (Float, Float) { + return (x + 1, y + 2) + } + expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).0 })) + expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).1 })) +} + +ForwardModeTests.test("CaptureLocal") { + let z: Float = 10 + func foo(_ x: Float) -> Float { + return z * x + } + expectEqual(10, derivative(at: 0, in: foo)) +} + +var globalVar: Float = 10 +ForwardModeTests.test("CaptureGlobal") { + func foo(x: Float) -> Float { + globalVar += 20 + return globalVar * x + } + expectEqual(30, derivative(at: 0, in: foo)) +} + +ForwardModeTests.test("SideEffects") { + func fourthPower(x: Float) -> Float { + var a = x + a = a * x + a = a * x + return a * x + } + expectEqual(4 * 27, derivative(at: 3, in: fourthPower)) +} + +ForwardModeTests.test("TupleSideEffects") { + func foo(_ x: Float) -> Float { + var tuple = (x, x) + tuple.0 = tuple.0 * x + return x * tuple.0 + } + expectEqual(27, derivative(at: 3, in: foo)) + + func fifthPower(_ x: Float) -> Float { + var tuple = (x, x) + tuple.0 = tuple.0 * x + tuple.1 = tuple.0 * x + return tuple.0 * tuple.1 + } + expectEqual(405, derivative(at: 3, in: fifthPower)) + + func nested(_ x: Float) -> Float { + var tuple = ((x, x), x) + tuple.0.0 = tuple.0.0 * x + tuple.0.1 = tuple.0.0 * x + return tuple.0.0 * tuple.0.1 + } + expectEqual(405, derivative(at: 3, in: nested)) + + // FIXME(TF-201): Update after reabstraction thunks can be directly differentiated. + /* + func generic(_ x: T) -> T { + var tuple = (x, x) + tuple.0 += x + tuple.1 += x + return tuple.0 + tuple.0 + } + expectEqual(1, derivative(at: 3.0, in: generic)) + */ +} + +// Tests TF-321. +ForwardModeTests.test("TupleNonDifferentiableElements") { + func foo(_ x: Float) -> Float { + var tuple = (x, 1) + tuple.0 = x + tuple.1 = 1 + return tuple.0 + } + expectEqual(1, derivative(at: 1, in: foo)) + + func bar(_ x: Float) -> Float { + var tuple: (Int, Int, Float, Float) = (1, 1, x, x) + tuple.0 = 1 + tuple.1 = 1 + tuple.3 = x + return tuple.3 + } + expectEqual(1, derivative(at: 1, in: bar)) + + struct Wrapper { + @differentiable(where T : Differentiable) + func baz(_ x: T) -> T { + var tuple = (1, 1, x, 1) + tuple.0 = 1 + tuple.2 = x + tuple.3 = 1 + return tuple.2 + } + } + expectEqual(1, derivative(at: Float(1), in: { x -> Float in + let wrapper = Wrapper() + return wrapper.baz(x) + })) +} + +// Tests TF-21. +ForwardModeTests.test("StructMemberwiseInitializer") { + struct Foo : AdditiveArithmetic, Differentiable { + var stored: Float + var computed: Float { + return stored * stored + } + } + + let derivFoo = differential(at: Float(4), in: { input -> Foo in + let foo = Foo(stored: input) + let foo2 = foo + foo + return Foo(stored: foo2.stored) + })(1) + expectEqual(Foo.TangentVector(stored: 2), derivFoo) + + let computed = derivative(at: Float(4)) { input -> Float in + let foo = Foo(stored: input) + return foo.computed + } + expectEqual(8, computed) + + let derivProduct = derivative(at: Float(4)) { input -> Float in + let foo = Foo(stored: input) + return foo.computed * foo.stored + } + expectEqual(48, derivProduct) + + struct Custom : AdditiveArithmetic, Differentiable { + var x: Float + + // Custom initializer with `@differentiable`. + @differentiable + init(x: Float) { + print(x) + self.x = x + } + } + + let derivCustom = differential(at: Float(4), in: { input -> Custom in + let foo = Custom(x: input) + return foo + foo + })(1) + expectEqual(Custom.TangentVector(x: 2), derivCustom) +} + +// Tests TF-319: struct with non-differentiable constant stored property. +ForwardModeTests.test("StructConstantStoredProperty") { + struct TF_319 : Differentiable { + var x: Float + @noDerivative let constant = Float(2) + + @differentiable + init(x: Float) { + self.x = x + } + + @differentiable(wrt: (self, input)) + func applied(to input: Float) -> Float { + return x * constant * input + } + } + func testStructInit(to input: Float) -> Float { + let model = TF_319(x: 10) + return model.applied(to: input) + } + expectEqual(6, derivative(at: 10, in: { TF_319(x: $0).applied(to: 3) })) + expectEqual(20, derivative(at: 3, in: testStructInit)) +} + +ForwardModeTests.test("StructSideEffects") { + struct Point : AdditiveArithmetic, Differentiable { + var x: Float + var y: Float + var z: Float + } + + func double(_ input: Float) -> Point { + let point = Point(x: input, y: input, z: input) + return point + point + } + expectEqual(Point(x: 2, y: 2, z: 2), differential(at: 4, in: double)(1)) + + func fifthPower(_ input: Float) -> Float { + var point = Point(x: input, y: input, z: input) + point.x = point.x * input + point.y = point.x * input + return point.x * point.y + } + expectEqual(405, derivative(at: 3, in: fifthPower)) + + func mix(_ input: Float) -> Float { + var tuple = (point: Point(x: input, y: input, z: input), float: input) + tuple.point.x = tuple.point.x * tuple.float + tuple.point.y = tuple.point.x * input + return tuple.point.x * tuple.point.y + } + expectEqual(405, derivative(at: 3, in: mix)) + + // Test TF-282. + struct Add : Differentiable { + var bias: Float + func applied(to input: Float) -> Float { + var tmp = input + tmp = tmp + bias + return tmp + } + } + expectEqual(1, derivative(at: 1) { m in Add(bias: m).applied(to: 1) }) +} + +ForwardModeTests.test("StructGeneric") { + struct Generic : AdditiveArithmetic, Differentiable { + var x: T + var y: T + var z: T + } + + let deriv = differential(at: Float(3), in: { input -> Generic in + var generic = Generic(x: input, y: input, z: input) + return generic + })(1) + expectEqual(Generic.TangentVector(x: 1, y: 1, z: 1), deriv) + + func fifthPower(_ input: Float) -> Float { + var generic = Generic(x: input, y: input, z: input) + generic.x = generic.x * input + generic.y = generic.x * input + return generic.x * generic.y + } + expectEqual(405, derivative(at: 3, in: fifthPower)) +} + +ForwardModeTests.test("SubsetIndices") { + func deriv(_ lossFunction: @differentiable (Float, Float) -> Float) -> Float { + return derivative(at: 1) { x in lossFunction(x * x, 10.0) } + } + expectEqual(2, deriv { x, y in x + y }) + + func derivWRTNonDiff(_ lossFunction: @differentiable (Float, @nondiff Int) -> Float) -> Float { + return derivative(at: 2) { x in lossFunction(x * x, 10) } + } + expectEqual(4, derivWRTNonDiff { x, y in x + Float(y) }) +} runAllTests() diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index 38fc6a14a2b77..64f2eb432d825 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -106,8 +106,8 @@ func derivatives(at x: Float, in: F.Type) ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") { expectEqual( - derivatives(at: 2.0, in: TestFunctionsOfX.self), - (1.0, 4.0, 5.0, 5.0)) + (1.0, 4.0, 5.0, 5.0), + derivatives(at: 2.0, in: TestFunctionsOfX.self)) } // MARK: - Test witness method SIL type computation. diff --git a/test/AutoDiff/simple_math.swift b/test/AutoDiff/simple_math.swift index 5fb2629ebddad..e159a886610ed 100644 --- a/test/AutoDiff/simple_math.swift +++ b/test/AutoDiff/simple_math.swift @@ -299,7 +299,6 @@ SimpleMathTests.test("StructGeneric") { generic.y = generic.x * input return generic.x * generic.y } - // FIXME(TF-274): The true expected result is `405`, like other variants of `fifthPower` above. expectEqual(405, gradient(at: 3, in: fifthPower)) }