diff --git a/docs/SIL.rst b/docs/SIL.rst index 412d556aa4fd9..1fa1c90cd50ef 100644 --- a/docs/SIL.rst +++ b/docs/SIL.rst @@ -5673,17 +5673,23 @@ differentiable_function_extract sil-instruction ::= 'differentiable_function_extract' '[' sil-differentiable-function-extractee ']' sil-value ':' sil-type + ('as' sil-type)? sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp' differentiable_function_extract [original] %0 : $@differentiable (T) -> T differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T + differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \ + as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector) Extracts the original function or a derivative function from the given ``@differentiable`` function. It must be provided with an extractee: ``[original]``, ``[jvp]`` or ``[vjp]``. +An explicit extractee type may be provided in lowered SIL. This is currently +used by the LoadableByAddress transformation, which rewrites function types. + linear_function_extract ``````````````````````` diff --git a/include/swift/AST/SILOptions.h b/include/swift/AST/SILOptions.h index a1dbc72ec031b..761bc797f6291 100644 --- a/include/swift/AST/SILOptions.h +++ b/include/swift/AST/SILOptions.h @@ -143,10 +143,7 @@ class SILOptions { bool EnableDynamicReplacementCanCallPreviousImplementation = true; /// Enable large loadable types IRGen pass. - // bool EnableLargeLoadableTypes = true; - // FIXME(TF-11, SR-9849): Disabled because LoadableByAddress cannot handle - // some functions that return closures that take/return large loadable types. - bool EnableLargeLoadableTypes = false; + bool EnableLargeLoadableTypes = true; /// Should the default pass pipelines strip ownership during the diagnostic /// pipeline or after serialization. diff --git a/include/swift/SIL/SILBuilder.h b/include/swift/SIL/SILBuilder.h index 7e7c365646122..b26adf1059d46 100644 --- a/include/swift/SIL/SILBuilder.h +++ b/include/swift/SIL/SILBuilder.h @@ -526,12 +526,14 @@ class SILBuilder { getModule(), getSILDebugLocation(Loc), ParameterIndices, OriginalFunction, TransposeFunction, hasOwnership())); } - + + /// Note: explicit extractee type may be specified only in lowered SIL. DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract( SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee, - SILValue TheFunction) { + SILValue TheFunction, Optional ExtracteeType = None) { return insert(new (getModule()) DifferentiableFunctionExtractInst( - getModule(), getSILDebugLocation(Loc), Extractee, TheFunction)); + getModule(), getSILDebugLocation(Loc), Extractee, TheFunction, + ExtracteeType)); } LinearFunctionExtractInst *createLinearFunctionExtract( diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index b889162dbf0d5..81533a4b1caae 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -7969,9 +7969,11 @@ class DifferentiableFunctionExtractInst SingleValueInstruction> { private: /// The extractee. - NormalDifferentiableFunctionTypeComponent extractee; + NormalDifferentiableFunctionTypeComponent Extractee; /// The list containing the `@differentiable` function operand. - FixedOperandList<1> operands; + FixedOperandList<1> Operands; + /// True if the instruction has an explicit extractee type. + bool HasExplicitExtracteeType; static SILType getExtracteeType( @@ -7979,24 +7981,26 @@ class DifferentiableFunctionExtractInst SILModule &module); public: + /// Note: explicit extractee type may be specified only in lowered SIL. explicit DifferentiableFunctionExtractInst( SILModule &module, SILDebugLocation debugLoc, NormalDifferentiableFunctionTypeComponent extractee, - SILValue theFunction); + SILValue theFunction, Optional extracteeType = None); NormalDifferentiableFunctionTypeComponent getExtractee() const { - return extractee; + return Extractee; } AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const { - auto kind = extractee.getAsDerivativeFunctionKind(); + auto kind = Extractee.getAsDerivativeFunctionKind(); assert(kind); return *kind; } - SILValue getFunctionOperand() const { return operands[0].get(); } - ArrayRef getAllOperands() const { return operands.asArray(); } - MutableArrayRef getAllOperands() { return operands.asArray(); } + SILValue getFunctionOperand() const { return Operands[0].get(); } + ArrayRef getAllOperands() const { return Operands.asArray(); } + MutableArrayRef getAllOperands() { return Operands.asArray(); } + bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; } }; /// `linear_function_extract` - given an `@differentiable(linear)` function diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index a6a4584e6898e..596b0bfc09566 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -4511,10 +4511,10 @@ void SILParameterInfo::print(ASTPrinter &Printer, const PrintOptions &Opts) const { /// SWIFT_ENABLE_TENSORFLOW switch (getDifferentiability()) { - case SILParameterDifferentiability::NotDifferentiable: + case SILParameterDifferentiability::NotDifferentiable: Printer << "@nondiff "; break; - default: + default: break; } Printer << getStringForParameterConvention(getConvention()); diff --git a/lib/IRGen/LoadableByAddress.cpp b/lib/IRGen/LoadableByAddress.cpp index f2a780fa7297f..996c74059740a 100644 --- a/lib/IRGen/LoadableByAddress.cpp +++ b/lib/IRGen/LoadableByAddress.cpp @@ -381,14 +381,23 @@ SILParameterInfo LargeSILTypeMapper::getNewParameter(GenericEnvironment *env, } else if (isLargeLoadableType(env, storageType, IGM)) { if (param.getConvention() == ParameterConvention::Direct_Guaranteed) return SILParameterInfo(storageType.getASTType(), - ParameterConvention::Indirect_In_Guaranteed); + // SWIFT_ENABLE_TENSORFLOW + ParameterConvention::Indirect_In_Guaranteed, + param.getDifferentiability()); + // SWIFT_ENABLE_TENSORFLOW_END else return SILParameterInfo(storageType.getASTType(), - ParameterConvention::Indirect_In_Constant); + // SWIFT_ENABLE_TENSORFLOW + ParameterConvention::Indirect_In_Constant, + param.getDifferentiability()); + // SWIFT_ENABLE_TENSORFLOW_END } else { auto newType = getNewSILType(env, storageType, IGM); return SILParameterInfo(newType.getASTType(), - param.getConvention()); + // SWIFT_ENABLE_TENSORFLOW + param.getConvention(), + param.getDifferentiability()); + // SWIFT_ENABLE_TENSORFLOW_END } } @@ -2757,8 +2766,10 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I, } case SILInstructionKind::DifferentiableFunctionExtractInst: { auto instr = cast(convInstr); + // Rewrite `differentiable_function_extract` with explicit extractee type. newInstr = convBuilder.createDifferentiableFunctionExtract( - instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand()); + instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand(), + newType); break; } case SILInstructionKind::LinearFunctionExtractInst: { diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp index 50ba2f3470acd..edf86eff7782d 100644 --- a/lib/ParseSIL/ParseSIL.cpp +++ b/lib/ParseSIL/ParseSIL.cpp @@ -3050,7 +3050,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { case SILInstructionKind::DifferentiableFunctionExtractInst: { // Parse the rest of the instruction: an extractee, a differentiable - // function operand, and a debug location. + // function operand, an optional explicit extractee type, and a debug + // location. NormalDifferentiableFunctionTypeComponent extractee; StringRef extracteeNames[3] = {"original", "jvp", "vjp"}; SILValue functionOperand; @@ -3062,11 +3063,19 @@ bool SILParser::parseSILInstruction(SILBuilder &B) { P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare, "extractee kind")) return true; - if (parseTypedValueRef(functionOperand, B) || - parseSILDebugLocation(InstLoc, B)) + if (parseTypedValueRef(functionOperand, B)) + return true; + // Parse an optional explicit extractee type. + Optional extracteeType = None; + if (P.consumeIf(tok::kw_as)) { + extracteeType = SILType(); + if (parseSILType(*extracteeType)) + return true; + } + if (parseSILDebugLocation(InstLoc, B)) return true; ResultVal = B.createDifferentiableFunctionExtract( - InstLoc, extractee, functionOperand); + InstLoc, extractee, functionOperand, extracteeType); break; } diff --git a/lib/SIL/SILInstructions.cpp b/lib/SIL/SILInstructions.cpp index 4354a8ccd28f3..7e228552eefb4 100644 --- a/lib/SIL/SILInstructions.cpp +++ b/lib/SIL/SILInstructions.cpp @@ -685,10 +685,21 @@ getExtracteeType( DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst( SILModule &module, SILDebugLocation debugLoc, - NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction) + NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction, + Optional extracteeType) : InstructionBase(debugLoc, - getExtracteeType(theFunction, extractee, module)), - extractee(extractee), operands(this, theFunction) {} + extracteeType + ? *extracteeType + : getExtracteeType(theFunction, extractee, module)), + Extractee(extractee), Operands(this, theFunction), + HasExplicitExtracteeType(extracteeType.hasValue()) { +#ifndef NDEBUG + if (extracteeType.hasValue()) { + assert(module.getStage() == SILStage::Lowered && + "Explicit type is valid only in lowered SIL"); + } +#endif +} SILType LinearFunctionExtractInst:: getExtracteeType( diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 2c2e73fa36d63..1d8454a24e8d7 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -1251,6 +1251,10 @@ class SILPrinter : public SILInstructionVisitor { } *this << "] "; *this << getIDAndType(dfei->getFunctionOperand()); + if (dfei->hasExplicitExtracteeType()) { + *this << " as "; + *this << dfei->getType(); + } } void visitLinearFunctionExtractInst(LinearFunctionExtractInst *lfei) { diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp index deb21f5e042f3..12724585fd626 100644 --- a/lib/SIL/SILVerifier.cpp +++ b/lib/SIL/SILVerifier.cpp @@ -1498,8 +1498,12 @@ class SILVerifier : public SILVerifierBase { require(origTy, "The original function must have a function type"); require(!origTy->isDifferentiable(), "The original function must not be @differentiable"); - if (F.getModule().getStage() == SILStage::Canonical || - dfi->hasDerivativeFunctions()) { + // Skip lowered SIL: LoadableByAddress changes parameter/result conventions. + // TODO: Check that derivative function types match excluding + // parameter/result conventions in lowered SIL. + if (F.getModule().getStage() == SILStage::Lowered) + return; + if (dfi->hasDerivativeFunctions()) { auto jvp = dfi->getJVPFunction(); auto jvpType = jvp->getType().getAs(); require(jvpType, "The JVP function must have a function type"); @@ -1533,8 +1537,12 @@ class SILVerifier : public SILVerifierBase { require(origTy, "The original function must have a function type"); require(!origTy->isDifferentiable(), "The original function must not be differentiable"); - if (F.getModule().getStage() == SILStage::Canonical || - lfi->hasTransposeFunction()) { + // Skip lowered SIL: LoadableByAddress changes parameter/result conventions. + // TODO: Check that transpose function type matches excluding + // parameter/result conventions in lowered SIL. + if (F.getModule().getStage() == SILStage::Lowered) + return; + if (lfi->hasTransposeFunction()) { auto transpose = lfi->getTransposeFunction(); auto transposeType = transpose->getType().getAs(); require(transposeType, diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index c0b656bc68c32..08a4390441273 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -1146,7 +1146,8 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, break; case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT: SILInstDifferentiableFunctionExtractLayout::readRecord( - scratch, TyID, TyCategory, ValID, /*extractee*/ Attr); + scratch, TyID, TyCategory, ValID, /*extractee*/ Attr, + /*hasExplicitExtracteeType*/ Attr2); RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst; break; case SIL_INST_LINEAR_FUNCTION_EXTRACT: @@ -1609,8 +1610,12 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB, auto silTy = getSILType(astTy, SILValueCategory::Object); auto val = getLocalValue(ValID, silTy); NormalDifferentiableFunctionTypeComponent extractee(Attr); + Optional explicitExtracteeType = None; + if (Attr2) + explicitExtracteeType = silTy; ResultVal = - Builder.createDifferentiableFunctionExtract(Loc, extractee, val); + Builder.createDifferentiableFunctionExtract(Loc, extractee, val, + explicitExtracteeType); break; } case SILInstructionKind::LinearFunctionExtractInst: { diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h index 29db40cdd0d38..b95f6272e8bb2 100644 --- a/lib/Serialization/ModuleFormat.h +++ b/lib/Serialization/ModuleFormat.h @@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0; /// describe what change you made. The content of this comment isn't important; /// it just ensures a conflict if two people change the module format. /// Don't worry about adhering to the 80-column limit for this line. -const uint16_t SWIFTMODULE_VERSION_MINOR = 523; // differentiable_function and differentiable_function_extract instructions +const uint16_t SWIFTMODULE_VERSION_MINOR = 524; // differentiable_function_extract explicit extractee type /// A standard hash seed used for all string hashes in a serialized module. /// diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h index 0f2f1c6b4d4bf..96aaef0c2f2a2 100644 --- a/lib/Serialization/SILFormat.h +++ b/lib/Serialization/SILFormat.h @@ -456,7 +456,8 @@ namespace sil_block { TypeIDField, SILTypeCategoryField, ValueIDField, - BCFixed<2> // extractee + BCFixed<2>, // extractee + BCFixed<1> // has explicit extractee type? >; using SILInstLinearFunctionExtractLayout = BCRecordLayout< diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 8151f58542ec3..a80ee958ab19e 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -1043,7 +1043,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) { SILInstDifferentiableFunctionExtractLayout::emitRecord(Out, ScratchRecord, SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code], operandTypeRef, (unsigned)operandType.getCategory(), operandRef, - rawExtractee); + rawExtractee, (unsigned)dfei->hasExplicitExtracteeType()); break; } case SILInstructionKind::LinearFunctionExtractInst: { diff --git a/test/AutoDiff/differentiable_function_inst_lowered.sil b/test/AutoDiff/differentiable_function_inst_lowered.sil new file mode 100644 index 0000000000000..5d06ad15ebfa9 --- /dev/null +++ b/test/AutoDiff/differentiable_function_inst_lowered.sil @@ -0,0 +1,64 @@ +// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s + +// Test `differentiable_function_extract` with explicit lowered type. +// SIL generated via `%target-sil-opt -loadable-address %s`. +// Note: SIL serialization/deserialization does not support lowered SIL. + +sil_stage lowered + +import Swift +import Builtin + +struct Large : _Differentiable { + @_hasStorage @noDerivative let a: Float { get } + @_hasStorage @noDerivative let b: Float { get } + @_hasStorage @noDerivative let c: Float { get } + @_hasStorage @noDerivative let d: Float { get } + @_hasStorage @noDerivative let e: Float { get } + init(a: Float, b: Float, c: Float, d: Float, e: Float) + struct TangentVector : _Differentiable, AdditiveArithmetic { + init() + typealias TangentVector = Large.TangentVector + static var zero: Large.TangentVector { get } + static func + (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector + static func - (lhs: Large.TangentVector, rhs: Large.TangentVector) -> Large.TangentVector + @_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Large.TangentVector, _ b: Large.TangentVector) -> Bool + } + mutating func move(along direction: Large.TangentVector) +} + +sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large +sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + +// CHECK-LABEL: sil @test +sil @test : $@convention(thin) () -> () { +bb0: + %0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + + // CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + + %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + + // CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + + %5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + + // CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector)) + + %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + + // CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large + // CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) + + %10 = tuple () + return %10 : $() +} diff --git a/test/AutoDiff/loadable-by-address.swift b/test/AutoDiff/loadable-by-address.swift new file mode 100644 index 0000000000000..929f7acdcb9b1 --- /dev/null +++ b/test/AutoDiff/loadable-by-address.swift @@ -0,0 +1,75 @@ +// RUN: %target-swift-frontend -c -enable-large-loadable-types -Xllvm -sil-verify-after-pass=loadable-address %s +// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL +// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %s 2>&1 | %FileCheck %s -check-prefix=CHECK-LBA-SIL +// RUN: %target-run-simple-swift +// REQUIRES: executable_test + +// TF-11: Verify that LoadableByAddress works with differentiation-related instructions: +// - `differentiable_function` +// - `differentiable_function_extract` + +// TODO: Add tests for `@differentiable(linear)` functions. + +import StdlibUnittest + +var LBATests = TestSuite("LoadableByAddress") + +// `Large` is a large loadable type. +// `Large.TangentVector` is not a large loadable type. +struct Large : Differentiable { + var a: Float + var b: Float + var c: Float + var d: Float + @noDerivative let e: Float +} + +@_silgen_name("large2large") +@differentiable +func large2large(_ foo: Large) -> Large { + foo +} + +// `large2large` old verification error: +// SIL verification failed: JVP type does not match expected JVP type +// $@callee_guaranteed (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) +// $@callee_guaranteed (@in_constant Large) -> (@out Large, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> @out Large.TangentVector) + +@_silgen_name("large2small") +@differentiable +func large2small(_ foo: Large) -> Float { + foo.a +} + +// `large2small` old verification error: +// SIL verification failed: JVP type does not match expected JVP type +// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) +// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> Float) + +// CHECK-SIL: sil hidden {{.*}} @large2large : $@convention(thin) (Large) -> Large { +// CHECK-LBA-SIL: sil hidden {{.*}} @large2large : $@convention(thin) (@in_constant Large) -> @out Large { + +// CHECK-SIL-LABEL: sil hidden {{.*}} @large2small : $@convention(thin) (Large) -> Float { +// CHECK-LBA-SIL: sil hidden {{.*}} @large2small : $@convention(thin) (@in_constant Large) -> Float { + +// CHECK-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) { +// CHECK-LBA-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) { + +// CHECK-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) { +// CHECK-LBA-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) { + +// CHECK-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) { +// CHECK-LBA-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) { + +// CHECK-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) { +// CHECK-LBA-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) { + +LBATests.test("Correctness") { + let one = Large.TangentVector(a: 1, b: 1, c: 1, d: 1) + expectEqual(one, + pullback(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2large)(one)) + expectEqual(Large.TangentVector(a: 1, b: 0, c: 0, d: 0), + gradient(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2small)) +} + +runAllTests()