diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index fe0c0bc60fa..9da8d56c20f 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2896,6 +2896,11 @@ DIAGNOSTIC( noTypeConformancesFoundForInterface, "No type conformances are found for interface '$0'. Code generation for current target " "requires at least one implementation type present in the linkage.") +DIAGNOSTIC( + 50101, + Error, + dynamicDispatchOnPotentiallyUninitializedExistential, + "Cannot dynamically dispatch on potentially uninitialized interface object '$0'.") DIAGNOSTIC( 52000, diff --git a/source/slang/slang-emit-vm.cpp b/source/slang/slang-emit-vm.cpp index 71439ba0b9d..6c9827510f5 100644 --- a/source/slang/slang-emit-vm.cpp +++ b/source/slang/slang-emit-vm.cpp @@ -1083,6 +1083,11 @@ class ByteCodeEmitter } } + bool isUnreachableBlock(IRBlock* block) + { + return as(block->getTerminator()) != nullptr; + } + void emitFunction(IRFunc* func) { VMByteCodeFunctionBuilder funcBuilder; @@ -1100,6 +1105,9 @@ class ByteCodeEmitter for (auto block : func->getBlocks()) { + if (isUnreachableBlock(block)) + continue; + mapBlockToByteOffset[block] = funcBuilder.code.getCount(); for (auto inst : block->getChildren()) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 294de16d6fa..6d7234f620a 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -22,6 +22,7 @@ #include "slang-emit-vm.h" #include "slang-emit-wgsl.h" #include "slang-ir-any-value-inference.h" +#include "slang-ir-any-value-marshalling.h" #include "slang-ir-autodiff.h" #include "slang-ir-bind-existentials.h" #include "slang-ir-byte-address-legalize.h" @@ -77,9 +78,9 @@ #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-coopvec.h" +#include "slang-ir-lower-dynamic-dispatch-insts.h" #include "slang-ir-lower-dynamic-resource-heap.h" #include "slang-ir-lower-enum-type.h" -#include "slang-ir-lower-generics.h" #include "slang-ir-lower-glsl-ssbo-types.h" #include "slang-ir-lower-l-value-cast.h" #include "slang-ir-lower-optional-type.h" @@ -114,6 +115,7 @@ #include "slang-ir-synthesize-active-mask.h" #include "slang-ir-transform-params-to-constref.h" #include "slang-ir-translate-global-varying-var.h" +#include "slang-ir-typeflow-specialize.h" #include "slang-ir-undo-param-copy.h" #include "slang-ir-uniformity.h" #include "slang-ir-user-type-hint.h" @@ -862,6 +864,12 @@ Result linkAndOptimizeIR( // Lower all the LValue implict casts (used for out/inout/ref scenarios) SLANG_PASS(lowerLValueCast, targetProgram); + // Lower enum types early since enums and enum casts may appear in + // specialization & not resolving them here would block specialization. + // + if (requiredLoweringPassSet.enumType) + SLANG_PASS(lowerEnumType, sink); + IRSimplificationOptions defaultIRSimplificationOptions = IRSimplificationOptions::getDefault(targetProgram); IRSimplificationOptions fastIRSimplificationOptions = @@ -1111,21 +1119,14 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(SLANG_PASS(performTypeInlining, targetProgram, sink)); } - if (requiredLoweringPassSet.reinterpret) - SLANG_PASS(lowerReinterpret, targetProgram, sink); - if (sink->getErrorCount() != 0) return SLANG_FAIL; validateIRModuleIfEnabled(codeGenContext, irModule); - SLANG_PASS(inferAnyValueSizeWhereNecessary, targetProgram); + SLANG_PASS(inferAnyValueSizeWhereNecessary, targetProgram, sink); - // If we have any witness tables that are marked as `KeepAlive`, - // but are not used for dynamic dispatch, unpin them so we don't - // do unnecessary work to lower them. SLANG_PASS(unpinWitnessTables); - if (!fastIRSimplificationOptions.minimalOptimization) { SLANG_PASS(simplifyIR, targetProgram, fastIRSimplificationOptions, sink); @@ -1135,6 +1136,26 @@ Result linkAndOptimizeIR( SLANG_PASS(eliminateDeadCode, fastIRSimplificationOptions.deadCodeElimOptions); } + // Tagged union type lowering typically generates more reinterpret instructions. + if (SLANG_PASS(lowerTaggedUnionTypes, sink)) + requiredLoweringPassSet.reinterpret = true; + + SLANG_PASS(lowerUntaggedUnionTypes, targetProgram, sink); + + if (requiredLoweringPassSet.reinterpret) + SLANG_PASS(lowerReinterpret, targetProgram, sink); + + SLANG_PASS(lowerSequentialIDTagCasts, codeGenContext->getLinkage(), sink); + SLANG_PASS(lowerTagInsts, sink); + SLANG_PASS(lowerTagTypes); + + SLANG_PASS(eliminateDeadCode, fastIRSimplificationOptions.deadCodeElimOptions); + + SLANG_PASS(lowerExistentials, targetProgram, sink); + + if (sink->getErrorCount() != 0) + return SLANG_FAIL; + if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc) && targetProgram->getOptionSet().shouldRunNonEssentialValidation()) { @@ -1143,16 +1164,9 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(SLANG_PASS(checkGetStringHashInsts, sink)); } - // For targets that supports dynamic dispatch, we need to lower the - // generics / interface types to ordinary functions and types using - // function pointers. - if (requiredLoweringPassSet.generics) - SLANG_PASS(lowerGenerics, targetProgram, sink); - else - SLANG_PASS(cleanupGenerics, targetProgram, sink); + SLANG_PASS(lowerTuples, sink); - if (requiredLoweringPassSet.enumType) - SLANG_PASS(lowerEnumType, sink); + SLANG_PASS(generateAnyValueMarshallingFunctions); // Don't need to run any further target-dependent passes if we are generating code // for host vm. diff --git a/source/slang/slang-ir-any-value-inference.cpp b/source/slang/slang-ir-any-value-inference.cpp index e2d0494ce73..751ca7c2c22 100644 --- a/source/slang/slang-ir-any-value-inference.cpp +++ b/source/slang/slang-ir-any-value-inference.cpp @@ -1,7 +1,6 @@ #include "slang-ir-any-value-inference.h" #include "../core/slang-func-ptr.h" -#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-util.h" @@ -90,7 +89,10 @@ List sortTopologically( return sortedInterfaceTypes; } -void inferAnyValueSizeWhereNecessary(IRModule* module, TargetProgram* targetProgram) +void inferAnyValueSizeWhereNecessary( + IRModule* module, + TargetProgram* targetProgram, + DiagnosticSink* sink) { // Go through the global insts and collect all interface types. // For each interface type, infer its any-value-size, by looking up @@ -130,9 +132,10 @@ void inferAnyValueSizeWhereNecessary(IRModule* module, TargetProgram* targetProg if (interfaceType->findDecoration()) continue; - // If the interface already has an explicit any-value-size, don't infer anything. + /* If the interface already has an explicit any-value-size, don't infer anything. if (interfaceType->findDecoration()) continue; + */ // Skip interfaces that are not implemented by any type. if (!implementedInterfaces.contains(interfaceType)) @@ -213,6 +216,12 @@ void inferAnyValueSizeWhereNecessary(IRModule* module, TargetProgram* targetProg for (auto interfaceType : sortedInterfaceTypes) { + IRIntegerValue existingMaxSize = (IRIntegerValue)kMaxInt; // Default to max int. + if (auto existingAnyValueDecor = interfaceType->findDecoration()) + { + existingMaxSize = existingAnyValueDecor->getSize(); + } + IRIntegerValue maxAnyValueSize = -1; for (auto implType : mapInterfaceToImplementations[interfaceType]) { @@ -223,6 +232,18 @@ void inferAnyValueSizeWhereNecessary(IRModule* module, TargetProgram* targetProg &sizeAndAlignment); maxAnyValueSize = Math::Max(maxAnyValueSize, sizeAndAlignment.size); + + // Diagnose if the existing any-value-size is smaller than the inferred size. + if (existingMaxSize < sizeAndAlignment.size) + { + sink->diagnose(implType, Diagnostics::typeDoesNotFitAnyValueSize, implType); + sink->diagnoseWithoutSourceView( + implType, + Diagnostics::typeAndLimit, + implType, + sizeAndAlignment.size, + existingMaxSize); + } } // Should not encounter interface types without any conforming implementations. @@ -232,7 +253,10 @@ void inferAnyValueSizeWhereNecessary(IRModule* module, TargetProgram* targetProg if (maxAnyValueSize >= 0) { IRBuilder builder(module); - builder.addAnyValueSizeDecoration(interfaceType, maxAnyValueSize); + if (!interfaceType->findDecoration()) + { + builder.addAnyValueSizeDecoration(interfaceType, maxAnyValueSize); + } } } } diff --git a/source/slang/slang-ir-any-value-inference.h b/source/slang/slang-ir-any-value-inference.h index e92550d3dd4..df21ba95b80 100644 --- a/source/slang/slang-ir-any-value-inference.h +++ b/source/slang/slang-ir-any-value-inference.h @@ -8,5 +8,8 @@ namespace Slang { -void inferAnyValueSizeWhereNecessary(IRModule* module, TargetProgram* targetProgram); +void inferAnyValueSizeWhereNecessary( + IRModule* module, + TargetProgram* targetProgram, + DiagnosticSink* sink); } diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 1b306978d32..2bb5ed24aa3 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -1,8 +1,8 @@ #include "slang-ir-any-value-marshalling.h" #include "../core/slang-math.h" -#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" #include "slang-ir.h" #include "slang-legalize-types.h" @@ -14,7 +14,36 @@ namespace Slang // functions. struct AnyValueMarshallingContext { - SharedGenericsLoweringContext* sharedContext; + IRModule* module; + + // We will use a single work list of instructions that need + // to be considered for lowering. + // + InstWorkList workList; + InstHashSet workListSet; + + AnyValueMarshallingContext(IRModule* module) + : module(module), workList(module), workListSet(module) + { + } + + void addToWorkList(IRInst* inst) + { + if (!inst) + return; + + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as(ii)) + return; + } + + if (workListSet.contains(inst)) + return; + + workList.add(inst); + workListSet.add(inst); + } // Stores information about generated `AnyValue` struct types. struct AnyValueTypeInfo : RefObject @@ -54,7 +83,7 @@ struct AnyValueMarshallingContext if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size)) return typeInfo->Ptr(); RefPtr info = new AnyValueTypeInfo(); - IRBuilder builder(sharedContext->module); + IRBuilder builder(module); builder.setInsertBefore(type); auto structType = builder.createStructType(); info->type = structType; @@ -498,7 +527,7 @@ struct AnyValueMarshallingContext IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType) { - IRBuilder builder(sharedContext->module); + IRBuilder builder(module); builder.setInsertBefore(type); auto anyValInfo = ensureAnyValueType(anyValueType); @@ -767,7 +796,7 @@ struct AnyValueMarshallingContext IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType) { - IRBuilder builder(sharedContext->module); + IRBuilder builder(module); builder.setInsertBefore(type); auto anyValInfo = ensureAnyValueType(anyValueType); @@ -822,7 +851,7 @@ struct AnyValueMarshallingContext auto func = ensureMarshallingFunc( operand->getDataType(), cast(packInst->getDataType())); - IRBuilder builderStorage(sharedContext->module); + IRBuilder builderStorage(module); auto builder = &builderStorage; builder->setInsertBefore(packInst); auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand); @@ -836,7 +865,7 @@ struct AnyValueMarshallingContext auto func = ensureMarshallingFunc( unpackInst->getDataType(), cast(operand->getDataType())); - IRBuilder builderStorage(sharedContext->module); + IRBuilder builderStorage(module); auto builder = &builderStorage; builder->setInsertBefore(unpackInst); auto callInst = @@ -869,37 +898,35 @@ struct AnyValueMarshallingContext // since we will re-use that state for any code we // generate along the way. // - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + addToWorkList(module->getModuleInst()); - while (sharedContext->workList.getCount() != 0) + while (workList.getCount() != 0) { - IRInst* inst = sharedContext->workList.getLast(); + IRInst* inst = workList.getLast(); - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); + workList.removeLast(); + workListSet.remove(inst); processInst(inst); for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { - sharedContext->addToWorkList(child); + addToWorkList(child); } } // Finally, replace all `AnyValueType` with the actual struct type that implements it. - for (auto inst : sharedContext->module->getModuleInst()->getChildren()) + for (auto inst : module->getModuleInst()->getChildren()) { if (auto anyValueType = as(inst)) processAnyValueType(anyValueType); } - sharedContext->mapInterfaceRequirementKeyValue.clear(); } }; -void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext) +void generateAnyValueMarshallingFunctions(IRModule* module) { - AnyValueMarshallingContext context; - context.sharedContext = sharedContext; + AnyValueMarshallingContext context(module); context.processModule(); } @@ -1018,12 +1045,14 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { return alignUp(offset, 4) + kRTTIHandleSize; } + case kIROp_SetTagType: + { + return alignUp(offset, 4) + 4; + } case kIROp_InterfaceType: { auto interfaceType = cast(type); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); size += kRTTIHeaderSize; return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } @@ -1041,18 +1070,14 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { auto thisType = cast(type); auto interfaceType = thisType->getConstraintType(); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } case kIROp_ExtractExistentialType: { auto existentialValue = type->getOperand(0); auto interfaceType = cast(existentialValue->getDataType()); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } case kIROp_LookupWitnessMethod: @@ -1087,9 +1112,7 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { anyValueSize = Math::Min( anyValueSize, - SharedGenericsLoweringContext::getInterfaceAnyValueSize( - assocType->getOperand(i), - type->sourceLoc)); + getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc)); } if (anyValueSize == kInvalidAnyValueSize) diff --git a/source/slang/slang-ir-any-value-marshalling.h b/source/slang/slang-ir-any-value-marshalling.h index 941dae5e8c2..bfebf1b0123 100644 --- a/source/slang/slang-ir-any-value-marshalling.h +++ b/source/slang/slang-ir-any-value-marshalling.h @@ -6,13 +6,13 @@ namespace Slang { struct IRType; -struct SharedGenericsLoweringContext; +struct IRModule; /// Generates functions that pack and unpack `AnyValue`s, and replaces /// all `IRPackAnyValue` and `IRUnpackAnyValue` instructions with calls /// to these packing/unpacking functions. /// This is a sub-pass of lower-generics. -void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext); +void generateAnyValueMarshallingFunctions(IRModule* module); /// Get the AnyValue size required to hold a value of `type`. diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 5ca277672c0..44b6a9babad 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1194,7 +1194,11 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( // the generic args to specialize the primal function. This is true for all of our core // module functions, but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffBaseSpecialize->getBase()->getDataType(), + args.getCount(), + args.getBuffer()), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); @@ -1209,7 +1213,11 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( } diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffBase->getDataType(), + args.getCount(), + args.getBuffer()), diffBase, args.getCount(), args.getBuffer()); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ab964aef69b..be1687289b7 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -1435,8 +1435,28 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( { args.add(primalSpecialize->getArg(i)); } + + IRType* typeForSpecialization = nullptr; + switch ((*diffBase)->getDataType()->getOp()) + { + case kIROp_TypeKind: + case kIROp_GenericKind: + typeForSpecialization = (*diffBase)->getDataType(); + break; + case kIROp_Generic: + typeForSpecialization = (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + (*diffBase)->getDataType(), + args.getCount(), + args.getBuffer()); + break; + default: + typeForSpecialization = builder->getTypeKind(); + break; + } + auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + typeForSpecialization, *diffBase, args.getCount(), args.getBuffer()); @@ -1472,7 +1492,11 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( // the generic args to specialize the primal function. This is true for all of our core // module functions, but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffBaseSpecialize->getBase()->getDataType(), + args.getCount(), + args.getBuffer()), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); @@ -1488,7 +1512,11 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( } auto diffCallee = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffCallee->getDataType(), + args.getCount(), + args.getBuffer()), diffCallee, args.getCount(), args.getBuffer()); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 0c6fcfd5627..9a1172492d2 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1034,6 +1034,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene mapDifferentialInst(origGeneric->getFirstBlock(), bodyBlock); transcribeBlockImpl(&builder, origGeneric->getFirstBlock(), instsToSkip); + auto diffReturnVal = getGenericReturnVal(diffGeneric); + if (auto func = as(diffReturnVal)) + { + IRInst* outSpecializedValue = nullptr; + auto hoistedFuncType = + hoistValueFromGeneric(builder, func->getDataType(), outSpecializedValue); + diffGeneric->setFullType((IRType*)hoistedFuncType); + } + return InstPair(primalGeneric, diffGeneric); } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index c70374e7706..b24ed50095f 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2212,7 +2212,7 @@ struct DiffTransposePass // If we reach this point, revValue must be a differentiable type. auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness( builder, - primalType, + fwdInst->getDataType(), DiffConformanceKind::Value); SLANG_ASSERT(revTypeWitness); diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 84154c18261..ea0c681492d 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -670,16 +670,17 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex) { + SLANG_UNUSED(operandIndex); // There are some type of operands that needs to be treated as // "weak" references -- they can never hold things alive, and // whenever we delete the referenced value, these operands needs // to be replaced with `undef`. switch (inst->getOp()) { - case kIROp_BoundInterfaceType: + /*case kIROp_BoundInterfaceType: if (inst->getOperand(operandIndex)->getOp() == kIROp_WitnessTable) return true; - break; + break;*/ case kIROp_SpecializationDictionaryItem: // Ignore all operands of SpecializationDictionaryItem. // This inst is used as a cache and shouldn't hold anything alive. diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp deleted file mode 100644 index 0dee5631a93..00000000000 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ /dev/null @@ -1,443 +0,0 @@ -// slang-ir-generics-lowering-context.cpp - -#include "slang-ir-generics-lowering-context.h" - -#include "slang-ir-layout.h" -#include "slang-ir-util.h" - -namespace Slang -{ -bool isPolymorphicType(IRInst* typeInst) -{ - if (as(typeInst) && as(typeInst->getDataType())) - return true; - switch (typeInst->getOp()) - { - case kIROp_ThisType: - case kIROp_AssociatedType: - case kIROp_InterfaceType: - case kIROp_LookupWitnessMethod: - return true; - case kIROp_Specialize: - { - for (UInt i = 0; i < typeInst->getOperandCount(); i++) - { - if (isPolymorphicType(typeInst->getOperand(i))) - return true; - } - return false; - } - default: - break; - } - if (auto ptrType = as(typeInst)) - { - return isPolymorphicType(ptrType->getValueType()); - } - return false; -} - -bool isTypeValue(IRInst* typeInst) -{ - if (typeInst) - { - switch (typeInst->getOp()) - { - case kIROp_TypeType: - case kIROp_TypeKind: - return true; - default: - return false; - } - } - return false; -} - -IRInst* SharedGenericsLoweringContext::maybeEmitRTTIObject(IRInst* typeInst) -{ - IRInst* result = nullptr; - if (mapTypeToRTTIObject.tryGetValue(typeInst, result)) - return result; - IRBuilder builderStorage(module); - auto builder = &builderStorage; - builder->setInsertAfter(typeInst); - - result = builder->emitMakeRTTIObject(typeInst); - - // For now the only type info we encapsualte is type size. - IRSizeAndAlignment sizeAndAlignment; - getNaturalSizeAndAlignment(targetProgram->getOptionSet(), (IRType*)typeInst, &sizeAndAlignment); - builder->addRTTITypeSizeDecoration(result, sizeAndAlignment.size); - - // Give a name to the rtti object. - if (auto exportDecoration = typeInst->findDecoration()) - { - String rttiObjName = exportDecoration->getMangledName(); - builder->addExportDecoration(result, rttiObjName.getUnownedSlice()); - } - - // Make sure the RTTI object for an exported struct type is marked as export if the type is. - if (typeInst->findDecoration()) - { - builder->addHLSLExportDecoration(result); - builder->addKeepAliveDecoration(result); - } - mapTypeToRTTIObject[typeInst] = result; - return result; -} - -IRInst* SharedGenericsLoweringContext::findInterfaceRequirementVal( - IRInterfaceType* interfaceType, - IRInst* requirementKey) -{ - if (auto dict = mapInterfaceRequirementKeyValue.tryGetValue(interfaceType)) - return dict->getValue(requirementKey); - _builldInterfaceRequirementMap(interfaceType); - return findInterfaceRequirementVal(interfaceType, requirementKey); -} - -void SharedGenericsLoweringContext::_builldInterfaceRequirementMap(IRInterfaceType* interfaceType) -{ - mapInterfaceRequirementKeyValue.add(interfaceType, Dictionary()); - auto dict = mapInterfaceRequirementKeyValue.tryGetValue(interfaceType); - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - auto entry = cast(interfaceType->getOperand(i)); - (*dict)[entry->getRequirementKey()] = entry->getRequirementVal(); - } -} - -IRType* SharedGenericsLoweringContext::lowerAssociatedType(IRBuilder* builder, IRInst* type) -{ - if (type->getOp() != kIROp_AssociatedType) - return (IRType*)type; - IRIntegerValue anyValueSize = kInvalidAnyValueSize; - for (UInt i = 0; i < type->getOperandCount(); i++) - { - anyValueSize = - Math::Min(anyValueSize, getInterfaceAnyValueSize(type->getOperand(i), type->sourceLoc)); - } - if (anyValueSize == kInvalidAnyValueSize) - { - // We could conceivably make it an error to have an associated type - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - anyValueSize = kDefaultAnyValueSize; - } - return builder->getAnyValueType(anyValueSize); -} - -IRType* SharedGenericsLoweringContext::lowerType( - IRBuilder* builder, - IRInst* paramType, - const Dictionary& typeMapping, - IRType* concreteType) -{ - if (!paramType) - return nullptr; - - IRInst* resultType; - if (typeMapping.tryGetValue(paramType, resultType)) - return (IRType*)resultType; - - if (isTypeValue(paramType)) - { - return builder->getRTTIHandleType(); - } - - switch (paramType->getOp()) - { - case kIROp_WitnessTableType: - case kIROp_WitnessTableIDType: - case kIROp_ExtractExistentialType: - // Do not translate these types. - return (IRType*)paramType; - case kIROp_Param: - { - if (auto anyValueSizeDecor = paramType->findDecoration()) - { - if (isBuiltin(anyValueSizeDecor->getConstraintType())) - return (IRType*)paramType; - auto anyValueSize = getInterfaceAnyValueSize( - anyValueSizeDecor->getConstraintType(), - paramType->sourceLoc); - return builder->getAnyValueType(anyValueSize); - } - // We could conceivably make it an error to have a generic parameter - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - return builder->getAnyValueType(kDefaultAnyValueSize); - } - case kIROp_ThisType: - { - auto interfaceType = cast(paramType)->getConstraintType(); - - if (isBuiltin(interfaceType)) - return (IRType*)paramType; - - if (isComInterfaceType((IRType*)interfaceType)) - return (IRType*)interfaceType; - - auto anyValueSize = getInterfaceAnyValueSize( - cast(paramType)->getConstraintType(), - paramType->sourceLoc); - return builder->getAnyValueType(anyValueSize); - } - case kIROp_AssociatedType: - { - return lowerAssociatedType(builder, paramType); - } - case kIROp_InterfaceType: - { - if (isBuiltin(paramType)) - return (IRType*)paramType; - - if (isComInterfaceType((IRType*)paramType)) - return (IRType*)paramType; - - // In the dynamic-dispatch case, a value of interface type - // is going to be packed into the "any value" part of a tuple. - // The size of the "any value" part depends on the interface - // type (e.g., it might have an `[anyValueSize(8)]` attribute - // indicating that 8 bytes needs to be reserved). - // - auto anyValueSize = getInterfaceAnyValueSize(paramType, paramType->sourceLoc); - - // If there is a non-null `concreteType` parameter, then this - // interface type is one that has been statically bound (via - // specialization parameters) to hold a value of that concrete - // type. - // - IRType* pendingType = nullptr; - if (concreteType) - { - // Because static specialization is being used (at least in part), - // we do *not* have a guarantee that the `concreteType` is one - // that can fit into the `anyValueSize` of the interface. - // - // We will use the IR layout logic to see if we can compute - // a size for the type, which can lead to a few different outcomes: - // - // * If a size is computed successfully, and it is smaller than or - // equal to `anyValueSize`, then the concrete value will fit into - // the reserved area, and the layout will match the dynamic case. - // - // * If a size is computed successfully, and it is larger than - // `anyValueSize`, then the concrete value cannot fit into the - // reserved area, and it needs to be stored out-of-line. - // - // * If size cannot be computed, then that implies that the type - // includes non-ordinary data (e.g., a `Texture2D` on a D3D11 - // target), and cannot possible fit into the reserved area - // (which consists of only uniform bytes). In this case, the - // value must be stored out-of-line. - // - IRSizeAndAlignment sizeAndAlignment; - Result result = getNaturalSizeAndAlignment( - targetProgram->getOptionSet(), - concreteType, - &sizeAndAlignment); - if (SLANG_FAILED(result) || (sizeAndAlignment.size > anyValueSize)) - { - // If the value must be stored out-of-line, we construct - // a "pseudo pointer" to the concrete type, and the - // constructed tuple will contain such a pseudo pointer. - // - // Semantically, the pseudo pointer behaves a bit like - // a pointer to the concrete type, in that it can be - // (pseudo-)dereferenced to produce a value of the chosen - // type. - // - // In terms of layout, the pseudo pointer occupies no - // space in the parent tuple/type, and will be automatically - // moved out-of-line by a later type legalization pass. - // - pendingType = builder->getPseudoPtrType(concreteType); - } - } - - auto anyValueType = builder->getAnyValueType(anyValueSize); - auto witnessTableType = builder->getWitnessTableIDType((IRType*)paramType); - auto rttiType = builder->getRTTIHandleType(); - - IRType* tupleType = nullptr; - if (!pendingType) - { - // In the oridnary (dynamic) case, an existential type decomposes - // into a tuple of: - // - // (RTTI, witness table, any-value). - // - tupleType = builder->getTupleType(rttiType, witnessTableType, anyValueType); - } - else - { - // In the case where static specialization mandateds out-of-line storage, - // an existential type decomposes into a tuple of: - // - // (RTTI, witness table, pseudo pointer, any-value) - // - tupleType = - builder->getTupleType(rttiType, witnessTableType, pendingType, anyValueType); - // - // Note that in each of the cases, the third element of the tuple - // is a representation of the value being stored in the existential. - // - // Also note that each of these representations has the same - // size and alignment when only "ordinary" data is considered - // (the pseudo-pointer will eventually be legalized away, leaving - // behind a tuple with equivalent layout). - } - - return tupleType; - } - case kIROp_LookupWitnessMethod: - { - auto lookupInterface = static_cast(paramType); - auto witnessTableType = - as(lookupInterface->getWitnessTable()->getDataType()); - if (!witnessTableType) - return (IRType*)paramType; - auto interfaceType = as(witnessTableType->getConformanceType()); - if (!interfaceType || isBuiltin(interfaceType)) - return (IRType*)paramType; - // Make sure we are looking up inside the original interface type (prior to lowering). - // Only in the original interface type will an associated type entry have an - // IRAssociatedType value. We need to extract AnyValueSize from this IRAssociatedType. - // In lowered interface type, that entry is lowered into an Ptr(RTTIType) and this info - // is lost. - mapLoweredInterfaceToOriginal.tryGetValue(interfaceType, interfaceType); - auto reqVal = - findInterfaceRequirementVal(interfaceType, lookupInterface->getRequirementKey()); - SLANG_ASSERT(reqVal && reqVal->getOp() == kIROp_AssociatedType); - return lowerType(builder, reqVal, typeMapping, nullptr); - } - case kIROp_BoundInterfaceType: - { - // A bound interface type represents an existential together with - // static knowledge that the value stored in the extistential has - // a particular concrete type. - // - // We handle this case by lowering the underlying interface type, - // but pass along the concrete type so that it can impact the - // layout of the interface type. - // - auto boundInterfaceType = static_cast(paramType); - return lowerType( - builder, - boundInterfaceType->getInterfaceType(), - typeMapping, - boundInterfaceType->getConcreteType()); - } - default: - { - bool translated = false; - List loweredOperands; - for (UInt i = 0; i < paramType->getOperandCount(); i++) - { - loweredOperands.add( - lowerType(builder, paramType->getOperand(i), typeMapping, nullptr)); - if (loweredOperands.getLast() != paramType->getOperand(i)) - translated = true; - } - if (translated) - return builder->getType( - paramType->getOp(), - loweredOperands.getCount(), - loweredOperands.getBuffer()); - return (IRType*)paramType; - } - } -} - -List getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType) -{ - List witnessTables; - for (auto globalInst : module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_WitnessTable && - cast(globalInst->getDataType())->getConformanceType() == - interfaceType) - { - witnessTables.add(cast(globalInst)); - } - } - return witnessTables; -} - -List SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType( - IRInst* interfaceType) -{ - return Slang::getWitnessTablesFromInterfaceType(module, interfaceType); -} - -IRIntegerValue SharedGenericsLoweringContext::getInterfaceAnyValueSize( - IRInst* type, - SourceLoc usageLoc) -{ - SLANG_UNUSED(usageLoc); - - if (auto decor = type->findDecoration()) - { - return decor->getSize(); - } - - // We could conceivably make it an error to have an interface - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - return kDefaultAnyValueSize; -} - - -bool SharedGenericsLoweringContext::doesTypeFitInAnyValue( - IRType* concreteType, - IRInterfaceType* interfaceType, - IRIntegerValue* outTypeSize, - IRIntegerValue* outLimit, - bool* outIsTypeOpaque) -{ - auto anyValueSize = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); - if (outLimit) - *outLimit = anyValueSize; - - if (!areResourceTypesBindlessOnTarget(targetProgram->getTargetReq())) - { - IRType* opaqueType = nullptr; - if (isOpaqueType(concreteType, &opaqueType)) - { - if (outIsTypeOpaque) - *outIsTypeOpaque = true; - return false; - } - } - IRSizeAndAlignment sizeAndAlignment; - Result result = - getNaturalSizeAndAlignment(targetProgram->getOptionSet(), concreteType, &sizeAndAlignment); - if (outTypeSize) - *outTypeSize = sizeAndAlignment.size; - - if (SLANG_FAILED(result) || (sizeAndAlignment.size > anyValueSize)) - { - // The value does not fit, either because it is too large, - // or because it includes types that cannot be stored - // in uniform/ordinary memory for this target. - // - return false; - } - - return true; -} - -} // namespace Slang diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h deleted file mode 100644 index 62848c4b7dd..00000000000 --- a/source/slang/slang-ir-generics-lowering-context.h +++ /dev/null @@ -1,171 +0,0 @@ -// slang-ir-generics-lowering-context.h -#pragma once - -#include "slang-ir-dce.h" -#include "slang-ir-insts.h" -#include "slang-ir-lower-generics.h" -#include "slang-ir.h" - -namespace Slang -{ -struct IRModule; - -constexpr IRIntegerValue kInvalidAnyValueSize = 0xFFFFFFFF; -constexpr IRIntegerValue kDefaultAnyValueSize = 16; -constexpr SlangInt kRTTIHeaderSize = 16; -constexpr SlangInt kRTTIHandleSize = 8; - -struct SharedGenericsLoweringContext -{ - // For convenience, we will keep a pointer to the module - // we are processing. - IRModule* module; - - TargetProgram* targetProgram; - - DiagnosticSink* sink; - - // RTTI objects for each type used to call a generic function. - OrderedDictionary mapTypeToRTTIObject; - - Dictionary loweredGenericFunctions; - Dictionary loweredInterfaceTypes; - Dictionary mapLoweredInterfaceToOriginal; - - // Dictionaries for interface type requirement key-value lookups. - // Used by `findInterfaceRequirementVal`. - Dictionary> mapInterfaceRequirementKeyValue; - - // Map from interface requirement keys to its corresponding dispatch method. - OrderedDictionary mapInterfaceRequirementKeyToDispatchMethods; - - // We will use a single work list of instructions that need - // to be considered for lowering. - // - InstWorkList workList; - InstHashSet workListSet; - - SharedGenericsLoweringContext(IRModule* inModule) - : module(inModule), workList(inModule), workListSet(inModule) - { - } - - void addToWorkList(IRInst* inst) - { - if (!inst) - return; - - for (auto ii = inst->getParent(); ii; ii = ii->getParent()) - { - if (as(ii)) - return; - } - - if (workListSet.contains(inst)) - return; - - workList.add(inst); - workListSet.add(inst); - } - - - void _builldInterfaceRequirementMap(IRInterfaceType* interfaceType); - - IRInst* findInterfaceRequirementVal(IRInterfaceType* interfaceType, IRInst* requirementKey); - - // Emits an IRRTTIObject containing type information for a given type. - IRInst* maybeEmitRTTIObject(IRInst* typeInst); - - static IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc); - static IRType* lowerAssociatedType(IRBuilder* builder, IRInst* type); - - IRType* lowerType( - IRBuilder* builder, - IRInst* paramType, - const Dictionary& typeMapping, - IRType* concreteType); - - IRType* lowerType(IRBuilder* builder, IRInst* paramType) - { - return lowerType(builder, paramType, Dictionary(), nullptr); - } - - // Get a list of all witness tables whose conformance type is `interfaceType`. - List getWitnessTablesFromInterfaceType(IRInst* interfaceType); - - /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`? - bool doesTypeFitInAnyValue( - IRType* concreteType, - IRInterfaceType* interfaceType, - IRIntegerValue* outTypeSize = nullptr, - IRIntegerValue* outLimit = nullptr, - bool* outIsTypeOpaque = nullptr); -}; - -List getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType); - -bool isPolymorphicType(IRInst* typeInst); - -// Returns true if typeInst represents a type and should be lowered into -// Ptr(RTTIType). -bool isTypeValue(IRInst* typeInst); - -template -void workOnModule(SharedGenericsLoweringContext* sharedContext, const TFunc& func) -{ - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - func(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } -} - -template -void workOnCallGraph(SharedGenericsLoweringContext* sharedContext, const TFunc& func) -{ - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - IRDeadCodeEliminationOptions dceOptions; - dceOptions.keepExportsAlive = true; - dceOptions.keepLayoutsAlive = true; - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - - sharedContext->addToWorkList(inst->parent); - sharedContext->addToWorkList(inst->getFullType()); - - UInt operandCount = inst->getOperandCount(); - for (UInt ii = 0; ii < operandCount; ++ii) - { - if (!isWeakReferenceOperand(inst, ii)) - sharedContext->addToWorkList(inst->getOperand(ii)); - } - - if (auto call = as(inst)) - { - if (func(call)) - return; - } - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - if (shouldInstBeLiveIfParentIsLive(child, dceOptions)) - sharedContext->addToWorkList(child); - } - } -} -} // namespace Slang diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 638e55902b5..922f3976ef4 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -685,4 +685,35 @@ return { ["StoreBase.copyLogical"] = 681, ["MakeStorageTypeLoweringConfig"] = 682, ["Decoration.experimentalModule"] = 683, + ["SetBase.TypeSet"] = 684, + ["SetBase.FuncSet"] = 685, + ["SetBase.WitnessTableSet"] = 686, + ["SetBase.GenericSet"] = 687, + ["Type.SetTagType"] = 689, + ["Type.TaggedUnionType"] = 690, + ["CastInterfaceToTaggedUnionPtr"] = 691, + ["GetTagForSuperSet"] = 693, + ["GetTagForMappedSet"] = 694, + ["GetTagForSpecializedSet"] = 695, + ["GetTagFromSequentialID"] = 696, + ["GetSequentialIDFromTag"] = 697, + ["GetElementFromTag"] = 698, + ["GetDispatcher"] = 699, + ["GetSpecializedDispatcher"] = 700, + ["GetTagFromTaggedUnion"] = 701, + ["GetValueFromTaggedUnion"] = 702, + ["Type.UntaggedUnionType"] = 703, + ["Type.ElementOfSetType"] = 704, + ["MakeTaggedUnion"] = 705, + ["GetTypeTagFromTaggedUnion"] = 706, + ["GetTagOfElementInSet"] = 707, + ["UnboundedTypeElement"] = 708, + ["UnboundedFuncElement"] = 709, + ["UnboundedWitnessTableElement"] = 710, + ["UnboundedGenericElement"] = 711, + ["UninitializedTypeElement"] = 712, + ["UninitializedWitnessTableElement"] = 713, + ["NoneTypeElement"] = 714, + ["NoneWitnessTableElement"] = 715, + ["GetTagForSubSet"] = 716 } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e838f0cf18f..c1b20ef8ffb 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2821,6 +2821,124 @@ struct IREmbeddedDownstreamIR : IRInst CodeGenTarget getTarget() { return static_cast(getTargetOperand()->getValue()); } }; +FIDDLE() +struct IRSetBase : IRInst +{ + FIDDLE(baseInst()) + UInt getCount() { return getOperandCount(); } + IRInst* getElement(UInt idx) { return getOperand(idx); } + bool isEmpty() { return getOperandCount() == 0; } + bool isSingleton() { return (getOperandCount() == 1) && !isUnbounded(); } + bool isUnbounded() + { + // This is an unbounded set if any of its elements are unbounded. + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UnboundedTypeElement: + case kIROp_UnboundedWitnessTableElement: + case kIROp_UnboundedFuncElement: + case kIROp_UnboundedGenericElement: + return true; + } + } + return false; + } + + IRInst* tryGetUnboundedElement() + { + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UnboundedTypeElement: + case kIROp_UnboundedWitnessTableElement: + case kIROp_UnboundedFuncElement: + case kIROp_UnboundedGenericElement: + return getElement(ii); + } + } + return nullptr; + } + + bool containsUninitializedElement() + { + // This is a "potentially uninitialized" set if any of its elements are unbounded. + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UninitializedTypeElement: + case kIROp_UninitializedWitnessTableElement: + return true; + } + } + return false; + } + + IRInst* tryGetUninitializedElement() + { + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UninitializedTypeElement: + case kIROp_UninitializedWitnessTableElement: + return getElement(ii); + } + } + return nullptr; + } +}; + +FIDDLE() +struct IRWitnessTableSet : IRSetBase +{ + FIDDLE(leafInst()) +}; + + +FIDDLE() +struct IRTypeSet : IRSetBase +{ + FIDDLE(leafInst()) +}; + +FIDDLE() +struct IRSetTagType : IRType +{ + FIDDLE(leafInst()) + IRSetBase* getSet() { return as(getOperand(0)); } + bool isSingleton() { return getSet()->isSingleton(); } +}; + +FIDDLE() +struct IRTaggedUnionType : IRType +{ + FIDDLE(leafInst()) + IRWitnessTableSet* getWitnessTableSet() { return as(getOperand(0)); } + IRTypeSet* getTypeSet() { return as(getOperand(1)); } + bool isSingleton() + { + return getTypeSet()->isSingleton() && getWitnessTableSet()->isSingleton(); + } +}; + +FIDDLE() +struct IRElementOfSetType : IRType +{ + FIDDLE(leafInst()) + IRSetBase* getSet() { return as(getOperand(0)); } +}; + +FIDDLE() +struct IRUntaggedUnionType : IRType +{ + FIDDLE(leafInst()) + IRSetBase* getSet() { return as(getOperand(0)); } +}; + // Generate struct definitions for all IR instructions not explicitly defined in this file #if 0 // FIDDLE TEMPLATE: % local lua_module = require("source/slang/slang-ir.h.lua") @@ -3008,6 +3126,7 @@ struct IRBuilder IRPtrLit* getNullPtrValue(IRType* type); IRPtrLit* getNullVoidPtrValue() { return getNullPtrValue(getPtrType(getVoidType())); } IRVoidLit* getVoidValue(); + IRVoidLit* getVoidValue(IRType* type); IRInst* getCapabilityValue(CapabilitySet const& caps); IRBasicType* getBasicType(BaseType baseType); @@ -3515,6 +3634,16 @@ struct IRBuilder return emitMakeTuple(SLANG_COUNT_OF(args), args); } + IRMakeTaggedUnion* emitMakeTaggedUnion( + IRType* type, + IRInst* typeTag, + IRInst* witnessTableTag, + IRInst* value) + { + IRInst* args[] = {typeTag, witnessTableTag, value}; + return cast(emitIntrinsicInst(type, kIROp_MakeTaggedUnion, 3, args)); + } + IRInst* emitMakeValuePack(IRType* type, UInt count, IRInst* const* args); IRInst* emitMakeValuePack(UInt count, IRInst* const* args); @@ -4036,6 +4165,170 @@ struct IRBuilder IRMetalSetPrimitive* emitMetalSetPrimitive(IRInst* index, IRInst* primitive); IRMetalSetIndices* emitMetalSetIndices(IRInst* index, IRInst* indices); + IRGetElementFromTag* emitGetElementFromTag(IRInst* tag) + { + auto tagType = cast(tag->getDataType()); + IRInst* set = tagType->getSet(); + auto elementType = + cast(emitIntrinsicInst(nullptr, kIROp_ElementOfSetType, 1, &set)); + return cast( + emitIntrinsicInst(elementType, kIROp_GetElementFromTag, 1, &tag)); + } + + IRGetTagFromTaggedUnion* emitGetTagFromTaggedUnion(IRInst* tag) + { + auto taggedUnionType = cast(tag->getDataType()); + + IRInst* set = taggedUnionType->getWitnessTableSet(); + auto tableTagType = + cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &set)); + + return cast( + emitIntrinsicInst(tableTagType, kIROp_GetTagFromTaggedUnion, 1, &tag)); + } + + IRGetTypeTagFromTaggedUnion* emitGetTypeTagFromTaggedUnion(IRInst* tag) + { + auto taggedUnionType = cast(tag->getDataType()); + + IRInst* typeSet = taggedUnionType->getTypeSet(); + auto typeTagType = + cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &typeSet)); + + return cast( + emitIntrinsicInst(typeTagType, kIROp_GetTypeTagFromTaggedUnion, 1, &tag)); + } + + IRGetValueFromTaggedUnion* emitGetValueFromTaggedUnion(IRInst* taggedUnion) + { + auto taggedUnionType = cast(taggedUnion->getDataType()); + + IRInst* typeSet = taggedUnionType->getTypeSet(); + auto valueOfTypeSetType = cast( + emitIntrinsicInst(nullptr, kIROp_UntaggedUnionType, 1, &typeSet)); + + return cast( + emitIntrinsicInst(valueOfTypeSetType, kIROp_GetValueFromTaggedUnion, 1, &taggedUnion)); + } + + IRGetDispatcher* emitGetDispatcher( + IRFuncType* funcType, + IRWitnessTableSet* witnessTableSet, + IRStructKey* key) + { + IRInst* args[] = {witnessTableSet, key}; + return cast(emitIntrinsicInst(funcType, kIROp_GetDispatcher, 2, args)); + } + + IRGetSpecializedDispatcher* emitGetSpecializedDispatcher( + IRFuncType* funcType, + IRWitnessTableSet* witnessTableSet, + IRStructKey* key, + List const& specArgs) + { + List args; + args.add(witnessTableSet); + args.add(key); + for (auto specArg : specArgs) + { + args.add(specArg); + } + return cast(emitIntrinsicInst( + funcType, + kIROp_GetSpecializedDispatcher, + (UInt)args.getCount(), + args.getBuffer())); + } + + IRUntaggedUnionType* getUntaggedUnionType(IRInst* operand) + { + return as( + emitIntrinsicInst(nullptr, kIROp_UntaggedUnionType, 1, &operand)); + } + + IRElementOfSetType* getElementOfSetType(IRInst* operand) + { + return as( + emitIntrinsicInst(nullptr, kIROp_ElementOfSetType, 1, &operand)); + } + + IRTaggedUnionType* getTaggedUnionType(IRWitnessTableSet* tables, IRTypeSet* types) + { + IRInst* operands[] = {tables, types}; + return as( + emitIntrinsicInst(nullptr, kIROp_TaggedUnionType, 2, operands)); + } + + IRSetTagType* getSetTagType(IRSetBase* collection) + { + IRInst* operands[] = {collection}; + return cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, operands)); + } + + IRUnboundedTypeElement* getUnboundedTypeElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedTypeElement, 1, &interfaceType)); + } + + IRUnboundedWitnessTableElement* getUnboundedWitnessTableElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedWitnessTableElement, 1, &interfaceType)); + } + + IRUnboundedFuncElement* getUnboundedFuncElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedFuncElement, 0, nullptr)); + } + + IRUnboundedGenericElement* getUnboundedGenericElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedGenericElement, 0, nullptr)); + } + + IRUninitializedTypeElement* getUninitializedTypeElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UninitializedTypeElement, 1, &interfaceType)); + } + + IRUninitializedWitnessTableElement* getUninitializedWitnessTableElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UninitializedWitnessTableElement, 1, &interfaceType)); + } + + IRNoneTypeElement* getNoneTypeElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_NoneTypeElement, 0, nullptr)); + } + + IRNoneWitnessTableElement* getNoneWitnessTableElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_NoneWitnessTableElement, 0, nullptr)); + } + + IRGetTagOfElementInSet* emitGetTagOfElementInSet( + IRType* tagType, + IRInst* element, + IRInst* collection) + { + SLANG_ASSERT(tagType->getOp() == kIROp_SetTagType); + IRInst* args[] = {element, collection}; + return cast( + emitIntrinsicInst(tagType, kIROp_GetTagOfElementInSet, 2, args)); + } + + IRSetTagType* getSetTagType(IRInst* collection) + { + return cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &collection)); + } + // // Decorations // @@ -4811,6 +5104,12 @@ struct IRBuilder } void addRayPayloadDecoration(IRType* inst) { addDecoration(inst, kIROp_RayPayloadDecoration); } + + IRSetBase* getSet(IROp op, const HashSet& elements); + + IRSetBase* getSingletonSet(IROp op, IRInst* element); + + UInt getUniqueID(IRInst* inst); }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index ff686ca1e9a..6349f7b5e78 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -677,6 +677,35 @@ local insts = { }, }, }, + { UntaggedUnionType = { + hoistable = true, + -- A type that represents that the value's _type_ is one of types in the set operand. + } }, + { ElementOfSetType = { + hoistable = true, + -- A type that represents that the value must be an element of the set operand. + } }, + { SetTagType = { + hoistable = true, + -- Represents a tag-type for a set. + -- + -- An inst whose type is SetTagType(set) is semantically carrying a + -- run-time value that "picks" one of the elements of the set operand. + -- + -- Only operand is a SetBase + } }, + { TaggedUnionType = { + hoistable = true, + -- Represents a tagged union type. + -- + -- An inst whose type is a TaggedUnionType(typeSet, witnessTableSet) is semantically carrying a tuple of + -- two values: a value of SetTagType(witnessTableSet) to represent the tag, and a payload value of type + -- UntaggedUnionType(typeSet), which conceptually represents a union/"anyvalue" type. + -- + -- This is most commonly used to specialize the type of existential insts once the possibilities can be statically determined. + -- + -- Operands are a TypeSet and a WitnessTableSet that represent the possibilities of the existential + } } }, }, -- IRGlobalValueWithCode @@ -2662,6 +2691,240 @@ local insts = { }, }, }, + { + SetBase = { + -- Base class for all set representation.s + -- + -- Semantically, `SetBase` types model sets of concrete values, and use Slang's de-duplication infrastructure + -- to allow set-equality to be the same as inst identity. + -- + -- - Set ops have one or more operands that represent the elements of the set + -- + -- - Set ops must have at least one operand. A zero-operand set is illegal. + -- The type-flow pass will represent this case using nullptr, so that uniqueness is preserved. + -- + -- - All operands of a set _must_ be concrete, individual insts + -- - Operands should NOT be an interface or abstract type. + -- - Operands should NOT be type parameters or existentail types (i.e. insts that appear in blocks) + -- - Operands should NOT be sets (i.e. sets should be flat and never heirarchical) + -- + -- - Since sets are hositable, set ops should (consequently) only appear in the global scope. + -- + -- - Set operands must be consistently sorted. i.e. a TypeSet(A, B) and TypeSet(B, A) + -- cannot exist at the same time, but either one is okay. + -- + -- - To help with the implementation of sets, the IRBuilder class provides operations such as `getSet` + -- that will ensure the above invariants are maintained, and uses a persistent unique ID map to + -- ensure stable ordering of set elements. + -- + -- Set representations should never be manually constructed to avoid breaking these invariants. + -- + hoistable = true, + { TypeSet = {} }, + { FuncSet = {} }, + { WitnessTableSet = {} }, + { GenericSet = {} } + }, + }, + { CastInterfaceToTaggedUnionPtr = { + -- Cast an interface-typed pointer to a tagged-union pointer with a known set. + } }, + { GetTagForSuperSet = { + -- Translate a tag from a set to its equivalent in a super-set + -- + -- Operands: (the tag for the source set) + -- The source and destination sets are implied by the type of the operand and the type of the result + } }, + { GetTagForSubSet = { + -- Translate a tag from a set to its equivalent in a sub-set + -- + -- Operands: (the tag for the source set) + -- The source and destination sets are implied by the type of the operand and the type of the result + } }, + { GetTagForMappedSet = { + -- Translate a tag from a set to its equivalent in a different set + -- based on a mapping induced by a lookup key + -- + -- Operands: (the tag for the witness table set, the lookup key) + } }, + { GetTagForSpecializedSet = { + -- Translate a tag from a set of generics to its equivalent in a specialized set + -- according to the set of specialization arguments that are encoded in the + -- operands of this instruction. + -- + -- Operands: (the tag for the generic set, any number of specialization arguments....) + } }, + { GetTagFromSequentialID = { + -- Translate an existing sequential ID (a 'global' ID) & and interface type into a tag + -- the provided set (a 'local' ID) + } }, + { GetSequentialIDFromTag = { + -- Translate a tag from the given set (a 'local' ID) to a sequential ID (a 'global' ID) + } }, + { GetElementFromTag = { + -- Translate a tag to its corresponding element in the set. + -- Input's type: SetTagType(set). + -- Output's type: ElementOfSetType(set) + -- + operands = {{"tag"}} + } }, + { GetDispatcher = { + -- Get a dispatcher function for a given witness table set + key. + -- + -- Inputs: set of witness tables to create a dispatched for and the key to use to identify the + -- entry that needs to be dispatched to. All witness tables must have an entry for the given key. + -- or else this is a malformed inst. + -- + -- Output: a value of 'FuncType' that can be called. + -- This func-type will take a `TagType(witnessTableSet)` as the first parameter to + -- discriminate which witness table to use, and the rest of the parameters. + -- + hoistable = true, + operands = {{"witnessTableSet", "IRWitnessTableSet"}, {"lookupKey", "IRStructKey"}} + } }, + { GetSpecializedDispatcher = { + -- Get a specialized dispatcher function for a given witness table set + key, where + -- the key points to a generic function. + -- + -- Operands: (set of witness tables, lookup key, specialization args...) + -- + -- + -- Output: a value of `FuncType` that can be called. + -- This func-type will take a `TagType(witnessTableSet)` as the first parameter to + -- discriminate which generic to use, and the rest of the parameters. + -- + hoistable = true + } }, + { GetTagFromTaggedUnion = { + -- Translate a tagged-union value to its corresponding tag in the tagged-union's set. + -- + -- Input's type: TaggedUnionType(typeSet, tableSet) + -- + -- Output's type: SetTagType(tableSet) + -- + operands = {{"taggedUnionValue"}} + } }, + { GetTypeTagFromTaggedUnion = { + -- Translate a tagged-union value to its corresponding type tag in the tagged-union's set. + -- + -- Input's type: TaggedUnionType(typeSet, tableSet) + -- + -- Output's type: SetTagType(typeSet) + -- + operands = {{"taggedUnionValue"}} + } }, + { GetValueFromTaggedUnion = { + -- Translate a tagged-union value to its corresponding value in the tagged-union's set. + -- + -- Input's type: TaggedUnionType(typeSet, tableSet) + -- + -- Output's type: UntaggedUnionType(typeSet) + -- + operands = {{"taggedUnionValue"}} + } }, + { MakeTaggedUnion = { + -- Create a tagged-union value from a tag and a value. + -- + -- Input's type: SetTagType(tableSet), UntaggedUnionType(typeSet) + -- + -- Output's type: TaggedUnionType(typeSet, tableSet) + -- + operands = { { "tag" }, { "value" } }, + } }, + { GetTagOfElementInSet = { + -- Get the tag corresponding to an element in a set. + -- + -- Operands: (element, set) + -- "element" must resolve into a concrete inst before lowering, + -- otherwise, this is an error. + -- + -- Output's type: SetTagType(set) + -- + hoistable = true + } }, + { UnboundedTypeElement = { + -- An element of TypeSet that represents an unbounded set of types conforming to + -- the given interface type. + -- + -- Used in cases where a finite set of types cannot be determined during type-flow analysis. + -- + -- Note that this is a set element, not a set in itself, so a TypeSet(A, B, UnboundedTypeElement(I)) + -- represents a set where we know two concrete types A and B, and any number of other types that conform to interface I. + -- + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { UnboundedFuncElement = { + -- An element of FuncSet that represents an unbounded set of functions of a certain + -- func-type + -- + -- Used in cases where a finite set of functions cannot be determined during type-flow analysis. + -- + -- Similar to UnboundedTypeElement, this is a set element, not a set in itself. + -- + hoistable = true, + operands = { {"funcType"} } + } }, + { UnboundedWitnessTableElement = { + -- An element of WitnessTableSet that represents an unbounded set of witness tables of a certain + -- interface type + -- + -- Used in cases where a finite set of witness tables cannot be determined during type-flow analysis. + -- + -- Similar to UnboundedTypeElement, this is a set element, not a set in itself. + -- + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { UnboundedGenericElement = { + -- An element of GenericSet that represents an unbounded set of generics of a certain + -- interface type + -- + -- Used in cases where a finite set of generics cannot be determined during type-flow analysis. + -- + -- Similar to UnboundedTypeElement, this is a set element, not a set in itself. + -- + hoistable = true, + } }, + { UninitializedTypeElement = { + -- An element that represents an uninitialized type of a certain interface. + -- + -- Used to denote cases where the type represented may be garbage (e.g. from a `LoadFromUninitializedMemory`) + -- + -- Similar to UnboundedXYZElement IR ops described above, this is a set element, not a set in itself. + -- e.g. a `TypeSet(A, B, UninitializedTypeElement(I))` represents a set where we know two concrete types A and B, + -- and an uninitialized type that conforms to interface I. + -- + -- Note: In practice, having any uninitialized type in a TypeSet will likely force the entire set to be treated as + -- uninitialized, and this element is mainly so that we can provide useful errors during the type-flow specialization pass. + -- + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { UninitializedWitnessTableElement = { + -- An element that represents an uninitialized witness table of a certain interface. + -- + -- Used to denote cases where the witness table information may be garbage (e.g. from a `LoadFromUninitializedMemory`) + -- + -- Similar to UninitializedTypeElement, this is a set element, not a set in itself. + -- + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { NoneTypeElement = { + -- An element that represents a default 'none' case (only relevant in the context of OptionalType) + -- + -- Similar to UnboundedXYZElement IR ops described above, this is a set element, not a set in itself. + -- + hoistable = true + } }, + { NoneWitnessTableElement = { + -- An element that represents a default 'none' case (only relevant in the context of OptionalType) + -- + -- Similar to UnboundedXYZElement IR ops described above, this is a set element, not a set in itself. + -- + hoistable = true + } }, } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 9db4b309757..614ca2baa63 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -1,8 +1,8 @@ // slang-ir-layout.cpp #include "slang-ir-layout.h" -#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" // This file implements facilities for computing and caching layout // information on IR types. @@ -300,12 +300,17 @@ Result IRTypeLayoutRules::calcSizeAndAlignment( return SLANG_OK; } break; + case kIROp_SetTagType: + { + outSizeAndAlignment->size = 4; + outSizeAndAlignment->alignment = 4; + return SLANG_OK; + } + break; case kIROp_InterfaceType: { auto interfaceType = cast(type); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); size += kRTTIHeaderSize; size = align(size, 4); IRSizeAndAlignment resultLayout; diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 2948a6e53bf..f8d10c80521 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1492,6 +1492,9 @@ static LegalVal legalizeGetElement( // the "index" argument. auto indexOperand = legalIndexOperand.getSimple(); + if (type.flavor == LegalType::Flavor::none) + return LegalVal(); + return legalizeGetElement(context, type, legalPtrOperand, indexOperand); } diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index c3243389cfb..7c1694c2d84 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -302,6 +302,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_GlobalGenericParam: case kIROp_WitnessTable: case kIROp_InterfaceType: + case kIROp_EnumType: case kIROp_SymbolAlias: return cloneGlobalValue(this, originalValue); @@ -805,6 +806,18 @@ IRStructType* cloneStructTypeImpl( return clonedStruct; } +IREnumType* cloneEnumTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IREnumType* originalEnum, + IROriginalValuesForClone const& originalValues) +{ + auto clonedEnum = + builder->createEnumType(cloneType(context, (IRType*)originalEnum->getOperand(0))); + cloneSimpleGlobalValueImpl(context, originalEnum, originalValues, clonedEnum); + return clonedEnum; +} + IRInterfaceType* cloneInterfaceTypeImpl( IRSpecContextBase* context, @@ -1388,6 +1401,9 @@ IRInst* cloneInst( cast(originalInst), originalValues); + case kIROp_EnumType: + return cloneEnumTypeImpl(context, builder, cast(originalInst), originalValues); + case kIROp_InterfaceType: return cloneInterfaceTypeImpl( context, @@ -2405,6 +2421,10 @@ struct IRPrelinkContext : IRSpecContext case kIROp_ClassType: clonedInst = builderForClone->createClassType(); break; + case kIROp_EnumType: + clonedInst = builderForClone->createEnumType( + cloneType(this, (IRType*)cast(originalVal)->getOperand(0))); + break; default: return completeClonedInst(IRSpecContext::maybeCloneValue(originalVal)); } diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp index e7c24c511e1..2f255d10432 100644 --- a/source/slang/slang-ir-loop-inversion.cpp +++ b/source/slang/slang-ir-loop-inversion.cpp @@ -3,7 +3,6 @@ #include "slang-ir-clone.h" #include "slang-ir-dominators.h" #include "slang-ir-insts.h" -#include "slang-ir-lower-witness-lookup.h" #include "slang-ir-reachability.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" diff --git a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp new file mode 100644 index 00000000000..3fde6d1b96d --- /dev/null +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp @@ -0,0 +1,1948 @@ +#include "slang-ir-lower-dynamic-dispatch-insts.h" + +#include "slang-ir-any-value-marshalling.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-insts.h" +#include "slang-ir-layout.h" +#include "slang-ir-specialize.h" +#include "slang-ir-typeflow-set.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ + +// Represents a work item for packing `inout` or `out` arguments after a concrete call. +struct ArgumentPackWorkItem +{ + enum Kind + { + Pack, + UpCast, + } kind = Pack; + + // A `AnyValue` typed destination. + IRInst* dstArg = nullptr; + // A concrete value to be packed. + IRInst* concreteArg = nullptr; +}; + +bool isAnyValueType(IRType* type) +{ + if (as(type) || as(type)) + return true; + return false; +} + +// Unpack an `arg` of `IRAnyValue` into concrete type if necessary, to make it feedable into the +// parameter. If `arg` represents a AnyValue typed variable passed in to a concrete `out` +// parameter, this function indicates that it needs to be packed after the call by setting +// `packAfterCall`. +IRInst* maybeUnpackArg( + IRBuilder* builder, + IRType* paramType, + IRInst* arg, + ArgumentPackWorkItem& packAfterCall) +{ + packAfterCall.dstArg = nullptr; + packAfterCall.concreteArg = nullptr; + + // If either paramType or argType is a pointer type + // (because of `inout` or `out` modifiers), we extract + // the underlying value type first. + IRType* paramValType = paramType; + IRType* argValType = arg->getDataType(); + IRInst* argVal = arg; + if (auto ptrType = as(paramType)) + { + paramValType = ptrType->getValueType(); + } + auto argType = arg->getDataType(); + if (auto argPtrType = as(argType)) + { + argValType = argPtrType->getValueType(); + } + + + // Unpack `arg` if the parameter expects concrete type but + // `arg` is an AnyValue. + if (!isAnyValueType(paramValType) && isAnyValueType(argValType)) + { + // if parameter expects an `out` pointer, store the unpacked val into a + // variable and pass in a pointer to that variable. + if (as(paramType)) + { + auto tempVar = builder->emitVar(paramValType); + if (as(paramType)) + builder->emitStore( + tempVar, + builder->emitUnpackAnyValue(paramValType, builder->emitLoad(arg))); + + // tempVar needs to be unpacked into original var after the call. + packAfterCall.kind = ArgumentPackWorkItem::Kind::Pack; + packAfterCall.dstArg = arg; + packAfterCall.concreteArg = tempVar; + return tempVar; + } + else + { + return builder->emitUnpackAnyValue(paramValType, argVal); + } + } + + // Reinterpret 'arg' if it is being passed to a parameter with + // a different type collection. For now, we'll approximate this + // by checking if the types are different, but this should be + // encoded in the types. + // + if (as(paramValType) && as(argValType) && + paramValType != argValType) + { + // if parameter expects an `out` pointer, store the unpacked val into a + // variable and pass in a pointer to that variable. + if (as(paramType)) + { + auto tempVar = builder->emitVar(paramValType); + + // tempVar needs to be unpacked into original var after the call. + packAfterCall.kind = ArgumentPackWorkItem::Kind::UpCast; + packAfterCall.dstArg = arg; + packAfterCall.concreteArg = tempVar; + return tempVar; + } + else + { + SLANG_UNEXPECTED("Unexpected upcast for non-out parameter"); + } + } + return arg; +} + +IRStringLit* _getWitnessTableWrapperFuncName(IRModule* module, IRFunc* func) +{ + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(func); + if (auto linkageDecoration = func->findDecoration()) + { + return builder->getStringValue( + (String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); + } + if (auto namehintDecoration = func->findDecoration()) + { + return builder->getStringValue( + (String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); + } + return nullptr; +} + +// Create a wrapper function that makes a specific function's signature match it's type in the +// interface requirement. +// +// e.g. the signature of the function from the caller's side might look like: ((ThisType, float) -> +// ThisType) while the actual implementation function might be ((FooImpl, float) -> FooImpl). +// +// The witness table wrapper will marshal from the union-types (ThisType) to the concrete types +// (FooImpl) expected by the implementation. +// +IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* interfaceRequirementVal) +{ + auto funcTypeInInterface = cast(interfaceRequirementVal); + auto targetFuncType = as(funcInst->getDataType()); + + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(funcInst); + + auto wrapperFunc = builder->createFunc(); + wrapperFunc->setFullType((IRType*)interfaceRequirementVal); + if (auto func = as(funcInst)) + if (auto name = _getWitnessTableWrapperFuncName(module, func)) + builder->addNameHintDecoration(wrapperFunc, name); + + builder->setInsertInto(wrapperFunc); + auto block = builder->emitBlock(); + builder->setInsertInto(block); + + ShortList params; + for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) + { + params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); + } + + List args; + List argsToPack; + + SLANG_ASSERT(params.getCount() == (Index)targetFuncType->getParamCount()); + for (UInt i = 0; i < targetFuncType->getParamCount(); i++) + { + auto wrapperParam = params[i]; + // Type of the parameter in the callee. + auto funcParamType = targetFuncType->getParamType(i); + + // If the implementation expects a concrete type + // (either in the form of a pointer for `out`/`inout` parameters, + // or in the form a value for `in` parameters, while + // the interface exposes an AnyValue type, + // we need to unpack the AnyValue argument to the appropriate + // concerete type. + ArgumentPackWorkItem packWorkItem; + auto newArg = maybeUnpackArg(builder, funcParamType, wrapperParam, packWorkItem); + args.add(newArg); + if (packWorkItem.concreteArg) + argsToPack.add(packWorkItem); + } + auto call = builder->emitCallInst(targetFuncType->getResultType(), funcInst, args); + + // Pack all `out` arguments. + for (auto item : argsToPack) + { + auto anyValType = cast(item.dstArg->getDataType())->getValueType(); + auto concreteVal = builder->emitLoad(item.concreteArg); + auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) + ? builder->emitPackAnyValue(anyValType, concreteVal) + : upcastSet(builder, concreteVal, anyValType); + builder->emitStore(item.dstArg, packedVal); + } + + // Pack return value if necessary. + if (!isAnyValueType(call->getDataType()) && + isAnyValueType(funcTypeInInterface->getResultType())) + { + auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); + builder->emitReturn(pack); + } + else if (call->getDataType() != funcTypeInInterface->getResultType()) + { + auto reinterpret = upcastSet(builder, call, funcTypeInInterface->getResultType()); + builder->emitReturn(reinterpret); + } + else + { + if (call->getDataType()->getOp() == kIROp_VoidType) + builder->emitReturn(); + else + builder->emitReturn(call); + } + return wrapperFunc; +} + +UInt getUniqueID(IRBuilder* builder, IRInst* inst) +{ + // Fallback. + return builder->getUniqueID(inst); +} + +// Generate a single function that dispatches to each function in the collection. +// The resulting function will have one additional parameter to accept the tag +// indicating which function to call. +// +IRFunc* createDispatchFunc(IRFuncType* dispatchFuncType, Dictionary& mapping) +{ + // Create a dispatch function with switch-case for each function + IRBuilder builder(dispatchFuncType->getModule()); + + // Consume the first parameter of the expected function type + List innerParamTypes; + for (auto paramType : dispatchFuncType->getParamTypes()) + innerParamTypes.add(paramType); + innerParamTypes.removeAt(0); // Remove the first parameter (ID) + + auto resultType = dispatchFuncType->getResultType(); + auto innerFuncType = builder.getFuncType(innerParamTypes, resultType); + + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(dispatchFuncType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto idParam = builder.emitParam(builder.getUIntType()); + + // Create parameters for the original function arguments + List originalParams; + for (Index i = 0; i < innerParamTypes.getCount(); i++) + { + originalParams.add(builder.emitParam(innerParamTypes[i])); + } + + // Create default block + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + // Return a default-constructed value + auto defaultValue = builder.emitDefaultConstruct(resultType); + builder.emitReturn(defaultValue); + } + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each function + List caseValues; + List caseBlocks; + + for (auto kvPair : mapping) + { + auto funcInst = kvPair.second; + auto funcTag = kvPair.first; + + // The different functions in the mapping may have different signatures, + // so we need to emit a wrapper that marshals the parameters to the expected types for + // each function. + // + auto wrapperFunc = emitWitnessTableWrapper(funcInst->getModule(), funcInst, innerFuncType); + + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + + List callArgs; + auto wrappedFuncType = as(wrapperFunc->getDataType()); + for (Index ii = 0; ii < originalParams.getCount(); ii++) + { + callArgs.add(originalParams[ii]); + } + + // Call the specific function + auto callResult = + builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); + + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + builder.emitReturn(callResult); + } + + caseValues.add(funcTag); + caseBlocks.add(caseBlock); + } + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Create an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + idParam, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + +// Create a function that maps input integers to output integers based on the provided mapping. +IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping, UInt defaultVal) +{ + // Emit a switch statement with the inputs as case labels and outputs as return values. + + IRBuilder builder(module); + + auto funcType = + builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto param = builder.emitParam(builder.getUIntType()); + + // Create default block that returns defaultVal + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), defaultVal)); + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each input table + List caseValues; + List caseBlocks; + + for (auto item : mapping) + { + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), item.second)); + + caseValues.add(builder.getIntValue(builder.getUIntType(), item.first)); + caseBlocks.add(caseBlock); + } + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Emit an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + param, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + +// This context lowers `GetTagOfElementInSet`, +// `GetTagForSuperSet`, and `GetTagForMappedSet` instructions, +// +struct TagOpsLoweringContext : public InstPassBase +{ + // Our strategy for lowering tag operations is to + // assign each element to a unique integer ID that is stable + // across the same module. This is acheived via `getUniqueID`, + // on the IRBuilder, which uses a dicionary on the module inst + // to keep track of assignments. + // + // Then, tag operations can be lowered to mapping functions that + // take an integer in and return an integer out, based on the + // input and output sets (and any other operands) + // + TagOpsLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerGetTagForSuperSet(IRGetTagForSuperSet* inst) + { + // `GetTagForSuperSet` is a no-op since we want to translate the tag + // for an element in the sub-set to a tag for the same element in the super-set. + // + // Since all elements have a unique ID across the module, this is the identity operation. + // + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + inst->replaceUsesWith(builder.emitCast(inst->getDataType(), inst->getOperand(0), true)); + inst->removeAndDeallocate(); + } + + void lowerGetTagForSubSet(IRGetTagForSubSet* inst) + { + // `GetTagForSubSet` is a no-op since we want to translate the tag + // for an element in the sub-set to a tag for the same element in a sub-set. + // It is assumed that this operation has already been confirmed as safe. + // + // Since all elements have a unique ID across the module, this is the identity operation. + // + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + inst->replaceUsesWith(builder.emitCast(inst->getDataType(), inst->getOperand(0), true)); + inst->removeAndDeallocate(); + } + + void lowerGetTagForMappedSet(IRGetTagForMappedSet* inst) + { + // `GetTagForMappedSet` turns into a integer mapping from + // the unique ID of each input set element to the unique ID of the + // corresponding element (as determined by witness table lookup) in the destination set. + // + auto srcSet = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destSet = cast(cast(inst->getDataType())->getOperand(0)); + auto key = cast(inst->getOperand(1)); + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + + Dictionary mapping; + for (UInt i = 0; i < srcSet->getCount(); i++) + { + // Find in destSet + bool found = false; + auto srcMappedElement = + findWitnessTableEntry(cast(srcSet->getElement(i)), key); + for (UInt j = 0; j < destSet->getCount(); j++) + { + auto destElement = destSet->getElement(j); + if (srcMappedElement == destElement) + { + found = true; + // We rely on the fact that if the element ever appeared in a collection, + // it must have been assigned a unique ID. + // + mapping.add( + getUniqueID(&builder, srcSet->getElement(i)), + getUniqueID(&builder, destElement)); + break; // Found the index + } + } + + if (!found) + { + // destSet must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } + } + + // Create an index mapping func and call that. + auto mappingFunc = createIntegerMappingFunc(inst->getModule(), mapping, 0); + + auto resultID = builder.emitCallInst( + inst->getDataType(), + mappingFunc, + List({inst->getOperand(0)})); + inst->replaceUsesWith(resultID); + inst->removeAndDeallocate(); + } + + void lowerGetTagOfElementInSet(IRGetTagOfElementInSet* inst) + { + // `GetTagOfElementInSet` simply gets replaced by the element's + // unique ID (as an integer literal value) + // + // Note: the element must be a concrete global inst (cannot by a + // dynamic value) + // + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + + auto uniqueId = getUniqueID(&builder, inst->getOperand(0)); + auto resultValue = builder.getIntValue(inst->getDataType(), uniqueId); + inst->replaceUsesWith(resultValue); + inst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_GetTagForSuperSet: + lowerGetTagForSuperSet(as(inst)); + break; + case kIROp_GetTagForSubSet: + lowerGetTagForSubSet(as(inst)); + break; + case kIROp_GetTagForMappedSet: + lowerGetTagForMappedSet(as(inst)); + break; + case kIROp_GetTagOfElementInSet: + lowerGetTagOfElementInSet(as(inst)); + break; + default: + break; + } + } + + void processModule() + { + processAllInsts([&](IRInst* inst) { return processInst(inst); }); + } +}; + +struct DispatcherLoweringContext : public InstPassBase +{ + DispatcherLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerGetDispatcher(IRGetDispatcher* dispatcher) + { + // Replace the `IRGetDispatcher` with a dispatch function, + // which takes an extra first parameter for the tag (i.e. ID) + // + // We'll also replace the callee in all 'call' insts. + // + // The generated dispatch function uses a switch-case to call the + // appropriate function based on the integer tag. Since tags + // may not yet be lowered into actual integers, we use `GetTagOfElementInSet` + // as a placeholder literal. + // + // Note that before each function is called, it needs to be wrapped in a + // method (a 'witness table wrapper') that handles marshalling between the input types + // to the dispatcher and the actual function types (which may be different) + // + + auto witnessTableSet = cast(dispatcher->getOperand(0)); + auto key = cast(dispatcher->getOperand(1)); + + IRBuilder builder(dispatcher->getModule()); + + Dictionary elements; + forEachInSet( + witnessTableSet, + [&](IRInst* table) + { + auto tag = builder.emitGetTagOfElementInSet( + builder.getSetTagType(witnessTableSet), + table, + witnessTableSet); + elements.add( + tag, + cast(findWitnessTableEntry(cast(table), key))); + }); + + if (dispatcher->hasUses() && dispatcher->getDataType() != nullptr) + { + auto dispatchFunc = + createDispatchFunc(cast(dispatcher->getDataType()), elements); + + if (auto nameHint = dispatcher->getLookupKey()->findDecoration()) + { + builder.setInsertBefore(dispatchFunc); + StringBuilder sb; + sb << "s_dispatch_" << nameHint->getName() << ""; + builder.addNameHintDecoration(dispatchFunc, sb.getUnownedSlice()); + } + + traverseUses( + dispatcher, + [&](IRUse* use) + { + if (auto callInst = as(use->getUser())) + { + // Replace callee with the generated dispatchFunc. + if (callInst->getCallee() == dispatcher) + { + IRBuilder callBuilder(callInst); + callBuilder.setInsertBefore(callInst); + callBuilder.replaceOperand(callInst->getCalleeUse(), dispatchFunc); + } + } + }); + } + } + + void lowerGetSpecializedDispatcher(IRGetSpecializedDispatcher* dispatcher) + { + // Replace the `IRGetSpecializedDispatcher` with a dispatch function, + // which takes an extra first parameter for the tag (i.e. ID) + // + // We'll also replace the callee in all 'call' insts. + // + // The logic here is very similar to `lowerGetDispatcher`, except that we need to + // account for the specialization arguments when creating the dispatch function. + // We construct an `IRSpecialize` inst around each generic function before dispatching + // to it. + // + + auto witnessTableSet = cast(dispatcher->getOperand(0)); + auto key = cast(dispatcher->getOperand(1)); + + List specArgs; + for (UIndex i = 2; i < dispatcher->getOperandCount(); i++) + { + specArgs.add(dispatcher->getOperand(i)); + } + + Dictionary elements; + IRBuilder builder(dispatcher->getModule()); + forEachInSet( + witnessTableSet, + [&](IRInst* table) + { + auto generic = + cast(findWitnessTableEntry(cast(table), key)); + + auto specializedFuncType = + (IRType*)specializeGeneric(cast(builder.emitSpecializeInst( + builder.getTypeKind(), + generic->getDataType(), + specArgs.getCount(), + specArgs.getBuffer()))); + + auto specializedFunc = builder.emitSpecializeInst( + specializedFuncType, + generic, + specArgs.getCount(), + specArgs.getBuffer()); + + auto singletonTag = builder.emitGetTagOfElementInSet( + builder.getSetTagType(witnessTableSet), + table, + witnessTableSet); + + elements.add(singletonTag, specializedFunc); + }); + + if (dispatcher->hasUses() && dispatcher->getDataType() != nullptr) + { + auto dispatchFunc = + createDispatchFunc(cast(dispatcher->getDataType()), elements); + + if (auto keyNameHint = key->findDecoration()) + { + builder.setInsertBefore(dispatchFunc); + StringBuilder sb; + sb << "s_dispatch_" << keyNameHint->getName() << ""; + for (auto specArg : specArgs) + { + sb << "_"; + getTypeNameHint(sb, specArg); + } + builder.addNameHintDecoration(dispatchFunc, sb.getUnownedSlice()); + } + + traverseUses( + dispatcher, + [&](IRUse* use) + { + if (auto callInst = as(use->getUser())) + { + // Replace callee with the generated dispatchFunc. + if (callInst->getCallee() == dispatcher) + { + IRBuilder callBuilder(callInst); + callBuilder.setInsertBefore(callInst); + callBuilder.replaceOperand(callInst->getCalleeUse(), dispatchFunc); + } + } + }); + } + } + + void processModule() + { + processInstsOfType( + kIROp_GetDispatcher, + [&](IRGetDispatcher* inst) { return lowerGetDispatcher(inst); }); + + processInstsOfType( + kIROp_GetSpecializedDispatcher, + [&](IRGetSpecializedDispatcher* inst) { return lowerGetSpecializedDispatcher(inst); }); + } +}; + +bool lowerDispatchers(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + DispatcherLoweringContext context(module); + context.processModule(); + return true; +} + +// This context lowers `TypeSet` instructions. +struct UntaggedUnionLoweringContext : public InstPassBase +{ + UntaggedUnionLoweringContext( + IRModule* module, + TargetProgram* targetProgram, + DiagnosticSink* sink = nullptr) + : InstPassBase(module), targetProgram(targetProgram), sink(sink) + { + } + + SlangInt tryCalculateAnyValueSize(const HashSet& types) + { + SlangInt maxSize = 0; + for (auto type : types) + { + auto size = getAnyValueSize(type); + if (size > maxSize) + maxSize = size; + + // We need to consider whether the concrete type can be stored in an `AnyValue`. + // + // For example, resource types that are not bindless on the target cannot be marshalled + // and we need to diagnose that. + // + if (sink && !canTypeBeStored(type)) + { + sink->diagnose( + type->sourceLoc, + Slang::Diagnostics::typeCannotBePackedIntoAnyValue, + type); + } + } + + // Defaults to 0 if any type could not be sized. + return maxSize; + } + + IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& types) + { + auto size = tryCalculateAnyValueSize(types); + return builder->getAnyValueType(size); + } + + bool canTypeBeStored(IRType* concreteType) + { + if (!areResourceTypesBindlessOnTarget(targetProgram->getTargetReq())) + { + IRType* opaqueType = nullptr; + if (isOpaqueType(concreteType, &opaqueType)) + { + return false; + } + } + + IRSizeAndAlignment sizeAndAlignment; + Result result = getNaturalSizeAndAlignment( + targetProgram->getOptionSet(), + concreteType, + &sizeAndAlignment); + + if (SLANG_FAILED(result)) + return false; + + return true; + } + + void lowerUntaggedUnionType(IRUntaggedUnionType* untaggedUnionType) + { + // Type collections are replaced with `AnyValueType` large enough to hold + // any of the types in the collection. + // + + HashSet types; + for (UInt i = 0; i < untaggedUnionType->getSet()->getCount(); i++) + { + if (auto type = as(untaggedUnionType->getSet()->getElement(i))) + { + types.add(type); + } + else if (as(untaggedUnionType->getSet()->getElement(i))) + { + // Can safely skip. (effectively 0 size) + } + else + { + // Should either be a type or NoneTypeElement (if other cases are okay, need to + // handle them here) + // + SLANG_UNEXPECTED("Expected type element in UntaggedUnionType set"); + } + } + + IRBuilder builder(module); + auto anyValueType = createAnyValueType(&builder, types); + untaggedUnionType->replaceUsesWith(anyValueType); + } + + // Replace any uses of `NoneTypeElement` with `VoidType`. + void replaceNoneTypeElementWithVoidType() + { + IRBuilder builder(module); + auto noneTypeElement = builder.getNoneTypeElement(); + noneTypeElement->replaceUsesWith(builder.getVoidType()); + noneTypeElement->removeAndDeallocate(); + } + + void processModule() + { + processInstsOfType( + kIROp_UntaggedUnionType, + [&](IRUntaggedUnionType* inst) { return lowerUntaggedUnionType(inst); }); + + replaceNoneTypeElementWithVoidType(); + } + +private: + DiagnosticSink* sink; + TargetProgram* targetProgram; +}; + +// Lower `UntaggedUnionType(TypeSet(...))` instructions by replacing them with +// appropriate `AnyValueType` instructions. +// +void lowerUntaggedUnionTypes(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + UntaggedUnionLoweringContext context(module, targetProgram, sink); + context.processModule(); +} + +// This context lowers `IRGetTagFromSequentialID` and `IRGetSequentialIDFromTag` instructions. +// Note: This pass requires that sequential ID decorations have been created for all witness +// tables. +// +struct SequentialIDTagLoweringContext : public InstPassBase +{ + SequentialIDTagLoweringContext(Linkage* linkage, IRModule* module) + : InstPassBase(module), m_linkage(linkage) + { + } + void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) + { + // We use the result type to figure out the destination collection + // for which we need to generate the tag. + // + // We then replace this with call into an integer mapping function, + // which takes the sequential ID and returns the local ID (i.e. tag). + // + // To construct, the mapping, we lookup the sequential ID decorator on + // each element of the destination collection, and map it to the table's + // operand index in the collection. + // + + // We use the result type and the type of the operand + auto srcSeqID = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destSet = cast(inst->getDataType())->getSet(); + + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + forEachInSet( + destSet, + [&](IRInst* table) + { + // Get unique ID for the witness table + auto outputId = builder.getUniqueID(table); + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping[inputId] = outputId; // Map ID to itself for now + } + }); + + // By default, use the tag for the largest available sequential ID. + UInt defaultSeqID = 0; + for (auto [inputId, outputId] : mapping) + { + if (inputId > defaultSeqID) + defaultSeqID = inputId; + } + + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping, mapping[defaultSeqID]), + List({srcSeqID})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + + + void lowerGetSequentialIDFromTag(IRGetSequentialIDFromTag* inst) + { + // Similar logic to the `GetTagFromSequentialID` case, except that + // we reverse the mapping. + // + + SLANG_UNUSED(cast(inst->getOperand(0))); + auto srcTagInst = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destSet = cast(srcTagInst->getDataType())->getSet(); + + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + forEachInSet( + destSet, + [&](IRInst* table) + { + // Get unique ID for the witness table + SLANG_UNUSED(cast(table)); + auto inputId = builder.getUniqueID(table); + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto outputId = seqDecoration->getSequentialID(); + mapping.add({inputId, outputId}); + } + }); + + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping, 0), + List({srcTagInst})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + + + // Ensures every witness table object has been assigned a sequential ID. + // + // All witness tables will have a SequentialID decoration after this function is run. + // + // The sequential ID in the decoration will be the same as the one specified in the Linkage. + // + // Otherwise, a new ID will be generated and assigned to the witness table object, and + // the sequential ID map in the Linkage will be updated to include the new ID, so they + // can be looked up by the user via future Slang API calls. + // + void ensureWitnessTableSequentialIDs() + { + StringBuilder generatedMangledName; + + auto linkage = getLinkage(); + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_WitnessTable) + { + UnownedStringSlice witnessTableMangledName; + if (auto instLinkage = inst->findDecoration()) + { + witnessTableMangledName = instLinkage->getMangledName(); + } + else + { + auto witnessTableType = as(inst->getDataType()); + + if (witnessTableType && witnessTableType->getConformanceType() == nullptr) + { + // Ignore witness tables that represent 'none' for optional witness table + // types. + continue; + } + + if (witnessTableType && witnessTableType->getConformanceType() + ->findDecoration()) + { + // The interface is for specialization only, it would be an error if dynamic + // dispatch is used through the interface. Skip assigning ID for the witness + // table. + continue; + } + + // generate a unique linkage for it. + static int32_t uniqueId = 0; + uniqueId++; + if (auto nameHint = inst->findDecoration()) + { + generatedMangledName << nameHint->getName(); + } + generatedMangledName << "_generated_witness_uuid_" << uniqueId; + witnessTableMangledName = generatedMangledName.getUnownedSlice(); + } + + // If the inst already has a SequentialIDDecoration, stop now. + if (inst->findDecoration()) + continue; + + // Get a sequential ID for the witness table using the map from the Linkage. + uint32_t seqID = 0; + if (!linkage->mapMangledNameToRTTIObjectIndex.tryGetValue( + witnessTableMangledName, + seqID)) + { + auto interfaceType = + cast(inst->getDataType())->getConformanceType(); + if (as(interfaceType)) + { + auto interfaceLinkage = + interfaceType->findDecoration(); + SLANG_ASSERT( + interfaceLinkage && "An interface type does not have a linkage," + "but a witness table associated with it has one."); + auto interfaceName = interfaceLinkage->getMangledName(); + auto idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( + interfaceName); + if (!idAllocator) + { + linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = + 0; + idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( + interfaceName); + } + seqID = *idAllocator; + ++(*idAllocator); + } + else + { + // NoneWitness, has special ID of -1. + seqID = uint32_t(-1); + } + linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID; + } + + // Add a decoration to the inst. + IRBuilder builder(module); + builder.setInsertBefore(inst); + builder.addSequentialIDDecoration(inst, seqID); + } + } + } + + void processModule() + { + ensureWitnessTableSequentialIDs(); + + processInstsOfType( + kIROp_GetTagFromSequentialID, + [&](IRGetTagFromSequentialID* inst) { return lowerGetTagFromSequentialID(inst); }); + + processInstsOfType( + kIROp_GetSequentialIDFromTag, + [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); + } + + Linkage* getLinkage() { return m_linkage; } + +private: + Linkage* m_linkage; +}; + +void lowerSequentialIDTagCasts(IRModule* module, Linkage* linkage, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + SequentialIDTagLoweringContext context(linkage, module); + context.processModule(); +} + +void lowerTagInsts(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + TagOpsLoweringContext tagContext(module); + tagContext.processModule(); +} + +// This context lowers `IRSetTagType` instructions, by replacing +// them with a suitable integer type. +struct TagTypeLoweringContext : public InstPassBase +{ + TagTypeLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void processModule() + { + processInstsOfType( + kIROp_SetTagType, + [&](IRSetTagType* inst) + { + IRBuilder builder(inst->getModule()); + inst->replaceUsesWith(builder.getUIntType()); + }); + } +}; + +void lowerTagTypes(IRModule* module) +{ + TagTypeLoweringContext context(module); + context.processModule(); +} + +bool isEffectivelyComPtrType(IRType* type) +{ + if (!type) + return false; + if (type->findDecoration() || type->getOp() == kIROp_ComPtrType) + { + return true; + } + if (auto witnessTableType = as(type)) + { + return isComInterfaceType((IRType*)witnessTableType->getConformanceType()); + } + if (auto ptrType = as(type)) + { + auto valueType = ptrType->getValueType(); + return valueType->findDecoration() != nullptr; + } + + return false; +} + +// This context lowers `CastInterfaceToTaggedUnionPtr` by finding all `IRLoad` and +// `IRStore` uses of these insts, and upcasting the tagged-union +// tuple to the the interface-based tuple (of the loaded inst or before +// storing the val, as necessary) +// +struct TaggedUnionLoweringContext : public InstPassBase +{ + TaggedUnionLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + // Extract the required components from an interface-typed value + // and create a tagged union tuple from the result. + // + IRInst* convertToTaggedUnion( + IRBuilder* builder, + IRInst* val, + IRInst* interfaceType, + IRInst* targetType) + { + auto baseInterfaceValue = val; + auto witnessTable = builder->emitExtractExistentialWitnessTable(baseInterfaceValue); + auto tableID = builder->emitGetSequentialIDInst(witnessTable); + + auto taggedUnionTupleType = cast(targetType); + + List getTagOperands; + getTagOperands.add(interfaceType); + getTagOperands.add(tableID); + auto tableTag = builder->emitIntrinsicInst( + (IRType*)taggedUnionTupleType->getOperand(0), + kIROp_GetTagFromSequentialID, + getTagOperands.getCount(), + getTagOperands.getBuffer()); + + return builder->emitMakeTuple( + {tableTag, + builder->emitReinterpret( + (IRType*)taggedUnionTupleType->getOperand(1), + builder->emitExtractExistentialValue( + (IRType*)builder->emitExtractExistentialType(baseInterfaceValue), + baseInterfaceValue))}); + } + + void lowerCastInterfaceToTaggedUnionPtr(IRCastInterfaceToTaggedUnionPtr* inst) + { + // `CastInterfaceToTaggedUnionPtr` is used to 'reinterpret' a pointer to an interface-typed + // location into a tagged union type. Usually this is to avoid changing the type of the + // base location because it is externally visible, and to avoid touching the external layout + // of the interface type. + // + // To lower this, we won't actually change the pointer or the base location, but instead + // rewrite all loads and stores out of this pointer by converting the existential into a + // tagged union tuple. + // + // e.g. + // + // let basePtr : PtrType(InterfaceType(I)) = /* ... */; + // let tuPtr : PtrType(TaggedUnionType(types, tables)) = + // CastInterfaceToTaggedUnionPtr(basePtr); + // let loadedVal : TaggedUnionType(...) = Load(tuPtr); + // + // becomes + // + // let basePtr : PtrType(InterfaceType(I)) = /* ... */; + // let intermediateVal : InterfaceType(I) = Load(basePtr); + // let loadedTableID : TagType(tables) = + // GetTagFromSequentialID( + // InterfaceType(I), + // GetSequentialID( + // ExtractExistentialWitnessTable(intermediateVal))); + // let loadedVal : types = ExtractExistentialValue(intermediateVal); + // let loadedTuple : TupleType(TagType(tables), types) = + // MakeTuple(loadedTableID, loadedVal); + // + // The logic is similar for StructuredBufferLoad and RWStructuredBufferLoad, + // but the operands structure is slightly different. + // + + traverseUses( + inst, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_Load: + { + auto baseInterfacePtr = inst->getOperand(0); + auto baseInterfaceType = as( + as(baseInterfacePtr->getDataType())->getValueType()); + + // Rewrite the load to use the original ptr and load + // an interface-typed object. + // + IRBuilder builder(module); + builder.setInsertAfter(user); + builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); + builder.replaceOperand(&user->typeUse, baseInterfaceType); + + // Then, we'll rewrite it. + List oldUses; + traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); + + auto newVal = convertToTaggedUnion( + &builder, + user, + baseInterfaceType, + as(inst->getDataType())->getValueType()); + for (auto oldUse : oldUses) + { + builder.replaceOperand(oldUse, newVal); + } + break; + } + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + auto baseInterfacePtr = inst->getOperand(0); + auto baseInterfaceType = + as((baseInterfacePtr->getDataType())->getOperand(0)); + + IRBuilder builder(module); + builder.setInsertAfter(user); + builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); + builder.replaceOperand(&user->typeUse, baseInterfaceType); + + // Then, we'll rewrite it. + List oldUses; + traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); + + auto newVal = convertToTaggedUnion( + &builder, + user, + baseInterfaceType, + as(inst->getDataType())->getValueType()); + for (auto oldUse : oldUses) + { + builder.replaceOperand(oldUse, newVal); + } + break; + } + default: + SLANG_UNEXPECTED("Unexpected user of CastInterfaceToTaggedUnionPtr"); + } + }); + + SLANG_ASSERT(!inst->hasUses()); + inst->removeAndDeallocate(); + } + + IRType* lowerTaggedUnionType(IRTaggedUnionType* taggedUnion) + { + // Replace `TaggedUnionType(typeSet, tableSet)` with + // `TupleType(SetTagType(tableSet), typeSet)` + // + // Unless the set has a single element, in which case we + // replace it with `TupleType(SetTagType(tableSet), elementType)` + // + // We still maintain a tuple type (even though it's not really necesssary) to avoid + // breaking any operations that assumed this is a tuple. + // In the single element case, the tuple should be optimized away. + // + + IRBuilder builder(module); + builder.setInsertInto(module); + + auto typeSet = builder.getUntaggedUnionType(taggedUnion->getTypeSet()); + auto tableSet = taggedUnion->getWitnessTableSet(); + + if (taggedUnion->getTypeSet()->isSingleton()) + return builder.getTupleType(List( + {(IRType*)builder.getSetTagType(tableSet), + (IRType*)taggedUnion->getTypeSet()->getElement(0)})); + + return builder.getTupleType( + List({(IRType*)builder.getSetTagType(tableSet), (IRType*)typeSet})); + } + + bool lowerGetValueFromTaggedUnion(IRGetValueFromTaggedUnion* inst) + { + // We replace `GetValueFromTaggedUnion(taggedUnionVal)` with + // `GetTupleElement(taggedUnionVal, 1)` + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tupleVal = inst->getOperand(0); + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)as(tupleVal->getDataType())->getOperand(1), + tupleVal, + 1)); + inst->removeAndDeallocate(); + return true; + } + + bool lowerGetTagFromTaggedUnion(IRGetTagFromTaggedUnion* inst) + { + // We replace `GetTagFromTaggedUnion(taggedUnionVal)` with + // `GetTupleElement(taggedUnionVal, 0)` + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tupleVal = inst->getOperand(0); + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)as(tupleVal->getDataType())->getOperand(0), + tupleVal, + 0)); + inst->removeAndDeallocate(); + return true; + } + + bool lowerGetTypeTagFromTaggedUnion(IRGetTypeTagFromTaggedUnion* inst) + { + // `GetTypeTagFromTaggedUnion(taggedUnionVal)` is not expected to + // appear after lowering, since we currently don't need the type tag + // for anything. + // + // We'll replace it with a poison value so that any accidental uses will result in + // an error later on. + // + IRBuilder builder(module); + builder.setInsertAfter(inst); + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + return true; + } + + + bool lowerMakeTaggedUnion(IRMakeTaggedUnion* inst) + { + // We replace `MakeTaggedUnion(typeTag, witnessTableTag, val)` with `MakeTuple(tag, val)` + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tuTupleType = cast(inst->getDataType()); + + // The current lowering logic is only for bounded tagged unions (finite sets) + SLANG_ASSERT(!as(tuTupleType->getOperand(0))->getSet()->isUnbounded()); + + auto typeTag = inst->getOperand(0); + // We'll ignore the type tag, since the table is the only thing we need. + // for the bounded case. + SLANG_UNUSED(typeTag); + + auto witnessTableTag = inst->getOperand(1); + auto val = inst->getOperand(2); + inst->replaceUsesWith( + builder.emitMakeTuple((IRType*)inst->getDataType(), {witnessTableTag, val})); + inst->removeAndDeallocate(); + return true; + } + + bool processModule() + { + // First, we'll lower all TaggedUnionType insts + // into tuples. + // + processInstsOfType( + kIROp_TaggedUnionType, + [&](IRTaggedUnionType* inst) + { + inst->replaceUsesWith(lowerTaggedUnionType(inst)); + inst->removeAndDeallocate(); + }); + + bool hasCastInsts = false; + processAllInsts( + [&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_GetTagFromTaggedUnion: + lowerGetTagFromTaggedUnion(as(inst)); + break; + case kIROp_GetTypeTagFromTaggedUnion: + lowerGetTypeTagFromTaggedUnion(as(inst)); + break; + case kIROp_GetValueFromTaggedUnion: + lowerGetValueFromTaggedUnion(as(inst)); + break; + case kIROp_MakeTaggedUnion: + lowerMakeTaggedUnion(as(inst)); + break; + case kIROp_CastInterfaceToTaggedUnionPtr: + { + hasCastInsts = true; + lowerCastInterfaceToTaggedUnionPtr( + as(inst)); + } + break; + default: + break; + } + }); + return hasCastInsts; + } +}; + +bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + + TaggedUnionLoweringContext context(module); + return context.processModule(); +} + +// Convert `IsType` insts into an boolean equality check on the sequential IDs of the +// witness table operands. +// +void lowerIsTypeInsts(IRModule* module) +{ + InstPassBase pass(module); + pass.processInstsOfType( + kIROp_IsType, + [&](IRIsType* inst) + { + auto witnessTableType = + as(inst->getValueWitness()->getDataType()); + if (witnessTableType && + isComInterfaceType((IRType*)witnessTableType->getConformanceType())) + return; + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto eqlInst = builder.emitEql( + builder.emitGetSequentialIDInst(inst->getValueWitness()), + builder.emitGetSequentialIDInst(inst->getTargetWitness())); + inst->replaceUsesWith(eqlInst); + inst->removeAndDeallocate(); + }); +} + +struct ExistentialLoweringContext : public InstPassBase +{ + TargetProgram* targetProgram; + + ExistentialLoweringContext(IRModule* module, TargetProgram* targetProgram) + : InstPassBase(module), targetProgram(targetProgram) + { + } + + bool _canReplace(IRUse* use) + { + switch (use->getUser()->getOp()) + { + case kIROp_WitnessTableIDType: + case kIROp_WitnessTableType: + case kIROp_RTTIPointerType: + case kIROp_RTTIHandleType: + case kIROp_ComPtrType: + case kIROp_NativePtrType: + { + // Don't replace + return false; + } + case kIROp_ThisType: + { + // Appears replacable. + break; + } + case kIROp_PtrType: + { + // We can have ** and ComPtr*. + // If it's a pointer type it could be because it is a global. + break; + } + default: + break; + } + return true; + } + + // Replace all WitnessTableID type or RTTIHandleType with `uint2`. + void lowerHandleTypes() + { + List instsToRemove; + for (auto inst : module->getGlobalInsts()) + { + switch (inst->getOp()) + { + case kIROp_WitnessTableIDType: + if (isComInterfaceType((IRType*)inst->getOperand(0))) + continue; + // fall through + case kIROp_RTTIHandleType: + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 2)); + inst->replaceUsesWith(uint2Type); + instsToRemove.add(inst); + } + break; + } + } + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + } + + IRInst* lowerInterfaceType(IRInst* interfaceType) + { + if (isComInterfaceType((IRType*)interfaceType)) + return (IRType*)interfaceType; + + IRBuilder builder(module); + if (isBuiltin(interfaceType)) + return (IRType*)builder.getIntValue(builder.getIntType(), 0); + + // In the dynamic-dispatch case, a value of interface type + // is going to be packed into the "any value" part of a tuple. + // The size of the "any value" part depends on the interface + // type (e.g., it might have an `[anyValueSize(8)]` attribute + // indicating that 8 bytes needs to be reserved). + // + + IRIntegerValue anyValueSize = 0; + if (auto decor = interfaceType->findDecoration()) + { + anyValueSize = decor->getSize(); + } + + auto anyValueType = builder.getAnyValueType(anyValueSize); + auto witnessTableType = builder.getWitnessTableIDType((IRType*)interfaceType); + auto rttiType = builder.getRTTIHandleType(); + + // In the ordinary (dynamic) case, an existential type decomposes + // into a tuple of: + // + // (RTTI, witness table, any-value). + // + return builder.getTupleType(rttiType, witnessTableType, anyValueType); + } + + IRInst* lowerBoundInterfaceType(IRBoundInterfaceType* boundInterfaceType) + { + // A bound interface type represents an existential together with + // static knowledge that the value stored in the extistential has + // a particular concrete type. + // + + IRBuilder builder(module); + + auto payloadType = boundInterfaceType->getConcreteType(); + auto witnessTableType = builder.getWitnessTableIDType( + (IRType*)as(boundInterfaceType->getWitnessTable()) + ->getConformanceType()); + auto rttiType = builder.getRTTIHandleType(); + auto interfaceType = boundInterfaceType->getInterfaceType(); + + IRIntegerValue anyValueSize = 16; + if (auto decor = interfaceType->findDecoration()) + { + anyValueSize = decor->getSize(); + } + + auto anyValueType = builder.getAnyValueType(anyValueSize); + + // Because static specialization is being used (at least in part), + // we do *not* have a guarantee that the `concreteType` is one + // that can fit into the `anyValueSize` of the interface. + // + // We will use the IR layout logic to see if we can compute + // a size for the type, which can lead to a few different outcomes: + // + // * If a size is computed successfully, and it is smaller than or + // equal to `anyValueSize`, then the concrete value will fit into + // the reserved area, and the layout will match the dynamic case. + // + // * If a size is computed successfully, and it is larger than + // `anyValueSize`, then the concrete value cannot fit into the + // reserved area, and it needs to be stored out-of-line. + // + // * If size cannot be computed, then that implies that the type + // includes non-ordinary data (e.g., a `Texture2D` on a D3D11 + // target), and cannot possible fit into the reserved area + // (which consists of only uniform bytes). In this case, the + // value must be stored out-of-line. + // + IRSizeAndAlignment sizeAndAlignment; + Result result = getNaturalSizeAndAlignment( + targetProgram->getOptionSet(), + payloadType, + &sizeAndAlignment); + if (SLANG_FAILED(result) || sizeAndAlignment.size > anyValueSize) + { + // In the case where static specialization mandateds out-of-line storage, + // an existential type decomposes into a tuple of: + // + // (RTTI, witness table, pseudo pointer, any-value) + // + return builder.getTupleType( + rttiType, + witnessTableType, + builder.getPseudoPtrType(payloadType), + anyValueType); + } + else + { + // Regular case (lower in the same way as unbound interface types) + return builder.getTupleType(rttiType, witnessTableType, anyValueType); + } + } + + bool lowerExtractExistentialType(IRExtractExistentialType* inst) + { + // Replace with extraction of the type as a value from the existential tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + if (auto tupleType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(0), + inst->getOperand(0), + 0)); + inst->removeAndDeallocate(); + } + else if (isEffectivelyComPtrType((IRType*)inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + } + return true; + } + + bool lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) + { + // Replace with extraction of the witness table identifier from the existential tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + if (auto tupleType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(1), + inst->getOperand(0), + 1)); + inst->removeAndDeallocate(); + return true; + } + else if (isEffectivelyComPtrType((IRType*)inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + return true; + } + else + { + SLANG_UNEXPECTED("Unexpected type for ExtractExistentialWitnessTable operand"); + } + } + + bool lowerGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst) + { + // Replace with extraction of the value from the tagged-union tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tupleType = as(inst->getOperand(0)->getDataType()); + auto element = + builder.emitGetTupleElement((IRType*)tupleType->getOperand(2), inst->getOperand(0), 2); + if (as(tupleType->getOperand(2))) + { + // The first case is when legacy static specialization + // is applied, and the element is a "pseudo-pointer." + // + // Semantically, we should emit a (pseudo-)load from the pseudo-pointer + // to go from `PseudoPtr` to `T`. + // + // TODO: Actually introduce and emit a "psedudo-load" instruction + // here. For right now we are just using the value directly and + // downstream passes seem okay with it, but it isn't really + // type-correct to be doing this. + // + inst->replaceUsesWith(element); + inst->removeAndDeallocate(); + return true; + } + else + { + // The second case is when the dynamic-dispatch layout is + // being used, and the element is an "any-value." + // + // In this case we need to emit an unpacking operation + // to get from `AnyValue` to `T`. + // + inst->replaceUsesWith(builder.emitUnpackAnyValue(inst->getDataType(), element)); + inst->removeAndDeallocate(); + return true; + } + } + + bool lowerExtractExistentialValue(IRExtractExistentialValue* inst) + { + // Replace with extraction of the value payload from the existential tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + if (auto tupleType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(2), + inst->getOperand(0), + 2)); + inst->removeAndDeallocate(); + return true; + } + else if (isEffectivelyComPtrType((IRType*)inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + return true; + } + else + { + SLANG_UNEXPECTED("Unexpected type for ExtractExistentialValue operand"); + } + } + + bool processGetSequentialIDInst(IRGetSequentialID* inst) + { + // If the operand is a witness table, it is already replaced with a uint2 + // at this point, where the first element in the uint2 is the id of the + // witness table. + // + + IRBuilder builder(module); + builder.setInsertBefore(inst); + + if (auto table = as(inst->getRTTIOperand())) + { + auto seqDecoration = table->findDecoration(); + SLANG_ASSERT(seqDecoration && "Witness table missing SequentialID decoration"); + auto id = builder.getIntValue(builder.getUIntType(), seqDecoration->getSequentialID()); + inst->replaceUsesWith(id); + inst->removeAndDeallocate(); + return true; + } + + UInt index = 0; + auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index); + inst->replaceUsesWith(id); + inst->removeAndDeallocate(); + return true; + } + + void processModule() + { + // Then, start lowering the remaining non-COM/non-Builtin interface types + // At this point, we should only be dealing with public facing uses of + // interface types, as all internal uses would have been rewritten + // into known tagged-union types. + // + // These must lower into either a + // TupleType(RTTI, witness table ID, AnyValue) for regular interface types or a + // TupleType(RTTI, witness table ID, PseudoPtr, AnyValue) for bound interface types. + // + processInstsOfType( + kIROp_InterfaceType, + [&](IRInterfaceType* inst) + { + IRBuilder builder(module); + builder.setInsertInto(module); + if (auto loweredInterfaceType = lowerInterfaceType(inst)) + { + if (loweredInterfaceType != inst) + { + + traverseUses( + inst, + [&](IRUse* use) + { + if (_canReplace(use)) + builder.replaceOperand(use, loweredInterfaceType); + }); + } + } + }); + + processInstsOfType( + kIROp_BoundInterfaceType, + [&](IRBoundInterfaceType* inst) + { + IRBuilder builder(module); + builder.setInsertInto(module); + if (auto loweredBoundInterfaceType = lowerBoundInterfaceType(inst)) + { + if (loweredBoundInterfaceType != inst) + { + traverseUses( + inst, + [&](IRUse* use) + { + if (_canReplace(use)) + builder.replaceOperand(use, loweredBoundInterfaceType); + }); + } + } + }); + + // Replace any other uses with dummy value 0. + // TODO: Ideally, we should replace it with OpPoison? + { + IRBuilder builder(module); + builder.setInsertInto(module); + auto dummyInterfaceObj = builder.getIntValue(builder.getIntType(), 0); + processInstsOfType( + kIROp_InterfaceType, + [&](IRInterfaceType* inst) + { + if (!isComInterfaceType((IRType*)inst)) + { + inst->replaceUsesWith(dummyInterfaceObj); + inst->removeAndDeallocate(); + } + }); + } + + processAllInsts( + [&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ExtractExistentialType: + lowerExtractExistentialType(cast(inst)); + break; + case kIROp_ExtractExistentialValue: + lowerExtractExistentialValue(cast(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + lowerExtractExistentialWitnessTable( + cast(inst)); + break; + case kIROp_GetValueFromBoundInterface: + lowerGetValueFromBoundInterface(cast(inst)); + break; + } + }); + + lowerIsTypeInsts(module); + + processInstsOfType( + kIROp_GetSequentialID, + [&](IRGetSequentialID* inst) { return processGetSequentialIDInst(inst); }); + + lowerHandleTypes(); + } +}; + +bool lowerExistentials(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + ExistentialLoweringContext context(module, targetProgram); + context.processModule(); + return true; +}; + +}; // namespace Slang \ No newline at end of file diff --git a/source/slang/slang-ir-lower-dynamic-dispatch-insts.h b/source/slang/slang-ir-lower-dynamic-dispatch-insts.h new file mode 100644 index 00000000000..d4c4ff064ca --- /dev/null +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.h @@ -0,0 +1,37 @@ +// slang-ir-lower-dynamic-dispatch-insts.h +#pragma once +#include "slang-ir.h" + +namespace Slang +{ +class Linkage; +class TargetProgram; + +// Lower `UntaggedUnionType` types. +void lowerUntaggedUnionTypes(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink); + +// Lower `SetTaggedUnion` and `CastInterfaceToTaggedUnionPtr` instructions +// May create new `Reinterpret` instructions. +// +bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); + +// Lower `SetTagType` types +void lowerTagTypes(IRModule* module); + +// Lower `GetTagOfElementInSet`, `GetTagForSuperSet`, `GetTagForSubSet` and `GetTagForMappedSet` +// instructions, +// +void lowerTagInsts(IRModule* module, DiagnosticSink* sink); + +// Lower `GetTagFromSequentialID` and `GetSequentialIDFromTag` instructions +void lowerSequentialIDTagCasts(IRModule* module, Linkage* linkage, DiagnosticSink* sink); + +// Lower `GetDispatcher` and `GetSpecializedDispatcher` instructions +bool lowerDispatchers(IRModule* module, DiagnosticSink* sink); + +// Lower `ExtractExistentialValue`, `ExtractExistentialType`, `ExtractExistentialWitnessTable`, +// `InterfaceType`, `GetSequentialID`, `WitnessTableIDType` and `RTTIHandleType` instructions. +// +bool lowerExistentials(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-enum-type.cpp b/source/slang/slang-ir-lower-enum-type.cpp index 548c29a5182..b22539192df 100644 --- a/source/slang/slang-ir-lower-enum-type.cpp +++ b/source/slang/slang-ir-lower-enum-type.cpp @@ -49,6 +49,23 @@ struct EnumTypeLoweringContext if (!type) return nullptr; + if (auto attributedType = as(type)) + { + IRBuilder builder(module); + + List attrs; + for (auto attr : attributedType->getAllAttrs()) + attrs.add(attr); + + RefPtr info = new LoweredEnumTypeInfo(); + info->enumType = (IRType*)type; + info->loweredType = builder.getAttributedType( + getLoweredEnumType(type->getOperand(0))->loweredType, + attrs); + loweredEnumTypes[type] = info; + return info.Ptr(); + } + if (type->getOp() != kIROp_EnumType) return nullptr; diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp deleted file mode 100644 index ea7f906a9a6..00000000000 --- a/source/slang/slang-ir-lower-existential.cpp +++ /dev/null @@ -1,313 +0,0 @@ -// slang-ir-lower-generic-existential.cpp - -#include "slang-ir-lower-existential.h" - -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -bool isCPUTarget(TargetRequest* targetReq); -bool isCUDATarget(TargetRequest* targetReq); - -struct ExistentialLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - void processMakeExistential(IRMakeExistentialWithRTTI* inst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - auto value = inst->getWrappedValue(); - auto valueType = sharedContext->lowerType(builder, value->getDataType()); - if (valueType->getOp() == kIROp_ComPtrType) - return; - auto witnessTableType = - cast(inst->getWitnessTable()->getDataType()); - auto interfaceType = witnessTableType->getConformanceType(); - if (interfaceType->findDecoration()) - return; - auto witnessTableIdType = builder->getWitnessTableIDType((IRType*)interfaceType); - auto anyValueSize = sharedContext->getInterfaceAnyValueSize(interfaceType, inst->sourceLoc); - auto anyValueType = builder->getAnyValueType(anyValueSize); - auto rttiType = builder->getRTTIHandleType(); - auto tupleType = builder->getTupleType(rttiType, witnessTableIdType, anyValueType); - - IRInst* rttiObject = inst->getRTTI(); - if (auto type = as(rttiObject)) - { - rttiObject = sharedContext->maybeEmitRTTIObject(type); - rttiObject = builder->emitGetAddress(rttiType, rttiObject); - } - IRInst* packedValue = value; - if (valueType->getOp() != kIROp_AnyValueType) - packedValue = builder->emitPackAnyValue(anyValueType, value); - IRInst* tupleArgs[] = {rttiObject, inst->getWitnessTable(), packedValue}; - auto tuple = builder->emitMakeTuple(tupleType, 3, tupleArgs); - inst->replaceUsesWith(tuple); - inst->removeAndDeallocate(); - } - - // Translates `createExistentialObject` insts, which takes a user defined - // type id and user defined value and turns into an existential value, - // into a `makeTuple` inst that makes the tuple representing the lowered - // existential value. - void processCreateExistentialObject(IRCreateExistentialObject* inst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - - // The result type of this `createExistentialObject` inst should already - // be lowered into a `TupleType(rttiType, WitnessTableIDType, AnyValueType)` - // in the previous `lowerGenericType` pass. - auto tupleType = inst->getDataType(); - auto witnessTableIdType = cast(tupleType->getOperand(1)); - auto anyValueType = cast(tupleType->getOperand(2)); - - // Create a standin value for `rttiObject` for now since it will not be used - // other than test for null in the case of `Optional`. - auto uint2Type = builder->getVectorType( - builder->getUIntType(), - builder->getIntValue(builder->getIntType(), 2)); - IRInst* standinVal = builder->getIntValue(builder->getUIntType(), 0xFFFFFFFF); - IRInst* zero = builder->getIntValue(builder->getUIntType(), 0); - IRInst* standinRTTIVectorArgs[] = {standinVal, zero}; - IRInst* rttiObject = builder->emitMakeVector(uint2Type, 2, standinRTTIVectorArgs); - - // Pack the user provided value into `AnyValue`. - IRInst* packedValue = inst->getValue(); - if (packedValue->getDataType()->getOp() != kIROp_AnyValueType) - packedValue = builder->emitPackAnyValue(anyValueType, packedValue); - - // Use the user provided `typeID` value as the witness table ID field in the - // newly constructed tuple. - // All `WitnessTableID` types are lowered into `uint2`s, so we need to create - // a `uint2` value from `typeID` to stay consistent with the convention. - IRInst* vectorArgs[2] = { - inst->getTypeID(), - builder->getIntValue(builder->getUIntType(), 0)}; - - IRInst* typeIdValue = builder->emitMakeVector(uint2Type, 2, vectorArgs); - typeIdValue = builder->emitBitCast(witnessTableIdType, typeIdValue); - IRInst* tupleArgs[] = {rttiObject, typeIdValue, packedValue}; - auto tuple = builder->emitMakeTuple(tupleType, 3, tupleArgs); - inst->replaceUsesWith(tuple); - inst->removeAndDeallocate(); - } - - IRInst* extractTupleElement(IRBuilder* builder, IRInst* value, UInt index) - { - auto tupleType = cast(sharedContext->lowerType(builder, value->getDataType())); - auto getElement = - builder->emitGetTupleElement((IRType*)tupleType->getOperand(index), value, index); - return getElement; - } - - void processExtractExistentialElement(IRInst* extractInst, UInt elementId) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(extractInst); - - IRInst* element = nullptr; - if (isComInterfaceType(extractInst->getOperand(0)->getDataType())) - { - // If this is an COM interface, the elements (witness table/rtti) are just the interface - // value itself. - element = extractInst->getOperand(0); - } - else - { - element = extractTupleElement(builder, extractInst->getOperand(0), elementId); - } - extractInst->replaceUsesWith(element); - extractInst->removeAndDeallocate(); - } - - void processExtractExistentialValue(IRExtractExistentialValue* inst) - { - processExtractExistentialElement(inst, 2); - } - - void processIsNullExistential(IRIsNullExistential* inst) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - - auto rttiElement = extractTupleElement(&builder, inst->getOperand(0), 0); - auto isNull = builder.emitNeq( - builder.emitGetElement(builder.getUIntType(), rttiElement, 0), - builder.getIntValue(builder.getUIntType(), 0)); - inst->replaceUsesWith(isNull); - inst->removeAndDeallocate(); - } - - void processExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) - { - processExtractExistentialElement(inst, 1); - } - - void processExtractExistentialType(IRExtractExistentialType* extractInst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(extractInst); - - IRInst* element = nullptr; - IRInst* anyValueType = nullptr; - if (isComInterfaceType(extractInst->getOperand(0)->getDataType())) - { - // If this is an COM interface, the elements (witness table/rtti) are just the interface - // value itself. - element = extractInst->getOperand(0); - } - else - { - element = extractTupleElement(builder, extractInst->getOperand(0), 0); - if (auto tupleType = as(extractInst->getOperand(0)->getDataType())) - { - anyValueType = tupleType->getOperand(2); - } - } - - // If this instruction is used as a type, we need to replace it with the lowered type, - // which should be an AnyValueType. - // If it is used as a value, then we can replace it with the extracted element. - auto isTypeUse = [](IRUse* use) -> bool - { - auto user = use->getUser(); - if (as(user)) - return true; - if (use == &use->getUser()->typeUse) - return true; - return false; - }; - traverseUses( - extractInst, - [&](IRUse* use) - { - if (anyValueType && isTypeUse(use)) - { - builder->replaceOperand(use, anyValueType); - return; - } - builder->replaceOperand(use, element); - }); - } - - void processGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - if (inst->getDataType()->getOp() == kIROp_ClassType) - { - return; - } - // A value of interface will lower as a tuple, and - // the third element of that tuple represents the - // concrete value that was put into the existential. - // - auto element = extractTupleElement(builder, inst->getOperand(0), 2); - auto elementType = element->getDataType(); - - // There are two cases we expect to see for that - // tuple element. - // - IRInst* replacement = nullptr; - if (as(elementType)) - { - // The first case is when legacy static specialization - // is applied, and the element is a "pseudo-pointer." - // - // Semantically, we should emit a (pseudo-)load from the pseudo-pointer - // to go from `PseudoPtr` to `T`. - // - // TODO: Actually introduce and emit a "psedudo-load" instruction - // here. For right now we are just using the value directly and - // downstream passes seem okay with it, but it isn't really - // type-correct to be doing this. - // - replacement = element; - } - else - { - // The second case is when the dynamic-dispatch layout is - // being used, and the element is an "any-value." - // - // In this case we need to emit an unpacking operation - // to get from `AnyValue` to `T`. - // - SLANG_ASSERT(as(elementType)); - replacement = builder->emitUnpackAnyValue(inst->getFullType(), element); - } - - inst->replaceUsesWith(replacement); - inst->removeAndDeallocate(); - } - - void processInst(IRInst* inst) - { - if (auto makeExistential = as(inst)) - { - processMakeExistential(makeExistential); - } - else if (auto createExistentialObject = as(inst)) - { - processCreateExistentialObject(createExistentialObject); - } - else if (auto getValueFromBoundInterface = as(inst)) - { - processGetValueFromBoundInterface(getValueFromBoundInterface); - } - else if (auto extractExistentialVal = as(inst)) - { - processExtractExistentialValue(extractExistentialVal); - } - else if (auto extractExistentialType = as(inst)) - { - processExtractExistentialType(extractExistentialType); - } - else if (auto extractExistentialWitnessTable = as(inst)) - { - processExtractExistentialWitnessTable(extractExistentialWitnessTable); - } - else if (auto isNullExistential = as(inst)) - { - processIsNullExistential(isNullExistential); - } - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - } -}; - - -void lowerExistentials(SharedGenericsLoweringContext* sharedContext) -{ - ExistentialLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-existential.h b/source/slang/slang-ir-lower-existential.h deleted file mode 100644 index 73ae61f75a7..00000000000 --- a/source/slang/slang-ir-lower-existential.h +++ /dev/null @@ -1,11 +0,0 @@ -// slang-ir-lower-existential.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower existential types and related instructions to Tuple types. -void lowerExistentials(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp deleted file mode 100644 index e4fb5ac9814..00000000000 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ /dev/null @@ -1,404 +0,0 @@ -// slang-ir-lower-generic-call.cpp -#include "slang-ir-lower-generic-call.h" - -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-util.h" - -namespace Slang -{ -struct GenericCallLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - // Represents a work item for unpacking `inout` or `out` arguments after a generic call. - struct ArgumentUnpackWorkItem - { - // Concrete typed destination. - IRInst* dstArg = nullptr; - // Packed argument. - IRInst* packedArg = nullptr; - }; - - // Packs `arg` into a `IRAnyValue` if necessary, to make it feedable into the parameter. - // If `arg` represents a concrete typed variable passed in to a generic `out` parameter, - // this function indicates that it needs to be unpacked after the call by setting - // `unpackAfterCall`. - IRInst* maybePackArgument( - IRBuilder* builder, - IRType* paramType, - IRInst* arg, - ArgumentUnpackWorkItem& unpackAfterCall) - { - unpackAfterCall.dstArg = nullptr; - unpackAfterCall.packedArg = nullptr; - - // If either paramType or argType is a pointer type - // (because of `inout` or `out` modifiers), we extract - // the underlying value type first. - IRType* paramValType = paramType; - IRType* argValType = arg->getDataType(); - IRInst* argVal = arg; - if (auto ptrType = as(paramType)) - { - paramValType = ptrType->getValueType(); - } - auto argType = arg->getDataType(); - if (auto argPtrType = as(argType)) - { - argValType = argPtrType->getValueType(); - argVal = builder->emitLoad(arg); - } - - // Pack `arg` if the parameter expects AnyValue but - // `arg` is not an AnyValue. - if (as(paramValType) && !as(argValType)) - { - auto packedArgVal = builder->emitPackAnyValue(paramValType, argVal); - // if parameter expects an `out` pointer, store the packed val into a - // variable and pass in a pointer to that variable. - if (as(paramType)) - { - auto tempVar = builder->emitVar(paramValType); - builder->emitStore(tempVar, packedArgVal); - // tempVar needs to be unpacked into original var after the call. - unpackAfterCall.dstArg = arg; - unpackAfterCall.packedArg = tempVar; - return tempVar; - } - else - { - return packedArgVal; - } - } - return arg; - } - - IRInst* maybeUnpackValue( - IRBuilder* builder, - IRType* expectedType, - IRType* actualType, - IRInst* value) - { - if (as(actualType) && !as(expectedType)) - { - auto unpack = builder->emitUnpackAnyValue(expectedType, value); - return unpack; - } - return value; - } - - // Create a dispatch function for a interface method. - // On CPU, the dispatch function is implemented as a witness table lookup followed by - // a function-pointer call. - // On GPU targets, we can modify the body of the dispatch function in a follow-up - // pass to implement it with a `switch` statement based on the type ID. - IRFunc* _createInterfaceDispatchMethod( - IRBuilder* builder, - IRInterfaceType* interfaceType, - IRInst* requirementKey, - IRInst* requirementVal) - { - auto func = builder->createFunc(); - if (auto linkage = requirementKey->findDecoration()) - { - builder->addNameHintDecoration(func, linkage->getMangledName()); - } - - auto reqFuncType = cast(requirementVal); - List paramTypes; - paramTypes.add(builder->getWitnessTableType(interfaceType)); - for (UInt i = 0; i < reqFuncType->getParamCount(); i++) - { - paramTypes.add(reqFuncType->getParamType(i)); - } - auto dispatchFuncType = builder->getFuncType(paramTypes, reqFuncType->getResultType()); - func->setFullType(dispatchFuncType); - builder->setInsertInto(func); - builder->emitBlock(); - List params; - IRParam* witnessTableParam = builder->emitParam(paramTypes[0]); - for (Index i = 1; i < paramTypes.getCount(); i++) - { - params.add(builder->emitParam(paramTypes[i])); - } - auto callee = - builder->emitLookupInterfaceMethodInst(reqFuncType, witnessTableParam, requirementKey); - auto call = (IRCall*)builder->emitCallInst(reqFuncType->getResultType(), callee, params); - if (call->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(call); - return func; - } - - // If an interface dispatch method is already created, return it. - // Otherwise, create the method. - IRFunc* getOrCreateInterfaceDispatchMethod( - IRBuilder* builder, - IRInterfaceType* interfaceType, - IRInst* requirementKey, - IRInst* requirementVal) - { - if (auto func = sharedContext->mapInterfaceRequirementKeyToDispatchMethods.tryGetValue( - requirementKey)) - return *func; - auto dispatchFunc = - _createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal); - sharedContext->mapInterfaceRequirementKeyToDispatchMethods.addIfNotExists( - requirementKey, - dispatchFunc); - return dispatchFunc; - } - - // Translate `callInst` into a call of `newCallee`, and respect the new `funcType`. - // If `newCallee` is a lowered generic function, `specializeInst` contains the type - // arguments used to specialize the callee. - void translateCallInst( - IRCall* callInst, - IRFuncType* funcType, - IRInst* newCallee, - IRSpecialize* specializeInst) - { - List paramTypes; - for (UInt i = 0; i < funcType->getParamCount(); i++) - paramTypes.add(funcType->getParamType(i)); - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(callInst); - - // Process the argument list of the call. - // For each argument, we test if it needs to be packed into an `AnyValue` for the - // call. For `out` and `inout` parameters, they may also need to be unpacked after - // the call, in which case we add such the argument to `argsToUnpack` so it can be - // processed after the new call inst is emitted. - List args; - List argsToUnpack; - for (UInt i = 0; i < callInst->getArgCount(); i++) - { - auto arg = callInst->getArg(i); - ArgumentUnpackWorkItem unpackWorkItem; - auto newArg = maybePackArgument(builder, paramTypes[i], arg, unpackWorkItem); - args.add(newArg); - if (unpackWorkItem.packedArg) - argsToUnpack.add(unpackWorkItem); - } - if (specializeInst) - { - for (UInt i = 0; i < specializeInst->getArgCount(); i++) - { - auto arg = specializeInst->getArg(i); - // Translate Type arguments into RTTI object. - if (as(arg)) - { - // We are using a simple type to specialize a callee. - // Generate RTTI for this type. - auto rttiObject = sharedContext->maybeEmitRTTIObject(arg); - arg = builder->emitGetAddress(builder->getRTTIHandleType(), rttiObject); - } - else if (arg->getOp() == kIROp_Specialize) - { - // The type argument used to specialize a callee is itself a - // specialization of some generic type. - // TODO: generate RTTI object for specializations of generic types. - SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); - } - else if (arg->getOp() == kIROp_RTTIObject) - { - // We are inside a generic function and using a generic parameter - // to specialize another callee. The generic parameter of the caller - // has already been translated into an RTTI object, so we just need - // to pass this object down. - } - args.add(arg); - } - } - - // If callee returns `AnyValue` but we are expecting a concrete value, unpack it. - auto calleeRetType = funcType->getResultType(); - auto newCall = builder->emitCallInst(calleeRetType, newCallee, args); - auto callInstType = callInst->getDataType(); - auto unpackInst = maybeUnpackValue(builder, callInstType, calleeRetType, newCall); - // Unpack other `out` arguments. - for (auto& item : argsToUnpack) - { - auto packedVal = builder->emitLoad(item.packedArg); - auto originalValType = cast(item.dstArg->getDataType())->getValueType(); - auto unpackedVal = builder->emitUnpackAnyValue(originalValType, packedVal); - builder->emitStore(item.dstArg, unpackedVal); - } - callInst->replaceUsesWith(unpackInst); - callInst->removeAndDeallocate(); - } - - IRInst* findInnerMostSpecializingBase(IRSpecialize* inst) - { - auto result = inst->getBase(); - while (auto specialize = as(result)) - result = specialize->getBase(); - return result; - } - - void lowerCallToSpecializedFunc(IRCall* callInst, IRSpecialize* specializeInst) - { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - auto loweredFunc = specializeInst->getBase(); - - // Don't process intrinsic functions. - UnownedStringSlice intrinsicDef; - IRInst* intrinsicInst; - if (findTargetIntrinsicDefinition( - getResolvedInstForDecorations(loweredFunc), - sharedContext->targetProgram->getTargetReq()->getTargetCaps(), - intrinsicDef, - intrinsicInst)) - return; - - // All callees should have already been lowered in lower-generic-functions pass. - // For intrinsic generic functions, they are left as is, and we also need to ignore - // them here. - if (loweredFunc->getOp() == kIROp_Generic) - { - return; - } - else if (loweredFunc->getOp() == kIROp_Specialize) - { - // All nested generic functions are supposed to be flattend before this pass. - // If they are not, they represent an intrinsic function that should not be - // modified in this pass. - SLANG_UNEXPECTED("Nested generics specialization."); - } - else if (loweredFunc->getOp() == kIROp_LookupWitnessMethod) - { - lowerCallToInterfaceMethod( - callInst, - cast(loweredFunc), - specializeInst); - return; - } - IRFuncType* funcType = cast(loweredFunc->getDataType()); - translateCallInst(callInst, funcType, loweredFunc, specializeInst); - } - - void lowerCallToInterfaceMethod( - IRCall* callInst, - IRLookupWitnessMethod* lookupInst, - IRSpecialize* specializeInst) - { - // If we see a call(lookup_interface_method(...), ...), we need to translate - // all occurences of associatedtypes. - - // If `w` in `lookup_interface_method(w, ...)` is a COM interface, bail. - if (isComInterfaceType(lookupInst->getWitnessTable()->getDataType())) - { - return; - } - - auto interfaceType = as( - cast(lookupInst->getWitnessTable()->getDataType()) - ->getConformanceType()); - - if (!interfaceType) - { - // NoneWitness -> remove call. - callInst->removeAndDeallocate(); - return; - } - - if (isBuiltin(interfaceType)) - return; - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(callInst); - - // Create interface dispatch method that bottlenecks the dispatch logic. - auto requirementKey = lookupInst->getRequirementKey(); - auto requirementVal = - sharedContext->findInterfaceRequirementVal(interfaceType, requirementKey); - - if (interfaceType->findDecoration()) - { - sharedContext->sink->diagnose( - callInst->sourceLoc, - Diagnostics::dynamicDispatchOnSpecializeOnlyInterface, - interfaceType); - } - auto dispatchFunc = getOrCreateInterfaceDispatchMethod( - builder, - interfaceType, - requirementKey, - requirementVal); - - auto parentFunc = getParentFunc(callInst); - // Don't process the call inst that is the one in the dispatch function itself. - if (parentFunc == dispatchFunc) - return; - - // Replace `callInst` with a new call inst that calls `dispatchFunc` instead, and - // with the witness table as first argument, - builder->setInsertBefore(callInst); - List newArgs; - newArgs.add(lookupInst->getWitnessTable()); - for (UInt i = 0; i < callInst->getArgCount(); i++) - newArgs.add(callInst->getArg(i)); - auto newCall = - (IRCall*)builder->emitCallInst(callInst->getFullType(), dispatchFunc, newArgs); - callInst->replaceUsesWith(newCall); - callInst->removeAndDeallocate(); - - // Translate the new call inst as normal, taking care of packing/unpacking inputs - // and outputs. - translateCallInst( - newCall, - cast(dispatchFunc->getFullType()), - dispatchFunc, - specializeInst); - } - - void lowerCall(IRCall* callInst) - { - if (auto specializeInst = as(callInst->getCallee())) - lowerCallToSpecializedFunc(callInst, specializeInst); - else if (auto lookupInst = as(callInst->getCallee())) - lowerCallToInterfaceMethod(callInst, lookupInst, nullptr); - } - - void processInst(IRInst* inst) - { - if (auto callInst = as(inst)) - { - lowerCall(callInst); - } - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - } -}; - -void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext) -{ - GenericCallLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-call.h b/source/slang/slang-ir-lower-generic-call.h deleted file mode 100644 index a876634e6ed..00000000000 --- a/source/slang/slang-ir-lower-generic-call.h +++ /dev/null @@ -1,12 +0,0 @@ -// slang-ir-lower-generic-call.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower generic and interface-based code to ordinary types and functions using -/// dynamic dispatch mechanisms. -void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp deleted file mode 100644 index 87243ff5cee..00000000000 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ /dev/null @@ -1,450 +0,0 @@ -// slang-ir-lower-generic-function.cpp -#include "slang-ir-lower-generic-function.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -// This is a subpass of generics lowering IR transformation. -// This pass lowers all generic function types and function definitions, including -// the function types used in interface types, to ordinary functions that takes -// raw pointers in place of generic types. -struct GenericFunctionLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - IRInst* lowerGenericFunction(IRInst* genericValue) - { - IRInst* result = nullptr; - if (sharedContext->loweredGenericFunctions.tryGetValue(genericValue, result)) - return result; - // Do not lower intrinsic functions. - if (genericValue->findDecoration()) - return genericValue; - auto genericParent = as(genericValue); - SLANG_ASSERT(genericParent); - SLANG_ASSERT(genericParent->getDataType()); - auto genericRetVal = findGenericReturnVal(genericParent); - auto func = as(genericRetVal); - if (!func) - { - // Nested generic functions are supposed to be flattened before entering - // this pass. The reason we are still seeing them must be that they are - // intrinsic functions. In this case we ignore the function. - if (as(genericRetVal)) - { - SLANG_ASSERT( - findInnerMostGenericReturnVal(genericParent) - ->findDecoration() != nullptr); - } - return genericValue; - } - SLANG_ASSERT(func); - // Do not lower intrinsic functions. - UnownedStringSlice intrinsicDef; - IRInst* intrinsicInst; - if (!func->isDefinition() || - findTargetIntrinsicDefinition( - func, - sharedContext->targetProgram->getTargetReq()->getTargetCaps(), - intrinsicDef, - intrinsicInst)) - { - sharedContext->loweredGenericFunctions[genericValue] = genericValue; - return genericValue; - } - IRCloneEnv cloneEnv; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(genericParent); - // Do not clone func type (which would break IR def-use rules if we do it here) - // This is OK since we will lower the type immediately after the clone. - cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind(); - auto loweredFunc = cast(cloneInstAndOperands(&cloneEnv, &builder, func)); - auto loweredGenericType = - lowerGenericFuncType(&builder, genericParent, cast(func->getFullType())); - SLANG_ASSERT(loweredGenericType); - loweredFunc->setFullType(loweredGenericType); - - OrderedHashSet childrenToDemote; - List clonedParams; - auto moduleInst = genericParent->getModule()->getModuleInst(); - for (auto genericChild : genericParent->getFirstBlock()->getChildren()) - { - switch (genericChild->getOp()) - { - case kIROp_Func: - continue; - case kIROp_Return: - continue; - } - // Process all generic parameters and local type definitions. - auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); - switch (clonedChild->getOp()) - { - case kIROp_Param: - { - auto paramType = clonedChild->getFullType(); - auto loweredParamType = sharedContext->lowerType(&builder, paramType); - if (loweredParamType != paramType) - { - clonedChild->setFullType((IRType*)loweredParamType); - } - clonedParams.add(clonedChild); - } - break; - case kIROp_Specialize: - case kIROp_LookupWitnessMethod: - childrenToDemote.add(clonedChild); - break; - default: - { - bool shouldDemote = false; - if (childrenToDemote.contains(clonedChild->getFullType())) - shouldDemote = true; - for (UInt i = 0; i < clonedChild->getOperandCount(); i++) - { - if (childrenToDemote.contains(clonedChild->getOperand(i))) - { - shouldDemote = true; - break; - } - } - if (shouldDemote && clonedChild->getParent() == moduleInst) - { - childrenToDemote.add(clonedChild); - } - continue; - } - } - } - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, func, loweredFunc); - fixUpDebugFuncType(loweredFunc); - - auto block = as(loweredFunc->getFirstChild()); - for (auto param : clonedParams) - { - param->removeFromParent(); - block->addParam(as(param)); - } - - // Demote specialize and lookupWitness insts and their dependents down to function body. - auto insertPoint = block->getFirstOrdinaryInst(); - List childrenToDemoteList; - for (auto child : childrenToDemote) - childrenToDemoteList.add(child); - for (Index i = childrenToDemoteList.getCount() - 1; i >= 0; i--) - { - auto child = childrenToDemoteList[i]; - child->insertBefore(insertPoint); - } - - // Lower generic typed parameters into AnyValueType. - auto firstInst = loweredFunc->getFirstOrdinaryInst(); - builder.setInsertBefore(firstInst); - sharedContext->loweredGenericFunctions[genericValue] = loweredFunc; - sharedContext->addToWorkList(loweredFunc); - return loweredFunc; - } - - IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType) - { - ShortList genericParamTypes; - Dictionary typeMapping; - for (auto genericParam : genericVal->getParams()) - { - genericParamTypes.add(sharedContext->lowerType(builder, genericParam->getFullType())); - if (auto anyValueSizeDecor = genericParam->findDecoration()) - { - auto anyValueSize = sharedContext->getInterfaceAnyValueSize( - anyValueSizeDecor->getConstraintType(), - genericParam->sourceLoc); - auto anyValueType = builder->getAnyValueType(anyValueSize); - typeMapping[genericParam] = anyValueType; - } - } - - auto innerType = (IRFuncType*)lowerFuncType( - builder, - funcType, - typeMapping, - genericParamTypes.getArrayView().arrayView); - - return innerType; - } - - IRType* lowerFuncType( - IRBuilder* builder, - IRFuncType* funcType, - const Dictionary& typeMapping, - ArrayView additionalParams) - { - List newOperands; - bool translated = false; - for (UInt i = 0; i < funcType->getOperandCount(); i++) - { - auto paramType = funcType->getOperand(i); - auto loweredParamType = - sharedContext->lowerType(builder, paramType, typeMapping, nullptr); - SLANG_ASSERT(loweredParamType); - translated = translated || (loweredParamType != paramType); - newOperands.add(loweredParamType); - } - if (!translated && additionalParams.getCount() == 0) - return funcType; - for (Index i = 0; i < additionalParams.getCount(); i++) - { - newOperands.add(additionalParams[i]); - } - auto newFuncType = builder->getFuncType( - newOperands.getCount() - 1, - (IRType**)(newOperands.begin() + 1), - (IRType*)newOperands[0]); - - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, funcType, newFuncType); - return newFuncType; - } - - IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) - { - IRInterfaceType* loweredType = nullptr; - if (sharedContext->loweredInterfaceTypes.tryGetValue(interfaceType, loweredType)) - return loweredType; - if (sharedContext->mapLoweredInterfaceToOriginal.containsKey(interfaceType)) - return interfaceType; - // Do not lower intrinsic interfaces. - if (isBuiltin(interfaceType)) - return interfaceType; - // Do not lower COM interfaces. - if (isComInterfaceType(interfaceType)) - return interfaceType; - - List newEntries; - - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(interfaceType); - - // Translate IRFuncType in interface requirements. - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - if (auto entry = as(interfaceType->getOperand(i))) - { - // Note: The logic that creates the `IRInterfaceRequirementEntry`s does - // not currently guarantee that the *value* part of each key-value pair - // gets filled in. We thus need to defend against a null `requirementVal` - // here, at least until the underlying issue gets resolved. - // - IRInst* requirementVal = entry->getRequirementVal(); - IRInst* loweredVal = nullptr; - if (!requirementVal) - { - } - else if (auto funcType = as(requirementVal)) - { - loweredVal = lowerFuncType( - &builder, - funcType, - Dictionary(), - ArrayView()); - } - else if (auto genericFuncType = as(requirementVal)) - { - loweredVal = lowerGenericFuncType( - &builder, - genericFuncType, - cast(findGenericReturnVal(genericFuncType))); - } - else if (requirementVal->getOp() == kIROp_AssociatedType) - { - loweredVal = builder.getRTTIHandleType(); - } - else - { - loweredVal = requirementVal; - } - auto newEntry = - builder.createInterfaceRequirementEntry(entry->getRequirementKey(), loweredVal); - newEntries.add(newEntry); - } - } - loweredType = - builder.createInterfaceType(newEntries.getCount(), (IRInst**)newEntries.getBuffer()); - loweredType->sourceLoc = interfaceType->sourceLoc; - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren( - &cloneEnv, - sharedContext->module, - interfaceType, - loweredType); - sharedContext->loweredInterfaceTypes.add(interfaceType, loweredType); - sharedContext->mapLoweredInterfaceToOriginal[loweredType] = interfaceType; - return loweredType; - } - - bool isTypeKindVal(IRInst* inst) - { - auto type = inst->getDataType(); - if (!type) - return false; - return type->getOp() == kIROp_TypeKind; - } - - // Lower items in a witness table. This triggers lowering of generic functions, - // and emission of wrapper functions. - void lowerWitnessTable(IRWitnessTable* witnessTable) - { - IRInterfaceType* conformanceType = as(witnessTable->getConformanceType()); - if (!conformanceType) - return; - - auto interfaceType = maybeLowerInterfaceType(conformanceType); - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(witnessTable); - if (interfaceType != witnessTable->getConformanceType()) - { - auto newWitnessTableType = builder->getWitnessTableType(interfaceType); - witnessTable->setFullType(newWitnessTableType); - } - if (isBuiltin(interfaceType)) - return; - for (auto child : witnessTable->getChildren()) - { - auto entry = as(child); - if (!entry) - continue; - if (auto genericVal = as(entry->getSatisfyingVal())) - { - // Lower generic functions in witness table. - if (findGenericReturnVal(genericVal)->getOp() == kIROp_Func) - { - auto loweredFunc = lowerGenericFunction(genericVal); - entry->satisfyingVal.set(loweredFunc); - } - } - else if (isTypeKindVal(entry->getSatisfyingVal())) - { - // Translate a Type value to an RTTI object pointer. - auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); - auto rttiObjectPtr = - builder->emitGetAddress(builder->getRTTIHandleType(), rttiObject); - entry->satisfyingVal.set(rttiObjectPtr); - } - else if (as(entry->getSatisfyingVal())) - { - // No processing needed here. - // The witness table will be processed from the work list. - } - } - } - - void lowerLookupInterfaceMethodInst(IRLookupWitnessMethod* lookupInst) - { - // Update the type of lookupInst to the lowered type of the corresponding interface - // requirement val. - - // If the requirement is a function, interfaceRequirementVal will be the lowered function - // type. If the requirement is an associatedtype, interfaceRequirementVal will be - // Ptr. - IRInst* interfaceRequirementVal = nullptr; - auto witnessTableType = - as(lookupInst->getWitnessTable()->getDataType()); - if (!witnessTableType) - return; - if (witnessTableType->getConformanceType()->findDecoration()) - return; - - IRInterfaceType* conformanceType = - as(witnessTableType->getConformanceType()); - - // NoneWitness generates conformance types which aren't interfaces. In - // that case, the method can just be skipped entirely, since there's no - // real witness for it and it should be in unreachable code at this - // point. - if (!conformanceType) - return; - auto interfaceType = maybeLowerInterfaceType(conformanceType); - interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( - interfaceType, - lookupInst->getRequirementKey()); - IRBuilder builder(lookupInst); - builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal); - } - - void lowerSpecialize(IRSpecialize* specializeInst) - { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - IRInst* loweredFunc = nullptr; - auto funcToSpecialize = specializeInst->getBase(); - if (funcToSpecialize->getOp() == kIROp_Generic) - { - loweredFunc = lowerGenericFunction(funcToSpecialize); - if (loweredFunc != funcToSpecialize) - { - IRBuilder builder; - builder.replaceOperand(specializeInst->getOperands(), loweredFunc); - } - } - } - - void processInst(IRInst* inst) - { - if (auto specializeInst = as(inst)) - { - lowerSpecialize(specializeInst); - } - else if (auto lookupInterfaceMethod = as(inst)) - { - lowerLookupInterfaceMethodInst(lookupInterfaceMethod); - } - else if (auto witnessTable = as(inst)) - { - lowerWitnessTable(witnessTable); - } - else if (auto interfaceType = as(inst)) - { - maybeLowerInterfaceType(interfaceType); - } - } - - void replaceLoweredInterfaceTypes() - { - for (const auto& [loweredKey, loweredValue] : sharedContext->loweredInterfaceTypes) - loweredKey->replaceUsesWith(loweredValue); - sharedContext->mapInterfaceRequirementKeyValue.clear(); - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - - replaceLoweredInterfaceTypes(); - } -}; -void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext) -{ - GenericFunctionLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-function.h b/source/slang/slang-ir-lower-generic-function.h deleted file mode 100644 index c70ad0ae6c6..00000000000 --- a/source/slang/slang-ir-lower-generic-function.h +++ /dev/null @@ -1,19 +0,0 @@ -// slang-ir-lower-generic-function.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower generic and interface-based code to ordinary types and functions using -/// dynamic dispatch mechanisms. -/// After this pass, generic type parameters will be lowered into `AnyValue` types, -/// and an existential type I in function signatures will be lowered into -/// `Tuple`. -/// Note that this pass mostly deals with function signatures and interface definitions, -/// and does not modify function bodies. -/// All variable declarations and type uses are handled in `lower-generic-type`, -/// and all call sites are handled in `lower-generic-call`. -void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-type.cpp b/source/slang/slang-ir-lower-generic-type.cpp deleted file mode 100644 index c4dcd92a9a1..00000000000 --- a/source/slang/slang-ir-lower-generic-type.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// slang-ir-lower-generic-type.cpp -#include "slang-ir-lower-generic-type.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir.h" - -namespace Slang -{ -// This is a subpass of generics lowering IR transformation. -// This pass lowers all generic/polymorphic types into IRAnyValueType. -struct GenericTypeLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - IRInst* processInst(IRInst* inst) - { - // Ensure exported struct types has RTTI object defined. - if (as(inst)) - { - if (inst->findDecoration()) - { - sharedContext->maybeEmitRTTIObject(inst); - } - } - - // Don't modify type insts themselves. - if (as(inst)) - return inst; - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - - auto newType = sharedContext->lowerType(builder, inst->getFullType()); - if (newType != inst->getFullType()) - inst = builder->replaceOperand(&inst->typeUse, newType); - - switch (inst->getOp()) - { - default: - break; - case kIROp_StructField: - { - // Translate the struct field type. - auto structField = static_cast(inst); - auto loweredFieldType = - sharedContext->lowerType(builder, structField->getFieldType()); - structField->setOperand(1, loweredFieldType); - } - break; - case kIROp_DebugFunction: - { - auto oldFuncType = as(inst)->getDebugType(); - auto newFuncType = sharedContext->lowerType(builder, oldFuncType); - if (newFuncType != oldFuncType) - inst = builder->replaceOperand(inst->getOperandUse(4), newFuncType); - } - break; - } - return inst; - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - inst = processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - sharedContext->mapInterfaceRequirementKeyValue.clear(); - } -}; - -void lowerGenericType(SharedGenericsLoweringContext* sharedContext) -{ - GenericTypeLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-type.h b/source/slang/slang-ir-lower-generic-type.h deleted file mode 100644 index 20d4fa7b33c..00000000000 --- a/source/slang/slang-ir-lower-generic-type.h +++ /dev/null @@ -1,12 +0,0 @@ -// slang-ir-lower-generic-type.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower all references to generic types (ThisType, AssociatedType, etc.) into IRAnyValueType, -/// and existential types into Tuple. -void lowerGenericType(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp deleted file mode 100644 index 6b25ff10fe1..00000000000 --- a/source/slang/slang-ir-lower-generics.cpp +++ /dev/null @@ -1,309 +0,0 @@ -// slang-ir-lower-generics.cpp -#include "slang-ir-lower-generics.h" - -#include "../core/slang-func-ptr.h" -#include "../core/slang-performance-profiler.h" -#include "slang-ir-any-value-inference.h" -#include "slang-ir-any-value-marshalling.h" -#include "slang-ir-augment-make-existential.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-inst-pass-base.h" -#include "slang-ir-layout.h" -#include "slang-ir-lower-existential.h" -#include "slang-ir-lower-generic-call.h" -#include "slang-ir-lower-generic-function.h" -#include "slang-ir-lower-generic-type.h" -#include "slang-ir-lower-tuple-types.h" -#include "slang-ir-specialize-dispatch.h" -#include "slang-ir-specialize-dynamic-associatedtype-lookup.h" -#include "slang-ir-ssa-simplification.h" -#include "slang-ir-util.h" -#include "slang-ir-witness-table-wrapper.h" - -namespace Slang -{ -// Replace all uses of RTTI objects with its sequential ID. -// Currently we don't use RTTI objects other than check for null/invalid value, -// so all of them are 0xFFFFFFFF. -void specializeRTTIObjectReferences(SharedGenericsLoweringContext* sharedContext) -{ - uint32_t id = 0xFFFFFFFF; - for (auto rtti : sharedContext->mapTypeToRTTIObject) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(rtti.value); - IRUse* nextUse = nullptr; - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - builder.getIntValue(builder.getUIntType(), id), - builder.getIntValue(builder.getUIntType(), 0)}; - auto idOperand = builder.emitMakeVector(uint2Type, 2, uint2Args); - for (auto use = rtti.value->firstUse; use; use = nextUse) - { - nextUse = use->nextUse; - if (use->getUser()->getOp() == kIROp_GetAddress) - { - use->getUser()->replaceUsesWith(idOperand); - } - } - } -} - -// Replace all WitnessTableID type or RTTIHandleType with `uint2`. -void cleanUpRTTIHandleTypes(SharedGenericsLoweringContext* sharedContext) -{ - List instsToRemove; - for (auto inst : sharedContext->module->getGlobalInsts()) - { - switch (inst->getOp()) - { - case kIROp_WitnessTableIDType: - if (isComInterfaceType((IRType*)inst->getOperand(0))) - continue; - // fall through - case kIROp_RTTIHandleType: - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - inst->replaceUsesWith(uint2Type); - instsToRemove.add(inst); - } - break; - } - } - for (auto inst : instsToRemove) - inst->removeAndDeallocate(); -} - -// Remove all interface types from module. -void cleanUpInterfaceTypes(SharedGenericsLoweringContext* sharedContext) -{ - IRBuilder builder(sharedContext->module); - builder.setInsertInto(sharedContext->module->getModuleInst()); - auto dummyInterfaceObj = builder.getIntValue(builder.getIntType(), 0); - List interfaceInsts; - for (auto inst : sharedContext->module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_InterfaceType) - { - if (inst->findDecoration()) - continue; - - interfaceInsts.add(inst); - } - } - for (auto inst : interfaceInsts) - { - inst->replaceUsesWith(dummyInterfaceObj); - inst->removeAndDeallocate(); - } -} - -void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext) -{ - InstPassBase pass(sharedContext->module); - pass.processInstsOfType( - kIROp_IsType, - [&](IRIsType* inst) - { - auto witnessTableType = - as(inst->getValueWitness()->getDataType()); - if (witnessTableType && - isComInterfaceType((IRType*)witnessTableType->getConformanceType())) - return; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto eqlInst = builder.emitEql( - builder.emitGetSequentialIDInst(inst->getValueWitness()), - builder.emitGetSequentialIDInst(inst->getTargetWitness())); - inst->replaceUsesWith(eqlInst); - inst->removeAndDeallocate(); - }); -} - -// Turn all references of witness table or RTTI objects into integer IDs, generate -// specialized `switch` based dispatch functions based on witness table IDs, and remove -// all original witness table, RTTI object and interface definitions from IR module. -// With these transformations, the resulting code is compatible with D3D/Vulkan where -// no pointers are involved in RTTI / dynamic dispatch logic. -void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, DiagnosticSink* sink) -{ - specializeDispatchFunctions(sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerIsTypeInsts(sharedContext); - - specializeDynamicAssociatedTypeLookup(sharedContext); - if (sink->getErrorCount() != 0) - return; - - sharedContext->mapInterfaceRequirementKeyValue.clear(); - - specializeRTTIObjectReferences(sharedContext); - - cleanUpRTTIHandleTypes(sharedContext); - - cleanUpInterfaceTypes(sharedContext); -} - -void checkTypeConformanceExists(SharedGenericsLoweringContext* context) -{ - HashSet implementedInterfaces; - - // Add all interface type that are implemented by at least one type to a set. - for (auto inst : context->module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_WitnessTable) - { - auto interfaceType = - cast(inst->getDataType())->getConformanceType(); - implementedInterfaces.add(interfaceType); - } - } - // Check if an interface type has any implementations. - workOnModule( - context, - [&](IRInst* inst) - { - if (auto lookupWitnessMethod = as(inst)) - { - auto witnessTableType = lookupWitnessMethod->getWitnessTable()->getDataType(); - if (!witnessTableType) - return; - auto interfaceType = - cast(witnessTableType)->getConformanceType(); - if (isComInterfaceType((IRType*)interfaceType)) - return; - if (!implementedInterfaces.contains(interfaceType)) - { - context->sink->diagnose( - interfaceType->sourceLoc, - Diagnostics::noTypeConformancesFoundForInterface, - interfaceType); - // Add to set to prevent duplicate diagnostic messages. - implementedInterfaces.add(interfaceType); - } - } - }); -} - -void stripWrapExistential(IRModule* module) -{ - InstWorkList workList(module); - - workList.add(module->getModuleInst()); - for (Index i = 0; i < workList.getCount(); i++) - { - auto inst = workList[i]; - switch (inst->getOp()) - { - case kIROp_WrapExistential: - { - auto operand = inst->getOperand(0); - inst->replaceUsesWith(operand); - inst->removeAndDeallocate(); - } - break; - default: - for (auto child : inst->getChildren()) - workList.add(child); - break; - } - } -} - -void lowerGenerics(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink) -{ - SLANG_PROFILE; - - SharedGenericsLoweringContext sharedContext(module); - sharedContext.targetProgram = targetProgram; - sharedContext.sink = sink; - - checkTypeConformanceExists(&sharedContext); - - // Replace all `makeExistential` insts with `makeExistentialWithRTTI` - // before making any other changes. This is necessary because a parameter of - // generic type will be lowered into `AnyValueType`, and after that we can no longer - // access the original generic type parameter from the lowered parameter value. - // This steps ensures that the generic type parameter is available via an - // explicit operand in `makeExistentialWithRTTI`, so that type parameter - // can be translated into an RTTI object during `lower-generic-type`, - // and used to create a tuple representing the existential value. - augmentMakeExistentialInsts(module); - - lowerGenericFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerGenericType(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerExistentials(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerGenericCalls(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - generateWitnessTableWrapperFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // This optional step replaces all uses of witness tables and RTTI objects with - // sequential IDs. Without this step, we will emit code that uses function pointers and - // real RTTI objects and witness tables. - specializeRTTIObjects(&sharedContext, sink); - - simplifyIR( - module, - sharedContext.targetProgram, - IRSimplificationOptions::getFast(sharedContext.targetProgram)); - - lowerTuples(module, sink); - if (sink->getErrorCount() != 0) - return; - - generateAnyValueMarshallingFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // At this point, we should no longer need to care any `WrapExistential` insts, - // although they could still exist in the IR in order to call generic core module functions, - // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). - // We should remove them now. - stripWrapExistential(module); -} - -void cleanupGenerics(IRModule* module, TargetProgram* program, DiagnosticSink* sink) -{ - SharedGenericsLoweringContext sharedContext(module); - sharedContext.targetProgram = program; - sharedContext.sink = sink; - - specializeRTTIObjects(&sharedContext, sink); - - lowerTuples(module, sink); - if (sink->getErrorCount() != 0) - return; - - generateAnyValueMarshallingFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // At this point, we should no longer need to care any `WrapExistential` insts, - // although they could still exist in the IR in order to call generic core module functions, - // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). - // We should remove them now. - stripWrapExistential(module); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.h b/source/slang/slang-ir-lower-generics.h deleted file mode 100644 index 9f746f48578..00000000000 --- a/source/slang/slang-ir-lower-generics.h +++ /dev/null @@ -1,19 +0,0 @@ -// slang-ir-lower-generics.h -#pragma once - -#include "slang-ir.h" - -namespace Slang -{ -struct IRModule; -class DiagnosticSink; -class TargetProgram; - -/// Lower generic and interface-based code to ordinary types and functions using -/// dynamic dispatch mechanisms. -void lowerGenerics(IRModule* module, TargetProgram* targetReq, DiagnosticSink* sink); - -// Clean up any generic-related IR insts that are no longer needed. Called when -// it has been determined that no more dynamic dispatch code will be generated. -void cleanupGenerics(IRModule* module, TargetProgram* targetReq, DiagnosticSink* sink); -} // namespace Slang diff --git a/source/slang/slang-ir-lower-reinterpret.cpp b/source/slang/slang-ir-lower-reinterpret.cpp index 4b66515526a..4f3796dfb23 100644 --- a/source/slang/slang-ir-lower-reinterpret.cpp +++ b/source/slang/slang-ir-lower-reinterpret.cpp @@ -94,7 +94,6 @@ void lowerReinterpret(IRModule* module, TargetProgram* target, DiagnosticSink* s // Before processing reinterpret insts, ensure that existential types without // user-defined sizes have inferred sizes where possible. // - inferAnyValueSizeWhereNecessary(module, target); ReinterpretLoweringContext context; context.module = module; diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp deleted file mode 100644 index f50cafd1d69..00000000000 --- a/source/slang/slang-ir-lower-witness-lookup.cpp +++ /dev/null @@ -1,447 +0,0 @@ -// slang-ir-lower-generic-existential.cpp - -#include "slang-ir-lower-witness-lookup.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ - -struct WitnessLookupLoweringContext -{ - IRModule* module; - DiagnosticSink* sink; - - Dictionary witnessDispatchFunctions; - - void init() - { - // Reconstruct the witness dispatch functions map. - for (auto inst : module->getGlobalInsts()) - { - if (auto key = as(inst)) - { - for (auto decor : key->getDecorations()) - { - if (auto witnessDispatchFunc = as(decor)) - { - witnessDispatchFunctions.add(key, witnessDispatchFunc->getFunc()); - } - } - } - } - } - - bool hasAssocType(IRInst* type) - { - if (!type) - return false; - - InstHashSet processedSet(type->getModule()); - InstWorkList workList(type->getModule()); - workList.add(type); - processedSet.add(type); - for (Index i = 0; i < workList.getCount(); i++) - { - auto inst = workList[i]; - if (inst->getOp() == kIROp_AssociatedType) - return true; - - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - if (!inst->getOperand(j)) - continue; - if (processedSet.add(inst->getOperand(j))) - workList.add(inst->getOperand(j)); - } - } - return false; - } - - IRType* translateType(IRBuilder builder, IRInst* type) - { - if (!type) - return nullptr; - if (auto genType = as(type)) - { - IRCloneEnv cloneEnv; - builder.setInsertBefore(genType); - auto newGeneric = as(cloneInst(&cloneEnv, &builder, genType)); - newGeneric->setFullType(builder.getGenericKind()); - auto retVal = findGenericReturnVal(newGeneric); - builder.setInsertBefore(retVal); - auto translated = translateType(builder, retVal); - retVal->replaceUsesWith(translated); - return (IRType*)newGeneric; - } - else if (auto thisType = as(type)) - { - return (IRType*)thisType->getConstraintType(); - } - else if (auto assocType = as(type)) - { - return assocType; - } - - if (as(type)) - return (IRType*)type; - - switch (type->getOp()) - { - case kIROp_Param: - case kIROp_VectorType: - case kIROp_MatrixType: - case kIROp_StructType: - case kIROp_ClassType: - case kIROp_InterfaceType: - case kIROp_LookupWitnessMethod: - case kIROp_EnumType: - return (IRType*)type; - default: - { - List translatedOperands; - for (UInt i = 0; i < type->getOperandCount(); i++) - { - translatedOperands.add(translateType(builder, type->getOperand(i))); - } - auto translated = builder.emitIntrinsicInst( - type->getFullType(), - type->getOp(), - (UInt)translatedOperands.getCount(), - translatedOperands.getBuffer()); - return (IRType*)translated; - } - } - } - - IRInst* findOrCreateDispatchFunc(IRLookupWitnessMethod* lookupInst) - { - IRInst* func = nullptr; - auto requirementKey = cast(lookupInst->getRequirementKey()); - if (witnessDispatchFunctions.tryGetValue(requirementKey, func)) - { - return func; - } - - auto witnessTableOperand = lookupInst->getWitnessTable(); - auto witnessTableType = as(witnessTableOperand->getDataType()); - SLANG_RELEASE_ASSERT(witnessTableType); - auto interfaceType = - as(unwrapAttributedType(witnessTableType->getConformanceType())); - SLANG_RELEASE_ASSERT(interfaceType); - if (interfaceType->findDecoration()) - return nullptr; - auto requirementType = findInterfaceRequirement(interfaceType, requirementKey); - SLANG_RELEASE_ASSERT(requirementType); - - // We only lower non-static function requirement lookups for now. - // Our front end will stick a StaticRequirementDecoration on the IRStructKey for static - // member requirements. - if (lookupInst->getRequirementKey()->findDecoration()) - return nullptr; - auto interfaceMethodFuncType = - as(getResolvedInstForDecorations(requirementType)); - if (interfaceMethodFuncType) - { - // Detect cases that we currently does not support and exit. - - // If this is a non static function requirement, we should - // make sure the first parameter is the interface type. If not, something has gone - // wrong. - if (interfaceMethodFuncType->getParamCount() == 0) - return nullptr; - if (!as(unwrapAttributedType(interfaceMethodFuncType->getParamType(0)))) - return nullptr; - - // The function has any associated type parameter, we currently can't lower it early in - // this pass. We will lower it in the catch all generic lowering pass. - for (UInt i = 1; i < interfaceMethodFuncType->getParamCount(); i++) - { - if (hasAssocType(interfaceMethodFuncType->getParamType(i))) - return nullptr; - } - - // If return type is a composite type containing an assoc type, we won't lower it now. - // Supporting general use of assoc type is possible, but would require more complex - // logic in this pass to marshal things to and from existential types. - if (interfaceMethodFuncType->getResultType()->getOp() != kIROp_AssociatedType && - hasAssocType(interfaceMethodFuncType->getResultType())) - return nullptr; - } - else - { - return nullptr; - } - - - IRBuilder builder(module); - builder.setInsertBefore(getParentFunc(lookupInst)); - - // Create a dispatch func. - IRFunc* dispatchFunc = nullptr; - IRFuncType* dispatchFuncType = nullptr; - IRGeneric* parentGeneric = nullptr; - - // If requirementType is a generic, we need to create a new generic that has the same - // parameters. - if (auto genericRequirement = as(requirementType)) - { - IRCloneEnv cloneEnv; - parentGeneric = as(cloneInst(&cloneEnv, &builder, genericRequirement)); - - auto returnInst = as(parentGeneric->getFirstBlock()->getLastInst()); - SLANG_RELEASE_ASSERT(returnInst); - builder.setInsertBefore(returnInst); - auto oldDispatchFuncType = as(returnInst->getVal()); - if (!oldDispatchFuncType) - return nullptr; - - dispatchFuncType = as(translateType(builder, oldDispatchFuncType)); - - SLANG_RELEASE_ASSERT(dispatchFuncType); - - dispatchFunc = builder.createFunc(); - dispatchFunc->setFullType(dispatchFuncType); - builder.emitReturn(dispatchFunc); - returnInst->removeAndDeallocate(); - - parentGeneric->setFullType(translateType(builder, requirementType)); - } - else - { - dispatchFuncType = as(translateType(builder, requirementType)); - dispatchFunc = builder.createFunc(); - dispatchFunc->setFullType(dispatchFuncType); - } - - // We need to inline this function if the requirement is differentiable, - // so that the autodiff pass doesn't need to handle the dispatch function. - if (requirementKey->findDecoration() || - requirementKey->findDecoration()) - { - builder.addForceInlineDecoration(dispatchFunc); - } - - // Collect generic params. - List genericParams; - if (parentGeneric) - { - for (auto param : parentGeneric->getParams()) - genericParams.add(param); - } - - // Emit the body of the dispatch func. - builder.setInsertInto(dispatchFunc); - auto firstBlock = builder.emitBlock(); - auto firstBlockBuilder = builder; - // Emit parameters. - List params; - - for (UInt i = 0; i < dispatchFuncType->getParamCount(); i++) - { - params.add(builder.emitParam(dispatchFuncType->getParamType(i))); - } - auto witness = builder.emitExtractExistentialWitnessTable(params[0]); - - auto witnessTables = getWitnessTablesFromInterfaceType(module, interfaceType); - if (witnessTables.getCount() == 0) - { - // If there is no witness table, we should emit an error. - sink->diagnose( - lookupInst, - Diagnostics::noTypeConformancesFoundForInterface, - interfaceType); - return nullptr; - } - else - { - List cases; - for (auto witnessTable : witnessTables) - { - IRBlock* block = builder.emitBlock(); - auto caseValue = firstBlockBuilder.emitGetSequentialIDInst(witnessTable); - cases.add(caseValue); - cases.add(block); - auto entry = findWitnessTableEntry(witnessTable, requirementKey); - SLANG_RELEASE_ASSERT(entry); - // If the entry is a generic, we need to specialize it. - if (const auto genericEntry = as(entry)) - { - auto specializedFuncType = builder.emitSpecializeInst( - builder.getTypeKind(), - entry->getFullType(), - (UInt)genericParams.getCount(), - genericParams.getBuffer()); - entry = builder.emitSpecializeInst( - (IRType*)specializedFuncType, - entry, - (UInt)genericParams.getCount(), - genericParams.getBuffer()); - } - auto args = params; - // Reinterpret the first arg into the concrete type. - args[0] = builder.emitReinterpret( - witnessTable->getConcreteType(), - builder.emitExtractExistentialValue( - builder.emitExtractExistentialType(args[0]), - args[0])); - - auto calleeFuncType = - as(getResolvedInstForDecorations(entry)->getFullType()); - auto callReturnType = calleeFuncType->getResultType(); - if (callReturnType->getParent() != module->getModuleInst()) - { - // the return type is dependent on generic parameter, use the type from - // dispatchFuncType instead. - callReturnType = dispatchFuncType->getResultType(); - } - - IRInst* ret = builder.emitCallInst( - callReturnType, - entry, - (UInt)args.getCount(), - args.getBuffer()); - // If result type is an associated type, we need to pack it into an anyValue. - if (as(dispatchFuncType->getResultType())) - { - ret = builder.emitPackAnyValue(dispatchFuncType->getResultType(), ret); - } - builder.emitReturn(ret); - } - builder.setInsertInto(firstBlock); - if (witnessTables.getCount() == 1) - { - builder.emitBranch((IRBlock*)cases[1]); - } - else - { - auto witnessId = firstBlockBuilder.emitGetSequentialIDInst(witness); - auto breakLabel = builder.emitBlock(); - builder.emitUnreachable(); - firstBlockBuilder.emitSwitch( - witnessId, - breakLabel, - (IRBlock*)cases.getLast(), - (UInt)(cases.getCount() - 2), - cases.getBuffer()); - } - } - - // Stick a decoration on the requirement key so we can find the dispatch func later. - IRInst* resultValue = parentGeneric ? (IRInst*)parentGeneric : dispatchFunc; - builder.addDispatchFuncDecoration(requirementKey, resultValue); - - // Register the dispatch func to witnessDispatchFunctions dictionary. - witnessDispatchFunctions[requirementKey] = resultValue; - - return resultValue; - } - - void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* initialExistentialObject) - { - SLANG_RELEASE_ASSERT(call->getArgCount() != 0); - call->setOperand(0, dispatchFunc); - IRBuilder builder(call); - builder.setInsertBefore(call); - auto witnessTable = builder.emitExtractExistentialWitnessTable(initialExistentialObject); - auto newExistentialObject = builder.emitMakeExistential( - initialExistentialObject->getDataType(), - call->getOperand(1), - witnessTable); - call->setOperand(1, newExistentialObject); - } - - bool processWitnessLookup(IRLookupWitnessMethod* lookupInst) - { - auto witnessTableOperand = lookupInst->getWitnessTable(); - auto extractInst = as(witnessTableOperand); - if (!extractInst) - return false; - auto dispatchFunc = findOrCreateDispatchFunc(lookupInst); - if (!dispatchFunc) - return false; - bool changed = false; - auto existentialObject = extractInst->getOperand(0); - - IRBuilder builder(lookupInst); - builder.setInsertBefore(lookupInst); - traverseUses( - lookupInst, - [&](IRUse* use) - { - if (auto specialize = as(use->getUser())) - { - builder.setInsertBefore(use->getUser()); - List args; - for (UInt i = 0; i < specialize->getArgCount(); i++) - args.add(specialize->getArg(i)); - auto specializedType = builder.emitSpecializeInst( - builder.getTypeKind(), - dispatchFunc->getFullType(), - (UInt)args.getCount(), - args.getBuffer()); - auto newSpecialize = builder.emitSpecializeInst( - (IRType*)specializedType, - dispatchFunc, - (UInt)args.getCount(), - args.getBuffer()); - traverseUses( - specialize, - [&](IRUse* specializeUse) - { - if (auto call = as(specializeUse->getUser())) - { - changed = true; - rewriteCallSite(call, newSpecialize, existentialObject); - } - }); - } - else if (auto call = as(use->getUser())) - { - changed = true; - rewriteCallSite(call, dispatchFunc, existentialObject); - } - }); - return changed; - } - - bool processFunc(IRFunc* func) - { - bool changed = false; - for (auto bb : func->getBlocks()) - { - for (auto inst : bb->getModifiableChildren()) - { - if (auto witnessLookupInst = as(inst)) - { - changed |= processWitnessLookup(witnessLookupInst); - } - } - } - return changed; - } -}; - -bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink) -{ - bool changed = false; - WitnessLookupLoweringContext context; - context.module = module; - context.sink = sink; - context.init(); - - for (auto inst : module->getGlobalInsts()) - { - // Process all fully specialized functions and look for - // witness lookup instructions. If we see a lookup for a non-static function, - // create a dispatch function and replace the lookup with a call to the dispatch function. - if (auto func = as(inst)) - changed |= context.processFunc(func); - } - return changed; -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-witness-lookup.h b/source/slang/slang-ir-lower-witness-lookup.h deleted file mode 100644 index 434e97f3e86..00000000000 --- a/source/slang/slang-ir-lower-witness-lookup.h +++ /dev/null @@ -1,15 +0,0 @@ -// slang-ir-lower-witness-lookup.h -#pragma once - -namespace Slang -{ -struct IRModule; -class DiagnosticSink; - -/// Lower calls to a witness lookup into a call to a dispatch function. -/// For example, if we see call(witnessLookup(wt, key)), we will create a -/// dispatch function that calls into different implementations based on witness table -/// ID. The dispatch function will be called instead of witnessLookup. -bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink); - -} // namespace Slang diff --git a/source/slang/slang-ir-propagate-func-properties.cpp b/source/slang/slang-ir-propagate-func-properties.cpp index 0e5ee731a38..ece196079ae 100644 --- a/source/slang/slang-ir-propagate-func-properties.cpp +++ b/source/slang/slang-ir-propagate-func-properties.cpp @@ -82,6 +82,20 @@ class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationCon return true; } + bool isDebugInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DebugLine: + case kIROp_DebugScope: + case kIROp_DebugVar: + case kIROp_DebugValue: + return true; + default: + return false; + } + } + virtual bool propagate(IRBuilder& builder, IRFunc* f) override { bool hasReadNoneCall = false; @@ -90,7 +104,7 @@ class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationCon for (auto inst : block->getChildren()) { // Is this inst known to not have global side effect/analyzable? - if (!isKnownOpCodeWithSideEffect(inst->getOp())) + if (!isKnownOpCodeWithSideEffect(inst->getOp()) && !isDebugInst(inst)) { if (inst->mightHaveSideEffects() || isResourceLoad(inst->getOp())) { @@ -122,6 +136,12 @@ class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationCon } } + // If the inst is a debug instruction, skip it. + // these are only annotations + // + if (isDebugInst(inst)) // TODO: May not need this + continue; + // Do any operands defined have pointer type of global or // unknown source? Passing them into a readNone callee may cause // side effects that breaks the readNone property. diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp deleted file mode 100644 index b862b3dd087..00000000000 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include "slang-ir-specialize-dispatch.h" - -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -IRFunc* specializeDispatchFunction( - SharedGenericsLoweringContext* sharedContext, - IRFunc* dispatchFunc) -{ - auto witnessTableType = cast(dispatchFunc->getDataType())->getParamType(0); - auto conformanceType = cast(witnessTableType)->getConformanceType(); - // Collect all witness tables of `witnessTableType` in current module. - List witnessTables = - sharedContext->getWitnessTablesFromInterfaceType(conformanceType); - - SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock()); - auto block = dispatchFunc->getFirstBlock(); - - // The dispatch function before modification must be in the form of - // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args) - // We now find the relavent instructions. - IRCall* callInst = nullptr; - IRLookupWitnessMethod* lookupInst = nullptr; - // Only used in debug builds as a sanity check - [[maybe_unused]] IRReturn* returnInst = nullptr; - for (auto inst : block->getOrdinaryInsts()) - { - switch (inst->getOp()) - { - case kIROp_Call: - callInst = cast(inst); - break; - case kIROp_LookupWitnessMethod: - lookupInst = cast(inst); - break; - case kIROp_Return: - returnInst = cast(inst); - break; - default: - break; - } - } - SLANG_ASSERT(callInst && lookupInst && returnInst); - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(dispatchFunc); - - // Create a new dispatch func to replace the existing one. - auto newDispatchFunc = builder->createFunc(); - - List paramTypes; - for (auto paramInst : dispatchFunc->getParams()) - { - paramTypes.add(paramInst->getFullType()); - } - - // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential - // ID. - paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType); - - auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType()); - newDispatchFunc->setFullType(newDipsatchFuncType); - dispatchFunc->transferDecorationsTo(newDispatchFunc); - - builder->setInsertInto(newDispatchFunc); - auto newBlock = builder->emitBlock(); - - IRBlock* defaultBlock = nullptr; - - auto requirementKey = lookupInst->getRequirementKey(); - List params; - for (Index i = 0; i < paramTypes.getCount(); i++) - { - auto param = builder->emitParam(paramTypes[i]); - if (i > 0) - params.add(param); - } - auto witnessTableParam = newBlock->getFirstParam(); - - // `witnessTableParam` is expected to have `IRWitnessTableID` type, which - // will later lower into a `uint2`. We only use the first element of the uint2 - // to store the sequential ID and reserve the second 32-bit value for future - // pointer-compatibility. We insert a member extract inst right now - // to obtain the first element and use it in our switch statement. - UInt elemIdx = 0; - auto witnessTableSequentialID = - builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx); - - // Generate case blocks for each possible witness table. - List caseBlocks; - for (Index i = 0; i < witnessTables.getCount(); i++) - { - auto witnessTable = witnessTables[i]; - auto seqIdDecoration = witnessTable->findDecoration(); - if (!seqIdDecoration) - { - sharedContext->sink->diagnose( - witnessTable->getConcreteType(), - Diagnostics::typeCannotBeUsedInDynamicDispatch, - witnessTable->getConcreteType()); - } - - if (i != witnessTables.getCount() - 1) - { - // Create a case block if we are not the last case. - caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); - builder->setInsertInto(newDispatchFunc); - auto caseBlock = builder->emitBlock(); - caseBlocks.add(caseBlock); - } - else - { - // Generate code for the last possible value in the `default` block. - builder->setInsertInto(newDispatchFunc); - defaultBlock = builder->emitBlock(); - builder->setInsertInto(defaultBlock); - } - - auto callee = findWitnessTableEntry(witnessTable, requirementKey); - SLANG_ASSERT(callee); - auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); - if (callInst->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(specializedCallInst); - } - - // Emit a switch statement to call the correct concrete function based on - // the witness table sequential ID passed in. - builder->setInsertInto(newDispatchFunc); - - - if (witnessTables.getCount() == 1) - { - // If there is only 1 case, no switch statement is necessary. - builder->setInsertInto(newBlock); - builder->emitBranch(defaultBlock); - } - else if (witnessTables.getCount() > 1) - { - auto breakBlock = builder->emitBlock(); - builder->setInsertInto(breakBlock); - builder->emitUnreachable(); - - builder->setInsertInto(newBlock); - builder->emitSwitch( - witnessTableSequentialID, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); - } - else - { - // We have no witness tables that implements this interface. - // Just return a default value. - builder->setInsertInto(newBlock); - if (callInst->getDataType()->getOp() == kIROp_VoidType) - { - builder->emitReturn(); - } - else - { - auto defaultValue = builder->emitDefaultConstruct(callInst->getDataType()); - builder->emitReturn(defaultValue); - } - } - // Remove old implementation. - dispatchFunc->replaceUsesWith(newDispatchFunc); - dispatchFunc->removeAndDeallocate(); - - return newDispatchFunc; -} - -// Ensures every witness table object has been assigned a sequential ID. -// All witness tables will have a SequentialID decoration after this function is run. -// The sequantial ID in the decoration will be the same as the one specified in the Linkage. -// Otherwise, a new ID will be generated and assigned to the witness table object, and -// the sequantial ID map in the Linkage will be updated to include the new ID, so they -// can be looked up by the user via future Slang API calls. -void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContext) -{ - StringBuilder generatedMangledName; - - auto linkage = sharedContext->targetProgram->getTargetReq()->getLinkage(); - for (auto inst : sharedContext->module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_WitnessTable) - { - UnownedStringSlice witnessTableMangledName; - if (auto instLinkage = inst->findDecoration()) - { - witnessTableMangledName = instLinkage->getMangledName(); - } - else - { - auto witnessTableType = as(inst->getDataType()); - if (witnessTableType && witnessTableType->getConformanceType() - ->findDecoration()) - { - // The interface is for specialization only, it would be an error if dynamic - // dispatch is used through the interface. Skip assigning ID for the witness - // table. - continue; - } - - // generate a unique linkage for it. - static int32_t uniqueId = 0; - uniqueId++; - if (auto nameHint = inst->findDecoration()) - { - generatedMangledName << nameHint->getName(); - } - generatedMangledName << "_generated_witness_uuid_" << uniqueId; - witnessTableMangledName = generatedMangledName.getUnownedSlice(); - } - - // If the inst already has a SequentialIDDecoration, stop now. - if (inst->findDecoration()) - continue; - - // Get a sequential ID for the witness table using the map from the Linkage. - uint32_t seqID = 0; - if (!linkage->mapMangledNameToRTTIObjectIndex.tryGetValue( - witnessTableMangledName, - seqID)) - { - auto interfaceType = - cast(inst->getDataType())->getConformanceType(); - if (as(interfaceType)) - { - auto interfaceLinkage = interfaceType->findDecoration(); - SLANG_ASSERT( - interfaceLinkage && "An interface type does not have a linkage," - "but a witness table associated with it has one."); - auto interfaceName = interfaceLinkage->getMangledName(); - auto idAllocator = - linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( - interfaceName); - if (!idAllocator) - { - linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; - idAllocator = - linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( - interfaceName); - } - seqID = *idAllocator; - ++(*idAllocator); - } - else - { - // NoneWitness, has special ID of -1. - seqID = uint32_t(-1); - } - linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID; - } - - // Add a decoration to the inst. - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - builder.addSequentialIDDecoration(inst, seqID); - } - } -} - -// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with -// its sequential ID. -void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc) -{ - List users; - for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse) - { - users.add(use->getUser()); - } - for (auto user : users) - { - if (auto call = as(user)) - { - if (call->getCallee() != newDispatchFunc) - continue; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(call); - List args; - for (UInt i = 0; i < call->getArgCount(); i++) - { - args.add(call->getArg(i)); - } - if (as(args[0]->getDataType())) - continue; - auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args); - call->replaceUsesWith(newCall); - call->removeAndDeallocate(); - } - } -} - -void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) -{ - // First we ensure that all witness table objects has a sequential ID assigned. - ensureWitnessTableSequentialIDs(sharedContext); - - // Generate specialized dispatch functions and fixup call sites. - for (const auto& [_, dispatchFunc] : sharedContext->mapInterfaceRequirementKeyToDispatchMethods) - { - // Generate a specialized `switch` statement based dispatch func, - // from the witness tables present in the module. - auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc); - - // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of - // witness table objects. - fixupDispatchFuncCall(sharedContext, newDispatchFunc); - } -} -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dispatch.h b/source/slang/slang-ir-specialize-dispatch.h deleted file mode 100644 index f176229cfc4..00000000000 --- a/source/slang/slang-ir-specialize-dispatch.h +++ /dev/null @@ -1,13 +0,0 @@ -// slang-ir-specialize-dispatch.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Modifies the body of interface dispatch functions to use branching instead -/// of function pointer calls to implement the dynamic dispatch logic. -/// This is only used on GPU targets where function pointers are not supported -/// or are not efficient. -void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext); -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp deleted file mode 100644 index 5bbc62bed5a..00000000000 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ /dev/null @@ -1,266 +0,0 @@ -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-specialize-dispatch.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ - -struct AssociatedTypeLookupSpecializationContext -{ - SharedGenericsLoweringContext* sharedContext; - - IRFunc* createWitnessTableLookupFunc(IRInterfaceType* interfaceType, IRInst* key) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(interfaceType); - - auto inputWitnessTableIDType = builder.getWitnessTableIDType(interfaceType); - auto requirementEntry = sharedContext->findInterfaceRequirementVal(interfaceType, key); - - auto resultWitnessTableType = cast(requirementEntry); - auto resultWitnessTableIDType = - builder.getWitnessTableIDType((IRType*)resultWitnessTableType->getConformanceType()); - - auto funcType = - builder.getFuncType(1, (IRType**)&inputWitnessTableIDType, resultWitnessTableIDType); - auto func = builder.createFunc(); - func->setFullType(funcType); - - if (auto linkage = key->findDecoration()) - builder.addNameHintDecoration(func, linkage->getMangledName()); - - builder.setInsertInto(func); - - auto block = builder.emitBlock(); - auto witnessTableParam = builder.emitParam(inputWitnessTableIDType); - - // `witnessTableParam` is expected to have `IRWitnessTableID` type, which - // will later lower into a `uint2`. We only use the first element of the uint2 - // to store the sequential ID and reserve the second 32-bit value for future - // pointer-compatibility. We insert a member extract inst right now - // to obtain the first element and use it in our switch statement. - UInt elemIdx = 0; - auto witnessTableSequentialID = - builder.emitSwizzle(builder.getUIntType(), witnessTableParam, 1, &elemIdx); - - // Collect all witness tables of `witnessTableType` in current module. - List witnessTables = - sharedContext->getWitnessTablesFromInterfaceType(interfaceType); - - // Generate case blocks for each possible witness table. - IRBlock* defaultBlock = nullptr; - List caseBlocks; - for (Index i = 0; i < witnessTables.getCount(); i++) - { - auto witnessTable = witnessTables[i]; - auto seqIdDecoration = witnessTable->findDecoration(); - SLANG_ASSERT(seqIdDecoration); - - if (i != witnessTables.getCount() - 1) - { - // Create a case block if we are not the last case. - caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); - builder.setInsertInto(func); - auto caseBlock = builder.emitBlock(); - caseBlocks.add(caseBlock); - } - else - { - // Generate code for the last possible value in the `default` block. - builder.setInsertInto(func); - defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - } - - auto resultWitnessTable = findWitnessTableEntry(witnessTable, key); - auto resultWitnessTableIDDecoration = - resultWitnessTable->findDecoration(); - SLANG_ASSERT(resultWitnessTableIDDecoration); - // Pack the resulting witness table ID into a `uint2`. - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - resultWitnessTableIDDecoration->getSequentialIDOperand(), - builder.getIntValue(builder.getUIntType(), 0)}; - auto resultID = builder.emitMakeVector(uint2Type, 2, uint2Args); - builder.emitReturn(resultID); - } - - builder.setInsertInto(func); - - if (witnessTables.getCount() == 1) - { - // If there is only 1 case, no switch statement is necessary. - builder.setInsertInto(block); - builder.emitBranch(defaultBlock); - } - else - { - // If there are more than 1 cases, - // emit a switch statement to return the correct witness table ID based on - // the witness table ID passed in. - auto breakBlock = builder.emitBlock(); - builder.setInsertInto(breakBlock); - builder.emitUnreachable(); - - builder.setInsertInto(block); - builder.emitSwitch( - witnessTableSequentialID, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); - } - - return func; - } - - void processLookupInterfaceMethodInst(IRLookupWitnessMethod* inst) - { - if (isComInterfaceType(inst->getWitnessTable()->getDataType())) - { - return; - } - - // Ignore lookups for RTTI objects for now, since they are not used anywhere. - if (!as(inst->getDataType())) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - auto zero = builder.getIntValue(builder.getUIntType(), 0); - IRInst* args[] = {zero, zero}; - auto zeroUint2 = builder.emitMakeVector(uint2Type, 2, args); - inst->replaceUsesWith(zeroUint2); - return; - } - - // Replace all witness table lookups with calls to specialized functions that directly - // returns the sequential ID of the resulting witness table, effectively getting rid - // of actual witness table objects in the target code (they all become IDs). - auto witnessTableType = inst->getWitnessTable()->getDataType(); - IRInterfaceType* interfaceType = cast( - cast(witnessTableType)->getConformanceType()); - if (!interfaceType) - return; - auto key = inst->getRequirementKey(); - IRFunc* func = nullptr; - if (!sharedContext->mapInterfaceRequirementKeyToDispatchMethods.tryGetValue(key, func)) - { - func = createWitnessTableLookupFunc(interfaceType, key); - sharedContext->mapInterfaceRequirementKeyToDispatchMethods[key] = func; - } - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto witnessTableArg = inst->getWitnessTable(); - auto callInst = builder.emitCallInst(func->getResultType(), func, witnessTableArg); - inst->replaceUsesWith(callInst); - inst->removeAndDeallocate(); - } - - void processGetSequentialIDInst(IRGetSequentialID* inst) - { - // If the operand is a witness table, it is already replaced with a uint2 - // at this point, where the first element in the uint2 is the id of the - // witness table. - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - UInt index = 0; - auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index); - inst->replaceUsesWith(id); - inst->removeAndDeallocate(); - } - - void processModule() - { - // Replace all `lookup_interface_method():IRWitnessTable` with call to specialized - // functions. - workOnModule( - sharedContext, - [this](IRInst* inst) - { - if (inst->getOp() == kIROp_LookupWitnessMethod) - { - processLookupInterfaceMethodInst(cast(inst)); - } - }); - - // Replace all direct uses of IRWitnessTables with its sequential ID. - workOnModule( - sharedContext, - [this](IRInst* inst) - { - if (inst->getOp() == kIROp_WitnessTable) - { - auto seqId = inst->findDecoration(); - if (!seqId) - return; - // Insert code to pack sequential ID into an uint2 at all use sites. - traverseUses( - inst, - [&](IRUse* use) - { - if (as(use->getUser())) - { - return; - } - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(use->getUser()); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - seqId->getSequentialIDOperand(), - builder.getIntValue(builder.getUIntType(), 0)}; - auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args); - builder.replaceOperand(use, uint2seqID); - }); - } - }); - - // Replace all `IRWitnessTableType`s with `IRWitnessTableIDType`. - for (auto globalInst : sharedContext->module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_WitnessTableType) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(globalInst); - auto witnessTableIDType = builder.getWitnessTableIDType( - (IRType*)cast(globalInst)->getConformanceType()); - traverseUses( - globalInst, - [&](IRUse* use) - { - if (use->getUser()->getOp() == kIROp_WitnessTable) - return; - builder.replaceOperand(use, witnessTableIDType); - }); - } - } - - // `GetSequentialID(WitnessTableIDOperand)` becomes just `WitnessTableIDOperand`. - workOnModule( - sharedContext, - [this](IRInst* inst) - { - if (inst->getOp() == kIROp_GetSequentialID) - { - processGetSequentialIDInst(cast(inst)); - } - }); - } -}; - -void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext) -{ - AssociatedTypeLookupSpecializationContext context; - context.sharedContext = sharedContext; - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h deleted file mode 100644 index 83039eca54c..00000000000 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h +++ /dev/null @@ -1,15 +0,0 @@ -// slang-ir-specialize-dynamic-associatedtype-lookup.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Modifies the lookup of associatedtype entries from witness tables into -/// calls to a specialized "lookup" function that takes a witness table id -/// and returns a witness table id. -/// This is used on GPU targets where all witness tables are replaced as -/// integral IDs instead of a real pointer table. -void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 781ccfd0500..0197826bde6 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -5,10 +5,12 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-insts.h" -#include "slang-ir-lower-witness-lookup.h" +#include "slang-ir-lower-dynamic-dispatch-insts.h" #include "slang-ir-peephole.h" #include "slang-ir-sccp.h" #include "slang-ir-ssa-simplification.h" +#include "slang-ir-typeflow-set.h" +#include "slang-ir-typeflow-specialize.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -52,6 +54,7 @@ struct SpecializationContext DiagnosticSink* sink; TargetProgram* targetProgram; SpecializationOptions options; + Dictionary irDictionaryMap; bool changed = false; @@ -269,7 +272,7 @@ struct SpecializationContext // using the simple key type defined as part of the IR cloning infrastructure. // typedef IRSimpleSpecializationKey Key; - Dictionary genericSpecializations; + Dictionary genericSpecializations; // Now let's look at the task of finding or generation a @@ -326,8 +329,7 @@ struct SpecializationContext // existing specialization that has been registered. // If one is found, our work is done. // - IRInst* specializedVal = nullptr; - if (genericSpecializations.tryGetValue(key, specializedVal)) + if (auto specializedVal = tryGetDictionaryEntry(genericSpecializations, key)) return specializedVal; } @@ -354,7 +356,12 @@ struct SpecializationContext // specializations so that we don't instantiate // this generic again for the same arguments. // - genericSpecializations.add(key, specializedVal); + genericSpecializations.add( + key, + addEntryToIRDictionary( + kIROp_GenericSpecializationDictionary, + key.vals, + specializedVal)); return specializedVal; } @@ -856,7 +863,57 @@ struct SpecializationContext IRInterfaceType* interfaceType = nullptr; if (!witnessTable) { - if (auto thisTypeWitness = as(lookupInst->getWitnessTable())) + if (auto witnessTableSet = as(lookupInst->getWitnessTable())) + { + auto requirementKey = lookupInst->getRequirementKey(); + + HashSet satisfyingValSet; + bool skipSpecialization = false; + forEachInSet( + witnessTableSet, + [&](IRInst* instElement) + { + if (auto table = as(instElement)) + { + if (auto satisfyingVal = findWitnessTableEntry(table, requirementKey)) + { + satisfyingValSet.add(satisfyingVal); + return; + } + } + + // If we reach here, we didn't find a satisfying value. + skipSpecialization = true; + }); + + if (!skipSpecialization) + { + IRBuilder builder(module); + auto setOp = getSetOpFromType(lookupInst->getDataType()); + auto newSet = builder.getSet(setOp, satisfyingValSet); + addUsersToWorkList(lookupInst); + if (as(newSet)) + { + lookupInst->replaceUsesWith(builder.getUntaggedUnionType(newSet)); + lookupInst->removeAndDeallocate(); + } + else if (as(newSet)) + { + lookupInst->replaceUsesWith(newSet); + lookupInst->removeAndDeallocate(); + } + else + { + // Should not see any other case. + SLANG_UNREACHABLE("unexpected set kind"); + } + + return true; + } + else + return false; + } + else if (auto thisTypeWitness = as(lookupInst->getWitnessTable())) { if (auto witnessTableType = as(thisTypeWitness->getDataType())) @@ -867,9 +924,10 @@ struct SpecializationContext interfaceType = as(witnessTableType->getConformanceType()); } } - - if (!interfaceType) + else + { return false; + } } // Because we have a concrete witness table, we can @@ -944,111 +1002,125 @@ struct SpecializationContext for (auto child = dictInst->getFirstChild(); child; child = child->next) childrenCount++; dict.reserve(Index{1} << Math::Log2Ceil(childrenCount * 2)); + + List invalidItems; for (auto child : dictInst->getChildren()) { auto item = as(child); if (!item) continue; IRSimpleSpecializationKey key; - bool shouldSkip = false; + bool isInvalid = false; for (UInt i = 0; i < item->getOperandCount(); i++) { if (item->getOperand(i) == nullptr) { - shouldSkip = true; - break; + isInvalid = true; } - if (item->getOperand(i)->getParent() == nullptr) + else if (item->getOperand(i)->getParent() == nullptr) { - shouldSkip = true; - break; + isInvalid = true; } if (as(item->getOperand(i))) { - shouldSkip = true; - break; + isInvalid = true; } + if (i > 0) { key.vals.add(item->getOperand(i)); } } - if (shouldSkip) + if (isInvalid) + { + invalidItems.add(item); + if (dict.containsKey(key)) + dict.remove(key); continue; - auto value = as::type>( - item->getOperand(0)); - SLANG_ASSERT(value); - dict[key] = value; + } + dict[key] = item; } - dictInst->removeAndDeallocate(); + + // Clean up the IR dictionary + for (auto item : invalidItems) + item->removeAndDeallocate(); } void readSpecializationDictionaries() { - auto moduleInst = module->getModuleInst(); - for (auto child : moduleInst->getChildren()) + _readSpecializationDictionaryImpl( + genericSpecializations, + getOrCreateIRDictionary(kIROp_GenericSpecializationDictionary)); + + _readSpecializationDictionaryImpl( + existentialSpecializedFuncs, + getOrCreateIRDictionary(kIROp_ExistentialFuncSpecializationDictionary)); + + _readSpecializationDictionaryImpl( + existentialSpecializedStructs, + getOrCreateIRDictionary(kIROp_ExistentialTypeSpecializationDictionary)); + } + + IRInst* getOrCreateIRDictionary(IROp dictOp) + { + if (irDictionaryMap.containsKey(dictOp)) + return irDictionaryMap[dictOp]; + + for (auto child : module->getModuleInst()->getChildren()) { - switch (child->getOp()) + if (child->getOp() == dictOp) { - case kIROp_GenericSpecializationDictionary: - _readSpecializationDictionaryImpl(genericSpecializations, child); - break; - case kIROp_ExistentialFuncSpecializationDictionary: - _readSpecializationDictionaryImpl(existentialSpecializedFuncs, child); - break; - case kIROp_ExistentialTypeSpecializationDictionary: - _readSpecializationDictionaryImpl(existentialSpecializedStructs, child); - break; - default: - continue; + irDictionaryMap[dictOp] = child; + return child; } } + + IRBuilder builder(module); + builder.setInsertInto(module); + auto dictInst = builder.emitIntrinsicInst(nullptr, dictOp, 0, nullptr); + irDictionaryMap[dictOp] = dictInst; + return dictInst; } - template - void _writeSpecializationDictionaryImpl(TDict& dict, IROp dictOp, IRInst* moduleInst) + IRSpecializationDictionaryItem* addEntryToIRDictionary( + IROp dictOp, + const List& key, + IRInst* val) { - IRBuilder builder(moduleInst); - builder.setInsertInto(moduleInst); - auto dictInst = builder.emitIntrinsicInst(nullptr, dictOp, 0, nullptr); - builder.setInsertInto(dictInst); + auto dictInst = getOrCreateIRDictionary(dictOp); List args; - for (const auto& [key, value] : dict) + args.add(val); + args.addRange(key); + IRBuilder builder(module); + builder.setInsertInto(dictInst); + return cast(builder.emitIntrinsicInst( + nullptr, + kIROp_SpecializationDictionaryItem, + (UInt)args.getCount(), + args.getBuffer())); + } + + // Look up an entry in the given specialization dictionary. + // + // Takes care of cases where the entry is not longer valid + // by removing the entry in the dictionary and returning null. + // + IRInst* tryGetDictionaryEntry( + Dictionary& dict, + const IRSimpleSpecializationKey& key) + { + IRSpecializationDictionaryItem* item = nullptr; + if (dict.tryGetValue(key, item)) { - if (!value->parent) - continue; - for (auto keyVal : key.vals) - { - if (!keyVal->parent) - goto next; - } + if (!as(item->getOperand(0))) + return item->getOperand(0); + else { - args.clear(); - args.add(value); - args.addRange(key.vals); - builder.emitIntrinsicInst( - nullptr, - kIROp_SpecializationDictionaryItem, - (UInt)args.getCount(), - args.getBuffer()); + dict.remove(key); + return nullptr; } - next:; } - } - void writeSpecializationDictionaries() - { - auto moduleInst = module->getModuleInst(); - _writeSpecializationDictionaryImpl( - genericSpecializations, - kIROp_GenericSpecializationDictionary, - moduleInst); - _writeSpecializationDictionaryImpl( - existentialSpecializedFuncs, - kIROp_ExistentialFuncSpecializationDictionary, - moduleInst); - _writeSpecializationDictionaryImpl( - existentialSpecializedStructs, - kIROp_ExistentialTypeSpecializationDictionary, - moduleInst); + + return nullptr; } // All of the machinery for generic specialization @@ -1057,9 +1129,9 @@ struct SpecializationContext // void processModule() { - // Read specialization dictionary from module if it is defined. - // This prevents us from generating duplicated specializations - // when this pass is invoked iteratively. + // Sync local dictionaries with the IR specialization + // dictionaries for faster lookup. + // readSpecializationDictionaries(); // The unspecialized IR we receive as input will have @@ -1180,18 +1252,17 @@ struct SpecializationContext // if (options.lowerWitnessLookups) { - iterChanged = lowerWitnessLookup(module, sink); + iterChanged = specializeDynamicInsts(module, sink); + if (iterChanged) + { + eliminateDeadCode(module->getModuleInst()); + lowerDispatchers(module, sink); + } } if (!iterChanged || sink->getErrorCount()) break; } - - - // For functions that still have `specialize` uses left, we need to preserve the - // its specializations in resulting IR so they can be reconstructed when this - // specialization pass gets invoked again. - writeSpecializationDictionaries(); } void addInstsToWorkListRec(IRInst* inst) @@ -1508,14 +1579,21 @@ struct SpecializationContext // Once we've constructed our key, we can try to look for an // existing specialization of the callee that we can use. // - IRFunc* specializedCallee = nullptr; - if (!existentialSpecializedFuncs.tryGetValue(key, specializedCallee)) + IRFunc* specializedCallee = + cast(tryGetDictionaryEntry(existentialSpecializedFuncs, key)); + + if (!specializedCallee) { // If we didn't find a specialized callee already made, then we // will go ahead and create one, and then register it in our cache. // specializedCallee = createExistentialSpecializedFunc(inst, calleeFunc); - existentialSpecializedFuncs.add(key, specializedCallee); + existentialSpecializedFuncs.add( + key, + addEntryToIRDictionary( + kIROp_ExistentialFuncSpecializationDictionary, + key.vals, + specializedCallee)); } // At this point we have found or generated a specialized version @@ -1739,7 +1817,8 @@ struct SpecializationContext // In order to cache and re-use functions that have had existential-type // parameters specialized, we need storage for the cache. // - Dictionary existentialSpecializedFuncs; + Dictionary + existentialSpecializedFuncs; // The logic for creating a specialized callee function by plugging // in concrete types for existentials is similar to other cases of @@ -2545,7 +2624,8 @@ struct SpecializationContext } } - Dictionary existentialSpecializedStructs; + Dictionary + existentialSpecializedStructs; bool maybeSpecializeBindExistentialsType(IRBindExistentialsType* type) { @@ -2639,10 +2719,12 @@ struct SpecializationContext key.vals.add(type->getExistentialArg(ii)); } - IRStructType* newStructType = nullptr; addUsersToWorkList(type); - if (!existentialSpecializedStructs.tryGetValue(key, newStructType)) + IRStructType* newStructType = + cast(tryGetDictionaryEntry(existentialSpecializedStructs, key)); + + if (!newStructType) { builder.setInsertBefore(baseStructType); newStructType = builder.createStructType(); @@ -2670,7 +2752,12 @@ struct SpecializationContext builder.createStructField(newStructType, oldField->getKey(), newFieldType); } - existentialSpecializedStructs.add(key, newStructType); + existentialSpecializedStructs.add( + key, + addEntryToIRDictionary( + kIROp_ExistentialTypeSpecializationDictionary, + key.vals, + newStructType)); } type->replaceUsesWith(newStructType); @@ -3073,12 +3160,242 @@ void finalizeSpecialization(IRModule* module) } } +// Evaluate a `Specialize` inst where the arguments are sets of types (or witness tables) +// rather than concrete singleton types and the generic returns a function. +// +// This needs to be slightly different from the usual case because the function +// needs dynamic information to select a specific element from each collection +// at runtime. +// +// The resulting function will therefore have additional parameters at the beginning +// to accept this information. +// +IRInst* specializeGenericWithSetArgs(IRSpecialize* specializeInst) +{ + // The high-level logic for specializing a generic to operate over collections + // is similar to specializing a simple generic: + // We "evaluate" the instructions in the first block of the generic and return + // the function that is returned by the generic. + // + // The key difference is that in the static case, all generic parameters, and instructions + // in the generic's body are guaranteed to be "baked out" into concrete types or witness tables. + // + // In the dynamic case, some generic parameters may turn into function parameters that accept a + // tag, and any lookup instructions might then have to be cloned into the function body. + // + // This is a slightly complex transformation that proceeds as follows: + // + // - Create an empty function that represents the final product. + // + // - Add any dynamic parameters of the generic to the function's first block. Keep track of the + // first block for later. For now, we only treat `WitnessTableType` parameters that have + // `WitnessTableSet` arguments (with atleast 2 distinct elements) as dynamic. Each such + // parameter will get a corresponding parameter of `TagType(tableSet)` + // + // - Clone in the rest of the generic's body into the first block of the function. + // The tricky part here is that we may have parameter types that depend on other parameters. + // This is a pattern that is allowed in the generic, but not in functions. + // + // To handle this, we maintain two cloning environments, a regular `cloneEnv` that registers + // the parameters, and a `staticCloneEnv` that is a child of `cloneEnv`, but overrides the + // dynamic parameters with their static collection. The `staticCloneEnv` is used to clone in + // the parameter and function types, while the `cloneEnv` is used for the rest. + // + // - When we reach the return value (i.e. the function inside the generic), we clone in the + // parameters + // of the function into the first block, and place them _after_ the parameters derived from + // the generic. Then the rest of the inner function's first block is cloned in. + // + // - All other blocks can be cloned in as usual. + // + + auto generic = cast(specializeInst->getBase()); + auto genericReturnVal = findGenericReturnVal(generic); + + IRBuilder builder(specializeInst->getModule()); + builder.setInsertInto(specializeInst->getModule()); + + // Let's start by creating the function itself. + auto loweredFunc = builder.createFunc(); + builder.setInsertInto(loweredFunc); + builder.setInsertInto(builder.emitBlock()); + + IRCloneEnv cloneEnv; + cloneEnv.squashChildrenMapping = true; + + IRCloneEnv staticCloningEnv; + // Use this as the child to 'override' certain elements in the parent environment with + // their static versions. + // + staticCloningEnv.parent = &cloneEnv; + + Index argIndex = 0; + List extraParamTypes; + OrderedDictionary extraParamMap; + // Map the generic's parameters to the specialized arguments. + for (auto param : generic->getFirstBlock()->getParams()) + { + auto specArg = specializeInst->getArg(argIndex++); + if (auto set = as(specArg)) + { + // We're dealing with a set of types. + if (as(param->getDataType())) + { + // TODO: This case should not happen anymore. + cloneEnv.mapOldValToNew[param] = builder.getUntaggedUnionType(set); + } + else if (as(param->getDataType())) + { + // For cloning parameter types, we want to just use the + // set. + // + staticCloningEnv.mapOldValToNew[param] = set; + + // We'll create an integer parameter for all the rest of + // the insts which will may need the runtime tag. + // + auto tagType = (IRType*)builder.getSetTagType(set); + // cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); + extraParamMap.add(param, builder.emitParam(tagType)); + extraParamTypes.add(tagType); + } + } + else + { + // For everything else, just set the parameter type to the argument; + SLANG_ASSERT(specArg->getParent()->getOp() == kIROp_ModuleInst); + cloneEnv.mapOldValToNew[param] = specArg; + } + } + + // The parameters we've used so far are merely tags, we can't use them directly + // without turning them into elements. + // + // We'll emit a `GetElementFromTag` on the parameter and map that to the original generic + // parameter. + // + for (auto paramPair : extraParamMap) + { + auto originalParam = paramPair.key; + auto newTagParam = paramPair.value; + + auto getElementInst = builder.emitGetElementFromTag(newTagParam); + cloneEnv.mapOldValToNew[originalParam] = getElementInst; + } + + // Clone in the rest of the generic's body including the blocks of the returned func. + for (auto inst = generic->getFirstBlock()->getFirstOrdinaryInst(); inst; + inst = inst->getNextInst()) + { + if (inst == genericReturnVal) + { + auto returnedFunc = cast(inst); + auto funcFirstBlock = returnedFunc->getFirstBlock(); + + builder.setInsertBefore(loweredFunc->getFirstBlock()); + for (auto decoration : returnedFunc->getDecorations()) + { + cloneInst(&staticCloningEnv, &builder, decoration); + } + + builder.setInsertInto(loweredFunc); + for (auto block : returnedFunc->getBlocks()) + { + // Merge the first block of the generic with the first block of the + // returned function to merge the parameter lists. + // + cloneEnv.mapOldValToNew[block] = cloneInstAndOperands(&cloneEnv, &builder, block); + } + + builder.setInsertInto(builder.getModule()); + auto loweredFuncType = + as(cloneInst(&staticCloningEnv, &builder, returnedFunc->getFullType())); + loweredFunc->setFullType((IRType*)loweredFuncType); + + builder.setInsertInto(loweredFunc->getFirstBlock()); + builder.emitBranch(as(cloneEnv.mapOldValToNew[funcFirstBlock])); + + for (auto param : funcFirstBlock->getParams()) + { + // Clone the parameters of the first block. + if (loweredFunc->getFirstBlock()->getFirstParam() == nullptr) + builder.setInsertBefore(loweredFunc->getFirstBlock()->getFirstChild()); + else + builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); + + auto newParam = cloneInst(&staticCloningEnv, &builder, param); + cloneEnv.mapOldValToNew[param] = newParam; // Transfer the param to the dynamic env + } + + builder.setInsertInto(as(cloneEnv.mapOldValToNew[funcFirstBlock])); + for (auto _inst = funcFirstBlock->getFirstOrdinaryInst(); _inst; + _inst = _inst->getNextInst()) + { + // Clone the instructions in the first block. + cloneInst(&cloneEnv, &builder, _inst); + } + + for (auto block : returnedFunc->getBlocks()) + { + if (block == funcFirstBlock) + continue; // Already cloned the first block + + cloneInstDecorationsAndChildren( + &cloneEnv, + builder.getModule(), + block, + cloneEnv.mapOldValToNew[block]); + } + + // Add extra indices to the func-type parameters + List funcTypeParams; + for (Index i = 0; i < extraParamTypes.getCount(); i++) + funcTypeParams.add(extraParamTypes[i]); + + for (auto paramType : loweredFuncType->getParamTypes()) + funcTypeParams.add(paramType); + + // Set the new function type with the extra indices + loweredFunc->setFullType( + builder.getFuncType(funcTypeParams, loweredFuncType->getResultType())); + } + else if (as(inst)) + { + // Emit out into the global scope. + IRBuilder globalBuilder(builder.getModule()); + globalBuilder.setInsertInto(builder.getModule()); + cloneInst(&staticCloningEnv, &globalBuilder, inst); + } + else if (!as(inst)) + { + // Clone insts in the generic under two different environments: + // One that is "static" (for cloning types), and one that is "dynamic" + // which uses tags for witness tables instead of a static collection. + // + // We'll want to use the dynamic environment for cloning everything in the function + // body, but the static environment for cloning parameter types. + // + // Note that this can result in some types inside the function (i.e. not used in the + // parameter types) being cloned under the dynamic environment, but + // a subsequent pass of dynamic-inst-lowering will convert those to the static form. + // + cloneInst(&staticCloningEnv, &builder, inst); + cloneInst(&cloneEnv, &builder, inst); + } + } + + return loweredFunc; +} + IRInst* specializeGenericImpl( IRGeneric* genericVal, IRSpecialize* specializeInst, IRModule* module, SpecializationContext* context) { + if (isSetSpecializedGeneric(specializeInst)) + return specializeGenericWithSetArgs(specializeInst); + // Effectively, specializing a generic amounts to "calling" the generic // on its concrete argument values and computing the // result it returns. @@ -3216,6 +3533,10 @@ IRInst* specializeGeneric(IRSpecialize* specializeInst) if (!module) return specializeInst; + if (isSetSpecializedGeneric(specializeInst)) + return specializeGenericWithSetArgs(specializeInst); + + // Standard static specialization of generic. return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); } diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index d069324b2d9..082ea50bad2 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -4,6 +4,8 @@ namespace Slang { struct IRModule; +struct IRInst; +struct IRSpecialize; class DiagnosticSink; class TargetProgram; @@ -24,4 +26,11 @@ bool specializeModule( void finalizeSpecialization(IRModule* module); +IRInst* specializeGeneric(IRSpecialize* specInst); + +// Specialize a generic with one or more arguments that are collections rather +// than single concrete values. +// +IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst); + } // namespace Slang diff --git a/source/slang/slang-ir-strip-legalization-insts.cpp b/source/slang/slang-ir-strip-legalization-insts.cpp index f55d3f11974..316d091c6ca 100644 --- a/source/slang/slang-ir-strip-legalization-insts.cpp +++ b/source/slang/slang-ir-strip-legalization-insts.cpp @@ -68,12 +68,8 @@ void unpinWitnessTables(IRModule* module) if (!witnessTable) continue; - // If a witness table is not used for dynamic dispatch, unpin it. - if (!witnessTable->findDecoration()) - { - while (auto decor = witnessTable->findDecoration()) - decor->removeAndDeallocate(); - } + while (auto decor = witnessTable->findDecoration()) + decor->removeAndDeallocate(); } } diff --git a/source/slang/slang-ir-typeflow-set.cpp b/source/slang/slang-ir-typeflow-set.cpp new file mode 100644 index 00000000000..b712ffbb5eb --- /dev/null +++ b/source/slang/slang-ir-typeflow-set.cpp @@ -0,0 +1,119 @@ +#include "slang-ir-typeflow-set.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ + +// Upcast the value in 'arg' to match the destInfo type. This method inserts +// any necessary reinterprets or tag translation instructions. +// +IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) +{ + // The upcasting process inserts the appropriate instructions + // to make arg's type match the type provided by destInfo. + // + // This process depends on the structure of arg and destInfo. + // + // We only deal with the type-flow data-types that are created in + // our pass (SetBase/TaggedUnionType/SetTagType/any other + // composites of these insts) + // + + auto argInfo = arg->getDataType(); + if (!argInfo || !destInfo) + return arg; + + if (as(argInfo) && as(destInfo)) + { + // A tagged union is essentially a tuple(TagType(tableSet), + // typeSet) We simply extract the two components, upcast each one, and put it + // back together. + // + + auto argTUType = as(argInfo); + auto destTUType = as(destInfo); + + if (argTUType != destTUType) + { + auto argTableTag = builder->emitGetTagFromTaggedUnion(arg); + auto reinterpretedTableTag = upcastSet( + builder, + argTableTag, + builder->getSetTagType(destTUType->getWitnessTableSet())); + + auto argTypeTag = builder->emitGetTypeTagFromTaggedUnion(arg); + auto reinterpretedTypeTag = + upcastSet(builder, argTypeTag, builder->getSetTagType(destTUType->getTypeSet())); + + auto argVal = builder->emitGetValueFromTaggedUnion(arg); + auto reinterpretedVal = + upcastSet(builder, argVal, builder->getUntaggedUnionType(destTUType->getTypeSet())); + return builder->emitMakeTaggedUnion( + destTUType, + reinterpretedTypeTag, + reinterpretedTableTag, + reinterpretedVal); + } + } + else if (as(argInfo) && as(destInfo)) + { + // If the arg represents a tag of a set, but the dest is a _different_ + // set, then we need to emit a tag operation to reinterpret the + // tag. + // + // Note that, by the invariant provided by the typeflow analysis, the target + // set must necessarily be a super-set. + // + if (argInfo != destInfo) + { + return builder->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperSet, 1, &arg); + } + } + else if (as(argInfo) && as(destInfo)) + { + // If the arg has a untagged union type, but the dest is a _different_ untagged union, + // we need to perform a reinterpret. + // + // e.g. TypeSet({T1, T2}) may lower to AnyValueType(N), while + // TypeSet({T1, T2, T3}) may lower to AnyValueType(M). Since the target + // is necessarily a super-set, the target any-value-type is always larger (M >= N), + // so we only need a simple reinterpret. + // + if (argInfo != destInfo) + { + auto argSet = as(argInfo)->getSet(); + if (argSet->isSingleton() && as(argSet->getElement(0))) + { + // There's a specific case where we're trying to reinterpret a value of 'none' + // type. We'll avoid emitting a reinterpret in this case, and emit a + // default-construct instead. + // + return builder->emitDefaultConstruct((IRType*)destInfo); + } + + // General case: + // + // If the sets of witness tables are not equal, reinterpret to the + // parameter type + // + return builder->emitReinterpret((IRType*)destInfo, arg); + } + } + else if (!as(argInfo) && as(destInfo)) + { + // If the arg is not a collection-type, but the dest is a collection, + // we need to perform a pack operation. + // + // This case only arises when passing a value of type T to a parameter + // of a type-set that contains T. + // + return builder->emitPackAnyValue((IRType*)destInfo, arg); + } + + return arg; // Can use as-is. +} + +} // namespace Slang diff --git a/source/slang/slang-ir-typeflow-set.h b/source/slang/slang-ir-typeflow-set.h new file mode 100644 index 00000000000..6431af6ac0b --- /dev/null +++ b/source/slang/slang-ir-typeflow-set.h @@ -0,0 +1,25 @@ +// slang-ir-typeflow-set.h +#pragma once +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +// +// Helpers to iterate over elements of a collection. +// + +template +void forEachInSet(IRSetBase* info, F func) +{ + for (UInt i = 0; i < info->getOperandCount(); ++i) + func(info->getOperand(i)); +} + +// Upcast the value in 'arg' to match the destInfo type. This method inserts +// any necessary reinterprets or tag translation instructions. +// +IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo); + +} // namespace Slang diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp new file mode 100644 index 00000000000..3fc863d039e --- /dev/null +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -0,0 +1,4726 @@ +#include "slang-ir-typeflow-specialize.h" + +#include "slang-ir-any-value-marshalling.h" +#include "slang-ir-clone.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-insts.h" +#include "slang-ir-specialize.h" +#include "slang-ir-typeflow-set.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + + +namespace Slang +{ + +// Basic unit for which we keep track of propagation information. +// +// This unit has two components: an 'inst' and a 'context' under which we +// are recording propagation info. +// +// The 'inst' must be inside a block (with either generic or func parent), since +// we assume everything in the global scope is concrete. +// +// The 'context' can be one of two cases: +// 1. an IRFunc ONLY if it is not generic (func is in the global scope). 'inst' must +// be inside the func. +// 2. an IRSpecialize(generic, ...). 'inst' must be inside the generic. the +// specialization args must all be global values (either concrete types/values, or collections). +// +// All other possibilites for 'context' are illegal. +// `InstWithContext::validateInstWithContext` enforces these rules. +// +// For an inst inside a generic, it is possible to have different propagation information +// depending on the specialization args, which is why we use the pair to keep track of the context. +// +struct InstWithContext +{ + IRInst* context; + IRInst* inst; + + InstWithContext() + : context(nullptr), inst(nullptr) + { + } + + InstWithContext(IRInst* context, IRInst* inst) + : context(context), inst(inst) + { + validateInstWithContext(); + } + + void validateInstWithContext() const + { + switch (context->getOp()) + { + case kIROp_Func: + { + SLANG_ASSERT(inst->getParent()->getParent() == context); + break; + } + case kIROp_Specialize: + { + auto generic = as(context)->getBase(); + // The base should be a parent of the inst. + bool foundParent = false; + for (auto parent = inst->getParent(); parent; parent = parent->getParent()) + { + if (parent == generic) + { + foundParent = true; + break; + } + } + SLANG_ASSERT(foundParent); + } + break; + default: + { + SLANG_UNEXPECTED("Invalid context for InstWithContext"); + } + } + } + + // If a context is not specified, we assume it is not in a generic, and + // simply use the parent func. + // + InstWithContext(IRInst* inst) + : inst(inst) + { + auto block = cast(inst->getParent()); + auto func = cast(block->getParent()); + + // If parent func is not a global, then it is not a direct + // reference. An explicit IRSpecialize instruction must be provided as + // context. + // + SLANG_ASSERT(func->getParent()->getOp() == kIROp_ModuleInst); + + context = func; + } + + bool operator==(const InstWithContext& other) const + { + return context == other.context && inst == other.inst; + } + + HashCode64 getHashCode() const { return combineHash(HashCode(context), HashCode(inst)); } +}; + +// Test if inst represents a pointer to a global resource. +bool isResourcePointer(IRInst* inst) +{ + if (isPointerToResourceType(inst->getDataType()) || + inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) + return true; + + if (as(inst)) + return true; + + if (auto ptr = as(inst)) + return isResourcePointer(ptr->getBase()); + + if (auto fieldAddress = as(inst)) + return isResourcePointer(fieldAddress->getBase()); + + return false; +} + +// Test if the callee represents an invalid call. This can arise from looking something up +// on a 'none' witness (in the presence of optional witness tables) +// +bool isNoneCallee(IRInst* callee) +{ + if (auto lookupWitness = as(callee)) + { + if (auto table = as(lookupWitness->getWitnessTable())) + { + return table->getConcreteType()->getOp() == kIROp_VoidType; + } + } + + return false; +} + +// Represents an interprocedural edge between call sites and functions +struct InterproceduralEdge +{ + enum class Direction + { + CallToFunc, // From call site to function entry (propagating arguments) + FuncToCall // From function return to call site (propagating return value) + }; + + Direction direction; + IRInst* callerContext; // The context of the call (e.g. function or specialized generic) + IRCall* callInst; // The call instruction + IRInst* targetContext; // The function/specialized-generic being called or returned from + + InterproceduralEdge() = default; + InterproceduralEdge(Direction dir, IRInst* callerContext, IRCall* call, IRInst* func) + : direction(dir), callerContext(callerContext), callInst(call), targetContext(func) + { + } +}; + +// Representation of a work item used to register work for the main propagation queue. +// When the propagation information for a particular inst is modified non-trivially, new +// 'WorkItem' objects are added to the queue to further propagate the changes. +// +// The "Type" captures the granularity & type of the propagation work, while the union +// holds on to any auxiliary information. +// +struct WorkItem +{ + enum class Type + { + None, // Invalid + Inst, // Propagate through a single instruction. + Block, // Propagate through each instruction in a block. + IntraProc, // Propagate through within-function edge (IREdge) + InterProc // Propagate across function call/return (InterproceduralEdge) + }; + + Type type; + IRInst* context; // The context of the work item. + union + { + IRInst* inst; // Type::Inst + IRBlock* block; // Type::Block + IREdge intraProcEdge; // Type::IntraProc + InterproceduralEdge interProcEdge; // Type::InterProc + }; + + WorkItem() + : type(Type::None) + { + } + + WorkItem(IRInst* context, IRInst* inst) + : type(Type::Inst), inst(inst), context(context) + { + SLANG_ASSERT(context != nullptr && inst != nullptr); + // Validate that the context is appropriate for the instruction + InstWithContext(context, inst).validateInstWithContext(); + } + + WorkItem(IRInst* context, IRBlock* block) + : type(Type::Block), block(block), context(context) + { + SLANG_ASSERT(context != nullptr && block != nullptr); + // Validate that the context is appropriate for the block + InstWithContext(context, block->getFirstChild()).validateInstWithContext(); + } + + WorkItem(IRInst* context, IREdge edge) + : type(Type::IntraProc), intraProcEdge(edge), context(context) + { + SLANG_ASSERT(context != nullptr); + } + + WorkItem(InterproceduralEdge edge) + : type(Type::InterProc), interProcEdge(edge), context(nullptr) + { + } + + WorkItem(InterproceduralEdge::Direction dir, IRInst* callerCtx, IRCall* call, IRInst* callee) + : type(Type::InterProc), interProcEdge(dir, callerCtx, call, callee), context(nullptr) + { + } + + // Copy constructor and assignment needed for union with non-trivial types + WorkItem(const WorkItem& other) + : type(other.type), context(other.context) + { + if (type == Type::IntraProc) + intraProcEdge = other.intraProcEdge; + else if (type == Type::InterProc) + interProcEdge = other.interProcEdge; + else if (type == Type::Inst) + inst = other.inst; + else + block = other.block; + } + + WorkItem& operator=(const WorkItem& other) + { + type = other.type; + context = other.context; + if (type == Type::IntraProc) + intraProcEdge = other.intraProcEdge; + else if (type == Type::InterProc) + interProcEdge = other.interProcEdge; + else if (type == Type::Inst) + inst = other.inst; + else + block = other.block; + return *this; + } +}; + +// Returns true if the two propagation infos are equal. +bool areInfosEqual(IRInst* a, IRInst* b) +{ + // Since all inst opcodes that are used to represent propagation information + // are hoistable and automatically de-duplicated by the Slang IR infrastructure, + // we can simply test pointer equality + // + return a == b; +} + +// Helper data-structure for efficient enqueueing/dequeueing of work items. +// +// Our 'List' data-structure are currently designed to be efficient at operating as a stack +// but have poor performance for queue-like operations, so this rolls two stacks into a queue +// structure. +// +struct WorkQueue +{ + List enqueueList; + List dequeueList; + Index dequeueIndex = 0; + + void enqueue(const WorkItem& item) { enqueueList.add(item); } + + WorkItem dequeue() + { + if (dequeueList.getCount() <= dequeueIndex) + { + dequeueList.swapWith(enqueueList); + enqueueList.clear(); + dequeueIndex = 0; + } + + SLANG_ASSERT(dequeueIndex < dequeueList.getCount()); + return dequeueList[dequeueIndex++]; + } + + bool hasItems() const + { + return (dequeueIndex < dequeueList.getCount()) || (enqueueList.getCount() > 0); + } +}; + +// Tests whether a generic can be fully specialized, or if it requires a dynamic information. +// +// This test is primarily used to determine if additional parameters are requried to place a call to +// this callee. +// +bool isSetSpecializedGeneric(IRInst* callee) +{ + // If the callee is a specialization, and at least one of its arguments + // is a collection, then it needs dynamic-dispatch logic to be generated. + // + if (auto specialize = as(callee)) + { + for (UInt i = 0; i < specialize->getArgCount(); i++) + { + // Only functions need set-aware specialization. + auto generic = specialize->getBase(); + if (getGenericReturnVal(generic)->getOp() != kIROp_Func) + return false; + + auto arg = specialize->getArg(i); + if (as(arg)) + return true; // Found a set argument + } + return false; // No set arguments found + } + + return false; +} + +IRInst* getArrayStride(IRArrayType* arrayType) +{ + if (arrayType->getOperandCount() >= 3) + return arrayType->getStride(); + return nullptr; +} + +// +// Helper struct to represent a parameter's direction and type component. +// This is used by the type flow system to figure out which direction to propagate +// information for each parameter. +// +struct ParameterDirectionInfo +{ + enum Kind + { + In, + BorrowIn, + Out, + BorrowInOut, + Ref + } kind; + + // For Ref and BorrowInOut + AddressSpace addressSpace; + + ParameterDirectionInfo(Kind kind, AddressSpace addressSpace = (AddressSpace)0) + : kind(kind), addressSpace(addressSpace) + { + } + + ParameterDirectionInfo() + : kind(Kind::In), addressSpace((AddressSpace)0) + { + } + + bool operator==(const ParameterDirectionInfo& other) const + { + return kind == other.kind && addressSpace == other.addressSpace; + } +}; + +// Split parameter type into a direction and a type +std::tuple splitParameterDirectionAndType(IRType* paramType) +{ + if (as(paramType)) + return { + ParameterDirectionInfo(ParameterDirectionInfo::Kind::Out), + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirectionInfo(ParameterDirectionInfo::Kind::BorrowInOut), + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirectionInfo( + ParameterDirectionInfo::Kind::Ref, + as(paramType)->getAddressSpace()), + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirectionInfo( + ParameterDirectionInfo::Kind::BorrowIn, + as(paramType)->getAddressSpace()), + as(paramType)->getValueType()}; + else + return {ParameterDirectionInfo(ParameterDirectionInfo::Kind::In), paramType}; +} + +// Join parameter direction and a type back into a parameter type +IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo info, IRType* type) +{ + switch (info.kind) + { + case ParameterDirectionInfo::Kind::In: + return type; + case ParameterDirectionInfo::Kind::Out: + return builder->getOutParamType(type); + case ParameterDirectionInfo::Kind::BorrowInOut: + return builder->getBorrowInOutParamType(type); + case ParameterDirectionInfo::Kind::BorrowIn: + return builder->getBorrowInParamType(type, info.addressSpace); + case ParameterDirectionInfo::Kind::Ref: + return builder->getRefParamType(type, info.addressSpace); + default: + SLANG_UNEXPECTED("Unhandled parameter info in fromDirectionAndType"); + } +} + +// Helper to test if an inst is in the global scope. +bool isGlobalInst(IRInst* inst) +{ + return inst->getParent()->getOp() == kIROp_ModuleInst; +} + +// This is fairly fundamental check: +// This method checks whether a inst's type cannot accept any further refinement. +// +// Examples: +// 1. an inst of `UInt` type cannot be further refined (under the current scope +// of the type-flow pass), since it has a concrete type, and we cannot replace it +// with a "narrower" type. +// +// 2. an inst of `InterfaceType` represents a tagged union of any type that implements +// the interface, so it can be further refined by determining a smaller set of +// possibilities (i.e. via `TaggedUnionType(tableSet, typeSet)`). +// +// 3. Similarly, an inst of `WitnessTableType` represents any witness table, +// so it can accept a further refinement into `ElementOfSetType(tableSet)`. +// +// In the future, we may want to extend this check to something more nuanced, +// which takes in both the inst and the refined type to determine if we want to +// accept the refinement (this is useful in cases like `UnsizedArrayType`, where +// we only want to refine it if we can determine a concrete size). +// +bool isConcreteType(IRInst* inst) +{ + bool isInstGlobal = isGlobalInst(inst); + if (!isInstGlobal) + return false; + + switch (inst->getOp()) + { + case kIROp_InterfaceType: // Can be refined to tagged unions + return isComInterfaceType(cast(inst)); + case kIROp_WitnessTableType: // Can be refined into set of concrete tables + return false; + case kIROp_FuncType: // Can be refined into set of concrete functions + return false; + case kIROp_GenericKind: // Can be refined into set of concrete generics + return false; + case kIROp_TypeKind: // Can be refined into set of concrete types + return false; + case kIROp_TypeType: // Can be refined into set of concrete types + return false; + case kIROp_ArrayType: + return isConcreteType(cast(inst)->getElementType()) && + isGlobalInst(cast(inst)->getElementCount()); + case kIROp_OptionalType: + return isConcreteType(cast(inst)->getValueType()); + default: + break; + } + + if (as(inst)) + { + auto ptrType = as(inst); + return isConcreteType(ptrType->getValueType()); + } + + if (auto generic = as(inst)) + { + if (as(getGenericReturnVal(generic))) + return false; // Can be refined into set of concrete generics. + } + + return true; +} + +IRInst* makeInfoForConcreteType(IRModule* module, IRInst* type) +{ + SLANG_ASSERT(isConcreteType(type)); + IRBuilder builder(module); + if (auto ptrType = as(type)) + { + return builder.getPtrTypeWithAddressSpace( + (IRType*)makeInfoForConcreteType(module, ptrType->getValueType()), + ptrType); + } + + if (auto arrayType = as(type)) + { + return builder.getArrayType( + (IRType*)makeInfoForConcreteType(module, arrayType->getElementType()), + arrayType->getElementCount(), + getArrayStride(arrayType)); + } + + return builder.getUntaggedUnionType( + cast(builder.getSingletonSet(kIROp_TypeSet, type))); +} + +// Determines a suitable set opcode to use to represent a set of elements of the given type. +// +// e.g. an inst of WitnessTableType can be be refined using a WitnessTableSet (but it doesn't make +// sense to use a TypeSet or FuncSet). +// +// This is primarily used for `LookupWitnessMethod`, where the resulting element could be +// a generic, witness table, type or function, depending on the type of the lookup inst. +// +IROp getSetOpFromType(IRType* type) +{ + switch (type->getOp()) + { + case kIROp_WitnessTableType: // Can be refined into set of concrete tables + return kIROp_WitnessTableSet; + case kIROp_FuncType: // Can be refined into set of concrete functions + return kIROp_FuncSet; + case kIROp_GenericKind: // Can be refined into set of concrete generics + return kIROp_GenericSet; + case kIROp_TypeKind: // Can be refined into set of concrete types + return kIROp_TypeSet; + case kIROp_TypeType: // Can be refined into set of concrete types + return kIROp_TypeSet; + default: + break; + } + + if (auto generic = as(type)) + { + auto innerValType = getGenericReturnVal(generic); + if (as(innerValType) || as(innerValType)) + return kIROp_GenericSet; // Can be refined into set of concrete generics. + } + + // Slight workaround for the fact that we can have cases where the type has not been specialized + // yet (particularly from auto-diff) + // + if (auto specialize = as(type)) + { + auto innerValType = getGenericReturnVal(specialize->getBase()); + if (as(innerValType)) + return kIROp_FuncSet; + if (as(innerValType)) + return kIROp_WitnessTableSet; + } + + return kIROp_Invalid; +} + +// Helper to check if an IRParam is a function parameter (vs. a phi param or generic param) +bool isFuncParam(IRParam* param) +{ + auto paramBlock = as(param->getParent()); + auto paramFunc = as(paramBlock->getParent()); + return (paramFunc && paramFunc->getFirstBlock() == paramBlock); +} + +// Helper to test if a function or generic contains a body (i.e. is intrinsic/external) +// For the purposes of type-flow, if a function body is not available, we can't analyze it. +// +bool isIntrinsic(IRInst* inst) +{ + auto func = as(inst); + if (auto specialize = as(inst)) + { + auto generic = specialize->getBase(); + func = as(getGenericReturnVal(generic)); + } + + if (!func) + return false; + + if (func->getFirstBlock() == nullptr) + return true; + + return false; +} + +// Returns true if the inst is of the form OptionalType +bool isOptionalExistentialType(IRInst* inst) +{ + if (auto optionalType = as(inst)) + if (auto interfaceType = as(optionalType->getValueType())) + return !isComInterfaceType(interfaceType) && !isBuiltin(interfaceType); + return false; +} + +// Parent context for the full type-flow pass. +struct TypeFlowSpecializationContext +{ + // Make a tagged-union-type out of a given set of tables. + // + // This type can be used for insts that are semantically a tuple of a tag (to select a table) + // and a payload to contain the existential value. + // + IRTaggedUnionType* makeTaggedUnionType(IRWitnessTableSet* tableSet) + { + IRBuilder builder(module); + HashSet typeSet; + + // Create a type set out of the base types from each table. + forEachInSet( + tableSet, + [&](IRInst* witnessTable) + { + switch (witnessTable->getOp()) + { + case kIROp_UnboundedWitnessTableElement: + typeSet.add(builder.getUnboundedTypeElement( + as(witnessTable)->getBaseInterfaceType())); + break; + case kIROp_WitnessTable: + typeSet.add(as(witnessTable)->getConcreteType()); + break; + case kIROp_UninitializedWitnessTableElement: + typeSet.add(builder.getUninitializedTypeElement( + as(witnessTable) + ->getBaseInterfaceType())); + break; + case kIROp_NoneWitnessTableElement: + typeSet.add(builder.getNoneTypeElement()); + break; + } + }); + + // Create the tagged union type out of the type and table collection. + return builder.getTaggedUnionType( + tableSet, + cast(builder.getSet(kIROp_TypeSet, typeSet))); + } + + // Creates an 'empty' inst (denoted by nullptr), that + // can be used to denote one of two things: + // + // 1. This inst does not have any dynamic components to specialize. + // e.g. an inst with a concrete int-type. + // 2. No possibilties have been propagated for this inst yet. This is the + // default starting state of all insts. + // + // From an order-theoretic perspective, 'none' is the bottom of the lattice. + // + IRInst* none() { return nullptr; } + + // Make an untagged-union type out of a given set of types. + // + // This is used to denote insts whose value can be of multiple possible types, + // Note that unlike tagged-unions, untagged-unions do not have any information + // on which type is currently held. + // + // Typically used as the type of the value part of existential objects. + // + IRUntaggedUnionType* makeUntaggedUnionType(IRTypeSet* typeSet) + { + IRBuilder builder(module); + return builder.getUntaggedUnionType(typeSet); + } + + // Make an element-of-set type out of a given set. + // + // ElementOfSetType can be used as the type of an inst whose + // _value_ is known to be one of the elements of the set. + // + // e.g. + // if we have IR of the form: + // %1 : ElementOfSetType> = ExtractExistentialType(%existentialObj) + // + // then %1 is Int or Float. + // (Note that this is different from saying %1's value is an int or a float) + // + IRElementOfSetType* makeElementOfSetType(IRSetBase* set) + { + IRBuilder builder(module); + return builder.getElementOfSetType(set); + } + + // Make a tag-type for a given set. + // + // `TagType(set)` is used to denote insts that are carrying an identifier for one of the + // elements of the set. + // These insts cannot be used directly as one of the values (must be used with `GetDispatcher` + // or `GetElementFromTag`) before they can be used as values. + // + IRSetTagType* makeTagType(IRSetBase* set) + { + IRBuilder builder(module); + return builder.getSetTagType(set); + } + + IRInst* _tryGetInfo(InstWithContext element) + { + auto found = propagationMap.tryGetValue(element); + if (found) + return *found; + return none(); // Default info for any inst that we haven't registered. + } + + // Find information for an inst that is being used as an argument + // under a given context. + // + // This is a bit different from `tryGetInfo`, in that we want to + // obtain an info that can be passed to a parameter (which can have multiple + // possibilities stemming from different argument types) + // + // This key difference is that even if the argument has no propagated info because it + // is not a dynamic or abstract type, it's possible that the parameter's type is + // abstract. + // Thus, we will construct an `UntaggedUnionType` of the concrete type so that it + // can be union-ed with other argument infos. + // + // Such 'argument' cases arise for phi-args and function args. + // + IRInst* tryGetArgInfo(IRInst* context, IRInst* inst) + { + if (auto info = tryGetInfo(context, inst)) + return info; + + if (isConcreteType(inst->getDataType())) + { + // If the inst has a concrete type, we can make a default info for it, + // considering it as a singleton set. Note that this needs to be + // nested into any relevant type structure that we want to propagate through + // `makeInfoForConcreteType` handles this logic. + // + return makeInfoForConcreteType(module, inst->getDataType()); + } + + return none(); + } + + // Bottleneck method to fetch the current propagation info + // for a given instruction under context. + // + IRInst* tryGetInfo(IRInst* context, IRInst* inst) + { + if (inst->getDataType()) + { + // If the data-type is already a tagged union or untagged union or + // element-of-set type, then the refinement occured during a previous phase. + // + // For now, we simply re-use that info directly. + // + // In the future, it makes sense to treat it as non-concrete and use + // them as an upper-bound for further refinement. + // + switch (inst->getDataType()->getOp()) + { + case kIROp_TaggedUnionType: + case kIROp_UntaggedUnionType: + case kIROp_ElementOfSetType: + return inst->getDataType(); + } + } + + // A small check for de-allocated insts. + if (!inst->getParent()) + return none(); + + // Global insts always have no info. + if (as(inst->getParent())) + return none(); + + return _tryGetInfo(InstWithContext(context, inst)); + } + + // Performs set-union over the two sets, and returns a new + // inst to represent the set. + // + template + T* unionSet(T* set1, T* set2) + { + // It's possible to accelerate this further, but we usually + // don't have to deal with overly large sets (usually 3-20 elements) + // + + SLANG_ASSERT(as(set1) && as(set2)); + SLANG_ASSERT(set1->getOp() == set2->getOp()); + + if (!set1) + return set2; + if (!set2) + return set1; + if (set1 == set2) + return set1; + + HashSet allValues; + // Collect all values from both sets + forEachInSet(set1, [&](IRInst* value) { allValues.add(value); }); + forEachInSet(set2, [&](IRInst* value) { allValues.add(value); }); + + IRBuilder builder(module); + return as(builder.getSet( + set1->getOp(), + allValues)); // Create a new set with the union of values + } + + // Find the union of two propagation info insts, and return and + // inst representing the result. + // + IRInst* unionPropagationInfo(IRInst* info1, IRInst* info2) + { + // This is similar to unionSet, but must consider structures that + // can be built out of collections. + // + // We allow some level of nesting of collections into other type instructions, + // to let us propagate information elegantly for pointers, parameters, arrays + // and existential tuples. + // + // A few cases are missing, but could be added in easily in the future: + // - TupleType (will allow us to propagate information for each tuple element) + // - Vector/Matrix types + // - TypePack + + // Basic cases: if either info is null, it is considered "empty" + // if they're equal, union must be the same inst. + // + + if (!info1) + return info2; + + if (!info2) + return info1; + + if (areInfosEqual(info1, info2)) + return info1; + + if (as(info1) && as(info2)) + { + SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); + // If both are array types, union their element types + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getArrayType( + (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), + as(info1)->getElementCount(), + getArrayStride(as(info1))); // Keep the same size + } + + if (as(info1) && as(info2)) + { + SLANG_ASSERT(info1->getOp() == info2->getOp()); + + // If both are array types, union their element types + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getPtrTypeWithAddressSpace( + (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), + as(info1)); + } + + // For all other cases which are structured composites of sets, + // we simply take the set union for all the set operands. + // + + if (as(info1) && as(info2)) + { + return makeTaggedUnionType(unionSet( + as(info1)->getWitnessTableSet(), + as(info2)->getWitnessTableSet())); + } + + if (as(info1) && as(info2)) + { + return makeTagType(unionSet( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); + } + + if (as(info1) && as(info2)) + { + return makeElementOfSetType(unionSet( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); + } + + if (as(info1) && as(info2)) + { + return makeUntaggedUnionType(unionSet( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); + } + + SLANG_UNEXPECTED("Unhandled propagation info types in unionPropagationInfo"); + } + + // Centralized method to update propagation info and add + // relevant work items to the work queue if the info changed. + // + void updateInfo( + IRInst* context, + IRInst* inst, + IRInst* newInfo, + bool takeUnion, + WorkQueue& workQueue) + { + if (isConcreteType(inst->getDataType())) + { + // No need to update info for insts with already defined concrete types, + // since these can't be refined any further. + // + return; + } + + auto existingInfo = tryGetInfo(context, inst); + auto unionedInfo = (takeUnion) ? unionPropagationInfo(existingInfo, newInfo) : newInfo; + + // Only proceed if info actually changed + if (areInfosEqual(existingInfo, unionedInfo)) + return; + + // Update the propagation map + propagationMap[InstWithContext(context, inst)] = unionedInfo; + + // Add all users to appropriate work items + addUsersToWorkQueue(context, inst, unionedInfo, workQueue); + } + + // Helper method to add work items for all call sites of a function/generic. + void addContextUsersToWorkQueue(IRInst* context, WorkQueue& workQueue) + { + if (this->funcCallSites.containsKey(context)) + for (auto callSite : this->funcCallSites[context]) + { + workQueue.enqueue(WorkItem( + InterproceduralEdge::Direction::FuncToCall, + callSite.context, + as(callSite.inst), + context)); + } + } + + // Helper method to add new work items to the queue based on the + // flavor of the instruction whose info has been updated. + // + void addUsersToWorkQueue(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) + { + // This method is responsible for ensuring the following property: + // + // If inst's information has changed, then all insts that may potentially depend + // (directly) on it should be added to the work queue. + // + // This includes a few cases: + // + // 1. In the default case, we simply add all users of the insts. + // + // 2. For insts that are used as phi-arguments, we add an intra-procedural + // edge to the target block. + // + // 3. For insts that are used as return values, we add an inter-procedural edge + // to all call sites. We do this by calling `updateFuncReturnInfo`, which takes + // care of modifying the workQueue. + // + + if (auto param = as(inst)) + if (isFuncParam(param)) + addContextUsersToWorkQueue(context, workQueue); + + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + + // If user is in a different block (or the inst is a param), add that block to work + // queue. + // + workQueue.enqueue(WorkItem(context, user)); + + // If user is a terminator, add intra-procedural edges + if (auto terminator = as(user)) + { + auto parentBlock = as(terminator->getParent()); + if (parentBlock) + { + auto successors = parentBlock->getSuccessors(); + for (auto succIter = successors.begin(); succIter != successors.end(); + ++succIter) + { + workQueue.enqueue(WorkItem(context, succIter.getEdge())); + } + } + } + + // If user is a return instruction, handle function return propagation + if (as(user)) + updateFuncReturnInfo(context, info, workQueue); + + // If the user is a top-level inout/out parameter, we need to handle it + // like we would a func-return. + // + if (auto param = as(user)) + if (isFuncParam(param)) + addContextUsersToWorkQueue(context, workQueue); + + // TODO: This is a tiny bit of a hack.. but we don't currently need to + // analyze FuncType insts, but we do want to make sure that any changes + // are propagated through to their users. + // + if (auto funcType = as(user)) + if (as(funcType->getParent())) + addUsersToWorkQueue(context, funcType, none(), workQueue); + } + } + + // Helper method to update function's return info and propagate back to call sites + void updateFuncReturnInfo(IRInst* callable, IRInst* returnInfo, WorkQueue& workQueue) + { + // Don't update info if the callee has a concrete return type. + auto callableFuncType = cast(callable->getDataType()); + if (isConcreteType(callableFuncType->getResultType())) + return; + + auto existingReturnInfo = getFuncReturnInfo(callable); + auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnInfo); + + if (!areInfosEqual(existingReturnInfo, newReturnInfo)) + { + funcReturnInfo[callable] = newReturnInfo; + + // Add interprocedural edges from the function back to all callsites. + if (funcCallSites.containsKey(callable)) + { + for (auto callSite : funcCallSites[callable]) + { + workQueue.enqueue(WorkItem( + InterproceduralEdge::Direction::FuncToCall, + callSite.context, + as(callSite.inst), + callable)); + } + } + } + } + + void performInformationPropagation() + { + // This method contains the main loop responsible for propagating information across all + // relevant functions, generics in the call graph. + // + // The mechanism is similar to data-flow analysis: + // 1. We start by initializing the propagation info for all instructions in functions + // that may be externally called. + // + // 2. For each instruction that received a non-trivial update, we add their users to the + // queue + // for further propagation. + // + // 3. Continue (2) until no more information has changed. + // + // This process is guaranteed to terminate because our propagation 'states' (i.e. + // sets and their composites) form a lattice. This is an order-theoretic + // structure that implies that + // (i) each update moves us strictly 'upward', and + // (ii) that there are a finite number of possible states. + // + + // Global worklist for interprocedural analysis. + WorkQueue workQueue; + + // Add all global functions to worklist. + // + // This could potentially be narrowed down to just entry points, but for + // now we are being conservative. Missing a potential entry point is worse + // than analyzing something that isn't used. + // + for (auto inst : module->getGlobalInsts()) + if (auto func = as(inst)) + discoverContext(func, workQueue); + + // Process until fixed point. + while (workQueue.hasItems()) + { + auto item = workQueue.dequeue(); + + switch (item.type) + { + case WorkItem::Type::Inst: + processInstForPropagation(item.context, item.inst, workQueue); + break; + case WorkItem::Type::Block: + processBlock(item.context, item.block, workQueue); + break; + case WorkItem::Type::IntraProc: + propagateWithinFuncEdge(item.context, item.intraProcEdge, workQueue); + break; + case WorkItem::Type::InterProc: + propagateInterproceduralEdge(item.interProcEdge, workQueue); + break; + default: + SLANG_UNEXPECTED("Unhandled work item type"); + return; + } + } + } + + IRInst* analyzeByType(IRInst* context, IRInst* inst) + { + if (!inst->getDataType()) + return none(); + + if (auto dataTypeInfo = tryGetInfo(context, inst->getDataType())) + { + if (auto elementOfSetType = as(dataTypeInfo)) + { + return makeUntaggedUnionType(cast(elementOfSetType->getSet())); + } + } + + return none(); + } + + void processInstForPropagation(IRInst* context, IRInst* inst, WorkQueue& workQueue) + { + IRInst* info = nullptr; + + switch (inst->getOp()) + { + case kIROp_CreateExistentialObject: + info = analyzeCreateExistentialObject(context, as(inst)); + break; + case kIROp_MakeExistential: + info = analyzeMakeExistential(context, as(inst)); + break; + case kIROp_LookupWitnessMethod: + info = analyzeLookupWitnessMethod(context, as(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + info = analyzeExtractExistentialWitnessTable( + context, + as(inst)); + break; + case kIROp_ExtractExistentialType: + info = analyzeExtractExistentialType(context, as(inst)); + break; + case kIROp_ExtractExistentialValue: + info = analyzeExtractExistentialValue(context, as(inst)); + break; + case kIROp_Call: + info = analyzeCall(context, as(inst), workQueue); + break; + case kIROp_Specialize: + info = analyzeSpecialize(context, as(inst)); + break; + case kIROp_Load: + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoad: + info = analyzeLoad(context, inst); + break; + case kIROp_LoadFromUninitializedMemory: + info = analyzeLoadFromUninitializedMemory(context, inst); + break; + case kIROp_MakeStruct: + info = analyzeMakeStruct(context, as(inst), workQueue); + break; + case kIROp_Store: + info = analyzeStore(context, as(inst), workQueue); + break; + case kIROp_GetElementPtr: + info = analyzeGetElementPtr(context, as(inst)); + break; + case kIROp_FieldAddress: + info = analyzeFieldAddress(context, as(inst)); + break; + case kIROp_FieldExtract: + info = analyzeFieldExtract(context, as(inst)); + break; + case kIROp_MakeOptionalNone: + info = analyzeMakeOptionalNone(context, as(inst)); + break; + case kIROp_MakeOptionalValue: + info = analyzeMakeOptionalValue(context, as(inst)); + break; + case kIROp_GetOptionalValue: + info = analyzeGetOptionalValue(context, as(inst)); + break; + } + + // If we didn't get any info from inst-specific analysis, we'll try to get + // info from the data-type's info. + // + if (!info) + info = analyzeByType(context, inst); + + if (info) + updateInfo(context, inst, info, false, workQueue); + } + + void processBlock(IRInst* context, IRBlock* block, WorkQueue& workQueue) + { + for (auto inst : block->getChildren()) + { + // Skip parameters & terminator + if (as(inst) || as(inst)) + continue; + processInstForPropagation(context, inst, workQueue); + } + + if (auto returnInfo = as(block->getTerminator())) + { + auto val = returnInfo->getVal(); + if (!as(val->getDataType())) + updateFuncReturnInfo(context, tryGetArgInfo(context, val), workQueue); + } + }; + + void propagateWithinFuncEdge(IRInst* context, IREdge edge, WorkQueue& workQueue) + { + // Handle intra-procedural edge (original logic) + auto predecessorBlock = edge.getPredecessor(); + auto successorBlock = edge.getSuccessor(); + + // Get the terminator instruction and extract arguments + auto terminator = predecessorBlock->getTerminator(); + if (!terminator) + return; + + // Right now, only unconditional branches can propagate new information + auto unconditionalBranch = as(terminator); + if (!unconditionalBranch) + return; + + // Find which successor this edge leads to (should be the target) + if (unconditionalBranch->getTargetBlock() != successorBlock) + return; + + // Collect propagation info for each argument and update corresponding parameter + UInt paramIndex = 0; + for (auto param : successorBlock->getParams()) + { + if (paramIndex < unconditionalBranch->getArgCount()) + { + auto arg = unconditionalBranch->getArg(paramIndex); + if (auto argInfo = tryGetArgInfo(context, arg)) + { + // Use centralized update method + updateInfo(context, param, argInfo, true, workQueue); + } + } + paramIndex++; + } + } + + void propagateInterproceduralEdge(InterproceduralEdge edge, WorkQueue& workQueue) + { + // Handle interprocedural edge + auto callInst = edge.callInst; + auto targetCallee = edge.targetContext; + + switch (edge.direction) + { + case InterproceduralEdge::Direction::CallToFunc: + { + // Propagate argument info from call site to function parameters + IRBlock* firstBlock = nullptr; + + if (as(targetCallee)) + firstBlock = targetCallee->getFirstBlock(); + else if (auto specInst = as(targetCallee)) + firstBlock = getGenericReturnVal(specInst->getBase())->getFirstBlock(); + + if (!firstBlock) + return; + + UInt argIndex = 1; // Skip callee (operand 0) + for (auto param : firstBlock->getParams()) + { + if (argIndex < callInst->getOperandCount()) + { + auto arg = callInst->getOperand(argIndex); + const auto [paramDirection, paramType] = + splitParameterDirectionAndType(param->getDataType()); + + // Only update if the parameter is not a concrete type. + // + // This is primarily just an optimization. + // Without this, we'd be storing 'singleton' sets for parameters with + // regular concrete types (i.e. 99% of cases), which can clog up + // the propagation dictionary when analyzing large modules. + // This optimization ignores them and re-derives the info + // from the data-type. + // + if (isConcreteType(paramType)) + { + argIndex++; + continue; + } + + IRInst* argInfo = tryGetArgInfo(edge.callerContext, arg); + + switch (paramDirection.kind) + { + case ParameterDirectionInfo::Kind::Out: + case ParameterDirectionInfo::Kind::BorrowInOut: + case ParameterDirectionInfo::Kind::BorrowIn: + { + IRBuilder builder(module); + if (!argInfo) + break; + + auto newInfo = fromDirectionAndType( + &builder, + paramDirection, + as(argInfo)->getValueType()); + updateInfo(edge.targetContext, param, newInfo, true, workQueue); + break; + } + case ParameterDirectionInfo::Kind::In: + { + updateInfo(edge.targetContext, param, argInfo, true, workQueue); + break; + } + default: + SLANG_UNEXPECTED( + "Unhandled parameter direction in interprocedural edge"); + } + } + argIndex++; + } + break; + } + case InterproceduralEdge::Direction::FuncToCall: + { + // If the call inst cannot accept anything dynamic, then + // no need to propagate anything to the result of the call inst. + // + // We'll still need to consider out parameters separately. + // + if (!isConcreteType(callInst->getDataType())) + { + auto returnInfoPtr = funcReturnInfo.tryGetValue(targetCallee); + auto returnInfo = (returnInfoPtr) ? *returnInfoPtr : nullptr; + if (!returnInfo) + { + // If the targetCallee's return type is concrete, but the + // callInst's return type is not, we should still propagate the + // known concrete type. + // + auto concreteReturnType = + cast(targetCallee->getDataType())->getResultType(); + if (isConcreteType(concreteReturnType)) + { + IRBuilder builder(module); + returnInfo = builder.getUntaggedUnionType(cast( + builder.getSingletonSet(kIROp_TypeSet, concreteReturnType))); + } + } + + if (returnInfo) + { + // Use centralized update method + updateInfo(edge.callerContext, callInst, returnInfo, true, workQueue); + } + } + + // Also update infos of any out parameters + auto paramInfos = getParamInfos(edge.targetContext); + auto paramDirections = getParamDirections(edge.targetContext); + UIndex argIndex = 0; + for (auto paramInfo : paramInfos) + { + if (paramInfo) + { + if (paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::Out || + paramDirections[argIndex].kind == + ParameterDirectionInfo::Kind::BorrowInOut) + { + auto arg = callInst->getArg(argIndex); + auto argPtrType = as(arg->getDataType()); + + IRBuilder builder(module); + updateInfo( + edge.callerContext, + arg, + builder.getPtrTypeWithAddressSpace( + (IRType*)as(paramInfo)->getValueType(), + argPtrType), + true, + workQueue); + } + } + argIndex++; + } + + break; + } + default: + SLANG_UNEXPECTED("Unhandled interprocedural edge direction"); + return; + } + } + + IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) + { + SLANG_UNUSED(context); + + IRBuilder builder(module); + if (auto interfaceType = as(inst->getDataType())) + { + if (isComInterfaceType(interfaceType)) + { + // If this is a COM interface, we ignore it. + return none(); + } + + HashSet& tables = *module->getContainerPool().getHashSet(); + collectExistentialTables(interfaceType, tables); + if (tables.getCount() > 0) + { + auto resultTaggedUnionType = makeTaggedUnionType( + as(builder.getSet(kIROp_WitnessTableSet, tables))); + module->getContainerPool().free(&tables); + return resultTaggedUnionType; + } + else + { + sink->diagnose( + inst, + Diagnostics::noTypeConformancesFoundForInterface, + interfaceType); + module->getContainerPool().free(&tables); + return none(); + } + } + + return none(); + } + + IRInst* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) + { + IRBuilder builder(module); + + // If we're building an existential for a COM interface, + // we always assume it is unbounded, since we can receive + // types that we know nothing about in the current linkage. + // + if (isComInterfaceType(inst->getDataType())) + { + return none(); + } + + auto witnessTable = inst->getWitnessTable(); + // Concrete case. + if (as(witnessTable)) + return makeTaggedUnionType(as( + builder.getSingletonSet(kIROp_WitnessTableSet, witnessTable))); + + // Get the witness table info + auto witnessTableInfo = tryGetInfo(context, witnessTable); + + if (!witnessTableInfo) + return none(); + + if (auto elementOfSetType = as(witnessTableInfo)) + return makeTaggedUnionType(cast(elementOfSetType->getSet())); + + SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); + } + + IRInst* analyzeMakeStruct(IRInst* context, IRMakeStruct* makeStruct, WorkQueue& workQueue) + { + // We'll process this in the same way as a field-address, but for + // all fields of the struct. + // + auto structType = as(makeStruct->getDataType()); + if (!structType) + return none(); + + UIndex operandIndex = 0; + for (auto field : structType->getFields()) + { + auto operand = makeStruct->getOperand(operandIndex); + if (auto operandInfo = tryGetInfo(context, operand)) + { + IRInst* existingInfo = nullptr; + this->fieldInfo.tryGetValue(field, existingInfo); + auto newInfo = unionPropagationInfo(existingInfo, operandInfo); + if (newInfo && !areInfosEqual(existingInfo, newInfo)) + { + // Update the field info map + this->fieldInfo[field] = newInfo; + + if (this->fieldUseSites.containsKey(field)) + for (auto useSite : this->fieldUseSites[field]) + workQueue.enqueue(WorkItem(useSite.context, useSite.inst)); + } + } + + operandIndex++; + } + + return none(); // the make struct itself doesn't have any info. + } + + IRInst* analyzeLoadFromUninitializedMemory(IRInst* context, IRInst* inst) + { + SLANG_UNUSED(context); + IRBuilder builder(module); + if (as(inst->getDataType()) && !isConcreteType(inst->getDataType())) + { + auto uninitializedSet = builder.getSingletonSet( + kIROp_WitnessTableSet, + builder.getUninitializedWitnessTableElement(inst->getDataType())); + return makeTaggedUnionType(as(uninitializedSet)); + } + + return none(); + } + + IRInst* analyzeLoad(IRInst* context, IRInst* inst) + { + IRBuilder builder(module); + if (auto loadInst = as(inst)) + { + // If we have a simple load, theres one of two cases: + // + // 1. If we're loading from a resource pointer, we need to treat it + // as unspecializable. If it's a COM interface, we consider it truly + // unbounded. Otherwise, we can simply enumerate all tables for the interface + // type. + // + // 2. In the default case, we can look up the registered information + // for the pointer, which should be of the form PtrTypeBase(valueInfo), and + // use the valueInfo + // + + if (isResourcePointer(loadInst->getPtr())) + { + if (auto interfaceType = as(loadInst->getDataType())) + { + if (!isComInterfaceType(interfaceType)) + { + HashSet& tables = *module->getContainerPool().getHashSet(); + collectExistentialTables(interfaceType, tables); + if (tables.getCount() > 0) + { + auto resultTaggedUnionType = makeTaggedUnionType(as( + builder.getSet(kIROp_WitnessTableSet, tables))); + module->getContainerPool().free(&tables); + return resultTaggedUnionType; + } + else + { + sink->diagnose( + loadInst, + Diagnostics::noTypeConformancesFoundForInterface, + interfaceType); + module->getContainerPool().free(&tables); + return none(); + } + } + else + { + return none(); + } + } + else if ( + auto boundInterfaceType = as(loadInst->getDataType())) + { + return makeTaggedUnionType(cast(builder.getSingletonSet( + kIROp_WitnessTableSet, + boundInterfaceType->getWitnessTable()))); + } + else + { + // Loading from a resource pointer that isn't an interface? + // Just return no info. + return none(); + } + } + + // If the load is from a pointer, we can transfer the info directly + auto address = as(loadInst)->getPtr(); + if (auto addrInfo = tryGetInfo(context, address)) + return as(addrInfo)->getValueType(); + else + return none(); // No info for the address + } + else if (as(inst) || as(inst)) + { + // In case of a buffer load, we know we're dealing with a location that cannot + // be specialized, so the logic is similar to case (1) from above. + // + if (auto interfaceType = as(inst->getDataType())) + { + if (!isComInterfaceType(interfaceType)) + { + HashSet& tables = *module->getContainerPool().getHashSet(); + collectExistentialTables(interfaceType, tables); + if (tables.getCount() > 0) + { + auto resultTaggedUnionType = makeTaggedUnionType( + as(builder.getSet(kIROp_WitnessTableSet, tables))); + module->getContainerPool().free(&tables); + return resultTaggedUnionType; + } + else + { + module->getContainerPool().free(&tables); + return none(); + } + } + else + { + return none(); + } + } + else if (auto boundInterfaceType = as(inst->getDataType())) + { + return makeTaggedUnionType(cast(builder.getSingletonSet( + kIROp_WitnessTableSet, + boundInterfaceType->getWitnessTable()))); + } + } + + return none(); // No info for other load types + } + + IRInst* analyzeStore(IRInst* context, IRStore* storeInst, WorkQueue& workQueue) + { + // For a simple store, we will attempt to update the location with + // the information from the stored value. + // + // Since the pointer can be an access chain, we have to recursively transfer + // the information down to the base. This logic is handled by `maybeUpdateInfoForAddress` + // + // If the value has "info", we construct an appropriate PtrType(info) and + // update the ptr with it. + // + + auto address = storeInst->getPtr(); + if (auto valInfo = tryGetInfo(context, storeInst->getVal())) + { + IRBuilder builder(module); + auto ptrInfo = builder.getPtrTypeWithAddressSpace( + (IRType*)valInfo, + as(address->getDataType())); + + // Propagate the information up the access chain to the base location. + maybeUpdateInfoForAddress(context, address, ptrInfo, workQueue); + } + + // The store inst itself doesn't produce anything, so it has no info + return none(); + } + + IRInst* analyzeGetElementPtr(IRInst* context, IRGetElementPtr* getElementPtr) + { + // The base info should be in Ptr> form, so we just need to unpack and + // return Ptr as the result. + // + IRBuilder builder(module); + builder.setInsertAfter(getElementPtr); + auto basePtr = getElementPtr->getBase(); + if (auto ptrType = as(tryGetInfo(context, basePtr))) + { + auto arrayType = as(ptrType->getValueType()); + SLANG_ASSERT(arrayType); + return builder.getPtrTypeWithAddressSpace(arrayType->getElementType(), ptrType); + } + + return none(); // No info for the base pointer => no info for the result. + } + + IRInst* analyzeFieldAddress(IRInst* context, IRFieldAddress* fieldAddress) + { + // In this case, we don't look up the base's info, but rather, we find + // the IRStructField being accessed, and look up the info in the fieldInfos + // map. + // + // This info will be the in the value form, so we need to wrap it in a + // pointer since the result is an address. + // + IRBuilder builder(module); + builder.setInsertAfter(fieldAddress); + auto basePtr = fieldAddress->getBase(); + + if (auto basePtrType = as(basePtr->getDataType())) + { + if (auto structType = as(basePtrType->getValueType())) + { + auto structField = + findStructField(structType, as(fieldAddress->getField())); + + // Register this as a user of the field so updates will invoke this function again. + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(InstWithContext(context, fieldAddress)); + + if (this->fieldInfo.containsKey(structField)) + { + return builder.getPtrTypeWithAddressSpace( + (IRType*)this->fieldInfo[structField], + as(fieldAddress->getDataType())); + } + } + } + + return none(); // No info for the field => no info for the result. + } + + IRInst* analyzeFieldExtract(IRInst* context, IRFieldExtract* fieldExtract) + { + // Very similar logic to `analyzeFieldAddress`, but without having to + // wrap the result in a pointer. + // + + IRBuilder builder(module); + + if (auto structType = as(fieldExtract->getBase()->getDataType())) + { + auto structField = + findStructField(structType, as(fieldExtract->getField())); + + // Register this as a user of the field so updates will invoke this function again. + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(InstWithContext(context, fieldExtract)); + + if (this->fieldInfo.containsKey(structField)) + { + return this->fieldInfo[structField]; + } + } + return none(); + } + + // Get the witness table inst to be used for the 'none' case of + // an optional witness table. + // + IRInst* getNoneWitness() + { + IRBuilder builder(module); + return builder.getNoneWitnessTableElement(); + } + + IRInst* analyzeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) + { + // If the optional type we're dealing with is an optional concrete type, we won't + // touch this case, since there's nothing dynamic to specialize. + // + // If the type inside the optional is an interface type, then we will treat it slightly + // differently by including 'none' as one of the possible candidates of the existential + // value. + // + // The `MakeOptionalNone` case represents the creating of an existential out of the + // 'none' witness table and a void value, so we'll represent that using the tagged union + // type. + // + SLANG_UNUSED(context); + IRBuilder builder(module); + if (isOptionalExistentialType(inst->getDataType())) + { + auto noneTableSet = + cast(builder.getSet(kIROp_WitnessTableSet, getNoneWitness())); + return makeTaggedUnionType(noneTableSet); + } + + return none(); + } + + IRInst* analyzeMakeOptionalValue(IRInst* context, IRMakeOptionalValue* inst) + { + // If the optional type we're dealing with is an optional concrete type, we won't + // touch this case, since there's nothing dynamic to specialize. + // + // If the type inside the optional is an interface type, then we will treat it slightly + // differently, by conceptually treating it as an interface type that has all the possible + // elements of the interface type plus an additional 'none' element. + // + // The `MakeOptionalValue` case is then very similar to the `MakeExistential` case, only we + // already have an existential as input. + // + // Thus, we simply pass the input existential info as-is. + // + // Note: we don't actually have to add a new 'none' table to the set, since that will + // automatically occur if this value ever merges with a value created using + // `MakeOptionalNone` + // + if (isOptionalExistentialType(inst->getDataType())) + { + if (auto info = tryGetInfo(context, inst->getValue())) + { + SLANG_ASSERT(as(info)); + return info; + } + } + + return none(); + } + + IRInst* analyzeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) + { + if (isOptionalExistentialType(inst->getOperand(0)->getDataType())) + { + // TODO: Document. + if (auto info = tryGetInfo(context, inst->getOperand(0))) + { + SLANG_ASSERT(as(info)); + + IRBuilder builder(module); + auto taggedUnion = as(info); + return builder.getTaggedUnionType( + cast(filterNoneElements(taggedUnion->getWitnessTableSet())), + cast(filterNoneElements(taggedUnion->getTypeSet()))); + } + } + + return none(); + } + + IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) + { + // A LookupWitnessMethod is assumed to by dynamic, so we + // (i) construct a set of the results by looking up the given + // key in each of the input witness tables + // (ii) wrap the result in a tag type, since the lookup inst is logically holding + // on to run-time information about which element of the set is active. + // + // Note that the input must be a set of concrete witness tables (or none/unbounded). + // If this is not the case and we see anything abstract, then something has gone + // wrong somewhere when analyzing a previous instruction. + // + + auto key = inst->getRequirementKey(); + + auto witnessTable = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(context, witnessTable); + + if (auto elementOfSetType = as(witnessTableInfo)) + { + IRBuilder builder(module); + HashSet& results = *module->getContainerPool().getHashSet(); + forEachInSet( + cast(elementOfSetType->getSet()), + [&](IRInst* table) + { + if (as(table)) + { + if (inst->getDataType()->getOp() == kIROp_FuncType) + results.add(builder.getUnboundedFuncElement()); + else if (inst->getDataType()->getOp() == kIROp_WitnessTableType) + results.add(builder.getUnboundedWitnessTableElement( + as(inst)->getDataType())); + else if (inst->getDataType()->getOp() == kIROp_TypeKind) + { + SLANG_UNEXPECTED( + "TypeKind result from LookupWitnessMethod not supported"); + } + return; + } + + results.add(findWitnessTableEntry(cast(table), key)); + }); + + auto setOp = getSetOpFromType(inst->getDataType()); + auto resultSetType = makeElementOfSetType(builder.getSet(setOp, results)); + module->getContainerPool().free(&results); + + return resultSetType; + } + + if (!witnessTableInfo) + return none(); + + SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); + } + + IRInst* analyzeExtractExistentialWitnessTable( + IRInst* context, + IRExtractExistentialWitnessTable* inst) + { + // An ExtractExistentialWitnessTable inst is assumed to by dynamic, so we + // extract the set of witness tables from the input existential and + // state that the info of the result is a tag-type of that set. + // + // Note that since ExtractExistentialWitnessTable can only be used on + // an existential, the input info must be a TaggedUnionType of + // concrete table and type sets (or none/unbounded) + // + + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(context, operand); + + if (!operandInfo) + return none(); + + if (auto taggedUnion = as(operandInfo)) + { + auto tableSet = taggedUnion->getWitnessTableSet(); + if (auto uninitElement = tableSet->tryGetUninitializedElement()) + { + sink->diagnose( + inst->sourceLoc, + Diagnostics::dynamicDispatchOnPotentiallyUninitializedExistential, + uninitElement->getOperand(0)); + + return none(); // We'll return none so that the analysis doesn't + // crash early, before we can detect the error count + // and exit gracefully. + } + + return makeElementOfSetType(tableSet); + } + + SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); + } + + IRInst* analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) + { + // An ExtractExistentialType inst is assumed to be dynamic, so we + // extract the set of witness tables from the input existential and + // state that the info of the result is a tag-type of that set. + // + // Note: Since ExtractExistentialType can only be used on + // an existential, the input info must be a TaggedUnionType of + // concrete table and type sets (or none/unbounded) + // + + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(context, operand); + + if (!operandInfo) + return none(); + + if (auto taggedUnion = as(operandInfo)) + return makeElementOfSetType(taggedUnion->getTypeSet()); + + SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); + } + + IRInst* analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) + { + // Logically, an ExtractExistentialValue inst is carrying a payload + // of a union type. + // + // We represent this by setting its info to be equal to the type-set, + // which will later lower into an any-value-type. + // + // Note that there is no 'tag' here since ExtractExistentialValue is not representing + // tag information about which type in the set is active, but is representing + // a value of the set's union type. + // + + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(context, operand); + + if (!operandInfo) + return none(); + + if (auto taggedUnion = as(operandInfo)) + return makeUntaggedUnionType(taggedUnion->getTypeSet()); + + return none(); + } + + IRInst* analyzeSpecialize(IRInst* context, IRSpecialize* inst) + { + // Analyzing an IRSpecialize inst is an interesting case. + // + // If we hit this case, it means we are encountering this instruction + // inside a block, so the arguments to the specialization likely have some + // dynamic types or witness tables. + // + // We'll first look at the specialization base, which may be a single generic + // or a set of generics. + // + // Then, for each generic, we'll create a specialized version by using the + // set info for each argument in place of the argument. + // e.g. Specialize(G, A0, A1) becomes Specialize(G, info(A1).set, + // info(A2).set) + // (i.e. if the args are tag-types, we only use the set part) + // + // This transformation is important to lift the 'dynamic' specialize instruction into a + // global specialize instruction while still retaining the information about what types and + // tables the resulting generic should support. + // + // Finally, we put all the specialized vesions back into a set and return that info. + // + + auto operand = inst->getBase(); + auto operandInfo = tryGetInfo(context, operand); + + if (as(operandInfo)) + { + SLANG_UNEXPECTED("Unexpected operand for IRSpecialize"); + } + + // Handle the 'many' or 'one' cases. + if (as(operandInfo) || isGlobalInst(operand)) + { + List& specializationArgs = *module->getContainerPool().getList(); + for (UInt i = 0; i < inst->getArgCount(); ++i) + { + // For concrete args, add as-is. + if (isGlobalInst(inst->getArg(i))) + { + specializationArgs.add(inst->getArg(i)); + continue; + } + + // For dynamic args, we need to replace them with + // their sets (if available) + // + auto argInfo = tryGetInfo(context, inst->getArg(i)); + + // If any of the args are 'empty' sets, we can't generate a specialization just yet. + if (!argInfo) + { + module->getContainerPool().free(&specializationArgs); + return none(); + } + + if (as(argInfo)) + { + SLANG_UNEXPECTED("Unexpected Existential operand in specialization argument."); + } + + if (auto elementOfSetType = as(argInfo)) + { + if (elementOfSetType->getSet()->isSingleton()) + specializationArgs.add(elementOfSetType->getSet()->getElement(0)); + else if ( + auto unboundedElement = + elementOfSetType->getSet()->tryGetUnboundedElement()) + { + // Infinite set. + // + // While our sets allow for encoding an unbounded case along with known + // cases, when it comes to specializing a function or placing a call to a + // function, we will default to the single unbounded element case. + // + IRBuilder builder(module); + SLANG_ASSERT(unboundedElement); + auto setOp = getSetOpFromType(inst->getArg(i)->getDataType()); + auto pureUnboundedSet = builder.getSingletonSet(setOp, unboundedElement); + if (auto typeSet = as(pureUnboundedSet)) + specializationArgs.add(makeUntaggedUnionType(typeSet)); + else + specializationArgs.add(pureUnboundedSet); + } + else + { + // Dealing with a non-singleton, but finite set. + if (auto typeSet = as(elementOfSetType->getSet())) + { + specializationArgs.add(makeUntaggedUnionType(typeSet)); + } + else if (as(elementOfSetType->getSet())) + { + specializationArgs.add(elementOfSetType->getSet()); + } + else + { + SLANG_UNEXPECTED("Unexpected set type in specialization argument."); + } + } + } + else + { + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeSpecialize"); + } + } + + // This part creates a correct type for the specialization, by following the same + // process: replace all operands in the composite type with their propagated set. + // + + IRType* typeOfSpecialization = nullptr; + if (inst->getDataType()->getParent()->getOp() == kIROp_ModuleInst) + typeOfSpecialization = inst->getDataType(); + else if (auto funcType = as(inst->getDataType())) + { + auto substituteSets = [&](IRInst* type) -> IRInst* + { + if (auto info = tryGetInfo(context, type)) + { + if (auto elementOfSetType = as(info)) + { + if (elementOfSetType->getSet()->isSingleton()) + return elementOfSetType->getSet()->getElement(0); + else if ( + auto unboundedElement = + elementOfSetType->getSet()->tryGetUnboundedElement()) + { + IRBuilder builder(module); + return makeUntaggedUnionType(cast( + builder.getSingletonSet(kIROp_TypeSet, unboundedElement))); + } + else + return makeUntaggedUnionType( + cast(elementOfSetType->getSet())); + } + else + return type; + } + else + return type; + }; + + List& newParamTypes = *module->getContainerPool().getList(); + for (auto paramType : funcType->getParamTypes()) + newParamTypes.add((IRType*)substituteSets(paramType)); + IRBuilder builder(module); + builder.setInsertInto(module); + typeOfSpecialization = builder.getFuncType( + newParamTypes.getCount(), + newParamTypes.getBuffer(), + (IRType*)substituteSets(funcType->getResultType())); + module->getContainerPool().free(&newParamTypes); + } + else if (auto typeInfo = tryGetInfo(context, inst->getDataType())) + { + // There's one other case we'd like to handle, where the func-type itself is a + // dynamic IRSpecialize. In this situation, we'd want to use the type inst's info to + // find the set-based specialization and create a func-type from it. + // + if (auto elementOfSetType = as(typeInfo)) + { + SLANG_ASSERT(elementOfSetType->getSet()->isSingleton()); + auto specializeInst = + cast(elementOfSetType->getSet()->getElement(0)); + auto specializedFuncType = cast(specializeGeneric(specializeInst)); + typeOfSpecialization = specializedFuncType; + } + else + { + module->getContainerPool().free(&specializationArgs); + return none(); + } + } + else + { + // We don't have a type we can work with just yet. + module->getContainerPool().free(&specializationArgs); + return none(); // No info for the type + } + + if (!isGlobalInst(typeOfSpecialization)) + { + // Our func-type operand is not yet been lifted. + // For now, we can't say anything. + // + module->getContainerPool().free(&specializationArgs); + return none(); + } + + // Specialize each element in the set + HashSet& specializedSet = *module->getContainerPool().getHashSet(); + + IRSetBase* set = nullptr; + if (auto elementOfSetType = as(operandInfo)) + { + set = elementOfSetType->getSet(); + + forEachInSet( + set, + [&](IRInst* arg) + { + // Create a new specialized instruction for each argument + IRBuilder builder(module); + builder.setInsertInto(module); + + if (as(arg)) + { + // Infinite set. + // + // We currently only support specializing generic functions + // in this way, so we'll assume its an unbounded-func-element + // + specializedSet.add(builder.getUnboundedFuncElement()); + return; + } + + specializedSet.add(builder.emitSpecializeInst( + typeOfSpecialization, + arg, + specializationArgs)); + }); + } + else + { + // Concrete case.. + IRBuilder builder(module); + builder.setInsertInto(module); + specializedSet.add( + builder.emitSpecializeInst(typeOfSpecialization, operand, specializationArgs)); + } + + IRBuilder builder(module); + auto setOp = getSetOpFromType(inst->getDataType()); + auto resultSetType = makeElementOfSetType(builder.getSet(setOp, specializedSet)); + module->getContainerPool().free(&specializedSet); + module->getContainerPool().free(&specializationArgs); + return resultSetType; + } + + if (!operandInfo) + return none(); + + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); + } + + void discoverContext(IRInst* context, WorkQueue& workQueue) + { + // "Discovering" a context, essentially means we check if this is the first + // time we're trying to propagate information into this context. A context + // is a global-scope IRFunc or IRSpecialize. + // + // If it is the first, we enqueue some work to perform initialization of all + // the insts in the body of the func. + // + // Since discover context is only called 'on-demand' as the type-flow propagation + // happens, we avoid having to deal with functions/generics that are never used, + // and minimize the amount of work being performed. + // + + if (this->availableContexts.add(context)) + { + IRFunc* func = nullptr; + + // Newly discovered context. Initialize it. + switch (context->getOp()) + { + case kIROp_Func: + { + func = cast(context); + + // Initialize the first block parameters + initializeFirstBlockParameters(context, func); + + // Add all blocks to the work queue + for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) + workQueue.enqueue(WorkItem(context, block)); + + break; + } + case kIROp_Specialize: + { + auto specialize = cast(context); + auto generic = cast(specialize->getBase()); + func = cast(findGenericReturnVal(generic)); + + // Transfer information from specialization arguments to the params in the + // first generic block. + // + IRParam* param = generic->getFirstBlock()->getFirstParam(); + for (UInt index = 0; index < specialize->getArgCount() && param; + ++index, param = param->getNextParam()) + { + // Map the specialization argument to the corresponding parameter + auto arg = specialize->getArg(index); + if (as(arg)) + continue; + + if (auto set = as(arg)) + { + updateInfo(context, param, makeElementOfSetType(set), true, workQueue); + } + else if (auto untaggedUnion = as(arg)) + { + IRBuilder builder(module); + updateInfo( + context, + param, + makeElementOfSetType(untaggedUnion->getSet()), + true, + workQueue); + } + else if (as(arg)) + { + IRBuilder builder(module); + updateInfo( + context, + param, + makeElementOfSetType(builder.getSingletonSet(kIROp_TypeSet, arg)), + true, + workQueue); + } + else if (as(arg)) + { + IRBuilder builder(module); + updateInfo( + context, + param, + makeElementOfSetType( + builder.getSingletonSet(kIROp_WitnessTableSet, arg)), + true, + workQueue); + } + else + { + SLANG_UNEXPECTED("Unexpected argument type in specialization"); + } + } + + // Initialize the first block parameters + initializeFirstBlockParameters(context, func); + + // Add all blocks to the work queue for an initial sweep + for (auto block = generic->getFirstBlock(); block; + block = block->getNextBlock()) + workQueue.enqueue(WorkItem(context, block)); + + for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) + workQueue.enqueue(WorkItem(context, block)); + } + } + } + } + + IRInst* analyzeCall(IRInst* context, IRCall* inst, WorkQueue& workQueue) + { + // We don't perform the propagation here, but instead we add inter-procedural + // edges to the work queue. + // The propagation logic is handled in `propagateInterproceduralEdge()` + // + + auto callee = inst->getCallee(); + auto calleeInfo = tryGetInfo(context, callee); + + if (isNoneCallee(callee)) + return none(); + + auto propagateToCallSite = [&](IRInst* callee) + { + if (as(callee)) + { + // An unbounded element represents an unknown function, + // so we can't propagate anything in this case. + // + return; + } + + // Register the call site in the map to allow for the + // return-edge to be created. + // + // We use an explicit map instead of walking the uses of the + // func, since we might have functions that are called indirectly + // through lookups. + // + discoverContext(callee, workQueue); + + this->funcCallSites.addIfNotExists(callee, HashSet()); + if (this->funcCallSites[callee].add(InstWithContext(context, inst))) + { + // If this is a new call site, add a propagation task to the queue (in case there's + // already information about this function) + workQueue.enqueue( + WorkItem(InterproceduralEdge::Direction::FuncToCall, context, inst, callee)); + } + workQueue.enqueue( + WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); + }; + + // If we have a set of functions (with or without a dynamic tag), register + // each one. + // + if (auto elementOfSetType = as(calleeInfo)) + { + forEachInSet( + elementOfSetType->getSet(), + [&](IRInst* func) { propagateToCallSite(func); }); + } + else if (isGlobalInst(callee)) + { + propagateToCallSite(callee); + } + + if (auto callInfo = tryGetInfo(context, inst)) + return callInfo; + else + return none(); + } + + // Updates the information for an address. + void maybeUpdateInfoForAddress( + IRInst* context, + IRInst* inst, + IRInst* info, + WorkQueue& workQueue) + { + // This method recursively walks up the access chain until it hits a location. + // + // Pointers don't have any unique information attached to them because two pointers to the + // same location (directly or indirectly) must mirror the same propagation info. + // + // Thus, our approach will be to update the info for the base value (usually an IRVar or + // IRParam), and then register users for updates to propagate the updated information + // to any access chain instructions. + // + + if (auto getElementPtr = as(inst)) + { + if (auto thisPtrInfo = as(info)) + { + // For get-element-ptr, we propagate information by + // wrapping the result's info into the array type. + // + // a : PtrType(ArrayType(T, count)) = ... + // b : PtrType(T) = GetElementPtr(a) + // + // info(a) = ArrayType(info(b), count) + // + + auto thisValueInfo = thisPtrInfo->getValueType(); + + IRInst* baseValueType = + as(getElementPtr->getBase()->getDataType())->getValueType(); + SLANG_ASSERT(as(baseValueType)); + + // Propagate 'this' information to the base by wrapping it as a pointer to array. + IRBuilder builder(module); + auto baseInfo = builder.getPtrTypeWithAddressSpace( + builder.getArrayType( + (IRType*)thisValueInfo, + as(baseValueType)->getElementCount(), + getArrayStride(as(baseValueType))), + as(getElementPtr->getBase()->getDataType())); + + // Recursively try to update the base pointer. + maybeUpdateInfoForAddress(context, getElementPtr->getBase(), baseInfo, workQueue); + } + } + else if (auto fieldAddress = as(inst)) + { + if (as(info)) + { + // Field address is also treated as a base case for the recursion. + // + // For field-address, we record the information against the field itself + // by using the fieldInfos map (after unwrapping the pointer) + // + // a : PtrType(S) = ... + // b : PtrType(T) = GetFieldAddress(a, fieldKey) + // + // infos[findField(S, fieldKey)] = info(T) + // + + IRBuilder builder(module); + auto baseStructPtrType = as(fieldAddress->getBase()->getDataType()); + auto baseStructType = as(baseStructPtrType->getValueType()); + if (!baseStructType) + return; // Do nothing.. + + if (auto fieldKey = as(fieldAddress->getField())) + { + IRStructField* foundField = findStructField(baseStructType, fieldKey); + IRInst* existingInfo = nullptr; + this->fieldInfo.tryGetValue(foundField, existingInfo); + + if (existingInfo) + existingInfo = builder.getPtrTypeWithAddressSpace( + (IRType*)existingInfo, + as(fieldAddress->getDataType())); + + // Manually update the prop info add work items for all users of this field. + // + // This case is not handled by updateInfo(), though in the future + // it makes sense to include this as a case in updateInfo() + // + if (auto newInfo = unionPropagationInfo(info, existingInfo)) + { + if (newInfo != existingInfo) + { + auto newInfoValType = cast(newInfo)->getValueType(); + + // Update the field info map + this->fieldInfo[foundField] = newInfoValType; + + if (this->fieldUseSites.containsKey(foundField)) + for (auto useSite : this->fieldUseSites[foundField]) + workQueue.enqueue(WorkItem(useSite.context, useSite.inst)); + } + } + } + } + } + else if (auto var = as(inst)) + { + // If we hit a local var, we'll update its info. + // + // This is one of the base cases for the recursion. + // + updateInfo(context, var, info, true, workQueue); + } + else if (auto param = as(inst)) + { + // We'll also update function parameters, + // but first change the info from PtrTypeBase + // to the specific pointer type for the parameter. + // + // (e.g. parameter may use a BorrowInOutType, but the info + // may be some other PtrType) + // + // This is one of the base cases for the recursion. + // + IRBuilder builder(param->getModule()); + auto newInfo = builder.getPtrTypeWithAddressSpace( + (IRType*)as(info)->getValueType(), + as(param->getDataType())); + + updateInfo(context, param, newInfo, true, workQueue); + } + else + { + // If we hit something unsupported, assume there's nothing to update. + return; + } + } + + // Returns the effective parameter types for a given calling context, after + // the type-flow propagation is complete. + // + List getEffectiveParamTypes(IRInst* context) + { + // This proceeds by looking at the propagation info for each parameter, + // then returning the info if one exists. + // + // If one does not exist, it means the parameter has a concrete type + // (not dynamic or generic), and we can just use that for the parameter. + // + + List effectiveTypes; + IRFunc* func = nullptr; + if (as(context)) + { + func = as(context); + } + else if (auto specialize = as(context)) + { + auto generic = specialize->getBase(); + auto innerFunc = getGenericReturnVal(generic); + func = cast(innerFunc); + } + else + { + // If it's not a function or a specialization, we can't get parameter info + SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); + } + + for (auto param : func->getParams()) + { + if (auto newInfo = tryGetInfo(context, param)) + if (getLoweredType(newInfo) != nullptr) // Check that info isn't unbounded + { + effectiveTypes.add((IRType*)newInfo); + continue; + } + + // Fallback.. no new info, just use the param type. + effectiveTypes.add(param->getDataType()); + } + + return effectiveTypes; + } + + // Helper to get any recorded propagation info for each parameter of a calling context. + List getParamInfos(IRInst* context) + { + List infos; + if (as(context)) + { + for (auto param : as(context)->getParams()) + infos.add(tryGetArgInfo(context, param)); + } + else if (auto specialize = as(context)) + { + auto generic = specialize->getBase(); + auto innerFunc = getGenericReturnVal(generic); + for (auto param : as(innerFunc)->getParams()) + infos.add(tryGetArgInfo(context, param)); + } + else + { + // If it's not a function or a specialization, we can't get parameter info + SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); + } + + return infos; + } + + // Helper to extract the directions of each parameter for a calling context. + List getParamDirections(IRInst* context) + { + // Note that this method does not actually have to retreive any propagation info, + // since the directions/address-spaces of parameters are always concrete. + // + + List directions; + if (as(context)) + { + for (auto param : as(context)->getParams()) + { + const auto [direction, type] = splitParameterDirectionAndType(param->getDataType()); + directions.add(direction); + } + } + else if (auto specialize = as(context)) + { + auto generic = specialize->getBase(); + auto innerFunc = getGenericReturnVal(generic); + for (auto param : as(innerFunc)->getParams()) + { + const auto [direction, type] = splitParameterDirectionAndType(param->getDataType()); + directions.add(direction); + } + } + else + { + // If it's not a function or a specialization, we can't get parameter info + SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); + } + + return directions; + } + + // Extract the return value information for a given calling context + IRInst* getFuncReturnInfo(IRInst* context) + { + // We record the information in a separate map, rather than using + // a specific inst. + // + // This is because we need the union of the infos of all Return instructions + // in the function, but there's no physical instruction that represents this. + // (unlike the block control flow case, where phi params exist) + // + // Effectively, this is a 'virtual' inst that represents the union of all + // the return values. + // + funcReturnInfo.addIfNotExists(context, none()); + return funcReturnInfo[context]; + } + + // Set up initial information for parameters based on their types. + void initializeFirstBlockParameters(IRInst* context, IRFunc* func) + { + // This method primarily just initializes known COM & Builtin interface + // types to 'unbounded', to avoid specializing any instructions derived from + // these parameters. + // + + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return; + + // Initialize parameters with COM/Builtin interface types to 'unbounded' and + // everything else to none. + // + for (auto param : firstBlock->getParams()) + { + auto paramInfo = tryGetInfo(context, param); + if (paramInfo) + continue; // Already has some information + } + } + + // Specialize the fields of a struct type based on the recorded field info (if we + // have a non-trivial specilialization) + // + bool specializeStructType(IRStructType* structType) + { + bool hasChanges = false; + for (auto field : structType->getFields()) + { + IRInst* info = nullptr; + this->fieldInfo.tryGetValue(field, info); + if (!info) + continue; + + auto specializedFieldType = getLoweredType(info); + if (specializedFieldType != field->getFieldType()) + { + hasChanges = true; + field->setFieldType(specializedFieldType); + } + } + + return hasChanges; + } + + bool specializeInstsInBlock(IRInst* context, IRBlock* block) + { + List& instsToLower = *module->getContainerPool().getList(); + + bool hasChanges = false; + for (auto inst : block->getChildren()) + instsToLower.add(inst); + + for (auto inst : instsToLower) + hasChanges |= specializeInst(context, inst); + + module->getContainerPool().free(&instsToLower); + return hasChanges; + } + + bool specializeFunc(IRFunc* func) + { + // When specializing a func, we + // (i) rewrite the types and insts by calling `specializeInstsInBlock` and + // (ii) handle 'merge' points where the sets need to be upcasted. + // + // The merge points are places where a specialized inst might be passed as + // argument to a parameter that has a 'wider' type. + // + // This frequently occurs with phi parameters. + // + // For example: + // A B + // \ / + // C + // + // After specialization, A could pass a value of type TagType(WitnessTableSet{T1, + // T2}) while B passes a value of type TagType(WitnessTableSet{T2, T3}), while the + // phi parameter's type in C has the union type `TagType(WitnessTableSet{T1, T2, + // T3})` + // + // In this case, we use `upcastSet` to insert a cast from + // TagType(WitnessTableSet{T1, T2}) -> TagType(WitnessTableSet{T1, T2, T3}) + // before passing the result as a phi argument. + // + // The same logic applies for the return values. The function's caller expects a union type + // of all possible return statements, so we cast each return inst if there is a mismatch. + // + + // Don't make any changes to non-global or intrinsic functions. + // + // If a function is inside a generic, we wait until the main specialization pass + // turns it into a regular func and the typeflow pass is re-run again. + // This approach is much simpler that trying to incorporate generic parameters into the + // typeflow specialization logic. + // + if (!isGlobalInst(func) || isIntrinsic(func)) + return false; + + bool hasChanges = false; + for (auto block : func->getBlocks()) + hasChanges |= specializeInstsInBlock(func, block); + + for (auto block : func->getBlocks()) + { + UIndex paramIndex = 0; + for (auto param : block->getParams()) + { + auto paramInfo = _tryGetInfo(InstWithContext(func, param)); + if (!paramInfo) + { + paramIndex++; + continue; + } + + // Find all predecessors of this block + for (auto pred : block->getPredecessors()) + { + auto terminator = pred->getTerminator(); + if (auto unconditionalBranch = as(terminator)) + { + auto arg = unconditionalBranch->getArg(paramIndex); + IRBuilder builder(module); + builder.setInsertBefore(unconditionalBranch); + auto newArg = upcastSet(&builder, arg, param->getDataType()); + + if (newArg != arg) + { + hasChanges = true; + + // Replace the argument in the branch instruction with the + // properly casted argument. + // + if (auto loop = as(unconditionalBranch)) + loop->setOperand(3 + paramIndex, newArg); + else + unconditionalBranch->setOperand(1 + paramIndex, newArg); + } + } + } + + paramIndex++; + } + + // If the terminator is a return instruction, perform the same upcasting to + // match the registered return value type for this function. + // + if (auto returnInst = as(block->getTerminator())) + { + if (!as(returnInst->getVal()->getDataType())) + { + if (auto specializedType = getLoweredType(getFuncReturnInfo(func))) + { + IRBuilder builder(module); + builder.setInsertBefore(returnInst); + auto newReturnVal = + upcastSet(&builder, returnInst->getVal(), specializedType); + if (newReturnVal != returnInst->getVal()) + { + // Replace the return value with the reinterpreted value + hasChanges = true; + returnInst->setOperand(0, newReturnVal); + } + } + } + } + } + + // Update the func type for this func accordingly. + auto effectiveFuncType = getEffectiveFuncType(func); + if (effectiveFuncType != func->getFullType()) + { + hasChanges = true; + func->setFullType(effectiveFuncType); + } + + return hasChanges; + } + + // Implements phase 2 of the type-flow specialization pass. + // + // This method is called after information propagation is complete and + // stabilized, and it replaces dynamic insts and types with specialized versions + // based on the collected information. + // + // After this pass is run, there should be no dynamic insts or types remaining, + // _except_ for those that are considered unbounded. + // + // i.e. `ExtractExistentialType`, `ExtractExistentialWitnessTable`, `ExtractExistentialValue`, + // `MakeExistential`, `LookupWitness` (and more) are rewritten to concrete tag translation + // insts (e.g. `GetTagForMappedSet`, `GetTagForSpecializedSet`, etc.) + // + bool performDynamicInstLowering() + { + List funcsToProcess; + List structsToProcess; + + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as(globalInst)) + funcsToProcess.add(func); + else if (auto structType = as(globalInst)) + structsToProcess.add(structType); + } + + bool hasChanges = false; + + // Lower struct types first so that data access can be + // marshalled properly during func specializeing. + // + for (auto structType : structsToProcess) + hasChanges |= specializeStructType(structType); + + for (auto func : funcsToProcess) + hasChanges |= specializeFunc(func); + + return hasChanges; + } + + // Returns an effective type to use for an inst, given its + // info. + // + // This basically recursively walks the info and applies the array/ptr-type + // wrappers, while replacing unbounded set with a nullptr. + // + // If the result of this is null, then the inst should keep its current type. + // + IRType* getLoweredType(IRInst* info) + { + if (!info) + return nullptr; + + if (auto ptrType = as(info)) + { + IRBuilder builder(module); + if (auto specializedValueType = getLoweredType(ptrType->getValueType())) + { + return builder.getPtrTypeWithAddressSpace((IRType*)specializedValueType, ptrType); + } + else + return nullptr; + } + + if (auto arrayType = as(info)) + { + IRBuilder builder(module); + if (auto specializedElementType = getLoweredType(arrayType->getElementType())) + { + return builder.getArrayType( + (IRType*)specializedElementType, + arrayType->getElementCount(), + getArrayStride(arrayType)); + } + else + return nullptr; + } + + if (auto taggedUnion = as(info)) + { + return (IRType*)taggedUnion; + } + + if (auto elementOfSetType = as(info)) + { + // Replace element-of-set types with tag types. + return makeTagType(elementOfSetType->getSet()); + } + + if (auto valOfSetType = as(info)) + { + if (valOfSetType->getSet()->isSingleton()) + { + // If there's only one type in the set, return it directly + return (IRType*)valOfSetType->getSet()->getElement(0); + } + + return valOfSetType; + } + + if (as(info) || as(info)) + { + // Don't specialize these sets.. they should be used through + // tag types, or be processed out during specializeing. + // + return nullptr; + } + + return (IRType*)info; + } + + // Replace an insts type with its effective type as determined by the analysis. + bool replaceType(IRInst* context, IRInst* inst) + { + // If the inst is a global val, we won't modify it. + if (as(inst->getParent())) + { + if (as(inst) || as(inst) || as(inst) || + as(inst)) + { + return false; + } + } + + if (auto info = tryGetInfo(context, inst)) + { + if (auto specializedType = getLoweredType(info)) + { + if (specializedType == inst->getDataType()) + return false; // No change + inst->setFullType(specializedType); + return true; + } + } + return false; + } + + bool specializeInst(IRInst* context, IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_LookupWitnessMethod: + return specializeLookupWitnessMethod(context, as(inst)); + case kIROp_ExtractExistentialWitnessTable: + return specializeExtractExistentialWitnessTable( + context, + as(inst)); + case kIROp_ExtractExistentialType: + return specializeExtractExistentialType(context, as(inst)); + case kIROp_ExtractExistentialValue: + return specializeExtractExistentialValue(context, as(inst)); + case kIROp_Call: + return specializeCall(context, as(inst)); + case kIROp_MakeExistential: + return specializeMakeExistential(context, as(inst)); + case kIROp_WrapExistential: + return specializeWrapExistential(context, as(inst)); + case kIROp_MakeStruct: + return specializeMakeStruct(context, as(inst)); + case kIROp_CreateExistentialObject: + return specializeCreateExistentialObject(context, as(inst)); + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoad: + return specializeStructuredBufferLoad(context, inst); + case kIROp_Specialize: + return specializeSpecialize(context, as(inst)); + case kIROp_GetValueFromBoundInterface: + return specializeGetValueFromBoundInterface( + context, + as(inst)); + case kIROp_GetElementFromTag: + return specializeGetElementFromTag(context, as(inst)); + case kIROp_Load: + return specializeLoad(context, inst); + case kIROp_Store: + return specializeStore(context, as(inst)); + case kIROp_GetSequentialID: + return specializeGetSequentialID(context, as(inst)); + case kIROp_IsType: + return specializeIsType(context, as(inst)); + case kIROp_MakeOptionalNone: + return specializeMakeOptionalNone(context, as(inst)); + case kIROp_MakeOptionalValue: + return specializeMakeOptionalValue(context, as(inst)); + case kIROp_OptionalHasValue: + return specializeOptionalHasValue(context, as(inst)); + case kIROp_GetOptionalValue: + return specializeGetOptionalValue(context, as(inst)); + default: + { + // Default case: replace inst type with specialized type (if available) + if (tryGetInfo(context, inst)) + return replaceType(context, inst); + return false; + } + } + } + + bool specializeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) + { + // Handle trivial case where inst's operand is a concrete table. + if (auto witnessTable = as(inst->getWitnessTable())) + { + inst->replaceUsesWith(findWitnessTableEntry(witnessTable, inst->getRequirementKey())); + inst->removeAndDeallocate(); + return true; + } + + // Otherwise, we go off the info for the inst. + auto info = tryGetInfo(context, inst); + if (!info) + return false; + + // If we didn't resolve anything for this inst, don't modify it. + auto elementOfSetType = as(info); + if (!elementOfSetType) + return false; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // If there's a single element, we can do a simple replacement. + if (elementOfSetType->getSet()->isSingleton()) + { + auto element = elementOfSetType->getSet()->getElement(0); + inst->replaceUsesWith(element); + inst->removeAndDeallocate(); + return true; + } + else if (elementOfSetType->getSet()->isEmpty()) + { + auto poison = builder.emitPoison(inst->getDataType()); + inst->replaceUsesWith(poison); + inst->removeAndDeallocate(); + return true; + } + + // If we reach here, we have a truly dynamic case. Multiple elements. + // We need to emit a run-time inst to keep track of the tag. + // + // We use the GetTagForMappedSet inst to do this, and set its data type to + // the appropriate tag-type. + // + + auto witnessTableInst = inst->getWitnessTable(); + auto witnessTableInfo = witnessTableInst->getDataType(); + + if (as(witnessTableInfo)) + { + auto thisInstInfo = cast(tryGetInfo(context, inst)); + if (thisInstInfo->getSet() != nullptr) + { + IRInst* operands[] = {witnessTableInst, inst->getRequirementKey()}; + + auto newInst = builder.emitIntrinsicInst( + (IRType*)makeTagType(thisInstInfo->getSet()), + kIROp_GetTagForMappedSet, + 2, + operands); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + + return false; + } + + bool specializeExtractExistentialWitnessTable( + IRInst* context, + IRExtractExistentialWitnessTable* inst) + { + // If we have a non-trivial info registered, it must of + // SetTagType(WitnessTableSet(...)) + // + // Further, the operand must be an existential (TaggedUnionType), which is + // conceptually a pair of TagType(tableSet) and a + // UntaggedUnionType(typeSet) + // + // We will simply extract the first element of this tuple. + // + + auto info = tryGetInfo(context, inst); + if (!info) + return false; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + if (auto elementOfSetType = as(info)) + { + if (elementOfSetType->getSet()->isSingleton()) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(elementOfSetType->getSet()->getElement(0)); + inst->removeAndDeallocate(); + return true; + } + else if (elementOfSetType->getSet()->isEmpty()) + { + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + inst->removeAndDeallocate(); + return true; + } + else + { + // Replace with GetElement(specializedInst, 0) -> TagType(tableSet) + // which retreives a 'tag' (i.e. a run-time identifier for one of the elements + // of the set) + // + auto operand = inst->getOperand(0); + auto element = builder.emitGetTagFromTaggedUnion(operand); + inst->replaceUsesWith(element); + inst->removeAndDeallocate(); + return true; + } + } + else + { + SLANG_UNEXPECTED("Unexpected info type for ExtractExistentialWitnessTable"); + } + } + + bool specializeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) + { + SLANG_UNUSED(context); + + auto existential = inst->getOperand(0); + auto existentialInfo = existential->getDataType(); + if (as(existentialInfo)) + { + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + auto val = builder.emitGetValueFromTaggedUnion(existential); + inst->replaceUsesWith(val); + inst->removeAndDeallocate(); + return true; + } + else if (as(existential)) + { + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + inst->removeAndDeallocate(); + return true; + } + + return false; + } + + bool specializeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) + { + auto info = tryGetInfo(context, inst); + if (auto elementOfSetType = as(info)) + { + if (elementOfSetType->getSet()->isSingleton()) + { + // Found a single possible type. Statically known concrete type. + auto singletonValue = elementOfSetType->getSet()->getElement(0); + inst->replaceUsesWith(singletonValue); + inst->removeAndDeallocate(); + return true; + } + else if (elementOfSetType->getSet()->isEmpty()) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + inst->removeAndDeallocate(); + return true; + } + else + { + // Multiple elements, emit a tag inst. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newInst = builder.emitGetTypeTagFromTaggedUnion(inst->getOperand(0)); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + + return false; + } + + bool isTaggedUnionType(IRInst* type) { return as(type) != nullptr; } + + IRType* updateType(IRType* currentType, IRType* newType) + { + if (auto valOfSetType = as(currentType)) + { + HashSet& setElements = *module->getContainerPool().getHashSet(); + forEachInSet( + valOfSetType->getSet(), + [&](IRInst* element) { setElements.add(element); }); + + if (auto newValOfSetType = as(newType)) + { + // If the new type is also a set, merge the two sets + forEachInSet( + newValOfSetType->getSet(), + [&](IRInst* element) { setElements.add(element); }); + } + else + { + // Otherwise, just add the new type to the set + setElements.add(newType); + } + + // If this is a set, we need to create a new set with the new type + IRBuilder builder(module); + auto newSet = builder.getSet(kIROp_TypeSet, setElements); + module->getContainerPool().free(&setElements); + return makeUntaggedUnionType(cast(newSet)); + } + else if (currentType == newType) + { + return currentType; + } + else if (currentType == nullptr) + { + return newType; + } + else if (as(currentType) && as(newType)) + { + // Merge the elements of both tagged unions into a new tuple type + return (IRType*)makeTaggedUnionType((unionSet( + as(currentType)->getWitnessTableSet(), + as(newType)->getWitnessTableSet()))); + } + else // Need to create a new set. + { + HashSet& setElements = *module->getContainerPool().getHashSet(); + + SLANG_ASSERT(!as(currentType) && !as(newType)); + + setElements.add(currentType); + setElements.add(newType); + + // If this is a set, we need to create a new set with the new type + IRBuilder builder(module); + auto newSet = builder.getSet(kIROp_TypeSet, setElements); + module->getContainerPool().free(&setElements); + return makeUntaggedUnionType(cast(newSet)); + } + } + + IRFuncType* getEffectiveFuncTypeForDispatcher( + IRWitnessTableSet* tableSet, + IRStructKey* key, + IRFuncSet* resultFuncSet) + { + SLANG_UNUSED(key); + + List& extraParamTypes = *module->getContainerPool().getList(); + extraParamTypes.add((IRType*)makeTagType(tableSet)); + + auto innerFuncType = getEffectiveFuncTypeForSet(resultFuncSet); + List& allParamTypes = *module->getContainerPool().getList(); + allParamTypes.addRange(extraParamTypes); + for (auto paramType : innerFuncType->getParamTypes()) + allParamTypes.add(paramType); + + IRBuilder builder(module); + auto resultFuncType = builder.getFuncType(allParamTypes, innerFuncType->getResultType()); + + module->getContainerPool().free(&extraParamTypes); + module->getContainerPool().free(&allParamTypes); + + return resultFuncType; + } + + // Get an effective func type to use for the callee. + // The callee may be a set, in which case, this returns a union-ed functype. + // + IRFuncType* getEffectiveFuncTypeForSet(IRFuncSet* calleeSet) + { + // The effective func type for a callee is calculated as follows: + // + // (i) we build up the effective parameter types for the callee + // by taking the union of each parameter type + // for each callee in the set. + // + // (ii) build up the effective result type in a similar manner. + // + // (iii) add extra tag parameters as necessary: + // + // - if we have multiple callees, then a parameter of TagType(callee) is appended + // to the beginning to select the callee. + // + // - if our callee is Specialize inst with set args, then for each + // table-set argument, a tag is required as input. + // + + if (calleeSet->isUnbounded()) + { + IRUnboundedFuncElement* unboundedFuncElement = + cast(calleeSet->tryGetUnboundedElement()); + return cast(unboundedFuncElement->getOperand(0)); + } + + IRBuilder builder(module); + + List& paramTypes = *module->getContainerPool().getList(); + IRType* resultType = nullptr; + + auto updateParamType = [&](Index index, IRType* paramType) -> IRType* + { + if (paramTypes.getCount() <= index) + { + // If this index hasn't been seen yet, expand the buffer and initialize + // the type. + // + paramTypes.growToCount(index + 1); + paramTypes[index] = paramType; + return paramType; + } + else + { + // Otherwise, update the existing type + auto [currentDirection, currentType] = + splitParameterDirectionAndType(paramTypes[index]); + auto [newDirection, newType] = splitParameterDirectionAndType(paramType); + auto updatedType = updateType(currentType, newType); + SLANG_ASSERT(currentDirection == newDirection); + paramTypes[index] = fromDirectionAndType(&builder, currentDirection, updatedType); + return updatedType; + } + }; + + List& calleesToProcess = *module->getContainerPool().getList(); + forEachInSet(calleeSet, [&](IRInst* func) { calleesToProcess.add(func); }); + + for (auto context : calleesToProcess) + { + auto paramEffectiveTypes = getEffectiveParamTypes(context); + auto paramDirections = getParamDirections(context); + + for (Index i = 0; i < paramEffectiveTypes.getCount(); i++) + updateParamType(i, getLoweredType(paramEffectiveTypes[i])); + + auto returnType = getFuncReturnInfo(context); + if (auto newResultType = getLoweredType(returnType)) + { + resultType = updateType(resultType, newResultType); + } + else if (auto funcType = as(context->getDataType())) + { + SLANG_ASSERT(isGlobalInst(funcType->getResultType())); + resultType = updateType(resultType, funcType->getResultType()); + } + else + { + SLANG_UNEXPECTED("Cannot determine result type for context"); + } + } + + module->getContainerPool().free(&calleesToProcess); + + // + // Add in extra parameter types for a call to a dynamic generic callee + // + + List& extraParamTypes = *module->getContainerPool().getList(); + + // If the any of the elements in the callee (or the callee itself in case + // of a singleton) is a dynamic specialization, each non-singleton WitnessTableSet, + // requries a corresponding tag input. + // + if (calleeSet->isSingleton() && isSetSpecializedGeneric(calleeSet->getElement(0))) + { + auto specializeInst = as(calleeSet->getElement(0)); + + // If this is a dynamic generic, we need to add a tag type for each + // WitnessTableSet in the callee. + // + for (UIndex i = 0; i < specializeInst->getArgCount(); i++) + if (auto tableSet = as(specializeInst->getArg(i))) + extraParamTypes.add((IRType*)makeTagType(tableSet)); + } + + List& allParamTypes = *module->getContainerPool().getList(); + allParamTypes.addRange(extraParamTypes); + allParamTypes.addRange(paramTypes); + + auto resultFuncType = builder.getFuncType(allParamTypes, resultType); + + module->getContainerPool().free(¶mTypes); + module->getContainerPool().free(&extraParamTypes); + module->getContainerPool().free(&allParamTypes); + + return resultFuncType; + } + + IRFuncType* getEffectiveFuncType(IRInst* callee) + { + IRBuilder builder(module); + return getEffectiveFuncTypeForSet( + cast(builder.getSingletonSet(kIROp_FuncSet, callee))); + } + + // Helper function for specializing calls. + // + // For a `Specialize` instruction that has dynamic tag arguments, + // extract all the tags and return them as a list. + // + void addArgsForSetSpecializedGeneric( + IRSpecialize* specializedCallee, + List& outCallArgs) + { + for (UInt ii = 0; ii < specializedCallee->getArgCount(); ii++) + { + auto specArg = specializedCallee->getArg(ii); + auto argInfo = specArg->getDataType(); + + // Pull all tag-type arguments from the specialization arguments + // and add them to the call arguments. + // + if (auto tagType = as(argInfo)) + if (as(tagType->getSet())) + outCallArgs.add(specArg); + } + } + + IRInst* maybeSpecializeCalleeType(IRInst* callee) + { + if (auto specializeInst = as(callee->getDataType())) + { + if (isGlobalInst(specializeInst)) + { + // callee->setFullType((IRType*)specializeGeneric(specializeInst)); + IRBuilder builder(module); + return builder.replaceOperand(&callee->typeUse, specializeGeneric(specializeInst)); + } + } + + return callee; + } + + // Filter out `NoneTypeElement` and `NoneWitnessTableElement` from a set (if any exist) + // and construct a new set of the same kind. + // + IRSetBase* filterNoneElements(IRSetBase* set) + { + auto setOp = set->getOp(); + + IRBuilder builder(module); + HashSet& filteredElements = *module->getContainerPool().getHashSet(); + bool containsNone = false; + forEachInSet( + set, + [&](IRInst* element) + { + switch (element->getOp()) + { + case kIROp_NoneTypeElement: + case kIROp_NoneWitnessTableElement: + containsNone = true; + break; + default: + filteredElements.add(element); + break; + } + }); + + if (!containsNone) + { + // Return the same set if there were no None elements + module->getContainerPool().free(&filteredElements); + return set; + } + + // Create a new set without the filtered elements + auto newFuncSet = cast(builder.getSet(setOp, filteredElements)); + module->getContainerPool().free(&filteredElements); + return newFuncSet; + } + + bool specializeCall(IRInst* context, IRCall* inst) + { + // The overall goal is to remove any dynamic-ness in the call inst + // (i.e. the callee as well as types of arguments should be global + // insts) + // + // There are a few cases we need to handle when specializing a call + // inst. + // + // First, we handle the callee: + // + // The callee can only be in a few specific patterns: + // + // 1. If the callee is already a concrete function, there's nothing to do + // + // 2. If the callee is a dynamic inst of tag type, then we need to look at the + // tag inst's structure: + // + // i. If inst is a GetTagForMappedSet (resulting from a lookup), + // + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let tag : TagType(funcSet) = GetTagForMappedSet(tableTag, key); + // let val = Call(tag, arg1, arg2, ...); + // becomes + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let dispatcher : FuncType(...) = GetDispatcher(witnessTableSet, key); + // let val = Call(dispatcher, tableTag, arg1, arg2, ...); + // where the dispatcher represents a dispatch function that selects the function + // based on the witness table tag. + // + // ii. If the inst is a GetTagForSpecializedCollection (resulting from a specialization + // resulting from a lookup), + // + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let tag : TagType(genericSet) = GetTagForMappedSet(tableTag, key); + // let specializedTag : + // TagType(funcSet) = GetTagForSpecializedCollection(tag, specArgs...); + // let val = Call(specializedTag, arg1, arg2, ...); + // becomes + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let dispatcher : FuncType(...) = + // GetSpecializedDispatcher(witnessTableSet, key, specArgs...); + // let val = Call(dispatcher, tableTag, arg1, arg2, ...); + // + // iii. If the inst is a Specialize of a concrete generic, then + // it means that one or more specialization arguments are dynamic. + // + // let specCallee = Specialize(generic, specArgs...); + // let val = Call(specCallee, callArgs...); + // becomes + // let specCallee = Specialize(generic, staticFormOfSpecArgs...); + // let val = Call(specCallee, dynamicSpecArgs..., callArgs...); + // where the new dynamicSpecArgs are the tag insts of WitnessTableSets + // and the static form is the corresponding WitnessTableSet itself. + // + // This creates a specialization that includes set arguments (and is handled + // by `specializeGenericWithSetArgs`) + // + // More concrete example for (iii): + // + // // --- before specialization --- + // let s1 : TagType(WitnessTableSet(tA, tB, tC)) = /* ... */; + // let s2 : TagType(TypeSet(A, B, C)) = /* ... */; + // let specCallee = Specialize(generic, s1, s2); + // let val = Call(specCallee, /* call args */); + // + // // --- after specialization --- + // let s1 : TagType(WitnessTableSet(tA, tB, tC)) = /* ... */; + // let s2 : TagType(TypeSet(A, B, C)) = /* ... */; + // let newSpecCallee = Specialize(generic, + // WitnessTableSet(tA, tB, tC), TypeSet(A, B, C)); + // let newVal = Call(newSpecCallee, s1, /* call args */); + // + // + // After the callee has been selected, we handle the argument types: + // It is possible that the parameters in the callee have been specialized + // to accept a super-set of types compared to the arguments from this call site + // + // In this case, we just upcast them using `upcastSet` before + // creating a new call inst + // + + auto callee = inst->getCallee(); + + if (isNoneCallee(callee)) + return false; + + List& callArgs = *module->getContainerPool().getList(); + + // This is a bit of a workaround for specialized callee's + // whose function types haven't been specialized yet (can + // occur for concrete IRSpecialize insts that are created + // during the specializeing process). + // + callee = maybeSpecializeCalleeType(callee); + + // If we're calling using a tag, place a call to the set, + // with the tag as the first argument. So the callee is + // the set itself. + // + if (auto setTag = as(callee->getDataType())) + { + if (!setTag->isSingleton() && !setTag->getSet()->isEmpty()) + { + // Multiple callees case: + // + // If we need to use a tag, we'll do a bit of an optimization here.. + // + // Instead of building a dispatcher on then func-set, we'll + // build it on the table set that it is looked up from. This + // avoids the extra map. + // + // This works primarily because this is the only way to call a dynamic + // function. If we ever have the ability to pass functions around more + // flexibly, then this should just become a specific case. + + if (auto tagMapOperand = as(callee)) + { + auto tableTag = tagMapOperand->getOperand(0); + auto lookupKey = cast(tagMapOperand->getOperand(1)); + + auto tableSet = cast( + cast(tableTag->getDataType())->getSet()); + IRBuilder builder(module); + + callee = builder.emitGetDispatcher( + getEffectiveFuncTypeForDispatcher( + tableSet, + lookupKey, + cast(setTag->getSet())), + tableSet, + lookupKey); + + callArgs.add(tableTag); + } + else if (auto specializedTagMapOperand = as(callee)) + { + auto innerTagMapOperand = + cast(specializedTagMapOperand->getOperand(0)); + auto tableTag = innerTagMapOperand->getOperand(0); + auto tableSet = cast( + cast(tableTag->getDataType())->getSet()); + auto lookupKey = cast(innerTagMapOperand->getOperand(1)); + + List specArgs; + for (UInt argIdx = 1; argIdx < specializedTagMapOperand->getOperandCount(); + ++argIdx) + { + auto arg = specializedTagMapOperand->getOperand(argIdx); + if (auto tagType = as(arg->getDataType())) + { + SLANG_ASSERT(!tagType->getSet()->isSingleton()); + if (as(tagType->getSet())) + { + callArgs.add(arg); + specArgs.add(tagType->getSet()); + } + else + { + specArgs.add(tagType->getSet()); + } + } + else + { + SLANG_ASSERT(isGlobalInst(arg)); + specArgs.add(arg); + } + } + + IRBuilder builder(module); + builder.setInsertBefore(callee); + callee = builder.emitGetSpecializedDispatcher( + getEffectiveFuncTypeForDispatcher( + tableSet, + lookupKey, + cast(setTag->getSet())), + tableSet, + lookupKey, + specArgs); + + callArgs.add(tableTag); + } + else + { + SLANG_UNEXPECTED("Cannot specialize call with non-singleton set tag callee"); + } + } + else if (isSetSpecializedGeneric(setTag->getSet()->getElement(0))) + { + // Single element which is a set specialized generic. + addArgsForSetSpecializedGeneric(cast(callee), callArgs); + callee = setTag->getSet()->getElement(0); + + auto funcType = getEffectiveFuncType(callee); + IRBuilder builder(module); + builder.setInsertInto(module); + callee = builder.replaceOperand(&callee->typeUse, funcType); + } + else + { + // If we reach here, then something is wrong. If our callee is an inst of tag-type, + // we expect it to either be a `GetTagForMappedSet`, `Specialize` or + // `GetTagForSpecializedSet`. + // Any other case should never occur (in the current design of the compiler) + // + SLANG_UNEXPECTED( + "Unexpected operand type for type-flow specialization of Call inst"); + } + } + else if (as(callee)) + { + // Occasionally, we will determine that there are absolutely no possible callees + // for a call site. This typically happens to impossible branches. + // + // If this happens, the inst representing the callee would have been replaced + // with a poison value. In this case, we're simply going to replace the entire call + // with a default-constructed value of the appropriate type. + // + // Note that it doesn't matter what we replace it with since this code should be + // effectively unreachable. + // + IRBuilder builder(context); + builder.setInsertBefore(inst); + auto defaultVal = builder.emitDefaultConstruct(inst->getDataType()); + inst->replaceUsesWith(defaultVal); + inst->removeAndDeallocate(); + return true; + } + else if (isGlobalInst(callee) && !isIntrinsic(callee)) + { + // If our callee is not a tag-type, then it is necessarily a simple concrete function. + // We will fix-up the function type so that it has the effective types as determined + // by the analysis. + // + auto funcType = getEffectiveFuncType(callee); + IRBuilder builder(module); + builder.setInsertInto(module); + callee = builder.replaceOperand(&callee->typeUse, funcType); + } + else if (isGlobalInst(callee)) + { + auto resultType = getLoweredType(getFuncReturnInfo(callee)); + if (resultType && resultType != inst->getFullType()) + { + auto oldFuncType = as(callee->getDataType()); + IRBuilder builder(module); + + List paramTypes; + for (auto paramType : oldFuncType->getParamTypes()) + paramTypes.add(paramType); + + auto newFuncType = builder.getFuncType(paramTypes, resultType); + builder.setInsertInto(module); + callee = builder.replaceOperand(&callee->typeUse, newFuncType); + } + } + + // If by this point, we haven't resolved our callee into a global inst ( + // either a set or a single function), then we can't specialize it (likely unbounded) + // + if (!isGlobalInst(callee)) + return false; + + // First, we'll legalize all operands by upcasting if necessary. + // This needs to be done even if the callee is not a set. + // + UCount extraArgCount = callArgs.getCount(); + for (UInt i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + const auto [paramDirection, paramType] = splitParameterDirectionAndType( + cast(callee->getFullType())->getParamType(i + extraArgCount)); + + switch (paramDirection.kind) + { + // We'll upcast any in-parameters. + case ParameterDirectionInfo::Kind::In: + { + IRBuilder builder(context); + builder.setInsertBefore(inst); + callArgs.add(upcastSet(&builder, arg, paramType)); + break; + } + + // Upcasting of out-parameters is the responsibility of the callee. + case ParameterDirectionInfo::Kind::Out: + + // For all other modes, sets must match ('subtyping' is not allowed) + case ParameterDirectionInfo::Kind::BorrowInOut: + case ParameterDirectionInfo::Kind::BorrowIn: + case ParameterDirectionInfo::Kind::Ref: + { + callArgs.add(arg); + break; + } + default: + SLANG_UNEXPECTED("Unhandled parameter direction in specializeCall"); + } + } + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + bool changed = false; + if (((UInt)callArgs.getCount()) != inst->getArgCount()) + changed = true; + else + { + for (Index i = 0; i < callArgs.getCount(); i++) + { + if (callArgs[i] != inst->getArg((UInt)i)) + { + changed = true; + break; + } + } + } + + if (callee != inst->getCallee()) + { + changed = true; + } + + auto calleeFuncType = cast(callee->getFullType()); + + if (changed) + { + auto newCall = builder.emitCallInst(calleeFuncType->getResultType(), callee, callArgs); + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); + } + else if (calleeFuncType->getResultType() != inst->getFullType()) + { + // If we didn't change the callee or the arguments, we still might + // need to update the result type. + // + inst->setFullType(calleeFuncType->getResultType()); + changed = true; + } + + module->getContainerPool().free(&callArgs); + + return changed; + } + + bool specializeMakeStruct(IRInst* context, IRMakeStruct* inst) + { + // The main thing to handle here is that we might have specialized + // the fields of the struct, so we need to upcast the arguments + // if necessary. + // + + auto structType = as(inst->getDataType()); + if (!structType) + return false; + + // Reinterpret any of the arguments as necessary. + bool changed = false; + UIndex operandIndex = 0; + for (auto field : structType->getFields()) + { + auto arg = inst->getOperand(operandIndex); + IRBuilder builder(context); + builder.setInsertBefore(inst); + auto newArg = upcastSet(&builder, arg, field->getFieldType()); + + if (arg != newArg) + { + changed = true; + inst->setOperand(operandIndex, newArg); + } + + operandIndex++; + } + + return changed; + } + + bool specializeMakeExistential(IRInst* context, IRMakeExistential* inst) + { + // After specialization, existentials (that are not unbounded) are treated as tuples + // of a WitnessTableSet tag and a value of type TypeSet. + // + // A MakeExistential is just converted into a MakeTaggedUnion, with any necessary + // upcasts. + // + + auto info = tryGetInfo(context, inst); + auto taggedUnion = as(info); + if (!taggedUnion) + return false; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Collect types from the witness tables to determine the any-value type + auto tableSet = taggedUnion->getWitnessTableSet(); + auto typeSet = taggedUnion->getTypeSet(); + + IRInst* witnessTableTag = nullptr; + if (auto witnessTable = as(inst->getWitnessTable())) + { + witnessTableTag = builder.emitGetTagOfElementInSet( + (IRType*)makeTagType(tableSet), + witnessTable, + tableSet); + } + else if (as(inst->getWitnessTable()->getDataType())) + { + witnessTableTag = upcastSet(&builder, inst->getWitnessTable(), makeTagType(tableSet)); + } + + // Create the appropriate any-value type + auto effectiveType = typeSet->isSingleton() + ? (IRType*)typeSet->getElement(0) + : builder.getUntaggedUnionType((IRType*)typeSet); + + // Pack the value + auto packedValue = as(effectiveType) + ? builder.emitPackAnyValue(effectiveType, inst->getWrappedValue()) + : inst->getWrappedValue(); + + auto taggedUnionType = getLoweredType(taggedUnion); + + inst->replaceUsesWith(builder.emitMakeTaggedUnion( + taggedUnionType, + builder.emitPoison(makeTagType(typeSet)), + witnessTableTag, + packedValue)); + inst->removeAndDeallocate(); + return true; + } + + bool specializeWrapExistential(IRInst* context, IRWrapExistential* inst) + { + SLANG_UNUSED(context); + inst->replaceUsesWith(inst->getWrappedValue()); + inst->removeAndDeallocate(); + return true; + } + + bool specializeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) + { + // A CreateExistentialObject uses an user-provided ID to create an object. + // Note that this ID is not the same as the tags we use. The user-provided ID must be + // compared against the SequentialID, which is a globally consistent & public ID present + // on the witness tables. + // + // The tags are a locally consistent ID whose semantics are only meaningful within the + // function. We use a special op `GetTagFromSequentialID` to convert from the user-provided + // global ID to a local tag ID. + // + + auto info = tryGetInfo(context, inst); + auto taggedUnion = as(info); + if (!taggedUnion) + return false; + + auto taggedUnionType = as(getLoweredType(taggedUnion)); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + IRInst* args[] = {inst->getDataType(), inst->getTypeID()}; + auto translatedTag = builder.emitIntrinsicInst( + (IRType*)makeTagType(taggedUnionType->getWitnessTableSet()), + kIROp_GetTagFromSequentialID, + 2, + args); + + IRInst* packedValue = nullptr; + auto set = taggedUnionType->getTypeSet(); + if (!set->isSingleton()) + { + packedValue = builder.emitPackAnyValue( + (IRType*)builder.getUntaggedUnionType(set), + inst->getValue()); + } + else + { + packedValue = builder.emitReinterpret( + (IRType*)builder.getUntaggedUnionType(set), + inst->getValue()); + } + + auto newInst = builder.emitMakeTaggedUnion( + (IRType*)taggedUnionType, + builder.emitPoison(makeTagType(taggedUnionType->getTypeSet())), + translatedTag, + packedValue); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + + bool specializeStructuredBufferLoad(IRInst* context, IRInst* inst) + { + // The key thing to take care of here is a load from an + // interface-typed pointer. + // + // Our type-flow analysis will convert the + // result into a set of all available implementations of this + // interface, so we need to cast the result. + // + + auto valInfo = tryGetInfo(context, inst); + + if (!valInfo) + return false; + + auto bufferType = (IRType*)inst->getOperand(0)->getDataType(); + auto bufferBaseType = (IRType*)bufferType->getOperand(0); + + auto specializedValType = (IRType*)getLoweredType(valInfo); + if (bufferBaseType != specializedValType) + { + if (as(bufferBaseType) && !isComInterfaceType(bufferBaseType) && + !isBuiltin(bufferBaseType)) + { + // If we're dealing with a loading a known tagged union value from + // an interface-typed pointer, we'll cast the pointer itself and + // defer the specializing of the load until a later lowering + // pass. + // + // This avoids having to change the source pointer type + // and confusing any future runs of the type flow + // analysis pass. + // + // This is slightly different from how a local 'load' is handled, + // because we don't want to modify the pointer (and consequently the global + // buffer) type, since it is a publicly visible type that is laid out + // in a certain way. + // + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto bufferHandle = inst->getOperand(0); + auto newHandle = builder.emitIntrinsicInst( + builder.getPtrType(specializedValType), + kIROp_CastInterfaceToTaggedUnionPtr, + 1, + &bufferHandle); + IRInst* newLoadOperands[] = {newHandle, inst->getOperand(1)}; + auto newLoad = builder.emitIntrinsicInst( + specializedValType, + inst->getOp(), + 2, + newLoadOperands); + + inst->replaceUsesWith(newLoad); + inst->removeAndDeallocate(); + + return true; + } + } + else if (inst->getDataType() != bufferBaseType) + { + // If the data type is not the same, we need to update it. + inst->setFullType((IRType*)getLoweredType(valInfo)); + return true; + } + + return false; + } + + bool specializeSpecialize(IRInst* context, IRSpecialize* inst) + { + // When specializing a `Specialize` instruction, we have a few nuances. + // + // If we're dealing with specializing a type, witness table, or any other + // generic, we simply drop all dynamic tag information, and replace all + // operands with their set variants. + // + // If we're dealing with a function, there are two cases: + // - A single function when dynamic specialization arguments. + // Removing the dynamic tag information will result in the eventual 'call' + // inst not have access to these insts. + // + // Instead, we'll just replace the type, and retain the `Specialize` inst with + // the dynamic args. It will be specialized out in `specializeCall` instead. + // + // - A set of functions with concrete specialization arguments. + // In this case, we will emit an instruction to map from the input generic set + // to the output specialized set via `GetTagForSpecializedSet`. + // This inst encodes the key-value mapping in its operands: + // e.g.(input_tag, key0, value0, key1, value1, ...) + // + // - The case where there is a set of functions with dynamic specialization arguments + // is not currently properly handled. This case should not arise naturally since we + // don't advertise support for it. + // + + bool isFuncReturn = false; + + if (auto concreteGeneric = as(inst->getBase())) + isFuncReturn = as(getGenericReturnVal(concreteGeneric)) != nullptr; + else if (auto tagType = as(inst->getBase()->getDataType())) + { + auto firstConcreteGeneric = as(tagType->getSet()->getElement(0)); + isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; + } + + // We'll emit a dynamic tag inst if the result is a func set with multiple elements + if (isFuncReturn) + { + if (auto info = tryGetInfo(context, inst)) + { + if (auto elementOfSetType = as(info)) + { + // Note for future reworks: + // Should we make it such that the `GetTagForSpecializedSet` + // is emitted in the single func case too? + // + // Basically, as long as any of the specialization operands are dynamic, + // we should probably emit a tag. + // + // Currently, if the func is a singleton, we leave it as a Specialize inst + // with dynamic args to be handled in specializeCall. + // + + if (elementOfSetType->getSet()->isSingleton()) + { + // If the result is a singleton set, we can just + // replace the type (if necessary) and be done with it. + return replaceType(context, inst); + } + else + { + // Otherwise, we'll emit a tag mapping instruction. + IRBuilder builder(inst); + setInsertBeforeOrdinaryInst(&builder, inst); + + List& specOperands = *module->getContainerPool().getList(); + specOperands.add(inst->getBase()); + + for (UInt ii = 0; ii < inst->getArgCount(); ii++) + specOperands.add(inst->getArg(ii)); + + auto newInst = builder.emitIntrinsicInst( + (IRType*)makeTagType(elementOfSetType->getSet()), + kIROp_GetTagForSpecializedSet, + specOperands.getCount(), + specOperands.getBuffer()); + module->getContainerPool().free(&specOperands); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + else + { + SLANG_UNEXPECTED("Expected element-of-set type for function specialization"); + } + } + } + + // For all other specializations, we'll 'drop' the dynamic tag information. + bool changed = false; + List& args = *module->getContainerPool().getList(); + for (UIndex i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + auto argDataType = arg->getDataType(); + if (auto setTagType = as(argDataType)) + { + // If this is a tag type, replace with set. + changed = true; + if (as(setTagType->getSet())) + { + args.add(setTagType->getSet()); + } + else if (auto typeSet = as(setTagType->getSet())) + { + IRBuilder builder(inst); + args.add(builder.getUntaggedUnionType(typeSet)); + } + } + else + { + args.add(arg); + } + } + + IRBuilder builder(inst); + IRType* typeForSpecialization = builder.getTypeKind(); + + if (changed) + { + auto newInst = builder.emitSpecializeInst( + typeForSpecialization, + inst->getBase(), + args.getCount(), + args.getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + module->getContainerPool().free(&args); + return changed; + } + + bool specializeGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) + { + // `GetValueFromBoundInterface` is essentially accessing the value component of + // an existential. If the operand has been specialized into a tagged-union, then we can + // turn it into a `GetValueFromTaggedUnion`. + // + + SLANG_UNUSED(context); + + auto operandInfo = inst->getOperand(0)->getDataType(); + if (as(operandInfo)) + { + IRBuilder builder(inst); + setInsertAfterOrdinaryInst(&builder, inst); + auto newInst = builder.emitGetValueFromTaggedUnion(inst->getOperand(0)); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + return false; + } + + bool specializeGetElementFromTag(IRInst* context, IRGetElementFromTag* inst) + { + // During specialization, we convert all "element-of" types into + // run-time tag types so that we can obtain and use this information. + // + // Thus, `GetElementFromTag` simply becomes a no-op. Any instructions using + // the result of `GetElementFromTag` should be specialized accordingly to + // use the tag operand. + // + SLANG_UNUSED(context); + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + return true; + } + + bool specializeLoad(IRInst* context, IRInst* inst) + { + // There's two cases to handle.. + // + // (i) For a simple load, the pointer itself is already specialized. + // so we just need to replace the type of the load with specialized type. + // + // (ii) if there is a mismatch between the two types, the most likely + // case is that we're trying to load from an interface typed location + // (whose type we cannot modify), and cast it into a tagged union tuple. + // + // This case is similar to `specializeStructuredBufferLoad`, where we + // cast the _pointer_ to convert its type, and defer the legalization of + // the load to a later lowering pass. + // + + auto valInfo = tryGetInfo(context, inst); + + if (!valInfo) + return false; + + auto loadPtr = as(inst)->getPtr(); + auto loadPtrType = as(loadPtr->getDataType()); + auto ptrValType = loadPtrType->getValueType(); + + IRType* specializedType = (IRType*)getLoweredType(valInfo); + if (ptrValType != specializedType) + { + SLANG_ASSERT(!as(inst)); + + if (as(ptrValType) && !isComInterfaceType(ptrValType) && + !isBuiltin(ptrValType)) + { + // If we're dealing with a loading a known tagged union value from + // an interface-typed pointer, we'll cast the pointer itself and + // defer the specializeing of the load until later. + // + // This avoids having to change the source pointer type + // and confusing any future runs of the type flow + // analysis pass. + // + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto newLoadPtr = builder.emitIntrinsicInst( + builder.getPtrTypeWithAddressSpace(specializedType, loadPtrType), + kIROp_CastInterfaceToTaggedUnionPtr, + 1, + &loadPtr); + auto newLoad = builder.emitLoad(specializedType, newLoadPtr); + + inst->replaceUsesWith(newLoad); + inst->removeAndDeallocate(); + + return true; + } + } + else if (inst->getDataType() != ptrValType) + { + inst->setFullType((IRType*)getLoweredType(valInfo)); + return true; + } + + return false; + } + + bool handleDefaultStore(IRInst* context, IRStore* inst) + { + // This handles a rare case in the compiler, where we + // try to use default-construct to initialize a field. + // + // This is not technically supported, but it can occur + // during some corner cases during higher-order auto-diff. + // + // In case we've specialized the field, we just need to + // modify the default-construct operand's type to + // match the field. + // + SLANG_UNUSED(context); + SLANG_ASSERT(inst->getVal()->getOp() == kIROp_DefaultConstruct); + auto ptr = inst->getPtr(); + auto destInfo = as(ptr->getDataType())->getValueType(); + auto valInfo = inst->getVal()->getDataType(); + + // "Legalize" the store type. + if (destInfo != valInfo) + { + inst->getVal()->setFullType(destInfo); + return true; + } + else + return false; + } + + bool specializeStore(IRInst* context, IRStore* inst) + { + // Similar to `specializeLoad`, we handle cases where + // the pointer has been specialized, so that we upcast + // our value to match the type before writing to the location. + // + + auto ptr = inst->getPtr(); + auto ptrInfo = as(ptr->getDataType())->getValueType(); + + // Special case for default initialization: + // + // Raw default initialization has been almost entirely + // removed from Slang, but the auto-diff process can sometimes + // produce a store of default-constructed value. + // + if (as(inst->getVal())) + return handleDefaultStore(context, inst); + + IRBuilder builder(context); + builder.setInsertBefore(inst); + auto specializedVal = upcastSet(&builder, inst->getVal(), ptrInfo); + + if (specializedVal != inst->getVal()) + { + // If the value was changed, we need to update the store instruction. + builder.replaceOperand(inst->getValUse(), specializedVal); + return true; + } + + return false; + } + + bool specializeGetSequentialID(IRInst* context, IRGetSequentialID* inst) + { + // A sequential ID is a globally unique ID for a witness table, while the + // the tags we use in the specialization are only locally consistent. + // + // To extract the global ID, we'll use a separate op code `GetSequentialIDFromTag` + // for now and lower it later once all the global sequential IDs have been assigned. + // + SLANG_UNUSED(context); + auto arg = inst->getOperand(0); + if (auto tagType = as(arg->getDataType())) + { + IRBuilder builder(inst); + setInsertAfterOrdinaryInst(&builder, inst); + auto firstElement = tagType->getSet()->getElement(0); + auto interfaceType = + as(as(firstElement)->getConformanceType()); + IRInst* args[] = {interfaceType, arg}; + auto newInst = builder.emitIntrinsicInst( + (IRType*)builder.getUIntType(), + kIROp_GetSequentialIDFromTag, + 2, + args); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + + return false; + } + + bool specializeIsType(IRInst* context, IRIsType* inst) + { + // The is-type checks equality between two witness tables + // + // We'll turn this into a tag comparison by extracting the tag + // for a specific element in the set, and comparing that to the + // dynamic witness table tag. + // + SLANG_UNUSED(context); + auto witnessTableArg = inst->getValueWitness(); + if (auto tagType = as(witnessTableArg->getDataType())) + { + IRBuilder builder(inst); + setInsertAfterOrdinaryInst(&builder, inst); + + auto targetTag = builder.emitGetTagOfElementInSet( + (IRType*)tagType, + inst->getTargetWitness(), + tagType->getSet()); + auto eqlInst = builder.emitEql(targetTag, witnessTableArg); + + inst->replaceUsesWith(eqlInst); + inst->removeAndDeallocate(); + return true; + } + + return false; + } + + bool specializeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) + { + if (auto taggedUnionType = as(tryGetInfo(context, inst))) + { + // If we're dealing with a `MakeOptionalNone` for an existential type, then + // this just becomes a tagged union tuple where the set of tables is {none} + // (i.e. singleton set of none witness) + // + + IRBuilder builder(module); + builder.setInsertBefore(inst); + + // Create a tuple for the empty type.. + SLANG_ASSERT(taggedUnionType->getWitnessTableSet()->isSingleton()); + auto noneWitnessTable = taggedUnionType->getWitnessTableSet()->getElement(0); + + auto singletonWitnessTableTagType = + makeTagType(builder.getSingletonSet(kIROp_WitnessTableSet, noneWitnessTable)); + IRInst* tableTag = builder.emitGetTagOfElementInSet( + (IRType*)singletonWitnessTableTagType, + noneWitnessTable, + taggedUnionType->getWitnessTableSet()); + + auto singletonTypeTagType = + makeTagType(builder.getSingletonSet(kIROp_TypeSet, builder.getNoneTypeElement())); + IRInst* typeTag = builder.emitGetTagOfElementInSet( + (IRType*)singletonTypeTagType, + builder.getNoneTypeElement(), + taggedUnionType->getTypeSet()); + + auto newTuple = builder.emitMakeTaggedUnion( + (IRType*)taggedUnionType, + typeTag, + tableTag, + builder.emitDefaultConstruct(makeUntaggedUnionType(taggedUnionType->getTypeSet()))); + + inst->replaceUsesWith(newTuple); + propagationMap[InstWithContext(context, newTuple)] = taggedUnionType; + inst->removeAndDeallocate(); + + return true; + } + + return false; + } + + bool specializeMakeOptionalValue(IRInst* context, IRMakeOptionalValue* inst) + { + SLANG_UNUSED(context); + if (as(inst->getValue()->getDataType())) + { + // If we're dealing with a `MakeOptionalValue` for an existential type, + // we don't actually have to change anything, since logically, the input and output + // represent the same set of types and tables. + // + // We'll do a simple replace. + // + + auto newInst = inst->getValue(); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + + return true; + } + + return false; + } + + bool specializeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) + { + SLANG_UNUSED(context); + if (auto srcTaggedUnionType = + as(inst->getOptionalOperand()->getDataType())) + { + // How we handle `GetOptionalValue` depends on whether our analysis + // shows that there can be a 'none' type in the input. + // + // If not, we can do a simple replace with the operand since the + // input and output are equivalent (this is a no-op) + // + // If so, then we will cast the tagged-union's tag to a sub-set (without the none), + // and unpack the untagged union-type'd value into the smaller union type. + // + // Note that the union is "smaller" only in the sense that it doesn't have the 'none' + // type. Since a 'none' value takes up 0 space, the two union types will end up having + // the same size in the end. + // + + + IRBuilder builder(inst); + auto destTaggedUnionType = cast(tryGetInfo(context, inst)); + if (destTaggedUnionType != srcTaggedUnionType) + { + // If the source and destination tagged-union types are different, + // we need to emit a cast. + builder.setInsertBefore(inst); + IRInst* tag = builder.emitGetTagFromTaggedUnion(inst->getOptionalOperand()); + auto downcastedTag = builder.emitIntrinsicInst( + (IRType*)makeTagType(destTaggedUnionType->getWitnessTableSet()), + kIROp_GetTagForSubSet, + 1, + &tag); + + auto unpackedValue = builder.emitUnpackAnyValue( + getLoweredType(makeUntaggedUnionType(destTaggedUnionType->getTypeSet())), + builder.emitGetValueFromTaggedUnion(inst->getOptionalOperand())); + + auto newTaggedUnion = builder.emitMakeTaggedUnion( + getLoweredType(destTaggedUnionType), + builder.emitPoison(makeTagType(destTaggedUnionType->getTypeSet())), + downcastedTag, + unpackedValue); + + + inst->replaceUsesWith(newTaggedUnion); + inst->removeAndDeallocate(); + } + else + { + inst->replaceUsesWith(inst->getOptionalOperand()); + inst->removeAndDeallocate(); + } + + return true; + } + return false; + } + + bool specializeOptionalHasValue(IRInst* context, IROptionalHasValue* inst) + { + SLANG_UNUSED(context); + if (auto taggedUnionType = as(inst->getOptionalOperand()->getDataType())) + { + // The logic here is similar to specializing IsType, but we'll directly compare + // tags instead of trying to use sequential ID. + // + // There's two cases to handle here: + // 1. We statically know that it cannot be a 'none' because the + // input's set type doesn't have a 'none'. In this case + // we just return a true. + // + // 2. 'none' is a possibility. In this case, we'll get the + // tag of 'none' in the set by emitting a `GetTagOfElementInSet` + // compare those to determine if we have a value. + // + + IRBuilder builder(inst); + + bool containsNone = false; + forEachInSet( + taggedUnionType->getWitnessTableSet(), + [&](IRInst* wt) + { + if (wt == getNoneWitness()) + containsNone = true; + }); + + if (!containsNone) + { + // If 'none' isn't a part of the set, statically set + // to true. + // + + auto trueVal = builder.getBoolValue(true); + inst->replaceUsesWith(trueVal); + inst->removeAndDeallocate(); + return true; + } + else + { + // Otherwise, we'll extract the tag and compare against + // the value for 'none' (in the context of the tag's set) + // + builder.setInsertBefore(inst); + + auto dynTag = builder.emitGetTagFromTaggedUnion(inst->getOptionalOperand()); + + // Cast the singleton tag to the target set tag (will convert the + // value to the corresponding value for the larger set) + // + auto noneWitnessTag = builder.emitGetTagOfElementInSet( + (IRType*)makeTagType(taggedUnionType->getWitnessTableSet()), + getNoneWitness(), + taggedUnionType->getWitnessTableSet()); + + auto newInst = builder.emitNeq(dynTag, noneWitnessTag); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + return false; + } + + void collectExistentialTables(IRInterfaceType* interfaceType, HashSet& outTables) + { + IRWitnessTableType* targetTableType = nullptr; + // First, find the IRWitnessTableType that wraps the given interfaceType + for (auto use = interfaceType->firstUse; use; use = use->nextUse) + { + if (auto wtType = as(use->getUser())) + { + if (wtType->getConformanceType() == interfaceType) + { + targetTableType = wtType; + break; + } + } + } + + // If the target witness table type was found, gather all witness tables using it + if (targetTableType) + { + for (auto use = targetTableType->firstUse; use; use = use->nextUse) + { + if (auto witnessTable = as(use->getUser())) + { + if (witnessTable->getDataType() == targetTableType) + { + outTables.add(witnessTable); + } + } + } + } + } + + bool processModule() + { + bool hasChanges = false; + + // Part 1: Information Propagation + // This phase propagates type information through the module + // and records them into different maps in the current context. + // + performInformationPropagation(); + + if (sink->getErrorCount() > 0) + { + // If there were errors during propagation, we bail out early. + return false; + } + + // Part 2: Dynamic Instruction Specialization + // Re-write dynamic instructions into specialized versions based on the + // type information in the previous phase. + // + hasChanges |= performDynamicInstLowering(); + + return hasChanges; + } + + TypeFlowSpecializationContext(IRModule* module, DiagnosticSink* sink) + : module(module), sink(sink) + { + } + + + // Basic context + IRModule* module; + DiagnosticSink* sink; + + // Mapping from (context, inst) --> propagated info + Dictionary propagationMap; + + // Mapping from context --> return value info + Dictionary funcReturnInfo; + + // Mapping from (struct field) --> propagated info + Dictionary fieldInfo; + + // Mapping from context --> Set<(context, inst)> + // + // Maintains a mapping from a callable context to all call-sites + // (and caller contexts) + // + Dictionary> funcCallSites; + + // Mapping from (struct-field) --> Set<(context, inst)> + // + // Maintains a mapping from a struct field to all accesses of that + // field + // + Dictionary> fieldUseSites; + + // Set of already discovered contexts. + HashSet availableContexts; +}; + +// Main entry point +bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink) +{ + TypeFlowSpecializationContext context(module, sink); + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-typeflow-specialize.h b/source/slang/slang-ir-typeflow-specialize.h new file mode 100644 index 00000000000..8a4dfd8efbf --- /dev/null +++ b/source/slang/slang-ir-typeflow-specialize.h @@ -0,0 +1,24 @@ +// slang-ir-typeflow-specialize.h +#pragma once +#include "slang-ir.h" + +namespace Slang +{ + +// Convert dynamic insts such as `LookupWitnessMethod`, `ExtractExistentialValue`, +// `ExtractExistentialType`, `ExtractExistentialWitnessTable` and more into specialized versions +// based on the possible values at at the use sites, based on a data-flow-style interprocedural +// analysis. +// +// This pass is intended to be run after all specialization insts with concrete arguments have +// already been processed. +// +// This pass may generate more `Specialize` insts, so it should be run in a loop with +// the standard specialization pass until a no more changes can be made. +// +bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink); + +bool isSetSpecializedGeneric(IRInst* callee); + +IROp getSetOpFromType(IRType* type); +} // namespace Slang diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index d25d1164d7d..f43228026d9 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -183,8 +183,29 @@ IRInst* specializeWithGeneric( { genArgs.add(param); } + + // Default to type kind for now. + IRType* typeForSpecialization = builder.getTypeKind(); + + auto dataType = genericToSpecialize->getDataType(); + if (dataType) + { + if (dataType->getOp() == kIROp_TypeKind || dataType->getOp() == kIROp_GenericKind) + { + typeForSpecialization = (genericToSpecialize)->getDataType(); + } + else if (dataType->getOp() == kIROp_Generic) + { + typeForSpecialization = (IRType*)builder.emitSpecializeInst( + builder.getTypeKind(), + (genericToSpecialize)->getDataType(), + genArgs.getCount(), + genArgs.getBuffer()); + } + } + return builder.emitSpecializeInst( - builder.getTypeKind(), + typeForSpecialization, genericToSpecialize, (UInt)genArgs.getCount(), genArgs.getBuffer()); @@ -1710,6 +1731,7 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* case kIROp_GlobalConstant: case kIROp_Var: case kIROp_Param: + case kIROp_DebugVar: break; case kIROp_Call: return true; @@ -2830,4 +2852,23 @@ bool canRelaxInstOrderRule(IRInst* inst, IRInst* useOfInst) return isSameBlock && isGenericParameter(useOfInst) && (useOfInst->getDataType() == inst); } +IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc) +{ + SLANG_UNUSED(usageLoc); + + if (auto decor = type->findDecoration()) + { + return decor->getSize(); + } + + // We could conceivably make it an error to have an interface + // without an `[anyValueSize(...)]` attribute, but then we risk + // producing error messages even when doing 100% static specialization. + // + // It is simpler to use a reasonable default size and treat any + // type without an explicit attribute as using that size. + // + return kDefaultAnyValueSize; +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 4423a7e2bd8..bf19b0687cc 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -11,6 +11,11 @@ namespace Slang struct GenericChildrenMigrationContextImpl; struct IRCloneEnv; +constexpr IRIntegerValue kInvalidAnyValueSize = 0xFFFFFFFF; +constexpr IRIntegerValue kDefaultAnyValueSize = 16; +constexpr SlangInt kRTTIHeaderSize = 16; +constexpr SlangInt kRTTIHandleSize = 8; + // A helper class to clone children insts to a different generic parent that has equivalent set of // generic parameters. The clone will take care of substitution of equivalent generic parameters and // intermediate values between the two generic parents. @@ -460,6 +465,8 @@ bool isGenericParameter(IRInst* inst); // of the generic parameter. bool canRelaxInstOrderRule(IRInst* instToCheck, IRInst* otherInst); +IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc); + } // namespace Slang #endif diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp deleted file mode 100644 index cfda6225edd..00000000000 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ /dev/null @@ -1,280 +0,0 @@ -// slang-ir-witness-table-wrapper.cpp -#include "slang-ir-witness-table-wrapper.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -struct GenericsLoweringContext; - -struct GenerateWitnessTableWrapperContext -{ - SharedGenericsLoweringContext* sharedContext; - - // Represents a work item for packing `inout` or `out` arguments after a concrete call. - struct ArgumentPackWorkItem - { - // A `AnyValue` typed destination. - IRInst* dstArg = nullptr; - // A concrete value to be packed. - IRInst* concreteArg = nullptr; - }; - - // Unpack an `arg` of `IRAnyValue` into concrete type if necessary, to make it feedable into the - // parameter. If `arg` represents a AnyValue typed variable passed in to a concrete `out` - // parameter, this function indicates that it needs to be packed after the call by setting - // `packAfterCall`. - IRInst* maybeUnpackArg( - IRBuilder* builder, - IRType* paramType, - IRInst* arg, - ArgumentPackWorkItem& packAfterCall) - { - packAfterCall.dstArg = nullptr; - packAfterCall.concreteArg = nullptr; - - // If either paramType or argType is a pointer type - // (because of `inout` or `out` modifiers), we extract - // the underlying value type first. - IRType* paramValType = paramType; - IRType* argValType = arg->getDataType(); - IRInst* argVal = arg; - if (auto ptrType = as(paramType)) - { - paramValType = ptrType->getValueType(); - } - auto argType = arg->getDataType(); - if (auto argPtrType = as(argType)) - { - argValType = argPtrType->getValueType(); - argVal = builder->emitLoad(arg); - } - - // Unpack `arg` if the parameter expects concrete type but - // `arg` is an AnyValue. - if (!as(paramValType) && as(argValType)) - { - auto unpackedArgVal = builder->emitUnpackAnyValue(paramValType, argVal); - // if parameter expects an `out` pointer, store the unpacked val into a - // variable and pass in a pointer to that variable. - if (as(paramType)) - { - auto tempVar = builder->emitVar(paramValType); - builder->emitStore(tempVar, unpackedArgVal); - // tempVar needs to be unpacked into original var after the call. - packAfterCall.dstArg = arg; - packAfterCall.concreteArg = tempVar; - return tempVar; - } - else - { - return unpackedArgVal; - } - } - return arg; - } - - IRStringLit* _getWitnessTableWrapperFuncName(IRFunc* func) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(func); - if (auto linkageDecoration = func->findDecoration()) - { - return builder->getStringValue( - (String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); - } - if (auto namehintDecoration = func->findDecoration()) - { - return builder->getStringValue( - (String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); - } - return nullptr; - } - - IRFunc* emitWitnessTableWrapper(IRFunc* func, IRInst* interfaceRequirementVal) - { - auto funcTypeInInterface = cast(interfaceRequirementVal); - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(func); - - auto wrapperFunc = builder->createFunc(); - wrapperFunc->setFullType((IRType*)interfaceRequirementVal); - if (auto name = _getWitnessTableWrapperFuncName(func)) - builder->addNameHintDecoration(wrapperFunc, name); - - builder->setInsertInto(wrapperFunc); - auto block = builder->emitBlock(); - builder->setInsertInto(block); - - ShortList params; - for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) - { - params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); - } - - List args; - List argsToPack; - - SLANG_ASSERT(params.getCount() == (Index)func->getParamCount()); - for (UInt i = 0; i < func->getParamCount(); i++) - { - auto wrapperParam = params[i]; - // Type of the parameter in the callee. - auto funcParamType = func->getParamType(i); - - // If the implementation expects a concrete type - // (either in the form of a pointer for `out`/`inout` parameters, - // or in the form a value for `in` parameters, while - // the interface exposes an AnyValue type, - // we need to unpack the AnyValue argument to the appropriate - // concerete type. - ArgumentPackWorkItem packWorkItem; - auto newArg = maybeUnpackArg(builder, funcParamType, wrapperParam, packWorkItem); - args.add(newArg); - if (packWorkItem.concreteArg) - argsToPack.add(packWorkItem); - } - auto call = builder->emitCallInst(func->getResultType(), func, args); - - // Pack all `out` arguments. - for (auto item : argsToPack) - { - auto anyValType = cast(item.dstArg->getDataType())->getValueType(); - auto concreteVal = builder->emitLoad(item.concreteArg); - auto packedVal = builder->emitPackAnyValue(anyValType, concreteVal); - builder->emitStore(item.dstArg, packedVal); - } - - // Pack return value if necessary. - if (!as(call->getDataType()) && - as(funcTypeInInterface->getResultType())) - { - auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); - builder->emitReturn(pack); - } - else - { - if (call->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(call); - } - return wrapperFunc; - } - - void lowerWitnessTable(IRWitnessTable* witnessTable) - { - auto interfaceType = as(witnessTable->getConformanceType()); - if (!interfaceType) - return; - if (isBuiltin(interfaceType)) - return; - if (isComInterfaceType(interfaceType)) - return; - - // We need to consider whether the concrete type that is conforming - // in this witness table actually fits within the declared any-value - // size for the interface. - // - // If the type doesn't fit then it would be invalid to use for dynamic - // dispatch, and the packing/unpacking operations we emit would fail - // to generate valid code. - // - // Such a type might still be useful for static specialization, so - // we can't consider this case a hard error. - // - auto concreteType = witnessTable->getConcreteType(); - IRIntegerValue typeSize, sizeLimit; - bool isTypeOpaque = false; - if (!sharedContext->doesTypeFitInAnyValue( - concreteType, - interfaceType, - &typeSize, - &sizeLimit, - &isTypeOpaque)) - { - HashSet visited; - if (isTypeOpaque) - { - sharedContext->sink->diagnose( - concreteType, - Diagnostics::typeCannotBePackedIntoAnyValue, - concreteType); - } - else - { - sharedContext->sink->diagnose( - concreteType, - Diagnostics::typeDoesNotFitAnyValueSize, - concreteType); - sharedContext->sink->diagnoseWithoutSourceView( - concreteType, - Diagnostics::typeAndLimit, - concreteType, - typeSize, - sizeLimit); - } - return; - } - - for (auto child : witnessTable->getChildren()) - { - auto entry = as(child); - if (!entry) - continue; - auto interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( - interfaceType, - entry->getRequirementKey()); - if (auto ordinaryFunc = as(entry->getSatisfyingVal())) - { - auto wrapper = emitWitnessTableWrapper(ordinaryFunc, interfaceRequirementVal); - entry->satisfyingVal.set(wrapper); - sharedContext->addToWorkList(wrapper); - } - } - } - - void processInst(IRInst* inst) - { - if (auto witnessTable = as(inst)) - { - lowerWitnessTable(witnessTable); - } - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - } -}; - -void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext) -{ - GenerateWitnessTableWrapperContext context; - context.sharedContext = sharedContext; - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-witness-table-wrapper.h b/source/slang/slang-ir-witness-table-wrapper.h deleted file mode 100644 index 27008b085ed..00000000000 --- a/source/slang/slang-ir-witness-table-wrapper.h +++ /dev/null @@ -1,22 +0,0 @@ -// slang-ir-witness-table-wrapper.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// This pass generates wrapper functions for witness table function entries. -/// -/// Enabled for generation of dynamic dispatch code only. -/// -/// Functions that are used to satisfy interface requirement have concrete -/// type signatures for `this` and `associatedtype` parameters/return values. -/// However, when they are called from a witness table, the callee only have a -/// raw pointer for this arguments, since the conrete type is not known to the -/// callee. Therefore, we need to generate wrappers for each member function -/// callable through a witness table, so that the wrapper functions take general void* -/// pointer for arguments whose type is unknown at call sites, and convert them -/// to concrete types and calls the actual implementation. -void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2b0c47cd10c..a8376973bc0 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2454,6 +2454,16 @@ IRVoidLit* IRBuilder::getVoidValue() return (IRVoidLit*)_findOrEmitConstant(keyInst); } +IRVoidLit* IRBuilder::getVoidValue(IRType* type) +{ + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.m_op = kIROp_VoidLit; + keyInst.typeUse.usedValue = type; + keyInst.value.intVal = 0; + return (IRVoidLit*)_findOrEmitConstant(keyInst); +} + IRInst* IRBuilder::getCapabilityValue(CapabilitySet const& caps) { IRType* capabilityAtomType = getIntType(); @@ -3579,7 +3589,7 @@ IRInst* IRBuilder::emitPackAnyValue(IRType* type, IRInst* value) IRInst* IRBuilder::emitUnpackAnyValue(IRType* type, IRInst* value) { - auto inst = createInst(this, kIROp_UnpackAnyValue, type, value); + auto inst = createInst(this, kIROp_UnpackAnyValue, type, value); addInst(inst); return inst; @@ -6565,6 +6575,46 @@ IREntryPointLayout* IRBuilder::getEntryPointLayout( operands)); } +IRSetBase* IRBuilder::getSet(IROp op, const HashSet& elements) +{ + // Verify that all operands are global instructions + for (auto element : elements) + if (element->getParent()->getOp() != kIROp_ModuleInst) + SLANG_ASSERT_FAILURE("createSet called with non-global operands"); + + List* sortedElements = getModule()->getContainerPool().getList(); + for (auto element : elements) + sortedElements->add(element); + + // Sort elements by their unique IDs to ensure canonical ordering + sortedElements->sort( + [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); + + auto setBaseInst = as( + emitIntrinsicInst(nullptr, op, sortedElements->getCount(), sortedElements->getBuffer())); + + getModule()->getContainerPool().free(sortedElements); + + return setBaseInst; +} + +IRSetBase* IRBuilder::getSingletonSet(IROp op, IRInst* element) +{ + return getSet(op, {element}); +} + +UInt IRBuilder::getUniqueID(IRInst* inst) +{ + auto uniqueIDMap = getModule()->getUniqueIdMap(); + auto existingId = uniqueIDMap->tryGetValue(inst); + if (existingId) + return *existingId; + + auto id = uniqueIDMap->getCount(); + uniqueIDMap->add(inst, id); + return id; +} + // struct IRDumpContext @@ -7998,6 +8048,38 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) // Swap this use over to use the other value. uu->usedValue = other; + bool userIsSetInst = as(uu->getUser()) != nullptr; + if (userIsSetInst) + { + // Set insts need their operands sorted + auto module = user->getModule(); + SLANG_ASSERT(module); + + List& operands = *module->getContainerPool().getList(); + for (UInt ii = 0; ii < user->getOperandCount(); ii++) + operands.add(user->getOperand(ii)); + + auto getUniqueId = [&](IRInst* inst) -> UInt + { + auto uniqueIDMap = module->getUniqueIdMap(); + auto existingId = uniqueIDMap->tryGetValue(inst); + if (existingId) + return *existingId; + + auto id = uniqueIDMap->getCount(); + uniqueIDMap->add(inst, id); + return (UInt)id; + }; + + operands.sort( + [&](IRInst* a, IRInst* b) -> bool { return getUniqueId(a) < getUniqueId(b); }); + + for (UInt ii = 0; ii < user->getOperandCount(); ii++) + user->getOperandUse(ii)->usedValue = operands[ii]; + + module->getContainerPool().free(&operands); + } + // If `other` is hoistable, then we need to make sure `other` is hoisted // to a point before `user`, if it is not already so. _maybeHoistOperand(uu); @@ -8297,6 +8379,9 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) if (as(this)) return false; + if (as(this)) + return false; + switch (getOp()) { // By default, assume that we might have side effects, @@ -8485,9 +8570,36 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetPerVertexInputArray: case kIROp_MetalCastToDepthTexture: case kIROp_GetCurrentStage: + case kIROp_GetDispatcher: + case kIROp_GetSpecializedDispatcher: + case kIROp_GetTagForMappedSet: + case kIROp_GetTagForSpecializedSet: + case kIROp_GetTagForSuperSet: + case kIROp_GetTagForSubSet: + case kIROp_GetTagFromSequentialID: + case kIROp_GetSequentialIDFromTag: + case kIROp_CastInterfaceToTaggedUnionPtr: + case kIROp_GetElementFromTag: + case kIROp_GetTagFromTaggedUnion: + case kIROp_GetTypeTagFromTaggedUnion: + case kIROp_GetValueFromTaggedUnion: + case kIROp_MakeTaggedUnion: + case kIROp_GetTagOfElementInSet: + case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: case kIROp_MakeStorageTypeLoweringConfig: return false; + case kIROp_UnboundedFuncElement: + case kIROp_UnboundedTypeElement: + case kIROp_UnboundedWitnessTableElement: + case kIROp_UnboundedGenericElement: + case kIROp_UninitializedTypeElement: + case kIROp_UninitializedWitnessTableElement: + case kIROp_NoneTypeElement: + case kIROp_NoneWitnessTableElement: + return false; + case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: case kIROp_BackwardDifferentiatePrimal: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 053b91aa1e6..73110a08473 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2048,6 +2048,8 @@ struct IRModule : RefObject IRDeduplicationContext* getDeduplicationContext() const { return &m_deduplicationContext; } + Dictionary* getUniqueIdMap() { return &m_mapInstToUniqueId; } + IRDominatorTree* findDominatorTree(IRGlobalValueWithCode* func) { IRAnalysis* analysis = m_mapInstToAnalysis.tryGetValue(func); @@ -2164,6 +2166,12 @@ struct IRModule : RefObject Dictionary m_mapInstToAnalysis; Dictionary> m_mapMangledNameToGlobalInst; + + /// Hold a mapping for inst -> uniqueID. This mapping is generated on + /// demand if passes need them, rather than eagerly storing them on + /// insts when unnecessary. + /// + Dictionary m_mapInstToUniqueId; }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b20588a53d9..d00bdbf96ec 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10902,10 +10902,22 @@ struct DeclLoweringVisitor : DeclVisitor for (auto param : paramCloneInfos) { typeBuilder.setInsertInto(param.clonedParam); - param.clonedParam->setFullType((IRType*)cloneInst( - &cloneEnv, - &typeBuilder, - param.originalParam->getFullType())); + + // If the type is present in the generic (i.e. not in module scope), + // then we need to make a clone since it is likely to be dependent on the + // new parameters we just emitted and we want to obtain the effective type + // in the current context. + // + // If it's in the global scope, we're good if we use the same type. + // + if (!as(param.originalParam->getFullType()->getParent())) + param.clonedParam->setFullType((IRType*)cloneInst( + &cloneEnv, + &typeBuilder, + param.originalParam->getFullType())); + else + param.clonedParam->setFullType(param.originalParam->getFullType()); + cloneInstDecorationsAndChildren( &cloneEnv, typeBuilder.getModule(), diff --git a/tests/autodiff/custom-intrinsic-1.slang b/tests/autodiff/custom-intrinsic-1.slang index e2ad6010b7a..f9616ca08bd 100644 --- a/tests/autodiff/custom-intrinsic-1.slang +++ b/tests/autodiff/custom-intrinsic-1.slang @@ -12,15 +12,22 @@ typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable; namespace myintrinsiclib { __generic - __target_intrinsic(hlsl, "exp($0)") - __target_intrinsic(glsl, "exp($0)") - __target_intrinsic(cuda, "$P_exp($0)") - __target_intrinsic(cpp, "$P_exp($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 27 _0") - __target_intrinsic(metal, "exp($0)") - __target_intrinsic(wgsl, "exp($0)") [ForwardDerivative(d_myexp)] - T myexp(T x); + T myexp(T x) + { + __target_switch + { + case cpp: __intrinsic_asm "$P_exp($0)"; + case cuda: __intrinsic_asm "$P_exp($0)"; + case glsl: __intrinsic_asm "exp"; + case hlsl: __intrinsic_asm "exp"; + case metal: __intrinsic_asm "exp"; + case spirv: return spirv_asm { + OpExtInst $$T result glsl450 27 $x + }; + case wgsl: __intrinsic_asm "exp"; + } + } __generic DifferentialPair d_myexp(DifferentialPair dpx) @@ -33,15 +40,22 @@ namespace myintrinsiclib // Sine __generic - __target_intrinsic(hlsl, "sin($0)") - __target_intrinsic(glsl, "sin($0)") - __target_intrinsic(metal, "sin($0)") - __target_intrinsic(cuda, "$P_sin($0)") - __target_intrinsic(cpp, "$P_sin($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 13 _0") - __target_intrinsic(wgsl, "sin($0)") [ForwardDerivative(d_mysin)] - T mysin(T x); + T mysin(T x) + { + __target_switch + { + case hlsl: __intrinsic_asm "sin($0)"; + case glsl: __intrinsic_asm "sin($0)"; + case metal: __intrinsic_asm "sin($0)"; + case cuda: __intrinsic_asm "$P_sin($0)"; + case cpp: __intrinsic_asm "$P_sin($0)"; + case spirv: return spirv_asm { + OpExtInst $$T result glsl450 13 $x + }; + case wgsl: __intrinsic_asm "sin($0)"; + } + } __generic DifferentialPair d_mysin(DifferentialPair dpx) @@ -53,15 +67,22 @@ namespace myintrinsiclib // Cosine __generic - __target_intrinsic(hlsl, "cos($0)") - __target_intrinsic(glsl, "cos($0)") - __target_intrinsic(metal, "cos($0)") - __target_intrinsic(cuda, "$P_cos($0)") - __target_intrinsic(cpp, "$P_cos($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 14 _0") - __target_intrinsic(wgsl, "cos($0)") [ForwardDerivative(d_mycos)] - T mycos(T x); + T mycos(T x) + { + __target_switch + { + case hlsl: __intrinsic_asm "cos($0)"; + case glsl: __intrinsic_asm "cos($0)"; + case metal: __intrinsic_asm "cos($0)"; + case cuda: __intrinsic_asm "$P_cos($0)"; + case cpp: __intrinsic_asm "$P_cos($0)"; + case spirv: return spirv_asm { + OpExtInst $$T result glsl450 14 $x + }; + case wgsl: __intrinsic_asm "cos($0)"; + } + } __generic DifferentialPair d_mycos(DifferentialPair dpx) diff --git a/tests/compute/dynamic-dispatch-1.slang b/tests/compute/dynamic-dispatch-1.slang index 10d3552ab88..f28704d47b3 100644 --- a/tests/compute/dynamic-dispatch-1.slang +++ b/tests/compute/dynamic-dispatch-1.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj // Test dynamic dispatch code gen for non-static member functions. diff --git a/tests/compute/dynamic-dispatch-10.slang b/tests/compute/dynamic-dispatch-10.slang index ddd51c4b920..9df6a6aaf37 100644 --- a/tests/compute/dynamic-dispatch-10.slang +++ b/tests/compute/dynamic-dispatch-10.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for specializing a generic with // an existential value. diff --git a/tests/compute/dynamic-dispatch-11.slang b/tests/compute/dynamic-dispatch-11.slang index 59e7ce58170..4b7a67324ad 100644 --- a/tests/compute/dynamic-dispatch-11.slang +++ b/tests/compute/dynamic-dispatch-11.slang @@ -6,8 +6,8 @@ //DISABLE_TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj [anyValueSize(8)] interface IInterface diff --git a/tests/compute/dynamic-dispatch-2.slang b/tests/compute/dynamic-dispatch-2.slang index 06aed6e71c9..3c83ee2a0f8 100644 --- a/tests/compute/dynamic-dispatch-2.slang +++ b/tests/compute/dynamic-dispatch-2.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for static member functions // of associated type. diff --git a/tests/compute/dynamic-dispatch-3.slang b/tests/compute/dynamic-dispatch-3.slang index 94430e638ab..2f8ca6504d3 100644 --- a/tests/compute/dynamic-dispatch-3.slang +++ b/tests/compute/dynamic-dispatch-3.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for static member functions // of associated type. diff --git a/tests/compute/dynamic-dispatch-4.slang b/tests/compute/dynamic-dispatch-4.slang index 996e706655b..6cb1c50d083 100644 --- a/tests/compute/dynamic-dispatch-4.slang +++ b/tests/compute/dynamic-dispatch-4.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for generic-typed local variables. diff --git a/tests/compute/dynamic-dispatch-5.slang b/tests/compute/dynamic-dispatch-5.slang index 61d4024af5a..4a7e12f9e5b 100644 --- a/tests/compute/dynamic-dispatch-5.slang +++ b/tests/compute/dynamic-dispatch-5.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for general `This` type. [anyValueSize(8)] diff --git a/tests/compute/dynamic-dispatch-6.slang b/tests/compute/dynamic-dispatch-6.slang index 69edb51bc56..6364274d57c 100644 --- a/tests/compute/dynamic-dispatch-6.slang +++ b/tests/compute/dynamic-dispatch-6.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for generic-typed return values. [anyValueSize(8)] diff --git a/tests/compute/dynamic-dispatch-7.slang b/tests/compute/dynamic-dispatch-7.slang index 7f516b0d556..1e5d7a639e0 100644 --- a/tests/compute/dynamic-dispatch-7.slang +++ b/tests/compute/dynamic-dispatch-7.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for associated-typed return values // and local variables. diff --git a/tests/compute/dynamic-dispatch-8.slang b/tests/compute/dynamic-dispatch-8.slang index 62d3f0a9d41..77faf771928 100644 --- a/tests/compute/dynamic-dispatch-8.slang +++ b/tests/compute/dynamic-dispatch-8.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for extential type parameters. diff --git a/tests/compute/dynamic-dispatch-9.slang b/tests/compute/dynamic-dispatch-9.slang index 77112a318b7..de213dc4224 100644 --- a/tests/compute/dynamic-dispatch-9.slang +++ b/tests/compute/dynamic-dispatch-9.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for initializing an extential value // from a generic value. diff --git a/tests/compute/dynamic-generics-simple.slang b/tests/compute/dynamic-generics-simple.slang index c3db4542041..ecf18486927 100644 --- a/tests/compute/dynamic-generics-simple.slang +++ b/tests/compute/dynamic-generics-simple.slang @@ -1,5 +1,5 @@ -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization +//TEST(compute):COMPARE_COMPUTE:-cpu +//TEST(compute):COMPARE_COMPUTE:-cuda // Test basic dynamic dispatch code gen diff --git a/tests/diagnostics/interfaces/anyvalue-size-validation.slang b/tests/diagnostics/interfaces/anyvalue-size-validation.slang index ffe968c678c..213c2718993 100644 --- a/tests/diagnostics/interfaces/anyvalue-size-validation.slang +++ b/tests/diagnostics/interfaces/anyvalue-size-validation.slang @@ -1,6 +1,6 @@ // anyvalue-size-validation.slang -//DIAGNOSTIC_TEST:SIMPLE:-target cpp -stage compute -entry main -disable-specialization +//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute -conformance "A:IInterface=0" -conformance "B:IInterface=1" [anyValueSize(8)] interface IInterface @@ -8,7 +8,9 @@ interface IInterface int doSomething(); }; -struct S : IInterface +//CHECK: ([[# @LINE+2]]): error 41011: type 'A' does not fit in the size required by its conforming interface. +//CHECK-NEXT: struct A : IInterface +struct A : IInterface { uint a; uint b; @@ -16,16 +18,27 @@ struct S : IInterface int doSomething() { return 5; } }; -T test(T s) +//CHECK: note: sizeof(A) is 12, limit is 8 + +struct B : IInterface { - return s; -} + int doSomething() { return 6; } +}; + +//TEST_INPUT: +//TEST_INPUT: type_conformance B:IInterface = 1 -RWStructuredBuffer output; +RWStructuredBuffer data; +RWStructuredBuffer output; + +struct Data +{ + float x; + float y; +}; [numthreads(4, 1, 1)] -void main() +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - S s = S(1, 2, 3); - output[0] = test(s).a; + output[1] = data[0].doSomething(); } diff --git a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected index e9465167114..a95188a0bb6 100644 --- a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected +++ b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected @@ -1,9 +1,11 @@ -result code = -1 +result code = 1 standard error = { -tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): error 41011: type 'S' does not fit in the size required by its conforming interface. -struct S : IInterface - ^ -tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note: sizeof(S) is 12, limit is 8 } standard output = { } +debug layer = { +tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): error 41011: type 'A' does not fit in the size required by its conforming interface. +struct A : IInterface + ^ +tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note: sizeof(A) is 12, limit is 8 +} diff --git a/tests/diagnostics/interfaces/interface-extension.slang b/tests/diagnostics/interfaces/interface-extension.slang index b63b454abdc..7c42b3874e6 100644 --- a/tests/diagnostics/interfaces/interface-extension.slang +++ b/tests/diagnostics/interfaces/interface-extension.slang @@ -1,4 +1,4 @@ -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main -disable-specialization +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main interface IFoo{} diff --git a/tests/diagnostics/no-type-conformance.slang b/tests/diagnostics/no-type-conformance.slang index ada4f663a75..08147268871 100644 --- a/tests/diagnostics/no-type-conformance.slang +++ b/tests/diagnostics/no-type-conformance.slang @@ -1,6 +1,6 @@ -//DIAGNOSTIC_TEST:COMMAND_LINE_SIMPLE:-target hlsl -entry computeMain -stage compute -o no-type-conformance.hlsl -// no type conformance linked +//TEST:SIMPLE(filecheck=CHECK):-target hlsl -entry computeMain -stage compute -o no-type-conformance.hlsl +//CHECK: error 50100: No type conformances are found for interface 'IFoo' interface IFoo { float get(); @@ -8,7 +8,7 @@ interface IFoo void foo() { - IFoo obj; + IFoo obj = createDynamicObject(0, 0); obj.get(); } diff --git a/tests/diagnostics/no-type-conformance.slang.expected b/tests/diagnostics/no-type-conformance.slang.expected deleted file mode 100644 index 5f5eda6afe7..00000000000 --- a/tests/diagnostics/no-type-conformance.slang.expected +++ /dev/null @@ -1,8 +0,0 @@ -result code = -1 -standard error = { -tests/diagnostics/no-type-conformance.slang(12): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage. - obj.get(); - ^ -} -standard output = { -} diff --git a/tests/diagnostics/uninitialized-existential.slang b/tests/diagnostics/uninitialized-existential.slang new file mode 100644 index 00000000000..26a5ed3a4ad --- /dev/null +++ b/tests/diagnostics/uninitialized-existential.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK):-target hlsl -entry computeMain -stage compute -o no-type-conformance.hlsl + +//CHECK: error 50101: Cannot dynamically dispatch on potentially uninitialized interface object 'IFoo' +interface IFoo +{ + float get(); +} + +void foo() +{ + IFoo obj; + obj.get(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + foo(); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/assoc-types.slang b/tests/language-feature/dynamic-dispatch/assoc-types.slang new file mode 100644 index 00000000000..8b9f82e8d0f --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/assoc-types.slang @@ -0,0 +1,60 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface ICalculation +{ + associatedtype Data; + float calc(Data d, float x); + Data make(float q); +} + +struct A : ICalculation +{ + typealias Data = float; + float calc(Data d, float x) { return d * x * x; } + Data make(float q) { return q; } +}; + +struct BData +{ + float x; + float y; +}; + +struct B : ICalculation +{ + typealias Data = BData; + float calc(Data d, float x) { return d.x * x * x + d.y; } + Data make(float q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = float; + float calc(Data d, float x) { return d * x; } + Data make(float q) { return q; } +}; + +ICalculation factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); // CHECK: 1.0 + outputBuffer[1] = f(1, 1); // CHECK: 2.0 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types-1.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-1.slang new file mode 100644 index 00000000000..b2e3da8dc91 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-1.slang @@ -0,0 +1,121 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface ISerializer +{ + static void serialize(T data, RWStructuredBuffer buffer); + static T deserialize(RWStructuredBuffer buffer); +} + +interface ICalculation +{ + associatedtype Data; + associatedtype DataSerializer : ISerializer; + float calc(Data d, float x); + Data make(float q); +} + +struct StandardSerializer : ISerializer +{ + static void serialize(T data, RWStructuredBuffer buffer) + { + // Note: just for testing.. don't ever serialize this way + buffer[0] = __realCast(data); + } + + static T deserialize(RWStructuredBuffer buffer) + { + return __realCast(buffer[0]); + } +}; + +struct A : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x * x; } + Data make(float q) { return q; } +}; + +struct BData +{ + float x; + float y; +}; + +struct BDataSerializer : ISerializer +{ + static void serialize(BData data, RWStructuredBuffer buffer) + { + buffer[0] = data.x; + buffer[1] = data.y; + } + + static BData deserialize(RWStructuredBuffer buffer) + { + return {buffer[0], buffer[1]}; + } +}; + +struct B : ICalculation +{ + typealias Data = BData; + typealias DataSerializer = BDataSerializer; + float calc(Data d, float x) { return d.x * x + d.y; } + Data make(float q) { return {q, q * q}; } +}; + +struct C : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x; } + Data make(float q) { return q; } +}; + +ICalculation factoryAB(uint id) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float q) +{ + let obj = factoryAB(id); + obj.Data d = obj.make(q); + obj.DataSerializer::serialize(d, outputBuffer); + outputBuffer[2] = id; + return 0; +} + +float g(float x) +{ + uint id = (uint)outputBuffer[2]; + let obj = factoryAB(id); + obj.Data d = obj.DataSerializer::deserialize(outputBuffer); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + f(0, 3.f); + outputBuffer[3] = g(2.f); + f(1, 2.f); + outputBuffer[4] = g(2.f); + + // Clear the first 3 elements + outputBuffer[0] = 0.f; + outputBuffer[1] = 0.f; + outputBuffer[2] = 0.f; + + // CHECK: 0.0 + // CHECK: 0.0 + // CHECK: 0.0 + // CHECK: 12.0 + // CHECK: 8.0 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang new file mode 100644 index 00000000000..4721387f0f5 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang @@ -0,0 +1,150 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):name=scratchBuffer +RWStructuredBuffer scratchBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IBufferRef +{ + uint[N] get(uint index); + void set(uint index, uint[N] value); + This withOffset(uint offset); +}; + +interface ISerializer +{ + static void serialize(T data, IBufferRef buffer); + static T deserialize(IBufferRef buffer); +} + +interface ICalculation +{ + associatedtype Data; + associatedtype DataSerializer : ISerializer; + float calc(Data d, float x); + Data make(float q); +} + +struct StandardSerializer : ISerializer +{ + static void serialize(T data, IBufferRef buffer) + { + buffer.set<1>(0, bit_cast(data)); + } + static T deserialize(IBufferRef buffer) + { + return bit_cast(buffer.get<1>(0)); + } +}; + +struct A : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x * x; } + Data make(float q) { return q; } +}; + +struct BData +{ + float x; + float y; +}; + +struct BDataSerializer : ISerializer +{ + static void serialize(BData data, IBufferRef buffer) + { + StandardSerializer::serialize(data.x, buffer.withOffset(0)); + StandardSerializer::serialize(data.y, buffer.withOffset(1)); + } + + static BData deserialize(IBufferRef buffer) + { + return BData(StandardSerializer::deserialize(buffer.withOffset(0)), + StandardSerializer::deserialize(buffer.withOffset(1))); + } +}; + +struct B : ICalculation +{ + typealias Data = BData; + typealias DataSerializer = BDataSerializer; + float calc(Data d, float x) { return d.x * x * x + d.y; } + Data make(float q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x; } + Data make(float q) { return q; } +}; + +ICalculation factoryAB(uint id) +{ + if (id == 0) + return A(); + else + return B(); +} + +struct BufferRef : IBufferRef +{ + uint offset; + uint[N] get(uint index) + { + uint[N] result; + for (int i = 0; i < N; ++i) + result[i] = scratchBuffer[offset + i]; + return result; + } + + void set(uint index, uint[N] value) + { + for (int i = 0; i < N; ++i) + scratchBuffer[offset + i] = value[i]; + } + + This withOffset(uint offset) + { + return {this.offset + offset}; + } +}; +IBufferRef getScratchBufferAsRef() +{ + return BufferRef(0); +} + +void f(uint id, float q) +{ + let obj = factoryAB(id); + obj.Data d = obj.make(q); + IBufferRef buf = getScratchBufferAsRef(); + buf.set<1>(0, {(uint)id}); + obj.DataSerializer::serialize(d, buf.withOffset(1)); +} + +float g(float x) +{ + IBufferRef buf = getScratchBufferAsRef(); + uint id = buf.get<1>(0)[0]; + let obj = factoryAB(id); + obj.Data d = obj.DataSerializer::deserialize(buf.withOffset(1)); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + f(0, 3.f); + outputBuffer[0] = g(2.f); + f(1, 2.f); + outputBuffer[1] = g(2.f); + + // CHECK: 12.0 + // CHECK: 10.0 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/derived-interface.slang b/tests/language-feature/dynamic-dispatch/derived-interface.slang new file mode 100644 index 00000000000..a52598477f0 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/derived-interface.slang @@ -0,0 +1,82 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IScalable +{ + This scale(float x); +} + +interface IShape : IScalable +{ + float getArea(); + float getPerimeter(); +} + +struct Circle : IShape +{ + float radius; + + This scale(float x) { return {radius * x}; } + float getArea() { return 3.14159f * radius * radius; } + float getPerimeter() { return 2 * 3.14159f * radius; } +}; + +struct Square : IShape +{ + float side; + + This scale(float x) { return {side * x}; } + float getArea() { return side * side; } + float getPerimeter() { return 4 * side; } +}; + +struct Vector : IScalable +{ + float x, y; + + This scale(float factor) { return {x * factor, y * factor}; } +}; + +struct Matrix : IScalable +{ + float m[4][4]; + + This scale(float factor) + { + Matrix result; + for (int i = 0; i < 4; ++i) + for (int j = 0; j < 4; ++j) + result.m[i][j] = m[i][j] * factor; + return result; + } +}; + +IShape createShape(uint id, float data) +{ + return createDynamicObject(id, data); +} + +float f(uint id, float data, float scale) +{ + let obj = createShape(id, data); + let scaledObj = obj.scale(scale); + return scaledObj.getArea() + scaledObj.getPerimeter(); +}; + + +//TEST_INPUT: type_conformance Vector:IScalable = 0 +//TEST_INPUT: type_conformance Matrix:IScalable = 1 +//TEST_INPUT: type_conformance Circle:IScalable = 2 +//TEST_INPUT: type_conformance Square:IScalable = 3 + +//TEST_INPUT: type_conformance Circle:IShape = 0 +//TEST_INPUT: type_conformance Square:IShape = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3, 3); // CHECK: 311.017426 + outputBuffer[1] = f(1, 3, 3); // CHECK: 117.000000 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang new file mode 100644 index 00000000000..0f9516c4279 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang @@ -0,0 +1,67 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IFoo +{ + float calc(float x); +} + +interface IFactory +{ + IFoo create(); +} + +struct AFoo : IFoo +{ + float calc(float x) { return x * x * x; } +}; + +struct A : IFactory +{ + IFoo create() { return AFoo(); } +}; + +struct B : IFactory, IFoo +{ + float calc(float x) { return x * x; } + IFoo create() { return this; } +}; + +struct CFoo : IFoo +{ + float calc(float x) { return x; } +}; + +struct C : IFactory +{ + IFoo create() { return CFoo(); } +}; + +IFactory getFactory(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float calc(IFoo obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IFactory obj = getFactory(id, x); + IFoo foo = obj.create(); + return calc(foo, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang new file mode 100644 index 00000000000..f02203c6f60 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang @@ -0,0 +1,69 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IFoo +{ + float calc(float x); +} + +interface IFactory +{ + IFoo create(float q); +} + +struct AFoo : IFoo +{ + float q; + float calc(float x) { return q * x; } +}; + +struct A : IFactory +{ + IFoo create(float q) { return AFoo(q); } +}; + +struct B : IFactory, IFoo +{ + float q; + float calc(float x) { return q * x * x; } + IFoo create(float q) { return This(q); } +}; + +struct CFoo : IFoo +{ + float calc(float x) { return x; } +}; + +struct C : IFactory +{ + IFoo create(float q) { return CFoo(); } +}; + +IFactory getFactory(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(x); +} + +float calc(IFoo obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IFactory factory = getFactory(id, x); + IFoo foo = factory.create(2 * x); + return calc(foo, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 16 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-1.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-1.slang new file mode 100644 index 00000000000..6be5c9c6acb --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-1.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + static float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + static float calc(float x) { return x; } +}; + +float calc(T obj, float y) +{ + return T::calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return calc(obj, x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang new file mode 100644 index 00000000000..43bc8d57175 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang @@ -0,0 +1,74 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + static float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + static float calc(float x) { return x; } +}; + +struct D : IInterface +{ + static float calc(float x) { return x * x * x * x; } +}; + +struct E : IInterface +{ + static float calc(float x) { return x * x * x * x; } +}; + +float calc(T obj, float y) +{ + return T::calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return calc(obj, x); // Specialized for A & B. +} + +float g(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = C(); + else + obj = D(); + + return calc(obj, x); // Specialized for C & D. +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 + + outputBuffer[2] = g(0, 2); // CHECK: 2 + outputBuffer[3] = g(1, 2); // CHECK: 16 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang new file mode 100644 index 00000000000..09bba484e4e --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang @@ -0,0 +1,61 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + This withOffset(float offset); + float calc(float x); +} + +struct A : IInterface +{ + float factor; + A withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return factor * x * x * x; } +}; + +struct B : IInterface +{ + float factor1; + float factor2; + B withOffset(float offset) { return {factor1 + offset, factor2 + offset}; } + float calc(float x) { return factor1 * x * x + factor2 * x; } +}; + +struct C : IInterface +{ + float factor; + C withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return x; } +}; + +T transfer(T obj) +{ + return obj.withOffset(2.5f); +} + +float calc(T obj, float y) +{ + return transfer(obj).calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return calc(obj, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 20 + outputBuffer[1] = f(1, 2); // CHECK: 15 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang new file mode 100644 index 00000000000..752f5b62c3c --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang @@ -0,0 +1,72 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + This withOffset(float offset); + float calc(float x); +} + +struct A : IInterface +{ + float factor; + A withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return factor * x * x * x; } +}; + +struct B : IInterface +{ + float factor1; + float factor2; + B withOffset(float offset) { return {factor1 + offset, factor2 + offset}; } + float calc(float x) { return factor1 * x * x + factor2 * x; } +}; + +struct C : IInterface +{ + float factor; + C withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return x; } +}; + +T transfer(T obj) +{ + return obj.withOffset(2.5f); +} + +struct Foo +{ + T a; + T b; +} + +Foo make(T obj) +{ + return {obj, obj.withOffset(1.0f)}; +} + +float calc(Foo obj, float y) +{ + return obj.a.calc(y) + obj.b.calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(x); + else + obj = B(x, x + 1); + + return calc(make(obj), x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 40 + outputBuffer[1] = f(1, 2); // CHECK: 34 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang new file mode 100644 index 00000000000..0d92ded6825 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang @@ -0,0 +1,66 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + associatedtype Extras; + This withOffset(float offset); + Extras preCalc(float x); + float postCalc(Extras extras); +} + +struct A : IInterface +{ + float factor; + typealias Extras = float; + A withOffset(float offset) { return {factor + offset}; } + float preCalc(float x) { return factor * x * x * x; } + float postCalc(float extras) { return extras; } +}; + +struct B : IInterface +{ + float factor1; + float factor2; + typealias Extras = float2; + B withOffset(float offset) { return {factor1 + offset, factor2 + offset}; } + float preCalc(float x) { return factor1 * x * x + factor2 * x; } + float postCalc(Extras extras) { return dot(extras, float2(1, 1)); } +}; + +struct Foo +{ + T a; + T.Extras extras; +} + +Foo make(T obj, float y) +{ + return {obj, obj.preCalc(y)}; +} + +float calc(Foo obj) +{ + return obj.a.postCalc(obj.extras); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(x); + else + obj = B(x, x + 1); + + return calc(make(obj, x)); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 16 + outputBuffer[1] = f(1, 2); // CHECK: 28 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-1.slang b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang new file mode 100644 index 00000000000..e93ededa7ae --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang @@ -0,0 +1,51 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +IInterface factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +// Should lower to accept A, B (but not C) +float calc(IInterface obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IInterface obj = factoryAB(id, x); + return calc(obj, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 9 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-2.slang b/tests/language-feature/dynamic-dispatch/func-call-input-2.slang new file mode 100644 index 00000000000..62988cb7b02 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/func-call-input-2.slang @@ -0,0 +1,77 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +struct D : IInterface +{ + float calc(float x) { return x + x * x; } +}; + +struct E : IInterface +{ + float calc(float x) { return x * x + x * 3; } +}; + +IInterface factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +IInterface factoryCD(uint id, float x) +{ + if (id == 0) + return C(); + else + return D(); +} + +// Should lower to accept A, B, C or D (but not E) +float calc(IInterface obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IInterface obj = factoryAB(id, x); + return calc(obj, x); +} + +float g(uint id, float x) +{ + IInterface obj = factoryCD(id, x); + return calc(obj, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 9 + outputBuffer[2] = g(0, 3); // CHECK: 3 + outputBuffer[3] = g(1, 3); // CHECK: 12 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-return.slang b/tests/language-feature/dynamic-dispatch/func-call-return.slang new file mode 100644 index 00000000000..a9d91823fb9 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/func-call-return.slang @@ -0,0 +1,45 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +IInterface factory(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + IInterface obj = factory(id, x); + return obj.calc(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 9 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-1.slang b/tests/language-feature/dynamic-dispatch/generic-interface-1.slang new file mode 100644 index 00000000000..13d5843e0e9 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-interface-1.slang @@ -0,0 +1,60 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface ICalculation +{ + associatedtype Data; + T calc(Data d, T x); + Data make(T q); +} + +struct A : ICalculation +{ + typealias Data = T; + T calc(Data d, T x) { return d * x * x; } + Data make(T q) { return q; } +}; + +struct BData +{ + T x; + T y; +}; + +struct B : ICalculation +{ + typealias Data = BData; + T calc(Data d, T x) { return d.x * x * x + d.y; } + Data make(T q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = T; + T calc(Data d, T x) { return d * x; } + Data make(T q) { return q; } +}; + +ICalculation factoryAB(uint id, T x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 27.0 + outputBuffer[1] = f(1, 3); // CHECK: 30.0 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-2.slang b/tests/language-feature/dynamic-dispatch/generic-interface-2.slang new file mode 100644 index 00000000000..e91abdedd3a --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-interface-2.slang @@ -0,0 +1,67 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IData +{ + T getNorm2(); +} + +interface ICalculation +{ + associatedtype Data : IData; + T calc(Data d, T x); + Data make(T q); +} + +struct AData : IData +{ + T x; + + T getNorm2() { return x * x; } +}; + +struct A : ICalculation +{ + typealias Data = AData; + T calc(Data d, T x) { return d.getNorm2() * x * x; } + Data make(T q) { return {q}; } +}; + +struct BData : IData +{ + T x; + T y; + + T getNorm2() { return x * x + y * y; } +}; + +struct B : ICalculation +{ + typealias Data = BData; + T calc(Data d, T x) { return d.x * x * x + d.y; } + Data make(T q) { return {q, q}; } +}; + +ICalculation factoryAB(uint id, T x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x) + d.getNorm2(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 90 + outputBuffer[1] = f(1, 3); // CHECK: 48 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-3.slang b/tests/language-feature/dynamic-dispatch/generic-interface-3.slang new file mode 100644 index 00000000000..e94518fb97a --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-interface-3.slang @@ -0,0 +1,67 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IData +{ + T getNorm2(); +} + +interface ICalculation +{ + associatedtype Data : IData; + T calc(Data d, T x); + Data make(T q); +} + +struct AData : IData +{ + T x; + + T getNorm2() { return x * x; } +}; + +struct A : ICalculation +{ + typealias Data = AData; + T calc(Data d, T x) { return d.x * x * x; } + Data make(T q) { return {q}; } +}; + +struct BData : IData +{ + T x; + T y; + + T getNorm2() { return x * x + y * y; } +}; + +struct B : ICalculation +{ + typealias Data = BData; + T calc(Data d, T x) { return d.x * x * x + d.y; } + Data make(T q) { return {q, q}; } +}; + +ICalculation factoryAB(uint id, T x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x) + d.getNorm2(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 36 + outputBuffer[1] = f(1, 3); // CHECK: 48 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-method.slang b/tests/language-feature/dynamic-dispatch/generic-method.slang new file mode 100644 index 00000000000..ff839d0df36 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-method.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + T calc(T x); +} + +struct A : IInterface +{ + T calc(T x) { return x * x * x; } +}; + +struct B : IInterface +{ + T calc(T x) { return x * x; } +}; + +struct C : IInterface +{ + T calc(T x) { return x; } +}; + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else if (id == 1) + obj = B(); + else + obj = C(); + + return obj.calc(x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/inout.slang b/tests/language-feature/dynamic-dispatch/inout.slang new file mode 100644 index 00000000000..e1c4063eb83 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/inout.slang @@ -0,0 +1,76 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// Test data-flow through inout parameters (which will require data-flow through 'IRVar' insts) + +interface IReserver +{ + [mutating] uint reserve(uint x); +} + +struct Single : IReserver +{ + uint total; + [mutating] uint reserve(uint x) + { + if (total < x) + { + total = 0; + return total; + } + else + { + total -= x; + return x; + } + } +}; + +struct Double : IReserver +{ + uint total; + [mutating] uint reserve(uint x) + { + if (total < 2 * x) + { + total = 0; + return total; + } + else + { + total -= 2 * x; + return 2 * x; + } + } +}; + +IReserver make(uint id, uint total) +{ + if (id == 0) + return Single(total); + else + return Double(total); +} + +float f(inout IReserver obj, uint res, float x) +{ + uint y = obj.reserve(res + 1); + + return x * y; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = make(0, 10); + outputBuffer[0] = f(obj, 0, 2); // CHECK: 2 + outputBuffer[1] = f(obj, 1, 3); // CHECK: 6 + outputBuffer[2] = f(obj, 0, 4); // CHECK: 4 + + obj = make(1, 10); + outputBuffer[3] = f(obj, 0, 2); // CHECK: 4 + outputBuffer[4] = f(obj, 1, 3); // CHECK: 12 + outputBuffer[5] = f(obj, 0, 4); // CHECK: 8 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/simple.slang b/tests/language-feature/dynamic-dispatch/simple.slang new file mode 100644 index 00000000000..6d52cfbe444 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/simple.slang @@ -0,0 +1,44 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return obj.calc(x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/with-data.slang b/tests/language-feature/dynamic-dispatch/with-data.slang new file mode 100644 index 00000000000..c48e2f2d04f --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/with-data.slang @@ -0,0 +1,47 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float factor; + float calc(float x) { return x * x * factor; } +}; + +struct B : IInterface +{ + float factor; + float offset; + float calc(float x) { return x * x * factor + offset; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(x); + else + obj = B(x, x); + + return obj.calc(x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 30 +} \ No newline at end of file diff --git a/tests/language-feature/types/optional-ifoo.slang b/tests/language-feature/types/optional-ifoo-1.slang similarity index 100% rename from tests/language-feature/types/optional-ifoo.slang rename to tests/language-feature/types/optional-ifoo-1.slang diff --git a/tests/language-feature/types/optional-ifoo-2.slang b/tests/language-feature/types/optional-ifoo-2.slang new file mode 100644 index 00000000000..0a19fa9ad35 --- /dev/null +++ b/tests/language-feature/types/optional-ifoo-2.slang @@ -0,0 +1,33 @@ +//TEST:INTERPRET(filecheck=CHECK): + +interface IFoo +{ + int get_result(); +} + +struct FooImpl : IFoo +{ + int get_result() { return data; } + int data; +} + +Optional generate_foo(int i) +{ + if (i % 5 == 0) + { + FooImpl result = {}; + result.data = i; + return { result }; + } + else + { + return {}; + } +} + +void main() +{ + // CHECK: hasValue: 1 + let result_foo = generate_foo(100); + printf("hasValue: %d\n", (int)result_foo.hasValue); +} diff --git a/tests/language-feature/types/optional-ifoo-3.slang b/tests/language-feature/types/optional-ifoo-3.slang new file mode 100644 index 00000000000..0299eb3e96f --- /dev/null +++ b/tests/language-feature/types/optional-ifoo-3.slang @@ -0,0 +1,35 @@ +//TEST:INTERPRET(filecheck=CHECK): + +interface IFoo +{ + T get_result(); +} + +struct FooImpl : IFoo +{ + T get_result() { return (T)data; } + int data; +} + +Optional generate_foo(int i) +{ + if (i % 5 == 0) + { + FooImpl result = {}; + result.data = i; + return { result }; + } + else + { + return {}; + } +} + +void main() +{ + // CHECK: hasValue: 1 + // CHECK-NEXT: result: 100.0 + let result_foo = generate_foo(100); + printf("hasValue: %d\n", (int)result_foo.hasValue); + printf("result: %f\n", (float)result_foo.value.get_result()); +} diff --git a/tests/wgsl/switch-case.slang b/tests/wgsl/switch-case.slang index fc24bd67a19..163178efc89 100644 --- a/tests/wgsl/switch-case.slang +++ b/tests/wgsl/switch-case.slang @@ -70,14 +70,8 @@ func fs_main(VertexOutput input)->FragmentOutput return output; } -//WGSL: switch({{.*}}) -//WGSL-NEXT: { -//WGSL-NEXT: case u32(0): -//WGSL-NEXT: { -//WGSL-NEXT: return Circle_getArea_0 -//WGSL-NEXT: } -//WGSL-NEXT: default : -//WGSL-NEXT: { -//WGSL-NEXT: return Rectangle_getArea_0 -//WGSL-NEXT: } -//WGSL-NEXT: } +//WGSL: fn s_dispatch_IShape_getArea_{{[0-9]+}}( _S{{[0-9]+}} : u32, _S{{[0-9]+}} : AnyValue8) -> f32 +//WGSL-NEXT:{ +//WGSL-DAG: return U_SR14switch_2Dxcase6Circle7getAreap0pf_wtwrapper_0(_S{{[0-9]+}}); +//WGSL-DAG: return U_SR14switch_2Dxcase9Rectangle7getAreap0pf_wtwrapper_0(_S{{[0-9]+}}); +//WGSL:}