diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index 70b9ea742328f..3c254850f479b 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -556,8 +556,6 @@ struct AutoDiffAssociatedFunctionKind { AutoDiffAssociatedFunctionKind() = default; AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {} - AutoDiffAssociatedFunctionKind(AutoDiffLinearMapKind linMapKind) - : rawValue(static_cast(linMapKind.rawValue)) {} explicit AutoDiffAssociatedFunctionKind(StringRef string); operator innerty() const { return rawValue; } AutoDiffLinearMapKind getLinearMapKind() { diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index cd2580ec842b5..1002285d47782 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -91,35 +91,13 @@ static bool isWithoutDerivative(SILValue v) { return false; } -static bool isArrayLiteralIntrinsic(ApplyInst *ai) { - return ai->hasSemantics("array.uninitialized_intrinsic"); -} - static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) { if (auto *applyInst = dyn_cast(v)) - if (isArrayLiteralIntrinsic(applyInst)) + if (applyInst->hasSemantics("array.uninitialized_intrinsic")) return applyInst; return nullptr; } -/// Given a value, find its single `destructure_tuple` user if the value is -/// tuple-typed and such a user exists. -static DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) { - bool foundDestructureTupleUser = false; - if (!value->getType().is()) - return nullptr; - DestructureTupleInst *result = nullptr; - for (auto *use : value->getUses()) { - if (auto *dti = dyn_cast(use->getUser())) { - assert(!foundDestructureTupleUser && - "There should only be one `destructure_tuple` user of a tuple"); - foundDestructureTupleUser = true; - result = dti; - } - } - return result; -} - /// Given a function, gather all of its formal results (both direct and /// indirect) in an order defined by its result type. Note that "formal results" /// refer to result values in the body of the function, not at call sites. @@ -144,23 +122,6 @@ collectAllFormalResultsInTypeOrder(SILFunction &function, : indResults[indResIdx++]); } -/// Given a function, gather all of its direct results in an order defined by -/// its result type. Note that "formal results" refer to result values in the -/// body of the function, not at call sites. -static void -collectAllDirectResultsInTypeOrder(SILFunction &function, - SmallVectorImpl &results) { - SILFunctionConventions convs(function.getLoweredFunctionType(), - function.getModule()); - auto *retInst = cast(function.findReturnBB()->getTerminator()); - auto retVal = retInst->getOperand(); - if (auto *tupleInst = dyn_cast(retVal)) - results.append(tupleInst->getElements().begin(), - tupleInst->getElements().end()); - else - results.push_back(retVal); -} - /// Given a function call site, gather all of its actual results (both direct /// and indirect) in an order defined by its result type. template @@ -295,6 +256,10 @@ static Inst *peerThroughFunctionConversions(SILValue value) { return nullptr; } +static bool isArrayLiteralIntrinsic(ApplyInst *ai) { + return ai->hasSemantics("array.uninitialized_intrinsic"); +} + //===----------------------------------------------------------------------===// // Auxiliary data structures //===----------------------------------------------------------------------===// @@ -404,7 +369,7 @@ class DifferentiableActivityInfo; class LinearMapInfo { private: /// The linear map kind. - AutoDiffLinearMapKind kind; + AutoDiffAssociatedFunctionKind kind; /// The original function. SILFunction *const original; @@ -516,13 +481,13 @@ class LinearMapInfo { // Create a branching trace enum. std::string enumName; switch (kind) { - case AutoDiffLinearMapKind::Differential: + case swift::AutoDiffAssociatedFunctionKind::JVP: enumName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + "__Succ__" + indices.mangle(); break; - case AutoDiffLinearMapKind::Pullback: + case swift::AutoDiffAssociatedFunctionKind::VJP: enumName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + @@ -581,10 +546,10 @@ class LinearMapInfo { auto &s = getADDebugStream(); std::string enumName; switch (kind) { - case AutoDiffLinearMapKind::Differential: + case AutoDiffAssociatedFunctionKind::JVP: enumName = "Predecessor"; break; - case AutoDiffLinearMapKind::Pullback: + case AutoDiffAssociatedFunctionKind::VJP: enumName = "Successor"; break; } @@ -608,13 +573,13 @@ class LinearMapInfo { std::string structName; switch (kind) { - case swift::AutoDiffLinearMapKind::Differential: + case swift::AutoDiffAssociatedFunctionKind::JVP: structName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + "__DF__" + indices.mangle(); break; - case swift::AutoDiffLinearMapKind::Pullback: + case swift::AutoDiffAssociatedFunctionKind::VJP: structName = "_AD__" + original->getName().str() + "_bb" + std::to_string(originalBB->getDebugID()) + @@ -644,10 +609,10 @@ class LinearMapInfo { auto &s = getADDebugStream(); std::string structName; switch (kind) { - case AutoDiffLinearMapKind::Differential: + case AutoDiffAssociatedFunctionKind::JVP: structName = "Differential"; break; - case AutoDiffLinearMapKind::Pullback: + case AutoDiffAssociatedFunctionKind::VJP: structName = "Pullback"; break; } @@ -680,10 +645,10 @@ class LinearMapInfo { auto *linMapStruct = getLinearMapStruct(origBB); std::string linearMapName; switch (kind) { - case AutoDiffLinearMapKind::Differential: + case swift::AutoDiffAssociatedFunctionKind::JVP: linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size()); break; - case AutoDiffLinearMapKind::Pullback: + case swift::AutoDiffAssociatedFunctionKind::VJP: linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size()); break; } @@ -710,7 +675,7 @@ class LinearMapInfo { LinearMapInfo &operator=(const LinearMapInfo &) = delete; explicit LinearMapInfo(ADContext &context, - AutoDiffLinearMapKind kind, + AutoDiffAssociatedFunctionKind kind, SILFunction *original, SILFunction *assocFn, const SILAutoDiffIndices &indices, const DifferentiableActivityInfo &activityInfo, @@ -1378,12 +1343,6 @@ using Activity = OptionSet; /// indices. class DifferentiableActivityInfo { private: - // TODO(TF-800): Temporarily store `AutoDiffAssociatedFunctionKind` because - // special logic for `apply` result does not work for reverse-mode. - - // with us handling `apply` instructions differently. - AutoDiffAssociatedFunctionKind kind; - DifferentiableActivityCollection &parent; GenericSignature *assocGenSig = nullptr; @@ -1418,8 +1377,7 @@ class DifferentiableActivityInfo { public: explicit DifferentiableActivityInfo( - DifferentiableActivityCollection &parent, GenericSignature *assocGenSig, - AutoDiffAssociatedFunctionKind kind); + DifferentiableActivityCollection &parent, GenericSignature *assocGenSig); bool isVaried(SILValue value, unsigned independentVariableIndex) const; bool isUseful(SILValue value, unsigned dependentVariableIndex) const; @@ -1521,7 +1479,7 @@ static void collectMinimalIndicesForFunctionCall( } LinearMapInfo::LinearMapInfo(ADContext &context, - AutoDiffLinearMapKind kind, + AutoDiffAssociatedFunctionKind kind, SILFunction *original, SILFunction *assocFn, const SILAutoDiffIndices &indices, const DifferentiableActivityInfo &activityInfo, @@ -1551,19 +1509,7 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) { activityInfo.isActive(paramArgs[i], indices)) return true; - // TODO(TF-800): Investigate why `apply` result special logic does not work - // for reverse-mode. - if (kind == AutoDiffLinearMapKind::Differential) { - for (auto use : ai->getUses()) { - if (auto *dti = dyn_cast(use->getUser())) { - for (auto result : dti->getResults()) { - if (activityInfo.isActive(result, indices)) - return true; - } - } - } - } - + // TODO(bartchr): Check `destructure_tuple` user's results' acvitity. bool hasActiveDirectResults = activityInfo.isActive(ai, indices); bool hasActiveIndirectResults = llvm::any_of(ai->getIndirectSILResults(), [&](SILValue result) { return activityInfo.isActive(result, indices); }); @@ -1579,12 +1525,10 @@ bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) { return hasActiveResults && hasActiveParamArguments; } -// TODO(TF-800): Investigate why reverse-mode requires special "should -// differentiate logic" and update comment. -/// Returns a flag indicating whether the instruction should be differentiated, -/// given the differentiation indices of the instruction's parent function. -/// Whether the instruction should be differentiated is determined sequentially -/// from the following conditions: +/// Returns a flag that indicates whether the instruction should be +/// differentiated, given the differentiation indices of the instruction's +/// parent function. Whether the instruction should be differentiated is +/// determined sequentially from the following conditions: /// 1. The instruction is an `apply` and `shouldDifferentiateApplyInst` returns /// true. /// 2. The instruction has an active operand and an active result. @@ -1602,70 +1546,17 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) { [&](SILValue val) { return activityInfo.isActive(val, indices); }); if (hasActiveOperands && hasActiveResults) return true; - - // TODO(TF-800): Investigate why reverse-mode requires special "should - // differentiate logic" and update comment. - switch (kind) { - case AutoDiffLinearMapKind::Differential: { - -#define CHECK_INST_TYPE_ACTIVE_OPERANDS(TYPE) \ -if (isa(inst) && hasActiveOperands) \ - return true; - -#define CHECK_INST_TYPE_ACTIVE_DEST(TYPE) \ -if (auto *castInst = dyn_cast(inst)) { \ - return activityInfo.isActive(castInst->getDest(), indices); \ -} - - CHECK_INST_TYPE_ACTIVE_DEST(StoreInst) - CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrowInst) - CHECK_INST_TYPE_ACTIVE_DEST(CopyAddrInst) - if ((isa(inst) && hasActiveResults)) - return true; - CHECK_INST_TYPE_ACTIVE_OPERANDS(RefCountingInst) - CHECK_INST_TYPE_ACTIVE_OPERANDS(EndAccessInst) - CHECK_INST_TYPE_ACTIVE_OPERANDS(EndBorrowInst) - CHECK_INST_TYPE_ACTIVE_OPERANDS(DeallocationInst) - CHECK_INST_TYPE_ACTIVE_OPERANDS(CopyValueInst) - CHECK_INST_TYPE_ACTIVE_OPERANDS(DestroyValueInst) - CHECK_INST_TYPE_ACTIVE_OPERANDS(DestroyAddrInst) - break; - -#undef CHECK_INST_TYPE_ACTIVE_OPERANDS -#undef CHECK_INST_TYPE_ACTIVE_DEST - } - case AutoDiffLinearMapKind::Pullback: { - if (inst->mayHaveSideEffects() && hasActiveOperands) - return true; - break; - } - } - + if (inst->mayHaveSideEffects() && hasActiveOperands) + return true; return false; } -/// Given an `apply` instruction, conditionally adds its linear map function to the +/// Takes an `apply` instruction and adds its linear map function to the /// linear map struct if it's active. void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, const SILAutoDiffIndices &indices) { SmallVector allResults; - // TODO(TF-800): Investigate why `apply` result special logic does not work - // for reverse-mode. - // If differential, handle `apply` result specially. - // If `apply` result is tuple-typed with a `destructure_tuple` user, add the - // results of the `destructure_tuple` user to `allResults` instead of adding - // the `apply` result itself. - bool isDifferentialAndFoundDestructureTupleUser = false; - if (kind == AutoDiffLinearMapKind::Differential) { - if (auto *dti = getSingleDestructureTupleUser(ai)) { - isDifferentialAndFoundDestructureTupleUser = true; - for (auto result : dti->getResults()) - allResults.push_back(result); - } - } - // Otherwise, add `apply` result to `allResults`. - if (!isDifferentialAndFoundDestructureTupleUser) - allResults.push_back(ai); + allResults.push_back(ai); allResults.append(ai->getIndirectSILResults().begin(), ai->getIndirectSILResults().end()); @@ -1682,20 +1573,18 @@ void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, if (!hasActiveResults || !hasActiveArguments) return; + unsigned source; + AutoDiffIndexSubset *parameters; SmallVector activeParamIndices; SmallVector activeResultIndices; collectMinimalIndicesForFunctionCall( ai, allResults, indices, activityInfo, activeParamIndices, activeResultIndices); + source = activeResultIndices.front(); - // Compute differentiation result index. - auto source = activeResultIndices.front(); - // Compute differentiation parameters. - // - If the callee has `@differentiable` function type, use differentiation - // parameters from the function type. - // - Otherwise, use the active parameters. - AutoDiffIndexSubset *parameters; + // If function is already marked differentiable, differentiate W.R.T. + // all parameters. auto originalFnSubstTy = ai->getSubstCalleeType(); if (originalFnSubstTy->isDifferentiable()) { parameters = originalFnSubstTy->getDifferentiationParameterIndices(); @@ -1705,22 +1594,25 @@ void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); } - // Create autodiff indices for the `apply` instruction. - SILAutoDiffIndices applyIndices(source, parameters); + SILAutoDiffIndices curIndices(activeResultIndices.front(), + AutoDiffIndexSubset::get( + builder.getASTContext(), + ai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices)); // Check for non-differentiable original function type. auto checkNondifferentiableOriginalFunctionType = [&](CanSILFunctionType origFnTy) { // Check and diagnose non-differentiable arguments. for (unsigned paramIndex : range(origFnTy->getNumParameters())) { - if (applyIndices.isWrtParameter(paramIndex) && + if (curIndices.isWrtParameter(paramIndex) && !origFnTy->getParameters()[paramIndex] .getSILStorageType() .isDifferentiable(builder.getModule())) return true; } // Check non-differentiable results. - if (!origFnTy->getResults()[applyIndices.source] + if (!origFnTy->getResults()[curIndices.source] .getSILStorageType() .isDifferentiable(builder.getModule())) return true; @@ -1729,10 +1621,8 @@ void LinearMapInfo::addLinearMapToStruct(ApplyInst *ai, if (checkNondifferentiableOriginalFunctionType(originalFnSubstTy)) return; - AutoDiffAssociatedFunctionKind assocFnKind(kind); auto assocFnType = originalFnSubstTy->getAutoDiffAssociatedFunctionType( - parameters, source, /*differentiationOrder*/ 1, assocFnKind, - builder.getModule(), + parameters, source, /*differentiationOrder*/ 1, kind, builder.getModule(), LookUpConformanceInModule(builder.getModule().getSwiftModule())); auto assocFnResultTypes = @@ -1826,13 +1716,12 @@ class DifferentiableActivityCollection { DominanceInfo *domInfo; PostDominanceInfo *postDomInfo; - DifferentiableActivityInfo &getActivityInfo( - GenericSignature *assocGenSig, AutoDiffAssociatedFunctionKind kind) { + DifferentiableActivityInfo &getActivityInfo(GenericSignature *assocGenSig) { auto activityInfoLookup = activityInfoMap.find(assocGenSig); if (activityInfoLookup != activityInfoMap.end()) return activityInfoLookup->getSecond(); auto insertion = activityInfoMap.insert( - {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig, kind)}); + {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)}); return insertion.first->getSecond(); } @@ -1865,9 +1754,8 @@ DifferentiableActivityCollection::DifferentiableActivityCollection( : function(f), domInfo(di), postDomInfo(pdi) {} DifferentiableActivityInfo::DifferentiableActivityInfo( - DifferentiableActivityCollection &parent, GenericSignature *assocGenSig, - AutoDiffAssociatedFunctionKind kind) - : kind(kind), parent(parent), assocGenSig(assocGenSig) { + DifferentiableActivityCollection &parent, GenericSignature *assocGenSig) + : parent(parent), assocGenSig(assocGenSig) { analyze(parent.domInfo, parent.postDomInfo); } @@ -1912,24 +1800,8 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, if (isVaried(arg, i)) { for (auto indRes : ai->getIndirectSILResults()) setVaried(indRes, i); - // TODO(TF-800): Investigate why `apply` result special logic - // does not work for reverse-mode. - // If differential, handle `apply` result specially. - // If JVP, handle `apply` result specially. - // If `apply` result is tuple-typed with a `destructure_tuple` - // user, mark the results of the `destructure_tuple` user as - // varied instead of marking the `apply` result itself. - bool isJVPAndFoundDestructureTupleUser = false; - if (kind == swift::AutoDiffAssociatedFunctionKind::JVP) { - if (auto *dti = getSingleDestructureTupleUser(ai)) { - for (auto result : dti->getResults()) - setVaried(result, i); - isJVPAndFoundDestructureTupleUser = true; - } - } - // Otherwise, mark the `apply` result as varied. - if (!isJVPAndFoundDestructureTupleUser) - setVaried(ai, i); + for (auto dirRes : ai->getResults()) + setVaried(dirRes, i); } } } @@ -3328,8 +3200,7 @@ class VJPEmitter final passManager.getAnalysis(); auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( - vjp->getLoweredFunctionType()->getGenericSignature(), - AutoDiffAssociatedFunctionKind::VJP); + vjp->getLoweredFunctionType()->getGenericSignature()); LLVM_DEBUG( dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; @@ -3343,7 +3214,7 @@ class VJPEmitter final context(context), original(original), attr(attr), vjp(vjp), invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), vjp)), - pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, + pullbackInfo(context, AutoDiffAssociatedFunctionKind::VJP, original, vjp, attr->getIndices(), activityInfo, getBuilder()) { // Create empty pullback function. pullback = createEmptyPullback(); @@ -3747,15 +3618,7 @@ class VJPEmitter final // Get the parameter indices required for differentiating this function. SmallVector allResults; - // Only append the results from the `destruct_tuple` instruction which are - // active, we don't consider the result of the original apply if it's a - // tuple. - if (auto *dti = getSingleDestructureTupleUser(ai)) { - for (auto result : dti->getResults()) - allResults.push_back(result); - } else { - allResults.push_back(ai); - } + allResults.push_back(ai); allResults.append(ai->getIndirectSILResults().begin(), ai->getIndirectSILResults().end()); SmallVector activeParamIndices; @@ -3779,12 +3642,11 @@ class VJPEmitter final errorOccurred = true; return; } - - // Form expected indices, assuming there's only one result. - SILAutoDiffIndices indices( - activeResultIndices.front(), + // Form expected indices by assuming there's only one result. + SILAutoDiffIndices indices(activeResultIndices.front(), AutoDiffIndexSubset::get( - getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), + getASTContext(), + ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices)); // Emit the VJP. @@ -3843,7 +3705,7 @@ class VJPEmitter final // If VJP has not yet been found, emit an `autodiff_function` instruction // on the remapped original function operand and `autodiff_function_extract` - // the VJP. The actual VJP functions will be populated in the + // the VJP. The actual JVP/VJP functions will be populated in the // `autodiff_function` during the transform main loop. if (!vjpValue) { // FIXME: Handle indirect differentiation invokers. This may require some @@ -4190,8 +4052,8 @@ class JVPEmitter final /// Mapping from differential struct field declarations to differential struct /// elements destructured from the linear map basic block argument. In the - /// beginning of each differential basic block, the block's differential - /// struct is destructured into the individual elements stored here. + /// beginning of each differential basic block, the block's differential struct is + /// destructured into individual elements stored here. DenseMap differentialStructElements; /// Mapping from original basic blocks and original values to corresponding @@ -4239,12 +4101,8 @@ class JVPEmitter final static SubstitutionMap getSubstitutionMap(SILFunction *original, SILFunction *jvp) { auto substMap = original->getForwardingSubstitutionMap(); - if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { - auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); - substMap = SubstitutionMap::get( - jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, - LookUpConformanceInSubstitutionMap(jvpSubstMap)); - } + if (auto *jvpGenEnv = jvp->getGenericEnvironment()) + substMap = substMap.subst(jvpGenEnv->getForwardingSubstitutionMap()); return substMap; } @@ -4258,16 +4116,14 @@ class JVPEmitter final passManager.getAnalysis(); auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( - jvp->getLoweredFunctionType()->getGenericSignature(), - AutoDiffAssociatedFunctionKind::JVP); + jvp->getLoweredFunctionType()->getGenericSignature()); LLVM_DEBUG( dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; } static SILBuilder - initializeDifferentialAndBuilder(ADContext &context, SILFunction *original, - SILDifferentiableAttr *attr, + initializeDifferentialAndBuilder(ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, LinearMapInfo *linearMapInfo) { auto *differential = createEmptyDifferential(context, original, attr, linearMapInfo); @@ -4291,8 +4147,7 @@ class JVPEmitter final auto insertion = differentialStructElements.insert({std::get<0>(pair), std::get<1>(pair)}); (void)insertion; - assert(insertion.second && - "A differential struct element mapping already exists!"); + assert(insertion.second && "A differential struct element already exists!"); } } @@ -4355,8 +4210,7 @@ class JVPEmitter final //--------------------------------------------------------------------------// AdjointValue makeZeroTangentValue(SILType type) { - return AdjointValue::createZero( - allocator, remapSILTypeInDifferential(type)); + return AdjointValue::createZero(allocator, remapType(type)); } AdjointValue makeConcreteTangentValue(SILValue value) { @@ -4442,8 +4296,7 @@ class JVPEmitter final assert(originalBuffer->getType().isAddress()); auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); - assert(insertion.second && "tangent buffer already exists."); - (void)insertion; + assert(insertion.second); (void)insertion; } SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { @@ -4456,52 +4309,30 @@ class JVPEmitter final } //--------------------------------------------------------------------------// - // Differential type calculations + // Type transformer //--------------------------------------------------------------------------// - /// Substitutes all replacement types of the given substitution map using the - /// tangent function's substitution map. - SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) { - return substMap.subst(getDifferential().getForwardingSubstitutionMap()); - } - - /// Remap any archetypes into the differential function's context. - Type remapTypeInDifferential(Type ty) { - if (ty->hasArchetype()) - return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext()); - return getDifferential().mapTypeIntoContext(ty); - } - - /// Remap any archetypes into the differential function's context. - SILType remapSILTypeInDifferential(SILType ty) { - if (ty.hasArchetype()) - return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext()); - return getDifferential().mapTypeIntoContext(ty); - } - - /// Find the tangent space of a given canonical type. Optional getTangentSpace(CanType type) { return type->getAutoDiffAssociatedTangentSpace( LookUpConformanceInModule(getModule().getSwiftModule())); } /// Assuming the given type conforms to `Differentiable` after remapping, - /// returns the associated tangent space SIL type. + /// returns the associated tangent space type. SILType getRemappedTangentType(SILType type) { return SILType::getPrimitiveType( - getTangentSpace(remapSILTypeInDifferential(type).getASTType()) - ->getCanonicalType(), + getTangentSpace(remapType(type).getASTType())->getCanonicalType(), type.getCategory()); } //--------------------------------------------------------------------------// - // Tangent value mapping + // Tngent value mapping //--------------------------------------------------------------------------// /// Get the tangent for an original value. The given value must be in the /// original function. /// - /// This method first tries to find an entry in `tangentValueMap`. If an entry + /// This method first tries to find an entry in `tangentValueMap`. If a tangent /// doesn't exist, create a zero tangent. AdjointValue getTangentValue(SILValue originalValue) { assert(originalValue->getType().isObject()); @@ -4515,14 +4346,6 @@ class JVPEmitter final /// Map the tangent value to the given original value. void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, AdjointValue newTangentValue) { - if (auto *defInst = originalValue->getDefiningInstruction()) { - bool isTupleTypedApplyResult = - isa(defInst) && originalValue->getType().is(); - assert(!isTupleTypedApplyResult && - "Should not set tangent value for tuple-typed result from `apply` " - "instruction; use `destructure_tuple` on `apply` result and set " - "tangent value for `destructure_tuple` results instead."); - } assert(originalValue->getType().isObject()); assert(newTangentValue.getType().isObject()); assert(originalValue->getFunction() == original); @@ -4539,16 +4362,15 @@ class JVPEmitter final //--------------------------------------------------------------------------// // Tangent emission helpers //--------------------------------------------------------------------------// -public: -#define CLONE_AND_EMIT_TANGENT(INST, ID) \ - void visit##INST##Inst(INST##Inst *inst) { \ - TypeSubstCloner::visit##INST##Inst(inst); \ - if (differentialInfo.shouldDifferentiateInstruction(inst)) \ - emitTangentFor##INST##Inst(inst); \ - } \ - void emitTangentFor##INST##Inst(INST##Inst *(ID)) - CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) { + void emitTangentForDestroyValueInst(DestroyValueInst *dvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = dvi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); + diffBuilder.emitDestroyValue(loc, tanVal); + } + + void emitTangentForBeginBorrow(BeginBorrowInst *bbi) { auto &diffBuilder = getDifferentialBuilder(); auto loc = bbi->getLoc(); auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); @@ -4557,21 +4379,14 @@ class JVPEmitter final makeConcreteTangentValue(tanValBorrow)); } - CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) { + void emitTangentForEndBorrow(EndBorrowInst *ebi) { auto &diffBuilder = getDifferentialBuilder(); auto loc = ebi->getLoc(); auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); diffBuilder.emitEndBorrowOperation(loc, tanVal); } - CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = dvi->getLoc(); - auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); - diffBuilder.emitDestroyValue(loc, tanVal); - } - - CLONE_AND_EMIT_TANGENT(CopyValue, cvi) { + void emitTangentForCopyValueInst(CopyValueInst *cvi) { auto &diffBuilder = getDifferentialBuilder(); auto tan = getTangentValue(cvi->getOperand()); auto tanVal = materializeTangent(tan, cvi->getLoc()); @@ -4580,437 +4395,44 @@ class JVPEmitter final makeConcreteTangentValue(tanValCopy)); } - /// Handle `struct_extract` instruction. - /// Original: y = struct_extract x, #field - /// Tangent: tan[y] = struct_extract tan[x], tan[#field]] - CLONE_AND_EMIT_TANGENT(StructExtract, sei) { - assert(!sei->getField()->getAttrs().hasAttribute() && - "`struct_extract` with `@noDerivative` field should not be " - "differentiated; activity analysis should not marked as varied."); - - auto diffBuilder = getDifferentialBuilder();; - auto tangentVectorTy = - getRemappedTangentType(sei->getOperand()->getType()); - auto *tangentVectorDecl = - tangentVectorTy.getStructOrBoundGenericStruct(); - - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == sei->getStructDecl()) - tanField = sei->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(sei->getField()->getName()); - if (tanFieldLookup.empty()) { - context.emitNondifferentiabilityError( - sei, invoker, - diag::autodiff_stored_property_no_corresponding_tangent, - sei->getStructDecl()->getNameStr(), - sei->getField()->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); - } - // Emit tangent `struct_extract`. - auto tanStruct = - materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc()); - auto tangentInst = - diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField); - // Update tangent value mapping for `struct_extract` result. - auto tangentResult = makeConcreteTangentValue(tangentInst); - setTangentValue(sei->getParent(), sei, tangentResult); - } - - /// Handle `struct_element_addr` instruction. - /// Original: y = struct_element_addr x, #field - /// Tangent: tan[y] = struct_element_addr tan[x], tan[#field] - CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) { - assert(!seai->getField()->getAttrs().hasAttribute() && - "`struct_element_addr` with `@noDerivative` field should not be " - "differentiated; activity analysis should not marked as varied."); - - auto diffBuilder = getDifferentialBuilder(); - auto *bb = seai->getParent(); - auto tangentVectorTy = - getRemappedTangentType(seai->getOperand()->getType()); - auto *tangentVectorDecl = - tangentVectorTy.getStructOrBoundGenericStruct(); - - // Find the corresponding field in the tangent space. - VarDecl *tanField = nullptr; - // If the tangent space is the original struct, then field is the same. - if (tangentVectorDecl == seai->getStructDecl()) - tanField = seai->getField(); - // Otherwise, look up the field by name. - else { - auto tanFieldLookup = - tangentVectorDecl->lookupDirect(seai->getField()->getName()); - if (tanFieldLookup.empty()) { - context.emitNondifferentiabilityError( - seai, invoker, - diag::autodiff_stored_property_no_corresponding_tangent, - seai->getStructDecl()->getNameStr(), - seai->getField()->getNameStr()); - errorOccurred = true; - return; - } - tanField = cast(tanFieldLookup.front()); - } - - // Emit tangent `struct_element_addr`. - auto tanOperand = getTangentBuffer(bb, seai->getOperand()); - auto tangentInst = diffBuilder.createStructElementAddr( - seai->getLoc(), tanOperand, tanField); - // Update tangent buffer map for `struct_element_addr`. - setTangentBuffer(bb, seai, tangentInst); - } - - /// Handle `load` instruction. - /// Original: y = load x - /// Tangent: tan[y] = load tan[x] - CLONE_AND_EMIT_TANGENT(Load, li) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = li->getParent(); - auto loc = li->getLoc(); - auto tanBuf = getTangentBuffer(bb, li->getOperand()); - auto tanVal = diffBuilder.emitLoadValueOperation( - loc, tanBuf, li->getOwnershipQualifier()); - setTangentValue(bb, li, makeConcreteTangentValue(tanVal)); - } - - /// Handle `load_borrow` instruction. - /// Original: y = load_borrow x - /// Tangent: tan[y] = load_borrow tan[x] - CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = lbi->getParent(); - auto loc = lbi->getLoc(); - auto tanBuf = getTangentBuffer(bb, lbi->getOperand()); - auto tanVal = diffBuilder.emitLoadBorrowOperation( - loc, tanBuf); - setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal)); - } - - /// Handle `store` instruction in the differential. - /// Original: store x to y - /// Tangent: store tan[x] to tan[y] - CLONE_AND_EMIT_TANGENT(Store, si) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = si->getLoc(); - auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc); - auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest()); - if (errorOccurred) - return; - diffBuilder.emitStoreValueOperation( - loc, tanValSrc, tanValDest, si->getOwnershipQualifier()); - } - - /// Handle `store_borrow` instruction in the differential. - /// Original: store_borrow x to y - /// Tangent: store_borrow tan[x] to tan[y] - CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = sbi->getLoc(); - auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc); - auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest()); - if (errorOccurred) - return; - diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest); - } - - /// Handle `copy_addr` instruction. - /// Original: copy_addr x to y - /// Tangent: copy_addr tan[x] to tan[y] - CLONE_AND_EMIT_TANGENT(CopyAddr, cai) { - auto *diffGenEnv = getDifferential().getGenericEnvironment(); - auto diffGenSig = diffGenEnv - ? diffGenEnv->getGenericSignature()->getCanonicalSignature() - : nullptr; - Lowering::GenericContextScope genericContextScope( - context.getTypeConverter(), diffGenSig); - - auto diffBuilder = getDifferentialBuilder(); - auto loc = cai->getLoc(); - auto *bb = cai->getParent(); - auto &tanSrc = getTangentBuffer(bb, cai->getSrc()); - auto tanDest = getTangentBuffer(bb, cai->getDest()); - if (errorOccurred) - return; - - diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(), - cai->isInitializationOfDest()); - } - - /// Handle `begin_access` instruction (and do differentiability checks). - /// Original: y = begin_access x - /// Tangent: tan[y] = begin_access tan[x] - CLONE_AND_EMIT_TANGENT(BeginAccess, bai) { - // Check for non-differentiable writes. - if (bai->getAccessKind() == SILAccessKind::Modify) { - if (auto *gai = dyn_cast(bai->getSource())) { - context.emitNondifferentiabilityError(bai, invoker, - diag::autodiff_cannot_differentiate_writes_to_global_variables); - errorOccurred = true; - return; - } - if (auto *pbi = dyn_cast(bai->getSource())) { - context.emitNondifferentiabilityError(bai, invoker, - diag::autodiff_cannot_differentiate_writes_to_mutable_captures); - errorOccurred = true; - return; - } - } - - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = bai->getParent(); - - auto tanSrc = getTangentBuffer(bb, bai->getSource()); - auto *tanDest = diffBuilder.createBeginAccess( - bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(), - bai->hasNoNestedConflict(), bai->isFromBuiltin()); - setTangentBuffer(bb, bai, tanDest); - } - - /// Handle `end_access` instruction. - /// Original: begin_access x - /// Tangent: end_access tan[x] - CLONE_AND_EMIT_TANGENT(EndAccess, eai) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = eai->getParent(); - auto loc = eai->getLoc(); - auto tanSrc = getTangentBuffer(bb, eai->getOperand()); - diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting()); - } - - /// Handle `alloc_stack` instruction. - /// Original: y = alloc_stack $T - /// Tangent: tan[y] = alloc_stack $T.Tangent - CLONE_AND_EMIT_TANGENT(AllocStack, asi) { - auto &diffBuilder = getDifferentialBuilder(); - auto *mappedAllocStackInst = diffBuilder.createAllocStack( - asi->getLoc(), getRemappedTangentType(asi->getElementType())); - bufferMap.try_emplace({asi->getParent(), asi}, - mappedAllocStackInst); - } - - /// Handle `dealloc_stack` instruction. - /// Original: dealloc_stack x - /// Tangent: dealloc_stack tan[x] - CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) { - auto &diffBuilder = getDifferentialBuilder(); - auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand()); - diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf); - } - - /// Handle `destroy_addr` instruction. - /// Original: destroy_addr x - /// Tangent: destroy_addr tan[x] - CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) { - auto &diffBuilder = getDifferentialBuilder(); - auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand()); - diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf); - } - - /// Handle `struct` instruction. - /// Original: y = struct $T (x0, x1, x2, ...) - /// Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...) - CLONE_AND_EMIT_TANGENT(Struct, si) { - auto &diffBuilder = getDifferentialBuilder(); - SmallVector tangentElements; - for (auto elem : si->getElements()) - tangentElements.push_back(getTangentValue(elem).getConcreteValue()); - auto tanExtract = diffBuilder.createStruct( - si->getLoc(), getRemappedTangentType(si->getType()), tangentElements); - setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract)); - } - - /// Handle `tuple` instruction. - /// Original: y = tuple (x0, x1, x2, ...) - /// Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...) - CLONE_AND_EMIT_TANGENT(Tuple, ti) { + void emitTangentForReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); auto diffBuilder = getDifferentialBuilder(); - - // Get the tangents of all the tuple elements. - SmallVector tangentTupleElements; - for (auto elem : ti->getElements()) { - tangentTupleElements.push_back( - materializeTangent(getTangentValue(elem), ti->getLoc())); - } - - // Emit the instruction and add the tangent mapping. - auto tanTuple = diffBuilder.createTuple(ti->getLoc(), tangentTupleElements); - setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple)); - } - - /// Handle `tuple_extract` instruction. - /// Original: y = tuple_element_addr x, - /// Tangent: tan[y] = tuple_element_addr tan[x], - CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) { - auto &diffBuilder = getDifferentialBuilder(); - auto origTupleTy = teai->getOperand()->getType().castTo(); - unsigned tanIndex = 0; - for (unsigned i : range(teai->getFieldNo())) { - if (getTangentSpace( - origTupleTy->getElement(i).getType()->getCanonicalType())) - ++tanIndex; - } - auto tanType = getRemappedTangentType(teai->getType()); - auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand()); - SILValue tanBuf; - // If the tangent buffer of the source does not have a tuple type, then - // it must represent a "single element tuple type". Use it directly. - if (!tanSource->getType().is()) { - tanBuf = tanSource; - } else { - tanBuf = diffBuilder.createTupleElementAddr( - teai->getLoc(), tanSource, tanIndex, tanType); - } - bufferMap.try_emplace({teai->getParent(), teai}, tanBuf); - } - - /// Handle `tuple_extract` instruction. - /// Original: y = tuple_extract x, - /// Tangent: tan[y] = tuple_extract tan[x], - CLONE_AND_EMIT_TANGENT(TupleExtract, tei) { - auto &diffBuilder = getDifferentialBuilder(); - auto loc = tei->getLoc(); - auto origTupleTy = tei->getOperand()->getType().castTo(); - unsigned tanIndex = 0; - for (unsigned i : range(tei->getFieldNo())) { - if (getTangentSpace( - origTupleTy->getElement(i).getType()->getCanonicalType())) - ++tanIndex; - } - auto tanType = getRemappedTangentType(tei->getType()); - auto tanSource = materializeTangent( - getTangentValue(tei->getOperand()), loc); - SILValue tanBuf; - // If the tangent buffer of the source does not have a tuple type, then - // it must represent a "single element tuple type". Use it directly. - if (!tanSource->getType().is()) { - setTangentValue(tei->getParent(), tei, - makeConcreteTangentValue(tanSource)); - } else { - tanBuf = diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType); - bufferMap.try_emplace({tei->getParent(), tei}, tanBuf); - } - } - - /// Handle `destructure_tuple` instruction. - /// Original: (y0, y1, y2, ...) = destructure_tuple x, - /// Tangent: (tan[y0], tan[y1], tan[y2], ...) = destructure_tuple tan[x], - CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) { - auto &diffBuilder = getDifferentialBuilder(); - auto *bb = dti->getParent(); - auto loc = dti->getLoc(); - - SmallVector activeOrigResults; - bool hasActiveResult = false; - for (auto result : dti->getResults()) { - if (activityInfo.isActive(result, getIndices())) { - activeOrigResults.push_back(result); - hasActiveResult = true; - break; - } - } - assert(!activeOrigResults.empty() && - "original 'destructure_tuple' should have at least one active " - "result"); - - auto tanTuple = - materializeTangent(getTangentValue(dti->getOperand()), loc); - auto *tupleElements = diffBuilder.createDestructureTuple(loc, tanTuple); - for (auto i : range(tupleElements->getNumResults())) { - auto origElem = dti->getResult(i); - auto tanElem = tupleElements->getResult(i); - setTangentValue(bb, origElem, makeConcreteTangentValue(tanElem)); - } + // This vector will contain all the materialized return elements. + SmallVector retElts; + // This vector will contain all indirect parameter tangent buffers. + // TODO: Handle indirect results. + auto tanParam = + materializeTangent(getTangentValue(ri->getOperand()), loc); + diffBuilder.createReturn(ri->getLoc(), tanParam); } -#undef CLONE_AND_EMIT_TANGENT - - /// Handle `apply` instruction. - /// Original: y = apply f(x) - /// Tangent: tan[y] = apply diff_f(tan[x]) void emitTangentForApplyInst(ApplyInst *ai, - const SILAutoDiffIndices &actualIndices, - CanSILFunctionType originalDifferentialType) { + SILAutoDiffIndices &actualIndices) { assert(differentialInfo.shouldDifferentiateApplyInst(ai)); auto *bb = ai->getParent(); auto loc = ai->getLoc(); - auto &diffBuilder = getDifferentialBuilder(); + auto diffBuilder = getDifferentialBuilder(); - // Get the differential value. + // Get the differential. auto *field = differentialInfo.lookUpLinearMapDecl(ai); assert(field); SILValue differential = getDifferentialStructElement(bb, field); - auto differentialType = remapSILTypeInDifferential(differential->getType()) - .castTo(); - // Get the differential arguments. SmallVector diffArgs; - - for (auto indRes : ai->getIndirectSILResults()) - diffArgs.push_back(getTangentBuffer(bb, indRes)); - - auto paramArgs = ai->getArgumentsWithoutIndirectResults(); - // Get the tangent value of the original arguments. - for (auto i : indices(paramArgs)) { - auto origArg = paramArgs[i]; - // If the argument is not active: - // - Skip the element, if it is not differentiable. - // - Otherwise, add a zero value to that location. - if (!activityInfo.isActive(origArg, getIndices())) { - auto origCalleeType = ai->getSubstCalleeType(); - if (!origCalleeType->isDifferentiable()) - continue; - auto actualOrigCalleeIndices = - origCalleeType->getDifferentiationParameterIndices(); - if (actualOrigCalleeIndices->contains(i)) { - SILValue tanParam; - if (origArg->getType().isObject()) { - tanParam = emitZeroDirect( - getRemappedTangentType(origArg->getType()).getASTType(), loc); - diffArgs.push_back(tanParam); - } else { - tanParam = diffBuilder.createAllocStack( - loc, getRemappedTangentType(origArg->getType())); - emitZeroIndirect( - getRemappedTangentType(origArg->getType()).getASTType(), tanParam, - loc); - } - } - } - // Otherwise, if the argument is active, handle the argument normally by - // getting its tangent value. - else { - SILValue tanParam; - if (origArg->getType().isObject()) { - tanParam = materializeTangent(getTangentValue(origArg), loc); - } else { - tanParam = getTangentBuffer(ai->getParent(), origArg); - } - diffArgs.push_back(tanParam); + for (auto origArg : ai->getArguments()) { + // Get the tangent value of the original parameter. + if (!activityInfo.isActive(origArg, getIndices())) + continue; + SILValue tanParam; + if (origArg->getType().isObject()) { + tanParam = materializeTangent(getTangentValue(origArg), loc); + } else { + tanParam = getTangentBuffer(ai->getParent(), origArg); if (errorOccurred) return; } - } - - // If callee differential was reabstracted in JVP, reabstract the callee - // differential. - if (!differentialType->isEqual(originalDifferentialType)) { - SILOptFunctionBuilder fb(context.getTransform()); - auto *thunk = getOrCreateReabstractionThunk( - fb, context.getModule(), loc, &getDifferential(), - differentialType, originalDifferentialType); - auto *thunkRef = diffBuilder.createFunctionRef(loc, thunk); - differential = diffBuilder.createPartialApply( - loc, thunkRef, - remapSubstitutionMapInDifferential(thunk->getForwardingSubstitutionMap()), - {differential}, differentialType->getCalleeConvention()); + diffArgs.push_back(tanParam); } // Call the differential. @@ -5021,198 +4443,23 @@ class JVPEmitter final assert(differentialCall->getNumResults() == 1 && "Expected differential to return one result"); - // Get the original results of the `apply` instructions. - SmallVector origDirResults; - collectAllExtractedElements(ai, origDirResults); - SmallVector origAllResults; - collectAllActualResultsInTypeOrder( - ai, origDirResults, ai->getIndirectSILResults(), origAllResults); - auto origResult = origAllResults[actualIndices.source]; + // TODO: Generalize for indirect results, multiple results, etc. + auto origResult = ai->getResult(actualIndices.source); - // Get the differential results of the `apply` instructions. + // Extract all direct results from the differential. SmallVector differentialDirResults; - collectAllExtractedElements(differentialCall, differentialDirResults); + extractAllElements(differentialCall, diffBuilder, differentialDirResults); + // Get all differential results in type-defined order. SmallVector differentialAllResults; collectAllActualResultsInTypeOrder( differentialCall, differentialDirResults, differentialCall->getIndirectSILResults(), differentialAllResults); - auto differentialResult = differentialAllResults.front(); + auto differentialResult = differentialAllResults[actualIndices.source]; // Add tangent for original result. - if (origResult->getType().isObject()) { - if (!origResult->getType().is()) { - setTangentValue(bb, origResult, - makeConcreteTangentValue(differentialResult)); - } else if (auto *dti = getSingleDestructureTupleUser(ai)) { - bool notSetValue = true; - for (auto result : dti->getResults()) { - if (activityInfo.isActive(result, getIndices())) { - assert(notSetValue && - "This was incorrectly set, should only have one active " - "result from the tuple."); - notSetValue = false; - setTangentValue(bb, result, - makeConcreteTangentValue(differentialResult)); - } - } - } - } - } - - /// Generate a `return` instruction in the current differential basic block. - void emitReturnInstForDifferential() { - auto &differential = getDifferential(); - auto diffLoc = differential.getLocation(); - auto &diffBuilder = getDifferentialBuilder(); - - SmallVector activeResults; - - // This vector will contain all the materialized return elements. - SmallVector retElts; - SmallVector originalResults; - collectAllDirectResultsInTypeOrder(*original, originalResults); - - // Materializes the return element corresponding to the result - // `resultIndex` into the `retElts` vector. - auto addActiveResult = [&](unsigned resultIndex) -> void { - auto origResult = originalResults[resultIndex]; - assert(origResult->getType().isObject() && - "Should only be handling direct results for 'return' " - "instruction."); - if (activityInfo.isActive(origResult, getIndices())) { - activeResults.push_back(origResult); - } - }; - // Create an array of the direct tangent values of the original results. - for (auto i : range(originalResults.size())) - addActiveResult(i); - assert(activeResults.size() <= 1); - - if (activeResults.empty() && !originalResults.empty()) { - // Create zero tangent value for direct result. - auto origResult = originalResults[getIndices().source]; - assert(origResult->getType().isObject() && - "Should only be handling direct results for 'return' " - "instruction."); - auto zeroType = origResult->getType().getASTType(); - auto zero = - emitZeroDirect(getTangentSpace(zeroType)->getCanonicalType(), - diffLoc); - retElts.push_back(zero); - } else if (!activeResults.empty()) { - auto diffVal = getTangentValue(activeResults.front()); - auto val = materializeTangent(diffVal, diffLoc); - retElts.push_back(val); - } - - diffBuilder.createReturn( - diffLoc, joinElements(retElts, diffBuilder, diffLoc)); - } - -private: - - /// Set up the differential function. This includes: - /// - Creating all differential blocks. - /// - Creating differential entry block arguments based on the function type. - /// - Creating tangent value mapping for original/differential parameters. - /// - Checking for unvaried result and emitting related warnings. - void prepareForDifferentialGeneration() { - // Create differential blocks and arguments. - auto *diffGenEnv = getDifferential().getGenericEnvironment(); - auto diffGenSig = diffGenEnv - ? diffGenEnv->getGenericSignature()->getCanonicalSignature() - : nullptr; - auto &differential = getDifferential(); - auto *origEntry = original->getEntryBlock(); - for (auto &origBB : *original) { - auto *diffBB = differential.createBasicBlock(); - diffBBMap.insert({&origBB, diffBB}); - { - Lowering::GenericContextScope genericContextScope( - context.getTypeConverter(), diffGenSig); - auto diffStructLoweredType = remapSILTypeInDifferential( - differentialInfo.getLinearMapStructLoweredType(&origBB)); - - // If the BB is the original entry, then the differential block that we - // just created must be the differential function's entry. Create - // differential entry arguments and continue. - if (&origBB == origEntry) { - assert(diffBB->isEntry()); - createEntryArguments(&differential); - auto *lastArg = diffBB->getArguments().back(); - assert(lastArg->getType() == diffStructLoweredType); - differentialStructArguments[&origBB] = lastArg; - } - } - - LLVM_DEBUG({ - auto &s = getADDebugStream() - << "Original bb" + std::to_string(origBB.getDebugID()) - << ": To differentiate or not to differentiate?\n"; - for (auto &inst : origBB) { - s << (differentialInfo.shouldDifferentiateInstruction(&inst) - ? "[∂] " : "[ ] ") - << inst; - } - }); - } - - assert(diffBBMap.size() == 1 && - "Can only currently handle single basic block functions"); - - // The differential function has type: - // (arg0', ..., argn', entry_df_struct) -> result'. - auto diffParamArgs = - differential.getArgumentsWithoutIndirectResults().drop_back(); - assert(diffParamArgs.size() == - attr->getIndices().parameters->getNumIndices()); - auto origParamArgs = original->getArgumentsWithoutIndirectResults(); - - // TODO(TF-788): Re-enable non-varied result warning. - /* - // Check if result is not varied. - SmallVector origFormalResults; - collectAllFormalResultsInTypeOrder(*original, origFormalResults); - auto origResult = origFormalResults[getIndices().source]; - // Emit warning if original result is not varied, because it will always - // have a zero derivative. - if (!activityInfo.isVaried(origResult, getIndices().parameters)) { - // Emit fixit if original result has a valid source location. - auto startLoc = origResult.getLoc().getStartSourceLoc(); - auto endLoc = origResult.getLoc().getEndSourceLoc(); - if (startLoc.isValid() && endLoc.isValid()) { - context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) - .fixItInsert(startLoc, "withoutDerivative(at:") - .fixItInsertAfter(endLoc, ")"); - } - } - */ - - // Create a mapping of the parameters. - auto autoDiffIndex = getIndices().parameters->begin(); - for (auto index : range(diffParamArgs.size())) { - auto *diffParam = diffParamArgs[index]; - auto *origParam = origParamArgs[*autoDiffIndex]; - autoDiffIndex++; - if (diffParam->getType().isAddress()) { - setTangentBuffer(origEntry, origParam, diffParam); - } else { - setTangentValue( - origEntry, origParam, makeConcreteTangentValue(diffParam)); - } - LLVM_DEBUG(getADDebugStream() - << "Assigned parameter " << *diffParam - << " as the tangent of original result " << *origParam); - } - - // If there are indirect results, create a mapping. - auto origIndResults = original->getIndirectResults(); - auto diffIndResults = differential.getIndirectResults(); - assert(origIndResults.size() == diffIndResults.size()); - - for (auto &origBB : *original) - for (auto i : indices(diffIndResults)) - setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]); + assert(actualIndices.source == 0 && "Expected result index to be first."); + setTangentValue(bb, origResult, + makeConcreteTangentValue(differentialResult)); } public: @@ -5223,7 +4470,7 @@ class JVPEmitter final context(context), original(original), attr(attr), jvp(jvp), invoker(invoker), activityInfo(getActivityInfo( context, original, attr->getIndices(), jvp)), - differentialInfo(context, AutoDiffLinearMapKind::Differential, original, + differentialInfo(context, AutoDiffAssociatedFunctionKind::JVP, original, jvp, attr->getIndices(), activityInfo, getBuilder()), differentialAndBuilder(initializeDifferentialAndBuilder( context, original, attr, &differentialInfo)), @@ -5313,6 +4560,91 @@ class JVPEmitter final return differential; } + /// Set up the differential function. This includes: + /// - Creating all the differential blocks. + /// - Create arguments for the entry block according to the function type. + /// - Adding the tangent values of the parameters to the tangent value map. + /// - Checking for unvaried result and emitting related warnings. + void prepareForDifferentialGeneration() { + auto &diffBuilder = getDifferentialBuilder(); + + // Create differential blocks and arguments. + // TODO: Consider visiting original blocks in pre-order (dominance) order. + auto &differential = getDifferential(); + auto *origEntry = original->getEntryBlock(); + for (auto &origBB : *original) { + auto *diffBB = differential.createBasicBlock(); + diffBBMap.insert({&origBB, diffBB}); + auto diffStructLoweredType = + remapType(differentialInfo.getLinearMapStructLoweredType(&origBB)); + // If the BB is the original entry, then the differential block that we + // just created must be the differential function's entry. Create + // differential entry arguments and continue. + if (&origBB == origEntry) { + assert(diffBB->isEntry()); + createEntryArguments(&differential); + auto *mainDifferentialStruct = diffBB->getArguments().back(); + assert(mainDifferentialStruct->getType() == diffStructLoweredType); + differentialStructArguments[&origBB] = mainDifferentialStruct; + } + + LLVM_DEBUG({ + auto &s = getADDebugStream() + << "Original bb" + std::to_string(origBB.getDebugID()) + << ": To differentiate or not to differentiate?\n"; + for (auto &inst : origBB) { + s << (differentialInfo.shouldDifferentiateInstruction(&inst) + ? "[∂] " : "[ ] ") + << inst; + } + }); + } + + assert(diffBBMap.size() == 1 && + "Can only currently handle single basic block functions"); + + // The differential function has type: + // (arg0', ..., argn', entry_df_struct) -> result'. + auto diffParamArgs = + differential.getArgumentsWithoutIndirectResults().drop_back(); + assert(diffParamArgs.size() == + attr->getIndices().parameters->getNumIndices()); + auto origParamArgs = original->getArgumentsWithoutIndirectResults(); + + // TODO(TF-788): Re-enable non-varied result warning. + /* + // Emit a warning and fixit if original result is not varied, because it + // will always have a zero derivative. + SmallVector origFormalResults; + collectAllFormalResultsInTypeOrder(*original, origFormalResults); + auto origResult = origFormalResults[getIndices().source]; + if (!activityInfo.isVaried(origResult, getIndices().parameters)) { + // Emit fixit if original result has a valid source location. + auto startLoc = origResult.getLoc().getStartSourceLoc(); + auto endLoc = origResult.getLoc().getEndSourceLoc(); + if (startLoc.isValid() && endLoc.isValid()) { + context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) + .fixItInsert(startLoc, "withoutDerivative(at:") + .fixItInsertAfter(endLoc, ")"); + } + } + */ + + auto *diffEntry = getDifferential().getEntryBlock(); + diffBuilder.setInsertionPoint( + diffEntry, getNextDifferentialLocalAllocationInsertionPoint()); + + for (auto index : *getIndices().parameters) { + auto diffParam = diffParamArgs[index]; + auto origParam = origParamArgs[index]; + setTangentValue(origEntry, origParam, + makeConcreteTangentValue(diffParam)); + LLVM_DEBUG(getADDebugStream() + << "Assigned parameter " << *diffParam + << " as the tangent of original result " << origParam); + } + } + /// Run JVP generation. Returns true on error. bool run() { LLVM_DEBUG(getADDebugStream() @@ -5326,7 +4658,6 @@ class JVPEmitter final SmallVector entryArgs(entry->getArguments().begin(), entry->getArguments().end()); cloneFunctionBody(original, entry, entryArgs); - emitReturnInstForDifferential(); // If errors occurred, back out. if (errorOccurred) return true; @@ -5352,25 +4683,9 @@ class JVPEmitter final /// General visitor for all instructions. If any error is emitted by previous /// visits, bail out. void visit(SILInstruction *inst) { - auto diffBuilder = getDifferentialBuilder(); if (errorOccurred) return; - if (differentialInfo.shouldDifferentiateInstruction(inst)) { - LLVM_DEBUG(getADDebugStream() << "JVPEmitter visited:\n[ORIG]" - << *inst); -#ifndef NDEBUG - auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); -#endif - TypeSubstCloner::visit(inst); - LLVM_DEBUG({ - auto &s = llvm::dbgs() << "[DF] Emitted in Differential:\n"; - auto afterInsertion = diffBuilder.getInsertionPoint(); - for (auto it = ++beforeInsertion; it != afterInsertion; ++it) - s << *it; - }); - } else { - TypeSubstCloner::visit(inst); - } + TypeSubstCloner::visit(inst); } void visitSILInstruction(SILInstruction *inst) { @@ -5379,6 +4694,47 @@ class JVPEmitter final errorOccurred = true; } + /// Handle `copy_value` instruction. + /// Original: y = copy_value x + /// Adjoint: tan[x] = copy_value tan[y] + void visitCopyValueInst(CopyValueInst *cvi) { + TypeSubstCloner::visitCopyValueInst(cvi); + emitTangentForCopyValueInst(cvi); + } + + void visitReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto *origExit = ri->getParent(); + auto &builder = getBuilder(); + auto *diffStructVal = buildDifferentialValueStructValue(ri); + + // Get the value in the JVP corresponding to the original result. + auto *origRetInst = cast(origExit->getTerminator()); + auto origResult = getOpValue(origRetInst->getOperand()); + SmallVector origResults; + extractAllElements(origResult, builder, origResults); + + // Get and partially apply the differential. + auto jvpGenericEnv = jvp->getGenericEnvironment(); + auto jvpSubstMap = jvpGenericEnv + ? jvpGenericEnv->getForwardingSubstitutionMap() + : jvp->getForwardingSubstitutionMap(); + auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); + auto *differentialPartialApply = builder.createPartialApply( + loc, differentialRef, jvpSubstMap, {diffStructVal}, + ParameterConvention::Direct_Guaranteed); + + // Return a tuple of the original result and differential. + SmallVector directResults; + directResults.append(origResults.begin(), origResults.end()); + directResults.push_back(differentialPartialApply); + builder.createReturn( + ri->getLoc(), joinElements(directResults, builder, loc)); + + // Differential emission. + emitTangentForReturnInst(ri); + } + void visitInstructionsInBlock(SILBasicBlock *bb) { // Destructure the differential struct to get the elements. auto &diffBuilder = getDifferentialBuilder(); @@ -5392,6 +4748,24 @@ class JVPEmitter final TypeSubstCloner::visitInstructionsInBlock(bb); } + void visitDestroyValueInst(DestroyValueInst *dvi) { + TypeSubstCloner::visitDestroyValueInst(dvi); + if (differentialInfo.shouldDifferentiateInstruction(dvi)) + emitTangentForDestroyValueInst(dvi); + } + + void visitBeginBorrowInst(BeginBorrowInst *bbi) { + TypeSubstCloner::visitBeginBorrowInst(bbi); + if (differentialInfo.shouldDifferentiateInstruction(bbi)) + emitTangentForBeginBorrow(bbi); + } + + void visitEndBorrowInst(EndBorrowInst *ebi) { + TypeSubstCloner::visitEndBorrowInst(ebi); + if (differentialInfo.shouldDifferentiateInstruction(ebi)) + emitTangentForEndBorrow(ebi); + } + // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its JVP. void visitApplyInst(ApplyInst *ai) { @@ -5418,21 +4792,11 @@ class JVPEmitter final } } - LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); + LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n'); // Get the parameter indices required for differentiating this function. SmallVector allResults; - // If `apply` result is tuple-typed with a `destructure_tuple` user, add the - // results of the `destructure_tuple` user to `allResults` instead of adding - // the `apply` result itself. - // Otherwise, add `apply` result to `allResults`. - if (auto *dti = getSingleDestructureTupleUser(ai)) { - for (auto result : dti->getResults()) - allResults.push_back(result); - } else { - allResults.push_back(ai); - } - + allResults.push_back(ai); allResults.append(ai->getIndirectSILResults().begin(), ai->getIndirectSILResults().end()); SmallVector activeParamIndices; @@ -5456,7 +4820,7 @@ class JVPEmitter final errorOccurred = true; return; } - // Form expected indices, assuming there's only one result. + // Form expected indices by assuming there's only one result. SILAutoDiffIndices indices( activeResultIndices.front(), AutoDiffIndexSubset::get( @@ -5605,85 +4969,13 @@ class JVPEmitter final recursivelyDeleteTriviallyDeadInstructions( getOpValue(origCallee)->getDefiningInstruction()); - // Add the differential function for when we create the struct we partially - // apply to the differential we are generating. - auto differential = jvpDirectResults.back(); - auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai); - auto originalDifferentialType = - getOpType(differential->getType()).getAs(); - auto differentialType = - remapType(differential->getType()) - .castTo(); - auto jvpGenSig = SubsMap.getGenericSignature() - ? SubsMap.getGenericSignature()->getCanonicalSignature() - : nullptr; - Lowering::GenericContextScope genericContextScope( - context.getTypeConverter(), jvpGenSig); - auto loweredDifferentialType = - getOpType(context.getTypeConverter().getLoweredType( - differentialDecl->getInterfaceType()->getCanonicalType(), - ResilienceExpansion::Minimal)) - .castTo(); - // If actual differential type does not match lowered differential type, - // reabstract the differential using a thunk. - if (!loweredDifferentialType->isEqual(originalDifferentialType)) { - SILOptFunctionBuilder fb(context.getTransform()); - auto *thunk = getOrCreateReabstractionThunk( - fb, context.getModule(), loc, &getDifferential(), - differentialType, loweredDifferentialType); - auto *thunkRef = builder.createFunctionRef(loc, thunk); - differential = builder.createPartialApply( - loc, thunkRef, - getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()), - {differential}, differentialType->getCalleeConvention()); - } - differentialValues[ai->getParent()].push_back(differential); + // Record the callee differential function value. + // This is used later to construct a differential struct. + auto diffFunc = jvpDirectResults.back(); + differentialValues[ai->getParent()].push_back(diffFunc); // Differential emission. - emitTangentForApplyInst(ai, indices, originalDifferentialType); - } - - void visitReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); - auto &builder = getBuilder(); - auto *diffStructVal = buildDifferentialValueStructValue(ri); - - // Get the JVP value corresponding to the original functions's return value. - auto *origRetInst = cast(origExit->getTerminator()); - auto origResult = getOpValue(origRetInst->getOperand()); - SmallVector origResults; - extractAllElements(origResult, builder, origResults); - - // Get and partially apply the differential. - auto jvpGenericEnv = jvp->getGenericEnvironment(); - auto jvpSubstMap = jvpGenericEnv - ? jvpGenericEnv->getForwardingSubstitutionMap() - : jvp->getForwardingSubstitutionMap(); - auto *differentialRef = - builder.createFunctionRef(loc, &getDifferential()); - auto *differentialPartialApply = builder.createPartialApply( - loc, differentialRef, jvpSubstMap, {diffStructVal}, - ParameterConvention::Direct_Guaranteed); - - // Return a tuple of the original result and pullback. - SmallVector directResults; - directResults.append(origResults.begin(), origResults.end()); - directResults.push_back(differentialPartialApply); - builder.createReturn( - ri->getLoc(), joinElements(directResults, builder, loc)); - } - - void visitBranchInst(BranchInst *bi) { - llvm_unreachable("Unsupported SIL instruction."); - } - - void visitCondBranchInst(CondBranchInst *cbi) { - llvm_unreachable("Unsupported SIL instruction."); - } - - void visitSwitchEnumInst(SwitchEnumInst *sei) { - llvm_unreachable("Unsupported SIL instruction."); + emitTangentForApplyInst(ai, indices); } void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { @@ -6845,7 +6137,7 @@ class PullbackEmitter final : public SILInstructionVisitor { void visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplyInst(ai)); // Handle array uninitialized allocation intrinsic specially. - if (isArrayLiteralIntrinsic(ai)) + if (ai->hasSemantics("array.uninitialized_intrinsic")) return visitArrayInitialization(ai); // Replace a call to a function with a call to its pullback. auto &nestedApplyInfo = getContext().getNestedApplyInfo(); @@ -7005,7 +6297,8 @@ class PullbackEmitter final : public SILInstructionVisitor { tangentVectorTy->getStructOrBoundGenericStruct(); assert(tangentVectorDecl); - auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct); + auto *destructure = + builder.createDestructureStruct(si->getLoc(), adjStruct); // Accumulate adjoints for the fields of the `struct` operand. unsigned fieldIndex = 0; for (auto it = structDecl->getStoredProperties().begin(); @@ -7032,7 +6325,7 @@ class PullbackEmitter final : public SILInstructionVisitor { tanField = cast(tanFieldLookup.front()); } assert(tanField); - auto tanElt = dti->getResult(fieldIndex); + auto tanElt = destructure->getResult(fieldIndex); addAdjointValue( bb, si->getFieldValue(field), makeConcreteAdjointValue(tanElt), si->getLoc()); diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift index ecd52ac0a8364..ba4c3ff02327f 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift @@ -202,13 +202,6 @@ extension Tracked where T : Differentiable, T == T.TangentVector { return (lhs + rhs, { v in (v, v) }) } - @usableFromInline - @differentiating(+) - internal static func _jvpAdd(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> Self) { - return (lhs + rhs, { $0 + $1 }) - } - @usableFromInline @differentiating(-) internal static func _vjpSubtract(lhs: Self, rhs: Self) @@ -216,11 +209,18 @@ extension Tracked where T : Differentiable, T == T.TangentVector { return (lhs - rhs, { v in (v, .zero - v) }) } + @usableFromInline + @differentiating(+) + internal static func _vjpAdd(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> (Self)) { + return (lhs + rhs, { (dx, dy) in dx + dy }) + } + @usableFromInline @differentiating(-) - internal static func _jvpSubtract(lhs: Self, rhs: Self) - -> (value: Self, differential: (Self, Self) -> Self) { - return (lhs - rhs, { $0 - $1 }) + internal static func _vjpSubtract(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> (Self)) { + return (lhs - rhs, { (dx, dy) in dx - dy }) } } @@ -235,7 +235,7 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, @usableFromInline @differentiating(*) - internal static func _jvpMultiply(lhs: Self, rhs: Self) + internal static func _vjpMultiply(lhs: Self, rhs: Self) -> (value: Self, differential: (Self, Self) -> (Self)) { return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs }) } @@ -251,7 +251,7 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector @usableFromInline @differentiating(/) - internal static func _jvpDivide(lhs: Self, rhs: Self) + internal static func _vjpDivide(lhs: Self, rhs: Self) -> (value: Self, differential: (Self, Self) -> (Self)) { return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy }) } diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index f70ff914680d8..f1956cc5cdb17 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -914,7 +914,7 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { } /// Creates a type-erased derivative from the given derivative. - @differentiable(jvp: _jvpInit(_:), vjp: _vjpInit(_:)) + @differentiable(vjp: _vjpInit(_:)) public init(_ base: T) where T : Differentiable, T.TangentVector == T { self._box = _ConcreteDerivativeBox(base) } @@ -927,14 +927,6 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { return (AnyDerivative(base), { v in v.base as! T.TangentVector }) } - @usableFromInline internal static func _jvpInit( - _ base: T - ) -> (AnyDerivative, (T.TangentVector) -> AnyDerivative) - where T : Differentiable, T.TangentVector == T - { - return (AnyDerivative(base), { dbase in AnyDerivative(dbase) }) - } - public typealias TangentVector = AnyDerivative // `Equatable` requirements (implied by `AdditiveArithmetic`). @@ -971,14 +963,6 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { return (lhs + rhs, { v in (v, v) }) } - @differentiating(+) - @usableFromInline internal static func _jvpAdd( - lhs: AnyDerivative, rhs: AnyDerivative - ) -> (value: AnyDerivative, - differential: (AnyDerivative, AnyDerivative) -> (AnyDerivative)) { - return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs }) - } - public static func - ( lhs: AnyDerivative, rhs: AnyDerivative ) -> AnyDerivative { @@ -993,14 +977,6 @@ public struct AnyDerivative : EuclideanDifferentiable & AdditiveArithmetic { return (lhs - rhs, { v in (v, .zero - v) }) } - @differentiating(-) - @usableFromInline internal static func _jvpSubtract( - lhs: AnyDerivative, rhs: AnyDerivative - ) -> (value: AnyDerivative, - differential: (AnyDerivative, AnyDerivative) -> AnyDerivative) { - return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs }) - } - // `Differentiable` requirements. public mutating func move(along direction: TangentVector) { if _box._isOpaqueZero() { diff --git a/test/AutoDiff/forward_mode_runtime.swift b/test/AutoDiff/forward_mode_runtime.swift index a1e2ad83ff8c5..81a215837bb3f 100644 --- a/test/AutoDiff/forward_mode_runtime.swift +++ b/test/AutoDiff/forward_mode_runtime.swift @@ -3,11 +3,6 @@ import StdlibUnittest import DifferentiationUnittest -#if os(macOS) -import Darwin.C -#else -import Glibc -#endif var ForwardModeTests = TestSuite("ForwardMode") @@ -15,15 +10,6 @@ var ForwardModeTests = TestSuite("ForwardMode") // Basic tests. //===----------------------------------------------------------------------===// -ForwardModeTests.test("Identity") { - func func_to_diff(x: Float) -> Float { - return x - } - let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff) - expectEqual(4, y) - expectEqual(1, differential(1)) -} - ForwardModeTests.test("Unary") { func func_to_diff(x: Float) -> Float { return x * x @@ -53,156 +39,8 @@ ForwardModeTests.test("BinaryWithLets") { expectEqual(-19, differential(1, 1)) } -ForwardModeTests.test("SubsetParametersDiff") { - func func_to_diff1(x: Int, y: Float, z: Int) -> Float { - return y - } - let (y1, differential1) = valueWithDifferential(at: 5) { y in - func_to_diff1(x: 0, y: y, z: 0) - } - expectEqual(5, y1) - expectEqual(1, differential1(1)) - - func func_to_diff2(x: Float, y: Int, z: Int) -> Float { - return 2 * x - } - let (y2, differential2) = valueWithDifferential(at: 6) { x in - func_to_diff2(x: x, y: 0, z: 0) - } - expectEqual(12, y2) - expectEqual(2, differential2(1)) - - func func_to_diff3(x: Int, y: Int, z: Float) -> Float { - return 3 * z - } - let (y3, differential3) = valueWithDifferential(at: 7) { z in - func_to_diff3(x: 0, y: 0, z: z) - } - expectEqual(21, y3) - expectEqual(3, differential3(1)) -} - -//===----------------------------------------------------------------------===// -// Functions with variables -//===----------------------------------------------------------------------===// - -ForwardModeTests.test("UnaryWithVars") { - func unary(x: Float) -> Float { - var a = x - a = x - var b = a + 2 - b = b - 1 - let c: Float = 3 - var d = a + b + c - 1 - d = d + d - return d - } - - let (y, differential) = valueWithDifferential(at: 4, in: unary) - expectEqual(22, y) - expectEqual(4, differential(1)) -} - -//===----------------------------------------------------------------------===// -// Functions with basic struct -//===----------------------------------------------------------------------===// - -struct A: Differentiable & AdditiveArithmetic { - var x: Float -} - -ForwardModeTests.test("StructInit") { - func structInit(x: Float) -> A { - return A(x: 2 * x) - } - - let (y, differential) = valueWithDifferential(at: 4, in: structInit) - expectEqual(A(x: 8), y) - expectEqual(A(x: 2), differential(1)) -} - -ForwardModeTests.test("StructExtract") { - func structExtract(x: A) -> Float { - return 2 * x.x - } - - let (y, differential) = valueWithDifferential( - at: A(x: 4), - in: structExtract) - expectEqual(8, y) - expectEqual(2, differential(A(x: 1))) -} - -ForwardModeTests.test("LocalStructVariable") { - func structExtract(x: A) -> A { - let a = A(x: 2 * x.x) // 2x - var b = A(x: a.x + 2) // 2x + 2 - b = A(x: b.x + a.x) // 2x + 2 + 2x = 4x + 2 - return b - } - - let (y, differential) = valueWithDifferential( - at: A(x: 4), - in: structExtract) - expectEqual(A(x: 18), y) - expectEqual(A(x: 4), differential(A(x: 1))) -} - -//===----------------------------------------------------------------------===// -// Functions with methods -//===----------------------------------------------------------------------===// - -extension A { - func noParamMethodA() -> A { - return A(x: 2 * x) - } - - func noParamMethodx() -> Float { - return 2 * x - } - - static func *(lhs: A, rhs: A) -> A { - return A(x: lhs.x * rhs.x) - } - - func complexBinaryMethod(u: A, v: Float) -> A { - var b: A = u * A(x: 2) // A(x: u * 2) - b.x = b.x * v // A(x: u * 2 * v) - let c = b.x + 1 // u * 2 * v + 1 - - // A(x: u * 2 * v + 1 + u * 2 * v) = A(x: x * (4uv + 1)) - return A(x: x * (c + b.x)) - } -} - -ForwardModeTests.test("noParamMethodA") { - let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in - x.noParamMethodA() - } - expectEqual(A(x: 8), y) - expectEqual(A(x: 2), differential(A(x: 1))) -} - -ForwardModeTests.test("noParamMethodx") { - let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in - x.noParamMethodx() - } - expectEqual(8, y) - expectEqual(2, differential(A(x: 1))) -} - -ForwardModeTests.test("complexBinaryMethod") { - let (y, differential) = valueWithDifferential(at: A(x: 4), A(x: 5), 3) { - (x, y, z) in - // derivative = A(x: 4uv + 4xv + 4ux + 1) = 4*5*3 + 4*4*3 + 4*5*4 + 1 = 189 - x.complexBinaryMethod(u: y, v: z) - } - expectEqual(A(x: 244), y) - expectEqual(A(x: 189), differential(A(x: 1), A(x: 1), 1)) -} - //===----------------------------------------------------------------------===// -// Tracked struct +// `Tracked` struct //===----------------------------------------------------------------------===// ForwardModeTests.test("TrackedIdentity") { @@ -255,378 +93,6 @@ ForwardModeTests.test("TrackedWithLets") { expectEqual(4.9375, differential(1, 1)) } -//===----------------------------------------------------------------------===// -// Tuples -//===----------------------------------------------------------------------===// - -ForwardModeTests.test("SimpleTupleExtractLet") { - func foo(_ x: Float) -> Float { - let tuple = (2*x, x) - return tuple.0 - } - let (y, differential) = valueWithDifferential(at: 4, in: foo) - expectEqual(8, y) - expectEqual(2, differential(1)) -} - -ForwardModeTests.test("SimpleTupleExtractVar") { - func foo(_ x: Float) -> Float { - let tuple = (2*x, x) - return tuple.0 - } - let (y, differential) = valueWithDifferential(at: 4, in: foo) - expectEqual(8, y) - expectEqual(2, differential(1)) -} - -ForwardModeTests.test("TupleSideEffects") { - func foo(_ x: Float) -> Float { - var tuple = (x, x) - tuple.0 = tuple.0 * x - return x * tuple.0 - } - expectEqual(27, derivative(at: 3, in: foo)) - - func fifthPower(_ x: Float) -> Float { - var tuple = (x, x) - tuple.0 = tuple.0 * x - tuple.1 = tuple.0 * x - return tuple.0 * tuple.1 - } - expectEqual(405, derivative(at: 3, in: fifthPower)) - - func nested(_ x: Float) -> Float { - var tuple = ((x, x), x) - tuple.0.0 = tuple.0.0 * x - tuple.0.1 = tuple.0.0 * x - return tuple.0.0 * tuple.0.1 - } - expectEqual(405, derivative(at: 3, in: nested)) - - // FIXME(TF-201): Update after reabstraction thunks can be directly differentiated. - /* - func generic(_ x: T) -> T { - var tuple = (x, x) - tuple.0 += x - tuple.1 += x - return tuple.0 + tuple.0 - } - expectEqual(1, derivative(at: 3.0, in: generic)) - */ -} - -// Tests TF-321. -ForwardModeTests.test("TupleNonDifferentiableElements") { - // @differentiable - func foo(_ x: Float) -> Float { - var tuple = (x, 1) - tuple.0 = x - tuple.1 = 1 - return tuple.0 - } - expectEqual(1, derivative(at: 1, in: foo)) - - func bar(_ x: Float) -> Float { - var tuple: (Int, Int, Float, Float) = (1, 1, x, x) - tuple.0 = 1 - tuple.1 = 1 - tuple.3 = x - return tuple.3 - } - expectEqual(1, derivative(at: 1, in: bar)) - - struct Wrapper { - @differentiable(where T : Differentiable) - func baz(_ x: T) -> T { - var tuple = (1, 1, x, 1) - tuple.0 = 1 - tuple.2 = x - tuple.3 = 1 - return tuple.2 - } - } - expectEqual(1, derivative(at: Float(1), in: { x -> Float in - let wrapper = Wrapper() - return wrapper.baz(x) - })) -} - -//===----------------------------------------------------------------------===// -// Generics -//===----------------------------------------------------------------------===// - -struct Tensor - : VectorProtocol, Differentiable { - // NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`) - // until differentiation has indirect passing support. - var value: Float - init(_ value: Float) { self.value = value } -} - -ForwardModeTests.test("GenericIdentity") { - func identity(_ x: T) -> T { - return x - } - let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in - identity(x) - } - expectEqual(4, y) - expectEqual(1, differential(1)) -} - -ForwardModeTests.test("GenericTensorIdentity") { - func identity( - _ x: Tensor) -> Tensor { - return x - } - let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in - identity(Tensor(x)) - } - expectEqual(Tensor(4), y) - expectEqual(Tensor(1), differential(1)) -} - -ForwardModeTests.test("GenericTensorPlus") { - func plus(_ x: Tensor) -> Float { - return x.value + x.value - } - let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in - plus(Tensor(x)) - } - expectEqual(8, y) - expectEqual(2, differential(1)) -} - -ForwardModeTests.test("GenericTensorBinaryInput") { - func binary( - _ x: Tensor, _ y: Tensor) -> Float { - return x.value * y.value - } - let (y, differential) = valueWithDifferential(at: 4, 5) { - (x: Float, y: Float) in - binary(Tensor(x), Tensor(y)) - } - expectEqual(20, y) - expectEqual(9, differential(1, 1)) -} - -ForwardModeTests.test("GenericTensorWithLets") { - func binary( - _ x: Tensor, _ y: Tensor) -> Float { - let a = Tensor(x.value) - let b = Tensor(y.value) - return a.value * b.value - } - let (y, differential) = valueWithDifferential(at: 4, 5) { - (x: Float, y: Float) in - binary(Tensor(x), Tensor(y)) - } - expectEqual(20, y) - expectEqual(9, differential(1, 1)) -} - -ForwardModeTests.test("GenericTensorWithVars") { - func binary( - _ x: Tensor, _ y: Tensor) -> Float { - var a = Tensor(x.value) - var b = Tensor(y.value) - b = a - a = Tensor(y.value) - return a.value * b.value - } - let (y, differential) = valueWithDifferential(at: 4, 5) { - (x: Float, y: Float) in - binary(Tensor(x), Tensor(y)) - } - expectEqual(20, y) - expectEqual(9, differential(1, 1)) -} - -// Test case where associated derivative function's requirements are met. -extension Tensor where Scalar : Numeric { - @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint) - func mean() -> Tensor { - return self - } - - @differentiable(wrt: self where Scalar : Differentiable & FloatingPoint) - func variance() -> Tensor { - return mean() // ok - } -} -_ = differential(at: Tensor(1), in: { $0.variance() }) - -// Tests TF-508: differentiation requirements with dependent member types. -protocol TF_508_Proto { - associatedtype Scalar -} -extension TF_508_Proto where Scalar : FloatingPoint { - @differentiable( - jvp: jvpAdd - where Self : Differentiable, Scalar : Differentiable, - // Conformance requirement with dependent member type. - Self.TangentVector : TF_508_Proto - ) - static func +(lhs: Self, rhs: Self) -> Self { - return lhs - } - - @differentiable( - jvp: jvpSubtract - where Self : Differentiable, Scalar : Differentiable, - // Same-type requirement with dependent member type. - Self.TangentVector == Float - ) - static func -(lhs: Self, rhs: Self) -> Self { - return lhs - } -} -extension TF_508_Proto where Self : Differentiable, - Scalar : FloatingPoint & Differentiable, - Self.TangentVector : TF_508_Proto { - static func jvpAdd(lhs: Self, rhs: Self) - -> (Self, (TangentVector, TangentVector) -> TangentVector) { - return (lhs, { (dlhs, drhs) in dlhs }) - } -} -extension TF_508_Proto where Self : Differentiable, - Scalar : FloatingPoint & Differentiable, - Self.TangentVector == Float { - static func jvpSubtract(lhs: Self, rhs: Self) - -> (Self, (TangentVector, TangentVector) -> TangentVector) { - return (lhs, { (dlhs, drhs) in dlhs }) - } -} - -struct TF_508_Struct - : TF_508_Proto, AdditiveArithmetic {} -extension TF_508_Struct : Differentiable where Scalar : Differentiable { - typealias TangentVector = TF_508_Struct -} - -// func TF_508() { -// let x = TF_508_Struct() -// // Test conformance requirement with dependent member type. -// _ = differential(at: x, in: { -// (x: TF_508_Struct) -> TF_508_Struct in -// return x + x -// }) -// // Test same-type requirement with dependent member type. -// _ = differential(at: x, in: { -// (x: TF_508_Struct) -> TF_508_Struct in -// return x - x -// }) -// } - -// TF-523 -struct TF_523_Struct : Differentiable & AdditiveArithmetic { - var a: Float = 1 - typealias TangentVector = TF_523_Struct - typealias AllDifferentiableVariables = TF_523_Struct -} - -@differentiable -func TF_523_f(_ x: TF_523_Struct) -> Float { - return x.a * 2 -} - -// TF-534: Thunk substitution map remapping. -protocol TF_534_Layer : Differentiable { - associatedtype Input : Differentiable - associatedtype Output : Differentiable - - @differentiable - func callAsFunction(_ input: Input) -> Output -} -struct TF_534_Tensor : Differentiable {} - -func TF_534( - _ model: inout Model, inputs: Model.Input -) -> TF_534_Tensor where Model.Output == TF_534_Tensor { - return valueWithDifferential(at: model) { model -> Model.Output in - return model(inputs) - }.0 -} - -// TODO: uncomment once control flow is supported in forward mode. -// TF-652: Test VJPEmitter substitution map generic signature. -// The substitution map should have the VJP's generic signature, not the -// original function's. -// struct TF_652 {} -// extension TF_652 : Differentiable where Scalar : FloatingPoint {} - -// @differentiable(wrt: x where Scalar: FloatingPoint) -// func test(x: TF_652) -> TF_652 { -// for _ in 0..<10 { -// let _ = x -// } -// return x -// } - -//===----------------------------------------------------------------------===// -// Tracked Generic. -//===----------------------------------------------------------------------===// - -ForwardModeTests.test("GenericTrackedIdentity") { - func identity(_ x: Tracked) -> Tracked { - return x - } - let (y, differential) = valueWithDifferential(at: 4) { (x: Float) in - identity(Tracked(x)) - } - expectEqual(4, y) - expectEqual(1, differential(1)) -} - -ForwardModeTests.test("GenericTrackedBinaryAdd") { - func add(_ x: Tracked, _ y: Tracked) -> Tracked - where T: Differentiable, T == T.TangentVector { - return x + y - } - let (y, differential) = valueWithDifferential(at: 4, 5) { - (x: Float, y: Float) in - add(Tracked(x), Tracked(y)) - } - expectEqual(9, y) - expectEqual(2, differential(1, 1)) -} - -ForwardModeTests.test("GenericTrackedBinaryLets") { - func add(_ x: Tracked, _ y: Tracked) -> Tracked - where T: Differentiable & SignedNumeric, - T == T.TangentVector, - T == T.Magnitude { - let a = x * y // xy - let b = a + a // 2xy - return b + b // 4xy - } - // 4y + 4x - let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in - add(Tracked(x), Tracked(y)) - } - expectEqual(80, y) - expectEqual(36, differential(1, 1)) -} - -ForwardModeTests.test("GenericTrackedBinaryVars") { - func add(_ x: Tracked, _ y: Tracked) -> Tracked - where T: Differentiable & SignedNumeric, - T == T.TangentVector, - T == T.Magnitude { - var a = x * y // xy - a = a + a // 2xy - var b = x - b = a - return b + b // 4xy - } - // 4y + 4x - let (y, differential) = valueWithDifferential(at: 4, 5) { (x: Float, y: Float) in - add(Tracked(x), Tracked(y)) - } - expectEqual(80, y) - expectEqual(36, differential(1, 1)) -} - ForwardModeTests.test("TrackedDifferentiableFuncType") { func valAndDeriv( f: @escaping @differentiable (Tracked) -> Tracked @@ -644,10 +110,11 @@ ForwardModeTests.test("TrackedDifferentiableFuncType") { expectEqual(400, val1) expectEqual(160, dv1) } - //===----------------------------------------------------------------------===// // Classes //===----------------------------------------------------------------------===// +// NOTE: once forward mode is done, can copy and replace this in +// `class_method.swift` as it already calls reverse mode functions. ForwardModeTests.test("Final") { final class Final : Differentiable { @@ -657,6 +124,7 @@ ForwardModeTests.test("Final") { } for i in -5...5 { + expectEqual(Float(i) * 2, gradient(at: Float(i)) { x in Final().method(x) }) expectEqual( Float(i) * 2, derivative(at: Float(i)) { x in Final().method(x) }) @@ -700,10 +168,16 @@ ForwardModeTests.test("Simple") { func classValueWithDerivative(_ c: Super) -> (Float, Float) { return valueWithDerivative(at: 1) { c.f($0) } } + func classValueWithGradient(_ c: Super) -> (Float, Float) { + return valueWithGradient(at: 1) { c.f($0) } + } expectEqual((2, 2), classValueWithDerivative(Super())) expectEqual((3, 3), classValueWithDerivative(SubOverride())) expectEqual((3, 3), classValueWithDerivative(SubOverrideCustomDerivatives())) + expectEqual((2, 2), classValueWithGradient(Super())) + expectEqual((3, 3), classValueWithGradient(SubOverride())) + expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives())) } ForwardModeTests.test("SimpleWrtSelf") { @@ -767,6 +241,13 @@ ForwardModeTests.test("SimpleWrtSelf") { // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) // expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v)) + + + // `valueWithGradient` is not used because nested tuples cannot be compared + // with `expectEqual`. + func classGradient(_ c: Super) -> (Super.TangentVector, Float) { + return gradient(at: c, 10) { c, x in c.f(x) } + } // `valueWithDerivative` is not used because the derivative requires `Super` // to conform to `FloatingPoint`. @@ -788,500 +269,43 @@ ForwardModeTests.test("SimpleWrtSelf") { expectEqual(30, y3) let c3 = SubOverrideCustomDerivatives.TangentVector(base: 1, _nontrivial: []) expectEqual(3, diff3(c3, 1)) + expectEqual((Super.TangentVector(base: 10, _nontrivial: []), 2), + classGradient(Super(base: 2))) + expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3), + classGradient(SubOverride(base: 2))) + expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3), + classGradient(SubOverrideCustomDerivatives(base: 2))) } //===----------------------------------------------------------------------===// // Protocols //===----------------------------------------------------------------------===// +// TODO: add more protocol tests. +// protocol DiffReq : Differentiable { +// @differentiable(wrt: x) +// func foo(x: Float) -> Float +// } -protocol Prot : Differentiable { - @differentiable(wrt: x) - func foo(x: Float) -> Float -} -ForwardModeTests.test("Simple Protocol") { - struct Linear: Prot, VectorProtocol { - typealias TangentVector = Linear - - let m: Float - let b: Float - - @differentiable(wrt: x) - func foo(x: Float) -> Float { - return m * x + b - } - } - - func genericFoo(_ t: T, _ x: Float) -> Float { - t.foo(x: x) - } - let inst = Linear(m: 5, b: -2) - let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) } - expectEqual(23, y1) - expectEqual(5, diff1(1)) -} - -protocol DiffReq : Differentiable { - @differentiable(wrt: (self, x)) - func f(_ x: Float) -> Float -} - -extension DiffReq where TangentVector : AdditiveArithmetic { - @inline(never) // Prevent specialization, to test all witness code. - func derivF(at x: Float) -> Float { - return (valueWithDifferential(at: x) { x in self.f(x) }).1(1) - } -} - -struct Quadratic : DiffReq, VectorProtocol { - typealias TangentVector = Quadratic - - @differentiable - let a: Float - - @differentiable - let b: Float - - @differentiable - let c: Float - - init(_ a: Float, _ b: Float, _ c: Float) { - self.a = a - self.b = b - self.c = c - } - - @differentiable(wrt: (self, x)) - func f(_ x: Float) -> Float { - return a * x * x + b * x + c - } -} - -ForwardModeTests.test("ProtocolFunc") { - expectEqual(12, Quadratic(11, 12, 13).derivF(at: 0)) - expectEqual(2 * 11 + 12, Quadratic(11, 12, 13).derivF(at: 1)) - expectEqual(2 * 11 * 2 + 12, Quadratic(11, 12, 13).derivF(at: 2)) -} - -// MARK: Constructor, accessor, and subscript requirements. - -protocol FunctionsOfX: Differentiable { - @differentiable - init(x: Float) - - @differentiable - var x: Float { get } - - @differentiable - var y: Float { get } - - @differentiable - var z: Float { get } - - @differentiable - subscript() -> Float { get } -} - -struct TestFunctionsOfX: FunctionsOfX { - @differentiable - init(x: Float) { - self.x = x - self.y = x * x - } - - /// x = x - var x: Float - - /// y = x * x - var y: Float - - /// z = x * x + x - var z: Float { - return y + x - } - - @differentiable - subscript() -> Float { - return z - } -} - -@inline(never) // Prevent specialization, to test all witness code. -func derivatives(at x: Float, in: F.Type) - -> (Float, Float, Float, Float) -{ - let dxdx = derivative(at: x) { x in F(x: x).x } - let dydx = derivative(at: x) { x in F(x: x).y } - let dzdx = derivative(at: x) { x in F(x: x).z } - let dsubscriptdx = derivative(at: x) { x in F(x: x)[] } - return (dxdx, dydx, dzdx, dsubscriptdx) -} - -ForwardModeTests.test("constructor, accessor, subscript") { - expectEqual( - (1.0, 4.0, 5.0, 5.0), - derivatives(at: 2.0, in: TestFunctionsOfX.self)) -} - -// MARK: - Test witness method SIL type computation. - -protocol P : Differentiable { - @differentiable(wrt: (x, y)) - func foo(_ x: Float, _ y: Double) -> Float -} -struct S : P { - @differentiable(wrt: (x, y)) - func foo(_ x: Float, _ y: Double) -> Float { - return x - } -} - -// MARK: - Overridden protocol method adding differentiable attribute. - -public protocol Distribution { - associatedtype Value - func logProbability(of value: Value) -> Float -} - -public protocol DifferentiableDistribution: Differentiable, Distribution { - @differentiable(wrt: self) - func logProbability(of value: Value) -> Float -} - -struct Foo: DifferentiableDistribution { - @differentiable(wrt: self) - func logProbability(of value: Float) -> Float { - .zero - } -} - -@differentiable -func blah(_ x: T) -> Float where T.Value: AdditiveArithmetic { - x.logProbability(of: .zero) -} - -// Adding a more general `@differentiable` attribute. -public protocol DoubleDifferentiableDistribution: DifferentiableDistribution - where Value: Differentiable { - @differentiable(wrt: self) - @differentiable(wrt: (self, value)) - func logProbability(of value: Value) -> Float -} - -@differentiable -func blah2(_ x: T, _ value: T.Value) -> Float - where T.Value: AdditiveArithmetic { - x.logProbability(of: value) -} - -protocol DifferentiableFoo { - associatedtype T: Differentiable - @differentiable(wrt: x) - func foo(_ x: T) -> Float -} - -protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo { - @differentiable(wrt: (self, x)) - func foo(_ x: T) -> Float -} - -struct MoreDifferentiableFooStruct: MoreDifferentiableFoo { - @differentiable(wrt: (self, x)) - func foo(_ x: Float) -> Float { - x - } -} - -//===----------------------------------------------------------------------===// -// Simple Math -//===----------------------------------------------------------------------===// - -ForwardModeTests.test("Arithmetics") { - func foo1(x: Float, y: Float) -> Float { - return x * y - } - expectEqual(7, derivative(at: 3, 4, in: foo1)) - func foo2(x: Float, y: Float) -> Float { - return -x * y - } - expectEqual(-7, derivative(at: 3, 4, in: foo2)) - func foo3(x: Float, y: Float) -> Float { - return -x + y - } - expectEqual(0, derivative(at: 3, 4, in: foo3)) -} - -ForwardModeTests.test("Fanout") { - func foo1(x: Float) -> Float { - x - x - } - expectEqual(0, derivative(at: 100, in: foo1)) - func foo2(x: Float) -> Float { - x + x - } - expectEqual(2, derivative(at: 100, in: foo2)) - func foo3(x: Float, y: Float) -> Float { - x + x + x * y - } - expectEqual(7, derivative(at: 3, 2, in: foo3)) -} - -ForwardModeTests.test("FunctionCall") { - func foo(_ x: Float, _ y: Float) -> Float { - return 3 * x + { $0 * 3 }(3) * y - } - expectEqual(12, derivative(at: 3, 4, in: foo)) - expectEqual(3, derivative(at: 3) { x in foo(x, 4) }) -} - -ForwardModeTests.test("ResultSelection") { - func foo(_ x: Float, _ y: Float) -> (Float, Float) { - return (x + 1, y + 2) - } - expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).0 })) - expectEqual(1, derivative(at: 3, 3, in: { x, y in foo(x, y).1 })) -} - -ForwardModeTests.test("CaptureLocal") { - let z: Float = 10 - func foo(_ x: Float) -> Float { - return z * x - } - expectEqual(10, derivative(at: 0, in: foo)) -} - -var globalVar: Float = 10 -ForwardModeTests.test("CaptureGlobal") { - func foo(x: Float) -> Float { - globalVar += 20 - return globalVar * x - } - expectEqual(30, derivative(at: 0, in: foo)) -} - -ForwardModeTests.test("SideEffects") { - func fourthPower(x: Float) -> Float { - var a = x - a = a * x - a = a * x - return a * x - } - expectEqual(4 * 27, derivative(at: 3, in: fourthPower)) -} - -ForwardModeTests.test("TupleSideEffects") { - func foo(_ x: Float) -> Float { - var tuple = (x, x) - tuple.0 = tuple.0 * x - return x * tuple.0 - } - expectEqual(27, derivative(at: 3, in: foo)) - - func fifthPower(_ x: Float) -> Float { - var tuple = (x, x) - tuple.0 = tuple.0 * x - tuple.1 = tuple.0 * x - return tuple.0 * tuple.1 - } - expectEqual(405, derivative(at: 3, in: fifthPower)) - - func nested(_ x: Float) -> Float { - var tuple = ((x, x), x) - tuple.0.0 = tuple.0.0 * x - tuple.0.1 = tuple.0.0 * x - return tuple.0.0 * tuple.0.1 - } - expectEqual(405, derivative(at: 3, in: nested)) - - // FIXME(TF-201): Update after reabstraction thunks can be directly differentiated. - /* - func generic(_ x: T) -> T { - var tuple = (x, x) - tuple.0 += x - tuple.1 += x - return tuple.0 + tuple.0 - } - expectEqual(1, derivative(at: 3.0, in: generic)) - */ -} - -// Tests TF-321. -ForwardModeTests.test("TupleNonDifferentiableElements") { - func foo(_ x: Float) -> Float { - var tuple = (x, 1) - tuple.0 = x - tuple.1 = 1 - return tuple.0 - } - expectEqual(1, derivative(at: 1, in: foo)) - - func bar(_ x: Float) -> Float { - var tuple: (Int, Int, Float, Float) = (1, 1, x, x) - tuple.0 = 1 - tuple.1 = 1 - tuple.3 = x - return tuple.3 - } - expectEqual(1, derivative(at: 1, in: bar)) - - struct Wrapper { - @differentiable(where T : Differentiable) - func baz(_ x: T) -> T { - var tuple = (1, 1, x, 1) - tuple.0 = 1 - tuple.2 = x - tuple.3 = 1 - return tuple.2 - } - } - expectEqual(1, derivative(at: Float(1), in: { x -> Float in - let wrapper = Wrapper() - return wrapper.baz(x) - })) -} - -// Tests TF-21. -ForwardModeTests.test("StructMemberwiseInitializer") { - struct Foo : AdditiveArithmetic, Differentiable { - var stored: Float - var computed: Float { - return stored * stored - } - } - - let derivFoo = differential(at: Float(4), in: { input -> Foo in - let foo = Foo(stored: input) - let foo2 = foo + foo - return Foo(stored: foo2.stored) - })(1) - expectEqual(Foo.TangentVector(stored: 2), derivFoo) - - let computed = derivative(at: Float(4)) { input -> Float in - let foo = Foo(stored: input) - return foo.computed - } - expectEqual(8, computed) - - let derivProduct = derivative(at: Float(4)) { input -> Float in - let foo = Foo(stored: input) - return foo.computed * foo.stored - } - expectEqual(48, derivProduct) - - struct Custom : AdditiveArithmetic, Differentiable { - var x: Float - - // Custom initializer with `@differentiable`. - @differentiable - init(x: Float) { - print(x) - self.x = x - } - } - - let derivCustom = differential(at: Float(4), in: { input -> Custom in - let foo = Custom(x: input) - return foo + foo - })(1) - expectEqual(Custom.TangentVector(x: 2), derivCustom) -} - -// Tests TF-319: struct with non-differentiable constant stored property. -ForwardModeTests.test("StructConstantStoredProperty") { - struct TF_319 : Differentiable { - var x: Float - @noDerivative let constant = Float(2) - - @differentiable - init(x: Float) { - self.x = x - } - - @differentiable(wrt: (self, input)) - func applied(to input: Float) -> Float { - return x * constant * input - } - } - func testStructInit(to input: Float) -> Float { - let model = TF_319(x: 10) - return model.applied(to: input) - } - expectEqual(6, derivative(at: 10, in: { TF_319(x: $0).applied(to: 3) })) - expectEqual(20, derivative(at: 3, in: testStructInit)) -} - -ForwardModeTests.test("StructSideEffects") { - struct Point : AdditiveArithmetic, Differentiable { - var x: Float - var y: Float - var z: Float - } - - func double(_ input: Float) -> Point { - let point = Point(x: input, y: input, z: input) - return point + point - } - expectEqual(Point(x: 2, y: 2, z: 2), differential(at: 4, in: double)(1)) - - func fifthPower(_ input: Float) -> Float { - var point = Point(x: input, y: input, z: input) - point.x = point.x * input - point.y = point.x * input - return point.x * point.y - } - expectEqual(405, derivative(at: 3, in: fifthPower)) - - func mix(_ input: Float) -> Float { - var tuple = (point: Point(x: input, y: input, z: input), float: input) - tuple.point.x = tuple.point.x * tuple.float - tuple.point.y = tuple.point.x * input - return tuple.point.x * tuple.point.y - } - expectEqual(405, derivative(at: 3, in: mix)) - - // Test TF-282. - struct Add : Differentiable { - var bias: Float - func applied(to input: Float) -> Float { - var tmp = input - tmp = tmp + bias - return tmp - } - } - expectEqual(1, derivative(at: 1) { m in Add(bias: m).applied(to: 1) }) -} - -ForwardModeTests.test("StructGeneric") { - struct Generic : AdditiveArithmetic, Differentiable { - var x: T - var y: T - var z: T - } - - let deriv = differential(at: Float(3), in: { input -> Generic in - var generic = Generic(x: input, y: input, z: input) - return generic - })(1) - expectEqual(Generic.TangentVector(x: 1, y: 1, z: 1), deriv) +// struct Linear: DiffReq, VectorProtocol { +// typealias TangentVector = Linear - func fifthPower(_ input: Float) -> Float { - var generic = Generic(x: input, y: input, z: input) - generic.x = generic.x * input - generic.y = generic.x * input - return generic.x * generic.y - } - expectEqual(405, derivative(at: 3, in: fifthPower)) -} +// let m: Float +// let b: Float -ForwardModeTests.test("SubsetIndices") { - func deriv(_ lossFunction: @differentiable (Float, Float) -> Float) -> Float { - return derivative(at: 1) { x in lossFunction(x * x, 10.0) } - } - expectEqual(2, deriv { x, y in x + y }) +// @differentiable(wrt: x) +// func foo(x: Float) -> Float { +// return m * x + b +// } +// } - func derivWRTNonDiff(_ lossFunction: @differentiable (Float, @nondiff Int) -> Float) -> Float { - return derivative(at: 2) { x in lossFunction(x * x, 10) } - } - expectEqual(4, derivWRTNonDiff { x, y in x + Float(y) }) -} +// ForwardModeTests.test("Protocols") { +// func genericFoo(_ t: T, _ x: Float) -> Float { +// t.foo(x: x) +// } +// let inst = Linear(m: 5, b: -2) +// let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) } +// expectEqual(23, y1) +// expectEqual(5, diff1(1)) +// } runAllTests() diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index 64f2eb432d825..38fc6a14a2b77 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -106,8 +106,8 @@ func derivatives(at x: Float, in: F.Type) ProtocolRequirementAutodiffTests.test("constructor, accessor, subscript") { expectEqual( - (1.0, 4.0, 5.0, 5.0), - derivatives(at: 2.0, in: TestFunctionsOfX.self)) + derivatives(at: 2.0, in: TestFunctionsOfX.self), + (1.0, 4.0, 5.0, 5.0)) } // MARK: - Test witness method SIL type computation. diff --git a/test/AutoDiff/simple_math.swift b/test/AutoDiff/simple_math.swift index e159a886610ed..5fb2629ebddad 100644 --- a/test/AutoDiff/simple_math.swift +++ b/test/AutoDiff/simple_math.swift @@ -299,6 +299,7 @@ SimpleMathTests.test("StructGeneric") { generic.y = generic.x * input return generic.x * generic.y } + // FIXME(TF-274): The true expected result is `405`, like other variants of `fifthPower` above. expectEqual(405, gradient(at: 3, in: fifthPower)) }