diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 57dd835ab75d1..6b8b95fb7b788 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -12,7 +12,7 @@ // // SWIFT_ENABLE_TENSORFLOW // -// This file implements reverse-mode automatic differentiation. +// This file implements automatic differentiation. // // NOTE: Although the AD feature is developed as part of the Swift for // TensorFlow project, it is completely independent from TensorFlow support. @@ -64,6 +64,12 @@ static llvm::cl::opt SkipFoldingAutoDiffFunctionExtraction( "differentiation-skip-folding-autodiff-function-extraction", llvm::cl::init(true)); +/// This flag is used to enable full JVP generation. +/// It will be removed when JVP/differential generation is robust. +static llvm::cl::opt RunJVPGeneration( + "run-jvp-generation", + llvm::cl::init(false)); + //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// @@ -506,7 +512,7 @@ class LinearMapInfo { assert(linearMapStruct); auto linearMapStructTy = linearMapStruct->getDeclaredInterfaceType()->getCanonicalType(); - // Create dummy declaration representing enum case p?arameter. + // Create dummy declaration representing enum case parameter. auto *decl = new (astCtx) ParamDecl(ParamDecl::Specifier::Default, loc, loc, Identifier(), loc, Identifier(), moduleDecl); @@ -3145,8 +3151,8 @@ class VJPEmitter final auto &activityCollection = *activityAnalysis->get(original); auto &activityInfo = activityCollection.getActivityInfo( vjp->getLoweredFunctionType()->getGenericSignature()); - LLVM_DEBUG(dumpActivityInfo(*original, indices, activityInfo, - getADDebugStream())); + LLVM_DEBUG( + dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); return activityInfo; } @@ -3171,8 +3177,8 @@ class VJPEmitter final auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); // RAII that pushes the original function's generic signature to - // `module.Types` so that the calls `module.Types.getTypeLowering()` below - // will know the original function's generic parameter types. + // `module.Types` so that the calls to `module.Types.getTypeLowering()` + // below will know the original function's generic parameter types. Lowering::GenericContextScope genericContextScope( module.Types, origTy->getGenericSignature()); @@ -3391,7 +3397,7 @@ class VJPEmitter final auto &builder = getBuilder(); auto *pbStructVal = buildPullbackValueStructValue(ri); - // Get the VJP value corresponding to the original functions's return value. + // Get the value in the VJP corresponding to the original result. auto *origRetInst = cast(origExit->getTerminator()); auto origResult = getOpValue(origRetInst->getOperand()); SmallVector origResults; @@ -3608,7 +3614,7 @@ class VJPEmitter final auto &builder = getBuilder(); auto original = getOpValue(ai->getCallee()); SILValue vjpValue; - // If functionSource is a @differentiable function, just extract it. + // If functionSource is a `@differentiable` function, just extract it. auto originalFnTy = original->getType().castTo(); if (originalFnTy->isDifferentiable()) { auto paramIndices = originalFnTy->getDifferentiationParameterIndices(); @@ -3797,200 +3803,6 @@ class VJPEmitter final }; } // end anonymous namespace -namespace { -class JVPEmitter final - : public TypeSubstCloner { -private: - /// The global context. - ADContext &context; - - /// The original function. - SILFunction *const original; - - /// The `[differentiable]` attribute. - SILDifferentiableAttr *const attr; - - /// The JVP function. - SILFunction *const jvp; - - /// The differential function. - SILFunction *differential; - - /// The differentiation invoker. - DifferentiationInvoker invoker; - - bool errorOccurred = false; - - ASTContext &getASTContext() const { return jvp->getASTContext(); } - SILModule &getModule() const { return jvp->getModule(); } - const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); } - - static SubstitutionMap getSubstitutionMap(SILFunction *original, - SILFunction *jvp) { - auto substMap = original->getForwardingSubstitutionMap(); - if (auto *jvpGenEnv = jvp->getGenericEnvironment()) { - auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap(); - substMap = SubstitutionMap::get( - jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap}, - LookUpConformanceInSubstitutionMap(jvpSubstMap)); - } - return substMap; - } - -public: - explicit JVPEmitter(ADContext &context, SILFunction *original, - SILDifferentiableAttr *attr, SILFunction *jvp, - DifferentiationInvoker invoker) - : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), - context(context), original(original), attr(attr), jvp(jvp), - invoker(invoker) { - // Create empty differential function. - differential = createEmptyDifferential(); - context.getGeneratedFunctions().push_back(differential); - } - - SILFunction *createEmptyDifferential() { - auto &module = context.getModule(); - auto origTy = original->getLoweredFunctionType(); - auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); - - // RAII that pushes the original function's generic signature to - // `module.Types` so that the calls `module.Types.getTypeLowering()` below - // will know the original function's generic parameter types. - Lowering::GenericContextScope genericContextScope( - module.Types, origTy->getGenericSignature()); - - SmallVector diffParams; - SmallVector diffResults; - auto origParams = origTy->getParameters(); - auto indices = attr->getIndices(); - - // Add differential result for the seed. - auto origResInfo = origTy->getResults()[indices.source]; - diffResults.push_back( - SILResultInfo(origResInfo.getType() - ->getAutoDiffAssociatedTangentSpace(lookupConformance) - ->getCanonicalType(), origResInfo.getConvention())); - - // Add pullback results for the requested wrt parameters. - for (auto i : indices.parameters->getIndices()) { - auto origParam = origParams[i]; - diffParams.push_back( - SILParameterInfo(origParam.getType() - ->getAutoDiffAssociatedTangentSpace(lookupConformance) - ->getCanonicalType(), origParam.getConvention())); - } - Mangle::ASTMangler mangler; - auto diffName = original->getASTContext().getIdentifier( - mangler.mangleAutoDiffLinearMapHelper( - original->getName(), AutoDiffLinearMapKind::Differential, - indices)).str(); - auto diffGenericSig = getAssociatedFunctionGenericSignature(attr, original); - auto *diffGenericEnv = diffGenericSig - ? diffGenericSig->createGenericEnvironment() - : nullptr; - auto diffType = SILFunctionType::get( - diffGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(), - origTy->getCalleeConvention(), diffParams, {}, diffResults, None, - original->getASTContext()); - - SILOptFunctionBuilder fb(context.getTransform()); - // The generated tangent linkage is set to Hidden because generated tangent - // are never called cross-module. - auto linkage = SILLinkage::Hidden; - auto *differential = fb.createFunction( - linkage, diffName, diffType, diffGenericEnv, original->getLocation(), - original->isBare(), IsNotTransparent, original->isSerialized(), - original->isDynamicallyReplaceable()); - differential->setOwnershipEliminated(); - differential->setDebugScope( - new (module) SILDebugScope(original->getLocation(), differential)); - // Create empty body of differential. - auto diffConv = differential->getConventions(); - auto *entry = differential->createBasicBlock(); - createEntryArguments(differential); - // Return undef. - SILBuilder builder(entry); - auto loc = differential->getLocation(); - builder.createReturn(loc, SILUndef::get( - differential->mapTypeIntoContext(diffConv.getSILResultType()), - *differential)); - return differential; - } - - /// Run JVP generation. Returns true on error. - bool run(); - - void postProcess(SILInstruction *orig, SILInstruction *cloned) { - if (errorOccurred) - return; - SILClonerWithScopes::postProcess(orig, cloned); - } - - /// General visitor for all instructions. If any error is emitted by previous - /// visits, bail out. - void visit(SILInstruction *inst) { - if (errorOccurred) - return; - TypeSubstCloner::visit(inst); - } - - void visitSILInstruction(SILInstruction *inst) { - context.emitNondifferentiabilityError(inst, invoker, - diag::autodiff_expression_not_differentiable_note); - errorOccurred = true; - } - -private: - /// Get the lowered SIL type of the given nominal type declaration. - SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { - auto nomType = getOpASTType( - nominal->getDeclaredInterfaceType()->getCanonicalType()); - auto nomSILType = context.getTypeConverter().getLoweredType( - nomType, ResilienceExpansion::Minimal); - return nomSILType; - } - -public: - void visitReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); - auto &builder = getBuilder(); - - // Get the JVP value corresponding to the original functions's return value. - auto *origRetInst = cast(origExit->getTerminator()); - auto origResult = getOpValue(origRetInst->getOperand()); - SmallVector origResults; - extractAllElements(origResult, builder, origResults); - - // Return a tuple of the original result and an undef, which at some point - // will be the differential. - SmallVector directResults; - directResults.append(origResults.begin(), origResults.end()); - - // Get differential result type. - auto jvpResultArray = jvp->getLoweredFunctionType()->getResults(); - auto funcType = jvpResultArray.back().getType(); - auto silFuncCanType = funcType->castTo() - ->getCanonicalType(); - - directResults.push_back( - SILUndef::get(jvp->mapTypeIntoContext( - SILType::getPrimitiveObjectType(silFuncCanType)), *jvp)); - builder.createReturn( - ri->getLoc(), joinElements(directResults, builder, loc)); - } - - void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { - // Clone `autodiff_function` from original to JVP, then add the cloned - // instruction to the `autodiff_function` worklist. - TypeSubstCloner::visitAutoDiffFunctionInst(adfi); - auto *newADFI = cast(getOpValue(adfi)); - context.getAutoDiffFunctionInsts().push_back(newADFI); - } -}; -} // end anonymous namespace - //===----------------------------------------------------------------------===// // AdjointValue - a symbolic representation for adjoint values that allows // for efficient differentiation of aggregates. @@ -4153,87 +3965,1104 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, } // end anonymous namespace -//===----------------------------------------------------------------------===// -// PullbackEmitter - visitors on the original function for pullback code -// generation -//===----------------------------------------------------------------------===// - namespace { -class PullbackEmitter final : public SILInstructionVisitor { + +class JVPEmitter final + : public TypeSubstCloner { private: - /// The parent VJP emitter. - VJPEmitter &vjpEmitter; + /// The global context. + ADContext &context; - /// Dominance info for the original function. - DominanceInfo *domInfo = nullptr; + /// The original function. + SILFunction *const original; - /// Post-dominance info for the original function. - PostDominanceInfo *postDomInfo = nullptr; + /// The `[differentiable]` attribute. + SILDifferentiableAttr *const attr; - /// Post-order info for the original function. - PostOrderFunctionInfo *postOrderInfo = nullptr; + /// The JVP function. + SILFunction *const jvp; - /// Mapping from original basic blocks and original values to corresponding - /// adjoint values. - DenseMap, AdjointValue> valueMap; + llvm::BumpPtrAllocator allocator; - /// Mapping from original basic blocks and original buffers to corresponding - /// adjoint buffers. - DenseMap, SILValue> bufferMap; + /// The differentiation invoker. + DifferentiationInvoker invoker; - /// Mapping from original basic blocks to corresponding pullback basic blocks. - /// Pullback basic blocks always have the predecessor as the single argument. - DenseMap pullbackBBMap; + /// Info from activity analysis on the original function. + const DifferentiableActivityInfo &activityInfo; - /// Mapping from pullback basic blocks to pullback struct arguments. - DenseMap pullbackStructArguments; + /// The differential info. + LinearMapInfo differentialInfo; - /// Mapping from pullback struct field declarations to pullback struct + bool errorOccurred = false; + + /// + /// Differential generation related fields. + /// + + /// The builder for the differential function. + SILBuilder differentialAndBuilder; + + /// Mapping from differential basic blocks to differential struct arguments. + DenseMap differentialStructArguments; + + /// Mapping from differential struct field declarations to differential struct /// elements destructured from the linear map basic block argument. In the - /// beginning of each pullback basic block, the block's pullback struct is + /// beginning of each differential basic block, the block's differential struct is /// destructured into individual elements stored here. - DenseMap pullbackStructElements; + DenseMap differentialStructElements; - /// Mapping from original basic blocks and successor basic blocks to - /// corresponding pullback trampoline basic blocks. Trampoline basic blocks - /// take additional arguments in addition to the predecessor enum argument. - DenseMap, SILBasicBlock *> - pullbackTrampolineBBMap; + /// Mapping from original basic blocks and original values to corresponding + /// tangent values. + DenseMap tangentValueMap; - /// Mapping from original basic blocks to dominated active values. - DenseMap> activeValues; + DenseMap diffBBMap; - /// Mapping from original basic blocks and original active values to - /// corresponding pullback block arguments. - DenseMap, SILArgument *> - activeValuePullbackBBArgumentMap; + /// Stack buffers allocated for storing local tangent values. + SmallVector differentialLocalAllocations; - /// Mapping from original basic blocks to local temporary values to be cleaned - /// up. This is populated when pullback emission is run on one basic block and - /// cleaned before processing another basic block. - DenseMap> - blockTemporaries; - llvm::DenseSet blockTemporarySet; + /// Mapping from original blocks to differential values. Used to build differential + /// struct instances. + DenseMap> differentialValues; - /// Stack buffers allocated for storing local adjoint values. - SmallVector functionLocalAllocations; - /// A set used to remember local allocations that were destroyed. - llvm::SmallDenseSet destroyedLocalAllocations; + /// Mapping from original basic blocks and original buffers to corresponding + /// tangent buffers. + DenseMap, SILValue> bufferMap; - /// The seed argument in the pullback function. - SILArgument *seed = nullptr; + /// An auxiliary differential local allocation builder. + SILBuilder diffLocalAllocBuilder; - /// The main builder. - SILBuilder builder; + //--------------------------------------------------------------------------// + // Getters + //--------------------------------------------------------------------------// - /// An auxiliary local allocation builder. - SILBuilder localAllocBuilder; + ASTContext &getASTContext() const { return jvp->getASTContext(); } + SILModule &getModule() const { return jvp->getModule(); } + const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); } + SILFunction &getDifferential() { return differentialAndBuilder.getFunction(); } + SILBuilder &getDifferentialBuilder() { return differentialAndBuilder; } + SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) { +#ifndef NDEBUG + auto *diffStruct = differentialStructArguments[origBB]->getType() + .getStructOrBoundGenericStruct(); + assert(diffStruct == differentialInfo.getLinearMapStruct(origBB)); +#endif + return differentialStructArguments[origBB]; + } - llvm::BumpPtrAllocator allocator; + //--------------------------------------------------------------------------// + // Initialization helpers + //--------------------------------------------------------------------------// - bool errorOccurred = false; + static SubstitutionMap getSubstitutionMap(SILFunction *original, + SILFunction *jvp) { + auto substMap = original->getForwardingSubstitutionMap(); + if (auto *jvpGenEnv = jvp->getGenericEnvironment()) + substMap = substMap.subst(jvpGenEnv->getForwardingSubstitutionMap()); + return substMap; + } - ADContext &getContext() const { return vjpEmitter.context; } + /// Returns the activity info about the SILValues in the original function. + static const DifferentiableActivityInfo &getActivityInfo( + ADContext &context, SILFunction *original, + const SILAutoDiffIndices &indices, SILFunction *jvp) { + // Get activity info of the original function. + auto &passManager = context.getPassManager(); + auto *activityAnalysis = + passManager.getAnalysis(); + auto &activityCollection = *activityAnalysis->get(original); + auto &activityInfo = activityCollection.getActivityInfo( + jvp->getLoweredFunctionType()->getGenericSignature()); + LLVM_DEBUG( + dumpActivityInfo(*original, indices, activityInfo, getADDebugStream())); + return activityInfo; + } + + static SILBuilder + initializeDifferentialAndBuilder(ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, + LinearMapInfo *linearMapInfo) { + auto *differential = + createEmptyDifferential(context, original, attr, linearMapInfo); + return SILBuilder(*differential); + } + + //--------------------------------------------------------------------------// + // Differential struct mapping + //--------------------------------------------------------------------------// + + void initializeDifferentialStructElements(SILBasicBlock *origBB, + SILInstructionResultArray values) { + auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB); + assert(diffStructDecl->getStoredProperties().size() == values.size() && + "The number of differential struct fields must equal the number of " + "differential struct element values"); + for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) { + assert( + std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed + && "Differential struct elements must be @owned"); + auto insertion = differentialStructElements.insert({std::get<0>(pair), + std::get<1>(pair)}); + (void)insertion; + assert(insertion.second && "A differential struct element already exists!"); + } + } + + SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field) { + assert(differentialInfo.getLinearMapStruct(origBB) == + cast(field->getDeclContext())); + assert(differentialStructElements.count(field) && + "Differential struct element for this field does not exist!"); + return differentialStructElements.lookup(field); + } + + //--------------------------------------------------------------------------// + // General utilities + //--------------------------------------------------------------------------// + + SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint() { + // If there are no local allocations, insert at the beginning of the tangent + // entry. + if (differentialLocalAllocations.empty()) + return getDifferential().getEntryBlock()->begin(); + // Otherwise, insert before the last local allocation. Inserting before + // rather than after ensures that allocation and zero initialization + // instructions are grouped together. + auto lastLocalAlloc = differentialLocalAllocations.back(); + auto it = lastLocalAlloc->getDefiningInstruction()->getIterator(); + return it; + } + + /// Get the lowered SIL type of the given nominal type declaration. + SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) { + auto nomType = + getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType()); + auto nomSILType = context.getTypeConverter().getLoweredType( + nomType, ResilienceExpansion::Minimal); + return nomSILType; + } + + /// Build a differential struct value for the original block corresponding to + /// the given terminator. + StructInst *buildDifferentialValueStructValue(TermInst *termInst) { + assert(termInst->getFunction() == original); + auto loc = termInst->getFunction()->getLocation(); + auto *origBB = termInst->getParent(); + auto *jvpBB = BBMap[origBB]; + assert(jvpBB && "Basic block mapping should exist"); + auto *diffStruct = differentialInfo.getLinearMapStruct(origBB); + assert(diffStruct && "The differential struct should have been declared"); + auto structLoweredTy = getNominalDeclLoweredType(diffStruct); + auto bbDifferentialValues = differentialValues[origBB]; + if (!origBB->isEntry()) { + auto *enumArg = jvpBB->getArguments().back(); + bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg); + } + return getBuilder().createStruct(loc, structLoweredTy, + bbDifferentialValues); + } + + bool shouldBeDifferentiated(SILInstruction *inst, + const SILAutoDiffIndices &indices) { + // Anything with an active result should be differentiated. + if (llvm::any_of(inst->getResults(), [&](SILValue val) { + return activityInfo.isActive(val, indices); + })) + return true; + // Anything with an an active argument should be differentiated + // (i.e. `return %0`). + if (llvm::any_of(inst->getAllOperands(), [&](Operand &val) { + return activityInfo.isActive(val.get(), indices); + })) + return true; + if (auto *ai = dyn_cast(inst)) { + // Function applications with an active indirect result should be + // differentiated. + for (auto indRes : ai->getIndirectSILResults()) + if (activityInfo.isActive(indRes, indices)) + return true; + // Function applications with an inout argument should be differentiated. + auto paramInfos = ai->getSubstCalleeConv().getParameters(); + for (auto i : swift::indices(paramInfos)) + if (paramInfos[i].isIndirectInOut() && + activityInfo.isActive( + ai->getArgumentsWithoutIndirectResults()[i], indices)) + return true; + } + // Instructions that may write to memory and that have an active operand + // should be differentiated. + if (inst->mayWriteToMemory()) + for (auto &op : inst->getAllOperands()) + if (activityInfo.isActive(op.get(), indices)) + return true; + return false; + } + + //--------------------------------------------------------------------------// + // Tangent value factory methods + //--------------------------------------------------------------------------// + + AdjointValue makeZeroTangentValue(SILType type) { + return AdjointValue::createZero(allocator, remapType(type)); + } + + AdjointValue makeConcreteTangentValue(SILValue value) { + return AdjointValue::createConcrete(allocator, value); + } + + //--------------------------------------------------------------------------// + // Tangent materialization + //--------------------------------------------------------------------------// + + void emitZeroIndirect(CanType type, SILValue bufferAccess, + SILLocation loc) { + auto builder = getDifferentialBuilder(); + auto tangentSpace = getTangentSpace(type); + assert(tangentSpace && "No tangent space for this type"); + switch (tangentSpace->getKind()) { + case VectorSpace::Kind::Vector: + emitZeroIntoBuffer(builder, type, bufferAccess, loc); + return; + case VectorSpace::Kind::Tuple: { + auto tupleType = tangentSpace->getTuple(); + SmallVector zeroElements; + for (unsigned i : range(tupleType->getNumElements())) { + auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i); + emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(), + eltAddr, loc); + } + return; + } + case VectorSpace::Kind::Function: { + llvm_unreachable( + "Unimplemented: Emit thunks for abstracting zero initialization"); + } + } + } + + SILValue emitZeroDirect(CanType type, SILLocation loc) { + auto diffBuilder = getDifferentialBuilder(); + auto silType = getModule().Types.getLoweredLoadableType( + type, ResilienceExpansion::Minimal); + auto *buffer = diffBuilder.createAllocStack(loc, silType); + emitZeroIndirect(type, buffer, loc); + auto loaded = diffBuilder.emitLoadValueOperation( + loc, buffer, LoadOwnershipQualifier::Take); + diffBuilder.createDeallocStack(loc, buffer); + return loaded; + } + + SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) { + assert(val.getType().isObject()); + LLVM_DEBUG(getADDebugStream() + << "Materializing tangents for " << val << '\n'); + switch (val.getKind()) { + case AdjointValueKind::Zero: { + auto zeroVal = emitZeroDirect(val.getSwiftType(), loc); + return zeroVal; + } + case AdjointValueKind::Aggregate: + llvm_unreachable( + "Tuples and structs are not supported in forward mode yet."); + case AdjointValueKind::Concrete: + return val.getConcreteValue(); + } + } + + SILValue materializeTangent(AdjointValue val, SILLocation loc) { + if (val.isConcrete()) { + LLVM_DEBUG(getADDebugStream() + << "Materializing tangent: Value is concrete.\n"); + return val.getConcreteValue(); + } + LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is " + "non-concrete. Materializing directly.\n"); + return materializeTangentDirect(val, loc); + } + + //--------------------------------------------------------------------------// + // Tangent buffer mapping + //--------------------------------------------------------------------------// + + void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer, + SILValue tangentBuffer) { + assert(originalBuffer->getType().isAddress()); + auto insertion = + bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer); + assert(insertion.second); (void)insertion; + } + + SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { + assert(originalBuffer->getType().isAddress()); + assert(originalBuffer->getFunction() == original); + auto insertion = bufferMap.try_emplace({origBB, originalBuffer}, + SILValue()); + assert(!insertion.second && "tangent buffer should already exist"); + return insertion.first->getSecond(); + } + + //--------------------------------------------------------------------------// + // Type transformer + //--------------------------------------------------------------------------// + + Optional getTangentSpace(CanType type) { + return type->getAutoDiffAssociatedTangentSpace( + LookUpConformanceInModule(getModule().getSwiftModule())); + } + + /// Assuming the given type conforms to `Differentiable` after remapping, + /// returns the associated tangent space type. + SILType getRemappedTangentType(SILType type) { + return SILType::getPrimitiveType( + getTangentSpace(remapType(type).getASTType())->getCanonicalType(), + type.getCategory()); + } + + //--------------------------------------------------------------------------// + // Tngent value mapping + //--------------------------------------------------------------------------// + + /// Get the tangent for an original value. The given value must be in the + /// original function. + /// + /// This method first tries to find an entry in `tangentValueMap`. If a tangent + /// doesn't exist, create a zero tangent. + AdjointValue getTangentValue(SILValue originalValue) { + assert(originalValue->getType().isObject()); + assert(originalValue->getFunction() == original); + auto insertion = tangentValueMap.try_emplace( + originalValue, makeZeroTangentValue( + getRemappedTangentType(originalValue->getType()))); + return insertion.first->getSecond(); + } + + /// Map the tangent value to the given original value. + void setTangentValue(SILBasicBlock *origBB, SILValue originalValue, + AdjointValue newTangentValue) { + assert(originalValue->getType().isObject()); + assert(newTangentValue.getType().isObject()); + assert(originalValue->getFunction() == original); + LLVM_DEBUG(getADDebugStream() << "Adding tangent for " << originalValue); + // The tangent value must be in the tangent space. + assert(newTangentValue.getType() == + getRemappedTangentType(originalValue->getType())); + auto insertion = + tangentValueMap.try_emplace(originalValue, newTangentValue); + auto inserted = insertion.second; + assert(inserted && "The tangent value should not already exist."); + } + + //--------------------------------------------------------------------------// + // Tangent emission helpers + //--------------------------------------------------------------------------// + + void emitTangentForDestroyValueInst(DestroyValueInst *dvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = dvi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc); + diffBuilder.emitDestroyValue(loc, tanVal); + } + + void emitTangentForBeginBorrow(BeginBorrowInst *bbi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = bbi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc); + auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal); + setTangentValue(bbi->getParent(), bbi, + makeConcreteTangentValue(tanValBorrow)); + } + + void emitTangentForEndBorrow(EndBorrowInst *ebi) { + auto &diffBuilder = getDifferentialBuilder(); + auto loc = ebi->getLoc(); + auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc); + diffBuilder.emitEndBorrowOperation(loc, tanVal); + } + + void emitTangentForCopyValueInst(CopyValueInst *cvi) { + auto &diffBuilder = getDifferentialBuilder(); + auto tan = getTangentValue(cvi->getOperand()); + auto tanVal = materializeTangent(tan, cvi->getLoc()); + auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal); + setTangentValue(cvi->getParent(), cvi, + makeConcreteTangentValue(tanValCopy)); + } + + void emitTangentForReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto diffBuilder = getDifferentialBuilder(); + // This vector will contain all the materialized return elements. + SmallVector retElts; + // This vector will contain all indirect parameter tangent buffers. + // TODO: Handle indirect results. + auto tanParam = + materializeTangent(getTangentValue(ri->getOperand()), loc); + diffBuilder.createReturn(ri->getLoc(), tanParam); + } + + void emitTangentForApplyInst(ApplyInst *ai, SILAutoDiffIndices indices) { + auto *bb = ai->getParent(); + auto loc = ai->getLoc(); + auto diffBuilder = getDifferentialBuilder(); + + // Get the differential. + auto *field = differentialInfo.lookUpLinearMapDecl(ai); + assert(field); + SILValue differential = getDifferentialStructElement(bb, field); + + SmallVector diffArgs; + for (auto origArg : ai->getArguments()) { + // Get the tangent value of the original parameter. + if (!activityInfo.isActive(origArg, getIndices())) + continue; + SILValue tanParam; + if (origArg->getType().isObject()) { + tanParam = materializeTangent(getTangentValue(origArg), loc); + } else { + tanParam = getTangentBuffer(ai->getParent(), origArg); + if (errorOccurred) + return; + } + diffArgs.push_back(tanParam); + } + + // Call the differential. + auto *differentialCall = diffBuilder.createApply( + loc, differential, SubstitutionMap(), diffArgs, + /*isNonThrowing*/ false); + diffBuilder.emitDestroyValueOperation(loc, differential); + assert(differentialCall->getNumResults() == 1 && + "Expected differential to return one result"); + + // TODO: Generalize for indirect results, multiple results, etc + auto origResult = ai->getResult(indices.source); + + // Extract all direct results from the differential. + SmallVector differentialDirResults; + extractAllElements(differentialCall, diffBuilder, differentialDirResults); + // Get all differential results in type-defined order. + SmallVector differentialAllResults; + collectAllActualResultsInTypeOrder( + differentialCall, differentialDirResults, + differentialCall->getIndirectSILResults(), differentialAllResults); + auto differentialResult = differentialAllResults[indices.source]; + + // Add tangent for original result. + assert(indices.source == 0 && "Expected result index to be first."); + setTangentValue(bb, origResult, + makeConcreteTangentValue(differentialResult)); + } + +public: + explicit JVPEmitter(ADContext &context, SILFunction *original, + SILDifferentiableAttr *attr, SILFunction *jvp, + DifferentiationInvoker invoker) + : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)), + context(context), original(original), attr(attr), jvp(jvp), + invoker(invoker), activityInfo(getActivityInfo( + context, original, attr->getIndices(), jvp)), + differentialInfo(context, AutoDiffAssociatedFunctionKind::JVP, original, + jvp, attr->getIndices(), activityInfo, getBuilder()), + differentialAndBuilder(initializeDifferentialAndBuilder( + context, original, attr, &differentialInfo)), + diffLocalAllocBuilder(getDifferential()) { + // Get JVP generic signature. + CanGenericSignature jvpGenSig = nullptr; + if (auto *jvpGenEnv = jvp->getGenericEnvironment()) + jvpGenSig = jvpGenEnv->getGenericSignature()->getCanonicalSignature(); + // Create empty differential function. + context.getGeneratedFunctions().push_back(&getDifferential()); + } + + static SILFunction *createEmptyDifferential(ADContext &context, + SILFunction *original, + SILDifferentiableAttr *attr, + LinearMapInfo *linearMapInfo) { + auto &module = context.getModule(); + auto origTy = original->getLoweredFunctionType(); + auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule()); + + // RAII that pushes the original function's generic signature to + // `module.Types` so that calls to `module.Types.getTypeLowering()` below + // will know the original function's generic parameter types. + Lowering::GenericContextScope genericContextScope( + module.Types, origTy->getGenericSignature()); + + // Parameters of the differential are: + // - the tangent values of the wrt parameters. + // - the differential struct for the original entry. + // Result of the differential is in the tangent space of the original + // result. + SmallVector dfParams; + SmallVector dfResults; + auto origParams = origTy->getParameters(); + auto indices = attr->getIndices(); + + // Add differential results. + auto origResInfo = origTy->getResults()[indices.source]; + dfResults.push_back( + SILResultInfo(origResInfo.getType() + ->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getCanonicalType(), + origResInfo.getConvention())); + + // Add differential parameters for the requested wrt parameters. + for (auto i : indices.parameters->getIndices()) { + auto origParam = origParams[i]; + dfParams.push_back(SILParameterInfo( + origParam.getType() + ->getAutoDiffAssociatedTangentSpace(lookupConformance) + ->getCanonicalType(), + origParam.getConvention())); + } + + // Accept a differential struct in the differential parameter list. This is + // the returned differential's closure context. + auto *origEntry = original->getEntryBlock(); + auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry); + auto dfStructType = + dfStruct->getDeclaredInterfaceType()->getCanonicalType(); + dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned}); + + Mangle::ASTMangler mangler; + auto diffName = original->getASTContext().getIdentifier( + mangler.mangleAutoDiffLinearMapHelper( + original->getName(), AutoDiffLinearMapKind::Differential, + indices)).str(); + auto diffGenericSig = getAssociatedFunctionGenericSignature(attr, original); + auto *diffGenericEnv = + diffGenericSig ? diffGenericSig->createGenericEnvironment() : nullptr; + auto diffType = SILFunctionType::get( + diffGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(), + origTy->getCalleeConvention(), dfParams, {}, dfResults, None, + original->getASTContext()); + + SILOptFunctionBuilder fb(context.getTransform()); + // The generated tangent linkage is set to Hidden because generated tangent + // are never called cross-module. + auto linkage = SILLinkage::Hidden; + auto *differential = fb.createFunction( + linkage, diffName, diffType, diffGenericEnv, original->getLocation(), + original->isBare(), IsNotTransparent, original->isSerialized(), + original->isDynamicallyReplaceable()); + differential->setDebugScope( + new (module) SILDebugScope(original->getLocation(), differential)); + + return differential; + } + + /// Set up the differential function. This includes: + /// - Creating all the differential blocks. + /// - Create arguments for the entry block according to the function type. + /// - Adding the tangent values of the parameters to the tangent value map. + /// - Checking for unvaried result and emitting related warnings. + void prepareForDifferentialGeneration() { + auto &diffBuilder = getDifferentialBuilder(); + + // Create differential blocks and arguments. + // TODO: Consider visiting original blocks in pre-order (dominance) order. + auto &differential = getDifferential(); + auto *origEntry = original->getEntryBlock(); + for (auto &origBB : *original) { + auto *diffBB = differential.createBasicBlock(); + diffBBMap.insert({&origBB, diffBB}); + auto diffStructLoweredType = + remapType(differentialInfo.getLinearMapStructLoweredType(&origBB)); + // If the BB is the original entry, then the differential block that we + // just created must be the differential function's entry. Create + // differential entry arguments and continue. + if (&origBB == origEntry) { + assert(diffBB->isEntry()); + createEntryArguments(&differential); + auto *mainDifferentialStruct = diffBB->getArguments().back(); + assert(mainDifferentialStruct->getType() == diffStructLoweredType); + differentialStructArguments[&origBB] = mainDifferentialStruct; + } + + LLVM_DEBUG({ + auto &s = getADDebugStream() + << "Original bb" + std::to_string(origBB.getDebugID()) + << ": To differentiate or not to differentiate?\n"; + for (auto &inst : origBB) { + s << (shouldBeDifferentiated(&inst, getIndices()) ? "[∂] " : "[ ] ") + << inst; + } + }); + } + + assert(diffBBMap.size() == 1 && + "Can only currently handle single basic block functions"); + + // The differential function has type: + // (arg0', ..., argn', entry_df_struct) -> result'. + auto diffParamArgs = + differential.getArgumentsWithoutIndirectResults().drop_back(); + assert(diffParamArgs.size() == + attr->getIndices().parameters->getNumIndices()); + auto origParamArgs = original->getArgumentsWithoutIndirectResults(); + + // Check if result is not varied. + SmallVector origFormalResults; + collectAllFormalResultsInTypeOrder(*original, origFormalResults); + auto origResult = origFormalResults[getIndices().source]; + // Emit warning if original result is not varied, because it will always + // have a zero derivative. + if (!activityInfo.isVaried(origResult, getIndices().parameters)) { + // Emit fixit if original result has a valid source location. + auto startLoc = origResult.getLoc().getStartSourceLoc(); + auto endLoc = origResult.getLoc().getEndSourceLoc(); + if (startLoc.isValid() && endLoc.isValid()) { + context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) + .fixItInsert(startLoc, "withoutDerivative(at:") + .fixItInsertAfter(endLoc, ")"); + } + } + + auto *diffEntry = getDifferential().getEntryBlock(); + diffBuilder.setInsertionPoint( + diffEntry, getNextDifferentialLocalAllocationInsertionPoint()); + + for (auto index : *getIndices().parameters) { + auto diffParam = diffParamArgs[index]; + auto origParam = origParamArgs[index]; + setTangentValue(origEntry, origParam, + makeConcreteTangentValue(diffParam)); + LLVM_DEBUG(getADDebugStream() + << "Assigned parameter " << *diffParam + << " as the tangent of original result " << origParam); + } + } + + /// Run JVP generation. Returns true on error. + bool run() { + LLVM_DEBUG(getADDebugStream() + << "Cloning original @" << original->getName() + << " to jvp @" << jvp->getName() << '\n'); + // Create JVP and differential entry and arguments. + auto *entry = jvp->createBasicBlock(); + createEntryArguments(jvp); + prepareForDifferentialGeneration(); + // Clone. + SmallVector entryArgs(entry->getArguments().begin(), + entry->getArguments().end()); + cloneFunctionBody(original, entry, entryArgs); + // If errors occurred, back out. + if (errorOccurred) + return true; + LLVM_DEBUG(getADDebugStream() << "Generated JVP for " + << original->getName() << ":\n" << *jvp); + LLVM_DEBUG(getADDebugStream() << "Generated differential for " + << original->getName() << ":\n" << getDifferential()); + return errorOccurred; + } + + void postProcess(SILInstruction *orig, SILInstruction *cloned) { + if (errorOccurred) + return; + SILClonerWithScopes::postProcess(orig, cloned); + } + + /// Remap original basic blocks. + SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) { + auto *jvpBB = BBMap[bb]; + return jvpBB; + } + + /// General visitor for all instructions. If any error is emitted by previous + /// visits, bail out. + void visit(SILInstruction *inst) { + if (errorOccurred) + return; + TypeSubstCloner::visit(inst); + } + + void visitSILInstruction(SILInstruction *inst) { + context.emitNondifferentiabilityError(inst, invoker, + diag::autodiff_expression_not_differentiable_note); + errorOccurred = true; + } + + /// Handle `copy_value` instruction. + /// Original: y = copy_value x + /// Adjoint: tan[x] = copy_value tan[y] + void visitCopyValueInst(CopyValueInst *cvi) { + TypeSubstCloner::visitCopyValueInst(cvi); + emitTangentForCopyValueInst(cvi); + } + + void visitReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto *origExit = ri->getParent(); + auto &builder = getBuilder(); + auto *diffStructVal = buildDifferentialValueStructValue(ri); + + // Get the value in the JVP corresponding to the original result. + auto *origRetInst = cast(origExit->getTerminator()); + auto origResult = getOpValue(origRetInst->getOperand()); + SmallVector origResults; + extractAllElements(origResult, builder, origResults); + + // Get and partially apply the differential. + auto jvpGenericEnv = jvp->getGenericEnvironment(); + auto jvpSubstMap = jvpGenericEnv + ? jvpGenericEnv->getForwardingSubstitutionMap() + : jvp->getForwardingSubstitutionMap(); + auto *differentialRef = builder.createFunctionRef(loc, &getDifferential()); + auto *differentialPartialApply = builder.createPartialApply( + loc, differentialRef, jvpSubstMap, {diffStructVal}, + ParameterConvention::Direct_Guaranteed); + + // Return a tuple of the original result and differential. + SmallVector directResults; + directResults.append(origResults.begin(), origResults.end()); + directResults.push_back(differentialPartialApply); + builder.createReturn( + ri->getLoc(), joinElements(directResults, builder, loc)); + + // Differential emission. + emitTangentForReturnInst(ri); + } + + void visitInstructionsInBlock(SILBasicBlock *bb) { + // Destructure the differential struct to get the elements. + auto &diffBuilder = getDifferentialBuilder(); + auto diffLoc = getDifferential().getLocation(); + auto *diffBB = diffBBMap.lookup(bb); + auto *mainDifferentialStruct = diffBB->getArguments().back(); + diffBuilder.setInsertionPoint(diffBB); + auto *dsi = diffBuilder.createDestructureStruct( + diffLoc, mainDifferentialStruct); + initializeDifferentialStructElements(bb, dsi->getResults()); + TypeSubstCloner::visitInstructionsInBlock(bb); + } + + void visitDestroyValueInst(DestroyValueInst *dvi) { + TypeSubstCloner::visitDestroyValueInst(dvi); + if (shouldBeDifferentiated(dvi, getIndices())) + emitTangentForDestroyValueInst(dvi); + } + + void visitBeginBorrowInst(BeginBorrowInst *bbi) { + TypeSubstCloner::visitBeginBorrowInst(bbi); + if (shouldBeDifferentiated(bbi, getIndices())) + emitTangentForBeginBorrow(bbi); + } + + void visitEndBorrowInst(EndBorrowInst *ebi) { + TypeSubstCloner::visitEndBorrowInst(ebi); + if (shouldBeDifferentiated(ebi, getIndices())) + emitTangentForEndBorrow(ebi); + } + + // If an `apply` has active results or active inout parameters, replace it + // with an `apply` of its JVP. + void visitApplyInst(ApplyInst *ai) { + // Special handling logic only applies when `apply` has active results or + // active arguments at an active parameter position. If not, just do + // standard cloning. + SmallVector allResults; + allResults.push_back(ai); + allResults.append(ai->getIndirectSILResults().begin(), + ai->getIndirectSILResults().end()); + auto hasActiveResults = llvm::any_of(allResults, [this](SILValue res) { + return activityInfo.isActive(res, getIndices()); + }); + auto hasActiveArguments = llvm::any_of( + ai->getArgumentsWithoutIndirectResults(), [this](SILValue arg) { + return activityInfo.isActive(arg, getIndices()); + }); + // Check for active 'inout' arguments. + auto paramInfos = ai->getSubstCalleeConv().getParameters(); + for (unsigned i : swift::indices(paramInfos)) { + if (paramInfos[i].isIndirectInOut() && + activityInfo.isActive(ai->getArgumentsWithoutIndirectResults()[i], + getIndices())) { + // Reject functions with active inout arguments. It's not yet supported. + context.emitNondifferentiabilityError( + ai, invoker, + diag::autodiff_cannot_differentiate_through_inout_arguments); + errorOccurred = true; + return; + } + } + + // If there's no active results, this function should not be differentiated. + // Do standard cloning. + if (!hasActiveResults || !hasActiveArguments) { + LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n'); + TypeSubstCloner::visitApplyInst(ai); + return; + } + + // Get the parameter indices required for differentiating this function. + LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n'); + SmallVector activeParamIndices; + SmallVector activeResultIndices; + collectMinimalIndicesForFunctionCall(ai, allResults, getIndices(), + activityInfo, activeParamIndices, + activeResultIndices); + assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); + assert(!activeResultIndices.empty() && "Result indices cannot be empty"); + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; + interleave(activeParamIndices.begin(), activeParamIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "}, results={"; interleave( + activeResultIndices.begin(), activeResultIndices.end(), + [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); + s << "}\n";); + // FIXME: We don't support multiple active results yet. + if (activeResultIndices.size() > 1) { + context.emitNondifferentiabilityError( + ai, invoker, diag::autodiff_expression_not_differentiable_note); + errorOccurred = true; + return; + } + + // Form expected indices by assuming there's only one result. + SILAutoDiffIndices indices( + activeResultIndices.front(), + AutoDiffIndexSubset::get( + getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), + activeParamIndices)); + + // Emit the JVP. + auto loc = ai->getLoc(); + auto &builder = getBuilder(); + auto original = getOpValue(ai->getCallee()); + SILValue jvpValue; + // If functionSource is a `@differentiable` function, just extract it. + auto originalFnTy = original->getType().castTo(); + if (originalFnTy->isDifferentiable()) { + auto paramIndices = originalFnTy->getDifferentiationParameterIndices(); + for (auto i : indices.parameters->getIndices()) { + if (!paramIndices->contains(i)) { + context.emitNondifferentiabilityError(original, invoker, + diag::autodiff_function_nondiff_parameter_not_differentiable); + errorOccurred = true; + return; + } + } + auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original); + jvpValue = builder.createAutoDiffFunctionExtract( + loc, AutoDiffFunctionExtractInst::Extractee::JVP, + /*differentiationOrder*/ 1, borrowedDiffFunc); + jvpValue = builder.emitCopyValueOperation(loc, jvpValue); + } + + // If JVP has not yet been found, emit an `autodiff_function` instruction + // on the remapped original function operand and `autodiff_function_extract` + // the JVP. The actual JVP functions will be populated in the + // `autodiff_function` during the transform main loop. + SILValue differentiableFunc; + if (!jvpValue) { + // FIXME: Handle indirect differentiation invokers. This may require some + // redesign: currently, each original function + attribute pair is mapped + // only to one invoker. + /* + DifferentiationInvoker indirect(ai, attr); + auto insertion = + context.getInvokers().try_emplace({this->original, attr}, indirect); + auto &invoker = insertion.first->getSecond(); + invoker = indirect; + */ + + // If the original `apply` instruction has a substitution map, then the + // applied function is specialized. + // In the JVP, specialization is also necessary for parity. The original + // function operand is specialized with a remapped version of same + // substitution map using an argument-less `partial_apply`. + if (ai->getSubstitutionMap().empty()) { + original = builder.emitCopyValueOperation(loc, original); + } else { + auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap()); + auto jvpPartialApply = getBuilder().createPartialApply( + ai->getLoc(), original, substMap, {}, + ParameterConvention::Direct_Guaranteed); + original = jvpPartialApply; + } + + // Check and diagnose non-differentiable original function type. + auto diagnoseNondifferentiableOriginalFunctionType = + [&](CanSILFunctionType origFnTy) { + // Check and diagnose non-differentiable arguments. + for (unsigned paramIndex : range(originalFnTy->getNumParameters())) { + if (indices.isWrtParameter(paramIndex) && + !originalFnTy->getParameters()[paramIndex] + .getSILStorageType() + .isDifferentiable(getModule())) { + context.emitNondifferentiabilityError( + ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker, + diag::autodiff_nondifferentiable_argument); + errorOccurred = true; + return true; + } + } + // Check and diagnose non-differentiable results. + if (!originalFnTy->getResults()[indices.source] + .getSILStorageType() + .isDifferentiable(getModule())) { + context.emitNondifferentiabilityError( + original, invoker, diag::autodiff_nondifferentiable_result); + errorOccurred = true; + return true; + } + return false; + }; + if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy)) + return; + + auto *autoDiffFuncInst = + context.createAutoDiffFunction(builder, loc, indices.parameters, + /*differentiationOrder*/ 1, original); + differentiableFunc = autoDiffFuncInst; + + // Record the `autodiff_function` instruction. + context.getAutoDiffFunctionInsts().push_back(autoDiffFuncInst); + context.getResultIndices()[autoDiffFuncInst] = + activeResultIndices.front(); + + jvpValue = builder.createAutoDiffFunctionExtract( + loc, AutoDiffFunctionExtractInst::Extractee::JVP, + /*differentiationOrder*/ 1, autoDiffFuncInst); + } + + // Call the JVP using the original parameters. + SmallVector jvpArgs; + auto jvpFnTy = getOpType(jvpValue->getType()).castTo(); + auto numJVPArgs = + jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults(); + jvpArgs.reserve(numJVPArgs); + // Collect substituted arguments. + for (auto origArg : ai->getArguments()) + jvpArgs.push_back(getOpValue(origArg)); + assert(jvpArgs.size() == numJVPArgs); + // Apply the JVP. + // The JVP should be specialized, so no substitution map is necessary. + auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(), + jvpArgs, ai->isNonThrowing()); + LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall); + + // Release the differentiable function. + builder.emitDestroyValueOperation(loc, jvpValue); + + // Get the JVP results (original results and differential). + SmallVector jvpDirectResults; + extractAllElements(jvpCall, builder, jvpDirectResults); + auto originalDirectResults = + ArrayRef(jvpDirectResults).drop_back(1); + auto originalDirectResult = + joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc()); + + mapValue(ai, originalDirectResult); + + // Some instructions that produce the callee may have been cloned. + // If the original callee did not have any users beyond this `apply`, + // recursively kill the cloned callee. + if (auto *origCallee = cast_or_null( + ai->getCallee()->getDefiningInstruction())) + if (origCallee->hasOneUse()) + recursivelyDeleteTriviallyDeadInstructions( + getOpValue(origCallee)->getDefiningInstruction()); + + // Record the callee differential function value. + // This is used later to construct a differential struct. + auto diffFunc = jvpDirectResults.back(); + differentialValues[ai->getParent()].push_back(diffFunc); + + // Differential emission. + emitTangentForApplyInst(ai, indices); + } + + void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { + // Clone `autodiff_function` from original to JVP, then add the cloned + // instruction to the `autodiff_function` worklist. + TypeSubstCloner::visitAutoDiffFunctionInst(adfi); + auto *newADFI = cast(getOpValue(adfi)); + context.getAutoDiffFunctionInsts().push_back(newADFI); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// PullbackEmitter - visitors on the original function for pullback code +// generation +//===----------------------------------------------------------------------===// + +namespace { +class PullbackEmitter final : public SILInstructionVisitor { +private: + /// The parent VJP emitter. + VJPEmitter &vjpEmitter; + + /// Dominance info for the original function. + DominanceInfo *domInfo = nullptr; + + /// Post-dominance info for the original function. + PostDominanceInfo *postDomInfo = nullptr; + + /// Post-order info for the original function. + PostOrderFunctionInfo *postOrderInfo = nullptr; + + /// Mapping from original basic blocks and original values to corresponding + /// adjoint values. + DenseMap, AdjointValue> valueMap; + + /// Mapping from original basic blocks and original buffers to corresponding + /// adjoint buffers. + DenseMap, SILValue> bufferMap; + + /// Mapping from original basic blocks to corresponding pullback basic blocks. + /// Pullback basic blocks always have the predecessor as the single argument. + DenseMap pullbackBBMap; + + /// Mapping from pullback basic blocks to pullback struct arguments. + DenseMap pullbackStructArguments; + + /// Mapping from pullback struct field declarations to pullback struct + /// elements destructured from the linear map basic block argument. In the + /// beginning of each pullback basic block, the block's pullback struct is + /// destructured into individual elements stored here. + DenseMap pullbackStructElements; + + /// Mapping from original basic blocks and successor basic blocks to + /// corresponding pullback trampoline basic blocks. Trampoline basic blocks + /// take additional arguments in addition to the predecessor enum argument. + DenseMap, SILBasicBlock *> + pullbackTrampolineBBMap; + + /// Mapping from original basic blocks to dominated active values. + DenseMap> activeValues; + + /// Mapping from original basic blocks and original active values to + /// corresponding pullback block arguments. + DenseMap, SILArgument *> + activeValuePullbackBBArgumentMap; + + /// Mapping from original basic blocks to local temporary values to be cleaned + /// up. This is populated when pullback emission is run on one basic block and + /// cleaned before processing another basic block. + DenseMap> + blockTemporaries; + + llvm::DenseSet blockTemporarySet; + + /// Stack buffers allocated for storing local adjoint values. + SmallVector functionLocalAllocations; + /// A set used to remember local allocations that were destroyed. + llvm::SmallDenseSet destroyedLocalAllocations; + + /// The seed argument in the pullback function. + SILArgument *seed = nullptr; + + /// The main builder. + SILBuilder builder; + + /// An auxiliary local allocation builder. + SILBuilder localAllocBuilder; + + llvm::BumpPtrAllocator allocator; + + bool errorOccurred = false; + + ADContext &getContext() const { return vjpEmitter.context; } SILModule &getModule() const { return getContext().getModule(); } ASTContext &getASTContext() const { return getPullback().getASTContext(); } SILFunction &getOriginal() const { return *vjpEmitter.original; } @@ -4267,16 +5096,6 @@ class PullbackEmitter final : public SILInstructionVisitor { // Pullback struct mapping //--------------------------------------------------------------------------// - SILArgument *getPullbackBlockPullbackStructArgument(SILBasicBlock *origBB) { -#ifndef NDEBUG - assert(origBB->getParent() == &getOriginal()); - auto *pbStruct = pullbackStructArguments[origBB]->getType() - .getStructOrBoundGenericStruct(); - assert(pbStruct == getPullbackInfo().getLinearMapStruct(origBB)); -#endif - return pullbackStructArguments[origBB]; - } - void initializePullbackStructElements(SILBasicBlock *origBB, SILInstructionResultArray values) { auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB); @@ -4565,8 +5384,7 @@ class PullbackEmitter final : public SILInstructionVisitor { // rather than after ensures that allocation and zero initialization // instructions are grouped together. auto lastLocalAlloc = functionLocalAllocations.back(); - auto it = lastLocalAlloc->getDefiningInstruction()->getIterator(); - return it; + return lastLocalAlloc->getDefiningInstruction()->getIterator(); } SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer) { @@ -4862,8 +5680,7 @@ class PullbackEmitter final : public SILInstructionVisitor { auto startLoc = origResult.getLoc().getStartSourceLoc(); auto endLoc = origResult.getLoc().getEndSourceLoc(); if (startLoc.isValid() && endLoc.isValid()) { - getContext() - .diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) + getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit) .fixItInsert(startLoc, "withoutDerivative(at:") .fixItInsertAfter(endLoc, ")"); } @@ -6312,25 +7129,6 @@ bool VJPEmitter::run() { return errorOccurred; } -bool JVPEmitter::run() { - LLVM_DEBUG(getADDebugStream() - << "Cloning original @" << original->getName() - << " to jvp @" << jvp->getName() << '\n'); - // Create entry BB and arguments. - auto *entry = jvp->createBasicBlock(); - createEntryArguments(jvp); - SmallVector entryArgs(entry->getArguments().begin(), - entry->getArguments().end()); - cloneFunctionBody(original, entry, entryArgs); - // If errors occurred, back out. - if (errorOccurred) - return true; - - LLVM_DEBUG(getADDebugStream() << "Generated JVP for " - << original->getName() << ":\n" << *jvp); - return errorOccurred; -} - //===----------------------------------------------------------------------===// // `[differentiable]` attribute processing //===----------------------------------------------------------------------===// @@ -6380,7 +7178,7 @@ static SILFunction *createEmptyVJP( auto vjpGenericSig = getAssociatedFunctionGenericSignature(attr, original); // RAII that pushes the original function's generic signature to - // `module.Types` so that the calls `module.Types.getTypeLowering()` below + // `module.Types` so that calls to `module.Types.getTypeLowering()` below // will know the VJP's generic parameter types. Lowering::GenericContextScope genericContextScope( module.Types, vjpGenericSig); @@ -6430,7 +7228,7 @@ static SILFunction *createEmptyJVP( auto jvpGenericSig = getAssociatedFunctionGenericSignature(attr, original); // RAII that pushes the original function's generic signature to - // `module.Types` so that the calls `module.Types.getTypeLowering()` below + // `module.Types` so that calls to `module.Types.getTypeLowering()` below // will know the VJP's generic parameter types. Lowering::GenericContextScope genericContextScope( module.Types, jvpGenericSig); @@ -6513,51 +7311,36 @@ bool ADContext::processDifferentiableAttribute( attr->setVJPName(vjpName); } - // If the JVP doesn't exist, need to synthesize it. - auto vjpGenerated = false; - if (!vjp) { - // Diagnose: - // - Functions with no return. - // - Functions with unsupported control flow. - if (diagnoseNoReturn(*this, original, invoker) || - diagnoseUnsupportedControlFlow(*this, original, invoker)) - return true; - - vjpGenerated = true; - vjp = createEmptyVJP(*this, original, attr, isAssocFnExported); - getGeneratedFunctions().push_back(vjp); - VJPEmitter emitter(*this, original, attr, vjp, invoker); - if (emitter.run()) { - return true; - } - } - // If the JVP doesn't exist, need to synthesize it. if (!jvp) { // Diagnose: // - Functions with no return. // - Functions with unsupported control flow. - if (vjpGenerated && (diagnoseNoReturn(*this, original, invoker) || + if (RunJVPGeneration && (diagnoseNoReturn(*this, original, invoker) || diagnoseUnsupportedControlFlow(*this, original, invoker))) return true; jvp = createEmptyJVP(*this, original, attr, isAssocFnExported); getGeneratedFunctions().push_back(jvp); - if (vjpGenerated) { + // For now, only run JVP emission if the flag is on and if there is no + // user defined VJP. If there is a user defined VJP but no JVP, that means + // the user should have provided a custom JVP as well since we likely + // cannot derive a custom JVP. Thus create empty body. + if (RunJVPGeneration && !vjp) { JVPEmitter emitter(*this, original, attr, jvp, invoker); - return emitter.run(); + if (emitter.run()) + return true; } else { LLVM_DEBUG(getADDebugStream() << "Generating empty JVP for original @" << original->getName() << '\n'); // Create empty body of JVP if the user defined their own custom VJP. - // Return undef. auto *entry = jvp->createBasicBlock(); createEntryArguments(jvp); - auto diffConv = jvp->getConventions(); SILBuilder builder(entry); auto loc = jvp->getLocation(); + // Destroy all owned arguments. for (auto *arg : entry->getArguments()) { if (arg->getOwnershipKind() == ValueOwnershipKind::Owned) { @@ -6567,15 +7350,46 @@ bool ADContext::processDifferentiableAttribute( builder.emitDestroyAddr(loc, arg); } } - // Return an `undef`. - builder.createReturn(loc, SILUndef::get( - jvp->mapTypeIntoContext(diffConv.getSILResultType()), - *jvp)); + + // Add a fatal error in case this function is called by the user. + auto neverResultInfo = SILResultInfo( + module.getASTContext().getNeverType(), ResultConvention::Unowned); + auto fatalErrorJVPType = SILFunctionType::get( + /*genericSig*/ nullptr, + SILFunctionType::ExtInfo().withRepresentation( + SILFunctionTypeRepresentation::Thin), + SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {}, + /*interfaceYields*/ {}, neverResultInfo, + /*interfaceErrorResults*/ None, getASTContext()); + auto fnBuilder = SILOptFunctionBuilder(getTransform()); + auto *fatalErrrorJvpFunc = fnBuilder.getOrCreateFunction( + loc, "_printJVPErrorAndExit", SILLinkage::PublicExternal, + fatalErrorJVPType, IsNotBare, IsNotTransparent, IsNotSerialized, + IsNotDynamic, ProfileCounter(), IsNotThunk); + auto *jvpErrorFuncRef = + builder.createFunctionRef(loc, fatalErrrorJvpFunc); + builder.createApply(loc, jvpErrorFuncRef, SubstitutionMap(), {}); + builder.createUnreachable(loc); LLVM_DEBUG(getADDebugStream() << "Generated empty JVP for " << original->getName() << ":\n" << *jvp); } } + // If the VJP doesn't exist, need to synthesize it. + if (!vjp) { + // Diagnose: + // - Functions with no return. + // - Functions with unsupported control flow. + if (diagnoseNoReturn(*this, original, invoker) || + diagnoseUnsupportedControlFlow(*this, original, invoker)) + return true; + + vjp = createEmptyVJP(*this, original, attr, isAssocFnExported); + getGeneratedFunctions().push_back(vjp); + VJPEmitter emitter(*this, original, attr, vjp, invoker); + return emitter.run(); + } + return false; } diff --git a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift index 05150d8fe2175..ba4c3ff02327f 100644 --- a/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift +++ b/stdlib/private/DifferentiationUnittest/DifferentiationUnittest.swift @@ -65,12 +65,12 @@ public struct Tracked { } private var handle: Box - @differentiable(vjp: _vjpInit where T : Differentiable, T == T.TangentVector) + @differentiable(jvp: _jvpInit, vjp: _vjpInit where T : Differentiable, T == T.TangentVector) public init(_ value: T) { self.handle = Box(value) } - @differentiable(vjp: _vjpValue where T : Differentiable, T == T.TangentVector) + @differentiable(jvp: _jvpValue, vjp: _vjpValue where T : Differentiable, T == T.TangentVector) public var value: T { get { handle.value } set { handle.value = newValue } @@ -177,10 +177,21 @@ extension Tracked where T : Differentiable, T == T.TangentVector { return (Tracked(value), { v in v.value }) } + @usableFromInline + internal static func _jvpInit(_ value: T) + -> (value: Self, differential: (T.TangentVector) -> (Self.TangentVector)) { + return (Tracked(value), { v in Tracked(v) }) + } + @usableFromInline internal func _vjpValue() -> (T, (T.TangentVector) -> Self.TangentVector) { return (value, { v in Tracked(v) }) } + + @usableFromInline + internal func _jvpValue() -> (T, (Self.TangentVector) -> T.TangentVector) { + return (value, { v in v.value }) + } } extension Tracked where T : Differentiable, T == T.TangentVector { @@ -197,6 +208,20 @@ extension Tracked where T : Differentiable, T == T.TangentVector { -> (value: Self, pullback: (Self) -> (Self, Self)) { return (lhs - rhs, { v in (v, .zero - v) }) } + + @usableFromInline + @differentiating(+) + internal static func _vjpAdd(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> (Self)) { + return (lhs + rhs, { (dx, dy) in dx + dy }) + } + + @usableFromInline + @differentiating(-) + internal static func _vjpSubtract(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> (Self)) { + return (lhs - rhs, { (dx, dy) in dx - dy }) + } } extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, @@ -207,6 +232,13 @@ extension Tracked where T : Differentiable & SignedNumeric, T == T.Magnitude, -> (value: Self, pullback: (Self) -> (Self, Self)) { return (lhs * rhs, { v in (v * rhs, v * lhs) }) } + + @usableFromInline + @differentiating(*) + internal static func _vjpMultiply(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> (Self)) { + return (lhs * rhs, { (dx, dy) in dx * rhs + dy * lhs }) + } } extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector { @@ -216,6 +248,13 @@ extension Tracked where T : Differentiable & FloatingPoint, T == T.TangentVector -> (value: Self, pullback: (Self) -> (Self, Self)) { return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } + + @usableFromInline + @differentiating(/) + internal static func _vjpDivide(lhs: Self, rhs: Self) + -> (value: Self, differential: (Self, Self) -> (Self)) { + return (lhs / rhs, { (dx, dy) in dx / rhs - lhs / (rhs * rhs) * dy }) + } } // Differential operators for `Tracked`. diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index 809bdc347864a..159352d177e69 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -479,7 +479,7 @@ public func pullback( @inlinable public func derivative( - at x: T, in f: @escaping @differentiable (T) -> R + at x: T, in f: @differentiable (T) -> R ) -> R.TangentVector where T.TangentVector == T { return differential(at: x, in: f)(T(1)) @@ -487,7 +487,7 @@ public func derivative( @inlinable public func derivative( - at x: T, _ y: U, in f: @escaping @differentiable (T, U) -> R + at x: T, _ y: U, in f: @differentiable (T, U) -> R ) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U { @@ -496,7 +496,7 @@ public func derivative( @inlinable public func derivative( - at x: T, _ y: U, _ z: V, in f: @escaping @differentiable (T, U, V) -> R + at x: T, _ y: U, _ z: V, in f: @differentiable (T, U, V) -> R ) -> R.TangentVector where T.TangentVector == T, U.TangentVector == U, @@ -995,3 +995,14 @@ public extension Array where Element: Differentiable { return (value: values, pullback: pullback) } } + +//===----------------------------------------------------------------------===// +// JVP Diagnostics +//===----------------------------------------------------------------------===// +@_silgen_name("_printJVPErrorAndExit") +public func _printJVPErrorAndExit() -> Never { + fatalError(""" + JVP does not exist. Differential-first differentiation APIs are \ + experimental and should not be used. + """) +} diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index df4a993640519..915a0fa2498c1 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1634,6 +1634,15 @@ extension ${Self} { -> (value: ${Self}, pullback: (${Self}) -> ${Self}) { return (-x, { v in -v }) } + + @usableFromInline + @_transparent + // SWIFT_ENABLE_TENSORFLOW + @differentiating(-) + static func _jvpNegate(x: ${Self}) + -> (value: ${Self}, differential: (${Self}) -> ${Self}) { + return (-x, { dx in -dx }) + } } //===----------------------------------------------------------------------===// @@ -1797,6 +1806,15 @@ extension ${Self} { return (lhs + rhs, { v in (v, v) }) } + @inlinable // FIXME(sil-serialize-all) + @_transparent + @differentiating(+) + static func _jvpAdd( + lhs: ${Self}, rhs: ${Self} + ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { + return (lhs + rhs, { (dlhs, drhs) in dlhs + drhs }) + } + @inlinable // FIXME(sil-serialize-all) @_transparent @differentiating(-) @@ -1806,6 +1824,15 @@ extension ${Self} { return (lhs - rhs, { v in (v, -v) }) } + @inlinable // FIXME(sil-serialize-all) + @_transparent + @differentiating(-) + static func _jvpSubtract( + lhs: ${Self}, rhs: ${Self} + ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { + return (lhs - rhs, { (dlhs, drhs) in dlhs - drhs }) + } + @inlinable // FIXME(sil-serialize-all) @_transparent @differentiating(*) @@ -1815,6 +1842,15 @@ extension ${Self} { return (lhs * rhs, { v in (rhs * v, lhs * v) }) } + @inlinable // FIXME(sil-serialize-all) + @_transparent + @differentiating(*) + static func _jvpMultiply( + lhs: ${Self}, rhs: ${Self} + ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { + return (lhs * rhs, { (dlhs, drhs) in lhs * drhs + rhs * dlhs }) + } + @inlinable // FIXME(sil-serialize-all) @_transparent @differentiating(/) @@ -1823,6 +1859,15 @@ extension ${Self} { ) -> (value: ${Self}, pullback: (${Self}) -> (${Self}, ${Self})) { return (lhs / rhs, { v in (v / rhs, -lhs / (rhs * rhs) * v) }) } + + @inlinable // FIXME(sil-serialize-all) + @_transparent + @differentiating(/) + static func _jvpDivide( + lhs: ${Self}, rhs: ${Self} + ) -> (value: ${Self}, differential: (${Self}, ${Self}) -> ${Self}) { + return (lhs / rhs, { (dlhs, drhs) in dlhs / rhs - lhs / (rhs * rhs) * drhs }) + } } //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index c4f87cf35ca05..dbdddd681ff81 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -277,7 +277,6 @@ func activeInoutArg(_ x: Float) -> Float { // expected-error @+1 {{function is not differentiable}} _ = pullback(at: .zero, in: activeInoutArg(_:)) - func activeInoutArgTuple(_ x: Float) -> Float { var tuple = (x, x) // expected-note @+1 {{cannot differentiate through 'inout' arguments}} diff --git a/test/AutoDiff/differentiable_attr_silgen_cross_module.swift b/test/AutoDiff/differentiable_attr_silgen_cross_module.swift index ed4828ade0bfa..634d189f3c36f 100644 --- a/test/AutoDiff/differentiable_attr_silgen_cross_module.swift +++ b/test/AutoDiff/differentiable_attr_silgen_cross_module.swift @@ -24,4 +24,3 @@ _ = pullback(at: Wrapper(1)) { x in x + x * x } // CHECK-SILGEN-NEXT: sil [differentiable source 0 wrt 0, 1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper // CHECK-SIL-LABEL: // static Wrapper.+ infix(_:_:) // CHECK-SIL-NEXT: sil [differentiable source 0 wrt 0, 1 jvp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__jvp_src_0_wrt_0_1 vjp @AD__$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ__vjp_src_0_wrt_0_1] @$s39differentiable_attr_silgen_other_module7WrapperV1poiyA2C_ACtFZ : $@convention(method) (Wrapper, Wrapper, @thin Wrapper.Type) -> Wrapper - diff --git a/test/AutoDiff/forward_mode_diagnostics.swift b/test/AutoDiff/forward_mode_diagnostics.swift new file mode 100644 index 0000000000000..f249b733ac9c8 --- /dev/null +++ b/test/AutoDiff/forward_mode_diagnostics.swift @@ -0,0 +1,95 @@ +// RUN: %target-swift-frontend -Xllvm -run-jvp-generation -emit-sil -verify %s + +// TODO: move these tests back into `autodiff_diagnostics.swift` once +// forward mode reaches feature parity with reverse mode. + +//===----------------------------------------------------------------------===// +// Basic function +//===----------------------------------------------------------------------===// + +func one_to_one_0(_ x: Float) -> Float { + return x + 2 +} + +_ = derivative(at: 0, in: one_to_one_0) // okay! + +//===----------------------------------------------------------------------===// +// Function composition +//===----------------------------------------------------------------------===// + +func base(_ x: Float) -> Float { + // expected-error @+2 2 {{expression is not differentiable}} + // expected-note @+1 2 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}} + return Float(Int(x)) +} + +// TODO: Fix nested differentiation diagnostics. Need to fix indirect differentiation invokers. +func nested(_ x: Float) -> Float { + // xpected-note @+1 {{when differentiating this function call}} + return base(x) +} + +func middle(_ x: Float) -> Float { + // xpected-note @+1 {{when differentiating this function call}} + return nested(x) +} + +func middle2(_ x: Float) -> Float { + // xpected-note @+1 {{when differentiating this function call}} + return middle(x) +} + +func func_to_diff(_ x: Float) -> Float { + // xpected-note @+1 {{expression is not differentiable}} + return middle2(x) +} + +func calls_diff_of_nested(_ x: Float) -> Float { + // xpected-error @+1 {{function is not differentiable}} + return derivative(at: x, in: func_to_diff) +} + +//===----------------------------------------------------------------------===// +// Inout arguments +//===----------------------------------------------------------------------===// + +func activeInoutArg(_ x: Float) -> Float { + var a = x + // expected-note @+1 {{cannot differentiate through 'inout' arguments}} + a += x + return a +} +// expected-error @+1 {{function is not differentiable}} +_ = differential(at: .zero, in: activeInoutArg(_:)) + +func activeInoutArgTuple(_ x: Float) -> Float { + var tuple = (x, x) + // expected-note @+1 {{cannot differentiate through 'inout' arguments}} + tuple.0 *= x + return x * tuple.0 +} +// expected-error @+1 {{function is not differentiable}} +_ = differential(at: .zero, in: activeInoutArgTuple(_:)) + +//===----------------------------------------------------------------------===// +// Non-varied results +//===----------------------------------------------------------------------===// + +func one() -> Float { + return 1 +} +@differentiable +func nonVariedResult(_ x: Float) -> Float { + // expected-warning @+1 2 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} + return one() +} + +//===----------------------------------------------------------------------===// +// Subset parameters +//===----------------------------------------------------------------------===// + +func nondiff(_ f: @differentiable (Float, @nondiff Float) -> Float) -> Float { + // expected-note @+2 {{cannot differentiate with respect to a '@nondiff' parameter}} + // expected-error @+1 {{function is not differentiable}} + return derivative(at: 2, 3) { (x, y) in f(x * x, y) } +} diff --git a/test/AutoDiff/forward_mode_runtime.swift b/test/AutoDiff/forward_mode_runtime.swift new file mode 100644 index 0000000000000..81a215837bb3f --- /dev/null +++ b/test/AutoDiff/forward_mode_runtime.swift @@ -0,0 +1,311 @@ +// RUN: %target_run_simple_swift_forward_mode_differentiation +// REQUIRES: executable_test + +import StdlibUnittest +import DifferentiationUnittest + +var ForwardModeTests = TestSuite("ForwardMode") + +//===----------------------------------------------------------------------===// +// Basic tests. +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("Unary") { + func func_to_diff(x: Float) -> Float { + return x * x + } + let (y, differential) = valueWithDifferential(at: 4, in: func_to_diff) + expectEqual(16, y) + expectEqual(8, differential(1)) +} + +ForwardModeTests.test("Binary") { + func func_to_diff(x: Float, y: Float) -> Float { + return x * y + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: func_to_diff) + expectEqual(20, y) + expectEqual(9, differential(1, 1)) +} + +ForwardModeTests.test("BinaryWithLets") { + func func_to_diff(x: Float, y: Float) -> Float { + let a = x + y + let b = a + return b * -y + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: func_to_diff) + expectEqual(-45, y) + expectEqual(-19, differential(1, 1)) +} + +//===----------------------------------------------------------------------===// +// `Tracked` struct +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("TrackedIdentity") { + func identity(x: Tracked) -> Tracked { + return x + } + let (y, differential) = valueWithDifferential(at: 4, in: identity) + expectEqual(4, y) + expectEqual(1, differential(1)) +} + +ForwardModeTests.test("TrackedAddition") { + func add(x: Tracked, y: Tracked) -> Tracked { + return x + y + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: add) + expectEqual(9, y) + expectEqual(2, differential(1, 1)) +} + +ForwardModeTests.test("TrackedDivision") { + func divide(x: Tracked, y: Tracked) -> Tracked { + return x / y + } + let (y, differential) = valueWithDifferential(at: 10, 5, in: divide) + expectEqual(2, y) + expectEqual(-0.2, differential(1, 1)) +} + +ForwardModeTests.test("TrackedMultipleMultiplication") { + func add(x: Tracked, y: Tracked) -> Tracked { + return x * y * x + } + let (y, differential) = valueWithDifferential(at: 4, 5, in: add) + expectEqual(80, y) + // 2yx+xx + expectEqual(56, differential(1, 1)) +} + +ForwardModeTests.test("TrackedWithLets") { + func add(x: Tracked, y: Tracked) -> Tracked { + let a = x + y + let b = a * a // (x+y)^2 + let c = b / x + y // (x+y)^2/x+y + return c + } + // (3x^2+2xy-y^2)/x^2+1 + let (y, differential) = valueWithDifferential(at: 4, 5, in: add) + expectEqual(25.25, y) + expectEqual(4.9375, differential(1, 1)) +} + +ForwardModeTests.test("TrackedDifferentiableFuncType") { + func valAndDeriv( + f: @escaping @differentiable (Tracked) -> Tracked + ) -> (Tracked, Tracked) { + let (y, diff) = valueWithDifferential(at: 5, in: f) + return (y, diff(1)) + } + + func func1(_ x: Tracked) -> Tracked { + let a = x + x // 2x + let b = a + a // 4x + return b * b // 16x^2 + } + let (val1, dv1) = valAndDeriv(f: func1) + expectEqual(400, val1) + expectEqual(160, dv1) +} +//===----------------------------------------------------------------------===// +// Classes +//===----------------------------------------------------------------------===// +// NOTE: once forward mode is done, can copy and replace this in +// `class_method.swift` as it already calls reverse mode functions. + +ForwardModeTests.test("Final") { + final class Final : Differentiable { + func method(_ x: Float) -> Float { + return x * x + } + } + + for i in -5...5 { + expectEqual(Float(i) * 2, gradient(at: Float(i)) { x in Final().method(x) }) + expectEqual( + Float(i) * 2, + derivative(at: Float(i)) { x in Final().method(x) }) + } +} + +ForwardModeTests.test("Simple") { + class Super { + @differentiable(wrt: x, jvp: jvpf, vjp: vjpf) + func f(_ x: Float) -> Float { + return 2 * x + } + final func jvpf(_ x: Float) -> (Float, (Float) -> Float) { + return (f(x), { v in 2 * v }) + } + final func vjpf(_ x: Float) -> (Float, (Float) -> Float) { + return (f(x), { v in 2 * v }) + } + } + + class SubOverride : Super { + @differentiable(wrt: x) + override func f(_ x: Float) -> Float { + return 3 * x + } + } + + class SubOverrideCustomDerivatives : Super { + @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2) + override func f(_ x: Float) -> Float { + return 3 * x + } + final func jvpf2(_ x: Float) -> (Float, (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + final func vjpf2(_ x: Float) -> (Float, (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + } + + func classValueWithDerivative(_ c: Super) -> (Float, Float) { + return valueWithDerivative(at: 1) { c.f($0) } + } + func classValueWithGradient(_ c: Super) -> (Float, Float) { + return valueWithGradient(at: 1) { c.f($0) } + } + + expectEqual((2, 2), classValueWithDerivative(Super())) + expectEqual((3, 3), classValueWithDerivative(SubOverride())) + expectEqual((3, 3), classValueWithDerivative(SubOverrideCustomDerivatives())) + expectEqual((2, 2), classValueWithGradient(Super())) + expectEqual((3, 3), classValueWithGradient(SubOverride())) + expectEqual((3, 3), classValueWithGradient(SubOverrideCustomDerivatives())) +} + +ForwardModeTests.test("SimpleWrtSelf") { + class Super : Differentiable { + var base: Float + // FIXME(TF-648): Dummy to make `Super.AllDifferentiableVariables` be nontrivial. + var _nontrivial: [Float] = [] + + // TODO(TF-654): Uncomment attribute when differentiation supports class initializers. + // TODO(TF-645): Remove `vjpInit` when differentiation supports `ref_element_addr`. + // @differentiable(vjp: vjpInit) + required init(base: Float) { + self.base = base + } + static func vjpInit(base: Float) -> (Super, (TangentVector) -> Float) { + return (Super(base: base), { x in x.base }) + } + + static func jvpInit(base: Float) -> (Super, (Float) -> TangentVector) { + return (Super(base: base), { x in TangentVector(base: x, _nontrivial: []) }) + } + + @differentiable(wrt: (self, x), jvp: jvpf, vjp: vjpf) + func f(_ x: Float) -> Float { + return base * x + } + final func jvpf(_ x: Float) -> (Float, (TangentVector, Float) -> Float) { + return (f(x), { (dself, dx) in dself.base * dx }) + } + final func vjpf(_ x: Float) -> (Float, (Float) -> (TangentVector, Float)) { + let base = self.base + return (f(x), { v in + (TangentVector(base: v * x, _nontrivial: []), base * v) + }) + } + } + + class SubOverride : Super { + @differentiable(wrt: (self, x)) + override func f(_ x: Float) -> Float { + return 3 * x + } + } + + class SubOverrideCustomDerivatives : Super { + @differentiable(wrt: (self, x)) + @differentiable(wrt: x, jvp: jvpf2, vjp: vjpf2) + override func f(_ x: Float) -> Float { + return 3 * x + } + final func jvpf2(_ x: Float) -> (Float, (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + final func vjpf2(_ x: Float) -> (Float, (Float) -> Float) { + return (f(x), { v in 3 * v }) + } + } + + // TODO(TF-654): Uncomment when differentiation supports class initializers. + // let v = Super.TangentVector(base: 100, _nontrivial: []) + // expectEqual(100, pullback(at: 1337) { x in Super(base: x) }(v)) + // expectEqual(100, pullback(at: 1337) { x in SubOverride(base: x) }(v)) + // expectEqual(100, pullback(at: 1337) { x in SubOverrideCustomDerivatives(base: x) }(v)) + + + // `valueWithGradient` is not used because nested tuples cannot be compared + // with `expectEqual`. + func classGradient(_ c: Super) -> (Super.TangentVector, Float) { + return gradient(at: c, 10) { c, x in c.f(x) } + } + + // `valueWithDerivative` is not used because the derivative requires `Super` + // to conform to `FloatingPoint`. + func classDifferential( + _ c: Super + ) -> (Float, (Super.TangentVector, Float) -> Float) { + return valueWithDifferential(at: c, 10) { (c: Super, x: Float) in c.f(x) } + } + + let (y1, diff1) = classDifferential(Super(base: 5)) + expectEqual(50, y1) + let c1 = Super.TangentVector(base: 1, _nontrivial: []) + expectEqual(1, diff1(c1, 1)) + let (y2, diff2) = classDifferential(SubOverride(base: 5)) + expectEqual(30, y2) + let c2 = SubOverride.TangentVector(base: 1, _nontrivial: []) + expectEqual(3, diff2(c2, 1)) + let (y3, diff3) = classDifferential(SubOverrideCustomDerivatives(base: 5)) + expectEqual(30, y3) + let c3 = SubOverrideCustomDerivatives.TangentVector(base: 1, _nontrivial: []) + expectEqual(3, diff3(c3, 1)) + expectEqual((Super.TangentVector(base: 10, _nontrivial: []), 2), + classGradient(Super(base: 2))) + expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3), + classGradient(SubOverride(base: 2))) + expectEqual((Super.TangentVector(base: 0, _nontrivial: []), 3), + classGradient(SubOverrideCustomDerivatives(base: 2))) +} + +//===----------------------------------------------------------------------===// +// Protocols +//===----------------------------------------------------------------------===// +// TODO: add more protocol tests. +// protocol DiffReq : Differentiable { +// @differentiable(wrt: x) +// func foo(x: Float) -> Float +// } + +// struct Linear: DiffReq, VectorProtocol { +// typealias TangentVector = Linear + +// let m: Float +// let b: Float + +// @differentiable(wrt: x) +// func foo(x: Float) -> Float { +// return m * x + b +// } +// } + +// ForwardModeTests.test("Protocols") { +// func genericFoo(_ t: T, _ x: Float) -> Float { +// t.foo(x: x) +// } +// let inst = Linear(m: 5, b: -2) +// let (y1, diff1) = valueWithDifferential(at: 5) { x in genericFoo(inst, x) } +// expectEqual(23, y1) +// expectEqual(5, diff1(1)) +// } + +runAllTests() diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift new file mode 100644 index 0000000000000..12e529a85226f --- /dev/null +++ b/test/AutoDiff/forward_mode_sil.swift @@ -0,0 +1,85 @@ +// RUN: %target-swift-frontend -emit-sil -Xllvm -run-jvp-generation -verify %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES +// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -Xllvm -run-jvp-generation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL + + +//===----------------------------------------------------------------------===// +// Unary +//===----------------------------------------------------------------------===// + +@differentiable +@_silgen_name("unary") +func unary(_ x: Float) -> Float { + return x * x * x +} +// CHECK-DATA-STRUCTURES: struct _AD__unary_bb0__DF__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: @_hasStorage var differential_0: (Float, Float) -> Float { get set } +// CHECK-DATA-STRUCTURES: @_hasStorage var differential_1: (Float, Float) -> Float { get set } +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__unary_bb0__Succ__src_0_wrt_0 { +// CHECK-DATA-STRUCTURES: } + +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) { +// CHECK-SIL: bb0([[X_ARG:%.*]] : $Float): +// CHECK-SIL: [[MULT_FUNC_1:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_FUNC_JVP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[MULT_FUNC_VJP_1:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) +// CHECK-SIL: [[AUTODIFF_INST_1:%.*]] = autodiff_function [wrt 0 1] [order 1] [[MULT_FUNC_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_1]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = autodiff_function_extract [jvp] [order 1] [[AUTODIFF_INST_1]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_1:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[X_ARG]], [[X_ARG]], %3) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: ([[ORIG_RESULT_1:%.*]], [[MULT_DIFF_1:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_1]] : $(Float, @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[MULT_FUNC_2:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_FUNC_JVP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[MULT_FUNC_VJP_2:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) +// CHECK-SIL: [[AUTODIFF_INST_2:%.*]] = autodiff_function [wrt 0 1] [order 1] [[MULT_FUNC_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP_2]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_EXTRACT_INST_1:%.*]] = autodiff_function_extract [jvp] [order 1] [[AUTODIFF_INST_2]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE_2:%.*]] = apply [[AUTODIFF_EXTRACT_INST_1]]([[ORIG_RESULT_1]], [[X_ARG]], %2) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: ([[ORIG_RESULT_2:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE_2]] : $(Float, @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__unary_bb0__DF__src_0_wrt_0 ([[MULT_DIFF_1]] : $@callee_guaranteed (Float, Float) -> Float, [[MULT_DIFF_2]] : $@callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[UNARY_DIFFERENTIAL:%.*]] = function_ref @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float +// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[UNARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float +// CHECK-SIL: [[RESULT:%.*]] = tuple ([[ORIG_RESULT_2]] : $Float, [[PARTIAL_APP_DIFFERENTIAL]] : $@callee_guaranteed (Float) -> Float) +// CHECK-SIL: return [[RESULT]] : $(Float, @callee_guaranteed (Float) -> Float) + +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__unary__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__unary_bb0__DF__src_0_wrt_0) -> Float { +// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : @owned $_AD__unary_bb0__DF__src_0_wrt_0): +// CHECK-SIL: ([[MULT_DIFF_1:%.*]], [[MULT_DIFF_2:%.*]]) = destructure_struct %1 : $_AD__unary_bb0__DF__src_0_wrt_0 +// CHECK-SIL: [[TEMP_TAN_1:%.*]] = apply [[MULT_DIFF_1]]([[X_TAN]], [[X_TAN]]) : $@callee_guaranteed (Float, Float) -> Float +// CHECK-SIL: [[TAN_RESULT:%.*]] = apply [[MULT_DIFF_2]]([[TEMP_TAN_1]], [[X_TAN]]) : $@callee_guaranteed (Float, Float) -> Float +// CHECK-SIL: return [[TAN_RESULT]] : $Float + +//===----------------------------------------------------------------------===// +// Binary +//===----------------------------------------------------------------------===// + +@differentiable +@_silgen_name("binary") +func binary(x: Float, y: Float) -> Float { + return x * y +} + +// CHECK-DATA-STRUCTURES: struct _AD__binary_bb0__DF__src_0_wrt_0_1 { +// CHECK-DATA-STRUCTURES: @_hasStorage var differential_0: (Float, Float) -> Float { get set } +// CHECK-DATA-STRUCTURES: } +// CHECK-DATA-STRUCTURES: enum _AD__binary_bb0__Succ__src_0_wrt_0_1 { +// CHECK-DATA-STRUCTURES: } + +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__jvp_src_0_wrt_0_1 : $@convention(thin) (Float, Float) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) { +// CHECK-SIL: bb0([[X_ARG:%.*]] : $Float, [[Y_ARG:%.*]] : $Float): +// CHECK-SIL: [[MULT_FUNC:%.*]] = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_FUNC_JVP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__jvp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[MULT_FUNC_VJP:%.*]] = function_ref @AD__$sSf1moiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float)) +// CHECK-SIL: [[AUTODIFF_INST:%.*]] = autodiff_function [wrt 0 1] [order 1] [[MULT_FUNC]] : $@convention(method) (Float, Float, @thin Float.Type) -> Float with {[[MULT_FUNC_JVP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float), [[MULT_FUNC_VJP]] : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))} +// CHECK-SIL: [[AUTODIFF_EXTRACT_INST:%.*]] = autodiff_function_extract [jvp] [order 1] [[AUTODIFF_INST]] : $@differentiable @convention(method) (Float, Float, @nondiff @thin Float.Type) -> Float +// CHECK-SIL: [[MULT_JVP_APPLY_TUPLE:%.*]] = apply [[AUTODIFF_EXTRACT_INST]]([[X_ARG]], [[Y_ARG]], %4) : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: ([[ORIG_RESULT:%.*]], [[MULT_DIFF:%.*]]) = destructure_tuple [[MULT_JVP_APPLY_TUPLE]] : $(Float, @callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[DIFF_STRUCT:%.*]] = struct $_AD__binary_bb0__DF__src_0_wrt_0_1 ([[MULT_DIFF]] : $@callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: [[BINARY_DIFFERENTIAL:%.*]] = function_ref @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float +// CHECK-SIL: [[PARTIAL_APP_DIFFERENTIAL:%.*]] = partial_apply [callee_guaranteed] [[BINARY_DIFFERENTIAL]]([[DIFF_STRUCT]]) : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float +// CHECK-SIL: [[RESULT:%.*]] = tuple ([[ORIG_RESULT]] : $Float, [[PARTIAL_APP_DIFFERENTIAL]] : $@callee_guaranteed (Float, Float) -> Float) +// CHECK-SIL: return [[RESULT:%.*]] : $(Float, @callee_guaranteed (Float, Float) -> Float) + +// CHECK-SIL-LABEL: sil hidden [ossa] @AD__binary__differential_src_0_wrt_0_1 : $@convention(thin) (Float, Float, @owned _AD__binary_bb0__DF__src_0_wrt_0_1) -> Float { +// CHECK-SIL: bb0([[X_TAN:%.*]] : $Float, [[Y_TAN:%.*]] : $Float, [[DIFF_STRUCT:%.*]] : @owned $_AD__binary_bb0__DF__src_0_wrt_0_1): +// CHECK-SIL: [[MULT_DIFF:%.*]] = destructure_struct [[DIFF_STRUCT]] : $_AD__binary_bb0__DF__src_0_wrt_0_1 +// CHECK-SIL: [[TAN_RESULT:%.*]] = apply [[MULT_DIFF]]([[X_TAN]], [[Y_TAN]]) : $@callee_guaranteed (Float, Float) -> Float +// CHECK-SIL: return [[TAN_RESULT]] : $Float diff --git a/test/lit.cfg b/test/lit.cfg index 9f5ea7be43de6..1da02449f0117 100644 --- a/test/lit.cfg +++ b/test/lit.cfg @@ -1529,6 +1529,14 @@ if not getattr(config, 'target_run_simple_swift', None): '%%line-directive %%t/main.swift -- ' '%s %%t/a.out' % (config.target_build_swift, mcp_opt, config.target_codesign, config.target_run)) + # SWIFT_ENABLE_TENSORFLOW + # TODO: Remove when forward mode AD support is robust. + config.target_run_simple_swift_forward_mode_differentiation = ( + '%%empty-directory(%%t) && ' + '%s %s %%s -Xllvm -run-jvp-generation -o %%t/a.out %s -module-name main && ' + '%s %%t/a.out &&' + '%s %%t/a.out' + % (config.target_build_swift, mcp_opt, swift_tensorflow_extra_options, config.target_codesign, config.target_run)) # # When changing substitutions, update docs/Testing.md. @@ -1551,7 +1559,9 @@ config.substitutions.append(('%target-swift-frontend\(mock-sdk:([^)]+)\)', SubstituteCaptures(r'%s \1 %s' % (subst_target_swift_frontend_mock_sdk, subst_target_swift_frontend_mock_sdk_after)))) config.substitutions.append(('%target-swift-frontend', config.target_swift_frontend)) - +# SWIFT_ENABLE_TENSORFLOW +# TODO: Remove when forward mode AD support is robust. +config.substitutions.append(('%target_run_simple_swift_forward_mode_differentiation', config.target_run_simple_swift_forward_mode_differentiation)) config.substitutions.append(('%target-run-simple-swiftgyb', config.target_run_simple_swiftgyb)) config.substitutions.append(('%target-run-simple-swift\(([^)]+)\)', config.target_run_simple_swift_parameterized))