From d9bc2cb2fa451632f2d8cc0651860041afe94d80 Mon Sep 17 00:00:00 2001 From: Dario Rexin Date: Thu, 6 Jun 2024 12:52:51 -0700 Subject: [PATCH 1/2] [IRGen] Return typed errors directly in synchronous functions when possible rdar://129359355 This PR implements the basic support for returning typed errors directly and applies it to synchronous functions. --- lib/IRGen/CallEmission.h | 6 + lib/IRGen/GenCall.cpp | 336 ++++++++++++++++++++++++--- lib/IRGen/GenCall.h | 9 + lib/IRGen/GenFunc.cpp | 15 +- lib/IRGen/GenThunk.cpp | 43 +++- lib/IRGen/IRGenFunction.h | 2 +- lib/IRGen/IRGenSIL.cpp | 147 ++++++++++-- lib/IRGen/NativeConventionSchema.h | 3 + test/IRGen/typed_throws.sil | 34 ++- test/IRGen/typed_throws.swift | 7 +- test/IRGen/typed_throws_thunks.swift | 4 +- 11 files changed, 515 insertions(+), 91 deletions(-) diff --git a/lib/IRGen/CallEmission.h b/lib/IRGen/CallEmission.h index 5f189f4c4c705..fbdff0893e9dd 100644 --- a/lib/IRGen/CallEmission.h +++ b/lib/IRGen/CallEmission.h @@ -19,6 +19,7 @@ #include "Address.h" #include "Callee.h" +#include "Explosion.h" #include "Temporary.h" namespace llvm { @@ -88,6 +89,7 @@ class CallEmission { unsigned IndirectTypedErrorArgIdx = 0; + std::optional typedErrorExplosion; virtual void setFromCallee(); void emitToUnmappedMemory(Address addr); @@ -123,6 +125,10 @@ class CallEmission { return CurCallee.getSubstitutions(); } + std::optional &getTypedErrorExplosion() { + return typedErrorExplosion; + } + virtual void begin(); virtual void end(); virtual SILType getParameterType(unsigned index) = 0; diff --git a/lib/IRGen/GenCall.cpp b/lib/IRGen/GenCall.cpp index b06534faec1bf..64d9d79c8be17 100644 --- a/lib/IRGen/GenCall.cpp +++ b/lib/IRGen/GenCall.cpp @@ -373,6 +373,70 @@ static void addIndirectResultAttributes(IRGenModule &IGM, attrs = attrs.addParamAttributes(IGM.getLLVMContext(), paramIndex, b); } +CombinedResultAndErrorType irgen::combineResultAndTypedErrorType( + const IRGenModule &IGM, const NativeConventionSchema &resultSchema, + const NativeConventionSchema &errorSchema) { + CombinedResultAndErrorType result; + SmallVector elts; + resultSchema.enumerateComponents( + [&](clang::CharUnits offset, clang::CharUnits end, llvm::Type *type) { + elts.push_back(type); + }); + + SmallVector errorElts; + errorSchema.enumerateComponents( + [&](clang::CharUnits offset, clang::CharUnits end, llvm::Type *type) { + errorElts.push_back(type); + }); + + llvm::SmallVector combined; + + auto resIt = elts.begin(); + auto errorIt = errorElts.begin(); + + while (resIt < elts.end() && errorIt < errorElts.end()) { + auto *res = *resIt; + if (!res->isIntOrPtrTy()) { + combined.push_back(res); + ++resIt; + continue; + } + + auto *error = *errorIt; + result.errorValueMapping.push_back(combined.size()); + if (res->getPrimitiveSizeInBits() >= error->getPrimitiveSizeInBits()) { + combined.push_back(res); + } else { + combined.push_back(error); + } + + ++resIt; + ++errorIt; + } + + while (resIt < elts.end()) { + combined.push_back(*resIt); + ++resIt; + } + + while (errorIt < errorElts.end()) { + result.errorValueMapping.push_back(combined.size()); + combined.push_back(*errorIt); + ++errorIt; + } + + if (combined.empty()) { + result.combinedTy = llvm::Type::getVoidTy(IGM.getLLVMContext()); + } else if (combined.size() == 1) { + result.combinedTy = combined[0]; + } else { + result.combinedTy = + llvm::StructType::get(IGM.getLLVMContext(), combined, /*packed*/ false); + } + + return result; +} + void IRGenModule::addSwiftAsyncContextAttributes(llvm::AttributeList &attrs, unsigned argIndex) { llvm::AttrBuilder b(getLLVMContext()); @@ -519,6 +583,7 @@ namespace { /// the direct result of this function. If the result is passed indirectly, /// a void type is returned instead, with a \c null type info. std::pair expandDirectResult(); + std::pair expandDirectErrorType(); void expandIndirectResults(); void expandParameters(SignatureExpansionABIDetails *recordedABIDetails); void expandKeyPathAccessorParameters(); @@ -572,6 +637,17 @@ void SignatureExpansion::expandResult( const TypeInfo *directResultTypeInfo; std::tie(ResultIRType, directResultTypeInfo) = expandDirectResult(); + if (!fnConv.hasIndirectSILErrorResults()) { + llvm::Type *directErrorType; + const TypeInfo *directErrorTypeInfo; + std::tie(directErrorType, directErrorTypeInfo) = expandDirectErrorType(); + if ((directResultTypeInfo || ResultIRType->isVoidTy()) && + directErrorTypeInfo) { + ResultIRType = directErrorType; + directResultTypeInfo = directErrorTypeInfo; + } + } + // Expand the indirect results. expandIndirectResults(); @@ -806,9 +882,8 @@ llvm::Type *NativeConventionSchema::getExpandedType(IRGenModule &IGM) const { if (empty()) return IGM.VoidTy; SmallVector elts; - Lowering.enumerateComponents([&](clang::CharUnits offset, - clang::CharUnits end, - llvm::Type *type) { elts.push_back(type); }); + enumerateComponents([&](clang::CharUnits offset, clang::CharUnits end, + llvm::Type *type) { elts.push_back(type); }); if (elts.size() == 1) return elts[0]; @@ -832,7 +907,7 @@ NativeConventionSchema::getCoercionTypes( unsigned idx = 0; // Mark overlapping ranges. - Lowering.enumerateComponents( + enumerateComponents( [&](clang::CharUnits offset, clang::CharUnits end, llvm::Type *type) { if (offset < lastEnd) { overlappedWithSuccessor.insert(idx); @@ -847,7 +922,7 @@ NativeConventionSchema::getCoercionTypes( lastEnd = clang::CharUnits::Zero(); SmallVector elts; bool packed = false; - Lowering.enumerateComponents( + enumerateComponents( [&](clang::CharUnits begin, clang::CharUnits end, llvm::Type *type) { bool overlapped = overlappedWithSuccessor.count(idx) || (idx && overlappedWithSuccessor.count(idx - 1)); @@ -887,7 +962,7 @@ NativeConventionSchema::getCoercionTypes( lastEnd = clang::CharUnits::Zero(); elts.clear(); packed = false; - Lowering.enumerateComponents( + enumerateComponents( [&](clang::CharUnits begin, clang::CharUnits end, llvm::Type *type) { bool overlapped = overlappedWithSuccessor.count(idx) || (idx && overlappedWithSuccessor.count(idx - 1)); @@ -952,6 +1027,38 @@ SignatureExpansion::expandDirectResult() { llvm_unreachable("Not a valid SILFunctionLanguage."); } +std::pair +SignatureExpansion::expandDirectErrorType() { + if (!getSILFuncConventions().funcTy->hasErrorResult() || + !getSILFuncConventions().isTypedError()) { + return std::make_pair(nullptr, nullptr); + } + + switch (FnType->getLanguage()) { + case SILFunctionLanguage::C: + llvm_unreachable("Expanding C/ObjC parameters in the wrong place!"); + break; + case SILFunctionLanguage::Swift: { + auto resultType = getSILFuncConventions().getSILResultType( + IGM.getMaximalTypeExpansionContext()); + auto errorType = getSILFuncConventions().getSILErrorType( + IGM.getMaximalTypeExpansionContext()); + const auto &ti = IGM.getTypeInfo(resultType); + auto &native = ti.nativeReturnValueSchema(IGM); + const auto &errorTI = IGM.getTypeInfo(errorType); + auto &errorNative = errorTI.nativeReturnValueSchema(IGM); + if (native.requiresIndirect() || + errorNative.shouldReturnTypedErrorIndirectly()) { + return std::make_pair(nullptr, nullptr); + } + + auto combined = combineResultAndTypedErrorType(IGM, native, errorNative); + + return std::make_pair(combined.combinedTy, &errorTI); + } + } +} + static const clang::FieldDecl * getLargestUnionField(const clang::RecordDecl *record, const clang::ASTContext &ctx) { @@ -1901,10 +2008,21 @@ void SignatureExpansion::expandParameters( if (recordedABIDetails) recordedABIDetails->hasErrorResult = true; if (getSILFuncConventions().isTypedError()) { - ParamIRTypes.push_back( - IGM.getStorageType(getSILFuncConventions().getSILType( - FnType->getErrorResult(), IGM.getMaximalTypeExpansionContext()) - )->getPointerTo()); + + auto resultType = getSILFuncConventions().getSILResultType( + IGM.getMaximalTypeExpansionContext()); + auto &resultTI = IGM.getTypeInfo(resultType); + auto &native = resultTI.nativeReturnValueSchema(IGM); + auto errorType = getSILFuncConventions().getSILErrorType( + IGM.getMaximalTypeExpansionContext()); + auto &errorTI = IGM.getTypeInfo(errorType); + auto &nativeError = errorTI.nativeReturnValueSchema(IGM); + + if (getSILFuncConventions().hasIndirectSILErrorResults() || + native.requiresIndirect() || + nativeError.shouldReturnTypedErrorIndirectly()) { + ParamIRTypes.push_back(IGM.getStorageType(errorType)->getPointerTo()); + } } } @@ -2453,10 +2571,22 @@ class SyncCallEmission final : public CallEmission { setIndirectTypedErrorResultSlotArgsIndex(--LastArgWritten); Args[LastArgWritten] = nullptr; } else { - // Return the error indirectly. - auto buf = IGF.getCalleeTypedErrorResultSlot( - fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext())); - Args[--LastArgWritten] = buf.getAddress(); + auto silResultTy = + fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext()); + auto silErrorTy = + fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()); + + auto &nativeSchema = + IGF.IGM.getTypeInfo(silResultTy).nativeReturnValueSchema(IGF.IGM); + auto &errorSchema = + IGF.IGM.getTypeInfo(silErrorTy).nativeReturnValueSchema(IGF.IGM); + + if (nativeSchema.requiresIndirect() || + errorSchema.shouldReturnTypedErrorIndirectly()) { + // Return the error indirectly. + auto buf = IGF.getCalleeTypedErrorResultSlot(silErrorTy); + Args[--LastArgWritten] = buf.getAddress(); + } } } Args[--LastArgWritten] = errorResultSlot.getAddress(); @@ -2599,14 +2729,25 @@ class SyncCallEmission final : public CallEmission { } void emitCallToUnmappedExplosion(llvm::CallBase *call, Explosion &out) override { + SILFunctionConventions fnConv(getCallee().getOrigFunctionType(), + IGF.getSILModule()); + bool mayReturnErrorDirectly = false; + if (!convertDirectToIndirectReturn && + !fnConv.hasIndirectSILErrorResults() && + fnConv.funcTy->hasErrorResult() && fnConv.isTypedError()) { + auto errorType = + fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()); + auto &errorSchema = + IGF.IGM.getTypeInfo(errorType).nativeReturnValueSchema(IGF.IGM); + + mayReturnErrorDirectly = !errorSchema.shouldReturnTypedErrorIndirectly(); + } + // Bail out immediately on a void result. llvm::Value *result = call; - if (result->getType()->isVoidTy()) + if (result->getType()->isVoidTy() && !mayReturnErrorDirectly) return; - SILFunctionConventions fnConv(getCallee().getOrigFunctionType(), - IGF.getSILModule()); - // If the result was returned autoreleased, implicitly insert the reclaim. // This is only allowed on a single direct result. if (fnConv.getNumDirectSILResults() == 1 @@ -2645,6 +2786,76 @@ class SyncCallEmission final : public CallEmission { auto &nativeSchema = IGF.IGM.getTypeInfo(resultType).nativeReturnValueSchema(IGF.IGM); + // Handle direct return of typed errors + if (mayReturnErrorDirectly && !nativeSchema.requiresIndirect()) { + auto errorType = + fnConv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()); + auto &errorSchema = + IGF.IGM.getTypeInfo(errorType).nativeReturnValueSchema(IGF.IGM); + + auto combined = + combineResultAndTypedErrorType(IGF.IGM, nativeSchema, errorSchema); + + if (combined.combinedTy->isVoidTy()) { + typedErrorExplosion = Explosion(); + return; + } + + Explosion nativeExplosion; + extractScalarResults(IGF, result->getType(), result, nativeExplosion); + auto values = nativeExplosion.claimAll(); + + auto convertIfNecessary = [&](llvm::Type *nativeTy, + llvm::Value *elt) -> llvm::Value * { + auto *eltTy = elt->getType(); + if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() && + nativeTy->getPrimitiveSizeInBits() != + eltTy->getPrimitiveSizeInBits()) { + return IGF.Builder.CreateTruncOrBitCast(elt, nativeTy); + } + return elt; + }; + + Explosion errorExplosion; + if (!errorSchema.empty()) { + if (auto *structTy = dyn_cast( + errorSchema.getExpandedType(IGF.IGM))) { + for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) { + llvm::Value *elt = values[combined.errorValueMapping[i]]; + auto *nativeTy = structTy->getElementType(i); + elt = convertIfNecessary(nativeTy, elt); + errorExplosion.add(elt); + } + } else { + errorExplosion.add(convertIfNecessary(combined.combinedTy, values[0])); + } + + typedErrorExplosion = + errorSchema.mapFromNative(IGF.IGM, IGF, errorExplosion, errorType); + } else { + typedErrorExplosion = std::move(errorExplosion); + } + + // If the regular result type is void, there is nothing to explode + if (!resultType.isVoid()) { + Explosion resultExplosion; + if (auto *structTy = dyn_cast( + nativeSchema.getExpandedType(IGF.IGM))) { + for (unsigned i = 0, e = structTy->getNumElements(); i < e; ++i) { + resultExplosion.add(values[i]); + } + } else { + resultExplosion.add(values[0]); + } + out = nativeSchema.mapFromNative(IGF.IGM, IGF, resultExplosion, + resultType); + } + return; + } + + if (result->getType()->isVoidTy()) + return; + // For ABI reasons the result type of the call might not actually match the // expected result type. // @@ -5105,9 +5316,8 @@ unsigned NativeConventionSchema::size() const { if (empty()) return 0; unsigned size = 0; - Lowering.enumerateComponents([&](clang::CharUnits offset, - clang::CharUnits end, - llvm::Type *type) { ++size; }); + enumerateComponents([&](clang::CharUnits offset, clang::CharUnits end, + llvm::Type *type) { ++size; }); return size; } @@ -5449,8 +5659,16 @@ Explosion IRGenFunction::coerceValueTo(SILType fromTy, Explosion &from, void IRGenFunction::emitScalarReturn(SILType returnResultType, SILType funcResultType, Explosion &result, - bool isSwiftCCReturn, bool isOutlined) { - if (result.empty()) { + bool isSwiftCCReturn, bool isOutlined, + SILType errorType) { + bool mayReturnErrorDirectly = false; + if (errorType) { + auto &errorTI = IGM.getTypeInfo(errorType); + auto &nativeError = errorTI.nativeReturnValueSchema(IGM); + mayReturnErrorDirectly = !nativeError.shouldReturnTypedErrorIndirectly(); + } + + if (result.empty() && !mayReturnErrorDirectly) { assert(IGM.getTypeInfo(returnResultType) .nativeReturnValueSchema(IGM) .empty() && @@ -5462,24 +5680,74 @@ void IRGenFunction::emitScalarReturn(SILType returnResultType, // In the native case no coercion is needed. if (isSwiftCCReturn) { - result = coerceValueTo(returnResultType, result, funcResultType); - auto &nativeSchema = - IGM.getTypeInfo(funcResultType).nativeReturnValueSchema(IGM); + auto &resultTI = IGM.getTypeInfo(funcResultType); + auto &nativeSchema = resultTI.nativeReturnValueSchema(IGM); assert(!nativeSchema.requiresIndirect()); + result = coerceValueTo(returnResultType, result, funcResultType); Explosion native = nativeSchema.mapIntoNative(IGM, *this, result, funcResultType, isOutlined); - if (native.size() == 1) { - Builder.CreateRet(native.claimNext()); - return; + llvm::Value *nativeAgg = nullptr; + + if (mayReturnErrorDirectly) { + auto &errorTI = IGM.getTypeInfo(errorType); + auto &nativeError = errorTI.nativeReturnValueSchema(IGM); + auto *combinedTy = + combineResultAndTypedErrorType(IGM, nativeSchema, nativeError) + .combinedTy; + + if (combinedTy->isVoidTy()) { + Builder.CreateRetVoid(); + return; + } + + if (native.empty()) { + Builder.CreateRet(llvm::UndefValue::get(combinedTy)); + return; + } + + auto convertIfNecessary = [&](llvm::Type *nativeTy, + llvm::Value *elt) -> llvm::Value * { + auto *eltTy = elt->getType(); + if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() && + nativeTy->getPrimitiveSizeInBits() != + eltTy->getPrimitiveSizeInBits()) { + assert(nativeTy->getPrimitiveSizeInBits() > + eltTy->getPrimitiveSizeInBits()); + return Builder.CreateZExt(elt, nativeTy); + } + return elt; + }; + + if (auto *structTy = dyn_cast(combinedTy)) { + nativeAgg = llvm::UndefValue::get(combinedTy); + for (unsigned i = 0, e = native.size(); i != e; ++i) { + llvm::Value *elt = native.claimNext(); + auto *nativeTy = structTy->getElementType(i); + elt = convertIfNecessary(nativeTy, elt); + nativeAgg = Builder.CreateInsertValue(nativeAgg, elt, i); + } + } else { + nativeAgg = convertIfNecessary(combinedTy, native.claimNext()); + } } - llvm::Value *nativeAgg = - llvm::UndefValue::get(nativeSchema.getExpandedType(IGM)); - for (unsigned i = 0, e = native.size(); i != e; ++i) { - llvm::Value *elt = native.claimNext(); - nativeAgg = Builder.CreateInsertValue(nativeAgg, elt, i); + + if (!nativeAgg) { + if (native.size() == 1) { + Builder.CreateRet(native.claimNext()); + return; + } + + nativeAgg = llvm::UndefValue::get(nativeSchema.getExpandedType(IGM)); + + for (unsigned i = 0, e = native.size(); i != e; ++i) { + llvm::Value *elt = native.claimNext(); + nativeAgg = Builder.CreateInsertValue(nativeAgg, elt, i); + } } + Builder.CreateRet(nativeAgg); + return; } diff --git a/lib/IRGen/GenCall.h b/lib/IRGen/GenCall.h index 7adca3f8afc39..a912f14b39751 100644 --- a/lib/IRGen/GenCall.h +++ b/lib/IRGen/GenCall.h @@ -121,6 +121,15 @@ namespace irgen { CanSILFunctionType substitutedType, SubstitutionMap substitutionMap); + struct CombinedResultAndErrorType { + llvm::Type *combinedTy; + llvm::SmallVector errorValueMapping; + }; + CombinedResultAndErrorType + combineResultAndTypedErrorType(const IRGenModule &IGM, + const NativeConventionSchema &resultSchema, + const NativeConventionSchema &errorSchema); + /// Given an async function, get the pointer to the function to be called and /// the size of the context to be allocated. /// diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp index 37e4cf848b776..67b46cf7eeab6 100644 --- a/lib/IRGen/GenFunc.cpp +++ b/lib/IRGen/GenFunc.cpp @@ -1166,8 +1166,19 @@ class SyncPartialApplicationForwarderEmission llvm::Value *errorResultPtr = origParams.claimNext(); args.add(errorResultPtr); if (origConv.isTypedError()) { - auto *typedErrorResultPtr = origParams.claimNext(); - args.add(typedErrorResultPtr); + auto errorType = + origConv.getSILErrorType(IGM.getMaximalTypeExpansionContext()); + auto silResultTy = + origConv.getSILResultType(IGM.getMaximalTypeExpansionContext()); + auto &errorTI = IGM.getTypeInfo(errorType); + auto &resultTI = IGM.getTypeInfo(silResultTy); + auto &resultSchema = resultTI.nativeReturnValueSchema(IGM); + auto &errorSchema = errorTI.nativeReturnValueSchema(IGM); + + if (resultSchema.requiresIndirect() || errorSchema.shouldReturnTypedErrorIndirectly() || outConv.hasIndirectSILErrorResults()) { + auto *typedErrorResultPtr = origParams.claimNext(); + args.add(typedErrorResultPtr); + } } } llvm::CallInst *createCall(FunctionPointer &fnPtr) override { diff --git a/lib/IRGen/GenThunk.cpp b/lib/IRGen/GenThunk.cpp index 49bff4aaf5dd5..4d3901b6d6152 100644 --- a/lib/IRGen/GenThunk.cpp +++ b/lib/IRGen/GenThunk.cpp @@ -142,14 +142,22 @@ void IRGenThunk::prepareArguments() { // Set the typed error value result slot. if (conv.isTypedError() && !conv.hasIndirectSILErrorResults()) { - auto directTypedErrorAddr = original.takeLast(); auto errorType = conv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()); auto &errorTI = cast(IGF.getTypeInfo(errorType)); - - IGF.setCalleeTypedErrorResultSlot(Address(directTypedErrorAddr, - errorTI.getStorageType(), - errorTI.getFixedAlignment())); + auto &errorSchema = errorTI.nativeReturnValueSchema(IGF.IGM); + auto resultType = + conv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext()); + auto &resultTI = cast(IGF.getTypeInfo(resultType)); + auto &resultSchema = resultTI.nativeReturnValueSchema(IGF.IGM); + + if (isAsync || resultSchema.requiresIndirect() || + errorSchema.shouldReturnTypedErrorIndirectly()) { + auto directTypedErrorAddr = original.takeLast(); + IGF.setCalleeTypedErrorResultSlot(Address(directTypedErrorAddr, + errorTI.getStorageType(), + errorTI.getFixedAlignment())); + } } else if (conv.isTypedError()) { auto directTypedErrorAddr = original.takeLast(); // Store for later processing when we know the argument index. @@ -329,7 +337,8 @@ void IRGenThunk::emit() { llvm::Value *errorValue = nullptr; - if (isAsync && origTy->hasErrorResult()) { + if (emission->getTypedErrorExplosion() || + (isAsync && origTy->hasErrorResult())) { SILType errorType = conv.getSILErrorType(expansionContext); Address calleeErrorSlot = emission->getCalleeErrorSlot( errorType, /*isCalleeAsync=*/origTy->isAsync()); @@ -338,6 +347,22 @@ void IRGenThunk::emit() { emission->end(); + // FIXME: we shouldn't have to generate all of this. We should just forward + // the value as is + if (auto &error = emission->getTypedErrorExplosion()) { + llvm::BasicBlock *successBB = IGF.createBasicBlock("success"); + llvm::BasicBlock *errorBB = IGF.createBasicBlock("failure"); + + llvm::Value *nil = llvm::ConstantPointerNull::get( + cast(errorValue->getType())); + auto *hasError = IGF.Builder.CreateICmpNE(errorValue, nil); + IGF.Builder.CreateCondBr(hasError, errorBB, successBB); + + IGF.Builder.emitBlock(errorBB); + IGF.emitScalarReturn(IGF.CurFn->getReturnType(), *error); + IGF.Builder.emitBlock(successBB); + } + if (isAsync) { Explosion error; if (errorValue) @@ -348,7 +373,11 @@ void IRGenThunk::emit() { // Return the result. if (result.empty()) { - IGF.Builder.CreateRetVoid(); + if (emission->getTypedErrorExplosion()) { + IGF.Builder.CreateRet(llvm::UndefValue::get(IGF.CurFn->getReturnType())); + } else { + IGF.Builder.CreateRetVoid(); + } return; } diff --git a/lib/IRGen/IRGenFunction.h b/lib/IRGen/IRGenFunction.h index 6351792cafdd0..273f6b41b7d91 100644 --- a/lib/IRGen/IRGenFunction.h +++ b/lib/IRGen/IRGenFunction.h @@ -102,7 +102,7 @@ class IRGenFunction { Explosion collectParameters(); void emitScalarReturn(SILType returnResultType, SILType funcResultType, Explosion &scalars, bool isSwiftCCReturn, - bool isOutlined); + bool isOutlined, SILType errorType = {}); void emitScalarReturn(llvm::Type *resultTy, Explosion &scalars); void emitBBForReturn(); diff --git a/lib/IRGen/IRGenSIL.cpp b/lib/IRGen/IRGenSIL.cpp index d707c5b2c5b48..8b15fadd0ac95 100644 --- a/lib/IRGen/IRGenSIL.cpp +++ b/lib/IRGen/IRGenSIL.cpp @@ -2151,7 +2151,7 @@ static void emitEntryPointArgumentsNativeCC(IRGenSILFunction &IGF, // Remap the entry block. IGF.LoweredBBs[&*IGF.CurSILFn->begin()] = LoweredBB(IGF.Builder.GetInsertBlock(), {}); } - } + } // Bind the error result by popping it off the parameter list. if (funcTy->hasErrorResult()) { @@ -2163,14 +2163,21 @@ static void emitEntryPointArgumentsNativeCC(IRGenSILFunction &IGF, bool isIndirectError = fnConv.hasIndirectSILErrorResults(); if (isTypedError && !isIndirectError) { - auto &errorTI = cast(IGF.getTypeInfo(errorType)); - IGF.setCallerTypedErrorResultSlot(Address( - emission->getCallerTypedErrorResultArgument(), - errorTI.getStorageType(), - errorTI.getFixedAlignment())); - + auto resultType = + fnConv.getSILResultType(IGF.IGM.getMaximalTypeExpansionContext()); + auto inContextResultType = IGF.CurSILFn->mapTypeIntoContext(resultType); + auto &resultTI = + cast(IGF.getTypeInfo(inContextResultType)); + auto &errorTI = cast(IGF.getTypeInfo(inContextErrorType)); + auto &native = resultTI.nativeReturnValueSchema(IGF.IGM); + auto &nativeError = errorTI.nativeReturnValueSchema(IGF.IGM); + if (funcTy->isAsync() || native.requiresIndirect() || + nativeError.shouldReturnTypedErrorIndirectly()) { + IGF.setCallerTypedErrorResultSlot( + Address(emission->getCallerTypedErrorResultArgument(), + errorTI.getStorageType(), errorTI.getFixedAlignment())); + } } else if (isTypedError && isIndirectError) { - auto &errorTI = IGF.getTypeInfo(inContextErrorType); auto ptr = emission->getCallerTypedErrorResultArgument(); auto addr = errorTI.getAddressForPointer(ptr); @@ -2317,7 +2324,6 @@ static void emitEntryPointArgumentsNativeCC(IRGenSILFunction &IGF, return IGF.getLoweredSingletonExplosion(parameter); }); } - assert(allParamValues.empty() && "didn't claim all parameters!"); } @@ -3874,12 +3880,41 @@ void IRGenSILFunction::visitFullApplySite(FullApplySite site) { } else { Builder.emitBlock(typedErrorLoadBB); - auto &ti = cast(IGM.getTypeInfo(errorType)); - Explosion errorValue; - ti.loadAsTake(*this, getCalleeTypedErrorResultSlot(errorType), errorValue); - for (unsigned i = 0, e = errorDest.phis.size(); i != e; ++i) { - errorDest.phis[i]->addIncoming(errorValue.claimNext(), Builder.GetInsertBlock()); + auto &errorTI = cast(IGM.getTypeInfo(errorType)); + auto silResultTy = + substConv.getSILResultType(IGM.getMaximalTypeExpansionContext()); + auto &resultTI = cast(IGM.getTypeInfo(silResultTy)); + + auto &resultSchema = resultTI.nativeReturnValueSchema(IGM); + auto &errorSchema = errorTI.nativeReturnValueSchema(IGM); + + if (isAsync() || substConv.hasIndirectSILErrorResults() || + resultSchema.requiresIndirect() || + errorSchema.shouldReturnTypedErrorIndirectly()) { + Explosion errorValue; + errorTI.loadAsTake(*this, getCalleeTypedErrorResultSlot(errorType), + errorValue); + for (unsigned i = 0, e = errorDest.phis.size(); i != e; ++i) { + errorDest.phis[i]->addIncoming(errorValue.claimNext(), + Builder.GetInsertBlock()); + } + } else { + auto combined = + combineResultAndTypedErrorType(IGM, resultSchema, errorSchema); + if (auto &errorValue = emission->getTypedErrorExplosion()) { + if (errorDest.phis.empty()) { + errorValue->reset(); + } else { + for (unsigned i = 0, e = errorDest.phis.size(); i != e; ++i) { + errorDest.phis[i]->addIncoming(errorValue->claimNext(), + Builder.GetInsertBlock()); + } + } + } else { + llvm_unreachable("No explosion set for direct typed error result"); + } } + Builder.CreateBr(errorDest.bb); } @@ -4318,8 +4353,14 @@ static void emitReturnInst(IRGenSILFunction &IGF, auto swiftCCReturn = funcLang == SILFunctionLanguage::Swift; assert(swiftCCReturn || funcLang == SILFunctionLanguage::C && "Need to handle all cases"); - IGF.emitScalarReturn(resultTy, funcResultType, result, swiftCCReturn, - false); + SILType errorType; + if (fnType->hasErrorResult() && conv.isTypedError() && + !conv.hasIndirectSILErrorResults()) { + errorType = + conv.getSILErrorType(IGF.IGM.getMaximalTypeExpansionContext()); + } + IGF.emitScalarReturn(resultTy, funcResultType, result, swiftCCReturn, false, + errorType); } } @@ -4347,23 +4388,85 @@ void IRGenSILFunction::visitThrowInst(swift::ThrowInst *i) { assert(!conv.hasIndirectSILErrorResults()); if (!isAsync()) { + auto fnTy = CurFn->getFunctionType(); + auto retTy = fnTy->getReturnType(); if (conv.isTypedError()) { llvm::Constant *flag = llvm::ConstantInt::get(IGM.IntPtrTy, 1); flag = llvm::ConstantExpr::getIntToPtr(flag, IGM.Int8PtrTy); Explosion errorResult = getLoweredExplosion(i->getOperand()); - auto &ti = cast(IGM.getTypeInfo(conv.getSILErrorType( - IGM.getMaximalTypeExpansionContext()))); - ti.initialize(*this, errorResult, getCallerTypedErrorResultSlot(), false); + auto silErrorTy = + conv.getSILErrorType(IGM.getMaximalTypeExpansionContext()); + auto &errorTI = cast(IGM.getTypeInfo(silErrorTy)); + + auto silResultTy = + conv.getSILResultType(IGM.getMaximalTypeExpansionContext()); + + if (silErrorTy.getASTType()->isNever()) { + emitTrap("Never can't be initialized", true); + return; + } else { + auto &resultTI = cast(IGM.getTypeInfo(silResultTy)); + auto &resultSchema = resultTI.nativeReturnValueSchema(IGM); + auto &errorSchema = errorTI.nativeReturnValueSchema(IGM); + + Builder.CreateStore(flag, getCallerErrorResultSlot()); + if (resultSchema.requiresIndirect() || + errorSchema.shouldReturnTypedErrorIndirectly()) { + errorTI.initialize(*this, errorResult, getCallerTypedErrorResultSlot(), + false); + } else { + auto combined = + combineResultAndTypedErrorType(IGM, resultSchema, errorSchema); + + if (combined.combinedTy->isVoidTy()) { + Builder.CreateRetVoid(); + return; + } + + llvm::Value *expandedResult = llvm::UndefValue::get(combined.combinedTy); + + if (!errorSchema.getExpandedType(IGM)->isVoidTy()) { + auto nativeError = errorSchema.mapIntoNative(IGM, *this, errorResult, + silErrorTy, false); + + auto convertIfNecessary = [&](llvm::Type *nativeTy, + llvm::Value *elt) -> llvm::Value * { + auto *eltTy = elt->getType(); + if (nativeTy->isIntOrPtrTy() && eltTy->isIntOrPtrTy() && + nativeTy->getPrimitiveSizeInBits() != + eltTy->getPrimitiveSizeInBits()) { + assert(nativeTy->getPrimitiveSizeInBits() > + eltTy->getPrimitiveSizeInBits()); + return Builder.CreateZExt(elt, nativeTy); + } + return elt; + }; + + if (auto *structTy = dyn_cast(combined.combinedTy)) { + for (unsigned i : combined.errorValueMapping) { + llvm::Value *elt = nativeError.claimNext(); + auto *nativeTy = structTy->getElementType(i); + elt = convertIfNecessary(nativeTy, elt); + expandedResult = Builder.CreateInsertValue(expandedResult, elt, i); + } + } else if (!errorSchema.getExpandedType(IGM)->isVoidTy()) { + expandedResult = + convertIfNecessary(combined.combinedTy, nativeError.claimNext()); + } + } + + Explosion nativeAgg = Explosion(expandedResult); + emitScalarReturn(combined.combinedTy, nativeAgg); - Builder.CreateStore(flag, getCallerErrorResultSlot()); + return; + } + } } else { Explosion errorResult = getLoweredExplosion(i->getOperand()); Builder.CreateStore(errorResult.claimNext(), getCallerErrorResultSlot()); } // Create a normal return, but leaving the return value undefined. - auto fnTy = CurFn->getFunctionType(); - auto retTy = fnTy->getReturnType(); if (retTy->isVoidTy()) { Builder.CreateRetVoid(); } else { diff --git a/lib/IRGen/NativeConventionSchema.h b/lib/IRGen/NativeConventionSchema.h index 218786ceaef47..c1230d3f04e5f 100644 --- a/lib/IRGen/NativeConventionSchema.h +++ b/lib/IRGen/NativeConventionSchema.h @@ -42,6 +42,9 @@ class NativeConventionSchema { NativeConventionSchema &operator=(const NativeConventionSchema&) = delete; bool requiresIndirect() const { return RequiresIndirect; } + bool shouldReturnTypedErrorIndirectly() const { + return requiresIndirect() || Lowering.shouldReturnTypedErrorIndirectly(); + } bool empty() const { return Lowering.empty(); } llvm::Type *getExpandedType(IRGenModule &IGM) const; diff --git a/test/IRGen/typed_throws.sil b/test/IRGen/typed_throws.sil index 3338156abf65d..3049232584c3f 100644 --- a/test/IRGen/typed_throws.sil +++ b/test/IRGen/typed_throws.sil @@ -18,15 +18,13 @@ sil_vtable A {} sil @create_error : $@convention(thin) () -> @owned A -// CHECK: define{{.*}} swiftcc void @throw_error(ptr swiftself %0, ptr noalias nocapture swifterror dereferenceable({{.*}}) %1, ptr %2) +// CHECK: define{{.*}} swiftcc { ptr, ptr } @throw_error(ptr swiftself %0, ptr noalias nocapture swifterror dereferenceable({{.*}}) %1) // CHECK: [[ERR:%.*]] = call swiftcc ptr @create_error() // CHECK: call ptr @swift_retain(ptr returned [[ERR]]) -// CHECK: [[F1:%.*]] = getelementptr inbounds %T12typed_throws1SV, ptr %2, i32 0, i32 0 -// CHECK: store ptr [[ERR]], ptr [[F1]] -// CHECK: [[F2:%.*]] = getelementptr inbounds %T12typed_throws1SV, ptr %2, i32 0, i32 1 -// CHECK: store ptr [[ERR]], ptr [[F2]] // CHECK: store ptr inttoptr (i64 1 to ptr), ptr %1 -// CHECK: ret void +// CHECK: [[RET_v1:%.*]] = insertvalue { ptr, ptr } undef, ptr [[ERR]], 0 +// CHECK: [[RET_v2:%.*]] = insertvalue { ptr, ptr } [[RET_v1]], ptr [[ERR]], 1 +// CHECK: ret { ptr, ptr } [[RET_v2]] // CHECK: } sil @throw_error : $@convention(thin) () -> @error S { @@ -49,27 +47,24 @@ sil @try_apply_helper : $@convention(thin) (@owned AnyObject) -> (@owned AnyObje // CHECK: entry: // CHECK: %swifterror = alloca swifterror ptr // CHECK: store ptr null, ptr %swifterror -// CHECK: %swifterror1 = alloca %T12typed_throws1SV -// CHECK: [[RES:%.*]] = call swiftcc ptr @try_apply_helper(ptr %0, ptr swiftself undef, ptr noalias nocapture swifterror dereferenceable({{.*}}) %swifterror, ptr %swifterror1) +// CHECK: [[RES:%.*]] = call swiftcc { ptr, ptr } @try_apply_helper(ptr %0, ptr swiftself undef, ptr noalias nocapture swifterror dereferenceable({{.*}}) %swifterror) +// CHECK: [[RES_0:%.*]] = extractvalue { ptr, ptr } [[RES]], 0 +// CHECK: [[RES_1:%.*]] = extractvalue { ptr, ptr } [[RES]], 1 // CHECK: [[ERRFLAG:%.*]] = load ptr, ptr %swifterror // CHECK: [[C:%.*]] = icmp ne ptr [[ERRFLAG]], null // CHECK: br i1 [[C]], label %[[ERR_B:.*]], label %[[SUCC_B:[0-9]+]] // CHECK: [[ERR_B]]: -// CHECK: %swifterror1.x = getelementptr inbounds %T12typed_throws1SV, ptr %swifterror1, i32 0, i32 0 -// CHECK: [[ERR_v1:%.*]] = load ptr, ptr %swifterror1.x -// CHECK: %swifterror1.y = getelementptr inbounds %T12typed_throws1SV, ptr %swifterror1, i32 0, i32 1 -// CHECK: [[ERR_v2:%.*]] = load ptr, ptr %swifterror1.y // CHECK: br label %[[ERR2_B:[0-9]+]] // CHECK: [[SUCC_B]]: -// CHECK: [[R:%.*]] = phi ptr [ [[RES]], %entry ] +// CHECK: [[R:%.*]] = phi ptr [ [[RES_0]], %entry ] // CHECK: call void @swift_{{.*}}elease(ptr [[R]]) // CHECK: br label %[[RET_B:[0-9]+]] // CHECK: [[ERR2_B]]: -// CHECK: [[E1:%.*]] = phi ptr [ [[ERR_v1]], %[[ERR_B]] ] -// CHECK: [[E2:%.*]] = phi ptr [ [[ERR_v2]], %[[ERR_B]] ] +// CHECK: [[E1:%.*]] = phi ptr [ [[RES_0]], %[[ERR_B]] ] +// CHECK: [[E2:%.*]] = phi ptr [ [[RES_1]], %[[ERR_B]] ] // CHECK: store ptr null, ptr %swifterror // CHECK: call void @swift_release(ptr [[E1]]) // CHECK: call void @swift_release(ptr [[E2]]) @@ -210,9 +205,9 @@ bb6: return %7 : $() } -// CHECK: define{{.*}} internal swiftcc ptr @"$s16try_apply_helperTA"(ptr swiftself %0, ptr noalias nocapture swifterror dereferenceable({{.*}}) %1, ptr %2) -// CHECK: tail call swiftcc ptr @try_apply_helper(ptr {{.*}}, ptr swiftself undef, ptr noalias nocapture swifterror dereferenceable({{.*}}) %1, ptr %2) -// CHECK: ret ptr +// CHECK: define{{.*}} internal swiftcc { ptr, ptr } @"$s16try_apply_helperTA"(ptr swiftself %0, ptr noalias nocapture swifterror dereferenceable({{.*}}) %1) +// CHECK: tail call swiftcc { ptr, ptr } @try_apply_helper(ptr {{.*}}, ptr swiftself undef, ptr noalias nocapture swifterror dereferenceable({{.*}}) %1) +// CHECK: ret { ptr, ptr } sil @partial_apply_test : $@convention(thin) (@owned AnyObject) -> @owned @callee_guaranteed () ->(@owned AnyObject, @error S) { entry(%0: $AnyObject): @@ -235,8 +230,7 @@ entry(%0: $AnyObject): // CHECK:entry: // CHECK: %swifterror = alloca swifterror ptr // CHECK: store ptr null, ptr %swifterror -// CHECK: %swifterror1 = alloca %T12typed_throws1SV -// CHECK: call swiftcc ptr %0(ptr swiftself %1, ptr noalias nocapture swifterror dereferenceable({{[0-9]+}}) %swifterror, ptr %swifterror1) +// CHECK: call swiftcc { ptr, ptr } %0(ptr swiftself %1, ptr noalias nocapture swifterror dereferenceable({{[0-9]+}}) %swifterror) sil @apply_closure : $@convention(thin) (@guaranteed @callee_guaranteed () -> (@owned AnyObject, @error S)) -> () { entry(%0 : $@callee_guaranteed () ->(@owned AnyObject, @error S)): diff --git a/test/IRGen/typed_throws.swift b/test/IRGen/typed_throws.swift index dbe1a616e5b78..fabbd64a021c0 100644 --- a/test/IRGen/typed_throws.swift +++ b/test/IRGen/typed_throws.swift @@ -9,9 +9,9 @@ public enum MyBigError: Error { case epicFail + case evenBiggerFail } - // CHECK-MANGLE: @"$s12typed_throws1XVAA1PAAWP" = hidden global [2 x ptr] [ptr @"$s12typed_throws1XVAA1PAAMc", ptr getelementptr inbounds (i8, ptr @"symbolic ySi_____YKc 12typed_throws10MyBigErrorO", {{i32|i64}} 1)] struct X: P { typealias A = (Int) throws(MyBigError) -> Void @@ -52,7 +52,7 @@ func five() -> Int { 5 } func fiveOrBust() throws -> Int { 5 } -func fiveOrTypedBust() throws(MyBigError) -> Int { 5 } +func fiveOrTypedBust() throws(MyBigError) -> Int { throw MyBigError.epicFail } func reabstractAsNonthrowing() -> Int { passthroughCall(five) @@ -69,7 +69,8 @@ func reabstractAsConcreteThrowing() throws -> Int { // CHECK-LABEL: define {{.*}} swiftcc void @"$sSi12typed_throws10MyBigErrorOIgdzo_SiACIegrzr_TR"(ptr noalias nocapture sret(%TSi) %0, ptr %1, ptr %2, ptr swiftself %3, ptr noalias nocapture swifterror dereferenceable(8) %4, ptr %5) // CHECK: call swiftcc {{i32|i64}} %1 -// CHECK: br i1 %8, label %typed.error.load, label %10 +// CHECK: [[CMP:%.*]] = icmp ne ptr {{%.*}}, null +// CHECK: br i1 [[CMP]], label %typed.error.load struct S : Error { } diff --git a/test/IRGen/typed_throws_thunks.swift b/test/IRGen/typed_throws_thunks.swift index 8799a3c5acb05..b38c5b8102334 100644 --- a/test/IRGen/typed_throws_thunks.swift +++ b/test/IRGen/typed_throws_thunks.swift @@ -57,9 +57,9 @@ extension P { } } - // CHECK-LABEL: define{{.*}} swiftcc void @"$s19typed_throws_thunks1PP2g34bodyyyyAA9FixedSizeVYKXE_tAGYKFTj"(ptr %0, ptr %1, ptr noalias swiftself %2, ptr noalias nocapture swifterror dereferenceable(8) %3, ptr %4, ptr %5, ptr %6) + // CHECK-LABEL: define{{.*}} swiftcc { i64, i64 } @"$s19typed_throws_thunks1PP2g34bodyyyyAA9FixedSizeVYKXE_tAGYKFTj"(ptr %0, ptr %1, ptr noalias swiftself %2, ptr noalias nocapture swifterror dereferenceable(8) %3, ptr %4, ptr %5) // CHECK-NOT: ret - // CHECK: call swiftcc void {{.*}}(ptr %0, ptr %1, ptr noalias swiftself %2, ptr noalias nocapture swifterror dereferenceable(8) %3, ptr %4, ptr %5, ptr %6) + // CHECK: call swiftcc { i64, i64 } {{.*}}(ptr %0, ptr %1, ptr noalias swiftself %2, ptr noalias nocapture swifterror dereferenceable(8) %3, ptr %4, ptr %5) public func g3(body: () throws(FixedSize) -> Void) throws(FixedSize) { From 3d4163a31947f3dfc74406adf0cf8c1fa721ad45 Mon Sep 17 00:00:00 2001 From: Dario Rexin Date: Thu, 6 Jun 2024 12:52:51 -0700 Subject: [PATCH 2/2] Address review feedback --- lib/IRGen/GenCall.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/IRGen/GenCall.cpp b/lib/IRGen/GenCall.cpp index 64d9d79c8be17..3b2df8597400b 100644 --- a/lib/IRGen/GenCall.cpp +++ b/lib/IRGen/GenCall.cpp @@ -373,9 +373,15 @@ static void addIndirectResultAttributes(IRGenModule &IGM, attrs = attrs.addParamAttributes(IGM.getLLVMContext(), paramIndex, b); } +// This function should only be called with directly returnable +// result and error types. Errors can only be returned directly if +// they consists solely of int and ptr values. CombinedResultAndErrorType irgen::combineResultAndTypedErrorType( const IRGenModule &IGM, const NativeConventionSchema &resultSchema, const NativeConventionSchema &errorSchema) { + assert(!resultSchema.requiresIndirect()); + assert(!errorSchema.shouldReturnTypedErrorIndirectly()); + CombinedResultAndErrorType result; SmallVector elts; resultSchema.enumerateComponents( @@ -403,6 +409,8 @@ CombinedResultAndErrorType irgen::combineResultAndTypedErrorType( } auto *error = *errorIt; + assert(error->isIntOrPtrTy() && + "Direct errors must only consist of int or ptr values"); result.errorValueMapping.push_back(combined.size()); if (res->getPrimitiveSizeInBits() >= error->getPrimitiveSizeInBits()) { combined.push_back(res);