diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 2b0f1e3610bcd..b6772f989d590 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -12,7 +12,7 @@ // // SWIFT_ENABLE_TENSORFLOW // -// This file implements reverse-mode automatic differentiation. +// This file implements automatic differentiation. // // NOTE: Although the AD feature is developed as part of the Swift for // TensorFlow project, it is completely independent from TensorFlow support. @@ -3803,6 +3803,42 @@ class JVPEmitter final bbDifferentialValues); } + bool shouldBeDifferentiated(SILInstruction *inst, + const SILAutoDiffIndices &indices) { + // Anything with an active result should be differentiated. + if (llvm::any_of(inst->getResults(), [&](SILValue val) { + return activityInfo.isActive(val, indices); + })) + return true; + // Anything with an an active argument should be differentiated + // (i.e. `return %0`). + if (llvm::any_of(inst->getAllOperands(), [&](Operand &val) { + return activityInfo.isActive(val.get(), indices); + })) + return true; + if (auto *ai = dyn_cast(inst)) { + // Function applications with an active indirect result should be + // differentiated. + for (auto indRes : ai->getIndirectSILResults()) + if (activityInfo.isActive(indRes, indices)) + return true; + // Function applications with an inout argument should be differentiated. + auto paramInfos = ai->getSubstCalleeConv().getParameters(); + for (auto i : swift::indices(paramInfos)) + if (paramInfos[i].isIndirectInOut() && + activityInfo.isActive( + ai->getArgumentsWithoutIndirectResults()[i], indices)) + return true; + } + // Instructions that may write to memory and that have an active operand + // should be differentiated. + if (inst->mayWriteToMemory()) + for (auto &op : inst->getAllOperands()) + if (activityInfo.isActive(op.get(), indices)) + return true; + return false; + } + //--------------------------------------------------------------------------// // Tangent materialization //--------------------------------------------------------------------------// @@ -4168,6 +4204,164 @@ class JVPEmitter final // ; } + /// Handle `struct_extract` instruction. + /// Original: y = struct_extract x, #field + /// Tangent: tan[y] = struct_extract tan[x], tan[#field]] + /// ^~~~~~~ + /// field in tangent space corresponding to #field + void visitStructExtractInstDifferential(StructExtractInst *sei) { + assert(!sei->getField()->getAttrs().hasAttribute() && + "`struct_extract` with `@noDerivative` field should not be " + "differentiated; activity analysis should not marked as varied"); + + auto diffBuilder = getDifferentialBuilder(); + auto structTy = remapType(sei->getOperand()->getType()).getASTType(); + auto tangentVectorTy = + getTangentSpace(structTy)->getType()->getCanonicalType(); + assert(!getModule().Types.getTypeLowering( + tangentVectorTy, ResilienceExpansion::Minimal) + .isAddressOnly()); + auto *tangentVectorDecl = + tangentVectorTy->getStructOrBoundGenericStruct(); + assert(tangentVectorDecl); + + // Get the tangent of the field and create the extract inst in the SIL + // of the differential. + // Find the corresponding field in the tangent space. + VarDecl *tanField = nullptr; + // If the tangent space is the original struct, then field is the same. + if (tangentVectorDecl == sei->getStructDecl()) + tanField = sei->getField(); + // Otherwise, look up the field by name. + else { + auto tanFieldLookup = + tangentVectorDecl->lookupDirect(sei->getField()->getName()); + if (tanFieldLookup.empty()) { + context.emitNondifferentiabilityError( + sei, invoker, + diag::autodiff_stored_property_no_corresponding_tangent, + sei->getStructDecl()->getNameStr(), + sei->getField()->getNameStr()); + errorOccurred = true; + return; + } + tanField = cast(tanFieldLookup.front()); + } + + // Get the Tangent of the operand (the struct) + auto tanOperand = + materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc()); + + // Emit the instruction + auto tangentExtractInst = + diffBuilder.createStructExtract(sei->getLoc(), tanOperand, tanField); + + // Add tangent for original result into value mapping. + auto tangentResult = makeConcreteTangentValue(tangentExtractInst); + addTangentValue(sei->getParent(), sei, tangentResult); + } + + /// Handle `load` instruction. + /// Original: y = load x + /// Tangent: tan[y] = load tan[x] + void visitLoadInstDifferential(LoadInst *li) { + auto *bb = li->getParent(); + auto diffBuilder = getDifferentialBuilder(); + + auto tanValSrc = getTangentBuffer(bb, li->getOperand()); + auto *tanValDest = diffBuilder.createLoad(li->getLoc(), tanValSrc, + getBufferLOQ(li->getType().getASTType(), + getDifferential())); + addTangentValue(bb, li, makeConcreteTangentValue( + tanValDest)); + } + + /// Handle `store` instruction in the differential. + /// Original: store x to y + /// Tangent: store tan[x] to tan[y] + void visitStoreInstDifferential(StoreInst *si) { + auto *bb = si->getParent(); + auto &diffBuilder = getDifferentialBuilder(); + auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), + si->getLoc()); + auto tanValDest = getTangentBuffer(bb, si->getDest()); + diffBuilder.createStore(si->getLoc(), tanValSrc, tanValDest, + getBufferSOQ(tanValDest->getType().getASTType(), + getDifferential())); + } + + /// Handle `copy_addr` instruction. + /// Original: copy_addr x to y + /// Tangent: copy_addr tan[x] to tan[y] + void visitCopyAddrInstDifferential(CopyAddrInst *cai) { + auto diffBuilder = getDifferentialBuilder(); + auto *bb = cai->getParent(); + auto &adjDest = getTangentBuffer(bb, cai->getDest()); + if (errorOccurred) + return; + + // Begin access, set the corresponding tangent buffer, and end access. + auto *readAccess = diffBuilder.createBeginAccess( + cai->getLoc(), adjDest, SILAccessKind::Read, + SILAccessEnforcement::Static, /*noNestedConflict*/ true, + /*fromBuiltin*/ false); + setTangentBuffer(bb, cai->getSrc(), readAccess); + diffBuilder.createEndAccess(cai->getLoc(), readAccess, /*aborted*/ false); + } + + /// Handle `begin_access` instruction. + /// Original: y = begin_access x + /// Tangent: nothing (differentiability checks) + void visitBeginAccessInstDifferential(BeginAccessInst *bai) { + // Check for non-differentiable writes. + if (bai->getAccessKind() == SILAccessKind::Modify) { + if (auto *gai = dyn_cast(bai->getSource())) { + context.emitNondifferentiabilityError(bai, invoker, + diag::autodiff_cannot_differentiate_writes_to_global_variables); + errorOccurred = true; + return; + } + if (auto *pbi = dyn_cast(bai->getSource())) { + context.emitNondifferentiabilityError(bai, invoker, + diag::autodiff_cannot_differentiate_writes_to_mutable_captures); + errorOccurred = true; + return; + } + } + } + + /// Add the value mapping and emit the same instruction. + void visitAllocStackInstDifferential(AllocStackInst *asi) { + auto &diffBuilder = getDifferentialBuilder(); + + auto *mappedAllocStackInst = + diffBuilder.createAllocStack( + asi->getLoc(), getRemappedTangentType(asi->getElementType())); + bufferMap.try_emplace({asi->getParent(), asi}, + mappedAllocStackInst); + } + + /// Emit the same instruction but on the tangent instead. + void visitDeallocStackInstDifferential(DeallocStackInst *dsi) { + auto &diffBuilder = getDifferentialBuilder(); + auto tanBuffer = getTangentBuffer(dsi->getParent(), dsi->getOperand()); + diffBuilder.createDeallocStack(dsi->getLoc(), tanBuffer); + } + + void visitStructInstDifferential(StructInst *si) { + auto diffBuilder = getDifferentialBuilder(); + auto *bb = si->getParent(); + auto loc = si->getLoc(); + SmallVector tangentElements; + for (auto elem : si->getElements()) + tangentElements.push_back(getTangentValue(elem).getConcreteValue()); + + auto tanExtract = diffBuilder.createStruct(loc, si->getType(), + tangentElements); + + addTangentValue(bb, si, makeConcreteTangentValue(tanExtract)); + } + public: explicit JVPEmitter(ADContext &context, SILFunction *original, SILDifferentiableAttr *attr, SILFunction *jvp, @@ -4275,17 +4469,16 @@ class JVPEmitter final // Create differential blocks and arguments. // TODO: Consider visiting original blocks in pre-order (dominance) order. - SmallVector preOrderDomOrder; auto &differential = getDifferential(); auto *origEntry = original->getEntryBlock(); for (auto &origBB : *original) { auto *diffBB = differential.createBasicBlock(); diffBBMap.insert({&origBB, diffBB}); auto diffStructLoweredType = - remapType(differentialInfo.getLinearMapStructLoweredType(&origBB)); - // If the BB is the original exit, then the differential block that we just - // created must be the differential function's entry. Create differential - // entry arguments and continue. + remapType(differentialInfo.getLinearMapStructLoweredType(&origBB)); + // If the BB is the original entry, then the differential block that we + // just created must be the differential function's entry. Create + // differential entry arguments and continue. if (&origBB == origEntry) { assert(diffBB->isEntry()); createEntryArguments(&differential); @@ -4293,7 +4486,18 @@ class JVPEmitter final assert(lastArg->getType() == diffStructLoweredType); differentialStructArguments[&origBB] = lastArg; } + + LLVM_DEBUG({ + auto &s = getADDebugStream() + << "Original bb" + std::to_string(origBB.getDebugID()) + << ": To differentiate or not to differentiate?\n"; + for (auto &inst : origBB) { + s << (shouldBeDifferentiated(&inst, getIndices()) ? "[∂] " : "[ ] ") + << inst; + } + }); } + assert(diffBBMap.size() == 1 && "Can only currently handle single basic block functions"); @@ -4302,7 +4506,8 @@ class JVPEmitter final auto &diffBuilder = getDifferentialBuilder(); auto diffParamArgs = differential.getArgumentsWithoutIndirectResults().drop_back(); - assert(diffParamArgs.size() == attr->getIndices().parameters->getCapacity()); + assert(diffParamArgs.size() == + attr->getIndices().parameters->getNumIndices()); auto origParamArgs = original->getArgumentsWithoutIndirectResults(); // Check if result is not varied. @@ -4363,11 +4568,35 @@ class JVPEmitter final if (errorOccurred) return true; + LLVM_DEBUG(getADDebugStream() << "Generated differential for " + << original->getName() << ":\n" << differential); LLVM_DEBUG(getADDebugStream() << "Generated JVP for " << original->getName() << ":\n" << *jvp); return errorOccurred; } + void visit(SILInstruction *inst) { + auto diffBuilder = getDifferentialBuilder(); + if (errorOccurred) + return; + if (shouldBeDifferentiated(inst, getIndices())) { + LLVM_DEBUG(getADDebugStream() << "JVPEmitter visited:\n[ORIG]" + << *inst); +#ifndef NDEBUG + auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint()); +#endif + SILInstructionVisitor::visit(inst); // TypeSubstCloner::visit(inst); + LLVM_DEBUG({ + auto &s = llvm::dbgs() << "[DF] Emitted in Differential:\n"; + auto afterInsertion = diffBuilder.getInsertionPoint(); + for (auto it = ++beforeInsertion; it != afterInsertion; ++it) + s << *it; + }); + } else { + SILInstructionVisitor::visit(inst); // TypeSubstCloner::visit(inst); + } + } + void postProcess(SILInstruction *orig, SILInstruction *cloned) { if (errorOccurred) return; @@ -4380,54 +4609,6 @@ class JVPEmitter final return jvpBB; } - /// General visitor for all instructions. If any error is emitted by previous - /// visits, bail out. - void visit(SILInstruction *inst) { - if (errorOccurred) - return; - TypeSubstCloner::visit(inst); - } - - void visitSILInstruction(SILInstruction *inst) { - context.emitNondifferentiabilityError(inst, invoker, - diag::autodiff_expression_not_differentiable_note); - errorOccurred = true; - } - - void visitReturnInst(ReturnInst *ri) { - auto loc = ri->getOperand().getLoc(); - auto *origExit = ri->getParent(); - auto &builder = getBuilder(); - auto *diffStructVal = buildDifferentialValueStructValue(ri); - - // Get the JVP value corresponding to the original functions's return value. - auto *origRetInst = cast(origExit->getTerminator()); - auto origResult = getOpValue(origRetInst->getOperand()); - SmallVector origResults; - extractAllElements(origResult, builder, origResults); - - // Get and partially apply the differential. - auto jvpGenericEnv = jvp->getGenericEnvironment(); - auto jvpSubstMap = jvpGenericEnv - ? jvpGenericEnv->getForwardingSubstitutionMap() - : jvp->getForwardingSubstitutionMap(); - auto *differentialRef = - builder.createFunctionRef(loc, &getDifferential()); - auto *differentialPartialApply = builder.createPartialApply( - loc, differentialRef, jvpSubstMap, {diffStructVal}, - ParameterConvention::Direct_Guaranteed); - - // Return a tuple of the original result and pullback. - SmallVector directResults; - directResults.append(origResults.begin(), origResults.end()); - directResults.push_back(differentialPartialApply); - builder.createReturn( - ri->getLoc(), joinElements(directResults, builder, loc)); - - // Differential emission. - visitReturnInstDifferential(ri); - } - // If an `apply` has active results or active inout parameters, replace it // with an `apply` of its JVP. void visitApplyInst(ApplyInst *ai) { @@ -4649,7 +4830,115 @@ class JVPEmitter final differentialValues[ai->getParent()].push_back(diffFunc); // Differential emission. - visitApplyInstDifferential(ai, indices); + if (shouldBeDifferentiated(ai, getIndices())) + visitApplyInstDifferential(ai, indices); + } + + void visitReturnInst(ReturnInst *ri) { + auto loc = ri->getOperand().getLoc(); + auto *origExit = ri->getParent(); + auto &builder = getBuilder(); + auto *diffStructVal = buildDifferentialValueStructValue(ri); + + // Get the JVP value corresponding to the original functions's return value. + auto *origRetInst = cast(origExit->getTerminator()); + auto origResult = getOpValue(origRetInst->getOperand()); + SmallVector origResults; + extractAllElements(origResult, builder, origResults); + + // Get and partially apply the differential. + auto jvpGenericEnv = jvp->getGenericEnvironment(); + auto jvpSubstMap = jvpGenericEnv + ? jvpGenericEnv->getForwardingSubstitutionMap() + : jvp->getForwardingSubstitutionMap(); + auto *differentialRef = + builder.createFunctionRef(loc, &getDifferential()); + auto *differentialPartialApply = builder.createPartialApply( + loc, differentialRef, jvpSubstMap, {diffStructVal}, + ParameterConvention::Direct_Guaranteed); + + // Return a tuple of the original result and pullback. + SmallVector directResults; + directResults.append(origResults.begin(), origResults.end()); + directResults.push_back(differentialPartialApply); + builder.createReturn( + ri->getLoc(), joinElements(directResults, builder, loc)); + + // Differential emission. + if (shouldBeDifferentiated(ri, getIndices())) + visitReturnInstDifferential(ri); + } + + void visitLoadInst(LoadInst *li) { + TypeSubstCloner::visitLoadInst(li); + if (shouldBeDifferentiated(li, getIndices())) + visitLoadInstDifferential(li); + } + + void visitStoreInst(StoreInst *si) { + TypeSubstCloner::visitStoreInst(si); + if (shouldBeDifferentiated(si, getIndices())) + visitStoreInstDifferential(si); + } + + void visitCopyAddrInst(CopyAddrInst *cai) { + TypeSubstCloner::visitCopyAddrInst(cai); + if (shouldBeDifferentiated(cai, getIndices())) + visitCopyAddrInstDifferential(cai); + } + + void visitBeginAccessInst(BeginAccessInst *bai) { + TypeSubstCloner::visitBeginAccessInst(bai); + if (shouldBeDifferentiated(bai, getIndices())) + visitBeginAccessInstDifferential(bai); + } + + void visitAllocStackInst(AllocStackInst *asi) { + TypeSubstCloner::visitAllocStackInst(asi); + if (shouldBeDifferentiated(asi, getIndices())) + visitAllocStackInstDifferential(asi); + } + + void visitDeallocStackInst(DeallocStackInst *dsi) { + TypeSubstCloner::visitDeallocStackInst(dsi); + if (shouldBeDifferentiated(dsi, getIndices())) + visitDeallocStackInstDifferential(dsi); + } + + void visitStructExtractInst(StructExtractInst *sei) { + TypeSubstCloner::visitStructExtractInst(sei); + if (shouldBeDifferentiated(sei, getIndices())) + visitStructExtractInstDifferential(sei); + } + + void visitStructInst(StructInst *si) { + TypeSubstCloner::visitStructInst(si); + if (shouldBeDifferentiated(si, getIndices())) + visitStructInstDifferential(si); + } + + void visitArrayInitialization(ApplyInst *ai) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitTupleInst(TupleInst *ai) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitTupleExtractInst(TupleExtractInst *ai) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitBranchInst(BranchInst *bi) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitCondBranchInst(CondBranchInst *cbi) { + llvm_unreachable("Unsupported SIL instruction."); + } + + void visitSwitchEnumInst(SwitchEnumInst *sei) { + llvm_unreachable("Unsupported SIL instruction."); } void visitAutoDiffFunctionInst(AutoDiffFunctionInst *adfi) { @@ -4659,6 +4948,12 @@ class JVPEmitter final auto *newADFI = cast(getOpValue(adfi)); context.getAutoDiffFunctionInsts().push_back(newADFI); } + + void visitSILInstruction(SILInstruction *inst) { + context.emitNondifferentiabilityError(inst, invoker, + diag::autodiff_expression_not_differentiable_note); + errorOccurred = true; + } }; } // end anonymous namespace @@ -5112,7 +5407,7 @@ class PullbackEmitter final : public SILInstructionVisitor { //--------------------------------------------------------------------------// // Other utilities //--------------------------------------------------------------------------// - + bool shouldBeDifferentiated(SILInstruction *inst, const SILAutoDiffIndices &indices) { // Anything with an active result should be differentiated. diff --git a/test/AutoDiff/forward_mode_runtime.swift b/test/AutoDiff/forward_mode_runtime.swift index 50a6c9a36c3c4..44d4ec8e08c16 100644 --- a/test/AutoDiff/forward_mode_runtime.swift +++ b/test/AutoDiff/forward_mode_runtime.swift @@ -5,6 +5,8 @@ import StdlibUnittest var ForwardModeTests = TestSuite("ForwardMode") +// Basic Float constant functions. + ForwardModeTests.test("Unary") { func func_to_diff(x: Float) -> Float { return x * x @@ -34,4 +36,117 @@ ForwardModeTests.test("BinaryWithLets") { expectEqual(-19, differential(1, 1)) } +// Functions with variables. + +ForwardModeTests.test("UnaryWithVars") { + func unary(x: Float) -> Float { + var a = x + a = x + var b = a + 2 + b = b - 1 + let c: Float = 3 + var d = a + b + c - 1 + d = d + d + return d + } + + let (y, differential) = valueWithDifferential(at: 4, in: unary) + expectEqual(22, y) + expectEqual(4, differential(1)) +} + +// Functions with basic struct + +struct A: Differentiable & AdditiveArithmetic { + var x: Float + } + +ForwardModeTests.test("StructInit") { + func structInit(x: Float) -> A { + return A(x: 2 * x) + } + + let (y, differential) = valueWithDifferential(at: 4, in: structInit) + expectEqual(A(x: 8), y) + expectEqual(A(x: 2), differential(1)) +} + +ForwardModeTests.test("StructExtract") { + func structExtract(x: A) -> Float { + return 2 * x.x + } + + let (y, differential) = valueWithDifferential( + at: A(x: 4), + in: structExtract) + expectEqual(8, y) + expectEqual(2, differential(A(x: 1))) +} + +ForwardModeTests.test("LocalStructVariable") { + func structExtract(x: A) -> A { + let a = A(x: 2 * x.x) // 2x + var b = A(x: a.x + 2) // 2x + 2 + b = A(x: b.x + a.x) // 2x + 2 + 2x = 4x + 2 + return b + } + + let (y, differential) = valueWithDifferential( + at: A(x: 4), + in: structExtract) + expectEqual(A(x: 18), y) + expectEqual(A(x: 4), differential(A(x: 1))) +} + +// Functions with methods. + +extension A { + func noParamMethodA() -> A { + return A(x: 2 * x) + } + + func noParamMethodx() -> Float { + return 2 * x + } + + static func *(lhs: A, rhs: A) -> A { + return A(x: lhs.x * rhs.x) + } + + func complexBinaryMethod(u: A, v: Float) -> A { + var b: A = u * A(x: 2) // A(x: u * 2) + b.x = b.x * v // A(x: u * 2 * v) + let c = b.x + 1 // u * 2 * v + 1 + + // A(x: u * 2 * v + 1 + u * 2 * v) = A(x: x * (4uv + 1)) + return A(x: x * (c + b.x)) + } +} + +ForwardModeTests.test("noParamMethodA") { + let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in + x.noParamMethodA() + } + expectEqual(A(x: 8), y) + expectEqual(A(x: 2), differential(A(x: 1))) +} + +ForwardModeTests.test("noParamMethodx") { + let (y, differential) = valueWithDifferential(at: A(x: 4)) { x in + x.noParamMethodx() + } + expectEqual(8, y) + expectEqual(2, differential(A(x: 1))) +} + +ForwardModeTests.test("complexBinaryMethod") { + let (y, differential) = valueWithDifferential(at: A(x: 4), A(x: 5), 3) { + (x, y, z) in + // derivative = A(x: 4uv + 4xv + 4ux + 1) = 4*5*3 + 4*4*3 + 4*5*4 + 1 = 189 + x.complexBinaryMethod(u: y, v: z) + } + expectEqual(A(x: 244), y) + expectEqual(A(x: 189), differential(A(x: 1), A(x: 1), 1)) +} + runAllTests()