From 0c8f5c48dc97707f46ba61a45ae6037759e939e1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Thu, 17 Jul 2025 18:22:30 -0400 Subject: [PATCH 001/105] Initial working data-flow-based dynamic-dispatch pass --- source/slang/slang-emit.cpp | 5 +- source/slang/slang-ir-legalize-types.cpp | 3 + source/slang/slang-ir-lower-dynamic-insts.cpp | 1526 +++++++++++++++++ source/slang/slang-ir-lower-dynamic-insts.h | 141 ++ source/slang/slang-ir.cpp | 2 +- source/slang/slang-lower-to-ir.cpp | 47 +- .../dynamic-dispatch/simple.slang | 41 + 7 files changed, 1743 insertions(+), 22 deletions(-) create mode 100644 source/slang/slang-ir-lower-dynamic-insts.cpp create mode 100644 source/slang/slang-ir-lower-dynamic-insts.h create mode 100644 tests/language-feature/dynamic-dispatch/simple.slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 7d8f1438d62..dad2db1fc62 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -76,6 +76,7 @@ #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-coopvec.h" +#include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-lower-dynamic-resource-heap.h" #include "slang-ir-lower-enum-type.h" #include "slang-ir-lower-generics.h" @@ -1044,10 +1045,12 @@ Result linkAndOptimizeIR( if (!codeGenContext->isSpecializationDisabled()) { SpecializationOptions specOptions; - specOptions.lowerWitnessLookups = true; + specOptions.lowerWitnessLookups = false; specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions); } + lowerDynamicInsts(irModule, sink); + finalizeSpecialization(irModule); // Lower `Result` types into ordinary struct types. This must happen diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 9cf8d7b5f49..a69344548cf 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1466,6 +1466,9 @@ static LegalVal legalizeGetElement( // the "index" argument. auto indexOperand = legalIndexOperand.getSimple(); + if (type.flavor == LegalType::Flavor::none) + return LegalVal(); + return legalizeGetElement(context, type, legalPtrOperand, indexOperand); } diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp new file mode 100644 index 00000000000..00fcca00f22 --- /dev/null +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -0,0 +1,1526 @@ +#include "slang-ir-lower-dynamic-insts.h" + +#include "slang-ir-any-value-marshalling.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ +// PropagationInfo implementation +RefPtr PropagationInfo::makeConcrete(PropagationJudgment j, IRInst* value) +{ + auto info = new PropagationInfo(j); + info->concreteValue = value; + return info; +} + +RefPtr PropagationInfo::makeSet( + PropagationJudgment j, + const HashSet& values) +{ + auto info = new PropagationInfo(j); + SLANG_ASSERT(j != PropagationJudgment::SetOfFuncs); + info->possibleValues = values; + return info; +} + + +RefPtr PropagationInfo::makeSetOfFuncs( + const HashSet& values, + IRFuncType* dynFuncType) +{ + auto info = new PropagationInfo(PropagationJudgment::SetOfFuncs); + info->possibleValues = values; + info->dynFuncType = dynFuncType; + return info; +} + +RefPtr PropagationInfo::makeExistential(const HashSet& tables) +{ + auto info = new PropagationInfo(PropagationJudgment::Existential); + info->possibleValues = tables; + return info; +} + +RefPtr PropagationInfo::makeUnknown() +{ + return new PropagationInfo(PropagationJudgment::UnknownSet); +} + +bool areInfosEqual(const RefPtr& a, const RefPtr& b) +{ + if (a->judgment != b->judgment) + return false; + + switch (a->judgment) + { + case PropagationJudgment::Value: + return true; // All value judgments are equal + + case PropagationJudgment::ConcreteType: + case PropagationJudgment::ConcreteTable: + case PropagationJudgment::ConcreteFunc: + return a->concreteValue == b->concreteValue; + + case PropagationJudgment::SetOfTypes: + case PropagationJudgment::SetOfTables: + case PropagationJudgment::SetOfFuncs: + if (a->possibleValues.getCount() != b->possibleValues.getCount()) + return false; + for (auto value : a->possibleValues) + { + if (!b->possibleValues.contains(value)) + return false; + } + return true; + + case PropagationJudgment::Existential: + if (a->possibleValues.getCount() != b->possibleValues.getCount()) + return false; + for (auto table : a->possibleValues) + { + if (!b->possibleValues.contains(table)) + return false; + } + return true; + + case PropagationJudgment::UnknownSet: + return true; // All unknown sets are considered equal + + default: + return false; + } +} + +// DynamicInstLoweringContext implementation +RefPtr DynamicInstLoweringContext::tryGetInfo(IRInst* inst) +{ + // If this is a global instruction (parent is module), return concrete info + if (as(inst->getParent())) + { + // Create static info based on instruction type + static RefPtr typeInfo = + PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, nullptr); + static RefPtr tableInfo = + PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, nullptr); + static RefPtr funcInfo = + PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, nullptr); + static RefPtr valueInfo = PropagationInfo::makeValue(); + + if (as(inst)) + { + typeInfo->concreteValue = inst; + return typeInfo; + } + else if (as(inst)) + { + tableInfo->concreteValue = inst; + return tableInfo; + } + else if (as(inst)) + { + funcInfo->concreteValue = inst; + return funcInfo; + } + else + { + return valueInfo; + } + } + + // For non-global instructions, look up in the map + auto found = propagationMap.tryGetValue(inst); + if (found) + return *found; + return nullptr; +} + +void DynamicInstLoweringContext::performInformationPropagation() +{ + // Process each function in the module + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as(inst)) + { + processFunction(func); + } + } +} + +IRInst* DynamicInstLoweringContext::maybeReinterpretArg(IRInst* arg, IRInst* param) +{ + /*auto argTypeInfo = tryGetInfo(arg->getDataType()); + auto paramTypeInfo = tryGetInfo(param->getDataType()); + + if (argTypeInfo & paramTypeInfo) + return arg; + + IRBuilder builder(module); + builder.setInsertAfter(arg); + + if (argTypeInfo->judgment == PropagationJudgment::SetOfTypes && + paramTypeInfo->judgment == PropagationJudgment::SetOfTypes) + { + if (argTypeInfo->possibleValues.getCount() != paramTypeInfo->possibleValues.getCount()) + { + SLANG_ASSERT("Unhandled"); + + // If the sets are not equal, reinterpret to the parameter type + // auto reinterpret = builder.emitReinterpret(param->getDataType(), arg); + // propagationMap[reinterpret] = paramTypeInfo; + } + }*/ + + auto argInfo = tryGetInfo(arg); + auto paramInfo = tryGetInfo(param); + + if (!argInfo || !paramInfo) + return arg; + + if (argInfo->judgment == PropagationJudgment::Existential && + paramInfo->judgment == PropagationJudgment::Existential) + { + if (argInfo->possibleValues.getCount() != paramInfo->possibleValues.getCount()) + { + // If the sets of witness tables are not equal, reinterpret to the parameter type + IRBuilder builder(module); + builder.setInsertAfter(arg); + auto reinterpret = builder.emitReinterpret(param->getDataType(), arg); + propagationMap[reinterpret] = paramInfo; + return reinterpret; // Return the reinterpret instruction + } + } + + return arg; // Can use as-is. +} + +void DynamicInstLoweringContext::insertReinterpretsForPhiParameters() +{ + // Process each function in the module + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as(inst)) + { + // Skip the first block as it contains function parameters, not phi parameters + for (auto block = func->getFirstBlock()->getNextBlock(); block; + block = block->getNextBlock()) + { + // Process each parameter in this block (these are phi parameters) + for (auto param : block->getParams()) + { + auto paramInfo = tryGetInfo(param); + if (!paramInfo) + continue; + + // Check all predecessors and their corresponding arguments + Index paramIndex = 0; + for (auto p : block->getParams()) + { + if (p == param) + break; + paramIndex++; + } + + // Find all predecessors of this block + for (auto pred : block->getPredecessors()) + { + auto terminator = pred->getTerminator(); + if (!terminator) + continue; + + if (auto unconditionalBranch = as(terminator)) + { + // Get the argument at the same index as this parameter + if (paramIndex < unconditionalBranch->getArgCount()) + { + auto arg = unconditionalBranch->getArg(paramIndex); + auto newArg = maybeReinterpretArg(arg, param); + + if (newArg != arg) + { + // Replace the argument in the branch instruction + unconditionalBranch->setOperand(1 + paramIndex, newArg); + } + } + } + } + } + } + } + } +} + +void DynamicInstLoweringContext::processInstForPropagation(IRInst* inst) +{ + RefPtr info; + + switch (inst->getOp()) + { + case kIROp_CreateExistentialObject: + info = analyzeCreateExistentialObject(as(inst)); + break; + case kIROp_MakeExistential: + info = analyzeMakeExistential(as(inst)); + break; + case kIROp_LookupWitnessMethod: + info = analyzeLookupWitnessMethod(as(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + info = analyzeExtractExistentialWitnessTable(as(inst)); + break; + case kIROp_ExtractExistentialType: + info = analyzeExtractExistentialType(as(inst)); + break; + case kIROp_ExtractExistentialValue: + info = analyzeExtractExistentialValue(as(inst)); + break; + case kIROp_Call: + info = analyzeCall(as(inst)); + break; + default: + info = analyzeDefault(inst); + break; + } + + propagationMap[inst] = info; +} + +RefPtr DynamicInstLoweringContext::analyzeCreateExistentialObject( + IRCreateExistentialObject* inst) +{ + // For now, error out as specified + SLANG_UNIMPLEMENTED_X("IRCreateExistentialObject lowering not yet implemented"); + return PropagationInfo::makeValue(); +} + +RefPtr DynamicInstLoweringContext::analyzeMakeExistential(IRMakeExistential* inst) +{ + auto witnessTable = inst->getWitnessTable(); + auto value = inst->getWrappedValue(); + auto valueType = value->getDataType(); + + // Get the witness table info + auto witnessTableInfo = tryGetInfo(witnessTable); + if (!witnessTableInfo || witnessTableInfo->judgment == PropagationJudgment::UnknownSet) + { + return PropagationInfo::makeUnknown(); + } + + HashSet tables; + + if (witnessTableInfo->judgment == PropagationJudgment::ConcreteTable) + { + tables.add(witnessTableInfo->concreteValue); + } + else if (witnessTableInfo->judgment == PropagationJudgment::SetOfTables) + { + for (auto table : witnessTableInfo->possibleValues) + { + tables.add(table); + } + } + + return PropagationInfo::makeExistential(tables); +} + +static IRInst* lookupEntry(IRInst* witnessTable, IRInst* key) +{ + if (auto concreteTable = as(witnessTable)) + { + for (auto entry : concreteTable->getEntries()) + { + if (entry->getRequirementKey() == key) + { + return entry->getSatisfyingVal(); + } + } + } + return nullptr; // Not found +} + +RefPtr DynamicInstLoweringContext::analyzeLookupWitnessMethod( + IRLookupWitnessMethod* inst) +{ + auto witnessTable = inst->getWitnessTable(); + auto key = inst->getRequirementKey(); + auto witnessTableInfo = tryGetInfo(witnessTable); + + if (!witnessTableInfo || witnessTableInfo->judgment == PropagationJudgment::UnknownSet) + { + return PropagationInfo::makeUnknown(); + } + + HashSet results; + + if (witnessTableInfo->judgment == PropagationJudgment::ConcreteTable) + { + results.add(lookupEntry(witnessTableInfo->concreteValue, key)); + } + else if (witnessTableInfo->judgment == PropagationJudgment::SetOfTables) + { + for (auto table : witnessTableInfo->possibleValues) + { + results.add(lookupEntry(table, key)); + } + } + + if (witnessTableInfo->judgment == PropagationJudgment::ConcreteTable) + { + if (as(inst->getDataType())) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteFunc, + *results.begin()); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteType, + *results.begin()); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteTable, + *results.begin()); + } + else + { + SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + } + } + else + { + if (auto funcType = as(inst->getDataType())) + { + return PropagationInfo::makeSetOfFuncs(results, funcType); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, results); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, results); + } + else + { + SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + } + } +} + +RefPtr DynamicInstLoweringContext::analyzeExtractExistentialWitnessTable( + IRExtractExistentialWitnessTable* inst) +{ + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(operand); + + if (!operandInfo || operandInfo->judgment == PropagationJudgment::UnknownSet) + { + return PropagationInfo::makeUnknown(); + } + + if (operandInfo->judgment == PropagationJudgment::Existential) + { + HashSet tables; + for (auto table : operandInfo->possibleValues) + { + tables.add(table); + } + + if (tables.getCount() == 1) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteTable, + *tables.begin()); + } + else + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, tables); + } + } + + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); +} + +RefPtr DynamicInstLoweringContext::analyzeExtractExistentialType( + IRExtractExistentialType* inst) +{ + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(operand); + + if (!operandInfo || operandInfo->judgment == PropagationJudgment::UnknownSet) + { + return PropagationInfo::makeUnknown(); + } + + if (operandInfo->judgment == PropagationJudgment::Existential) + { + HashSet types; + // Extract types from witness tables by looking at the concrete types + for (auto table : operandInfo->possibleValues) + { + // Get the concrete type from the witness table + if (auto witnessTable = as(table)) + { + if (auto concreteType = witnessTable->getConcreteType()) + { + types.add(concreteType); + } + } + else + { + SLANG_UNEXPECTED("Expected witness table in existential extraction base type"); + } + } + + if (types.getCount() == 0) + { + // No concrete types found, treat as this instruction + types.add(inst); + } + + if (types.getCount() == 1) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, *types.begin()); + } + else + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, types); + } + } + + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); +} + +RefPtr DynamicInstLoweringContext::analyzeExtractExistentialValue( + IRExtractExistentialValue* inst) +{ + // The value itself is just a regular value + return PropagationInfo::makeValue(); +} + +RefPtr DynamicInstLoweringContext::analyzeCall(IRCall* inst) +{ + auto callee = inst->getCallee(); + auto calleeInfo = tryGetInfo(callee); + + List outputVals; + outputVals.add(inst); // The call inst itself is an output. + + // Also add all OutTypeBase parameters + auto funcType = as(callee->getDataType()); + if (funcType) + { + UIndex paramIndex = 0; + for (auto paramType : funcType->getParamTypes()) + { + if (as(paramType)) + { + // If this is an OutTypeBase, we consider it an output + outputVals.add(inst->getArg(paramIndex)); + } + paramIndex++; + } + } + + for (auto outputVal : outputVals) + { + if (as(outputVal)) + { + // TODO: We need to set up infrastructure to track variable + // assignments. + // For now, we will just return a value judgment for variables. + // (doesn't make much sense.. but its fine for now) + // + propagationMap[outputVal] = PropagationInfo::makeValue(); + } + else + { + if (auto interfaceType = as(outputVal->getDataType())) + { + if (!interfaceType->findDecoration()) + { + // If this is an interface type, we need to propagate existential info + // based on the interface type. + propagationMap[outputVal] = + PropagationInfo::makeExistential(collectExistentialTables(interfaceType)); + } + else + { + // If this is a COM interface, we treat it as unknown + propagationMap[outputVal] = PropagationInfo::makeUnknown(); + } + } + else + { + propagationMap[outputVal] = PropagationInfo::makeValue(); + } + } + } + + return tryGetInfo(inst); +} + + +void DynamicInstLoweringContext::processFunction(IRFunc* func) +{ + // Initialize parameter info for the first block + initializeFirstBlockParameters(func); + + // Initialize worklist with edges from successor blocks + LinkedList edgeQueue; + + auto processBlock = [&](IRBlock* block) + { + bool anyInfoChanged = false; + + for (auto inst : block->getChildren()) + { + // Skip parameters & terminator + if (as(inst) || as(inst)) + continue; + + auto oldInfo = tryGetInfo(inst); + processInstForPropagation(inst); + auto newInfo = tryGetInfo(inst); + + // Check if information changed + if (!anyInfoChanged) + { + if (!oldInfo || (newInfo && !areInfosEqual(oldInfo, newInfo))) + { + anyInfoChanged = true; + } + } + } + + // If any info changed, add successor edges back to queue + if (anyInfoChanged) + { + auto successors = block->getSuccessors(); + for (auto succIter = successors.begin(); succIter != successors.end(); ++succIter) + { + edgeQueue.addLast(succIter.getEdge()); + } + } + }; + + // Start processing from the first block + processBlock(func->getFirstBlock()); + + // Process until fixed point + while (edgeQueue.getCount() > 0) + { + // Pop edge from front + auto edge = edgeQueue.getFirst(); + edgeQueue.getFirstNode()->removeAndDelete(); + + // Propagate along the edge + propagateEdge(edge); + + // Process the successor block's instructions. + // This will also add any new edges to the queue if info changed + // + processBlock(edge.getSuccessor()); + } +} // namespace Slang + +void DynamicInstLoweringContext::propagateEdge(IREdge edge) +{ + auto predecessorBlock = edge.getPredecessor(); + auto successorBlock = edge.getSuccessor(); + + // Get the terminator instruction and extract arguments + auto terminator = predecessorBlock->getTerminator(); + if (!terminator) + return; + + // Handle different types of branch instructions + if (auto unconditionalBranch = as(terminator)) + { + // Find which successor this edge leads to (should be the target) + if (unconditionalBranch->getTargetBlock() != successorBlock) + return; + + // Collect propagation info for each argument and update corresponding parameter + Index paramIndex = 0; + for (auto param : successorBlock->getParams()) + { + if (paramIndex < unconditionalBranch->getArgCount()) + { + auto arg = unconditionalBranch->getArg(paramIndex); + auto argInfo = tryGetInfo(arg); + + if (argInfo) + { + // Union with existing parameter info + auto existingInfo = tryGetInfo(param); + if (existingInfo) + { + List> infos; + infos.add(existingInfo); + infos.add(argInfo); + propagationMap[param] = unionPropagationInfo(infos); + } + else + { + propagationMap[param] = argInfo; + } + } + } + paramIndex++; + } + } +} + +void DynamicInstLoweringContext::initializeFirstBlockParameters(IRFunc* func) +{ + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return; + + // Initialize parameters based on their types + for (auto param : firstBlock->getParams()) + { + auto paramType = param->getDataType(); + + if (auto interfaceType = as(paramType)) + { + if (!interfaceType->findDecoration()) + propagationMap[param] = + PropagationInfo::makeExistential(collectExistentialTables(interfaceType)); + else + propagationMap[param] = PropagationInfo::makeUnknown(); + } + else if (as(paramType)) + { + propagationMap[param] = PropagationInfo::makeUnknown(); + } + else if (as(paramType)) + { + propagationMap[param] = PropagationInfo::makeUnknown(); + } + else if (as(paramType) || as(paramType)) + { + propagationMap[param] = PropagationInfo::makeUnknown(); + } + else + { + propagationMap[param] = PropagationInfo::makeValue(); + } + } +} + +RefPtr DynamicInstLoweringContext::unionPropagationInfo( + const List>& infos) +{ + if (infos.getCount() == 0) + { + return PropagationInfo::makeValue(); + } + + if (infos.getCount() == 1) + { + return infos[0]; + } + + // Check if all infos are the same + bool allSame = true; + for (Index i = 1; i < infos.getCount(); i++) + { + if (!areInfosEqual(infos[0], infos[i])) + { + allSame = false; + break; + } + } + + if (allSame) + { + return infos[0]; + } + + // Need to create a union - collect all possible values based on judgment types + HashSet allValues; + IRFuncType* dynFuncType = nullptr; + PropagationJudgment unionJudgment = PropagationJudgment::Value; + + // Determine the union judgment type and collect values + for (auto info : infos) + { + switch (info->judgment) + { + case PropagationJudgment::ConcreteType: + unionJudgment = PropagationJudgment::SetOfTypes; + allValues.add(info->concreteValue); + break; + case PropagationJudgment::ConcreteTable: + unionJudgment = PropagationJudgment::SetOfTables; + allValues.add(info->concreteValue); + break; + case PropagationJudgment::ConcreteFunc: + unionJudgment = PropagationJudgment::SetOfFuncs; + allValues.add(info->concreteValue); + break; + case PropagationJudgment::SetOfTypes: + unionJudgment = PropagationJudgment::SetOfTypes; + for (auto value : info->possibleValues) + allValues.add(value); + break; + case PropagationJudgment::SetOfTables: + unionJudgment = PropagationJudgment::SetOfTables; + for (auto value : info->possibleValues) + allValues.add(value); + break; + case PropagationJudgment::SetOfFuncs: + unionJudgment = PropagationJudgment::SetOfFuncs; + for (auto value : info->possibleValues) + allValues.add(value); + if (!dynFuncType) + { + // If we haven't set a function type yet, use the first one + dynFuncType = info->dynFuncType; + } + else if (dynFuncType != info->dynFuncType) + { + SLANG_UNEXPECTED( + "Mismatched function types in union propagation info for SetOfFuncs"); + } + + break; + case PropagationJudgment::Value: + // Value judgments don't contribute to the union + break; + case PropagationJudgment::Existential: + // For existential union, we need to collect all witness tables + // For now, we'll handle this properly by creating a new existential with all tables + { + HashSet allTables; + for (auto otherInfo : infos) + { + if (otherInfo->judgment == PropagationJudgment::Existential) + { + for (auto table : otherInfo->possibleValues) + { + allTables.add(table); + } + } + } + if (allTables.getCount() > 0) + { + return PropagationInfo::makeExistential(allTables); + } + } + return PropagationInfo::makeValue(); + case PropagationJudgment::UnknownSet: + // If any info is unknown, the union is unknown + return PropagationInfo::makeUnknown(); + } + } + + // If we collected values, create a set; otherwise return value + if (allValues.getCount() > 0) + { + if (unionJudgment == PropagationJudgment::SetOfFuncs && dynFuncType) + return PropagationInfo::makeSetOfFuncs(allValues, dynFuncType); + + return PropagationInfo::makeSet(unionJudgment, allValues); + } + else + { + return PropagationInfo::makeValue(); + } +} + +RefPtr DynamicInstLoweringContext::analyzeDefault(IRInst* inst) +{ + // Check if this is a type, witness table, or function + if (as(inst)) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); + } + else if (as(inst)) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); + } + else if (as(inst)) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, inst); + } + else + { + return PropagationInfo::makeValue(); + } +} + +void DynamicInstLoweringContext::performDynamicInstLowering() +{ + // Collect all instructions that need lowering + List typeInstsToLower; + List valueInstsToLower; + List instWithReplacementTypes; + + for (auto globalInst : module->getGlobalInsts()) + { + if (auto func = as(globalInst)) + { + // Process each function's instructions + for (auto block : func->getBlocks()) + { + for (auto child : block->getChildren()) + { + if (as(child)) + continue; // Skip parameters and terminators + + switch (child->getOp()) + { + case kIROp_LookupWitnessMethod: + { + if (child->getDataType()->getOp() == kIROp_TypeKind) + typeInstsToLower.add(child); + else + valueInstsToLower.add(child); + break; + } + case kIROp_ExtractExistentialType: + typeInstsToLower.add(child); + break; + case kIROp_ExtractExistentialWitnessTable: + case kIROp_ExtractExistentialValue: + case kIROp_Call: + case kIROp_MakeExistential: + case kIROp_CreateExistentialObject: + valueInstsToLower.add(child); + break; + default: + if (auto info = tryGetInfo(child)) + { + if (info->judgment == PropagationJudgment::Existential) + { + // If this instruction has a set of types, tables, or funcs, + // we need to lower it to a unified type. + instWithReplacementTypes.add(child); + } + } + } + } + } + } + } + + for (auto inst : typeInstsToLower) + lowerInst(inst); + + for (auto inst : valueInstsToLower) + lowerInst(inst); + + for (auto inst : instWithReplacementTypes) + replaceType(inst); +} + +void DynamicInstLoweringContext::replaceType(IRInst* inst) +{ + auto info = tryGetInfo(inst); + if (!info || info->judgment != PropagationJudgment::Existential) + return; + + // Replace type with Tuple + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto anyValueType = createAnyValueTypeFromInsts(info->possibleValues); + auto tupleType = builder.getTupleType(List({builder.getUIntType(), anyValueType})); + inst->setFullType(tupleType); +} + +void DynamicInstLoweringContext::lowerInst(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_LookupWitnessMethod: + lowerLookupWitnessMethod(as(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + lowerExtractExistentialWitnessTable(as(inst)); + break; + case kIROp_ExtractExistentialType: + lowerExtractExistentialType(as(inst)); + break; + case kIROp_ExtractExistentialValue: + lowerExtractExistentialValue(as(inst)); + break; + case kIROp_Call: + lowerCall(as(inst)); + break; + case kIROp_MakeExistential: + lowerMakeExistential(as(inst)); + break; + case kIROp_CreateExistentialObject: + lowerCreateExistentialObject(as(inst)); + break; + } +} + +void DynamicInstLoweringContext::lowerLookupWitnessMethod(IRLookupWitnessMethod* inst) +{ + auto info = tryGetInfo(inst); + if (!info) + return; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Check if this is a TypeKind data type with SetOfTypes judgment + if (info->judgment == PropagationJudgment::SetOfTypes && + inst->getDataType()->getOp() == kIROp_TypeKind) + { + // Create an any-value type based on the set of types + auto anyValueType = createAnyValueTypeFromInsts(info->possibleValues); + + // Store the mapping for later use + loweredInstToAnyValueType[inst] = anyValueType; + + // Replace the instruction with the any-value type + inst->replaceUsesWith(anyValueType); + inst->removeAndDeallocate(); + return; + } + + if (info->judgment == PropagationJudgment::SetOfTables || + info->judgment == PropagationJudgment::SetOfFuncs) + { + // Get the witness table operand info + auto witnessTableInst = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(witnessTableInst); + + if (witnessTableInfo && witnessTableInfo->judgment == PropagationJudgment::SetOfTables) + { + // Create a key mapping function + auto keyMappingFunc = createKeyMappingFunc( + inst->getRequirementKey(), + witnessTableInfo->possibleValues, + info->possibleValues); + + // Replace with call to key mapping function + auto witnessTableId = builder.emitCallInst( + builder.getUIntType(), + keyMappingFunc, + List({inst->getWitnessTable()})); + inst->replaceUsesWith(witnessTableId); + propagationMap[witnessTableId] = info; + inst->removeAndDeallocate(); + } + } +} + +void DynamicInstLoweringContext::lowerExtractExistentialWitnessTable( + IRExtractExistentialWitnessTable* inst) +{ + auto info = tryGetInfo(inst); + if (!info) + return; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + if (info->judgment == PropagationJudgment::SetOfTables) + { + // Replace with GetElement(loweredInst, 0) -> uint + auto operand = inst->getOperand(0); + auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); + inst->replaceUsesWith(element); + propagationMap[element] = info; + inst->removeAndDeallocate(); + } +} + +void DynamicInstLoweringContext::lowerExtractExistentialValue(IRExtractExistentialValue* inst) +{ + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Check if we have a lowered any-value type for the result + auto resultType = inst->getDataType(); + auto loweredType = loweredInstToAnyValueType.tryGetValue(inst); + if (loweredType) + { + resultType = *loweredType; + } + + // Replace with GetElement(loweredInst, 1) -> AnyValueType + auto operand = inst->getOperand(0); + auto element = builder.emitGetTupleElement(resultType, operand, 1); + inst->replaceUsesWith(element); + inst->removeAndDeallocate(); +} + +void DynamicInstLoweringContext::lowerExtractExistentialType(IRExtractExistentialType* inst) +{ + auto info = tryGetInfo(inst); + if (!info || info->judgment != PropagationJudgment::SetOfTypes) + return; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Create an any-value type based on the set of types + auto anyValueType = createAnyValueTypeFromInsts(info->possibleValues); + + // Store the mapping for later use + loweredInstToAnyValueType[inst] = anyValueType; + + // Replace the instruction with the any-value type + inst->replaceUsesWith(anyValueType); + inst->removeAndDeallocate(); +} + +void DynamicInstLoweringContext::lowerCall(IRCall* inst) +{ + auto callee = inst->getCallee(); + auto calleeInfo = tryGetInfo(callee); + + if (!calleeInfo || calleeInfo->judgment != PropagationJudgment::SetOfFuncs) + return; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Create dispatch function + auto dispatchFunc = createDispatchFunc(calleeInfo->possibleValues, calleeInfo->dynFuncType); + + // Replace call with dispatch + List newArgs; + newArgs.add(callee); // Add the lookup as first argument (will get lowered into an uint tag) + for (UInt i = 1; i < inst->getOperandCount(); i++) + { + newArgs.add(inst->getOperand(i)); + } + + auto newCall = builder.emitCallInst(inst->getDataType(), dispatchFunc, newArgs); + inst->replaceUsesWith(newCall); + if (auto info = tryGetInfo(inst)) + propagationMap[newCall] = info; + inst->removeAndDeallocate(); +} + +void DynamicInstLoweringContext::lowerMakeExistential(IRMakeExistential* inst) +{ + auto info = tryGetInfo(inst); + if (!info || info->judgment != PropagationJudgment::Existential) + return; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get unique ID for the witness table. TODO: Assert that this is a concrete table.. + auto witnessTable = cast(inst->getWitnessTable()); + auto tableId = builder.getIntValue(builder.getUIntType(), getUniqueID(witnessTable)); + + // Collect types from the witness tables to determine the any-value type + HashSet types; + for (auto table : info->possibleValues) + { + if (auto witnessTableInst = as(table)) + { + if (auto concreteType = witnessTableInst->getConcreteType()) + { + types.add(concreteType); + } + } + } + + // Create the appropriate any-value type + auto anyValueType = createAnyValueType(types); + + // Pack the value + auto packedValue = builder.emitPackAnyValue(anyValueType, inst->getWrappedValue()); + + // Create tuple (table_unique_id, PackAnyValue(val)) + auto tupleType = + builder.getTupleType(List({builder.getUIntType(), packedValue->getDataType()})); + IRInst* tupleArgs[] = {tableId, packedValue}; + auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); + + inst->replaceUsesWith(tuple); + inst->removeAndDeallocate(); +} + +void DynamicInstLoweringContext::lowerCreateExistentialObject(IRCreateExistentialObject* inst) +{ + // Error out for now as specified + sink->diagnose( + inst, + Diagnostics::unimplemented, + "IRCreateExistentialObject lowering not yet implemented"); +} + +UInt DynamicInstLoweringContext::getUniqueID(IRInst* funcOrTable) +{ + auto existingId = uniqueIds.tryGetValue(funcOrTable); + if (existingId) + return *existingId; + + UInt newId = nextUniqueId++; + uniqueIds[funcOrTable] = newId; + return newId; +} + +IRFunc* DynamicInstLoweringContext::createKeyMappingFunc( + IRInst* key, + const HashSet& inputTables, + const HashSet& outputVals) +{ + // Create a function that maps input IDs to output IDs + IRBuilder builder(module); + + auto funcType = + builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto param = builder.emitParam(builder.getUIntType()); + + // Create default block that returns 0 + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), 0)); + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each input table + List caseValues; + List caseBlocks; + + // Build mapping from input tables to output values + List inputTableArray; + List outputValArray; + + for (auto table : inputTables) + inputTableArray.add(table); + for (auto table : outputVals) + outputValArray.add(table); + + for (Index i = 0; i < inputTableArray.getCount(); i++) + { + auto inputTable = inputTableArray[i]; + auto inputId = getUniqueID(inputTable); + + // Find corresponding output table (for now, use simple 1:1 mapping) + IRInst* outputVal = nullptr; + if (i < outputValArray.getCount()) + { + outputVal = outputValArray[i]; + } + else if (outputValArray.getCount() > 0) + { + outputVal = outputValArray[0]; // Fallback to first output + } + + if (outputVal) + { + auto outputId = getUniqueID(outputVal); + + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), outputId)); + + caseValues.add(builder.getIntValue(builder.getUIntType(), inputId)); + caseBlocks.add(caseBlock); + } + } + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Emit an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + param, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + +IRFunc* DynamicInstLoweringContext::createDispatchFunc( + const HashSet& funcs, + IRFuncType* expectedFuncType) +{ + // Create a dispatch function with switch-case for each function + IRBuilder builder(module); + + // Extract parameter types from the first function in the set + List paramTypes; + paramTypes.add(builder.getUIntType()); // ID parameter + + // Get parameter types from first function + List funcArray; + for (auto func : funcs) + funcArray.add(func); + + for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) + { + paramTypes.add(expectedFuncType->getParamType(i)); + } + + auto resultType = expectedFuncType->getResultType(); + auto funcType = builder.getFuncType(paramTypes, resultType); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto idParam = builder.emitParam(builder.getUIntType()); + + // Create parameters for the original function arguments + List originalParams; + for (UInt i = 1; i < paramTypes.getCount(); i++) + { + originalParams.add(builder.emitParam(paramTypes[i])); + } + + // Create default block + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + // Return a default-constructed value + auto defaultValue = builder.emitDefaultConstruct(resultType); + builder.emitReturn(defaultValue); + } + + auto maybePackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* + { + // If the type is AnyValueType, pack the value + if (as(type) && !as(value->getDataType())) + { + return builder->emitPackAnyValue(type, value); + } + return value; // Otherwise, return as is + }; + + auto maybeUnpackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* + { + // If the type is AnyValueType, unpack the value + if (as(value->getDataType()) && !as(type)) + { + return builder->emitUnpackAnyValue(type, value); + } + return value; // Otherwise, return as is + }; + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each function + List caseValues; + List caseBlocks; + + for (auto funcInst : funcs) + { + auto funcId = getUniqueID(funcInst); + + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + + List callArgs; + auto concreteFuncType = as(funcInst->getDataType()); + for (UIndex ii = 0; ii < originalParams.getCount(); ii++) + { + callArgs.add( + maybeUnpackValue(&builder, originalParams[ii], concreteFuncType->getParamType(ii))); + } + + // Call the specific function + auto callResult = builder.emitCallInst(resultType, funcInst, callArgs); + + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + builder.emitReturn(maybePackValue(&builder, callResult, resultType)); + } + + caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); + caseBlocks.add(caseBlock); + } + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Create an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + idParam, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + +IRAnyValueType* DynamicInstLoweringContext::createAnyValueType(const HashSet& types) +{ + IRBuilder builder(module); + auto size = calculateAnyValueSize(types); + return builder.getAnyValueType(size); +} + +IRAnyValueType* DynamicInstLoweringContext::createAnyValueTypeFromInsts( + const HashSet& typeInsts) +{ + HashSet types; + for (auto inst : typeInsts) + { + if (auto type = as(inst)) + { + types.add(type); + } + } + return createAnyValueType(types); +} + +SlangInt DynamicInstLoweringContext::calculateAnyValueSize(const HashSet& types) +{ + SlangInt maxSize = 0; + for (auto type : types) + { + auto size = getAnyValueSize(type); + if (size > maxSize) + maxSize = size; + } + return maxSize; +} + +bool DynamicInstLoweringContext::needsReinterpret( + RefPtr sourceInfo, + RefPtr targetInfo) +{ + if (!sourceInfo || !targetInfo) + return false; + + // Check if both are SetOfTypes with different sets + if (sourceInfo->judgment == PropagationJudgment::SetOfTypes && + targetInfo->judgment == PropagationJudgment::SetOfTypes) + { + if (sourceInfo->possibleValues.getCount() != targetInfo->possibleValues.getCount()) + return true; + } + + // Check if both are Existential with different witness table sets + if (sourceInfo->judgment == PropagationJudgment::Existential && + targetInfo->judgment == PropagationJudgment::Existential) + { + if (sourceInfo->possibleValues.getCount() != targetInfo->possibleValues.getCount()) + return true; + } + + return false; +} + +bool DynamicInstLoweringContext::isExistentialType(IRType* type) +{ + return as(type) != nullptr; +} + +bool DynamicInstLoweringContext::isInterfaceType(IRType* type) +{ + return as(type) != nullptr; +} + +HashSet DynamicInstLoweringContext::collectExistentialTables( + IRInterfaceType* interfaceType) +{ + HashSet tables; + + IRWitnessTableType* targetTableType = nullptr; + // First, find the IRWitnessTableType that wraps the given interfaceType + for (auto use = interfaceType->firstUse; use; use = use->nextUse) + { + if (auto wtType = as(use->getUser())) + { + if (wtType->getConformanceType() == interfaceType) + { + targetTableType = wtType; + break; + } + } + } + + // If the target witness table type was found, gather all witness tables using it + if (targetTableType) + { + for (auto use = targetTableType->firstUse; use; use = use->nextUse) + { + if (auto witnessTable = as(use->getUser())) + { + if (witnessTable->getDataType() == targetTableType) + { + tables.add(witnessTable); + } + } + } + } + + return tables; +} + +void DynamicInstLoweringContext::processModule() +{ + // Phase 1: Information Propagation + performInformationPropagation(); + + // Phase 1.5: Insert reinterprets for phi parameters where needed + insertReinterpretsForPhiParameters(); + + // Phase 2: Dynamic Instruction Lowering + performDynamicInstLowering(); +} + +// Main entry point +void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) +{ + DynamicInstLoweringContext context(module, sink); + context.processModule(); +} +} // namespace Slang diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h new file mode 100644 index 00000000000..8fb551e88d6 --- /dev/null +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -0,0 +1,141 @@ +// slang-ir-lower-dynamic-insts.h +#pragma once +#include "../core/slang-linked-list.h" +#include "../core/slang-smart-pointer.h" +#include "slang-ir.h" + +namespace Slang +{ + +// Enumeration for different kinds of judgments about IR instructions +enum class PropagationJudgment +{ + Value, // Regular value computation (not interface-related) + ConcreteType, // Concrete type reference + ConcreteTable, // Concrete witness table reference + ConcreteFunc, // Concrete function reference + SetOfTypes, // Set of possible types + SetOfTables, // Set of possible witness tables + SetOfFuncs, // Set of possible functions + UnknownSet, // Unknown set of possible types/tables/funcs (e.g. from COM interface types) + Existential // Existential box with a set of possible witness tables +}; + +// Data structure to hold propagation information for an instruction +struct PropagationInfo : RefObject +{ + PropagationJudgment judgment; + + // For concrete references + IRInst* concreteValue = nullptr; + + // For sets of types/tables/funcs and existential witness tables + HashSet possibleValues; + + // For SetOfFuncs + IRFuncType* dynFuncType; + + PropagationInfo() = default; + PropagationInfo(PropagationJudgment j) + : judgment(j) + { + } + + static RefPtr makeValue() + { + return new PropagationInfo(PropagationJudgment::Value); + } + static RefPtr makeConcrete(PropagationJudgment j, IRInst* value); + static RefPtr makeSet(PropagationJudgment j, const HashSet& values); + static RefPtr makeSetOfFuncs( + const HashSet& funcs, + IRFuncType* dynFuncType); + static RefPtr makeExistential(const HashSet& tables); + static RefPtr makeUnknown(); +}; + +// Context for the abstract interpretation pass +struct DynamicInstLoweringContext +{ + IRModule* module; + DiagnosticSink* sink; + + // Mapping from instruction to propagation information + Dictionary> propagationMap; + + // Unique ID assignment for functions and witness tables + Dictionary uniqueIds; + UInt nextUniqueId = 1; + + // Mapping from lowered instruction to their any-value types + Dictionary loweredInstToAnyValueType; + + DynamicInstLoweringContext(IRModule* inModule, DiagnosticSink* inSink) + : module(inModule), sink(inSink) + { + } + + // Phase 1: Information Propagation + void performInformationPropagation(); + void processFunction(IRFunc* func); + void propagateEdge(IREdge edge); + void processInstForPropagation(IRInst* inst); + + // Helper to get propagation info, handling global insts specially + RefPtr tryGetInfo(IRInst* inst); + + // Control flow analysis helpers + RefPtr unionPropagationInfo(const List>& infos); + void initializeFirstBlockParameters(IRFunc* func); + void insertReinterpretsForPhiParameters(); + + // Analysis of specific instruction types + RefPtr analyzeCreateExistentialObject(IRCreateExistentialObject* inst); + RefPtr analyzeMakeExistential(IRMakeExistential* inst); + RefPtr analyzeLookupWitnessMethod(IRLookupWitnessMethod* inst); + RefPtr analyzeExtractExistentialWitnessTable( + IRExtractExistentialWitnessTable* inst); + RefPtr analyzeExtractExistentialType(IRExtractExistentialType* inst); + RefPtr analyzeExtractExistentialValue(IRExtractExistentialValue* inst); + RefPtr analyzeCall(IRCall* inst); + RefPtr analyzeDefault(IRInst* inst); + + // Phase 2: Dynamic Instruction Lowering + void performDynamicInstLowering(); + void lowerInst(IRInst* inst); + void replaceType(IRInst* inst); + + // Lowering of specific instruction types + void lowerLookupWitnessMethod(IRLookupWitnessMethod* inst); + void lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst); + void lowerExtractExistentialType(IRExtractExistentialType* inst); + void lowerExtractExistentialValue(IRExtractExistentialValue* inst); + void lowerCall(IRCall* inst); + void lowerMakeExistential(IRMakeExistential* inst); + void lowerCreateExistentialObject(IRCreateExistentialObject* inst); + + // Helper functions + UInt getUniqueID(IRInst* funcOrTable); + IRFunc* createKeyMappingFunc( + IRInst* key, + const HashSet& inputTables, + const HashSet& outputTables); + IRFunc* createDispatchFunc(const HashSet& funcs, IRFuncType* expectedFuncType); + IRAnyValueType* createAnyValueType(const HashSet& types); + IRAnyValueType* createAnyValueTypeFromInsts(const HashSet& typeInsts); + SlangInt calculateAnyValueSize(const HashSet& types); + bool needsReinterpret(RefPtr sourceInfo, RefPtr targetInfo); + IRInst* maybeReinterpretArg(IRInst* arg, IRInst* param); + + // Utility functions + bool isExistentialType(IRType* type); + bool isInterfaceType(IRType* type); + HashSet collectExistentialTables(IRInterfaceType* interfaceType); + + // Main entry point + void processModule(); +}; + +// Main entry point for the pass +void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6b9273c15ea..6c26708064c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3878,7 +3878,7 @@ IRInst* IRBuilder::emitPackAnyValue(IRType* type, IRInst* value) IRInst* IRBuilder::emitUnpackAnyValue(IRType* type, IRInst* value) { - auto inst = createInst(this, kIROp_UnpackAnyValue, type, value); + auto inst = createInst(this, kIROp_UnpackAnyValue, type, value); addInst(inst); return inst; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 8e1f85f8e26..55c175a83cf 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1875,10 +1875,11 @@ struct ValLoweringVisitor : ValVisitorgetMidToSup()); } - return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( - getBuilder()->getWitnessTableType(lowerType(context, val->getSup())), - baseWitnessTable, - midToSup)); + return LoweredValInfo::simple( + getBuilder()->emitLookupInterfaceMethodInst( + getBuilder()->getWitnessTableType(lowerType(context, val->getSup())), + baseWitnessTable, + midToSup)); } LoweredValInfo visitForwardDifferentiateVal(ForwardDifferentiateVal* val) @@ -3329,6 +3330,8 @@ void collectParameterLists( auto thisType = getThisParamTypeForContainer(context, parentDeclRef); if (thisType) { + thisType = as( + thisType->substitute(getCurrentASTBuilder(), SubstitutionSet(declRef))); if (declRef.getDecl()->findModifier()) { auto noDiffAttr = context->astBuilder->getNoDiffModifierVal(); @@ -4193,18 +4196,20 @@ struct ExprLoweringVisitorBase : public ExprVisitor auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - return LoweredValInfo::simple(getBuilder()->emitForwardDifferentiateInst( - lowerType(context, expr->type), - baseVal.val)); + return LoweredValInfo::simple( + getBuilder()->emitForwardDifferentiateInst( + lowerType(context, expr->type), + baseVal.val)); } LoweredValInfo visitDetachExpr(DetachExpr* expr) { auto baseVal = lowerRValueExpr(context, expr->inner); - return LoweredValInfo::simple(getBuilder()->emitDetachDerivative( - lowerType(context, expr->type), - getSimpleVal(context, baseVal))); + return LoweredValInfo::simple( + getBuilder()->emitDetachDerivative( + lowerType(context, expr->type), + getSimpleVal(context, baseVal))); } LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) @@ -4290,9 +4295,10 @@ struct ExprLoweringVisitorBase : public ExprVisitor auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - return LoweredValInfo::simple(getBuilder()->emitBackwardDifferentiateInst( - lowerType(context, expr->type), - baseVal.val)); + return LoweredValInfo::simple( + getBuilder()->emitBackwardDifferentiateInst( + lowerType(context, expr->type), + baseVal.val)); } LoweredValInfo visitDispatchKernelExpr(DispatchKernelExpr* expr) @@ -4303,13 +4309,14 @@ struct ExprLoweringVisitorBase : public ExprVisitor auto groupSize = lowerRValueExpr(context, expr->dispatchSize); // Actual arguments to be filled in when we lower the actual call expr. // This is handled in `emitCallToVal`. - return LoweredValInfo::simple(getBuilder()->emitDispatchKernelInst( - lowerType(context, expr->type), - baseVal.val, - getSimpleVal(context, threadSize), - getSimpleVal(context, groupSize), - 0, - nullptr)); + return LoweredValInfo::simple( + getBuilder()->emitDispatchKernelInst( + lowerType(context, expr->type), + baseVal.val, + getSimpleVal(context, threadSize), + getSimpleVal(context, groupSize), + 0, + nullptr)); } LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) diff --git a/tests/language-feature/dynamic-dispatch/simple.slang b/tests/language-feature/dynamic-dispatch/simple.slang new file mode 100644 index 00000000000..229d07648db --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/simple.slang @@ -0,0 +1,41 @@ + +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return obj.calc(x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file From 41c998c2beaeeb680a6971ec838033e94d56cdfe Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Sun, 20 Jul 2025 17:09:36 -0400 Subject: [PATCH 002/105] Expand implementation to also handle inter-procedural edges --- out.hlsl | 86 + source/slang/slang-ir-lower-dynamic-insts.cpp | 2633 ++++++++++------- source/slang/slang-ir-lower-dynamic-insts.h | 130 - source/slang/slang-vscode.natvis | 818 +++++ tests/sample.slang | 54 + 5 files changed, 2495 insertions(+), 1226 deletions(-) create mode 100644 out.hlsl create mode 100644 source/slang/slang-vscode.natvis create mode 100644 tests/sample.slang diff --git a/out.hlsl b/out.hlsl new file mode 100644 index 00000000000..50b7bd1e159 --- /dev/null +++ b/out.hlsl @@ -0,0 +1,86 @@ +#pragma pack_matrix(column_major) +#ifdef SLANG_HLSL_ENABLE_NVAPI +#include "nvHLSLExtns.h" +#endif + +#ifndef __DXC_VERSION_MAJOR +// warning X3557: loop doesn't seem to do anything, forcing loop to unroll +#pragma warning(disable : 3557) +#endif + +RWStructuredBuffer outputBuffer_0 : register(u0); + +struct Tuple_0 +{ + uint value0_0; +}; + +uint _S1(uint _S2) +{ + switch(_S2) + { + case 1U: + { + return 3U; + } + case 2U: + { + return 4U; + } + default: + { + return 0U; + } + } +} + +float A_calc_0(float x_0) +{ + return x_0 * x_0 * x_0; +} + +float B_calc_0(float x_1) +{ + return x_1 * x_1; +} + +float _S3(uint _S4, float _S5) +{ + switch(_S4) + { + case 3U: + { + return A_calc_0(_S5); + } + case 4U: + { + return B_calc_0(_S5); + } + default: + { + return 0.0f; + } + } +} + +float f_0(uint id_0, float x_2) +{ + Tuple_0 obj_0; + if(id_0 == 0U) + { + obj_0.value0_0 = 1U; + } + else + { + obj_0.value0_0 = 2U; + } + return _S3(_S1(obj_0.value0_0), x_2); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID_0 : SV_DispatchThreadID) +{ + outputBuffer_0[int(0)] = f_0(0U, 1.0f); + return; +} + diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 00fcca00f22..acf1af84e24 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -7,53 +7,208 @@ namespace Slang { + + +// Enumeration for different kinds of judgments about IR instructions. +// +// This forms a lattice with +// +// None < Value +// None < ConcreteX < SetOfX < Unbounded +// None < Existential < Unbounded +// +enum class PropagationJudgment +{ + None, // No judgment (initial value) + Value, // Regular value computation (unrelated to dynamic dispatch) + ConcreteType, // Concrete type reference + ConcreteTable, // Concrete witness table reference + ConcreteFunc, // Concrete function reference + SetOfTypes, // Set of possible types + SetOfTables, // Set of possible witness tables + SetOfFuncs, // Set of possible functions + Existential, // Existential box with a set of possible witness tables + Unbounded, // Unknown set of possible types/tables/funcs (e.g. COM interface types) +}; + +// Data structure to hold propagation information for an instruction +struct PropagationInfo : RefObject +{ + PropagationJudgment judgment; + + // For concrete references + IRInst* concreteValue = nullptr; + + // For sets of types/tables/funcs and existential witness tables + HashSet possibleValues; + + // For SetOfFuncs + IRFuncType* dynFuncType; + + PropagationInfo() + : judgment(PropagationJudgment::None), concreteValue(nullptr), dynFuncType(nullptr) + { + } + + PropagationInfo(PropagationJudgment j) + : judgment(j) + { + } + + static PropagationInfo makeValue() { return PropagationInfo(PropagationJudgment::Value); } + static PropagationInfo makeConcrete(PropagationJudgment j, IRInst* value); + static PropagationInfo makeSet(PropagationJudgment j, const HashSet& values); + static PropagationInfo makeSetOfFuncs(const HashSet& funcs, IRFuncType* dynFuncType); + static PropagationInfo makeExistential(const HashSet& tables); + static PropagationInfo makeUnbounded(); + static PropagationInfo none(); + + bool isNone() const { return judgment == PropagationJudgment::None; } + + operator bool() const { return judgment != PropagationJudgment::None; } +}; + +// Data structures for interprocedural data-flow analysis + +// Represents an interprocedural edge between call sites and functions +struct InterproceduralEdge +{ + enum class Direction + { + CallToFunc, // From call site to function entry (propagating arguments) + FuncToCall // From function return to call site (propagating return value) + }; + + Direction direction; + IRCall* callInst; // The call instruction + IRFunc* targetFunc; // The function being called/returned from + + InterproceduralEdge() = default; + InterproceduralEdge(Direction dir, IRCall* call, IRFunc* func) + : direction(dir), callInst(call), targetFunc(func) + { + } +}; + +// Union type representing either an intra-procedural or interprocedural edge +struct WorkItem +{ + enum class Type + { + None, // Invalid + Block, // Propagate information within a block + IntraProc, // Propagate through within-function edge (IREdge) + InterProc // Propagate across function call/return (InterproceduralEdge) + }; + + Type type; + union + { + IRBlock* block; + IREdge intraProcEdge; + InterproceduralEdge interProcEdge; + }; + + WorkItem() + : type(Type::None) + { + } + + WorkItem(IRBlock* block) + : type(Type::Block), block(block) + { + } + + WorkItem(IREdge edge) + : type(Type::IntraProc), intraProcEdge(edge) + { + } + + WorkItem(InterproceduralEdge edge) + : type(Type::InterProc), interProcEdge(edge) + { + } + + WorkItem(InterproceduralEdge::Direction dir, IRCall* call, IRFunc* func) + : type(Type::InterProc), interProcEdge(dir, call, func) + { + } + + // Copy constructor and assignment needed for union with non-trivial types + WorkItem(const WorkItem& other) + : type(other.type) + { + if (type == Type::IntraProc) + intraProcEdge = other.intraProcEdge; + else if (type == Type::InterProc) + interProcEdge = other.interProcEdge; + else + block = other.block; + } + + WorkItem& operator=(const WorkItem& other) + { + type = other.type; + if (type == Type::IntraProc) + intraProcEdge = other.intraProcEdge; + else if (type == Type::InterProc) + interProcEdge = other.interProcEdge; + else + block = other.block; + return *this; + } +}; + // PropagationInfo implementation -RefPtr PropagationInfo::makeConcrete(PropagationJudgment j, IRInst* value) +PropagationInfo PropagationInfo::makeConcrete(PropagationJudgment j, IRInst* value) { - auto info = new PropagationInfo(j); - info->concreteValue = value; + auto info = PropagationInfo(j); + info.concreteValue = value; return info; } -RefPtr PropagationInfo::makeSet( - PropagationJudgment j, - const HashSet& values) +PropagationInfo PropagationInfo::makeSet(PropagationJudgment j, const HashSet& values) { - auto info = new PropagationInfo(j); + auto info = PropagationInfo(j); SLANG_ASSERT(j != PropagationJudgment::SetOfFuncs); - info->possibleValues = values; + info.possibleValues = values; return info; } -RefPtr PropagationInfo::makeSetOfFuncs( +PropagationInfo PropagationInfo::makeSetOfFuncs( const HashSet& values, IRFuncType* dynFuncType) { - auto info = new PropagationInfo(PropagationJudgment::SetOfFuncs); - info->possibleValues = values; - info->dynFuncType = dynFuncType; + auto info = PropagationInfo(PropagationJudgment::SetOfFuncs); + info.possibleValues = values; + info.dynFuncType = dynFuncType; return info; } -RefPtr PropagationInfo::makeExistential(const HashSet& tables) +PropagationInfo PropagationInfo::makeExistential(const HashSet& tables) { - auto info = new PropagationInfo(PropagationJudgment::Existential); - info->possibleValues = tables; + auto info = PropagationInfo(PropagationJudgment::Existential); + info.possibleValues = tables; return info; } -RefPtr PropagationInfo::makeUnknown() +PropagationInfo PropagationInfo::makeUnbounded() { - return new PropagationInfo(PropagationJudgment::UnknownSet); + return PropagationInfo(PropagationJudgment::Unbounded); } -bool areInfosEqual(const RefPtr& a, const RefPtr& b) +PropagationInfo PropagationInfo::none() { - if (a->judgment != b->judgment) + return PropagationInfo(PropagationJudgment::None); +} + +bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) +{ + if (a.judgment != b.judgment) return false; - switch (a->judgment) + switch (a.judgment) { case PropagationJudgment::Value: return true; // All value judgments are equal @@ -61,31 +216,31 @@ bool areInfosEqual(const RefPtr& a, const RefPtrconcreteValue == b->concreteValue; + return a.concreteValue == b.concreteValue; case PropagationJudgment::SetOfTypes: case PropagationJudgment::SetOfTables: case PropagationJudgment::SetOfFuncs: - if (a->possibleValues.getCount() != b->possibleValues.getCount()) + if (a.possibleValues.getCount() != b.possibleValues.getCount()) return false; - for (auto value : a->possibleValues) + for (auto value : a.possibleValues) { - if (!b->possibleValues.contains(value)) + if (!b.possibleValues.contains(value)) return false; } return true; case PropagationJudgment::Existential: - if (a->possibleValues.getCount() != b->possibleValues.getCount()) + if (a.possibleValues.getCount() != b.possibleValues.getCount()) return false; - for (auto table : a->possibleValues) + for (auto table : a.possibleValues) { - if (!b->possibleValues.contains(table)) + if (!b.possibleValues.contains(table)) return false; } return true; - case PropagationJudgment::UnknownSet: + case PropagationJudgment::Unbounded: return true; // All unknown sets are considered equal default: @@ -93,155 +248,280 @@ bool areInfosEqual(const RefPtr& a, const RefPtr DynamicInstLoweringContext::tryGetInfo(IRInst* inst) +struct DynamicInstLoweringContext { - // If this is a global instruction (parent is module), return concrete info - if (as(inst->getParent())) + // DynamicInstLoweringContext implementation + PropagationInfo tryGetInfo(IRInst* inst) { - // Create static info based on instruction type - static RefPtr typeInfo = - PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, nullptr); - static RefPtr tableInfo = - PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, nullptr); - static RefPtr funcInfo = - PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, nullptr); - static RefPtr valueInfo = PropagationInfo::makeValue(); - - if (as(inst)) + // If this is a global instruction (parent is module), return concrete info + if (as(inst->getParent())) { - typeInfo->concreteValue = inst; - return typeInfo; + if (as(inst)) + { + PropagationInfo typeInfo = + PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, nullptr); + typeInfo.concreteValue = inst; + return typeInfo; + } + else if (as(inst)) + { + PropagationInfo tableInfo = + PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, nullptr); + tableInfo.concreteValue = inst; + return tableInfo; + } + else if (as(inst)) + { + PropagationInfo funcInfo = + PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, nullptr); + funcInfo.concreteValue = inst; + return funcInfo; + } + else + { + return PropagationInfo::makeValue(); + } } - else if (as(inst)) + + // For non-global instructions, look up in the map + auto found = propagationMap.tryGetValue(inst); + if (found) + return *found; + return PropagationInfo::none(); + } + + PropagationInfo tryGetFuncReturnInfo(IRFunc* func) + { + auto found = funcReturnInfo.tryGetValue(func); + if (found) + return *found; + return PropagationInfo::none(); + } + + void processBlock(IRBlock* block, LinkedList& workQueue) + { + bool anyInfoChanged = false; + + HashSet affectedBlocks; + HashSet affectedTerminators; + for (auto inst : block->getChildren()) { - tableInfo->concreteValue = inst; - return tableInfo; + // Skip parameters & terminator + if (as(inst) || as(inst)) + continue; + + auto oldInfo = tryGetInfo(inst); + processInstForPropagation(inst, workQueue); + auto newInfo = tryGetInfo(inst); + + // If information has changed, propagate to appropriate blocks/edges + if (!areInfosEqual(oldInfo, newInfo)) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto userBlock = as(use->getUser()); + if (userBlock && userBlock != block) + affectedBlocks.add(userBlock); + + if (auto terminator = as(use->getUser())) + affectedTerminators.add(terminator); + } + } } - else if (as(inst)) + + for (auto block : affectedBlocks) { - funcInfo->concreteValue = inst; - return funcInfo; + workQueue.addLast(WorkItem(block)); } - else + + for (auto terminator : affectedTerminators) { - return valueInfo; + auto successors = as(terminator->getParent())->getSuccessors(); + for (auto succIter = successors.begin(), succEnd = successors.end(); + succIter != succEnd; + ++succIter) + { + workQueue.addLast(WorkItem(succIter.getEdge())); + } } - } - // For non-global instructions, look up in the map - auto found = propagationMap.tryGetValue(inst); - if (found) - return *found; - return nullptr; -} - -void DynamicInstLoweringContext::performInformationPropagation() -{ - // Process each function in the module - for (auto inst : module->getGlobalInsts()) - { - if (auto func = as(inst)) + if (as(block->getTerminator())) { - processFunction(func); + // If the block has a return inst, we need to propagate return values + propagateReturnValues(block, workQueue); } - } -} + }; -IRInst* DynamicInstLoweringContext::maybeReinterpretArg(IRInst* arg, IRInst* param) -{ - /*auto argTypeInfo = tryGetInfo(arg->getDataType()); - auto paramTypeInfo = tryGetInfo(param->getDataType()); + void performInformationPropagation() + { + // Global worklist for interprocedural analysis + LinkedList workQueue; - if (argTypeInfo & paramTypeInfo) - return arg; + // Add all global function entry blocks to worklist. + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as(inst)) + { + initializeFirstBlockParameters(func); - IRBuilder builder(module); - builder.setInsertAfter(arg); + // Add all blocks to start with. Once the initial + // sweep is done, propagation will proceed on an on-demand basis + // depending on affected blocks & edges + // + for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) + { + workQueue.addLast(WorkItem(block)); + } + } + } - if (argTypeInfo->judgment == PropagationJudgment::SetOfTypes && - paramTypeInfo->judgment == PropagationJudgment::SetOfTypes) - { - if (argTypeInfo->possibleValues.getCount() != paramTypeInfo->possibleValues.getCount()) + // Process until fixed point + while (workQueue.getCount() > 0) { - SLANG_ASSERT("Unhandled"); + // Pop work item from front + auto item = workQueue.getFirst(); + workQueue.getFirstNode()->removeAndDelete(); - // If the sets are not equal, reinterpret to the parameter type - // auto reinterpret = builder.emitReinterpret(param->getDataType(), arg); - // propagationMap[reinterpret] = paramTypeInfo; + switch (item.type) + { + case WorkItem::Type::Block: + processBlock(item.block, workQueue); + break; + case WorkItem::Type::IntraProc: + propagateWithinFuncEdge(item.intraProcEdge, workQueue); + break; + case WorkItem::Type::InterProc: + propagateInterproceduralEdge(item.interProcEdge, workQueue); + break; + default: + SLANG_UNEXPECTED("Unhandled work item type"); + return; + } } - }*/ + } - auto argInfo = tryGetInfo(arg); - auto paramInfo = tryGetInfo(param); + IRInst* maybeReinterpret(IRInst* arg, PropagationInfo destInfo) + { + auto argInfo = tryGetInfo(arg); - if (!argInfo || !paramInfo) - return arg; + if (!argInfo || !destInfo) + return arg; - if (argInfo->judgment == PropagationJudgment::Existential && - paramInfo->judgment == PropagationJudgment::Existential) - { - if (argInfo->possibleValues.getCount() != paramInfo->possibleValues.getCount()) + if (argInfo.judgment == PropagationJudgment::Existential && + destInfo.judgment == PropagationJudgment::Existential) { - // If the sets of witness tables are not equal, reinterpret to the parameter type - IRBuilder builder(module); - builder.setInsertAfter(arg); - auto reinterpret = builder.emitReinterpret(param->getDataType(), arg); - propagationMap[reinterpret] = paramInfo; - return reinterpret; // Return the reinterpret instruction + if (argInfo.possibleValues.getCount() != destInfo.possibleValues.getCount()) + { + // If the sets of witness tables are not equal, reinterpret to the parameter type + IRBuilder builder(module); + builder.setInsertAfter(arg); + + // We'll use nulltype for the reinterpret since the type is going to be re-written + // and if it doesn't, this will help catch it before code-gen. + // + auto reinterpret = builder.emitReinterpret(nullptr, arg); + propagationMap[reinterpret] = destInfo; + return reinterpret; // Return the reinterpret instruction + } } - } - return arg; // Can use as-is. -} + return arg; // Can use as-is. + } -void DynamicInstLoweringContext::insertReinterpretsForPhiParameters() -{ - // Process each function in the module - for (auto inst : module->getGlobalInsts()) + void insertReinterprets() { - if (auto func = as(inst)) + // Process each function in the module + for (auto inst : module->getGlobalInsts()) { - // Skip the first block as it contains function parameters, not phi parameters - for (auto block = func->getFirstBlock()->getNextBlock(); block; - block = block->getNextBlock()) + if (auto func = as(inst)) { - // Process each parameter in this block (these are phi parameters) - for (auto param : block->getParams()) + // Skip the first block as it contains function parameters, not phi parameters + for (auto block = func->getFirstBlock()->getNextBlock(); block; + block = block->getNextBlock()) { - auto paramInfo = tryGetInfo(param); - if (!paramInfo) - continue; + // Process each parameter in this block (these are phi parameters) + for (auto param : block->getParams()) + { + auto paramInfo = tryGetInfo(param); + if (!paramInfo) + continue; + + // Check all predecessors and their corresponding arguments + Index paramIndex = 0; + for (auto p : block->getParams()) + { + if (p == param) + break; + paramIndex++; + } + + // Find all predecessors of this block + for (auto pred : block->getPredecessors()) + { + auto terminator = pred->getTerminator(); + if (!terminator) + continue; + + if (auto unconditionalBranch = as(terminator)) + { + // Get the argument at the same index as this parameter + if (paramIndex < unconditionalBranch->getArgCount()) + { + auto arg = unconditionalBranch->getArg(paramIndex); + auto newArg = maybeReinterpret(arg, tryGetInfo(param)); + + if (newArg != arg) + { + // Replace the argument in the branch instruction + unconditionalBranch->setOperand(1 + paramIndex, newArg); + } + } + } + } + } - // Check all predecessors and their corresponding arguments - Index paramIndex = 0; - for (auto p : block->getParams()) + // Is the terminator a return instruction? + if (auto returnInst = as(block->getTerminator())) { - if (p == param) - break; - paramIndex++; + if (!as(returnInst->getVal()->getDataType())) + { + auto funcReturnInfo = tryGetFuncReturnInfo(func); + auto newReturnVal = + maybeReinterpret(returnInst->getVal(), funcReturnInfo); + if (newReturnVal != returnInst->getVal()) + { + // Replace the return value with the reinterpreted value + returnInst->setOperand(0, newReturnVal); + } + } } - // Find all predecessors of this block - for (auto pred : block->getPredecessors()) + List callInsts; + // Collect all call instructions in this block + for (auto inst : block->getChildren()) { - auto terminator = pred->getTerminator(); - if (!terminator) - continue; + if (auto callInst = as(inst)) + callInsts.add(callInst); + } - if (auto unconditionalBranch = as(terminator)) + // Look at all the args and reinterpret them if necessary + for (auto callInst : callInsts) + { + if (auto irFunc = as(callInst->getCallee())) { - // Get the argument at the same index as this parameter - if (paramIndex < unconditionalBranch->getArgCount()) + List params; + List args; + Index i = 0; + for (auto param : irFunc->getParams()) { - auto arg = unconditionalBranch->getArg(paramIndex); - auto newArg = maybeReinterpretArg(arg, param); - - if (newArg != arg) + auto newArg = + maybeReinterpret(callInst->getArg(i), tryGetInfo(param)); + if (newArg != callInst->getArg(i)) { - // Replace the argument in the branch instruction - unconditionalBranch->setOperand(1 + paramIndex, newArg); + // Replace the argument in the call instruction + callInst->setArg(i, newArg); } + i++; } } } @@ -249,1273 +529,1434 @@ void DynamicInstLoweringContext::insertReinterpretsForPhiParameters() } } } -} - -void DynamicInstLoweringContext::processInstForPropagation(IRInst* inst) -{ - RefPtr info; - switch (inst->getOp()) + void processInstForPropagation(IRInst* inst, LinkedList& workQueue) { - case kIROp_CreateExistentialObject: - info = analyzeCreateExistentialObject(as(inst)); - break; - case kIROp_MakeExistential: - info = analyzeMakeExistential(as(inst)); - break; - case kIROp_LookupWitnessMethod: - info = analyzeLookupWitnessMethod(as(inst)); - break; - case kIROp_ExtractExistentialWitnessTable: - info = analyzeExtractExistentialWitnessTable(as(inst)); - break; - case kIROp_ExtractExistentialType: - info = analyzeExtractExistentialType(as(inst)); - break; - case kIROp_ExtractExistentialValue: - info = analyzeExtractExistentialValue(as(inst)); - break; - case kIROp_Call: - info = analyzeCall(as(inst)); - break; - default: - info = analyzeDefault(inst); - break; - } - - propagationMap[inst] = info; -} + PropagationInfo info; -RefPtr DynamicInstLoweringContext::analyzeCreateExistentialObject( - IRCreateExistentialObject* inst) -{ - // For now, error out as specified - SLANG_UNIMPLEMENTED_X("IRCreateExistentialObject lowering not yet implemented"); - return PropagationInfo::makeValue(); -} - -RefPtr DynamicInstLoweringContext::analyzeMakeExistential(IRMakeExistential* inst) -{ - auto witnessTable = inst->getWitnessTable(); - auto value = inst->getWrappedValue(); - auto valueType = value->getDataType(); + switch (inst->getOp()) + { + case kIROp_CreateExistentialObject: + info = analyzeCreateExistentialObject(as(inst)); + break; + case kIROp_MakeExistential: + info = analyzeMakeExistential(as(inst)); + break; + case kIROp_LookupWitnessMethod: + info = analyzeLookupWitnessMethod(as(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + info = + analyzeExtractExistentialWitnessTable(as(inst)); + break; + case kIROp_ExtractExistentialType: + info = analyzeExtractExistentialType(as(inst)); + break; + case kIROp_ExtractExistentialValue: + info = analyzeExtractExistentialValue(as(inst)); + break; + case kIROp_Call: + info = analyzeCall(as(inst), workQueue); + break; + default: + info = analyzeDefault(inst); + break; + } - // Get the witness table info - auto witnessTableInfo = tryGetInfo(witnessTable); - if (!witnessTableInfo || witnessTableInfo->judgment == PropagationJudgment::UnknownSet) - { - return PropagationInfo::makeUnknown(); + propagationMap[inst] = info; } - HashSet tables; - - if (witnessTableInfo->judgment == PropagationJudgment::ConcreteTable) + PropagationInfo analyzeCreateExistentialObject(IRCreateExistentialObject* inst) { - tables.add(witnessTableInfo->concreteValue); + // For now, error out as specified + SLANG_UNIMPLEMENTED_X("IRCreateExistentialObject lowering not yet implemented"); + return PropagationInfo::makeValue(); } - else if (witnessTableInfo->judgment == PropagationJudgment::SetOfTables) + + PropagationInfo analyzeMakeExistential(IRMakeExistential* inst) { - for (auto table : witnessTableInfo->possibleValues) + auto witnessTable = inst->getWitnessTable(); + auto value = inst->getWrappedValue(); + auto valueType = value->getDataType(); + + // Get the witness table info + auto witnessTableInfo = tryGetInfo(witnessTable); + if (!witnessTableInfo || witnessTableInfo.judgment == PropagationJudgment::Unbounded) { - tables.add(table); + return PropagationInfo::makeUnbounded(); } - } - return PropagationInfo::makeExistential(tables); -} + HashSet tables; -static IRInst* lookupEntry(IRInst* witnessTable, IRInst* key) -{ - if (auto concreteTable = as(witnessTable)) - { - for (auto entry : concreteTable->getEntries()) + if (witnessTableInfo.judgment == PropagationJudgment::ConcreteTable) { - if (entry->getRequirementKey() == key) + tables.add(witnessTableInfo.concreteValue); + } + else if (witnessTableInfo.judgment == PropagationJudgment::SetOfTables) + { + for (auto table : witnessTableInfo.possibleValues) { - return entry->getSatisfyingVal(); + tables.add(table); } } - } - return nullptr; // Not found -} -RefPtr DynamicInstLoweringContext::analyzeLookupWitnessMethod( - IRLookupWitnessMethod* inst) -{ - auto witnessTable = inst->getWitnessTable(); - auto key = inst->getRequirementKey(); - auto witnessTableInfo = tryGetInfo(witnessTable); - - if (!witnessTableInfo || witnessTableInfo->judgment == PropagationJudgment::UnknownSet) - { - return PropagationInfo::makeUnknown(); + return PropagationInfo::makeExistential(tables); } - HashSet results; - - if (witnessTableInfo->judgment == PropagationJudgment::ConcreteTable) + static IRInst* lookupEntry(IRInst* witnessTable, IRInst* key) { - results.add(lookupEntry(witnessTableInfo->concreteValue, key)); - } - else if (witnessTableInfo->judgment == PropagationJudgment::SetOfTables) - { - for (auto table : witnessTableInfo->possibleValues) + if (auto concreteTable = as(witnessTable)) { - results.add(lookupEntry(table, key)); + for (auto entry : concreteTable->getEntries()) + { + if (entry->getRequirementKey() == key) + { + return entry->getSatisfyingVal(); + } + } } + return nullptr; // Not found } - if (witnessTableInfo->judgment == PropagationJudgment::ConcreteTable) + PropagationInfo analyzeLookupWitnessMethod(IRLookupWitnessMethod* inst) { - if (as(inst->getDataType())) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteFunc, - *results.begin()); - } - else if (as(inst->getDataType())) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteType, - *results.begin()); - } - else if (as(inst->getDataType())) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteTable, - *results.begin()); - } - else + auto witnessTable = inst->getWitnessTable(); + auto key = inst->getRequirementKey(); + auto witnessTableInfo = tryGetInfo(witnessTable); + + if (!witnessTableInfo || witnessTableInfo.judgment == PropagationJudgment::Unbounded) { - SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + return PropagationInfo::makeUnbounded(); } - } - else - { - if (auto funcType = as(inst->getDataType())) + + HashSet results; + + if (witnessTableInfo.judgment == PropagationJudgment::ConcreteTable) { - return PropagationInfo::makeSetOfFuncs(results, funcType); + results.add(lookupEntry(witnessTableInfo.concreteValue, key)); } - else if (as(inst->getDataType())) + else if (witnessTableInfo.judgment == PropagationJudgment::SetOfTables) { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, results); + for (auto table : witnessTableInfo.possibleValues) + { + results.add(lookupEntry(table, key)); + } } - else if (as(inst->getDataType())) + + if (witnessTableInfo.judgment == PropagationJudgment::ConcreteTable) { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, results); + if (as(inst->getDataType())) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteFunc, + *results.begin()); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteType, + *results.begin()); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteTable, + *results.begin()); + } + else + { + SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + } } else { - SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + if (auto funcType = as(inst->getDataType())) + { + return PropagationInfo::makeSetOfFuncs(results, funcType); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, results); + } + else if (as(inst->getDataType())) + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, results); + } + else + { + SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + } } } -} -RefPtr DynamicInstLoweringContext::analyzeExtractExistentialWitnessTable( - IRExtractExistentialWitnessTable* inst) -{ - auto operand = inst->getOperand(0); - auto operandInfo = tryGetInfo(operand); - - if (!operandInfo || operandInfo->judgment == PropagationJudgment::UnknownSet) + PropagationInfo analyzeExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) { - return PropagationInfo::makeUnknown(); - } + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(operand); - if (operandInfo->judgment == PropagationJudgment::Existential) - { - HashSet tables; - for (auto table : operandInfo->possibleValues) + if (!operandInfo || operandInfo.judgment == PropagationJudgment::Unbounded) { - tables.add(table); + return PropagationInfo::makeUnbounded(); } - if (tables.getCount() == 1) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteTable, - *tables.begin()); - } - else + if (operandInfo.judgment == PropagationJudgment::Existential) { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, tables); - } - } - - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); -} + HashSet tables; + for (auto table : operandInfo.possibleValues) + { + tables.add(table); + } -RefPtr DynamicInstLoweringContext::analyzeExtractExistentialType( - IRExtractExistentialType* inst) -{ - auto operand = inst->getOperand(0); - auto operandInfo = tryGetInfo(operand); + if (tables.getCount() == 1) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteTable, + *tables.begin()); + } + else + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, tables); + } + } - if (!operandInfo || operandInfo->judgment == PropagationJudgment::UnknownSet) - { - return PropagationInfo::makeUnknown(); + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); } - if (operandInfo->judgment == PropagationJudgment::Existential) + PropagationInfo analyzeExtractExistentialType(IRExtractExistentialType* inst) { - HashSet types; - // Extract types from witness tables by looking at the concrete types - for (auto table : operandInfo->possibleValues) + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(operand); + + if (!operandInfo || operandInfo.judgment == PropagationJudgment::Unbounded) + { + return PropagationInfo::makeUnbounded(); + } + + if (operandInfo.judgment == PropagationJudgment::Existential) { - // Get the concrete type from the witness table - if (auto witnessTable = as(table)) + HashSet types; + // Extract types from witness tables by looking at the concrete types + for (auto table : operandInfo.possibleValues) { - if (auto concreteType = witnessTable->getConcreteType()) + // Get the concrete type from the witness table + if (auto witnessTable = as(table)) { - types.add(concreteType); + if (auto concreteType = witnessTable->getConcreteType()) + { + types.add(concreteType); + } + } + else + { + SLANG_UNEXPECTED("Expected witness table in existential extraction base type"); } } - else + + if (types.getCount() == 0) { - SLANG_UNEXPECTED("Expected witness table in existential extraction base type"); + // No concrete types found, treat as this instruction + types.add(inst); } - } - if (types.getCount() == 0) - { - // No concrete types found, treat as this instruction - types.add(inst); + if (types.getCount() == 1) + { + return PropagationInfo::makeConcrete( + PropagationJudgment::ConcreteType, + *types.begin()); + } + else + { + return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, types); + } } - if (types.getCount() == 1) - { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, *types.begin()); - } - else - { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, types); - } + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); } - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); -} - -RefPtr DynamicInstLoweringContext::analyzeExtractExistentialValue( - IRExtractExistentialValue* inst) -{ - // The value itself is just a regular value - return PropagationInfo::makeValue(); -} + PropagationInfo analyzeExtractExistentialValue(IRExtractExistentialValue* inst) + { + // The value itself is just a regular value + return PropagationInfo::makeValue(); + } -RefPtr DynamicInstLoweringContext::analyzeCall(IRCall* inst) -{ - auto callee = inst->getCallee(); - auto calleeInfo = tryGetInfo(callee); + PropagationInfo analyzeCall(IRCall* inst, LinkedList& workQueue) + { + auto callee = inst->getCallee(); + auto calleeInfo = tryGetInfo(callee); - List outputVals; - outputVals.add(inst); // The call inst itself is an output. + auto funcType = as(callee->getDataType()); - // Also add all OutTypeBase parameters - auto funcType = as(callee->getDataType()); - if (funcType) - { - UIndex paramIndex = 0; - for (auto paramType : funcType->getParamTypes()) + // TODO: Expand logic to handle all outputs. + // For now, we're focused on the return value. + // + if (!funcType || !isExistentialType(funcType->getResultType())) { - if (as(paramType)) - { - // If this is an OutTypeBase, we consider it an output - outputVals.add(inst->getArg(paramIndex)); - } - paramIndex++; + return PropagationInfo::makeValue(); } - } - for (auto outputVal : outputVals) - { - if (as(outputVal)) + // Okay, we have an call that can return an existential type. + // + // Propagate the input judgments to the call & append a work item + // for inter-procedural propagation. + // + + // For now, we'll handle just a concrete func. But the logic for multiple functions + // is exactly the same (add an edge for each function). + // + if (calleeInfo && calleeInfo.judgment == PropagationJudgment::ConcreteFunc) { - // TODO: We need to set up infrastructure to track variable - // assignments. - // For now, we will just return a value judgment for variables. - // (doesn't make much sense.. but its fine for now) - // - propagationMap[outputVal] = PropagationInfo::makeValue(); + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::CallToFunc, inst, cast(callee))); } - else + + /*if (auto concreteFunc = as(callee)) { - if (auto interfaceType = as(outputVal->getDataType())) + auto returnInfo = tryGetFuncReturnInfo(concreteFunc); + if (returnInfo) { - if (!interfaceType->findDecoration()) - { - // If this is an interface type, we need to propagate existential info - // based on the interface type. - propagationMap[outputVal] = - PropagationInfo::makeExistential(collectExistentialTables(interfaceType)); - } - else - { - // If this is a COM interface, we treat it as unknown - propagationMap[outputVal] = PropagationInfo::makeUnknown(); - } - } - else - { - propagationMap[outputVal] = PropagationInfo::makeValue(); + // Use the function's return info directly + propagationMap[inst] = returnInfo; + return returnInfo; } } - } - - return tryGetInfo(inst); -} - - -void DynamicInstLoweringContext::processFunction(IRFunc* func) -{ - // Initialize parameter info for the first block - initializeFirstBlockParameters(func); - - // Initialize worklist with edges from successor blocks - LinkedList edgeQueue; - auto processBlock = [&](IRBlock* block) - { - bool anyInfoChanged = false; + List outputVals; + outputVals.add(inst); // The call inst itself is an output. - for (auto inst : block->getChildren()) + // Also add all OutTypeBase parameters + auto funcType = as(callee->getDataType()); + if (funcType) { - // Skip parameters & terminator - if (as(inst) || as(inst)) - continue; - - auto oldInfo = tryGetInfo(inst); - processInstForPropagation(inst); - auto newInfo = tryGetInfo(inst); - - // Check if information changed - if (!anyInfoChanged) + UIndex paramIndex = 0; + for (auto paramType : funcType->getParamTypes()) { - if (!oldInfo || (newInfo && !areInfosEqual(oldInfo, newInfo))) + if (as(paramType)) { - anyInfoChanged = true; + // If this is an OutTypeBase, we consider it an output + outputVals.add(inst->getArg(paramIndex)); } + paramIndex++; } } - // If any info changed, add successor edges back to queue - if (anyInfoChanged) + for (auto outputVal : outputVals) { - auto successors = block->getSuccessors(); - for (auto succIter = successors.begin(); succIter != successors.end(); ++succIter) + if (as(outputVal)) { - edgeQueue.addLast(succIter.getEdge()); + // TODO: We need to set up infrastructure to track variable + // assignments. + // For now, we will just return a value judgment for variables. + // (doesn't make much sense.. but its fine for now) + // + propagationMap[outputVal] = PropagationInfo::makeValue(); } - } - }; + else + { + if (auto interfaceType = as(outputVal->getDataType())) + { + if (!interfaceType->findDecoration()) + { + // If this is an interface type, we need to propagate existential info + // based on the interface type. + propagationMap[outputVal] = PropagationInfo::makeExistential( + collectExistentialTables(interfaceType)); + } + else + { + // If this is a COM interface, we treat it as unknown + propagationMap[outputVal] = PropagationInfo::makeUnbounded(); + } + } + else + { + propagationMap[outputVal] = PropagationInfo::makeValue(); + } + } + }*/ - // Start processing from the first block - processBlock(func->getFirstBlock()); + return PropagationInfo::none(); + } - // Process until fixed point - while (edgeQueue.getCount() > 0) + void propagateWithinFuncEdge(IREdge edge, LinkedList& workQueue) { - // Pop edge from front - auto edge = edgeQueue.getFirst(); - edgeQueue.getFirstNode()->removeAndDelete(); - - // Propagate along the edge - propagateEdge(edge); + // Handle intra-procedural edge (original logic) + auto predecessorBlock = edge.getPredecessor(); + auto successorBlock = edge.getSuccessor(); - // Process the successor block's instructions. - // This will also add any new edges to the queue if info changed - // - processBlock(edge.getSuccessor()); - } -} // namespace Slang - -void DynamicInstLoweringContext::propagateEdge(IREdge edge) -{ - auto predecessorBlock = edge.getPredecessor(); - auto successorBlock = edge.getSuccessor(); + // Get the terminator instruction and extract arguments + auto terminator = predecessorBlock->getTerminator(); + if (!terminator) + return; - // Get the terminator instruction and extract arguments - auto terminator = predecessorBlock->getTerminator(); - if (!terminator) - return; + // Right now, only unconditional branches can propagate new information + auto unconditionalBranch = as(terminator); + if (!unconditionalBranch) + return; - // Handle different types of branch instructions - if (auto unconditionalBranch = as(terminator)) - { // Find which successor this edge leads to (should be the target) if (unconditionalBranch->getTargetBlock() != successorBlock) return; // Collect propagation info for each argument and update corresponding parameter + HashSet affectedBlocks; Index paramIndex = 0; for (auto param : successorBlock->getParams()) { if (paramIndex < unconditionalBranch->getArgCount()) { auto arg = unconditionalBranch->getArg(paramIndex); - auto argInfo = tryGetInfo(arg); - - if (argInfo) + if (auto argInfo = tryGetInfo(arg)) { // Union with existing parameter info - auto existingInfo = tryGetInfo(param); - if (existingInfo) + bool infoChanged = false; + if (auto existingInfo = tryGetInfo(param)) { - List> infos; - infos.add(existingInfo); - infos.add(argInfo); - propagationMap[param] = unionPropagationInfo(infos); + propagationMap[param] = unionPropagationInfo(existingInfo, argInfo); + if (!infoChanged && !areInfosEqual(existingInfo, propagationMap[param])) + infoChanged = true; } else { propagationMap[param] = argInfo; + infoChanged = true; + } + // If any info changed, add all user blocks to the affected set + if (infoChanged) + { + for (auto use = param->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto block = as(user->getParent())) + affectedBlocks.add(block); + } } } } paramIndex++; } + + for (auto block : affectedBlocks) + { + workQueue.addLast(WorkItem(block)); + } } -} -void DynamicInstLoweringContext::initializeFirstBlockParameters(IRFunc* func) -{ - auto firstBlock = func->getFirstBlock(); - if (!firstBlock) - return; + void propagateInterproceduralEdge(InterproceduralEdge edge, LinkedList& workQueue) + { + // Handle interprocedural edge + auto callInst = edge.callInst; + auto targetFunc = edge.targetFunc; + + switch (edge.direction) + { + case InterproceduralEdge::Direction::CallToFunc: + { + // Propagate argument info from call site to function parameters + auto firstBlock = targetFunc->getFirstBlock(); + if (!firstBlock) + return; + + Index argIndex = 1; // Skip callee (operand 0) + HashSet affectedBlocks; + for (auto param : firstBlock->getParams()) + { + if (argIndex < callInst->getOperandCount()) + { + // TODO: handle inst-effect propagation properly + auto arg = callInst->getOperand(argIndex); + if (auto argInfo = tryGetInfo(arg)) + { + // Union with existing parameter info + auto existingInfo = tryGetInfo(param); + auto newInfo = unionPropagationInfo(tryGetInfo(param), argInfo); + propagationMap[param] = newInfo; + if (!areInfosEqual(existingInfo, newInfo)) + { + for (auto use = param->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto block = as(user->getParent())) + affectedBlocks.add(block); + } + } + } + } + argIndex++; + } + // Add the affected block to the work queue if any info changed + for (auto block : affectedBlocks) + workQueue.addLast(WorkItem(block)); + break; + } + case InterproceduralEdge::Direction::FuncToCall: + { + // Propagate return value info from function to call site + auto returnInfo = funcReturnInfo.tryGetValue(targetFunc); + + bool anyInfoChanged = false; + if (returnInfo) + { + // Union with existing call info + auto existingCallInfo = tryGetInfo(callInst); + if (existingCallInfo) + { + auto newInfo = unionPropagationInfo(existingCallInfo, *returnInfo); + propagationMap[callInst] = newInfo; + if (!anyInfoChanged && !areInfosEqual(existingCallInfo, newInfo)) + anyInfoChanged = true; + } + else + { + propagationMap[callInst] = *returnInfo; + anyInfoChanged = true; + } + } + + // Add the callInst's parent block to the work queue if any info changed + if (anyInfoChanged) + workQueue.addLast(WorkItem(as(callInst->getParent()))); + + break; + } + default: + SLANG_UNEXPECTED("Unhandled interprocedural edge direction"); + return; + } + } - // Initialize parameters based on their types - for (auto param : firstBlock->getParams()) + void initializeFirstBlockParameters(IRFunc* func) { - auto paramType = param->getDataType(); + auto firstBlock = func->getFirstBlock(); + if (!firstBlock) + return; - if (auto interfaceType = as(paramType)) + // Initialize parameters based on their types + for (auto param : firstBlock->getParams()) { - if (!interfaceType->findDecoration()) - propagationMap[param] = - PropagationInfo::makeExistential(collectExistentialTables(interfaceType)); + auto paramType = param->getDataType(); + auto paramInfo = tryGetInfo(param); + if (paramInfo) + continue; // Already has some information + + if (auto interfaceType = as(paramType)) + { + if (interfaceType->findDecoration()) + propagationMap[param] = PropagationInfo::makeUnbounded(); + else + propagationMap[param] = PropagationInfo::none(); // Initialize to none. + } + else if ( + as(paramType) || as(paramType) || + as(paramType) || as(paramType)) + { + propagationMap[param] = PropagationInfo::none(); + } else - propagationMap[param] = PropagationInfo::makeUnknown(); + { + propagationMap[param] = PropagationInfo::makeValue(); + } } - else if (as(paramType)) + } + + bool propagateReturnValues(IRBlock* block, LinkedList& workQueue) + { + auto terminator = block->getTerminator(); + auto returnInst = as(terminator); + if (!returnInst) + return false; + + auto func = as(block->getParent()); + if (!func) + return false; + + // Get return value info if there is a return value + PropagationInfo returnValueInfo; + if (returnInst->getOperandCount() > 0) { - propagationMap[param] = PropagationInfo::makeUnknown(); + auto returnValue = returnInst->getOperand(0); + returnValueInfo = tryGetInfo(returnValue); } - else if (as(paramType)) + else { - propagationMap[param] = PropagationInfo::makeUnknown(); + // Void return + returnValueInfo = PropagationInfo::makeValue(); } - else if (as(paramType) || as(paramType)) + + // Update function return info by unioning with existing info + auto existingReturnInfo = funcReturnInfo.tryGetValue(func); + bool returnInfoChanged = false; + + if (returnValueInfo) { - propagationMap[param] = PropagationInfo::makeUnknown(); + if (existingReturnInfo) + { + auto newReturnInfo = unionPropagationInfo( + List({*existingReturnInfo, returnValueInfo})); + + if (!areInfosEqual(*existingReturnInfo, newReturnInfo)) + { + funcReturnInfo[func] = newReturnInfo; + returnInfoChanged = true; + } + } + else + { + funcReturnInfo[func] = returnValueInfo; + returnInfoChanged = true; + } } - else + + // If return info changed, add return edges to call sites + if (returnInfoChanged) { - propagationMap[param] = PropagationInfo::makeValue(); - } - } -} + for (auto use = func->firstUse; use; use = use->nextUse) + { + if (auto callInst = as(use->getUser())) + { + if (callInst->getCallee() != func) + continue; // Not a call to this function -RefPtr DynamicInstLoweringContext::unionPropagationInfo( - const List>& infos) -{ - if (infos.getCount() == 0) - { - return PropagationInfo::makeValue(); - } + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::FuncToCall, callInst, func)); + } + } + } - if (infos.getCount() == 1) - { - return infos[0]; + return returnInfoChanged; } - // Check if all infos are the same - bool allSame = true; - for (Index i = 1; i < infos.getCount(); i++) + PropagationInfo unionPropagationInfo(const List& infos) { - if (!areInfosEqual(infos[0], infos[i])) + if (infos.getCount() == 0) { - allSame = false; - break; + return PropagationInfo::makeValue(); } - } - - if (allSame) - { - return infos[0]; - } - // Need to create a union - collect all possible values based on judgment types - HashSet allValues; - IRFuncType* dynFuncType = nullptr; - PropagationJudgment unionJudgment = PropagationJudgment::Value; + if (infos.getCount() == 1) + { + return infos[0]; + } - // Determine the union judgment type and collect values - for (auto info : infos) - { - switch (info->judgment) + // Check if all infos are the same + bool allSame = true; + for (Index i = 1; i < infos.getCount(); i++) { - case PropagationJudgment::ConcreteType: - unionJudgment = PropagationJudgment::SetOfTypes; - allValues.add(info->concreteValue); - break; - case PropagationJudgment::ConcreteTable: - unionJudgment = PropagationJudgment::SetOfTables; - allValues.add(info->concreteValue); - break; - case PropagationJudgment::ConcreteFunc: - unionJudgment = PropagationJudgment::SetOfFuncs; - allValues.add(info->concreteValue); - break; - case PropagationJudgment::SetOfTypes: - unionJudgment = PropagationJudgment::SetOfTypes; - for (auto value : info->possibleValues) - allValues.add(value); - break; - case PropagationJudgment::SetOfTables: - unionJudgment = PropagationJudgment::SetOfTables; - for (auto value : info->possibleValues) - allValues.add(value); - break; - case PropagationJudgment::SetOfFuncs: - unionJudgment = PropagationJudgment::SetOfFuncs; - for (auto value : info->possibleValues) - allValues.add(value); - if (!dynFuncType) - { - // If we haven't set a function type yet, use the first one - dynFuncType = info->dynFuncType; - } - else if (dynFuncType != info->dynFuncType) + if (!areInfosEqual(infos[0], infos[i])) { - SLANG_UNEXPECTED( - "Mismatched function types in union propagation info for SetOfFuncs"); + allSame = false; + break; } + } - break; - case PropagationJudgment::Value: - // Value judgments don't contribute to the union - break; - case PropagationJudgment::Existential: - // For existential union, we need to collect all witness tables - // For now, we'll handle this properly by creating a new existential with all tables + if (allSame) + { + return infos[0]; + } + + // Need to create a union - collect all possible values based on judgment types + HashSet allValues; + IRFuncType* dynFuncType = nullptr; + PropagationJudgment unionJudgment = PropagationJudgment::None; + + // Determine the union judgment type and collect values + for (auto info : infos) + { + switch (info.judgment) { - HashSet allTables; - for (auto otherInfo : infos) + case PropagationJudgment::ConcreteType: + unionJudgment = PropagationJudgment::SetOfTypes; + allValues.add(info.concreteValue); + break; + case PropagationJudgment::ConcreteTable: + unionJudgment = PropagationJudgment::SetOfTables; + allValues.add(info.concreteValue); + break; + case PropagationJudgment::ConcreteFunc: + unionJudgment = PropagationJudgment::SetOfFuncs; + allValues.add(info.concreteValue); + break; + case PropagationJudgment::SetOfTypes: + unionJudgment = PropagationJudgment::SetOfTypes; + for (auto value : info.possibleValues) + allValues.add(value); + break; + case PropagationJudgment::SetOfTables: + unionJudgment = PropagationJudgment::SetOfTables; + for (auto value : info.possibleValues) + allValues.add(value); + break; + case PropagationJudgment::SetOfFuncs: + unionJudgment = PropagationJudgment::SetOfFuncs; + for (auto value : info.possibleValues) + allValues.add(value); + if (!dynFuncType) + { + // If we haven't set a function type yet, use the first one + dynFuncType = info.dynFuncType; + } + else if (dynFuncType != info.dynFuncType) { - if (otherInfo->judgment == PropagationJudgment::Existential) + SLANG_UNEXPECTED( + "Mismatched function types in union propagation info for SetOfFuncs"); + } + + break; + case PropagationJudgment::Value: + if (unionJudgment == PropagationJudgment::None) + unionJudgment = PropagationJudgment::Value; + else + { + SLANG_ASSERT(unionJudgment == PropagationJudgment::Value); + } + break; + case PropagationJudgment::None: + // None judgments are basically 'empty' + break; + case PropagationJudgment::Existential: + // For existential union, we need to collect all witness tables + // For now, we'll handle this properly by creating a new existential with all tables + { + HashSet allTables; + for (auto otherInfo : infos) { - for (auto table : otherInfo->possibleValues) + if (otherInfo.judgment == PropagationJudgment::Existential) { - allTables.add(table); + for (auto table : otherInfo.possibleValues) + { + allTables.add(table); + } } } + if (allTables.getCount() > 0) + { + return PropagationInfo::makeExistential(allTables); + } } - if (allTables.getCount() > 0) - { - return PropagationInfo::makeExistential(allTables); - } + return PropagationInfo::none(); + case PropagationJudgment::Unbounded: + // If any info is unbounded, the union is unbounded + return PropagationInfo::makeUnbounded(); } - return PropagationInfo::makeValue(); - case PropagationJudgment::UnknownSet: - // If any info is unknown, the union is unknown - return PropagationInfo::makeUnknown(); } - } - // If we collected values, create a set; otherwise return value - if (allValues.getCount() > 0) - { - if (unionJudgment == PropagationJudgment::SetOfFuncs && dynFuncType) - return PropagationInfo::makeSetOfFuncs(allValues, dynFuncType); + // If we collected values, create a set; otherwise return value + if (allValues.getCount() > 0) + { + if (unionJudgment == PropagationJudgment::SetOfFuncs && dynFuncType) + return PropagationInfo::makeSetOfFuncs(allValues, dynFuncType); - return PropagationInfo::makeSet(unionJudgment, allValues); - } - else - { - return PropagationInfo::makeValue(); + return PropagationInfo::makeSet(unionJudgment, allValues); + } + else + { + return PropagationInfo::none(); + } } -} -RefPtr DynamicInstLoweringContext::analyzeDefault(IRInst* inst) -{ - // Check if this is a type, witness table, or function - if (as(inst)) + PropagationInfo unionPropagationInfo(PropagationInfo info1, PropagationInfo info2) { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); - } - else if (as(inst)) - { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); + // Union the two infos + List infos; + infos.add(info1); + infos.add(info2); + return unionPropagationInfo(infos); } - else if (as(inst)) - { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, inst); - } - else + + PropagationInfo analyzeDefault(IRInst* inst) { - return PropagationInfo::makeValue(); + // Check if this is a type, witness table, or function + if (as(inst)) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); + } + else if (as(inst)) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); + } + else if (as(inst)) + { + return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, inst); + } + else + { + return PropagationInfo::makeValue(); + } } -} - -void DynamicInstLoweringContext::performDynamicInstLowering() -{ - // Collect all instructions that need lowering - List typeInstsToLower; - List valueInstsToLower; - List instWithReplacementTypes; - for (auto globalInst : module->getGlobalInsts()) + void performDynamicInstLowering() { - if (auto func = as(globalInst)) + // Collect all instructions that need lowering + List typeInstsToLower; + List valueInstsToLower; + List instWithReplacementTypes; + + for (auto globalInst : module->getGlobalInsts()) { - // Process each function's instructions - for (auto block : func->getBlocks()) + if (auto func = as(globalInst)) { - for (auto child : block->getChildren()) + // Process each function's instructions + for (auto block : func->getBlocks()) { - if (as(child)) - continue; // Skip parameters and terminators - - switch (child->getOp()) + for (auto child : block->getChildren()) { - case kIROp_LookupWitnessMethod: + if (as(child)) + continue; // Skip parameters and terminators + + switch (child->getOp()) { - if (child->getDataType()->getOp() == kIROp_TypeKind) - typeInstsToLower.add(child); - else - valueInstsToLower.add(child); + case kIROp_LookupWitnessMethod: + { + if (child->getDataType()->getOp() == kIROp_TypeKind) + typeInstsToLower.add(child); + else + valueInstsToLower.add(child); + break; + } + case kIROp_ExtractExistentialType: + typeInstsToLower.add(child); break; - } - case kIROp_ExtractExistentialType: - typeInstsToLower.add(child); - break; - case kIROp_ExtractExistentialWitnessTable: - case kIROp_ExtractExistentialValue: - case kIROp_Call: - case kIROp_MakeExistential: - case kIROp_CreateExistentialObject: - valueInstsToLower.add(child); - break; - default: - if (auto info = tryGetInfo(child)) - { - if (info->judgment == PropagationJudgment::Existential) + case kIROp_ExtractExistentialWitnessTable: + case kIROp_ExtractExistentialValue: + case kIROp_Call: + case kIROp_MakeExistential: + case kIROp_CreateExistentialObject: + valueInstsToLower.add(child); + break; + default: + if (auto info = tryGetInfo(child)) { - // If this instruction has a set of types, tables, or funcs, - // we need to lower it to a unified type. - instWithReplacementTypes.add(child); + if (info.judgment == PropagationJudgment::Existential) + { + // If this instruction has a set of types, tables, or funcs, + // we need to lower it to a unified type. + instWithReplacementTypes.add(child); + } } } } } } } - } - for (auto inst : typeInstsToLower) - lowerInst(inst); + for (auto inst : typeInstsToLower) + lowerInst(inst); - for (auto inst : valueInstsToLower) - lowerInst(inst); + for (auto inst : valueInstsToLower) + lowerInst(inst); - for (auto inst : instWithReplacementTypes) - replaceType(inst); -} + for (auto inst : instWithReplacementTypes) + replaceType(inst); + } -void DynamicInstLoweringContext::replaceType(IRInst* inst) -{ - auto info = tryGetInfo(inst); - if (!info || info->judgment != PropagationJudgment::Existential) - return; - - // Replace type with Tuple - IRBuilder builder(module); - builder.setInsertBefore(inst); - auto anyValueType = createAnyValueTypeFromInsts(info->possibleValues); - auto tupleType = builder.getTupleType(List({builder.getUIntType(), anyValueType})); - inst->setFullType(tupleType); -} + void replaceType(IRInst* inst) + { + auto info = tryGetInfo(inst); + if (!info || info.judgment != PropagationJudgment::Existential) + return; -void DynamicInstLoweringContext::lowerInst(IRInst* inst) -{ - switch (inst->getOp()) + // Replace type with Tuple + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); + auto tupleType = builder.getTupleType(List({builder.getUIntType(), anyValueType})); + inst->setFullType(tupleType); + } + + void lowerInst(IRInst* inst) { - case kIROp_LookupWitnessMethod: - lowerLookupWitnessMethod(as(inst)); - break; - case kIROp_ExtractExistentialWitnessTable: - lowerExtractExistentialWitnessTable(as(inst)); - break; - case kIROp_ExtractExistentialType: - lowerExtractExistentialType(as(inst)); - break; - case kIROp_ExtractExistentialValue: - lowerExtractExistentialValue(as(inst)); - break; - case kIROp_Call: - lowerCall(as(inst)); - break; - case kIROp_MakeExistential: - lowerMakeExistential(as(inst)); - break; - case kIROp_CreateExistentialObject: - lowerCreateExistentialObject(as(inst)); - break; + switch (inst->getOp()) + { + case kIROp_LookupWitnessMethod: + lowerLookupWitnessMethod(as(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + lowerExtractExistentialWitnessTable(as(inst)); + break; + case kIROp_ExtractExistentialType: + lowerExtractExistentialType(as(inst)); + break; + case kIROp_ExtractExistentialValue: + lowerExtractExistentialValue(as(inst)); + break; + case kIROp_Call: + lowerCall(as(inst)); + break; + case kIROp_MakeExistential: + lowerMakeExistential(as(inst)); + break; + case kIROp_CreateExistentialObject: + lowerCreateExistentialObject(as(inst)); + break; + } } -} -void DynamicInstLoweringContext::lowerLookupWitnessMethod(IRLookupWitnessMethod* inst) -{ - auto info = tryGetInfo(inst); - if (!info) - return; + void lowerLookupWitnessMethod(IRLookupWitnessMethod* inst) + { + auto info = tryGetInfo(inst); + if (!info) + return; - IRBuilder builder(inst); - builder.setInsertBefore(inst); + IRBuilder builder(inst); + builder.setInsertBefore(inst); - // Check if this is a TypeKind data type with SetOfTypes judgment - if (info->judgment == PropagationJudgment::SetOfTypes && - inst->getDataType()->getOp() == kIROp_TypeKind) - { - // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(info->possibleValues); + // Check if this is a TypeKind data type with SetOfTypes judgment + if (info.judgment == PropagationJudgment::SetOfTypes && + inst->getDataType()->getOp() == kIROp_TypeKind) + { + // Create an any-value type based on the set of types + auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); - // Store the mapping for later use - loweredInstToAnyValueType[inst] = anyValueType; + // Store the mapping for later use + loweredInstToAnyValueType[inst] = anyValueType; - // Replace the instruction with the any-value type - inst->replaceUsesWith(anyValueType); - inst->removeAndDeallocate(); - return; + // Replace the instruction with the any-value type + inst->replaceUsesWith(anyValueType); + inst->removeAndDeallocate(); + return; + } + + if (info.judgment == PropagationJudgment::SetOfTables || + info.judgment == PropagationJudgment::SetOfFuncs) + { + // Get the witness table operand info + auto witnessTableInst = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(witnessTableInst); + + if (witnessTableInfo && witnessTableInfo.judgment == PropagationJudgment::SetOfTables) + { + // Create a key mapping function + auto keyMappingFunc = createKeyMappingFunc( + inst->getRequirementKey(), + witnessTableInfo.possibleValues, + info.possibleValues); + + // Replace with call to key mapping function + auto witnessTableId = builder.emitCallInst( + builder.getUIntType(), + keyMappingFunc, + List({inst->getWitnessTable()})); + inst->replaceUsesWith(witnessTableId); + propagationMap[witnessTableId] = info; + inst->removeAndDeallocate(); + } + } } - if (info->judgment == PropagationJudgment::SetOfTables || - info->judgment == PropagationJudgment::SetOfFuncs) + void lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) { - // Get the witness table operand info - auto witnessTableInst = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(witnessTableInst); - - if (witnessTableInfo && witnessTableInfo->judgment == PropagationJudgment::SetOfTables) - { - // Create a key mapping function - auto keyMappingFunc = createKeyMappingFunc( - inst->getRequirementKey(), - witnessTableInfo->possibleValues, - info->possibleValues); - - // Replace with call to key mapping function - auto witnessTableId = builder.emitCallInst( - builder.getUIntType(), - keyMappingFunc, - List({inst->getWitnessTable()})); - inst->replaceUsesWith(witnessTableId); - propagationMap[witnessTableId] = info; + auto info = tryGetInfo(inst); + if (!info) + return; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + if (info.judgment == PropagationJudgment::SetOfTables) + { + // Replace with GetElement(loweredInst, 0) -> uint + auto operand = inst->getOperand(0); + auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); + inst->replaceUsesWith(element); + propagationMap[element] = info; inst->removeAndDeallocate(); } } -} -void DynamicInstLoweringContext::lowerExtractExistentialWitnessTable( - IRExtractExistentialWitnessTable* inst) -{ - auto info = tryGetInfo(inst); - if (!info) - return; + void lowerExtractExistentialValue(IRExtractExistentialValue* inst) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); - IRBuilder builder(inst); - builder.setInsertBefore(inst); + // Check if we have a lowered any-value type for the result + auto resultType = inst->getDataType(); + auto loweredType = loweredInstToAnyValueType.tryGetValue(inst); + if (loweredType) + { + resultType = *loweredType; + } - if (info->judgment == PropagationJudgment::SetOfTables) - { - // Replace with GetElement(loweredInst, 0) -> uint + // Replace with GetElement(loweredInst, 1) -> AnyValueType auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); + auto element = builder.emitGetTupleElement(resultType, operand, 1); inst->replaceUsesWith(element); - propagationMap[element] = info; inst->removeAndDeallocate(); } -} - -void DynamicInstLoweringContext::lowerExtractExistentialValue(IRExtractExistentialValue* inst) -{ - IRBuilder builder(inst); - builder.setInsertBefore(inst); - // Check if we have a lowered any-value type for the result - auto resultType = inst->getDataType(); - auto loweredType = loweredInstToAnyValueType.tryGetValue(inst); - if (loweredType) + void lowerExtractExistentialType(IRExtractExistentialType* inst) { - resultType = *loweredType; - } - - // Replace with GetElement(loweredInst, 1) -> AnyValueType - auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement(resultType, operand, 1); - inst->replaceUsesWith(element); - inst->removeAndDeallocate(); -} + auto info = tryGetInfo(inst); + if (!info || info.judgment != PropagationJudgment::SetOfTypes) + return; -void DynamicInstLoweringContext::lowerExtractExistentialType(IRExtractExistentialType* inst) -{ - auto info = tryGetInfo(inst); - if (!info || info->judgment != PropagationJudgment::SetOfTypes) - return; + IRBuilder builder(inst); + builder.setInsertBefore(inst); - IRBuilder builder(inst); - builder.setInsertBefore(inst); + // Create an any-value type based on the set of types + auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); - // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(info->possibleValues); + // Store the mapping for later use + loweredInstToAnyValueType[inst] = anyValueType; - // Store the mapping for later use - loweredInstToAnyValueType[inst] = anyValueType; + // Replace the instruction with the any-value type + inst->replaceUsesWith(anyValueType); + inst->removeAndDeallocate(); + } - // Replace the instruction with the any-value type - inst->replaceUsesWith(anyValueType); - inst->removeAndDeallocate(); -} + void lowerCall(IRCall* inst) + { + auto callee = inst->getCallee(); + auto calleeInfo = tryGetInfo(callee); -void DynamicInstLoweringContext::lowerCall(IRCall* inst) -{ - auto callee = inst->getCallee(); - auto calleeInfo = tryGetInfo(callee); + if (!calleeInfo || calleeInfo.judgment != PropagationJudgment::SetOfFuncs) + return; - if (!calleeInfo || calleeInfo->judgment != PropagationJudgment::SetOfFuncs) - return; + IRBuilder builder(inst); + builder.setInsertBefore(inst); - IRBuilder builder(inst); - builder.setInsertBefore(inst); + // Create dispatch function + auto dispatchFunc = createDispatchFunc(calleeInfo.possibleValues, calleeInfo.dynFuncType); - // Create dispatch function - auto dispatchFunc = createDispatchFunc(calleeInfo->possibleValues, calleeInfo->dynFuncType); + // Replace call with dispatch + List newArgs; + newArgs.add(callee); // Add the lookup as first argument (will get lowered into an uint tag) + for (UInt i = 1; i < inst->getOperandCount(); i++) + { + newArgs.add(inst->getOperand(i)); + } - // Replace call with dispatch - List newArgs; - newArgs.add(callee); // Add the lookup as first argument (will get lowered into an uint tag) - for (UInt i = 1; i < inst->getOperandCount(); i++) - { - newArgs.add(inst->getOperand(i)); + auto newCall = builder.emitCallInst(inst->getDataType(), dispatchFunc, newArgs); + inst->replaceUsesWith(newCall); + if (auto info = tryGetInfo(inst)) + propagationMap[newCall] = info; + inst->removeAndDeallocate(); } - auto newCall = builder.emitCallInst(inst->getDataType(), dispatchFunc, newArgs); - inst->replaceUsesWith(newCall); - if (auto info = tryGetInfo(inst)) - propagationMap[newCall] = info; - inst->removeAndDeallocate(); -} - -void DynamicInstLoweringContext::lowerMakeExistential(IRMakeExistential* inst) -{ - auto info = tryGetInfo(inst); - if (!info || info->judgment != PropagationJudgment::Existential) - return; + void lowerMakeExistential(IRMakeExistential* inst) + { + auto info = tryGetInfo(inst); + if (!info || info.judgment != PropagationJudgment::Existential) + return; - IRBuilder builder(inst); - builder.setInsertBefore(inst); + IRBuilder builder(inst); + builder.setInsertBefore(inst); - // Get unique ID for the witness table. TODO: Assert that this is a concrete table.. - auto witnessTable = cast(inst->getWitnessTable()); - auto tableId = builder.getIntValue(builder.getUIntType(), getUniqueID(witnessTable)); + // Get unique ID for the witness table. TODO: Assert that this is a concrete table.. + auto witnessTable = cast(inst->getWitnessTable()); + auto tableId = builder.getIntValue(builder.getUIntType(), getUniqueID(witnessTable)); - // Collect types from the witness tables to determine the any-value type - HashSet types; - for (auto table : info->possibleValues) - { - if (auto witnessTableInst = as(table)) + // Collect types from the witness tables to determine the any-value type + HashSet types; + for (auto table : info.possibleValues) { - if (auto concreteType = witnessTableInst->getConcreteType()) + if (auto witnessTableInst = as(table)) { - types.add(concreteType); + if (auto concreteType = witnessTableInst->getConcreteType()) + { + types.add(concreteType); + } } } - } - // Create the appropriate any-value type - auto anyValueType = createAnyValueType(types); + // Create the appropriate any-value type + auto anyValueType = createAnyValueType(types); - // Pack the value - auto packedValue = builder.emitPackAnyValue(anyValueType, inst->getWrappedValue()); + // Pack the value + auto packedValue = builder.emitPackAnyValue(anyValueType, inst->getWrappedValue()); - // Create tuple (table_unique_id, PackAnyValue(val)) - auto tupleType = - builder.getTupleType(List({builder.getUIntType(), packedValue->getDataType()})); - IRInst* tupleArgs[] = {tableId, packedValue}; - auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); + // Create tuple (table_unique_id, PackAnyValue(val)) + auto tupleType = builder.getTupleType( + List({builder.getUIntType(), packedValue->getDataType()})); + IRInst* tupleArgs[] = {tableId, packedValue}; + auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); - inst->replaceUsesWith(tuple); - inst->removeAndDeallocate(); -} - -void DynamicInstLoweringContext::lowerCreateExistentialObject(IRCreateExistentialObject* inst) -{ - // Error out for now as specified - sink->diagnose( - inst, - Diagnostics::unimplemented, - "IRCreateExistentialObject lowering not yet implemented"); -} - -UInt DynamicInstLoweringContext::getUniqueID(IRInst* funcOrTable) -{ - auto existingId = uniqueIds.tryGetValue(funcOrTable); - if (existingId) - return *existingId; + inst->replaceUsesWith(tuple); + inst->removeAndDeallocate(); + } - UInt newId = nextUniqueId++; - uniqueIds[funcOrTable] = newId; - return newId; -} + void lowerCreateExistentialObject(IRCreateExistentialObject* inst) + { + // Error out for now as specified + sink->diagnose( + inst, + Diagnostics::unimplemented, + "IRCreateExistentialObject lowering not yet implemented"); + } -IRFunc* DynamicInstLoweringContext::createKeyMappingFunc( - IRInst* key, - const HashSet& inputTables, - const HashSet& outputVals) -{ - // Create a function that maps input IDs to output IDs - IRBuilder builder(module); + UInt getUniqueID(IRInst* funcOrTable) + { + auto existingId = uniqueIds.tryGetValue(funcOrTable); + if (existingId) + return *existingId; - auto funcType = - builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); - auto func = builder.createFunc(); - builder.setInsertInto(func); - func->setFullType(funcType); + UInt newId = nextUniqueId++; + uniqueIds[funcOrTable] = newId; + return newId; + } - auto entryBlock = builder.emitBlock(); - builder.setInsertInto(entryBlock); + IRFunc* createKeyMappingFunc( + IRInst* key, + const HashSet& inputTables, + const HashSet& outputVals) + { + // Create a function that maps input IDs to output IDs + IRBuilder builder(module); - auto param = builder.emitParam(builder.getUIntType()); + auto funcType = + builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); - // Create default block that returns 0 - auto defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), 0)); + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); - // Go back to entry block and create switch - builder.setInsertInto(entryBlock); + auto param = builder.emitParam(builder.getUIntType()); - // Create case blocks for each input table - List caseValues; - List caseBlocks; + // Create default block that returns 0 + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), 0)); - // Build mapping from input tables to output values - List inputTableArray; - List outputValArray; + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); - for (auto table : inputTables) - inputTableArray.add(table); - for (auto table : outputVals) - outputValArray.add(table); + // Create case blocks for each input table + List caseValues; + List caseBlocks; - for (Index i = 0; i < inputTableArray.getCount(); i++) - { - auto inputTable = inputTableArray[i]; - auto inputId = getUniqueID(inputTable); + // Build mapping from input tables to output values + List inputTableArray; + List outputValArray; - // Find corresponding output table (for now, use simple 1:1 mapping) - IRInst* outputVal = nullptr; - if (i < outputValArray.getCount()) - { - outputVal = outputValArray[i]; - } - else if (outputValArray.getCount() > 0) - { - outputVal = outputValArray[0]; // Fallback to first output - } + for (auto table : inputTables) + inputTableArray.add(table); + for (auto table : outputVals) + outputValArray.add(table); - if (outputVal) + for (Index i = 0; i < inputTableArray.getCount(); i++) { - auto outputId = getUniqueID(outputVal); - - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), outputId)); + auto inputTable = inputTableArray[i]; + auto inputId = getUniqueID(inputTable); - caseValues.add(builder.getIntValue(builder.getUIntType(), inputId)); - caseBlocks.add(caseBlock); - } - } + // Find corresponding output table (for now, use simple 1:1 mapping) + IRInst* outputVal = nullptr; + if (i < outputValArray.getCount()) + { + outputVal = outputValArray[i]; + } + else if (outputValArray.getCount() > 0) + { + outputVal = outputValArray[0]; // Fallback to first output + } - // Create flattened case arguments array - List flattenedCaseArgs; - for (Index i = 0; i < caseValues.getCount(); i++) - { - flattenedCaseArgs.add(caseValues[i]); - flattenedCaseArgs.add(caseBlocks[i]); - } + if (outputVal) + { + auto outputId = getUniqueID(outputVal); - // Emit an unreachable block for the break block. - auto unreachableBlock = builder.emitBlock(); - builder.setInsertInto(unreachableBlock); - builder.emitUnreachable(); - - // Go back to entry and emit switch - builder.setInsertInto(entryBlock); - builder.emitSwitch( - param, - unreachableBlock, - defaultBlock, - flattenedCaseArgs.getCount(), - flattenedCaseArgs.getBuffer()); - - return func; -} + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), outputId)); -IRFunc* DynamicInstLoweringContext::createDispatchFunc( - const HashSet& funcs, - IRFuncType* expectedFuncType) -{ - // Create a dispatch function with switch-case for each function - IRBuilder builder(module); + caseValues.add(builder.getIntValue(builder.getUIntType(), inputId)); + caseBlocks.add(caseBlock); + } + } - // Extract parameter types from the first function in the set - List paramTypes; - paramTypes.add(builder.getUIntType()); // ID parameter + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } - // Get parameter types from first function - List funcArray; - for (auto func : funcs) - funcArray.add(func); + // Emit an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + param, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; + } - for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) + IRFunc* createDispatchFunc(const HashSet& funcs, IRFuncType* expectedFuncType) { - paramTypes.add(expectedFuncType->getParamType(i)); - } + // Create a dispatch function with switch-case for each function + IRBuilder builder(module); - auto resultType = expectedFuncType->getResultType(); - auto funcType = builder.getFuncType(paramTypes, resultType); - auto func = builder.createFunc(); - builder.setInsertInto(func); - func->setFullType(funcType); + // Extract parameter types from the first function in the set + List paramTypes; + paramTypes.add(builder.getUIntType()); // ID parameter - auto entryBlock = builder.emitBlock(); - builder.setInsertInto(entryBlock); + // Get parameter types from first function + List funcArray; + for (auto func : funcs) + funcArray.add(func); - auto idParam = builder.emitParam(builder.getUIntType()); + for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) + { + paramTypes.add(expectedFuncType->getParamType(i)); + } - // Create parameters for the original function arguments - List originalParams; - for (UInt i = 1; i < paramTypes.getCount(); i++) - { - originalParams.add(builder.emitParam(paramTypes[i])); - } + auto resultType = expectedFuncType->getResultType(); + auto funcType = builder.getFuncType(paramTypes, resultType); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); - // Create default block - auto defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); - } - else - { - // Return a default-constructed value - auto defaultValue = builder.emitDefaultConstruct(resultType); - builder.emitReturn(defaultValue); - } + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); - auto maybePackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* - { - // If the type is AnyValueType, pack the value - if (as(type) && !as(value->getDataType())) + auto idParam = builder.emitParam(builder.getUIntType()); + + // Create parameters for the original function arguments + List originalParams; + for (UInt i = 1; i < paramTypes.getCount(); i++) { - return builder->emitPackAnyValue(type, value); + originalParams.add(builder.emitParam(paramTypes[i])); } - return value; // Otherwise, return as is - }; - auto maybeUnpackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* - { - // If the type is AnyValueType, unpack the value - if (as(value->getDataType()) && !as(type)) + // Create default block + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else { - return builder->emitUnpackAnyValue(type, value); + // Return a default-constructed value + auto defaultValue = builder.emitDefaultConstruct(resultType); + builder.emitReturn(defaultValue); } - return value; // Otherwise, return as is - }; - // Go back to entry block and create switch - builder.setInsertInto(entryBlock); + auto maybePackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* + { + // If the type is AnyValueType, pack the value + if (as(type) && !as(value->getDataType())) + { + return builder->emitPackAnyValue(type, value); + } + return value; // Otherwise, return as is + }; - // Create case blocks for each function - List caseValues; - List caseBlocks; + auto maybeUnpackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* + { + // If the type is AnyValueType, unpack the value + if (as(value->getDataType()) && !as(type)) + { + return builder->emitUnpackAnyValue(type, value); + } + return value; // Otherwise, return as is + }; - for (auto funcInst : funcs) - { - auto funcId = getUniqueID(funcInst); + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); + // Create case blocks for each function + List caseValues; + List caseBlocks; - List callArgs; - auto concreteFuncType = as(funcInst->getDataType()); - for (UIndex ii = 0; ii < originalParams.getCount(); ii++) + for (auto funcInst : funcs) { - callArgs.add( - maybeUnpackValue(&builder, originalParams[ii], concreteFuncType->getParamType(ii))); - } + auto funcId = getUniqueID(funcInst); - // Call the specific function - auto callResult = builder.emitCallInst(resultType, funcInst, callArgs); + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); + List callArgs; + auto concreteFuncType = as(funcInst->getDataType()); + for (UIndex ii = 0; ii < originalParams.getCount(); ii++) + { + callArgs.add(maybeUnpackValue( + &builder, + originalParams[ii], + concreteFuncType->getParamType(ii))); + } + + // Call the specific function + auto callResult = builder.emitCallInst(resultType, funcInst, callArgs); + + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + builder.emitReturn(maybePackValue(&builder, callResult, resultType)); + } + + caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); + caseBlocks.add(caseBlock); } - else + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) { - builder.emitReturn(maybePackValue(&builder, callResult, resultType)); + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); } - caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); - caseBlocks.add(caseBlock); + // Create an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + idParam, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; } - // Create flattened case arguments array - List flattenedCaseArgs; - for (Index i = 0; i < caseValues.getCount(); i++) + IRAnyValueType* createAnyValueType(const HashSet& types) { - flattenedCaseArgs.add(caseValues[i]); - flattenedCaseArgs.add(caseBlocks[i]); + IRBuilder builder(module); + auto size = calculateAnyValueSize(types); + return builder.getAnyValueType(size); } - // Create an unreachable block for the break block. - auto unreachableBlock = builder.emitBlock(); - builder.setInsertInto(unreachableBlock); - builder.emitUnreachable(); - - // Go back to entry and emit switch - builder.setInsertInto(entryBlock); - builder.emitSwitch( - idParam, - unreachableBlock, - defaultBlock, - flattenedCaseArgs.getCount(), - flattenedCaseArgs.getBuffer()); - - return func; -} - -IRAnyValueType* DynamicInstLoweringContext::createAnyValueType(const HashSet& types) -{ - IRBuilder builder(module); - auto size = calculateAnyValueSize(types); - return builder.getAnyValueType(size); -} - -IRAnyValueType* DynamicInstLoweringContext::createAnyValueTypeFromInsts( - const HashSet& typeInsts) -{ - HashSet types; - for (auto inst : typeInsts) + IRAnyValueType* createAnyValueTypeFromInsts(const HashSet& typeInsts) { - if (auto type = as(inst)) + HashSet types; + for (auto inst : typeInsts) { - types.add(type); + if (auto type = as(inst)) + { + types.add(type); + } } + return createAnyValueType(types); } - return createAnyValueType(types); -} -SlangInt DynamicInstLoweringContext::calculateAnyValueSize(const HashSet& types) -{ - SlangInt maxSize = 0; - for (auto type : types) + SlangInt calculateAnyValueSize(const HashSet& types) { - auto size = getAnyValueSize(type); - if (size > maxSize) - maxSize = size; + SlangInt maxSize = 0; + for (auto type : types) + { + auto size = getAnyValueSize(type); + if (size > maxSize) + maxSize = size; + } + return maxSize; } - return maxSize; -} - -bool DynamicInstLoweringContext::needsReinterpret( - RefPtr sourceInfo, - RefPtr targetInfo) -{ - if (!sourceInfo || !targetInfo) - return false; - // Check if both are SetOfTypes with different sets - if (sourceInfo->judgment == PropagationJudgment::SetOfTypes && - targetInfo->judgment == PropagationJudgment::SetOfTypes) + bool needsReinterpret(PropagationInfo sourceInfo, PropagationInfo targetInfo) { - if (sourceInfo->possibleValues.getCount() != targetInfo->possibleValues.getCount()) - return true; - } + if (!sourceInfo || !targetInfo) + return false; - // Check if both are Existential with different witness table sets - if (sourceInfo->judgment == PropagationJudgment::Existential && - targetInfo->judgment == PropagationJudgment::Existential) - { - if (sourceInfo->possibleValues.getCount() != targetInfo->possibleValues.getCount()) - return true; - } + // Check if both are SetOfTypes with different sets + if (sourceInfo.judgment == PropagationJudgment::SetOfTypes && + targetInfo.judgment == PropagationJudgment::SetOfTypes) + { + if (sourceInfo.possibleValues.getCount() != targetInfo.possibleValues.getCount()) + return true; + } - return false; -} + // Check if both are Existential with different witness table sets + if (sourceInfo.judgment == PropagationJudgment::Existential && + targetInfo.judgment == PropagationJudgment::Existential) + { + if (sourceInfo.possibleValues.getCount() != targetInfo.possibleValues.getCount()) + return true; + } -bool DynamicInstLoweringContext::isExistentialType(IRType* type) -{ - return as(type) != nullptr; -} + return false; + } -bool DynamicInstLoweringContext::isInterfaceType(IRType* type) -{ - return as(type) != nullptr; -} + bool isExistentialType(IRType* type) { return as(type) != nullptr; } -HashSet DynamicInstLoweringContext::collectExistentialTables( - IRInterfaceType* interfaceType) -{ - HashSet tables; + bool isInterfaceType(IRType* type) { return as(type) != nullptr; } - IRWitnessTableType* targetTableType = nullptr; - // First, find the IRWitnessTableType that wraps the given interfaceType - for (auto use = interfaceType->firstUse; use; use = use->nextUse) + HashSet collectExistentialTables(IRInterfaceType* interfaceType) { - if (auto wtType = as(use->getUser())) + HashSet tables; + + IRWitnessTableType* targetTableType = nullptr; + // First, find the IRWitnessTableType that wraps the given interfaceType + for (auto use = interfaceType->firstUse; use; use = use->nextUse) { - if (wtType->getConformanceType() == interfaceType) + if (auto wtType = as(use->getUser())) { - targetTableType = wtType; - break; + if (wtType->getConformanceType() == interfaceType) + { + targetTableType = wtType; + break; + } } } - } - // If the target witness table type was found, gather all witness tables using it - if (targetTableType) - { - for (auto use = targetTableType->firstUse; use; use = use->nextUse) + // If the target witness table type was found, gather all witness tables using it + if (targetTableType) { - if (auto witnessTable = as(use->getUser())) + for (auto use = targetTableType->firstUse; use; use = use->nextUse) { - if (witnessTable->getDataType() == targetTableType) + if (auto witnessTable = as(use->getUser())) { - tables.add(witnessTable); + if (witnessTable->getDataType() == targetTableType) + { + tables.add(witnessTable); + } } } } + + return tables; } - return tables; -} + void processModule() + { + // Phase 1: Information Propagation + performInformationPropagation(); -void DynamicInstLoweringContext::processModule() -{ - // Phase 1: Information Propagation - performInformationPropagation(); + // Phase 1.5: Insert reinterprets for points where sets merge + // e.g. phi, return, call + // + insertReinterprets(); - // Phase 1.5: Insert reinterprets for phi parameters where needed - insertReinterpretsForPhiParameters(); + // Phase 2: Dynamic Instruction Lowering + performDynamicInstLowering(); + } - // Phase 2: Dynamic Instruction Lowering - performDynamicInstLowering(); -} + DynamicInstLoweringContext(IRModule* module, DiagnosticSink* sink) + : module(module), sink(sink) + { + } + + // Basic context + IRModule* module; + DiagnosticSink* sink; + + // Mapping from instruction to propagation information + Dictionary propagationMap; + + // Mapping from function to return value propagation information + Dictionary funcReturnInfo; + + // Unique ID assignment for functions and witness tables + Dictionary uniqueIds; + UInt nextUniqueId = 1; + + // Mapping from lowered instruction to their any-value types + Dictionary loweredInstToAnyValueType; +}; // Main entry point void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index 8fb551e88d6..33872899e40 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -6,136 +6,6 @@ namespace Slang { - -// Enumeration for different kinds of judgments about IR instructions -enum class PropagationJudgment -{ - Value, // Regular value computation (not interface-related) - ConcreteType, // Concrete type reference - ConcreteTable, // Concrete witness table reference - ConcreteFunc, // Concrete function reference - SetOfTypes, // Set of possible types - SetOfTables, // Set of possible witness tables - SetOfFuncs, // Set of possible functions - UnknownSet, // Unknown set of possible types/tables/funcs (e.g. from COM interface types) - Existential // Existential box with a set of possible witness tables -}; - -// Data structure to hold propagation information for an instruction -struct PropagationInfo : RefObject -{ - PropagationJudgment judgment; - - // For concrete references - IRInst* concreteValue = nullptr; - - // For sets of types/tables/funcs and existential witness tables - HashSet possibleValues; - - // For SetOfFuncs - IRFuncType* dynFuncType; - - PropagationInfo() = default; - PropagationInfo(PropagationJudgment j) - : judgment(j) - { - } - - static RefPtr makeValue() - { - return new PropagationInfo(PropagationJudgment::Value); - } - static RefPtr makeConcrete(PropagationJudgment j, IRInst* value); - static RefPtr makeSet(PropagationJudgment j, const HashSet& values); - static RefPtr makeSetOfFuncs( - const HashSet& funcs, - IRFuncType* dynFuncType); - static RefPtr makeExistential(const HashSet& tables); - static RefPtr makeUnknown(); -}; - -// Context for the abstract interpretation pass -struct DynamicInstLoweringContext -{ - IRModule* module; - DiagnosticSink* sink; - - // Mapping from instruction to propagation information - Dictionary> propagationMap; - - // Unique ID assignment for functions and witness tables - Dictionary uniqueIds; - UInt nextUniqueId = 1; - - // Mapping from lowered instruction to their any-value types - Dictionary loweredInstToAnyValueType; - - DynamicInstLoweringContext(IRModule* inModule, DiagnosticSink* inSink) - : module(inModule), sink(inSink) - { - } - - // Phase 1: Information Propagation - void performInformationPropagation(); - void processFunction(IRFunc* func); - void propagateEdge(IREdge edge); - void processInstForPropagation(IRInst* inst); - - // Helper to get propagation info, handling global insts specially - RefPtr tryGetInfo(IRInst* inst); - - // Control flow analysis helpers - RefPtr unionPropagationInfo(const List>& infos); - void initializeFirstBlockParameters(IRFunc* func); - void insertReinterpretsForPhiParameters(); - - // Analysis of specific instruction types - RefPtr analyzeCreateExistentialObject(IRCreateExistentialObject* inst); - RefPtr analyzeMakeExistential(IRMakeExistential* inst); - RefPtr analyzeLookupWitnessMethod(IRLookupWitnessMethod* inst); - RefPtr analyzeExtractExistentialWitnessTable( - IRExtractExistentialWitnessTable* inst); - RefPtr analyzeExtractExistentialType(IRExtractExistentialType* inst); - RefPtr analyzeExtractExistentialValue(IRExtractExistentialValue* inst); - RefPtr analyzeCall(IRCall* inst); - RefPtr analyzeDefault(IRInst* inst); - - // Phase 2: Dynamic Instruction Lowering - void performDynamicInstLowering(); - void lowerInst(IRInst* inst); - void replaceType(IRInst* inst); - - // Lowering of specific instruction types - void lowerLookupWitnessMethod(IRLookupWitnessMethod* inst); - void lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst); - void lowerExtractExistentialType(IRExtractExistentialType* inst); - void lowerExtractExistentialValue(IRExtractExistentialValue* inst); - void lowerCall(IRCall* inst); - void lowerMakeExistential(IRMakeExistential* inst); - void lowerCreateExistentialObject(IRCreateExistentialObject* inst); - - // Helper functions - UInt getUniqueID(IRInst* funcOrTable); - IRFunc* createKeyMappingFunc( - IRInst* key, - const HashSet& inputTables, - const HashSet& outputTables); - IRFunc* createDispatchFunc(const HashSet& funcs, IRFuncType* expectedFuncType); - IRAnyValueType* createAnyValueType(const HashSet& types); - IRAnyValueType* createAnyValueTypeFromInsts(const HashSet& typeInsts); - SlangInt calculateAnyValueSize(const HashSet& types); - bool needsReinterpret(RefPtr sourceInfo, RefPtr targetInfo); - IRInst* maybeReinterpretArg(IRInst* arg, IRInst* param); - - // Utility functions - bool isExistentialType(IRType* type); - bool isInterfaceType(IRType* type); - HashSet collectExistentialTables(IRInterfaceType* interfaceType); - - // Main entry point - void processModule(); -}; - // Main entry point for the pass void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); } // namespace Slang diff --git a/source/slang/slang-vscode.natvis b/source/slang/slang-vscode.natvis new file mode 100644 index 00000000000..f159c036835 --- /dev/null +++ b/source/slang/slang-vscode.natvis @@ -0,0 +1,818 @@ + + + + + rawVal ? ($T1*)((char*)this + rawVal) : ($T1*)0 + BCPtr nullptr + BCPtr {*($T1*)((char*)this + rawVal)} + + rawVal ? ($T1*)((char*)this + rawVal) : ($T1*)0 + + + + Constant {intOperand} + {(Slang::Val*)nodeOperand} + {nodeOperand} + + *(Slang::Val*)nodeOperand + + + + DeclRef nullptr + + {*declRefBase} + + declRefBase + + + + {astNodeType,en}#{_debugUID}({(Decl*)m_operands.m_buffer[0].values.nodeOperand}) + {astNodeType,en}({(Decl*)m_operands.m_buffer[0].values.nodeOperand}) + DeclRefBase nullptr + + + {*(Decl*)m_operands.m_buffer[0].values.nodeOperand} + + *(Decl*)m_operands.m_buffer[0].values.nodeOperand + + + + {*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)} + + *(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand) + + + + {*(Val*)(this->m_operands.m_buffer[1].values.nodeOperand)} + + *(Val*)(this->m_operands.m_buffer[1].values.nodeOperand) + + + + {*(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand)} + + *(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand) + + + + {*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)} + + *(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand) + + + + + + *(Val*)(this->m_operands.m_buffer[index].values.nodeOperand) + index=index+1 + + + + + + {astNodeType,en}#{_debugUID} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand} + + {astNodeType,en} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand} + + + {astNodeType,en}#{_debugUID} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en}#{m_operands.m_buffer[0].values.nodeOperand->_debugUID} + {astNodeType,en} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en} + + *(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand + + + + FuncDecl {nameAndLoc} + + + {{name={(char*)(text.m_buffer.pointer+1), s}}} + + + {{name={(char*)((*name).text.m_buffer.pointer+1), s} loc={loc.raw}}} + + + + requirementKey + satisfyingVal + + + + {{{m_op} {(uint32_t)(void*)this, x}}} + {{{m_op} #{_debugUID}}} + + m_op + _debugUID + typeUse.usedValue + + + + + + + + ((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.chars,[((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.numChars]s8 + + + ((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.chars,[((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.numChars]s8 + + + ((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.chars,[((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.numChars]s8 + + child = child->next + + + ((IRStringLit*)this)->value.stringVal.chars,[((IRStringLit*)this)->value.stringVal.numChars]s8 + ((IRIntLit*)this)->value.intVal + + + + + + + pOperandInst = ((IRUse*)(&(typeUse) + 1 + index))->usedValue + pOperandInst + + child = pOperandInst->m_decorationsAndChildren.first + nameDecoration = 0 + + + nameDecoration = child + + + + nameDecoration = child + + + nameDecoration = child + + child = child->next + + *pOperandInst + *pOperandInst + + index = index + 1 + + + + + + + + + + + child = pItem->m_decorationsAndChildren.first + nameDecoration = 0 + + + nameDecoration = child + + + + nameDecoration = child + + + nameDecoration = child + + child = child->next + + *pItem + *pItem + pItem = pItem->next + index = index + 1 + + + + + parent + + + + firstUse + nextUse + user + + + + + + + {{IRUse {usedValue}}} + + usedValue + + + + + {astNodeType,en} + + (Slang::DeclRefExpr*)&astNodeType + (Slang::VarExpr*)&astNodeType + (Slang::MemberExpr*)&astNodeType + (Slang::StaticMemberExpr*)&astNodeType + (Slang::OverloadedExpr*)&astNodeType + (Slang::OverloadedExpr2*)&astNodeType + (Slang::LiteralExpr*)&astNodeType + (Slang::IntegerLiteralExpr*)&astNodeType + (Slang::FloatingPointLiteralExpr*)&astNodeType + (Slang::BoolLiteralExpr*)&astNodeType + (Slang::NullPtrLiteralExpr*)&astNodeType + (Slang::StringLiteralExpr*)&astNodeType + (Slang::InitializerListExpr*)&astNodeType + (Slang::ExprWithArgsBase*)&astNodeType + (Slang::AggTypeCtorExpr*)&astNodeType + (Slang::AppExprBase*)&astNodeType + (Slang::InvokeExpr*)&astNodeType + (Slang::NewExpr*)&astNodeType + (Slang::OperatorExpr*)&astNodeType + (Slang::InfixExpr*)&astNodeType + (Slang::PrefixExpr*)&astNodeType + (Slang::PostfixExpr*)&astNodeType + (Slang::SelectExpr*)&astNodeType + (Slang::TypeCastExpr*)&astNodeType + (Slang::ExplicitCastExpr*)&astNodeType + (Slang::ImplicitCastExpr*)&astNodeType + (Slang::GenericAppExpr*)&astNodeType + (Slang::TryExpr*)&astNodeType + (Slang::IndexExpr*)&astNodeType + (Slang::MatrixSwizzleExpr*)&astNodeType + (Slang::SwizzleExpr*)&astNodeType + (Slang::DerefExpr*)&astNodeType + (Slang::CastToSuperTypeExpr*)&astNodeType + (Slang::ModifierCastExpr*)&astNodeType + (Slang::SharedTypeExpr*)&astNodeType + (Slang::AssignExpr*)&astNodeType + (Slang::ParenExpr*)&astNodeType + (Slang::ThisExpr*)&astNodeType + (Slang::LetExpr*)&astNodeType + (Slang::ExtractExistentialValueExpr*)&astNodeType + (Slang::OpenRefExpr*)&astNodeType + (Slang::ForwardDifferentiateExpr*)&astNodeType + (Slang::BackwardDifferentiateExpr*)&astNodeType + (Slang::ThisTypeExpr*)&astNodeType + (Slang::AndTypeExpr*)&astNodeType + (Slang::ModifiedTypeExpr*)&astNodeType + (Slang::PointerTypeExpr*)&astNodeType + type + (Slang::Expr*)this,! + + + + {astNodeType,en} + + (Slang::ScopeStmt*)&astNodeType + (Slang::BlockStmt*)&astNodeType + (Slang::BreakableStmt*)&astNodeType + (Slang::SwitchStmt*)&astNodeType + (Slang::LoopStmt*)&astNodeType + (Slang::ForStmt*)&astNodeType + (Slang::UnscopedForStmt*)&astNodeType + (Slang::WhileStmt*)&astNodeType + (Slang::DoWhileStmt*)&astNodeType + (Slang::GpuForeachStmt*)&astNodeType + (Slang::CompileTimeForStmt*)&astNodeType + (Slang::SeqStmt*)&astNodeType + (Slang::UnparsedStmt*)&astNodeType + (Slang::EmptyStmt*)&astNodeType + (Slang::DiscardStmt*)&astNodeType + (Slang::DeclStmt*)&astNodeType + (Slang::IfStmt*)&astNodeType + (Slang::ChildStmt*)&astNodeType + (Slang::CaseStmtBase*)&astNodeType + (Slang::CaseStmt*)&astNodeType + (Slang::DefaultStmt*)&astNodeType + (Slang::JumpStmt*)&astNodeType + (Slang::BreakStmt*)&astNodeType + (Slang::ContinueStmt*)&astNodeType + (Slang::ReturnStmt*)&astNodeType + (Slang::ExpressionStmt*)&astNodeType + (Slang::Stmt*)this,! + + + + {text} + + + {astNodeType,en} {nameAndLoc.name->text} + {astNodeType,en} + + nameAndLoc.name->text + parentDecl + Slang::DeclCheckState(checkState.m_raw & ~Slang::DeclCheckStateExt::kBeingCheckedBit) + (Slang::ContainerDecl*)&astNodeType + (Slang::ExtensionDecl*)&astNodeType + (Slang::StructDecl*)&astNodeType + (Slang::ClassDecl*)&astNodeType + (Slang::EnumDecl*)&astNodeType + (Slang::InterfaceDecl*)&astNodeType + (Slang::AssocTypeDecl*)&astNodeType + (Slang::GlobalGenericParamDecl*)&astNodeType + (Slang::ScopeDecl*)&astNodeType + (Slang::ConstructorDecl*)&astNodeType + (Slang::GetterDecl*)&astNodeType + (Slang::SetterDecl*)&astNodeType + (Slang::RefAccessorDecl*)&astNodeType + (Slang::FuncDecl*)&astNodeType + (Slang::SubscriptDecl*)&astNodeType + (Slang::PropertyDecl*)&astNodeType + (Slang::NamespaceDecl*)&astNodeType + (Slang::ModuleDecl*)&astNodeType + (Slang::GenericDecl*)&astNodeType + (Slang::AttributeDecl*)&astNodeType + (Slang::VarDeclBase*)&astNodeType + (Slang::VarDecl*)&astNodeType + (Slang::LetDecl*)&astNodeType + (Slang::GlobalGenericValueParamDecl*)&astNodeType + (Slang::ParamDecl*)&astNodeType + (Slang::ModernParamDecl*)&astNodeType + (Slang::GenericValueParamDecl*)&astNodeType + (Slang::EnumCaseDecl*)&astNodeType + (Slang::TypeConstraintDecl*)&astNodeType + (Slang::InheritanceDecl*)&astNodeType + (Slang::GenericTypeConstraintDecl*)&astNodeType + (Slang::SimpleTypeDecl*)&astNodeType + (Slang::TypeDefDecl*)&astNodeType + (Slang::TypeAliasDecl*)&astNodeType + (Slang::GenericTypeParamDecl*)&astNodeType + (Slang::UsingDecl*)&astNodeType + (Slang::ImportDecl*)&astNodeType + (Slang::EmptyDecl*)&astNodeType + (Slang::SyntaxDecl*)&astNodeType + (Slang::DeclGroup*)&astNodeType + + (Slang::DeclBase*)this,! + + + + + {astNodeType,en} + + (Slang::ContainerDecl*)&astNodeType + (Slang::ExtensionDecl*)&astNodeType + (Slang::StructDecl*)&astNodeType + (Slang::ClassDecl*)&astNodeType + (Slang::EnumDecl*)&astNodeType + (Slang::InterfaceDecl*)&astNodeType + (Slang::AssocTypeDecl*)&astNodeType + (Slang::GlobalGenericParamDecl*)&astNodeType + (Slang::ScopeDecl*)&astNodeType + (Slang::ConstructorDecl*)&astNodeType + (Slang::GetterDecl*)&astNodeType + (Slang::SetterDecl*)&astNodeType + (Slang::RefAccessorDecl*)&astNodeType + (Slang::FuncDecl*)&astNodeType + (Slang::SubscriptDecl*)&astNodeType + (Slang::PropertyDecl*)&astNodeType + (Slang::NamespaceDecl*)&astNodeType + (Slang::ModuleDecl*)&astNodeType + (Slang::GenericDecl*)&astNodeType + (Slang::AttributeDecl*)&astNodeType + (Slang::VarDeclBase*)&astNodeType + (Slang::VarDecl*)&astNodeType + (Slang::LetDecl*)&astNodeType + (Slang::GlobalGenericValueParamDecl*)&astNodeType + (Slang::ParamDecl*)&astNodeType + (Slang::ModernParamDecl*)&astNodeType + (Slang::GenericValueParamDecl*)&astNodeType + (Slang::EnumCaseDecl*)&astNodeType + (Slang::TypeConstraintDecl*)&astNodeType + (Slang::InheritanceDecl*)&astNodeType + (Slang::GenericTypeConstraintDecl*)&astNodeType + (Slang::SimpleTypeDecl*)&astNodeType + (Slang::TypeDefDecl*)&astNodeType + (Slang::TypeAliasDecl*)&astNodeType + (Slang::GenericTypeParamDecl*)&astNodeType + (Slang::UsingDecl*)&astNodeType + (Slang::ImportDecl*)&astNodeType + (Slang::EmptyDecl*)&astNodeType + (Slang::SyntaxDecl*)&astNodeType + (Slang::DeclGroup*)&astNodeType + (Slang::Decl*)this,! + + + + {astNodeType,en} #{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} + {astNodeType,en} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} + + *(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand) + + + + {astNodeType,en} #{_debugUID} + {astNodeType,en} + + + {m_operands.m_count-2} + + + m_operands.m_count-2 + m_operands.m_buffer + + + {m_operands.m_buffer[m_operands.m_count-2]} + + m_operands.m_buffer[m_operands.m_count-2] + + + + {m_operands.m_buffer[m_operands.m_count-1]} + + m_operands.m_buffer[m_operands.m_count-1] + + + + + + + DeclRefType#{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} + DeclRefType {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} + DirectRef#{_debugUID} {*(Decl*)m_operands.m_buffer[0].values.nodeOperand} + DirectRef {*(Decl*)m_operands.m_buffer[0].values.nodeOperand} + {astNodeType,en} #{_debugUID} + {astNodeType,en} + + + + {astNodeType} + + + m_operands + + + + + SubstitutionSet{declRef,en} + + declRef + + + + + + substType = subst->astNodeType + shouldBreak = 1 + + + + + + subst = (DeclRefBase*)(((Slang::MemberDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand) + shouldBreak = 0 + + + (LookupDeclRef*)subst + + + + (GenericAppDeclRef*)subst + subst = (DeclRefBase*)(((Slang::GenericAppDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand) + shouldBreak = 0 + + + + + + + + + + {astNodeType}({nameAndLoc.name}) + + members + + + + Const({values.intOperand})#{_debugUID} + {*(Val*)values.nodeOperand} + {values.nodeOperand} + + *(Val*)values.nodeOperand + *(Decl*)values.nodeOperand + + + + + _impl + nullptr + {_impl} + + + empty + + + _head._impl != 0 ? &_head : 0 + _impl->next._impl != 0 ? &_impl->next : 0 + *this + + + + + empty + + + _head != 0 ? _head : 0 + next != 0 ? next : 0 + *this + + + + + {astNodeType,en}#{_debugUID} ({m_operands.m_buffer[1].values.intOperand} : {*(Type*)m_operands.m_buffer[0].values.nodeOperand}) + ConstantIntVal ({m_operands.m_buffer[1].values.intOperand} : {*(Type*)m_operands.m_buffer[0].values.nodeOperand}) + + + + {astNodeType,en}#{_debugUID} + {astNodeType,en} + + + m_operands.m_count + m_operands.m_buffer + + + + + + {astNodeType,en}#{_debugUID} + {astNodeType,en} + + + m_operands.m_count + m_operands.m_buffer + + + + + + {astNodeType,en}#{_debugUID} + {astNodeType,en} + + + m_operands.m_count + m_operands.m_buffer + + + + + + BasicExpressionType ({*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand}) + + + + {m_targetSets.map.m_values} + + + m_targetSets + m_targetSets.map.m_values + + + + + {{target={target}}} + + + target + shaderStageSets + shaderStageSets.map.m_values + + + + + {{size={atomSet}}} + + + stage + atomSet + + + + + + + {{max_size={m_buffer.m_count*Slang::UIntSet::kElementSize}}} + + + + + + + + + + + + iter = (Slang::UIntSet::Element)0 + bitIter = (Slang::UIntSet::Element)0 + totalBitIter = (Slang::UIntSet::Element)0 + value = 0 + + + bitIter = 0 + totalBitIter++ + iter++ + + + + + bitValue = (m_buffer[iter]>>bitIter)&1 + + value = totalBitIter + (CapabilityAtom)value + + bitIter++ + totalBitIter++ + + + + + + + + {{max_size={m_buffer.m_count*Slang::UIntSet::kElementSize}}} + + + + + + + + + + + + iter = (Slang::UIntSet::Element)0 + bitIter = (Slang::UIntSet::Element)0 + totalBitIter = (Slang::UIntSet::Element)0 + value = 0 + + + bitIter = 0 + totalBitIter++ + iter++ + + + + + bitValue = (m_buffer[iter]>>bitIter)&1 + + value = totalBitIter + (CapabilityAtom)value + + bitIter++ + totalBitIter++ + + + + + + + + + + {((char*) (m_buffer.pointer+1)),s} + ((char*) (m_buffer.pointer+1)),s + + + + {{ size={m_count} }} + + m_count + + m_count + m_buffer + + + + + + {{ size={m_count} }} + + m_count + m_capacity + + m_count + m_buffer + + + + + + {{ size={m_count} }} + + m_count + m_capacity + + m_count + m_shortBuffer + $i + m_buffer + $i - $T2 + + + + + + {{ size={m_count} }} + + m_count + + m_count + m_buffer + + + + + + {{ {map.m_values} }} + + map.m_values + map + + + + + {{ {dict} }} + + dict + + + + + {{ size={dict._count} }} + + + m_dict._count + m_dict.kvPairs.head + next + value + + + + + + {{ size={m_count} }} + + + m_count + m_kvPairs.head + next + value + + + + + + pointer + empty + RefPtr {*pointer} + + pointer + + + + + + + ($T1*)(m_base->m_data + m_offset) + + + + + + (m_offset == 0x80000000) ? nullptr : ($T1*)(((char*)this) + m_offset) + + + + + + + m_count + + m_count + ($T1*)(m_data.m_base->m_data + m_data.m_offset) + + + + + + + + m_count + + m_count + (m_data.m_offset == 0x80000000) ? nullptr : ($T1*)(((char*)&m_data) + m_data.m_offset) + + + + + + {(m_sizeThenContents + 1),s} + (m_sizeThenContents + 1),s + + + + {m_begin,[m_end-m_begin]s} + m_begin,[m_end-m_begin]s + + + + + diff --git a/tests/sample.slang b/tests/sample.slang new file mode 100644 index 00000000000..165a78bbbea --- /dev/null +++ b/tests/sample.slang @@ -0,0 +1,54 @@ + + +interface ILog +{ + static float log(float x); +} + +struct SlowLog : ILog +{ + static float log(float x) + { + // implement slow log. + } +} + +struct FastLog : ILog +{ + static float log(float x) + { + // implement slow log. + } +} + +// prototype system +// declare ILog = SlowLog | FastLog; + +// actual slang +extern struct MyLog : ILog; // slangApi->setLinkTimeConst("MyLog", "FastLog"); + + +extern static const int TILE_SIZE; // slangApi->setLinkTimeConst("TILE_SIZE", 16); + +void main() +{ + // Use 1 + MyLog.log(10.f); + + + // Use 2 + MyLog.log(20.f); +} + +matrix matmul(matrix, matrix) +{ + .... +} + +void main() +{ + + matmul(cast(matA), cast(matB)); +} + +// slangApi->setLinkTimeConst("main.MyLog1", "FastLog"); \ No newline at end of file From e7ecb75925051d9b17a5aa3ec92bdce323b381f7 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 23 Jul 2025 12:33:51 -0400 Subject: [PATCH 003/105] More tests + fixes (generic interfaces & dependent associated types should work now) --- out.hlsl | 86 ------ source/slang/slang-ir-lower-dynamic-insts.cpp | 289 +++++++++++------- .../dynamic-dispatch/assoc-types.slang | 57 ++++ .../dependent-assoc-types-2.slang | 143 +++++++++ .../dependent-assoc-types.slang | 104 +++++++ .../dynamic-func-dynamic-output-1.slang | 64 ++++ .../dynamic-func-dynamic-output-2.slang | 66 ++++ .../dynamic-dispatch/func-call-input-1.slang | 48 +++ .../dynamic-dispatch/func-call-input-2.slang | 73 +++++ .../dynamic-dispatch/func-call-return.slang | 42 +++ .../generic-interface-2.slang | 64 ++++ .../generic-interface-3.slang | 64 ++++ .../dynamic-dispatch/generic-interface.slang | 57 ++++ .../dynamic-dispatch/with-data.slang | 44 +++ 14 files changed, 998 insertions(+), 203 deletions(-) delete mode 100644 out.hlsl create mode 100644 tests/language-feature/dynamic-dispatch/assoc-types.slang create mode 100644 tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang create mode 100644 tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang create mode 100644 tests/language-feature/dynamic-dispatch/func-call-input-1.slang create mode 100644 tests/language-feature/dynamic-dispatch/func-call-input-2.slang create mode 100644 tests/language-feature/dynamic-dispatch/func-call-return.slang create mode 100644 tests/language-feature/dynamic-dispatch/generic-interface-2.slang create mode 100644 tests/language-feature/dynamic-dispatch/generic-interface-3.slang create mode 100644 tests/language-feature/dynamic-dispatch/generic-interface.slang create mode 100644 tests/language-feature/dynamic-dispatch/with-data.slang diff --git a/out.hlsl b/out.hlsl deleted file mode 100644 index 50b7bd1e159..00000000000 --- a/out.hlsl +++ /dev/null @@ -1,86 +0,0 @@ -#pragma pack_matrix(column_major) -#ifdef SLANG_HLSL_ENABLE_NVAPI -#include "nvHLSLExtns.h" -#endif - -#ifndef __DXC_VERSION_MAJOR -// warning X3557: loop doesn't seem to do anything, forcing loop to unroll -#pragma warning(disable : 3557) -#endif - -RWStructuredBuffer outputBuffer_0 : register(u0); - -struct Tuple_0 -{ - uint value0_0; -}; - -uint _S1(uint _S2) -{ - switch(_S2) - { - case 1U: - { - return 3U; - } - case 2U: - { - return 4U; - } - default: - { - return 0U; - } - } -} - -float A_calc_0(float x_0) -{ - return x_0 * x_0 * x_0; -} - -float B_calc_0(float x_1) -{ - return x_1 * x_1; -} - -float _S3(uint _S4, float _S5) -{ - switch(_S4) - { - case 3U: - { - return A_calc_0(_S5); - } - case 4U: - { - return B_calc_0(_S5); - } - default: - { - return 0.0f; - } - } -} - -float f_0(uint id_0, float x_2) -{ - Tuple_0 obj_0; - if(id_0 == 0U) - { - obj_0.value0_0 = 1U; - } - else - { - obj_0.value0_0 = 2U; - } - return _S3(_S1(obj_0.value0_0), x_2); -} - -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID_0 : SV_DispatchThreadID) -{ - outputBuffer_0[int(0)] = f_0(0U, 1.0f); - return; -} - diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index acf1af84e24..0f06596409f 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -473,6 +473,7 @@ struct DynamicInstLoweringContext if (newArg != arg) { // Replace the argument in the branch instruction + SLANG_ASSERT(!as(unconditionalBranch)); unconditionalBranch->setOperand(1 + paramIndex, newArg); } } @@ -581,10 +582,12 @@ struct DynamicInstLoweringContext // Get the witness table info auto witnessTableInfo = tryGetInfo(witnessTable); - if (!witnessTableInfo || witnessTableInfo.judgment == PropagationJudgment::Unbounded) - { + + if (!witnessTableInfo) + return PropagationInfo::none(); + + if (witnessTableInfo.judgment == PropagationJudgment::Unbounded) return PropagationInfo::makeUnbounded(); - } HashSet tables; @@ -624,10 +627,11 @@ struct DynamicInstLoweringContext auto key = inst->getRequirementKey(); auto witnessTableInfo = tryGetInfo(witnessTable); - if (!witnessTableInfo || witnessTableInfo.judgment == PropagationJudgment::Unbounded) - { + if (!witnessTableInfo) + return PropagationInfo::none(); + + if (witnessTableInfo.judgment == PropagationJudgment::Unbounded) return PropagationInfo::makeUnbounded(); - } HashSet results; @@ -694,10 +698,11 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(operand); - if (!operandInfo || operandInfo.judgment == PropagationJudgment::Unbounded) - { + if (!operandInfo) + return PropagationInfo::none(); + + if (operandInfo.judgment == PropagationJudgment::Unbounded) return PropagationInfo::makeUnbounded(); - } if (operandInfo.judgment == PropagationJudgment::Existential) { @@ -727,10 +732,11 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(operand); - if (!operandInfo || operandInfo.judgment == PropagationJudgment::Unbounded) - { + if (!operandInfo) + return PropagationInfo::none(); + + if (operandInfo.judgment == PropagationJudgment::Unbounded) return PropagationInfo::makeUnbounded(); - } if (operandInfo.judgment == PropagationJudgment::Existential) { @@ -786,15 +792,6 @@ struct DynamicInstLoweringContext auto funcType = as(callee->getDataType()); - // TODO: Expand logic to handle all outputs. - // For now, we're focused on the return value. - // - if (!funcType || !isExistentialType(funcType->getResultType())) - { - return PropagationInfo::makeValue(); - } - - // Okay, we have an call that can return an existential type. // // Propagate the input judgments to the call & append a work item // for inter-procedural propagation. @@ -803,78 +800,44 @@ struct DynamicInstLoweringContext // For now, we'll handle just a concrete func. But the logic for multiple functions // is exactly the same (add an edge for each function). // - if (calleeInfo && calleeInfo.judgment == PropagationJudgment::ConcreteFunc) - { - workQueue.addLast( - WorkItem(InterproceduralEdge::Direction::CallToFunc, inst, cast(callee))); - } - - /*if (auto concreteFunc = as(callee)) - { - auto returnInfo = tryGetFuncReturnInfo(concreteFunc); - if (returnInfo) - { - // Use the function's return info directly - propagationMap[inst] = returnInfo; - return returnInfo; - } - } - - List outputVals; - outputVals.add(inst); // The call inst itself is an output. - - // Also add all OutTypeBase parameters - auto funcType = as(callee->getDataType()); - if (funcType) - { - UIndex paramIndex = 0; - for (auto paramType : funcType->getParamTypes()) + auto propagateToCallSite = [&](IRFunc* func) + { + // Register the call site in the map to allow for the + // return-edge to be created. + // + // We use an explicit map instead of walking the uses of the + // func, since we might have functions that are called indirectly + // through lookups. + // + this->funcCallSites.addIfNotExists(func, HashSet()); + if (this->funcCallSites[func].add(inst)) { - if (as(paramType)) - { - // If this is an OutTypeBase, we consider it an output - outputVals.add(inst->getArg(paramIndex)); - } - paramIndex++; + // If this is a new call site, add a propagation task to the queue (in case there's + // already information about this function) + workQueue.addLast(WorkItem(InterproceduralEdge::Direction::FuncToCall, inst, func)); } - } + workQueue.addLast(WorkItem(InterproceduralEdge::Direction::CallToFunc, inst, func)); + }; - for (auto outputVal : outputVals) + if (calleeInfo) { - if (as(outputVal)) + if (calleeInfo.judgment == PropagationJudgment::ConcreteFunc) { - // TODO: We need to set up infrastructure to track variable - // assignments. - // For now, we will just return a value judgment for variables. - // (doesn't make much sense.. but its fine for now) - // - propagationMap[outputVal] = PropagationInfo::makeValue(); + // If we have a concrete function, register the call site + propagateToCallSite(as(calleeInfo.concreteValue)); } - else + else if (calleeInfo.judgment == PropagationJudgment::SetOfFuncs) { - if (auto interfaceType = as(outputVal->getDataType())) - { - if (!interfaceType->findDecoration()) - { - // If this is an interface type, we need to propagate existential info - // based on the interface type. - propagationMap[outputVal] = PropagationInfo::makeExistential( - collectExistentialTables(interfaceType)); - } - else - { - // If this is a COM interface, we treat it as unknown - propagationMap[outputVal] = PropagationInfo::makeUnbounded(); - } - } - else - { - propagationMap[outputVal] = PropagationInfo::makeValue(); - } + // If we have a set of functions, register each one + for (auto func : calleeInfo.possibleValues) + propagateToCallSite(as(func)); } - }*/ + } - return PropagationInfo::none(); + if (auto callInfo = tryGetInfo(inst)) + return callInfo; + else + return PropagationInfo::none(); } void propagateWithinFuncEdge(IREdge edge, LinkedList& workQueue) @@ -898,6 +861,7 @@ struct DynamicInstLoweringContext return; // Collect propagation info for each argument and update corresponding parameter + // TODO: Unify this logic with the affectedBlocks logic in the per-inst processing logic. HashSet affectedBlocks; Index paramIndex = 0; for (auto param : successorBlock->getParams()) @@ -962,7 +926,6 @@ struct DynamicInstLoweringContext { if (argIndex < callInst->getOperandCount()) { - // TODO: handle inst-effect propagation properly auto arg = callInst->getOperand(argIndex); if (auto argInfo = tryGetInfo(arg)) { @@ -998,18 +961,10 @@ struct DynamicInstLoweringContext { // Union with existing call info auto existingCallInfo = tryGetInfo(callInst); - if (existingCallInfo) - { - auto newInfo = unionPropagationInfo(existingCallInfo, *returnInfo); - propagationMap[callInst] = newInfo; - if (!anyInfoChanged && !areInfosEqual(existingCallInfo, newInfo)) - anyInfoChanged = true; - } - else - { - propagationMap[callInst] = *returnInfo; + auto newInfo = unionPropagationInfo(existingCallInfo, *returnInfo); + propagationMap[callInst] = newInfo; + if (!areInfosEqual(existingCallInfo, newInfo)) anyInfoChanged = true; - } } // Add the callInst's parent block to the work queue if any info changed @@ -1107,18 +1062,12 @@ struct DynamicInstLoweringContext } // If return info changed, add return edges to call sites - if (returnInfoChanged) + if (returnInfoChanged && this->funcCallSites.containsKey(func)) { - for (auto use = func->firstUse; use; use = use->nextUse) + for (auto callSite : this->funcCallSites[func]) { - if (auto callInst = as(use->getUser())) - { - if (callInst->getCallee() != func) - continue; // Not a call to this function - - workQueue.addLast( - WorkItem(InterproceduralEdge::Direction::FuncToCall, callInst, func)); - } + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::FuncToCall, callSite, func)); } } @@ -1289,6 +1238,7 @@ struct DynamicInstLoweringContext List typeInstsToLower; List valueInstsToLower; List instWithReplacementTypes; + List funcTypesToProcess; for (auto globalInst : module->getGlobalInsts()) { @@ -1317,24 +1267,32 @@ struct DynamicInstLoweringContext break; case kIROp_ExtractExistentialWitnessTable: case kIROp_ExtractExistentialValue: - case kIROp_Call: case kIROp_MakeExistential: case kIROp_CreateExistentialObject: valueInstsToLower.add(child); break; + case kIROp_Call: + { + if (auto info = tryGetInfo(child)) + if (info.judgment == PropagationJudgment::Existential) + instWithReplacementTypes.add(child); + + if (auto calleeInfo = tryGetInfo(as(child)->getCallee())) + if (calleeInfo.judgment == PropagationJudgment::SetOfFuncs) + valueInstsToLower.add(child); + } + break; default: if (auto info = tryGetInfo(child)) - { if (info.judgment == PropagationJudgment::Existential) - { // If this instruction has a set of types, tables, or funcs, // we need to lower it to a unified type. instWithReplacementTypes.add(child); - } - } } } } + + funcTypesToProcess.add(func); } } @@ -1346,6 +1304,54 @@ struct DynamicInstLoweringContext for (auto inst : instWithReplacementTypes) replaceType(inst); + + for (auto func : funcTypesToProcess) + replaceFuncType(func, this->funcReturnInfo[func]); + } + + void replaceFuncType(IRFunc* func, PropagationInfo& returnTypeInfo) + { + IRFuncType* origFuncType = as(func->getFullType()); + IRType* returnType = origFuncType->getResultType(); + if (returnTypeInfo.judgment == PropagationJudgment::Existential) + { + // If the return type is existential, we need to replace it with a tuple type + returnType = getTypeForExistential(returnTypeInfo); + } + + List paramTypes; + for (auto param : func->getFirstBlock()->getParams()) + { + // Extract the existential type from the parameter if it exists + auto paramInfo = tryGetInfo(param); + if (paramInfo && paramInfo.judgment == PropagationJudgment::Existential) + { + paramTypes.add(getTypeForExistential(paramInfo)); + } + else + paramTypes.add(param->getDataType()); + } + + IRBuilder builder(module); + builder.setInsertBefore(func); + func->setFullType(builder.getFuncType(paramTypes, returnType)); + } + + IRType* getTypeForExistential(PropagationInfo info) + { + // Replace type with Tuple + IRBuilder builder(module); + builder.setInsertInto(module); + + HashSet types; + // Extract types from witness tables by looking at the concrete types + for (auto table : info.possibleValues) + if (auto witnessTable = as(table)) + if (auto concreteType = witnessTable->getConcreteType()) + types.add(concreteType); + + auto anyValueType = createAnyValueTypeFromInsts(types); + return builder.getTupleType(List({builder.getUIntType(), anyValueType})); } void replaceType(IRInst* inst) @@ -1354,12 +1360,22 @@ struct DynamicInstLoweringContext if (!info || info.judgment != PropagationJudgment::Existential) return; - // Replace type with Tuple + /* Replace type with Tuple IRBuilder builder(module); builder.setInsertBefore(inst); - auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); - auto tupleType = builder.getTupleType(List({builder.getUIntType(), anyValueType})); - inst->setFullType(tupleType); + + HashSet types; + // Extract types from witness tables by looking at the concrete types + for (auto table : info.possibleValues) + if (auto witnessTable = as(table)) + if (auto concreteType = witnessTable->getConcreteType()) + types.add(concreteType); + + auto anyValueType = createAnyValueTypeFromInsts(types); + auto tupleType = builder.getTupleType(List({builder.getUIntType(), + anyValueType}));*/ + + inst->setFullType(getTypeForExistential(info)); } void lowerInst(IRInst* inst) @@ -1502,6 +1518,40 @@ struct DynamicInstLoweringContext inst->removeAndDeallocate(); } + IRFuncType* getExpectedFuncType(IRCall* inst) + { + // Translate argument types into expected function type. + // For now, we handle just 'in' arguments. + List argTypes; + for (UInt i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + if (auto argInfo = tryGetInfo(arg)) + { + // If the argument is existential, we need to use the type for existential + if (argInfo.judgment == PropagationJudgment::Existential) + { + argTypes.add(getTypeForExistential(argInfo)); + continue; + } + } + + argTypes.add(arg->getDataType()); + } + + // Translate result type. + IRType* resultType = inst->getDataType(); + auto returnInfo = tryGetInfo(inst); + if (returnInfo && returnInfo.judgment == PropagationJudgment::Existential) + { + resultType = getTypeForExistential(returnInfo); + } + + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getFuncType(argTypes, resultType); + } + void lowerCall(IRCall* inst) { auto callee = inst->getCallee(); @@ -1513,8 +1563,9 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); + auto expectedFuncType = getExpectedFuncType(inst); // Create dispatch function - auto dispatchFunc = createDispatchFunc(calleeInfo.possibleValues, calleeInfo.dynFuncType); + auto dispatchFunc = createDispatchFunc(calleeInfo.possibleValues, expectedFuncType); // Replace call with dispatch List newArgs; @@ -1785,7 +1836,8 @@ struct DynamicInstLoweringContext } // Call the specific function - auto callResult = builder.emitCallInst(resultType, funcInst, callArgs); + auto callResult = + builder.emitCallInst(concreteFuncType->getResultType(), funcInst, callArgs); if (resultType->getOp() == kIROp_VoidType) { @@ -1950,6 +2002,9 @@ struct DynamicInstLoweringContext // Mapping from function to return value propagation information Dictionary funcReturnInfo; + // Mapping from functions to call-sites. + Dictionary> funcCallSites; + // Unique ID assignment for functions and witness tables Dictionary uniqueIds; UInt nextUniqueId = 1; diff --git a/tests/language-feature/dynamic-dispatch/assoc-types.slang b/tests/language-feature/dynamic-dispatch/assoc-types.slang new file mode 100644 index 00000000000..9900014900d --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/assoc-types.slang @@ -0,0 +1,57 @@ + +RWStructuredBuffer outputBuffer; + +interface ICalculation +{ + associatedtype Data; + float calc(Data d, float x); + Data make(float q); +} + +struct A : ICalculation +{ + typealias Data = float; + float calc(Data d, float x) { return d * x * x; } + Data make(float q) { return q; } +}; + +struct BData +{ + float x; + float y; +}; + +struct B : ICalculation +{ + typealias Data = BData; + float calc(Data d, float x) { return d.x * x * x + d.y; } + Data make(float q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = float; + float calc(Data d, float x) { return d * x; } + Data make(float q) { return q; } +}; + +ICalculation factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang new file mode 100644 index 00000000000..e14e3a1c063 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang @@ -0,0 +1,143 @@ + +RWStructuredBuffer outputBuffer; + +interface IBufferRef +{ + uint8_t[N] get(uint index); + void set(uint index, uint8_t[N] value); + This withOffset(uint offset); +}; + +interface ISerializer +{ + static void serialize(T data, IBufferRef buffer); + static T deserialize(IBufferRef buffer); +} + +interface ICalculation +{ + associatedtype Data; + associatedtype DataSerializer : ISerializer; + float calc(Data d, float x); + Data make(float q); +} + +struct StandardSerializer : ISerializer +{ + static void serialize(T data, IBufferRef buffer) + { + buffer.set<4>(0, bit_cast(data)); + } + + static T deserialize(IBufferRef buffer) + { + return bit_cast(buffer.get<4>(0)); + } +}; + +struct A : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x * x; } + Data make(float q) { return q; } +}; + +struct BData +{ + float x; + float y; +}; + +struct BDataSerializer : ISerializer +{ + static void serialize(BData data, IBufferRef buffer) + { + StandardSerializer::serialize(data.x, buffer.withOffset(0)); + StandardSerializer::serialize(data.y, buffer.withOffset(4)); + } + + static BData deserialize(IBufferRef buffer) + { + return BData(StandardSerializer::deserialize(buffer.withOffset(0)), + StandardSerializer::deserialize(buffer.withOffset(4))); + } +}; + +struct B : ICalculation +{ + typealias Data = BData; + typealias DataSerializer = BDataSerializer; + float calc(Data d, float x) { return d.x * x * x + d.y; } + Data make(float q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x; } + Data make(float q) { return q; } +}; + +ICalculation factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +struct BufferRef : IBufferRef +{ + uint offset; + uint8_t[N] get(uint index) + { + uint offset = index * N; + uint8_t[N] result; + for (int i = 0; i < N; ++i) + { + result[i] = outputBuffer[offset + i]; + } + return result; + } + + void set(uint index, uint8_t[N] value) + { + uint offset = index * N; + for (int i = 0; i < N; ++i) + { + outputBuffer[offset + i] = value[i]; + } + } + + This withOffset(uint offset) + { + return {offset + this.offset}; + } +}; +IBufferRef getOutputBufferAsRef() +{ + return BufferRef(0); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + obj.DataSerializer::serialize(d, getOutputBufferAsRef()); +} + +float g(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.DataSerializer::deserialize(getOutputBufferAsRef()); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + f(0, 1.f); + g(0, 1.f); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang new file mode 100644 index 00000000000..223d1a0cc67 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang @@ -0,0 +1,104 @@ + +RWStructuredBuffer outputBuffer; + +interface ISerializer +{ + static void serialize(T data, RWStructuredBuffer buffer); + static T deserialize(RWStructuredBuffer buffer); +} + +interface ICalculation +{ + associatedtype Data; + associatedtype DataSerializer : ISerializer; + float calc(Data d, float x); + Data make(float q); +} + +struct StandardSerializer : ISerializer +{ + static void serialize(T data, RWStructuredBuffer buffer) + { + // Note: just for testing.. don't ever serialize this way + buffer[0] = __realCast(data); + } + + static T deserialize(RWStructuredBuffer buffer) + { + return __realCast(buffer[0]); + } +}; + +struct A : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x * x; } + Data make(float q) { return q; } +}; + +struct BData +{ + float x; + float y; +}; + +struct BDataSerializer : ISerializer +{ + static void serialize(BData data, RWStructuredBuffer buffer) + { + buffer[0] = data.x; + buffer[1] = data.y; + } + + static BData deserialize(RWStructuredBuffer buffer) + { + return {buffer[0], buffer[1]}; + } +}; + +struct B : ICalculation +{ + typealias Data = BData; + typealias DataSerializer = BDataSerializer; + float calc(Data d, float x) { return d.x * x * x + d.y; } + Data make(float q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = float; + typealias DataSerializer = StandardSerializer; + float calc(Data d, float x) { return d * x; } + Data make(float q) { return q; } +}; + +ICalculation factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + obj.DataSerializer::serialize(d, outputBuffer); + outputBuffer[3] = id; +} + +float g(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.DataSerializer::deserialize(outputBuffer); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + f(0, 1.f); + g(0, 1.f); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang new file mode 100644 index 00000000000..f06952769e2 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang @@ -0,0 +1,64 @@ + +RWStructuredBuffer outputBuffer; + +interface IFoo +{ + float calc(float x); +} + +interface IFactory +{ + IFoo create(); +} + +struct AFoo : IFoo +{ + float calc(float x) { return x * x * x; } +}; + +struct A : IFactory +{ + IFoo create() { return AFoo(); } +}; + +struct B : IFactory, IFoo +{ + float calc(float x) { return x * x; } + IFoo create() { return this; } +}; + +struct CFoo : IFoo +{ + float calc(float x) { return x; } +}; + +struct C : IFactory +{ + IFoo create() { return CFoo(); } +}; + +IFactory getFactory(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float calc(IFoo obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IFactory obj = getFactory(id, x); + IFoo foo = obj.create(); + return calc(foo, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang new file mode 100644 index 00000000000..9d08e4a3867 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang @@ -0,0 +1,66 @@ + +RWStructuredBuffer outputBuffer; + +interface IFoo +{ + float calc(float x); +} + +interface IFactory +{ + IFoo create(float q); +} + +struct AFoo : IFoo +{ + float q; + float calc(float x) { return q * x * x; } +}; + +struct A : IFactory +{ + IFoo create(float q) { return AFoo(q); } +}; + +struct B : IFactory, IFoo +{ + float q; + float calc(float x) { return q * x * x; } + IFoo create(float q) { return This(q); } +}; + +struct CFoo : IFoo +{ + float calc(float x) { return x; } +}; + +struct C : IFactory +{ + IFoo create(float q) { return CFoo(); } +}; + +IFactory getFactory(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(0); +} + +float calc(IFoo obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IFactory factory = getFactory(id, x); + IFoo foo = factory.create(2 * x); + return calc(foo, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-1.slang b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang new file mode 100644 index 00000000000..aaeb9b6da3c --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang @@ -0,0 +1,48 @@ + +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +IInterface factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +// Should lower to accept A, B, C or D (but not E) +float calc(IInterface obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IInterface obj = factoryAB(id, x); + return calc(obj, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-2.slang b/tests/language-feature/dynamic-dispatch/func-call-input-2.slang new file mode 100644 index 00000000000..cebb191cb69 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/func-call-input-2.slang @@ -0,0 +1,73 @@ + +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +struct D : IInterface +{ + float calc(float x) { return x + x * x; } +}; + +struct E : IInterface +{ + float calc(float x) { return x * x + x * 3; } +}; + +IInterface factoryAB(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +IInterface factoryCD(uint id, float x) +{ + if (id == 0) + return C(); + else + return D(); +} + +// Should lower to accept A, B, C or D (but not E) +float calc(IInterface obj, float y) +{ + return obj.calc(y); +} + +float f(uint id, float x) +{ + IInterface obj = factoryAB(id, x); + return calc(obj, x); +} + +float g(uint id, float x) +{ + IInterface obj = factoryCD(id, x); + return calc(obj, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); + outputBuffer[1] = g(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-return.slang b/tests/language-feature/dynamic-dispatch/func-call-return.slang new file mode 100644 index 00000000000..6eb3b352bdc --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/func-call-return.slang @@ -0,0 +1,42 @@ + +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +IInterface factory(uint id, float x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + IInterface obj = factory(id, x); + return obj.calc(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-2.slang b/tests/language-feature/dynamic-dispatch/generic-interface-2.slang new file mode 100644 index 00000000000..3eac08feb76 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-interface-2.slang @@ -0,0 +1,64 @@ + +RWStructuredBuffer outputBuffer; + +interface IData +{ + T getNorm2(); +} + +interface ICalculation +{ + associatedtype Data : IData; + T calc(Data d, T x); + Data make(T q); +} + +struct AData : IData +{ + T x; + + T getNorm2() { return x * x; } +}; + +struct A : ICalculation +{ + typealias Data = AData; + T calc(Data d, T x) { return d.getNorm2() * x * x; } + Data make(T q) { return {q}; } +}; + +struct BData : IData +{ + T x; + T y; + + T getNorm2() { return x * x + y * y; } +}; + +struct B : ICalculation +{ + typealias Data = BData; + T calc(Data d, T x) { return d.x * x * x + d.y; } + Data make(T q) { return {q, q}; } +}; + +ICalculation factoryAB(uint id, T x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x) + d.getNorm2(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-3.slang b/tests/language-feature/dynamic-dispatch/generic-interface-3.slang new file mode 100644 index 00000000000..3eac08feb76 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-interface-3.slang @@ -0,0 +1,64 @@ + +RWStructuredBuffer outputBuffer; + +interface IData +{ + T getNorm2(); +} + +interface ICalculation +{ + associatedtype Data : IData; + T calc(Data d, T x); + Data make(T q); +} + +struct AData : IData +{ + T x; + + T getNorm2() { return x * x; } +}; + +struct A : ICalculation +{ + typealias Data = AData; + T calc(Data d, T x) { return d.getNorm2() * x * x; } + Data make(T q) { return {q}; } +}; + +struct BData : IData +{ + T x; + T y; + + T getNorm2() { return x * x + y * y; } +}; + +struct B : ICalculation +{ + typealias Data = BData; + T calc(Data d, T x) { return d.x * x * x + d.y; } + Data make(T q) { return {q, q}; } +}; + +ICalculation factoryAB(uint id, T x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x) + d.getNorm2(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface.slang b/tests/language-feature/dynamic-dispatch/generic-interface.slang new file mode 100644 index 00000000000..b421ebc30be --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-interface.slang @@ -0,0 +1,57 @@ + +RWStructuredBuffer outputBuffer; + +interface ICalculation +{ + associatedtype Data; + T calc(Data d, T x); + Data make(T q); +} + +struct A : ICalculation +{ + typealias Data = T; + T calc(Data d, T x) { return d * x * x; } + Data make(T q) { return q; } +}; + +struct BData +{ + T x; + T y; +}; + +struct B : ICalculation +{ + typealias Data = BData; + T calc(Data d, T x) { return d.x * x * x + d.y; } + Data make(T q) { return {q, q}; } +}; + +struct C : ICalculation +{ + typealias Data = T; + T calc(Data d, T x) { return d * x; } + Data make(T q) { return q; } +}; + +ICalculation factoryAB(uint id, T x) +{ + if (id == 0) + return A(); + else + return B(); +} + +float f(uint id, float x) +{ + let obj = factoryAB(id, x); + obj.Data d = obj.make(x); + return obj.calc(d, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/with-data.slang b/tests/language-feature/dynamic-dispatch/with-data.slang new file mode 100644 index 00000000000..85874fdcccd --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/with-data.slang @@ -0,0 +1,44 @@ + +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + float calc(float x); +} + +struct A : IInterface +{ + float factor; + float calc(float x) { return x * x * factor; } +}; + +struct B : IInterface +{ + float factor; + float offset; + float calc(float x) { return x * x * factor + offset; } +}; + +struct C : IInterface +{ + float calc(float x) { return x; } +}; + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(x); + else + obj = B(x, x); + + return obj.calc(x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 1); +} \ No newline at end of file From feaa2fa3bc5d5b6a5678d9ec4f4caa5731a36e4a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 23 Jul 2025 13:54:31 -0400 Subject: [PATCH 004/105] Make test files work with slang-test add CHECK statements --- source/slang/slang-ir-lower-dynamic-insts.cpp | 7 ++-- .../dynamic-dispatch/assoc-types.slang | 5 ++- .../dependent-assoc-types.slang | 39 +++++++++++++------ .../dynamic-func-dynamic-output-1.slang | 5 ++- .../dynamic-func-dynamic-output-2.slang | 9 +++-- .../dynamic-dispatch/func-call-input-1.slang | 5 ++- .../dynamic-dispatch/func-call-input-2.slang | 8 +++- .../dynamic-dispatch/func-call-return.slang | 5 ++- ...erface.slang => generic-interface-1.slang} | 5 ++- .../generic-interface-2.slang | 5 ++- .../generic-interface-3.slang | 7 +++- .../dynamic-dispatch/simple.slang | 5 ++- .../dynamic-dispatch/with-data.slang | 5 ++- 13 files changed, 81 insertions(+), 29 deletions(-) rename tests/language-feature/dynamic-dispatch/{generic-interface.slang => generic-interface-1.slang} (78%) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 0f06596409f..3d8a6ffcb9b 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1299,14 +1299,14 @@ struct DynamicInstLoweringContext for (auto inst : typeInstsToLower) lowerInst(inst); + for (auto func : funcTypesToProcess) + replaceFuncType(func, this->funcReturnInfo[func]); + for (auto inst : valueInstsToLower) lowerInst(inst); for (auto inst : instWithReplacementTypes) replaceType(inst); - - for (auto func : funcTypesToProcess) - replaceFuncType(func, this->funcReturnInfo[func]); } void replaceFuncType(IRFunc* func, PropagationInfo& returnTypeInfo) @@ -1579,6 +1579,7 @@ struct DynamicInstLoweringContext inst->replaceUsesWith(newCall); if (auto info = tryGetInfo(inst)) propagationMap[newCall] = info; + replaceType(newCall); // "maybe replace type" inst->removeAndDeallocate(); } diff --git a/tests/language-feature/dynamic-dispatch/assoc-types.slang b/tests/language-feature/dynamic-dispatch/assoc-types.slang index 9900014900d..8b9f82e8d0f 100644 --- a/tests/language-feature/dynamic-dispatch/assoc-types.slang +++ b/tests/language-feature/dynamic-dispatch/assoc-types.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface ICalculation @@ -53,5 +55,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 1); // CHECK: 1.0 + outputBuffer[1] = f(1, 1); // CHECK: 2.0 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang index 223d1a0cc67..b2e3da8dc91 100644 --- a/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang +++ b/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface ISerializer @@ -61,8 +63,8 @@ struct B : ICalculation { typealias Data = BData; typealias DataSerializer = BDataSerializer; - float calc(Data d, float x) { return d.x * x * x + d.y; } - Data make(float q) { return {q, q}; } + float calc(Data d, float x) { return d.x * x + d.y; } + Data make(float q) { return {q, q * q}; } }; struct C : ICalculation @@ -73,7 +75,7 @@ struct C : ICalculation Data make(float q) { return q; } }; -ICalculation factoryAB(uint id, float x) +ICalculation factoryAB(uint id) { if (id == 0) return A(); @@ -81,17 +83,19 @@ ICalculation factoryAB(uint id, float x) return B(); } -float f(uint id, float x) +float f(uint id, float q) { - let obj = factoryAB(id, x); - obj.Data d = obj.make(x); + let obj = factoryAB(id); + obj.Data d = obj.make(q); obj.DataSerializer::serialize(d, outputBuffer); - outputBuffer[3] = id; + outputBuffer[2] = id; + return 0; } -float g(uint id, float x) +float g(float x) { - let obj = factoryAB(id, x); + uint id = (uint)outputBuffer[2]; + let obj = factoryAB(id); obj.Data d = obj.DataSerializer::deserialize(outputBuffer); return obj.calc(d, x); } @@ -99,6 +103,19 @@ float g(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - f(0, 1.f); - g(0, 1.f); + f(0, 3.f); + outputBuffer[3] = g(2.f); + f(1, 2.f); + outputBuffer[4] = g(2.f); + + // Clear the first 3 elements + outputBuffer[0] = 0.f; + outputBuffer[1] = 0.f; + outputBuffer[2] = 0.f; + + // CHECK: 0.0 + // CHECK: 0.0 + // CHECK: 0.0 + // CHECK: 12.0 + // CHECK: 8.0 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang index f06952769e2..0f9516c4279 100644 --- a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang +++ b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-1.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IFoo @@ -60,5 +62,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang index 9d08e4a3867..f02203c6f60 100644 --- a/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang +++ b/tests/language-feature/dynamic-dispatch/dynamic-func-dynamic-output-2.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IFoo @@ -14,7 +16,7 @@ interface IFactory struct AFoo : IFoo { float q; - float calc(float x) { return q * x * x; } + float calc(float x) { return q * x; } }; struct A : IFactory @@ -44,7 +46,7 @@ IFactory getFactory(uint id, float x) if (id == 0) return A(); else - return B(0); + return B(x); } float calc(IFoo obj, float y) @@ -62,5 +64,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 16 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-1.slang b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang index aaeb9b6da3c..60e2ad547c0 100644 --- a/tests/language-feature/dynamic-dispatch/func-call-input-1.slang +++ b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IInterface @@ -44,5 +46,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 9 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-2.slang b/tests/language-feature/dynamic-dispatch/func-call-input-2.slang index cebb191cb69..62988cb7b02 100644 --- a/tests/language-feature/dynamic-dispatch/func-call-input-2.slang +++ b/tests/language-feature/dynamic-dispatch/func-call-input-2.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IInterface @@ -68,6 +70,8 @@ float g(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); - outputBuffer[1] = g(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 9 + outputBuffer[2] = g(0, 3); // CHECK: 3 + outputBuffer[3] = g(1, 3); // CHECK: 12 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/func-call-return.slang b/tests/language-feature/dynamic-dispatch/func-call-return.slang index 6eb3b352bdc..a9d91823fb9 100644 --- a/tests/language-feature/dynamic-dispatch/func-call-return.slang +++ b/tests/language-feature/dynamic-dispatch/func-call-return.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IInterface @@ -38,5 +40,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 9 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface.slang b/tests/language-feature/dynamic-dispatch/generic-interface-1.slang similarity index 78% rename from tests/language-feature/dynamic-dispatch/generic-interface.slang rename to tests/language-feature/dynamic-dispatch/generic-interface-1.slang index b421ebc30be..13d5843e0e9 100644 --- a/tests/language-feature/dynamic-dispatch/generic-interface.slang +++ b/tests/language-feature/dynamic-dispatch/generic-interface-1.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface ICalculation @@ -53,5 +55,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 27.0 + outputBuffer[1] = f(1, 3); // CHECK: 30.0 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-2.slang b/tests/language-feature/dynamic-dispatch/generic-interface-2.slang index 3eac08feb76..e91abdedd3a 100644 --- a/tests/language-feature/dynamic-dispatch/generic-interface-2.slang +++ b/tests/language-feature/dynamic-dispatch/generic-interface-2.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IData @@ -60,5 +62,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 90 + outputBuffer[1] = f(1, 3); // CHECK: 48 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/generic-interface-3.slang b/tests/language-feature/dynamic-dispatch/generic-interface-3.slang index 3eac08feb76..e94518fb97a 100644 --- a/tests/language-feature/dynamic-dispatch/generic-interface-3.slang +++ b/tests/language-feature/dynamic-dispatch/generic-interface-3.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IData @@ -23,7 +25,7 @@ struct AData : IData struct A : ICalculation { typealias Data = AData; - T calc(Data d, T x) { return d.getNorm2() * x * x; } + T calc(Data d, T x) { return d.x * x * x; } Data make(T q) { return {q}; } }; @@ -60,5 +62,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 36 + outputBuffer[1] = f(1, 3); // CHECK: 48 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/simple.slang b/tests/language-feature/dynamic-dispatch/simple.slang index 229d07648db..6d52cfbe444 100644 --- a/tests/language-feature/dynamic-dispatch/simple.slang +++ b/tests/language-feature/dynamic-dispatch/simple.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IInterface @@ -37,5 +39,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 } \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/with-data.slang b/tests/language-feature/dynamic-dispatch/with-data.slang index 85874fdcccd..c48e2f2d04f 100644 --- a/tests/language-feature/dynamic-dispatch/with-data.slang +++ b/tests/language-feature/dynamic-dispatch/with-data.slang @@ -1,4 +1,6 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; interface IInterface @@ -40,5 +42,6 @@ float f(uint id, float x) [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - outputBuffer[0] = f(0, 1); + outputBuffer[0] = f(0, 3); // CHECK: 27 + outputBuffer[1] = f(1, 3); // CHECK: 30 } \ No newline at end of file From d13f3d712fc6fc2c828d44d283625ad1ea865fa5 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 23 Jul 2025 15:26:26 -0400 Subject: [PATCH 005/105] Simplify the inst states --- source/slang/slang-ir-lower-dynamic-insts.cpp | 604 +++++------------- 1 file changed, 175 insertions(+), 429 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 3d8a6ffcb9b..438edb58f6d 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -13,22 +13,15 @@ namespace Slang // // This forms a lattice with // -// None < Value -// None < ConcreteX < SetOfX < Unbounded +// None < Set < Unbounded // None < Existential < Unbounded // enum class PropagationJudgment { - None, // No judgment (initial value) - Value, // Regular value computation (unrelated to dynamic dispatch) - ConcreteType, // Concrete type reference - ConcreteTable, // Concrete witness table reference - ConcreteFunc, // Concrete function reference - SetOfTypes, // Set of possible types - SetOfTables, // Set of possible witness tables - SetOfFuncs, // Set of possible functions - Existential, // Existential box with a set of possible witness tables - Unbounded, // Unknown set of possible types/tables/funcs (e.g. COM interface types) + None, // Either uninitialized or irrelevant + Set, // Set of possible types/tables/funcs + Existential, // Existential box with a set of possible witness tables + Unbounded, // Unknown set of possible types/tables/funcs (e.g. COM interface types) }; // Data structure to hold propagation information for an instruction @@ -36,17 +29,11 @@ struct PropagationInfo : RefObject { PropagationJudgment judgment; - // For concrete references - IRInst* concreteValue = nullptr; - // For sets of types/tables/funcs and existential witness tables HashSet possibleValues; - // For SetOfFuncs - IRFuncType* dynFuncType; - PropagationInfo() - : judgment(PropagationJudgment::None), concreteValue(nullptr), dynFuncType(nullptr) + : judgment(PropagationJudgment::None) { } @@ -55,15 +42,25 @@ struct PropagationInfo : RefObject { } - static PropagationInfo makeValue() { return PropagationInfo(PropagationJudgment::Value); } - static PropagationInfo makeConcrete(PropagationJudgment j, IRInst* value); - static PropagationInfo makeSet(PropagationJudgment j, const HashSet& values); - static PropagationInfo makeSetOfFuncs(const HashSet& funcs, IRFuncType* dynFuncType); + static PropagationInfo makeSingletonSet(IRInst* value); + static PropagationInfo makeSet(const HashSet& values); static PropagationInfo makeExistential(const HashSet& tables); static PropagationInfo makeUnbounded(); static PropagationInfo none(); bool isNone() const { return judgment == PropagationJudgment::None; } + bool isSingleton() const + { + return judgment == PropagationJudgment::Set && possibleValues.getCount() == 1; + } + + IRInst* getSingletonValue() const + { + if (judgment == PropagationJudgment::Set && possibleValues.getCount() == 1) + return *possibleValues.begin(); + + SLANG_UNEXPECTED("getSingletonValue called on non-singleton PropagationInfo"); + } operator bool() const { return judgment != PropagationJudgment::None; } }; @@ -160,32 +157,20 @@ struct WorkItem }; // PropagationInfo implementation -PropagationInfo PropagationInfo::makeConcrete(PropagationJudgment j, IRInst* value) +PropagationInfo PropagationInfo::makeSingletonSet(IRInst* value) { - auto info = PropagationInfo(j); - info.concreteValue = value; + auto info = PropagationInfo(PropagationJudgment::Set); + info.possibleValues.add(value); return info; } -PropagationInfo PropagationInfo::makeSet(PropagationJudgment j, const HashSet& values) +PropagationInfo PropagationInfo::makeSet(const HashSet& values) { - auto info = PropagationInfo(j); - SLANG_ASSERT(j != PropagationJudgment::SetOfFuncs); + auto info = PropagationInfo(PropagationJudgment::Set); info.possibleValues = values; return info; } - -PropagationInfo PropagationInfo::makeSetOfFuncs( - const HashSet& values, - IRFuncType* dynFuncType) -{ - auto info = PropagationInfo(PropagationJudgment::SetOfFuncs); - info.possibleValues = values; - info.dynFuncType = dynFuncType; - return info; -} - PropagationInfo PropagationInfo::makeExistential(const HashSet& tables) { auto info = PropagationInfo(PropagationJudgment::Existential); @@ -210,17 +195,10 @@ bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) switch (a.judgment) { - case PropagationJudgment::Value: - return true; // All value judgments are equal - - case PropagationJudgment::ConcreteType: - case PropagationJudgment::ConcreteTable: - case PropagationJudgment::ConcreteFunc: - return a.concreteValue == b.concreteValue; - - case PropagationJudgment::SetOfTypes: - case PropagationJudgment::SetOfTables: - case PropagationJudgment::SetOfFuncs: + case PropagationJudgment::None: + case PropagationJudgment::Unbounded: + return true; + case PropagationJudgment::Set: if (a.possibleValues.getCount() != b.possibleValues.getCount()) return false; for (auto value : a.possibleValues) @@ -239,10 +217,6 @@ bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) return false; } return true; - - case PropagationJudgment::Unbounded: - return true; // All unknown sets are considered equal - default: return false; } @@ -256,30 +230,13 @@ struct DynamicInstLoweringContext // If this is a global instruction (parent is module), return concrete info if (as(inst->getParent())) { - if (as(inst)) - { - PropagationInfo typeInfo = - PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, nullptr); - typeInfo.concreteValue = inst; - return typeInfo; - } - else if (as(inst)) + if (as(inst) || as(inst) || as(inst)) { - PropagationInfo tableInfo = - PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, nullptr); - tableInfo.concreteValue = inst; - return tableInfo; - } - else if (as(inst)) - { - PropagationInfo funcInfo = - PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, nullptr); - funcInfo.concreteValue = inst; - return funcInfo; + return PropagationInfo::makeSingletonSet(inst); } else { - return PropagationInfo::makeValue(); + return PropagationInfo::none(); } } @@ -571,7 +528,7 @@ struct DynamicInstLoweringContext { // For now, error out as specified SLANG_UNIMPLEMENTED_X("IRCreateExistentialObject lowering not yet implemented"); - return PropagationInfo::makeValue(); + return PropagationInfo::none(); } PropagationInfo analyzeMakeExistential(IRMakeExistential* inst) @@ -590,106 +547,47 @@ struct DynamicInstLoweringContext return PropagationInfo::makeUnbounded(); HashSet tables; - - if (witnessTableInfo.judgment == PropagationJudgment::ConcreteTable) - { - tables.add(witnessTableInfo.concreteValue); - } - else if (witnessTableInfo.judgment == PropagationJudgment::SetOfTables) - { + if (witnessTableInfo.judgment == PropagationJudgment::Set) for (auto table : witnessTableInfo.possibleValues) - { tables.add(table); - } - } return PropagationInfo::makeExistential(tables); } - static IRInst* lookupEntry(IRInst* witnessTable, IRInst* key) + static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) { if (auto concreteTable = as(witnessTable)) - { for (auto entry : concreteTable->getEntries()) - { if (entry->getRequirementKey() == key) - { return entry->getSatisfyingVal(); - } - } - } return nullptr; // Not found } PropagationInfo analyzeLookupWitnessMethod(IRLookupWitnessMethod* inst) { - auto witnessTable = inst->getWitnessTable(); auto key = inst->getRequirementKey(); - auto witnessTableInfo = tryGetInfo(witnessTable); - - if (!witnessTableInfo) - return PropagationInfo::none(); - if (witnessTableInfo.judgment == PropagationJudgment::Unbounded) - return PropagationInfo::makeUnbounded(); - - HashSet results; - - if (witnessTableInfo.judgment == PropagationJudgment::ConcreteTable) - { - results.add(lookupEntry(witnessTableInfo.concreteValue, key)); - } - else if (witnessTableInfo.judgment == PropagationJudgment::SetOfTables) - { - for (auto table : witnessTableInfo.possibleValues) - { - results.add(lookupEntry(table, key)); - } - } + auto witnessTable = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(witnessTable); - if (witnessTableInfo.judgment == PropagationJudgment::ConcreteTable) - { - if (as(inst->getDataType())) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteFunc, - *results.begin()); - } - else if (as(inst->getDataType())) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteType, - *results.begin()); - } - else if (as(inst->getDataType())) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteTable, - *results.begin()); - } - else - { - SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); - } - } - else + switch (witnessTableInfo.judgment) { - if (auto funcType = as(inst->getDataType())) - { - return PropagationInfo::makeSetOfFuncs(results, funcType); - } - else if (as(inst->getDataType())) - { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, results); - } - else if (as(inst->getDataType())) + case PropagationJudgment::None: + case PropagationJudgment::Unbounded: + return witnessTableInfo.judgment; + case PropagationJudgment::Set: { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, results); - } - else - { - SLANG_UNEXPECTED("Unexpected data type for LookupWitnessMethod"); + HashSet results; + for (auto table : witnessTableInfo.possibleValues) + results.add(findEntryInConcreteTable(table, key)); + return PropagationInfo::makeSet(results); } + case PropagationJudgment::Existential: + SLANG_UNEXPECTED("Unexpected LookupWitnessMethod on Existential"); + break; + default: + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeLookupWitnessMethod"); + break; } } @@ -698,33 +596,23 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(operand); - if (!operandInfo) + switch (operandInfo.judgment) + { + case PropagationJudgment::None: return PropagationInfo::none(); - - if (operandInfo.judgment == PropagationJudgment::Unbounded) + case PropagationJudgment::Unbounded: return PropagationInfo::makeUnbounded(); - - if (operandInfo.judgment == PropagationJudgment::Existential) - { - HashSet tables; - for (auto table : operandInfo.possibleValues) - { - tables.add(table); - } - - if (tables.getCount() == 1) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteTable, - *tables.begin()); - } - else - { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTables, tables); - } + case PropagationJudgment::Existential: + return PropagationInfo::makeSet(operandInfo.possibleValues); + case PropagationJudgment::Set: + SLANG_UNEXPECTED( + "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); + break; + default: + SLANG_UNEXPECTED( + "Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); + break; } - - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); } PropagationInfo analyzeExtractExistentialType(IRExtractExistentialType* inst) @@ -732,57 +620,37 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(operand); - if (!operandInfo) + switch (operandInfo.judgment) + { + case PropagationJudgment::None: return PropagationInfo::none(); - - if (operandInfo.judgment == PropagationJudgment::Unbounded) + case PropagationJudgment::Unbounded: return PropagationInfo::makeUnbounded(); - - if (operandInfo.judgment == PropagationJudgment::Existential) - { - HashSet types; - // Extract types from witness tables by looking at the concrete types - for (auto table : operandInfo.possibleValues) + case PropagationJudgment::Existential: { - // Get the concrete type from the witness table - if (auto witnessTable = as(table)) - { - if (auto concreteType = witnessTable->getConcreteType()) - { - types.add(concreteType); - } - } - else - { - SLANG_UNEXPECTED("Expected witness table in existential extraction base type"); - } - } - - if (types.getCount() == 0) - { - // No concrete types found, treat as this instruction - types.add(inst); - } - - if (types.getCount() == 1) - { - return PropagationInfo::makeConcrete( - PropagationJudgment::ConcreteType, - *types.begin()); - } - else - { - return PropagationInfo::makeSet(PropagationJudgment::SetOfTypes, types); + HashSet types; + for (auto table : operandInfo.possibleValues) + if (auto witnessTable = cast(table)) // Expect witness table + if (auto concreteType = witnessTable->getConcreteType()) + types.add(concreteType); + return PropagationInfo::makeSet(types); } + case PropagationJudgment::Set: + SLANG_UNEXPECTED( + "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); + break; + default: + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialType"); + break; } - - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); } PropagationInfo analyzeExtractExistentialValue(IRExtractExistentialValue* inst) { - // The value itself is just a regular value - return PropagationInfo::makeValue(); + // We don't care about the value itself. + // (We rely on the propagation info for the type) + // + return PropagationInfo::none(); } PropagationInfo analyzeCall(IRCall* inst, LinkedList& workQueue) @@ -819,19 +687,11 @@ struct DynamicInstLoweringContext workQueue.addLast(WorkItem(InterproceduralEdge::Direction::CallToFunc, inst, func)); }; - if (calleeInfo) + if (calleeInfo.judgment == PropagationJudgment::Set) { - if (calleeInfo.judgment == PropagationJudgment::ConcreteFunc) - { - // If we have a concrete function, register the call site - propagateToCallSite(as(calleeInfo.concreteValue)); - } - else if (calleeInfo.judgment == PropagationJudgment::SetOfFuncs) - { - // If we have a set of functions, register each one - for (auto func : calleeInfo.possibleValues) - propagateToCallSite(as(func)); - } + // If we have a set of functions, register each one + for (auto func : calleeInfo.possibleValues) + propagateToCallSite(cast(func)); } if (auto callInfo = tryGetInfo(inst)) @@ -979,6 +839,12 @@ struct DynamicInstLoweringContext } } + PropagationInfo getFuncReturnInfo(IRFunc* func) + { + funcReturnInfo.addIfNotExists(func, PropagationInfo::none()); + return funcReturnInfo[func]; + } + void initializeFirstBlockParameters(IRFunc* func) { auto firstBlock = func->getFirstBlock(); @@ -1000,15 +866,9 @@ struct DynamicInstLoweringContext else propagationMap[param] = PropagationInfo::none(); // Initialize to none. } - else if ( - as(paramType) || as(paramType) || - as(paramType) || as(paramType)) - { - propagationMap[param] = PropagationInfo::none(); - } else { - propagationMap[param] = PropagationInfo::makeValue(); + propagationMap[param] = PropagationInfo::none(); } } } @@ -1026,49 +886,20 @@ struct DynamicInstLoweringContext // Get return value info if there is a return value PropagationInfo returnValueInfo; - if (returnInst->getOperandCount() > 0) - { - auto returnValue = returnInst->getOperand(0); - returnValueInfo = tryGetInfo(returnValue); - } - else - { - // Void return - returnValueInfo = PropagationInfo::makeValue(); - } + returnValueInfo = tryGetInfo(returnInst->getVal()); // Update function return info by unioning with existing info - auto existingReturnInfo = funcReturnInfo.tryGetValue(func); bool returnInfoChanged = false; - - if (returnValueInfo) + auto existingReturnInfo = getFuncReturnInfo(func); + auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnValueInfo); + if (!areInfosEqual(newReturnInfo, existingReturnInfo)) { - if (existingReturnInfo) - { - auto newReturnInfo = unionPropagationInfo( - List({*existingReturnInfo, returnValueInfo})); + funcReturnInfo[func] = newReturnInfo; - if (!areInfosEqual(*existingReturnInfo, newReturnInfo)) - { - funcReturnInfo[func] = newReturnInfo; - returnInfoChanged = true; - } - } - else - { - funcReturnInfo[func] = returnValueInfo; - returnInfoChanged = true; - } - } - - // If return info changed, add return edges to call sites - if (returnInfoChanged && this->funcCallSites.containsKey(func)) - { - for (auto callSite : this->funcCallSites[func]) - { - workQueue.addLast( - WorkItem(InterproceduralEdge::Direction::FuncToCall, callSite, func)); - } + if (this->funcCallSites.containsKey(func)) + for (auto callSite : this->funcCallSites[func]) + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::FuncToCall, callSite, func)); } return returnInfoChanged; @@ -1078,7 +909,7 @@ struct DynamicInstLoweringContext { if (infos.getCount() == 0) { - return PropagationInfo::makeValue(); + return PropagationInfo::none(); } if (infos.getCount() == 1) @@ -1112,94 +943,42 @@ struct DynamicInstLoweringContext { switch (info.judgment) { - case PropagationJudgment::ConcreteType: - unionJudgment = PropagationJudgment::SetOfTypes; - allValues.add(info.concreteValue); - break; - case PropagationJudgment::ConcreteTable: - unionJudgment = PropagationJudgment::SetOfTables; - allValues.add(info.concreteValue); - break; - case PropagationJudgment::ConcreteFunc: - unionJudgment = PropagationJudgment::SetOfFuncs; - allValues.add(info.concreteValue); - break; - case PropagationJudgment::SetOfTypes: - unionJudgment = PropagationJudgment::SetOfTypes; - for (auto value : info.possibleValues) - allValues.add(value); + case PropagationJudgment::None: break; - case PropagationJudgment::SetOfTables: - unionJudgment = PropagationJudgment::SetOfTables; + case PropagationJudgment::Set: + unionJudgment = PropagationJudgment::Set; for (auto value : info.possibleValues) allValues.add(value); break; - case PropagationJudgment::SetOfFuncs: - unionJudgment = PropagationJudgment::SetOfFuncs; - for (auto value : info.possibleValues) - allValues.add(value); - if (!dynFuncType) - { - // If we haven't set a function type yet, use the first one - dynFuncType = info.dynFuncType; - } - else if (dynFuncType != info.dynFuncType) - { - SLANG_UNEXPECTED( - "Mismatched function types in union propagation info for SetOfFuncs"); - } - - break; - case PropagationJudgment::Value: - if (unionJudgment == PropagationJudgment::None) - unionJudgment = PropagationJudgment::Value; - else - { - SLANG_ASSERT(unionJudgment == PropagationJudgment::Value); - } - break; - case PropagationJudgment::None: - // None judgments are basically 'empty' - break; case PropagationJudgment::Existential: // For existential union, we need to collect all witness tables // For now, we'll handle this properly by creating a new existential with all tables - { - HashSet allTables; - for (auto otherInfo : infos) - { - if (otherInfo.judgment == PropagationJudgment::Existential) - { - for (auto table : otherInfo.possibleValues) - { - allTables.add(table); - } - } - } - if (allTables.getCount() > 0) - { - return PropagationInfo::makeExistential(allTables); - } - } - return PropagationInfo::none(); + unionJudgment = PropagationJudgment::Existential; + for (auto value : info.possibleValues) + allValues.add(value); + break; case PropagationJudgment::Unbounded: // If any info is unbounded, the union is unbounded return PropagationInfo::makeUnbounded(); } } - // If we collected values, create a set; otherwise return value - if (allValues.getCount() > 0) - { - if (unionJudgment == PropagationJudgment::SetOfFuncs && dynFuncType) - return PropagationInfo::makeSetOfFuncs(allValues, dynFuncType); + if (unionJudgment == PropagationJudgment::Existential) + if (allValues.getCount() > 0) + return PropagationInfo::makeExistential(allValues); + else + return PropagationInfo::none(); - return PropagationInfo::makeSet(unionJudgment, allValues); - } - else - { - return PropagationInfo::none(); - } + if (unionJudgment == PropagationJudgment::Set) + if (allValues.getCount() > 1) + return PropagationInfo::makeSet(allValues); + else + return PropagationInfo::none(); + + // If we reach here, crash instead of returning none (which could make the analysis go into + // an infinite loop) + // + SLANG_UNEXPECTED("Unhandled prop-info union"); } PropagationInfo unionPropagationInfo(PropagationInfo info1, PropagationInfo info2) @@ -1214,22 +993,10 @@ struct DynamicInstLoweringContext PropagationInfo analyzeDefault(IRInst* inst) { // Check if this is a type, witness table, or function - if (as(inst)) - { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteType, inst); - } - else if (as(inst)) - { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteTable, inst); - } - else if (as(inst)) - { - return PropagationInfo::makeConcrete(PropagationJudgment::ConcreteFunc, inst); - } + if (as(inst) || as(inst) || as(inst)) + return PropagationInfo::makeSingletonSet(inst); else - { - return PropagationInfo::makeValue(); - } + return PropagationInfo::none(); // Default case, no propagation info } void performDynamicInstLowering() @@ -1278,7 +1045,7 @@ struct DynamicInstLoweringContext instWithReplacementTypes.add(child); if (auto calleeInfo = tryGetInfo(as(child)->getCallee())) - if (calleeInfo.judgment == PropagationJudgment::SetOfFuncs) + if (calleeInfo.judgment == PropagationJudgment::Set) valueInstsToLower.add(child); } break; @@ -1416,44 +1183,44 @@ struct DynamicInstLoweringContext builder.setInsertBefore(inst); // Check if this is a TypeKind data type with SetOfTypes judgment - if (info.judgment == PropagationJudgment::SetOfTypes && - inst->getDataType()->getOp() == kIROp_TypeKind) + if (info.judgment == PropagationJudgment::Set) { - // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); - - // Store the mapping for later use - loweredInstToAnyValueType[inst] = anyValueType; - - // Replace the instruction with the any-value type - inst->replaceUsesWith(anyValueType); - inst->removeAndDeallocate(); - return; - } + if (inst->getDataType()->getOp() == kIROp_TypeKind) + { + // Create an any-value type based on the set of types + auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); - if (info.judgment == PropagationJudgment::SetOfTables || - info.judgment == PropagationJudgment::SetOfFuncs) - { - // Get the witness table operand info - auto witnessTableInst = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(witnessTableInst); + // Store the mapping for later use + loweredInstToAnyValueType[inst] = anyValueType; - if (witnessTableInfo && witnessTableInfo.judgment == PropagationJudgment::SetOfTables) - { - // Create a key mapping function - auto keyMappingFunc = createKeyMappingFunc( - inst->getRequirementKey(), - witnessTableInfo.possibleValues, - info.possibleValues); - - // Replace with call to key mapping function - auto witnessTableId = builder.emitCallInst( - builder.getUIntType(), - keyMappingFunc, - List({inst->getWitnessTable()})); - inst->replaceUsesWith(witnessTableId); - propagationMap[witnessTableId] = info; + // Replace the instruction with the any-value type + inst->replaceUsesWith(anyValueType); inst->removeAndDeallocate(); + return; + } + else + { + // Get the witness table operand info + auto witnessTableInst = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(witnessTableInst); + + if (witnessTableInfo.judgment == PropagationJudgment::Set) + { + // Create a key mapping function + auto keyMappingFunc = createKeyMappingFunc( + inst->getRequirementKey(), + witnessTableInfo.possibleValues, + info.possibleValues); + + // Replace with call to key mapping function + auto witnessTableId = builder.emitCallInst( + builder.getUIntType(), + keyMappingFunc, + List({inst->getWitnessTable()})); + inst->replaceUsesWith(witnessTableId); + propagationMap[witnessTableId] = info; + inst->removeAndDeallocate(); + } } } } @@ -1467,7 +1234,7 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - if (info.judgment == PropagationJudgment::SetOfTables) + if (info.judgment == PropagationJudgment::Set) { // Replace with GetElement(loweredInst, 0) -> uint auto operand = inst->getOperand(0); @@ -1501,7 +1268,7 @@ struct DynamicInstLoweringContext void lowerExtractExistentialType(IRExtractExistentialType* inst) { auto info = tryGetInfo(inst); - if (!info || info.judgment != PropagationJudgment::SetOfTypes) + if (!info || info.judgment != PropagationJudgment::Set) return; IRBuilder builder(inst); @@ -1557,7 +1324,10 @@ struct DynamicInstLoweringContext auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(callee); - if (!calleeInfo || calleeInfo.judgment != PropagationJudgment::SetOfFuncs) + if (!calleeInfo || calleeInfo.judgment != PropagationJudgment::Set) + return; + + if (calleeInfo.isSingleton() && calleeInfo.getSingletonValue() == callee) return; IRBuilder builder(inst); @@ -1910,30 +1680,6 @@ struct DynamicInstLoweringContext return maxSize; } - bool needsReinterpret(PropagationInfo sourceInfo, PropagationInfo targetInfo) - { - if (!sourceInfo || !targetInfo) - return false; - - // Check if both are SetOfTypes with different sets - if (sourceInfo.judgment == PropagationJudgment::SetOfTypes && - targetInfo.judgment == PropagationJudgment::SetOfTypes) - { - if (sourceInfo.possibleValues.getCount() != targetInfo.possibleValues.getCount()) - return true; - } - - // Check if both are Existential with different witness table sets - if (sourceInfo.judgment == PropagationJudgment::Existential && - targetInfo.judgment == PropagationJudgment::Existential) - { - if (sourceInfo.possibleValues.getCount() != targetInfo.possibleValues.getCount()) - return true; - } - - return false; - } - bool isExistentialType(IRType* type) { return as(type) != nullptr; } bool isInterfaceType(IRType* type) { return as(type) != nullptr; } From 2543391b11f5ab6acb30395dab3606e1bf6903a1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 23 Jul 2025 16:26:01 -0400 Subject: [PATCH 006/105] Bottleneck info updates through a single func --- source/slang/slang-ir-lower-dynamic-insts.cpp | 255 ++++++++---------- 1 file changed, 114 insertions(+), 141 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 438edb58f6d..3ea24754ba1 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -255,57 +255,131 @@ struct DynamicInstLoweringContext return PropagationInfo::none(); } - void processBlock(IRBlock* block, LinkedList& workQueue) + // Centralized method to update propagation info and manage work queue + // Use this when you want to propagate new information to an existing instruction + // This will union the new info with existing info and add users to work queue if changed + void updateInfo(IRInst* inst, PropagationInfo newInfo, LinkedList& workQueue) { - bool anyInfoChanged = false; + auto existingInfo = tryGetInfo(inst); + auto unionedInfo = unionPropagationInfo(existingInfo, newInfo); - HashSet affectedBlocks; - HashSet affectedTerminators; - for (auto inst : block->getChildren()) + // Only proceed if info actually changed + if (areInfosEqual(existingInfo, unionedInfo)) + return; + + // Update the propagation map + propagationMap[inst] = unionedInfo; + + // Add all users to appropriate work items + addUsersToWorkQueue(inst, unionedInfo, workQueue); + } + + // Helper to add users of an instruction to the work queue based on how they use it + // This handles intra-procedural edges, inter-procedural edges, and return value propagation + void addUsersToWorkQueue(IRInst* inst, PropagationInfo info, LinkedList& workQueue) + { + for (auto use = inst->firstUse; use; use = use->nextUse) { - // Skip parameters & terminator - if (as(inst) || as(inst)) - continue; + auto user = use->getUser(); - auto oldInfo = tryGetInfo(inst); - processInstForPropagation(inst, workQueue); - auto newInfo = tryGetInfo(inst); + // If user is in a different block (or the inst is a param), add that block to work + // queue. + // + if (auto userBlock = as(user->getParent())) + { + auto instBlock = as(inst->getParent()); + if (userBlock != instBlock || as(inst)) + { + workQueue.addLast(WorkItem(userBlock)); + } + } - // If information has changed, propagate to appropriate blocks/edges - if (!areInfosEqual(oldInfo, newInfo)) + // If user is a terminator, add intra-procedural edges + if (auto terminator = as(user)) { - for (auto use = inst->firstUse; use; use = use->nextUse) + auto parentBlock = as(terminator->getParent()); + if (parentBlock) { - auto userBlock = as(use->getUser()); - if (userBlock && userBlock != block) - affectedBlocks.add(userBlock); + auto successors = parentBlock->getSuccessors(); + for (auto succIter = successors.begin(); succIter != successors.end(); + ++succIter) + { + workQueue.addLast(WorkItem(succIter.getEdge())); + } + } + } - if (auto terminator = as(use->getUser())) - affectedTerminators.add(terminator); + // If user is a return instruction, handle function return propagation + if (auto returnInst = as(user)) + { + auto func = as(returnInst->getParent()->getParent()); + if (func) + { + updateFuncReturnInfo(func, info, workQueue); } } - } - for (auto block : affectedBlocks) - { - workQueue.addLast(WorkItem(block)); + /* If user is a call instruction and inst is the callee, add interprocedural edges + if (auto callInst = as(user)) + { + if (callInst->getCallee() == inst && info.judgment == PropagationJudgment::Set) + { + // Add interprocedural edges for each possible function + for (auto funcInst : info.possibleValues) + { + if (auto func = as(funcInst)) + { + workQueue.addLast(WorkItem( + InterproceduralEdge::Direction::CallToFunc, + callInst, + func)); + } + } + } + }*/ } + } - for (auto terminator : affectedTerminators) + // Helper method to update function return info and propagate to call sites + void updateFuncReturnInfo( + IRFunc* func, + PropagationInfo returnInfo, + LinkedList& workQueue) + { + auto existingReturnInfo = getFuncReturnInfo(func); + auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnInfo); + + if (!areInfosEqual(existingReturnInfo, newReturnInfo)) { - auto successors = as(terminator->getParent())->getSuccessors(); - for (auto succIter = successors.begin(), succEnd = successors.end(); - succIter != succEnd; - ++succIter) + funcReturnInfo[func] = newReturnInfo; + + // Add interprocedural edges to all call sites + if (funcCallSites.containsKey(func)) { - workQueue.addLast(WorkItem(succIter.getEdge())); + for (auto callSite : funcCallSites[func]) + { + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::FuncToCall, callSite, func)); + } } } + } + + // Summary of centralized propagation methods: + // - updateInfo(): Use for propagating new info to existing instructions (unions and manages + // work queue) + // - addUsersToWorkQueue(): Handles adding users to work queue for intra/inter-procedural + // propagation + // - updateFuncReturnInfo(): Specialized helper for function return value propagation - if (as(block->getTerminator())) + void processBlock(IRBlock* block, LinkedList& workQueue) + { + for (auto inst : block->getChildren()) { - // If the block has a return inst, we need to propagate return values - propagateReturnValues(block, workQueue); + // Skip parameters & terminator + if (as(inst) || as(inst)) + continue; + processInstForPropagation(inst, workQueue); } }; @@ -521,14 +595,13 @@ struct DynamicInstLoweringContext break; } - propagationMap[inst] = info; + updateInfo(inst, info, workQueue); } PropagationInfo analyzeCreateExistentialObject(IRCreateExistentialObject* inst) { // For now, error out as specified SLANG_UNIMPLEMENTED_X("IRCreateExistentialObject lowering not yet implemented"); - return PropagationInfo::none(); } PropagationInfo analyzeMakeExistential(IRMakeExistential* inst) @@ -721,8 +794,6 @@ struct DynamicInstLoweringContext return; // Collect propagation info for each argument and update corresponding parameter - // TODO: Unify this logic with the affectedBlocks logic in the per-inst processing logic. - HashSet affectedBlocks; Index paramIndex = 0; for (auto param : successorBlock->getParams()) { @@ -731,38 +802,12 @@ struct DynamicInstLoweringContext auto arg = unconditionalBranch->getArg(paramIndex); if (auto argInfo = tryGetInfo(arg)) { - // Union with existing parameter info - bool infoChanged = false; - if (auto existingInfo = tryGetInfo(param)) - { - propagationMap[param] = unionPropagationInfo(existingInfo, argInfo); - if (!infoChanged && !areInfosEqual(existingInfo, propagationMap[param])) - infoChanged = true; - } - else - { - propagationMap[param] = argInfo; - infoChanged = true; - } - // If any info changed, add all user blocks to the affected set - if (infoChanged) - { - for (auto use = param->firstUse; use; use = use->nextUse) - { - auto user = use->getUser(); - if (auto block = as(user->getParent())) - affectedBlocks.add(block); - } - } + // Use centralized update method + updateInfo(param, argInfo, workQueue); } } paramIndex++; } - - for (auto block : affectedBlocks) - { - workQueue.addLast(WorkItem(block)); - } } void propagateInterproceduralEdge(InterproceduralEdge edge, LinkedList& workQueue) @@ -781,7 +826,6 @@ struct DynamicInstLoweringContext return; Index argIndex = 1; // Skip callee (operand 0) - HashSet affectedBlocks; for (auto param : firstBlock->getParams()) { if (argIndex < callInst->getOperandCount()) @@ -789,48 +833,24 @@ struct DynamicInstLoweringContext auto arg = callInst->getOperand(argIndex); if (auto argInfo = tryGetInfo(arg)) { - // Union with existing parameter info - auto existingInfo = tryGetInfo(param); - auto newInfo = unionPropagationInfo(tryGetInfo(param), argInfo); - propagationMap[param] = newInfo; - if (!areInfosEqual(existingInfo, newInfo)) - { - for (auto use = param->firstUse; use; use = use->nextUse) - { - auto user = use->getUser(); - if (auto block = as(user->getParent())) - affectedBlocks.add(block); - } - } + // Use centralized update method + updateInfo(param, argInfo, workQueue); } } argIndex++; } - // Add the affected block to the work queue if any info changed - for (auto block : affectedBlocks) - workQueue.addLast(WorkItem(block)); break; } case InterproceduralEdge::Direction::FuncToCall: { // Propagate return value info from function to call site auto returnInfo = funcReturnInfo.tryGetValue(targetFunc); - - bool anyInfoChanged = false; if (returnInfo) { - // Union with existing call info - auto existingCallInfo = tryGetInfo(callInst); - auto newInfo = unionPropagationInfo(existingCallInfo, *returnInfo); - propagationMap[callInst] = newInfo; - if (!areInfosEqual(existingCallInfo, newInfo)) - anyInfoChanged = true; - } - - // Add the callInst's parent block to the work queue if any info changed - if (anyInfoChanged) + // Use centralized update method workQueue.addLast(WorkItem(as(callInst->getParent()))); - + updateInfo(callInst, *returnInfo, workQueue); + } break; } default: @@ -873,38 +893,6 @@ struct DynamicInstLoweringContext } } - bool propagateReturnValues(IRBlock* block, LinkedList& workQueue) - { - auto terminator = block->getTerminator(); - auto returnInst = as(terminator); - if (!returnInst) - return false; - - auto func = as(block->getParent()); - if (!func) - return false; - - // Get return value info if there is a return value - PropagationInfo returnValueInfo; - returnValueInfo = tryGetInfo(returnInst->getVal()); - - // Update function return info by unioning with existing info - bool returnInfoChanged = false; - auto existingReturnInfo = getFuncReturnInfo(func); - auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnValueInfo); - if (!areInfosEqual(newReturnInfo, existingReturnInfo)) - { - funcReturnInfo[func] = newReturnInfo; - - if (this->funcCallSites.containsKey(func)) - for (auto callSite : this->funcCallSites[func]) - workQueue.addLast( - WorkItem(InterproceduralEdge::Direction::FuncToCall, callSite, func)); - } - - return returnInfoChanged; - } - PropagationInfo unionPropagationInfo(const List& infos) { if (infos.getCount() == 0) @@ -1127,21 +1115,6 @@ struct DynamicInstLoweringContext if (!info || info.judgment != PropagationJudgment::Existential) return; - /* Replace type with Tuple - IRBuilder builder(module); - builder.setInsertBefore(inst); - - HashSet types; - // Extract types from witness tables by looking at the concrete types - for (auto table : info.possibleValues) - if (auto witnessTable = as(table)) - if (auto concreteType = witnessTable->getConcreteType()) - types.add(concreteType); - - auto anyValueType = createAnyValueTypeFromInsts(types); - auto tupleType = builder.getTupleType(List({builder.getUIntType(), - anyValueType}));*/ - inst->setFullType(getTypeForExistential(info)); } From 3e38e826ed0d9e05c5f7674981a9bcbc65dc6b24 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 23 Jul 2025 17:46:32 -0400 Subject: [PATCH 007/105] Use more granular work items (inst-level) + add support for `createDynamicObject` --- source/slang/slang-ir-lower-dynamic-insts.cpp | 186 ++++++++++-------- .../dynamic-dispatch/derived-interface.slang | 82 ++++++++ 2 files changed, 184 insertions(+), 84 deletions(-) create mode 100644 tests/language-feature/dynamic-dispatch/derived-interface.slang diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 3ea24754ba1..86384f565fd 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -93,6 +93,7 @@ struct WorkItem enum class Type { None, // Invalid + Inst, // Propagate through a single instruction Block, // Propagate information within a block IntraProc, // Propagate through within-function edge (IREdge) InterProc // Propagate across function call/return (InterproceduralEdge) @@ -101,9 +102,10 @@ struct WorkItem Type type; union { - IRBlock* block; - IREdge intraProcEdge; - InterproceduralEdge interProcEdge; + IRInst* inst; // Type::Inst + IRBlock* block; // Type::Block + IREdge intraProcEdge; // Type::IntraProc + InterproceduralEdge interProcEdge; // Type::InterProc }; WorkItem() @@ -111,6 +113,11 @@ struct WorkItem { } + WorkItem(IRInst* inst) + : type(Type::Inst), inst(inst) + { + } + WorkItem(IRBlock* block) : type(Type::Block), block(block) { @@ -139,6 +146,8 @@ struct WorkItem intraProcEdge = other.intraProcEdge; else if (type == Type::InterProc) interProcEdge = other.interProcEdge; + else if (type == Type::Inst) + inst = other.inst; else block = other.block; } @@ -150,6 +159,8 @@ struct WorkItem intraProcEdge = other.intraProcEdge; else if (type == Type::InterProc) interProcEdge = other.interProcEdge; + else if (type == Type::Inst) + inst = other.inst; else block = other.block; return *this; @@ -285,14 +296,7 @@ struct DynamicInstLoweringContext // If user is in a different block (or the inst is a param), add that block to work // queue. // - if (auto userBlock = as(user->getParent())) - { - auto instBlock = as(inst->getParent()); - if (userBlock != instBlock || as(inst)) - { - workQueue.addLast(WorkItem(userBlock)); - } - } + workQueue.addLast(WorkItem(user)); // If user is a terminator, add intra-procedural edges if (auto terminator = as(user)) @@ -318,25 +322,6 @@ struct DynamicInstLoweringContext updateFuncReturnInfo(func, info, workQueue); } } - - /* If user is a call instruction and inst is the callee, add interprocedural edges - if (auto callInst = as(user)) - { - if (callInst->getCallee() == inst && info.judgment == PropagationJudgment::Set) - { - // Add interprocedural edges for each possible function - for (auto funcInst : info.possibleValues) - { - if (auto func = as(funcInst)) - { - workQueue.addLast(WorkItem( - InterproceduralEdge::Direction::CallToFunc, - callInst, - func)); - } - } - } - }*/ } } @@ -365,13 +350,6 @@ struct DynamicInstLoweringContext } } - // Summary of centralized propagation methods: - // - updateInfo(): Use for propagating new info to existing instructions (unions and manages - // work queue) - // - addUsersToWorkQueue(): Handles adding users to work queue for intra/inter-procedural - // propagation - // - updateFuncReturnInfo(): Specialized helper for function return value propagation - void processBlock(IRBlock* block, LinkedList& workQueue) { for (auto inst : block->getChildren()) @@ -415,6 +393,9 @@ struct DynamicInstLoweringContext switch (item.type) { + case WorkItem::Type::Inst: + processInstForPropagation(item.inst, workQueue); + break; case WorkItem::Type::Block: processBlock(item.block, workQueue); break; @@ -600,8 +581,28 @@ struct DynamicInstLoweringContext PropagationInfo analyzeCreateExistentialObject(IRCreateExistentialObject* inst) { - // For now, error out as specified - SLANG_UNIMPLEMENTED_X("IRCreateExistentialObject lowering not yet implemented"); + // + // TODO: Actually use the integer<->type map present in the linkage to + // extract a set of possible witness tables (if the index is a compile-time constant). + // + + if (auto interfaceType = as(inst->getDataType())) + { + if (!interfaceType->findDecoration()) + { + auto tables = collectExistentialTables(interfaceType); + if (tables.getCount() > 0) + return PropagationInfo::makeExistential(tables); + else + return PropagationInfo::none(); + } + else + { + return PropagationInfo::makeUnbounded(); + } + } + + return PropagationInfo::none(); } PropagationInfo analyzeMakeExistential(IRMakeExistential* inst) @@ -848,7 +849,6 @@ struct DynamicInstLoweringContext if (returnInfo) { // Use centralized update method - workQueue.addLast(WorkItem(as(callInst->getParent()))); updateInfo(callInst, *returnInfo, workQueue); } break; @@ -1364,17 +1364,48 @@ struct DynamicInstLoweringContext IRInst* tupleArgs[] = {tableId, packedValue}; auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); + if (auto info = tryGetInfo(inst)) + propagationMap[tuple] = info; + inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); } void lowerCreateExistentialObject(IRCreateExistentialObject* inst) { - // Error out for now as specified - sink->diagnose( - inst, - Diagnostics::unimplemented, - "IRCreateExistentialObject lowering not yet implemented"); + auto info = tryGetInfo(inst); + if (!info || info.judgment != PropagationJudgment::Existential) + return; + + Dictionary mapping; + for (auto table : info.possibleValues) + { + // Get unique ID for the witness table + auto witnessTable = cast(table); + auto outputId = getUniqueID(witnessTable); + auto inputId = table->findDecoration()->getSequentialID(); + mapping[inputId] = outputId; // Map ID to itself for now + } + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto translatedID = builder.emitCallInst( + builder.getUIntType(), + createIntegerMappingFunc(mapping), + List({inst->getTypeID()})); + + auto existentialTupleType = as(getTypeForExistential(info)); + auto existentialTuple = builder.emitMakeTuple( + existentialTupleType, + List( + {translatedID, + builder.emitReinterpret(existentialTupleType->getOperand(1), inst->getValue())})); + + if (auto info = tryGetInfo(inst)) + propagationMap[existentialTuple] = info; + + inst->replaceUsesWith(existentialTuple); + inst->removeAndDeallocate(); } UInt getUniqueID(IRInst* funcOrTable) @@ -1388,10 +1419,7 @@ struct DynamicInstLoweringContext return newId; } - IRFunc* createKeyMappingFunc( - IRInst* key, - const HashSet& inputTables, - const HashSet& outputVals) + IRFunc* createIntegerMappingFunc(Dictionary& mapping) { // Create a function that maps input IDs to output IDs IRBuilder builder(module); @@ -1419,43 +1447,15 @@ struct DynamicInstLoweringContext List caseValues; List caseBlocks; - // Build mapping from input tables to output values - List inputTableArray; - List outputValArray; - - for (auto table : inputTables) - inputTableArray.add(table); - for (auto table : outputVals) - outputValArray.add(table); - - for (Index i = 0; i < inputTableArray.getCount(); i++) + for (auto item : mapping) { - auto inputTable = inputTableArray[i]; - auto inputId = getUniqueID(inputTable); - - // Find corresponding output table (for now, use simple 1:1 mapping) - IRInst* outputVal = nullptr; - if (i < outputValArray.getCount()) - { - outputVal = outputValArray[i]; - } - else if (outputValArray.getCount() > 0) - { - outputVal = outputValArray[0]; // Fallback to first output - } - - if (outputVal) - { - auto outputId = getUniqueID(outputVal); - - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), outputId)); + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), item.second)); - caseValues.add(builder.getIntValue(builder.getUIntType(), inputId)); - caseBlocks.add(caseBlock); - } + caseValues.add(builder.getIntValue(builder.getUIntType(), item.first)); + caseBlocks.add(caseBlock); } // Create flattened case arguments array @@ -1483,6 +1483,24 @@ struct DynamicInstLoweringContext return func; } + IRFunc* createKeyMappingFunc( + IRInst* key, + const HashSet& inputTables, + const HashSet& outputVals) + { + Dictionary mapping; + + // Create a mapping. + for (auto table : inputTables) + { + auto inputId = getUniqueID(table); + auto outputId = getUniqueID(findEntryInConcreteTable(table, key)); + mapping[inputId] = outputId; + } + + return createIntegerMappingFunc(mapping); + } + IRFunc* createDispatchFunc(const HashSet& funcs, IRFuncType* expectedFuncType) { // Create a dispatch function with switch-case for each function diff --git a/tests/language-feature/dynamic-dispatch/derived-interface.slang b/tests/language-feature/dynamic-dispatch/derived-interface.slang new file mode 100644 index 00000000000..a52598477f0 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/derived-interface.slang @@ -0,0 +1,82 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IScalable +{ + This scale(float x); +} + +interface IShape : IScalable +{ + float getArea(); + float getPerimeter(); +} + +struct Circle : IShape +{ + float radius; + + This scale(float x) { return {radius * x}; } + float getArea() { return 3.14159f * radius * radius; } + float getPerimeter() { return 2 * 3.14159f * radius; } +}; + +struct Square : IShape +{ + float side; + + This scale(float x) { return {side * x}; } + float getArea() { return side * side; } + float getPerimeter() { return 4 * side; } +}; + +struct Vector : IScalable +{ + float x, y; + + This scale(float factor) { return {x * factor, y * factor}; } +}; + +struct Matrix : IScalable +{ + float m[4][4]; + + This scale(float factor) + { + Matrix result; + for (int i = 0; i < 4; ++i) + for (int j = 0; j < 4; ++j) + result.m[i][j] = m[i][j] * factor; + return result; + } +}; + +IShape createShape(uint id, float data) +{ + return createDynamicObject(id, data); +} + +float f(uint id, float data, float scale) +{ + let obj = createShape(id, data); + let scaledObj = obj.scale(scale); + return scaledObj.getArea() + scaledObj.getPerimeter(); +}; + + +//TEST_INPUT: type_conformance Vector:IScalable = 0 +//TEST_INPUT: type_conformance Matrix:IScalable = 1 +//TEST_INPUT: type_conformance Circle:IScalable = 2 +//TEST_INPUT: type_conformance Square:IScalable = 3 + +//TEST_INPUT: type_conformance Circle:IShape = 0 +//TEST_INPUT: type_conformance Square:IShape = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 3, 3); // CHECK: 311.017426 + outputBuffer[1] = f(1, 3, 3); // CHECK: 117.000000 +} \ No newline at end of file From de04cd86b09ed1a9d57a3e61a27ea8ac558899ec Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 24 Jul 2025 12:59:15 -0400 Subject: [PATCH 008/105] Further simplify: Use IR insts instead of hash-sets to avoid repeated collection duplication --- source/slang/slang-ir-insts-stable-names.lua | 1 + source/slang/slang-ir-insts.lua | 8 + source/slang/slang-ir-lower-dynamic-insts.cpp | 354 +++++++++++------- source/slang/slang-ir.cpp | 3 + 4 files changed, 227 insertions(+), 139 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index db2a8f8c810..c4d8d177e62 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -669,4 +669,5 @@ return { ["SPIRVAsmOperand.__sampledType"] = 665, ["SPIRVAsmOperand.__imageType"] = 666, ["SPIRVAsmOperand.__sampledImageType"] = 667, + ["TypeFlow.TypeFlowCollection"] = 668 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 0325fcbfd9e..0c7be24ca65 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2170,6 +2170,14 @@ local insts = { }, }, }, + { + -- A collection of IR instructions used for propagation analysis + -- The operands are the elements of the set, sorted by unique ID to ensure canonical ordering + TypeFlow = { + hoistable = true, + { TypeFlowCollection = {} } + }, + } } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 86384f565fd..354d3e4a54c 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -30,38 +30,55 @@ struct PropagationInfo : RefObject PropagationJudgment judgment; // For sets of types/tables/funcs and existential witness tables - HashSet possibleValues; + // Instead of HashSet, we use an IRCollection instruction with sorted operands + IRInst* collection; PropagationInfo() - : judgment(PropagationJudgment::None) + : judgment(PropagationJudgment::None), collection(nullptr) { } PropagationInfo(PropagationJudgment j) - : judgment(j) + : judgment(j), collection(nullptr) { } - static PropagationInfo makeSingletonSet(IRInst* value); - static PropagationInfo makeSet(const HashSet& values); - static PropagationInfo makeExistential(const HashSet& tables); - static PropagationInfo makeUnbounded(); - static PropagationInfo none(); + PropagationInfo(PropagationJudgment j, IRInst* coll) + : judgment(j), collection(coll) + { + } + + // NOTE: Factory methods moved to DynamicInstLoweringContext to access collection creation bool isNone() const { return judgment == PropagationJudgment::None; } bool isSingleton() const { - return judgment == PropagationJudgment::Set && possibleValues.getCount() == 1; + return judgment == PropagationJudgment::Set && getCollectionCount() == 1; } IRInst* getSingletonValue() const { - if (judgment == PropagationJudgment::Set && possibleValues.getCount() == 1) - return *possibleValues.begin(); + if (judgment == PropagationJudgment::Set && getCollectionCount() == 1) + return getCollectionElement(0); SLANG_UNEXPECTED("getSingletonValue called on non-singleton PropagationInfo"); } + // Helper functions to access collection elements + UInt getCollectionCount() const + { + if (!collection) + return 0; + return collection->getOperandCount(); + } + + IRInst* getCollectionElement(UInt index) const + { + if (!collection || index >= collection->getOperandCount()) + return nullptr; + return collection->getOperand(index); + } + operator bool() const { return judgment != PropagationJudgment::None; } }; @@ -167,37 +184,7 @@ struct WorkItem } }; -// PropagationInfo implementation -PropagationInfo PropagationInfo::makeSingletonSet(IRInst* value) -{ - auto info = PropagationInfo(PropagationJudgment::Set); - info.possibleValues.add(value); - return info; -} - -PropagationInfo PropagationInfo::makeSet(const HashSet& values) -{ - auto info = PropagationInfo(PropagationJudgment::Set); - info.possibleValues = values; - return info; -} - -PropagationInfo PropagationInfo::makeExistential(const HashSet& tables) -{ - auto info = PropagationInfo(PropagationJudgment::Existential); - info.possibleValues = tables; - return info; -} - -PropagationInfo PropagationInfo::makeUnbounded() -{ - return PropagationInfo(PropagationJudgment::Unbounded); -} - -PropagationInfo PropagationInfo::none() -{ - return PropagationInfo(PropagationJudgment::None); -} +// PropagationInfo implementations will be added to DynamicInstLoweringContext bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) { @@ -210,24 +197,10 @@ bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) case PropagationJudgment::Unbounded: return true; case PropagationJudgment::Set: - if (a.possibleValues.getCount() != b.possibleValues.getCount()) - return false; - for (auto value : a.possibleValues) - { - if (!b.possibleValues.contains(value)) - return false; - } - return true; - case PropagationJudgment::Existential: - if (a.possibleValues.getCount() != b.possibleValues.getCount()) - return false; - for (auto table : a.possibleValues) - { - if (!b.possibleValues.contains(table)) - return false; - } - return true; + // For collection-based sets, compare the collection instructions + // If both have the same collection instruction (from hoisting), they're equal + return a.collection == b.collection; default: return false; } @@ -235,27 +208,107 @@ bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) struct DynamicInstLoweringContext { + // Helper methods for creating canonical collections + IRInst* createCollection(const HashSet& elements) + { + if (elements.getCount() == 0) + return nullptr; + + List sortedElements; + for (auto element : elements) + sortedElements.add(element); + + return createCollection(sortedElements); + } + + IRInst* createCollection(const List& elements) + { + if (elements.getCount() == 0) + return nullptr; + + // Verify that all operands are global instructions + for (auto element : elements) + if (element->getParent()->getOp() != kIROp_ModuleInst) + SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); + + // Sort elements by their unique IDs to ensure canonical ordering + List sortedElements = elements; + sortedElements.sort( + [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); + + // Create the collection instruction + IRBuilder builder(module); + builder.setInsertInto(module); + + // Use makeTuple as a temporary implementation until IRCollection is available + return builder.emitIntrinsicInst( + nullptr, + kIROp_TypeFlowCollection, + sortedElements.getCount(), + sortedElements.getBuffer()); + } + + // Factory methods for PropagationInfo + PropagationInfo makeSingletonSet(IRInst* value) + { + HashSet singleSet; + singleSet.add(value); + auto collection = createCollection(singleSet); + return PropagationInfo(PropagationJudgment::Set, collection); + } + + PropagationInfo makeSet(const HashSet& values) + { + SLANG_ASSERT(values.getCount() > 0); + auto collection = createCollection(values); + return PropagationInfo(PropagationJudgment::Set, collection); + } + + PropagationInfo makeExistential(const HashSet& tables) + { + SLANG_ASSERT(tables.getCount() > 0); + auto collection = createCollection(tables); + return PropagationInfo(PropagationJudgment::Existential, collection); + } + + PropagationInfo makeUnbounded() + { + return PropagationInfo(PropagationJudgment::Unbounded, nullptr); + } + + PropagationInfo none() { return PropagationInfo(PropagationJudgment::None, nullptr); } + + // Helper to iterate over collection elements + template + void forEachInCollection(const PropagationInfo& info, F func) + { + for (UInt i = 0; i < info.getCollectionCount(); ++i) + func(info.getCollectionElement(i)); + } + + // Helper to convert collection to HashSet + HashSet collectionToHashSet(const PropagationInfo& info) + { + HashSet result; + forEachInCollection(info, [&](IRInst* element) { result.add(element); }); + return result; + } + // DynamicInstLoweringContext implementation PropagationInfo tryGetInfo(IRInst* inst) { // If this is a global instruction (parent is module), return concrete info if (as(inst->getParent())) - { if (as(inst) || as(inst) || as(inst)) - { - return PropagationInfo::makeSingletonSet(inst); - } + return makeSingletonSet(inst); else - { - return PropagationInfo::none(); - } - } + return none(); // For non-global instructions, look up in the map auto found = propagationMap.tryGetValue(inst); if (found) return *found; - return PropagationInfo::none(); + return none(); } PropagationInfo tryGetFuncReturnInfo(IRFunc* func) @@ -263,7 +316,7 @@ struct DynamicInstLoweringContext auto found = funcReturnInfo.tryGetValue(func); if (found) return *found; - return PropagationInfo::none(); + return none(); } // Centralized method to update propagation info and manage work queue @@ -422,7 +475,7 @@ struct DynamicInstLoweringContext if (argInfo.judgment == PropagationJudgment::Existential && destInfo.judgment == PropagationJudgment::Existential) { - if (argInfo.possibleValues.getCount() != destInfo.possibleValues.getCount()) + if (argInfo.getCollectionCount() != destInfo.getCollectionCount()) { // If the sets of witness tables are not equal, reinterpret to the parameter type IRBuilder builder(module); @@ -592,17 +645,17 @@ struct DynamicInstLoweringContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return PropagationInfo::makeExistential(tables); + return makeExistential(tables); else - return PropagationInfo::none(); + return none(); } else { - return PropagationInfo::makeUnbounded(); + return makeUnbounded(); } } - return PropagationInfo::none(); + return none(); } PropagationInfo analyzeMakeExistential(IRMakeExistential* inst) @@ -615,17 +668,16 @@ struct DynamicInstLoweringContext auto witnessTableInfo = tryGetInfo(witnessTable); if (!witnessTableInfo) - return PropagationInfo::none(); + return none(); if (witnessTableInfo.judgment == PropagationJudgment::Unbounded) - return PropagationInfo::makeUnbounded(); + return makeUnbounded(); HashSet tables; if (witnessTableInfo.judgment == PropagationJudgment::Set) - for (auto table : witnessTableInfo.possibleValues) - tables.add(table); + forEachInCollection(witnessTableInfo, [&](IRInst* table) { tables.add(table); }); - return PropagationInfo::makeExistential(tables); + return makeExistential(tables); } static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) @@ -652,9 +704,10 @@ struct DynamicInstLoweringContext case PropagationJudgment::Set: { HashSet results; - for (auto table : witnessTableInfo.possibleValues) - results.add(findEntryInConcreteTable(table, key)); - return PropagationInfo::makeSet(results); + forEachInCollection( + witnessTableInfo, + [&](IRInst* table) { results.add(findEntryInConcreteTable(table, key)); }); + return makeSet(results); } case PropagationJudgment::Existential: SLANG_UNEXPECTED("Unexpected LookupWitnessMethod on Existential"); @@ -673,11 +726,16 @@ struct DynamicInstLoweringContext switch (operandInfo.judgment) { case PropagationJudgment::None: - return PropagationInfo::none(); + return none(); case PropagationJudgment::Unbounded: - return PropagationInfo::makeUnbounded(); + return makeUnbounded(); case PropagationJudgment::Existential: - return PropagationInfo::makeSet(operandInfo.possibleValues); + { + // Convert collection to HashSet and create Set PropagationInfo + HashSet tables; + forEachInCollection(operandInfo, [&](IRInst* table) { tables.add(table); }); + return makeSet(tables); + } case PropagationJudgment::Set: SLANG_UNEXPECTED( "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); @@ -697,17 +755,21 @@ struct DynamicInstLoweringContext switch (operandInfo.judgment) { case PropagationJudgment::None: - return PropagationInfo::none(); + return none(); case PropagationJudgment::Unbounded: - return PropagationInfo::makeUnbounded(); + return makeUnbounded(); case PropagationJudgment::Existential: { HashSet types; - for (auto table : operandInfo.possibleValues) - if (auto witnessTable = cast(table)) // Expect witness table - if (auto concreteType = witnessTable->getConcreteType()) - types.add(concreteType); - return PropagationInfo::makeSet(types); + forEachInCollection( + operandInfo, + [&](IRInst* table) + { + if (auto witnessTable = cast(table)) // Expect witness table + if (auto concreteType = witnessTable->getConcreteType()) + types.add(concreteType); + }); + return makeSet(types); } case PropagationJudgment::Set: SLANG_UNEXPECTED( @@ -724,7 +786,7 @@ struct DynamicInstLoweringContext // We don't care about the value itself. // (We rely on the propagation info for the type) // - return PropagationInfo::none(); + return none(); } PropagationInfo analyzeCall(IRCall* inst, LinkedList& workQueue) @@ -764,14 +826,15 @@ struct DynamicInstLoweringContext if (calleeInfo.judgment == PropagationJudgment::Set) { // If we have a set of functions, register each one - for (auto func : calleeInfo.possibleValues) - propagateToCallSite(cast(func)); + forEachInCollection( + calleeInfo, + [&](IRInst* func) { propagateToCallSite(cast(func)); }); } if (auto callInfo = tryGetInfo(inst)) return callInfo; else - return PropagationInfo::none(); + return none(); } void propagateWithinFuncEdge(IREdge edge, LinkedList& workQueue) @@ -861,7 +924,7 @@ struct DynamicInstLoweringContext PropagationInfo getFuncReturnInfo(IRFunc* func) { - funcReturnInfo.addIfNotExists(func, PropagationInfo::none()); + funcReturnInfo.addIfNotExists(func, none()); return funcReturnInfo[func]; } @@ -882,13 +945,13 @@ struct DynamicInstLoweringContext if (auto interfaceType = as(paramType)) { if (interfaceType->findDecoration()) - propagationMap[param] = PropagationInfo::makeUnbounded(); + propagationMap[param] = makeUnbounded(); else - propagationMap[param] = PropagationInfo::none(); // Initialize to none. + propagationMap[param] = none(); // Initialize to none. } else { - propagationMap[param] = PropagationInfo::none(); + propagationMap[param] = none(); } } } @@ -897,7 +960,7 @@ struct DynamicInstLoweringContext { if (infos.getCount() == 0) { - return PropagationInfo::none(); + return none(); } if (infos.getCount() == 1) @@ -935,33 +998,31 @@ struct DynamicInstLoweringContext break; case PropagationJudgment::Set: unionJudgment = PropagationJudgment::Set; - for (auto value : info.possibleValues) - allValues.add(value); + forEachInCollection(info, [&](IRInst* value) { allValues.add(value); }); break; case PropagationJudgment::Existential: // For existential union, we need to collect all witness tables // For now, we'll handle this properly by creating a new existential with all tables unionJudgment = PropagationJudgment::Existential; - for (auto value : info.possibleValues) - allValues.add(value); + forEachInCollection(info, [&](IRInst* value) { allValues.add(value); }); break; case PropagationJudgment::Unbounded: // If any info is unbounded, the union is unbounded - return PropagationInfo::makeUnbounded(); + return makeUnbounded(); } } if (unionJudgment == PropagationJudgment::Existential) if (allValues.getCount() > 0) - return PropagationInfo::makeExistential(allValues); + return makeExistential(allValues); else - return PropagationInfo::none(); + return none(); if (unionJudgment == PropagationJudgment::Set) if (allValues.getCount() > 1) - return PropagationInfo::makeSet(allValues); + return makeSet(allValues); else - return PropagationInfo::none(); + return none(); // If we reach here, crash instead of returning none (which could make the analysis go into // an infinite loop) @@ -980,11 +1041,13 @@ struct DynamicInstLoweringContext PropagationInfo analyzeDefault(IRInst* inst) { - // Check if this is a type, witness table, or function - if (as(inst) || as(inst) || as(inst)) - return PropagationInfo::makeSingletonSet(inst); + // Check if this is a global type, witness table, or function. + // If so, it's a concrete element. We'll create a singleton set for it. + if (inst->getParent()->getOp() == kIROp_ModuleInst && + (as(inst) || as(inst) || as(inst))) + return makeSingletonSet(inst); else - return PropagationInfo::none(); // Default case, no propagation info + return none(); // Default case, no propagation info } void performDynamicInstLowering() @@ -1100,10 +1163,14 @@ struct DynamicInstLoweringContext HashSet types; // Extract types from witness tables by looking at the concrete types - for (auto table : info.possibleValues) - if (auto witnessTable = as(table)) - if (auto concreteType = witnessTable->getConcreteType()) - types.add(concreteType); + forEachInCollection( + info, + [&](IRInst* table) + { + if (auto witnessTable = as(table)) + if (auto concreteType = witnessTable->getConcreteType()) + types.add(concreteType); + }); auto anyValueType = createAnyValueTypeFromInsts(types); return builder.getTupleType(List({builder.getUIntType(), anyValueType})); @@ -1161,7 +1228,7 @@ struct DynamicInstLoweringContext if (inst->getDataType()->getOp() == kIROp_TypeKind) { // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); + auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(info)); // Store the mapping for later use loweredInstToAnyValueType[inst] = anyValueType; @@ -1182,8 +1249,8 @@ struct DynamicInstLoweringContext // Create a key mapping function auto keyMappingFunc = createKeyMappingFunc( inst->getRequirementKey(), - witnessTableInfo.possibleValues, - info.possibleValues); + collectionToHashSet(witnessTableInfo), + collectionToHashSet(info)); // Replace with call to key mapping function auto witnessTableId = builder.emitCallInst( @@ -1248,7 +1315,7 @@ struct DynamicInstLoweringContext builder.setInsertBefore(inst); // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(info.possibleValues); + auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(info)); // Store the mapping for later use loweredInstToAnyValueType[inst] = anyValueType; @@ -1308,7 +1375,7 @@ struct DynamicInstLoweringContext auto expectedFuncType = getExpectedFuncType(inst); // Create dispatch function - auto dispatchFunc = createDispatchFunc(calleeInfo.possibleValues, expectedFuncType); + auto dispatchFunc = createDispatchFunc(collectionToHashSet(calleeInfo), expectedFuncType); // Replace call with dispatch List newArgs; @@ -1335,22 +1402,24 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - // Get unique ID for the witness table. TODO: Assert that this is a concrete table.. + // Get unique ID for the witness table. auto witnessTable = cast(inst->getWitnessTable()); auto tableId = builder.getIntValue(builder.getUIntType(), getUniqueID(witnessTable)); // Collect types from the witness tables to determine the any-value type HashSet types; - for (auto table : info.possibleValues) - { - if (auto witnessTableInst = as(table)) + forEachInCollection( + info, + [&](IRInst* table) { - if (auto concreteType = witnessTableInst->getConcreteType()) + if (auto witnessTableInst = as(table)) { - types.add(concreteType); + if (auto concreteType = witnessTableInst->getConcreteType()) + { + types.add(concreteType); + } } - } - } + }); // Create the appropriate any-value type auto anyValueType = createAnyValueType(types); @@ -1378,14 +1447,20 @@ struct DynamicInstLoweringContext return; Dictionary mapping; - for (auto table : info.possibleValues) - { - // Get unique ID for the witness table - auto witnessTable = cast(table); - auto outputId = getUniqueID(witnessTable); - auto inputId = table->findDecoration()->getSequentialID(); - mapping[inputId] = outputId; // Map ID to itself for now - } + forEachInCollection( + info, + [&](IRInst* table) + { + // Get unique ID for the witness table + auto witnessTable = cast(table); + auto outputId = getUniqueID(witnessTable); + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping[inputId] = outputId; // Map ID to itself for now + } + }); IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1757,4 +1832,5 @@ void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) DynamicInstLoweringContext context(module, sink); context.processModule(); } + } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6c26708064c..b2ef78c4dd6 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8722,6 +8722,9 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_DetachDerivative: return false; + case kIROp_TypeFlowCollection: + return false; + case kIROp_Div: case kIROp_IRem: if (isIntegralScalarOrCompositeType(getFullType())) From ca8886bd839915882cfe484bfb90e1037d523df9 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 24 Jul 2025 13:53:54 -0400 Subject: [PATCH 009/105] Integrate into specialization pass + add initial support for propagating through specialization --- source/slang/slang-ir-lower-dynamic-insts.cpp | 218 +++++++++++++----- source/slang/slang-ir-lower-dynamic-insts.h | 2 +- source/slang/slang-ir-specialize.cpp | 8 +- 3 files changed, 169 insertions(+), 59 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 354d3e4a54c..c12203861ba 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -184,8 +184,6 @@ struct WorkItem } }; -// PropagationInfo implementations will be added to DynamicInstLoweringContext - bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) { if (a.judgment != b.judgment) @@ -198,8 +196,6 @@ bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) return true; case PropagationJudgment::Set: case PropagationJudgment::Existential: - // For collection-based sets, compare the collection instructions - // If both have the same collection instruction (from hoisting), they're equal return a.collection == b.collection; default: return false; @@ -211,9 +207,6 @@ struct DynamicInstLoweringContext // Helper methods for creating canonical collections IRInst* createCollection(const HashSet& elements) { - if (elements.getCount() == 0) - return nullptr; - List sortedElements; for (auto element : elements) sortedElements.add(element); @@ -493,8 +486,9 @@ struct DynamicInstLoweringContext return arg; // Can use as-is. } - void insertReinterprets() + bool insertReinterprets() { + bool changed = false; // Process each function in the module for (auto inst : module->getGlobalInsts()) { @@ -537,6 +531,7 @@ struct DynamicInstLoweringContext if (newArg != arg) { + changed = true; // Replace the argument in the branch instruction SLANG_ASSERT(!as(unconditionalBranch)); unconditionalBranch->setOperand(1 + paramIndex, newArg); @@ -557,6 +552,7 @@ struct DynamicInstLoweringContext if (newReturnVal != returnInst->getVal()) { // Replace the return value with the reinterpreted value + changed = true; returnInst->setOperand(0, newReturnVal); } } @@ -585,6 +581,7 @@ struct DynamicInstLoweringContext if (newArg != callInst->getArg(i)) { // Replace the argument in the call instruction + changed = true; callInst->setArg(i, newArg); } i++; @@ -594,6 +591,8 @@ struct DynamicInstLoweringContext } } } + + return changed; } void processInstForPropagation(IRInst* inst, LinkedList& workQueue) @@ -624,6 +623,9 @@ struct DynamicInstLoweringContext case kIROp_Call: info = analyzeCall(as(inst), workQueue); break; + case kIROp_Specialize: + info = analyzeSpecialize(as(inst)); + break; default: info = analyzeDefault(inst); break; @@ -789,6 +791,90 @@ struct DynamicInstLoweringContext return none(); } + + PropagationInfo analyzeSpecialize(IRSpecialize* inst) + { + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(operand); + + switch (operandInfo.judgment) + { + case PropagationJudgment::None: + return none(); + case PropagationJudgment::Unbounded: + return makeUnbounded(); + case PropagationJudgment::Existential: + { + SLANG_UNEXPECTED( + "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); + } + case PropagationJudgment::Set: + { + List specializationArgs; + for (auto i = 0; i < inst->getArgCount(); ++i) + { + // For integer args, add as is (also applies to any value args) + if (as(inst->getArg(i))) + specializationArgs.add(inst->getArg(i)); + + // For type args, we need to replace any dynamic args with + // their sets. + // + auto argInfo = tryGetInfo(inst->getArg(i)); + switch (argInfo.judgment) + { + case PropagationJudgment::None: + case PropagationJudgment::Unbounded: + SLANG_UNEXPECTED( + "Unexpected PropagationJudgment for specialization argument"); + case PropagationJudgment::Existential: + SLANG_UNEXPECTED( + "Unexpected Existential operand in specialization argument. Should be " + "set"); + case PropagationJudgment::Set: + { + if (argInfo.getCollectionCount() == 1) + specializationArgs.add(argInfo.getSingletonValue()); + else + specializationArgs.add(argInfo.collection); + break; + } + default: + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeSpecialize"); + break; + } + } + + IRType* typeOfSpecialization = nullptr; + if (inst->getDataType()->getParent()->getOp() == kIROp_ModuleInst) + typeOfSpecialization = inst->getDataType(); + else + SLANG_UNIMPLEMENTED_X("unhandled specialization type in non-global context"); + + // Specialize each element in the set + HashSet specializedSet; + forEachInCollection( + operandInfo, + [&](IRInst* arg) + { + // Create a new specialized instruction for each argument + IRBuilder builder(module); + builder.setInsertInto(module); + specializedSet.add(builder.emitSpecializeInst( + typeOfSpecialization, + arg, + specializationArgs)); + }); + return makeSet(specializedSet); + } + break; + default: + SLANG_UNEXPECTED( + "Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); + break; + } + } + PropagationInfo analyzeCall(IRCall* inst, LinkedList& workQueue) { auto callee = inst->getCallee(); @@ -1050,7 +1136,7 @@ struct DynamicInstLoweringContext return none(); // Default case, no propagation info } - void performDynamicInstLowering() + bool performDynamicInstLowering() { // Collect all instructions that need lowering List typeInstsToLower; @@ -1058,6 +1144,7 @@ struct DynamicInstLoweringContext List instWithReplacementTypes; List funcTypesToProcess; + bool hasChanges = false; for (auto globalInst : module->getGlobalInsts()) { if (auto func = as(globalInst)) @@ -1115,19 +1202,21 @@ struct DynamicInstLoweringContext } for (auto inst : typeInstsToLower) - lowerInst(inst); + hasChanges |= lowerInst(inst); for (auto func : funcTypesToProcess) - replaceFuncType(func, this->funcReturnInfo[func]); + hasChanges |= replaceFuncType(func, this->funcReturnInfo[func]); for (auto inst : valueInstsToLower) - lowerInst(inst); + hasChanges |= lowerInst(inst); for (auto inst : instWithReplacementTypes) - replaceType(inst); + hasChanges |= replaceType(inst); + + return hasChanges; } - void replaceFuncType(IRFunc* func, PropagationInfo& returnTypeInfo) + bool replaceFuncType(IRFunc* func, PropagationInfo& returnTypeInfo) { IRFuncType* origFuncType = as(func->getFullType()); IRType* returnType = origFuncType->getResultType(); @@ -1152,7 +1241,13 @@ struct DynamicInstLoweringContext IRBuilder builder(module); builder.setInsertBefore(func); - func->setFullType(builder.getFuncType(paramTypes, returnType)); + + auto newFuncType = builder.getFuncType(paramTypes, returnType); + if (newFuncType == func->getFullType()) + return false; // No change + + func->setFullType(newFuncType); + return true; } IRType* getTypeForExistential(PropagationInfo info) @@ -1162,7 +1257,6 @@ struct DynamicInstLoweringContext builder.setInsertInto(module); HashSet types; - // Extract types from witness tables by looking at the concrete types forEachInCollection( info, [&](IRInst* table) @@ -1176,48 +1270,44 @@ struct DynamicInstLoweringContext return builder.getTupleType(List({builder.getUIntType(), anyValueType})); } - void replaceType(IRInst* inst) + bool replaceType(IRInst* inst) { auto info = tryGetInfo(inst); if (!info || info.judgment != PropagationJudgment::Existential) - return; + return false; inst->setFullType(getTypeForExistential(info)); + return true; } - void lowerInst(IRInst* inst) + bool lowerInst(IRInst* inst) { switch (inst->getOp()) { case kIROp_LookupWitnessMethod: - lowerLookupWitnessMethod(as(inst)); - break; + return lowerLookupWitnessMethod(as(inst)); case kIROp_ExtractExistentialWitnessTable: - lowerExtractExistentialWitnessTable(as(inst)); - break; + return lowerExtractExistentialWitnessTable(as(inst)); case kIROp_ExtractExistentialType: - lowerExtractExistentialType(as(inst)); - break; + return lowerExtractExistentialType(as(inst)); case kIROp_ExtractExistentialValue: - lowerExtractExistentialValue(as(inst)); - break; + return lowerExtractExistentialValue(as(inst)); case kIROp_Call: - lowerCall(as(inst)); - break; + return lowerCall(as(inst)); case kIROp_MakeExistential: - lowerMakeExistential(as(inst)); - break; + return lowerMakeExistential(as(inst)); case kIROp_CreateExistentialObject: - lowerCreateExistentialObject(as(inst)); - break; + return lowerCreateExistentialObject(as(inst)); + default: + return false; } } - void lowerLookupWitnessMethod(IRLookupWitnessMethod* inst) + bool lowerLookupWitnessMethod(IRLookupWitnessMethod* inst) { auto info = tryGetInfo(inst); if (!info) - return; + return false; IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1236,7 +1326,7 @@ struct DynamicInstLoweringContext // Replace the instruction with the any-value type inst->replaceUsesWith(anyValueType); inst->removeAndDeallocate(); - return; + return true; } else { @@ -1260,16 +1350,19 @@ struct DynamicInstLoweringContext inst->replaceUsesWith(witnessTableId); propagationMap[witnessTableId] = info; inst->removeAndDeallocate(); + return true; } } } + + return false; } - void lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) + bool lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) { auto info = tryGetInfo(inst); if (!info) - return; + return false; IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1282,11 +1375,17 @@ struct DynamicInstLoweringContext inst->replaceUsesWith(element); propagationMap[element] = info; inst->removeAndDeallocate(); + return true; } + return false; } - void lowerExtractExistentialValue(IRExtractExistentialValue* inst) + bool lowerExtractExistentialValue(IRExtractExistentialValue* inst) { + auto operandInfo = tryGetInfo(inst->getOperand(0)); + if (!operandInfo || operandInfo.judgment != PropagationJudgment::Existential) + return false; + IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1303,13 +1402,14 @@ struct DynamicInstLoweringContext auto element = builder.emitGetTupleElement(resultType, operand, 1); inst->replaceUsesWith(element); inst->removeAndDeallocate(); + return true; } - void lowerExtractExistentialType(IRExtractExistentialType* inst) + bool lowerExtractExistentialType(IRExtractExistentialType* inst) { auto info = tryGetInfo(inst); if (!info || info.judgment != PropagationJudgment::Set) - return; + return false; IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1323,6 +1423,7 @@ struct DynamicInstLoweringContext // Replace the instruction with the any-value type inst->replaceUsesWith(anyValueType); inst->removeAndDeallocate(); + return true; } IRFuncType* getExpectedFuncType(IRCall* inst) @@ -1359,16 +1460,16 @@ struct DynamicInstLoweringContext return builder.getFuncType(argTypes, resultType); } - void lowerCall(IRCall* inst) + bool lowerCall(IRCall* inst) { auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(callee); if (!calleeInfo || calleeInfo.judgment != PropagationJudgment::Set) - return; + return false; if (calleeInfo.isSingleton() && calleeInfo.getSingletonValue() == callee) - return; + return false; IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1391,13 +1492,14 @@ struct DynamicInstLoweringContext propagationMap[newCall] = info; replaceType(newCall); // "maybe replace type" inst->removeAndDeallocate(); + return true; } - void lowerMakeExistential(IRMakeExistential* inst) + bool lowerMakeExistential(IRMakeExistential* inst) { auto info = tryGetInfo(inst); if (!info || info.judgment != PropagationJudgment::Existential) - return; + return false; IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -1438,13 +1540,14 @@ struct DynamicInstLoweringContext inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); + return true; } - void lowerCreateExistentialObject(IRCreateExistentialObject* inst) + bool lowerCreateExistentialObject(IRCreateExistentialObject* inst) { auto info = tryGetInfo(inst); if (!info || info.judgment != PropagationJudgment::Existential) - return; + return false; Dictionary mapping; forEachInCollection( @@ -1481,6 +1584,7 @@ struct DynamicInstLoweringContext inst->replaceUsesWith(existentialTuple); inst->removeAndDeallocate(); + return true; } UInt getUniqueID(IRInst* funcOrTable) @@ -1786,18 +1890,22 @@ struct DynamicInstLoweringContext return tables; } - void processModule() + bool processModule() { + bool hasChanges = false; + // Phase 1: Information Propagation performInformationPropagation(); // Phase 1.5: Insert reinterprets for points where sets merge // e.g. phi, return, call // - insertReinterprets(); + hasChanges |= insertReinterprets(); // Phase 2: Dynamic Instruction Lowering - performDynamicInstLowering(); + hasChanges |= performDynamicInstLowering(); + + return hasChanges; } DynamicInstLoweringContext(IRModule* module, DiagnosticSink* sink) @@ -1813,10 +1921,10 @@ struct DynamicInstLoweringContext Dictionary propagationMap; // Mapping from function to return value propagation information - Dictionary funcReturnInfo; + Dictionary funcReturnInfo; // Mapping from functions to call-sites. - Dictionary> funcCallSites; + Dictionary> funcCallSites; // Unique ID assignment for functions and witness tables Dictionary uniqueIds; @@ -1827,10 +1935,10 @@ struct DynamicInstLoweringContext }; // Main entry point -void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) +bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) { DynamicInstLoweringContext context(module, sink); - context.processModule(); + return context.processModule(); } } // namespace Slang diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index 33872899e40..9e7d779aac6 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -7,5 +7,5 @@ namespace Slang { // Main entry point for the pass -void lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); +bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); } // namespace Slang diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index ed96b23c1e2..7e131226e91 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -5,6 +5,7 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-insts.h" +#include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-lower-witness-lookup.h" #include "slang-ir-peephole.h" #include "slang-ir-sccp.h" @@ -1157,7 +1158,7 @@ struct SpecializationContext // if (options.lowerWitnessLookups) { - iterChanged = lowerWitnessLookup(module, sink); + iterChanged = lowerDynamicInsts(module, sink); } if (!iterChanged || sink->getErrorCount()) @@ -3099,8 +3100,9 @@ IRInst* specializeGenericImpl( builder->setInsertBefore(genericVal); List pendingWorkList; - SLANG_DEFER(for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--) if (context) - context->addToWorkList(pendingWorkList[ii]);); + SLANG_DEFER( + for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--) if (context) + context->addToWorkList(pendingWorkList[ii]);); // Now we will run through the body of the generic and // clone each of its instructions into the global scope, From 47f7786256c1c0c6fa579edcc42a6d49507bd77e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:35:11 -0400 Subject: [PATCH 010/105] Make propagation map use insts-with-context --- source/slang/slang-ir-lower-dynamic-insts.cpp | 54 ++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index c12203861ba..f960707ddee 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -8,7 +8,48 @@ namespace Slang { +// Elements for which we keep track of propagation information. +struct Element +{ + IRInst* context; + IRInst* inst; + + Element(IRInst* context, IRInst* inst) + : context(context), inst(inst) + { + } + + // Create element from an instruction that has a + // concrete parent (i.e. global IRFunc) + // + Element(IRInst* inst) + : inst(inst) + { + auto block = cast(inst->getParent()); + auto func = cast(block->getParent()); + // If parent func is not a global, then it is not a direct + // reference. An explicit IRSpecialize instruction must be provided as + // context. + // + SLANG_ASSERT(func->getParent()->getOp() == kIROp_ModuleInst); + + context = func; + } + + Element(const Element& other) + : context(other.context), inst(other.inst) + { + } + + bool operator==(const Element& other) const + { + return context == other.context && inst == other.inst; + } + + // getHashCode() + HashCode64 getHashCode() const { return combineHash(HashCode(context), HashCode(inst)); } +}; // Enumeration for different kinds of judgments about IR instructions. // // This forms a lattice with @@ -298,7 +339,7 @@ struct DynamicInstLoweringContext return none(); // For non-global instructions, look up in the map - auto found = propagationMap.tryGetValue(inst); + auto found = propagationMap.tryGetValue(Element(inst)); if (found) return *found; return none(); @@ -405,6 +446,12 @@ struct DynamicInstLoweringContext continue; processInstForPropagation(inst, workQueue); } + + if (auto returnInfo = as(block->getTerminator())) + { + auto valInfo = returnInfo->getVal(); + updateFuncReturnInfo(as(block->getParent()), tryGetInfo(valInfo), workQueue); + } }; void performInformationPropagation() @@ -1272,6 +1319,9 @@ struct DynamicInstLoweringContext bool replaceType(IRInst* inst) { + if (inst->getParent() == nullptr) + return false; // Not a valid instruction + auto info = tryGetInfo(inst); if (!info || info.judgment != PropagationJudgment::Existential) return false; @@ -1918,7 +1968,7 @@ struct DynamicInstLoweringContext DiagnosticSink* sink; // Mapping from instruction to propagation information - Dictionary propagationMap; + Dictionary propagationMap; // Mapping from function to return value propagation information Dictionary funcReturnInfo; From f07726f7e9862c542c9b3b9bdd5bcbb057990a18 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 25 Jul 2025 11:27:35 -0400 Subject: [PATCH 011/105] Improve context-sensitive analysis. Concrete specialization of dynamic method is now functional --- source/slang/slang-ir-lower-dynamic-insts.cpp | 633 ++++++++++++------ .../dependent-assoc-types-2.slang | 69 +- 2 files changed, 464 insertions(+), 238 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index f960707ddee..a100a4be550 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -14,9 +14,47 @@ struct Element IRInst* context; IRInst* inst; + Element() + : context(nullptr), inst(nullptr) + { + } + Element(IRInst* context, IRInst* inst) : context(context), inst(inst) { + validateElement(); + } + + bool validateElement() const + { + switch (context->getOp()) + { + case kIROp_Func: + { + SLANG_ASSERT(inst->getParent()->getParent() == context); + break; + } + case kIROp_Specialize: + { + auto generic = as(context)->getBase(); + // The base should be a parent of the inst. + bool foundParent = false; + for (auto parent = inst->getParent(); parent; parent = parent->getParent()) + { + if (parent == generic) + { + foundParent = true; + break; + } + } + SLANG_ASSERT(foundParent); + } + break; + default: + { + SLANG_UNEXPECTED("Invalid context for Element"); + } + } } // Create element from an instruction that has a @@ -135,16 +173,18 @@ struct InterproceduralEdge }; Direction direction; - IRCall* callInst; // The call instruction - IRFunc* targetFunc; // The function being called/returned from + IRInst* callerContext; // The context of the call (e.g. function or specialized generic) + IRCall* callInst; // The call instruction + IRInst* targetContext; // The function/specialized-generic being called/returned from InterproceduralEdge() = default; - InterproceduralEdge(Direction dir, IRCall* call, IRFunc* func) - : direction(dir), callInst(call), targetFunc(func) + InterproceduralEdge(Direction dir, IRInst* callerContext, IRCall* call, IRInst* func) + : direction(dir), callerContext(callerContext), callInst(call), targetContext(func) { } }; + // Union type representing either an intra-procedural or interprocedural edge struct WorkItem { @@ -158,6 +198,7 @@ struct WorkItem }; Type type; + IRInst* context; // The context of the work item. union { IRInst* inst; // Type::Inst @@ -171,34 +212,41 @@ struct WorkItem { } - WorkItem(IRInst* inst) - : type(Type::Inst), inst(inst) + WorkItem(IRInst* context, IRInst* inst) + : type(Type::Inst), inst(inst), context(context) { + SLANG_ASSERT(context != nullptr && inst != nullptr); + // Validate that the context is appropriate for the instruction + Element(context, inst).validateElement(); } - WorkItem(IRBlock* block) - : type(Type::Block), block(block) + WorkItem(IRInst* context, IRBlock* block) + : type(Type::Block), block(block), context(context) { + SLANG_ASSERT(context != nullptr && block != nullptr); + // Validate that the context is appropriate for the block + Element(context, block->getFirstChild()).validateElement(); } - WorkItem(IREdge edge) - : type(Type::IntraProc), intraProcEdge(edge) + WorkItem(IRInst* context, IREdge edge) + : type(Type::IntraProc), intraProcEdge(edge), context(context) { + SLANG_ASSERT(context != nullptr); } WorkItem(InterproceduralEdge edge) - : type(Type::InterProc), interProcEdge(edge) + : type(Type::InterProc), interProcEdge(edge), context(nullptr) { } - WorkItem(InterproceduralEdge::Direction dir, IRCall* call, IRFunc* func) - : type(Type::InterProc), interProcEdge(dir, call, func) + WorkItem(InterproceduralEdge::Direction dir, IRInst* callerCtx, IRCall* call, IRInst* callee) + : type(Type::InterProc), interProcEdge(dir, callerCtx, call, callee), context(nullptr) { } // Copy constructor and assignment needed for union with non-trivial types WorkItem(const WorkItem& other) - : type(other.type) + : type(other.type), context(other.context) { if (type == Type::IntraProc) intraProcEdge = other.intraProcEdge; @@ -213,6 +261,7 @@ struct WorkItem WorkItem& operator=(const WorkItem& other) { type = other.type; + context = other.context; if (type == Type::IntraProc) intraProcEdge = other.intraProcEdge; else if (type == Type::InterProc) @@ -328,8 +377,7 @@ struct DynamicInstLoweringContext return result; } - // DynamicInstLoweringContext implementation - PropagationInfo tryGetInfo(IRInst* inst) + /*PropagationInfo tryGetInfo(IRInst* inst) { // If this is a global instruction (parent is module), return concrete info if (as(inst->getParent())) @@ -338,13 +386,30 @@ struct DynamicInstLoweringContext else return none(); + return tryGetInfo(Element(inst)); + }*/ + + PropagationInfo tryGetInfo(Element element) + { // For non-global instructions, look up in the map - auto found = propagationMap.tryGetValue(Element(inst)); + auto found = propagationMap.tryGetValue(element); if (found) return *found; return none(); } + PropagationInfo tryGetInfo(IRInst* context, IRInst* inst) + { + // If this is a global instruction (parent is module), return concrete info + if (as(inst->getParent())) + if (as(inst) || as(inst) || as(inst)) + return makeSingletonSet(inst); + else + return none(); + + return tryGetInfo(Element(context, inst)); + } + PropagationInfo tryGetFuncReturnInfo(IRFunc* func) { auto found = funcReturnInfo.tryGetValue(func); @@ -356,9 +421,13 @@ struct DynamicInstLoweringContext // Centralized method to update propagation info and manage work queue // Use this when you want to propagate new information to an existing instruction // This will union the new info with existing info and add users to work queue if changed - void updateInfo(IRInst* inst, PropagationInfo newInfo, LinkedList& workQueue) + void updateInfo( + IRInst* context, + IRInst* inst, + PropagationInfo newInfo, + LinkedList& workQueue) { - auto existingInfo = tryGetInfo(inst); + auto existingInfo = tryGetInfo(context, inst); auto unionedInfo = unionPropagationInfo(existingInfo, newInfo); // Only proceed if info actually changed @@ -366,15 +435,19 @@ struct DynamicInstLoweringContext return; // Update the propagation map - propagationMap[inst] = unionedInfo; + propagationMap[Element(context, inst)] = unionedInfo; // Add all users to appropriate work items - addUsersToWorkQueue(inst, unionedInfo, workQueue); + addUsersToWorkQueue(context, inst, unionedInfo, workQueue); } // Helper to add users of an instruction to the work queue based on how they use it // This handles intra-procedural edges, inter-procedural edges, and return value propagation - void addUsersToWorkQueue(IRInst* inst, PropagationInfo info, LinkedList& workQueue) + void addUsersToWorkQueue( + IRInst* context, + IRInst* inst, + PropagationInfo info, + LinkedList& workQueue) { for (auto use = inst->firstUse; use; use = use->nextUse) { @@ -383,7 +456,7 @@ struct DynamicInstLoweringContext // If user is in a different block (or the inst is a param), add that block to work // queue. // - workQueue.addLast(WorkItem(user)); + workQueue.addLast(WorkItem(context, user)); // If user is a terminator, add intra-procedural edges if (auto terminator = as(user)) @@ -395,7 +468,7 @@ struct DynamicInstLoweringContext for (auto succIter = successors.begin(); succIter != successors.end(); ++succIter) { - workQueue.addLast(WorkItem(succIter.getEdge())); + workQueue.addLast(WorkItem(context, succIter.getEdge())); } } } @@ -403,54 +476,53 @@ struct DynamicInstLoweringContext // If user is a return instruction, handle function return propagation if (auto returnInst = as(user)) { - auto func = as(returnInst->getParent()->getParent()); - if (func) - { - updateFuncReturnInfo(func, info, workQueue); - } + updateFuncReturnInfo(context, info, workQueue); } } } // Helper method to update function return info and propagate to call sites void updateFuncReturnInfo( - IRFunc* func, + IRInst* callable, PropagationInfo returnInfo, LinkedList& workQueue) { - auto existingReturnInfo = getFuncReturnInfo(func); + auto existingReturnInfo = getFuncReturnInfo(callable); auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnInfo); if (!areInfosEqual(existingReturnInfo, newReturnInfo)) { - funcReturnInfo[func] = newReturnInfo; + funcReturnInfo[callable] = newReturnInfo; // Add interprocedural edges to all call sites - if (funcCallSites.containsKey(func)) + if (funcCallSites.containsKey(callable)) { - for (auto callSite : funcCallSites[func]) + for (auto callSite : funcCallSites[callable]) { - workQueue.addLast( - WorkItem(InterproceduralEdge::Direction::FuncToCall, callSite, func)); + workQueue.addLast(WorkItem( + InterproceduralEdge::Direction::FuncToCall, + callSite.context, + as(callSite.inst), + callable)); } } } } - void processBlock(IRBlock* block, LinkedList& workQueue) + void processBlock(IRInst* context, IRBlock* block, LinkedList& workQueue) { for (auto inst : block->getChildren()) { // Skip parameters & terminator if (as(inst) || as(inst)) continue; - processInstForPropagation(inst, workQueue); + processInstForPropagation(context, inst, workQueue); } if (auto returnInfo = as(block->getTerminator())) { auto valInfo = returnInfo->getVal(); - updateFuncReturnInfo(as(block->getParent()), tryGetInfo(valInfo), workQueue); + updateFuncReturnInfo(context, tryGetInfo(context, valInfo), workQueue); } }; @@ -459,23 +531,10 @@ struct DynamicInstLoweringContext // Global worklist for interprocedural analysis LinkedList workQueue; - // Add all global function entry blocks to worklist. + // Add all global functions to worklist for (auto inst : module->getGlobalInsts()) - { if (auto func = as(inst)) - { - initializeFirstBlockParameters(func); - - // Add all blocks to start with. Once the initial - // sweep is done, propagation will proceed on an on-demand basis - // depending on affected blocks & edges - // - for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) - { - workQueue.addLast(WorkItem(block)); - } - } - } + discoverContext(func, workQueue); // Process until fixed point while (workQueue.getCount() > 0) @@ -487,13 +546,13 @@ struct DynamicInstLoweringContext switch (item.type) { case WorkItem::Type::Inst: - processInstForPropagation(item.inst, workQueue); + processInstForPropagation(item.context, item.inst, workQueue); break; case WorkItem::Type::Block: - processBlock(item.block, workQueue); + processBlock(item.context, item.block, workQueue); break; case WorkItem::Type::IntraProc: - propagateWithinFuncEdge(item.intraProcEdge, workQueue); + propagateWithinFuncEdge(item.context, item.intraProcEdge, workQueue); break; case WorkItem::Type::InterProc: propagateInterproceduralEdge(item.interProcEdge, workQueue); @@ -505,9 +564,9 @@ struct DynamicInstLoweringContext } } - IRInst* maybeReinterpret(IRInst* arg, PropagationInfo destInfo) + IRInst* maybeReinterpret(IRInst* context, IRInst* arg, PropagationInfo destInfo) { - auto argInfo = tryGetInfo(arg); + auto argInfo = tryGetInfo(context, arg); if (!argInfo || !destInfo) return arg; @@ -525,7 +584,7 @@ struct DynamicInstLoweringContext // and if it doesn't, this will help catch it before code-gen. // auto reinterpret = builder.emitReinterpret(nullptr, arg); - propagationMap[reinterpret] = destInfo; + propagationMap[Element(reinterpret)] = destInfo; return reinterpret; // Return the reinterpret instruction } } @@ -541,6 +600,7 @@ struct DynamicInstLoweringContext { if (auto func = as(inst)) { + auto context = func; // Skip the first block as it contains function parameters, not phi parameters for (auto block = func->getFirstBlock()->getNextBlock(); block; block = block->getNextBlock()) @@ -574,7 +634,7 @@ struct DynamicInstLoweringContext if (paramIndex < unconditionalBranch->getArgCount()) { auto arg = unconditionalBranch->getArg(paramIndex); - auto newArg = maybeReinterpret(arg, tryGetInfo(param)); + auto newArg = maybeReinterpret(context, arg, tryGetInfo(param)); if (newArg != arg) { @@ -595,7 +655,7 @@ struct DynamicInstLoweringContext { auto funcReturnInfo = tryGetFuncReturnInfo(func); auto newReturnVal = - maybeReinterpret(returnInst->getVal(), funcReturnInfo); + maybeReinterpret(context, returnInst->getVal(), funcReturnInfo); if (newReturnVal != returnInst->getVal()) { // Replace the return value with the reinterpreted value @@ -623,8 +683,10 @@ struct DynamicInstLoweringContext Index i = 0; for (auto param : irFunc->getParams()) { - auto newArg = - maybeReinterpret(callInst->getArg(i), tryGetInfo(param)); + auto newArg = maybeReinterpret( + context, + callInst->getArg(i), + tryGetInfo(param)); if (newArg != callInst->getArg(i)) { // Replace the argument in the call instruction @@ -642,46 +704,47 @@ struct DynamicInstLoweringContext return changed; } - void processInstForPropagation(IRInst* inst, LinkedList& workQueue) + void processInstForPropagation(IRInst* context, IRInst* inst, LinkedList& workQueue) { PropagationInfo info; switch (inst->getOp()) { case kIROp_CreateExistentialObject: - info = analyzeCreateExistentialObject(as(inst)); + info = analyzeCreateExistentialObject(context, as(inst)); break; case kIROp_MakeExistential: - info = analyzeMakeExistential(as(inst)); + info = analyzeMakeExistential(context, as(inst)); break; case kIROp_LookupWitnessMethod: - info = analyzeLookupWitnessMethod(as(inst)); + info = analyzeLookupWitnessMethod(context, as(inst)); break; case kIROp_ExtractExistentialWitnessTable: - info = - analyzeExtractExistentialWitnessTable(as(inst)); + info = analyzeExtractExistentialWitnessTable( + context, + as(inst)); break; case kIROp_ExtractExistentialType: - info = analyzeExtractExistentialType(as(inst)); + info = analyzeExtractExistentialType(context, as(inst)); break; case kIROp_ExtractExistentialValue: - info = analyzeExtractExistentialValue(as(inst)); + info = analyzeExtractExistentialValue(context, as(inst)); break; case kIROp_Call: - info = analyzeCall(as(inst), workQueue); + info = analyzeCall(context, as(inst), workQueue); break; case kIROp_Specialize: - info = analyzeSpecialize(as(inst)); + info = analyzeSpecialize(context, as(inst)); break; default: - info = analyzeDefault(inst); + info = analyzeDefault(context, inst); break; } - updateInfo(inst, info, workQueue); + updateInfo(context, inst, info, workQueue); } - PropagationInfo analyzeCreateExistentialObject(IRCreateExistentialObject* inst) + PropagationInfo analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { // // TODO: Actually use the integer<->type map present in the linkage to @@ -707,14 +770,14 @@ struct DynamicInstLoweringContext return none(); } - PropagationInfo analyzeMakeExistential(IRMakeExistential* inst) + PropagationInfo analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) { auto witnessTable = inst->getWitnessTable(); auto value = inst->getWrappedValue(); auto valueType = value->getDataType(); // Get the witness table info - auto witnessTableInfo = tryGetInfo(witnessTable); + auto witnessTableInfo = tryGetInfo(context, witnessTable); if (!witnessTableInfo) return none(); @@ -738,12 +801,12 @@ struct DynamicInstLoweringContext return nullptr; // Not found } - PropagationInfo analyzeLookupWitnessMethod(IRLookupWitnessMethod* inst) + PropagationInfo analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { auto key = inst->getRequirementKey(); auto witnessTable = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(witnessTable); + auto witnessTableInfo = tryGetInfo(context, witnessTable); switch (witnessTableInfo.judgment) { @@ -767,10 +830,12 @@ struct DynamicInstLoweringContext } } - PropagationInfo analyzeExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) + PropagationInfo analyzeExtractExistentialWitnessTable( + IRInst* context, + IRExtractExistentialWitnessTable* inst) { auto operand = inst->getOperand(0); - auto operandInfo = tryGetInfo(operand); + auto operandInfo = tryGetInfo(context, operand); switch (operandInfo.judgment) { @@ -796,10 +861,10 @@ struct DynamicInstLoweringContext } } - PropagationInfo analyzeExtractExistentialType(IRExtractExistentialType* inst) + PropagationInfo analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto operand = inst->getOperand(0); - auto operandInfo = tryGetInfo(operand); + auto operandInfo = tryGetInfo(context, operand); switch (operandInfo.judgment) { @@ -830,7 +895,7 @@ struct DynamicInstLoweringContext } } - PropagationInfo analyzeExtractExistentialValue(IRExtractExistentialValue* inst) + PropagationInfo analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { // We don't care about the value itself. // (We rely on the propagation info for the type) @@ -839,10 +904,10 @@ struct DynamicInstLoweringContext } - PropagationInfo analyzeSpecialize(IRSpecialize* inst) + PropagationInfo analyzeSpecialize(IRInst* context, IRSpecialize* inst) { auto operand = inst->getOperand(0); - auto operandInfo = tryGetInfo(operand); + auto operandInfo = tryGetInfo(context, operand); switch (operandInfo.judgment) { @@ -862,12 +927,15 @@ struct DynamicInstLoweringContext { // For integer args, add as is (also applies to any value args) if (as(inst->getArg(i))) + { specializationArgs.add(inst->getArg(i)); + continue; + } // For type args, we need to replace any dynamic args with // their sets. // - auto argInfo = tryGetInfo(inst->getArg(i)); + auto argInfo = tryGetInfo(context, inst->getArg(i)); switch (argInfo.judgment) { case PropagationJudgment::None: @@ -895,8 +963,40 @@ struct DynamicInstLoweringContext IRType* typeOfSpecialization = nullptr; if (inst->getDataType()->getParent()->getOp() == kIROp_ModuleInst) typeOfSpecialization = inst->getDataType(); + else if (auto funcType = as(inst->getDataType())) + { + auto substituteSets = [&](IRInst* type) -> IRInst* + { + if (auto info = tryGetInfo(context, type)) + { + if (info.judgment == PropagationJudgment::Set) + { + if (info.getCollectionCount() == 1) + return info.getSingletonValue(); + else + return info.collection; + } + else + return type; + } + else + return type; + }; + + List newParamTypes; + for (auto paramType : funcType->getParamTypes()) + newParamTypes.add((IRType*)substituteSets(paramType)); + IRBuilder builder(module); + builder.setInsertInto(module); + typeOfSpecialization = builder.getFuncType( + newParamTypes.getCount(), + newParamTypes.getBuffer(), + (IRType*)substituteSets(funcType->getResultType())); + } else - SLANG_UNIMPLEMENTED_X("unhandled specialization type in non-global context"); + { + SLANG_ASSERT_FAILURE("Unexpected data type for specialization instruction"); + } // Specialize each element in the set HashSet specializedSet; @@ -922,10 +1022,82 @@ struct DynamicInstLoweringContext } } - PropagationInfo analyzeCall(IRCall* inst, LinkedList& workQueue) + void discoverContext(IRInst* context, LinkedList& workQueue) + { + if (this->availableContexts.add(context)) + { + IRFunc* func = nullptr; + + // Newly discovered context. Initialize it. + switch (context->getOp()) + { + case kIROp_Func: + { + func = cast(context); + + // Initialize the first block parameters + initializeFirstBlockParameters(context, func); + + // Add all blocks to the work queue + for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) + workQueue.addLast(WorkItem(context, block)); + break; + } + case kIROp_Specialize: + { + auto specialize = cast(context); + auto generic = cast(specialize->getBase()); + func = cast(findGenericReturnVal(generic)); + + // Transfer information from specialization arguments to the params in the + // first generic block. + // + IRParam* param = generic->getFirstBlock()->getFirstParam(); + for (auto index = 0; index < specialize->getArgCount() && param; + ++index, param = param->getNextParam()) + { + // Map the specialization argument to the corresponding parameter + auto arg = specialize->getArg(index); + if (as(arg)) + continue; + + if (auto collection = as(arg)) + { + updateInfo( + context, + param, + PropagationInfo(PropagationJudgment::Set, collection), + workQueue); + } + else if (as(arg) || as(arg)) + { + updateInfo(context, param, makeSingletonSet(arg), workQueue); + } + else + { + SLANG_UNEXPECTED("Unexpected argument type in specialization"); + } + } + + // Initialize the first block parameters + initializeFirstBlockParameters(context, func); + + // Add all blocks to the work queue for an initial sweep + for (auto block = generic->getFirstBlock(); block; + block = block->getNextBlock()) + workQueue.addLast(WorkItem(context, block)); + + for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) + workQueue.addLast(WorkItem(context, block)); + } + } + } + } + + PropagationInfo analyzeCall(IRInst* context, IRCall* inst, LinkedList& workQueue) { auto callee = inst->getCallee(); - auto calleeInfo = tryGetInfo(callee); + auto calleeInfo = tryGetInfo(context, callee); auto funcType = as(callee->getDataType()); @@ -937,7 +1109,7 @@ struct DynamicInstLoweringContext // For now, we'll handle just a concrete func. But the logic for multiple functions // is exactly the same (add an edge for each function). // - auto propagateToCallSite = [&](IRFunc* func) + auto propagateToCallSite = [&](IRInst* callee) { // Register the call site in the map to allow for the // return-edge to be created. @@ -946,31 +1118,33 @@ struct DynamicInstLoweringContext // func, since we might have functions that are called indirectly // through lookups. // - this->funcCallSites.addIfNotExists(func, HashSet()); - if (this->funcCallSites[func].add(inst)) + discoverContext(callee, workQueue); + + this->funcCallSites.addIfNotExists(callee, HashSet()); + if (this->funcCallSites[callee].add(inst)) { // If this is a new call site, add a propagation task to the queue (in case there's // already information about this function) - workQueue.addLast(WorkItem(InterproceduralEdge::Direction::FuncToCall, inst, func)); + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::FuncToCall, context, inst, callee)); } - workQueue.addLast(WorkItem(InterproceduralEdge::Direction::CallToFunc, inst, func)); + workQueue.addLast( + WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); }; if (calleeInfo.judgment == PropagationJudgment::Set) { // If we have a set of functions, register each one - forEachInCollection( - calleeInfo, - [&](IRInst* func) { propagateToCallSite(cast(func)); }); + forEachInCollection(calleeInfo, [&](IRInst* func) { propagateToCallSite(func); }); } - if (auto callInfo = tryGetInfo(inst)) + if (auto callInfo = tryGetInfo(context, inst)) return callInfo; else return none(); } - void propagateWithinFuncEdge(IREdge edge, LinkedList& workQueue) + void propagateWithinFuncEdge(IRInst* context, IREdge edge, LinkedList& workQueue) { // Handle intra-procedural edge (original logic) auto predecessorBlock = edge.getPredecessor(); @@ -997,10 +1171,10 @@ struct DynamicInstLoweringContext if (paramIndex < unconditionalBranch->getArgCount()) { auto arg = unconditionalBranch->getArg(paramIndex); - if (auto argInfo = tryGetInfo(arg)) + if (auto argInfo = tryGetInfo(context, arg)) { // Use centralized update method - updateInfo(param, argInfo, workQueue); + updateInfo(context, param, argInfo, workQueue); } } paramIndex++; @@ -1011,14 +1185,14 @@ struct DynamicInstLoweringContext { // Handle interprocedural edge auto callInst = edge.callInst; - auto targetFunc = edge.targetFunc; + auto targetCallee = edge.targetContext; switch (edge.direction) { case InterproceduralEdge::Direction::CallToFunc: { // Propagate argument info from call site to function parameters - auto firstBlock = targetFunc->getFirstBlock(); + auto firstBlock = targetCallee->getFirstBlock(); if (!firstBlock) return; @@ -1028,10 +1202,10 @@ struct DynamicInstLoweringContext if (argIndex < callInst->getOperandCount()) { auto arg = callInst->getOperand(argIndex); - if (auto argInfo = tryGetInfo(arg)) + if (auto argInfo = tryGetInfo(edge.callerContext, arg)) { // Use centralized update method - updateInfo(param, argInfo, workQueue); + updateInfo(edge.targetContext, param, argInfo, workQueue); } } argIndex++; @@ -1041,11 +1215,11 @@ struct DynamicInstLoweringContext case InterproceduralEdge::Direction::FuncToCall: { // Propagate return value info from function to call site - auto returnInfo = funcReturnInfo.tryGetValue(targetFunc); + auto returnInfo = funcReturnInfo.tryGetValue(targetCallee); if (returnInfo) { // Use centralized update method - updateInfo(callInst, *returnInfo, workQueue); + updateInfo(edge.callerContext, callInst, *returnInfo, workQueue); } break; } @@ -1055,13 +1229,13 @@ struct DynamicInstLoweringContext } } - PropagationInfo getFuncReturnInfo(IRFunc* func) + PropagationInfo getFuncReturnInfo(IRInst* callee) { - funcReturnInfo.addIfNotExists(func, none()); - return funcReturnInfo[func]; + funcReturnInfo.addIfNotExists(callee, none()); + return funcReturnInfo[callee]; } - void initializeFirstBlockParameters(IRFunc* func) + void initializeFirstBlockParameters(IRInst* context, IRFunc* func) { auto firstBlock = func->getFirstBlock(); if (!firstBlock) @@ -1071,20 +1245,20 @@ struct DynamicInstLoweringContext for (auto param : firstBlock->getParams()) { auto paramType = param->getDataType(); - auto paramInfo = tryGetInfo(param); + auto paramInfo = tryGetInfo(context, param); if (paramInfo) continue; // Already has some information if (auto interfaceType = as(paramType)) { if (interfaceType->findDecoration()) - propagationMap[param] = makeUnbounded(); + propagationMap[Element(context, param)] = makeUnbounded(); else - propagationMap[param] = none(); // Initialize to none. + propagationMap[Element(context, param)] = none(); // Initialize to none. } else { - propagationMap[param] = none(); + propagationMap[Element(context, param)] = none(); } } } @@ -1152,7 +1326,7 @@ struct DynamicInstLoweringContext return none(); if (unionJudgment == PropagationJudgment::Set) - if (allValues.getCount() > 1) + if (allValues.getCount() > 0) return makeSet(allValues); else return none(); @@ -1172,7 +1346,7 @@ struct DynamicInstLoweringContext return unionPropagationInfo(infos); } - PropagationInfo analyzeDefault(IRInst* inst) + PropagationInfo analyzeDefault(IRInst* context, IRInst* inst) { // Check if this is a global type, witness table, or function. // If so, it's a concrete element. We'll create a singleton set for it. @@ -1186,9 +1360,9 @@ struct DynamicInstLoweringContext bool performDynamicInstLowering() { // Collect all instructions that need lowering - List typeInstsToLower; - List valueInstsToLower; - List instWithReplacementTypes; + List typeInstsToLower; + List valueInstsToLower; + List instWithReplacementTypes; List funcTypesToProcess; bool hasChanges = false; @@ -1196,6 +1370,7 @@ struct DynamicInstLoweringContext { if (auto func = as(globalInst)) { + auto context = func; // Process each function's instructions for (auto block : func->getBlocks()) { @@ -1209,37 +1384,38 @@ struct DynamicInstLoweringContext case kIROp_LookupWitnessMethod: { if (child->getDataType()->getOp() == kIROp_TypeKind) - typeInstsToLower.add(child); + typeInstsToLower.add(Element(context, child)); else - valueInstsToLower.add(child); + valueInstsToLower.add(Element(context, child)); break; } case kIROp_ExtractExistentialType: - typeInstsToLower.add(child); + typeInstsToLower.add(Element(context, child)); break; case kIROp_ExtractExistentialWitnessTable: case kIROp_ExtractExistentialValue: case kIROp_MakeExistential: case kIROp_CreateExistentialObject: - valueInstsToLower.add(child); + valueInstsToLower.add(Element(context, child)); break; case kIROp_Call: { - if (auto info = tryGetInfo(child)) + if (auto info = tryGetInfo(context, child)) if (info.judgment == PropagationJudgment::Existential) - instWithReplacementTypes.add(child); + instWithReplacementTypes.add(Element(context, child)); - if (auto calleeInfo = tryGetInfo(as(child)->getCallee())) + if (auto calleeInfo = + tryGetInfo(context, as(child)->getCallee())) if (calleeInfo.judgment == PropagationJudgment::Set) - valueInstsToLower.add(child); + valueInstsToLower.add(Element(context, child)); } break; default: - if (auto info = tryGetInfo(child)) + if (auto info = tryGetInfo(context, child)) if (info.judgment == PropagationJudgment::Existential) // If this instruction has a set of types, tables, or funcs, // we need to lower it to a unified type. - instWithReplacementTypes.add(child); + instWithReplacementTypes.add(Element(context, child)); } } } @@ -1248,17 +1424,21 @@ struct DynamicInstLoweringContext } } - for (auto inst : typeInstsToLower) - hasChanges |= lowerInst(inst); + for (auto instWithCtx : typeInstsToLower) + hasChanges |= lowerInst(instWithCtx.context, instWithCtx.inst); for (auto func : funcTypesToProcess) hasChanges |= replaceFuncType(func, this->funcReturnInfo[func]); - for (auto inst : valueInstsToLower) - hasChanges |= lowerInst(inst); + for (auto instWithCtx : valueInstsToLower) + hasChanges |= lowerInst(instWithCtx.context, instWithCtx.inst); - for (auto inst : instWithReplacementTypes) - hasChanges |= replaceType(inst); + for (auto instWithCtx : instWithReplacementTypes) + { + if (instWithCtx.inst->getParent() == nullptr) + continue; + hasChanges |= replaceType(instWithCtx.context, instWithCtx.inst); + } return hasChanges; } @@ -1313,16 +1493,14 @@ struct DynamicInstLoweringContext types.add(concreteType); }); - auto anyValueType = createAnyValueTypeFromInsts(types); - return builder.getTupleType(List({builder.getUIntType(), anyValueType})); + SLANG_ASSERT(types.getCount() > 0); + auto unionType = types.getCount() > 1 ? createAnyValueTypeFromInsts(types) : *types.begin(); + return builder.getTupleType(List({builder.getUIntType(), (IRType*)unionType})); } - bool replaceType(IRInst* inst) + bool replaceType(IRInst* context, IRInst* inst) { - if (inst->getParent() == nullptr) - return false; // Not a valid instruction - - auto info = tryGetInfo(inst); + auto info = tryGetInfo(context, inst); if (!info || info.judgment != PropagationJudgment::Existential) return false; @@ -1330,51 +1508,62 @@ struct DynamicInstLoweringContext return true; } - bool lowerInst(IRInst* inst) + bool lowerInst(IRInst* context, IRInst* inst) { switch (inst->getOp()) { case kIROp_LookupWitnessMethod: - return lowerLookupWitnessMethod(as(inst)); + return lowerLookupWitnessMethod(context, as(inst)); case kIROp_ExtractExistentialWitnessTable: - return lowerExtractExistentialWitnessTable(as(inst)); + return lowerExtractExistentialWitnessTable( + context, + as(inst)); case kIROp_ExtractExistentialType: - return lowerExtractExistentialType(as(inst)); + return lowerExtractExistentialType(context, as(inst)); case kIROp_ExtractExistentialValue: - return lowerExtractExistentialValue(as(inst)); + return lowerExtractExistentialValue(context, as(inst)); case kIROp_Call: - return lowerCall(as(inst)); + return lowerCall(context, as(inst)); case kIROp_MakeExistential: - return lowerMakeExistential(as(inst)); + return lowerMakeExistential(context, as(inst)); case kIROp_CreateExistentialObject: - return lowerCreateExistentialObject(as(inst)); + return lowerCreateExistentialObject(context, as(inst)); default: return false; } } - bool lowerLookupWitnessMethod(IRLookupWitnessMethod* inst) + bool lowerLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { - auto info = tryGetInfo(inst); + auto info = tryGetInfo(context, inst); if (!info) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); - // Check if this is a TypeKind data type with SetOfTypes judgment - if (info.judgment == PropagationJudgment::Set) + if (info.isSingleton()) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(info.getSingletonValue()); + inst->removeAndDeallocate(); + return true; + } + else if (info.judgment == PropagationJudgment::Set) { + // Set of types. if (inst->getDataType()->getOp() == kIROp_TypeKind) { // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(info)); + auto typeSet = collectionToHashSet(info); + auto unionType = typeSet.getCount() > 1 ? createAnyValueTypeFromInsts(typeSet) + : *typeSet.begin(); // Store the mapping for later use - loweredInstToAnyValueType[inst] = anyValueType; + loweredInstToAnyValueType[inst] = unionType; // Replace the instruction with the any-value type - inst->replaceUsesWith(anyValueType); + inst->replaceUsesWith(unionType); inst->removeAndDeallocate(); return true; } @@ -1382,7 +1571,7 @@ struct DynamicInstLoweringContext { // Get the witness table operand info auto witnessTableInst = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(witnessTableInst); + auto witnessTableInfo = tryGetInfo(context, witnessTableInst); if (witnessTableInfo.judgment == PropagationJudgment::Set) { @@ -1398,7 +1587,7 @@ struct DynamicInstLoweringContext keyMappingFunc, List({inst->getWitnessTable()})); inst->replaceUsesWith(witnessTableId); - propagationMap[witnessTableId] = info; + propagationMap[Element(witnessTableId)] = info; inst->removeAndDeallocate(); return true; } @@ -1408,31 +1597,40 @@ struct DynamicInstLoweringContext return false; } - bool lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) + bool lowerExtractExistentialWitnessTable( + IRInst* context, + IRExtractExistentialWitnessTable* inst) { - auto info = tryGetInfo(inst); + auto info = tryGetInfo(context, inst); if (!info) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); - if (info.judgment == PropagationJudgment::Set) + if (info.isSingleton()) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(info.getSingletonValue()); + inst->removeAndDeallocate(); + return true; + } + else if (info.judgment == PropagationJudgment::Set) { // Replace with GetElement(loweredInst, 0) -> uint auto operand = inst->getOperand(0); auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); inst->replaceUsesWith(element); - propagationMap[element] = info; + propagationMap[Element(element)] = info; inst->removeAndDeallocate(); return true; } return false; } - bool lowerExtractExistentialValue(IRExtractExistentialValue* inst) + bool lowerExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { - auto operandInfo = tryGetInfo(inst->getOperand(0)); + auto operandInfo = tryGetInfo(context, inst->getOperand(0)); if (!operandInfo || operandInfo.judgment != PropagationJudgment::Existential) return false; @@ -1444,7 +1642,7 @@ struct DynamicInstLoweringContext auto loweredType = loweredInstToAnyValueType.tryGetValue(inst); if (loweredType) { - resultType = *loweredType; + resultType = (IRType*)*loweredType; } // Replace with GetElement(loweredInst, 1) -> AnyValueType @@ -1455,15 +1653,24 @@ struct DynamicInstLoweringContext return true; } - bool lowerExtractExistentialType(IRExtractExistentialType* inst) + bool lowerExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { - auto info = tryGetInfo(inst); + auto info = tryGetInfo(context, inst); if (!info || info.judgment != PropagationJudgment::Set) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); + if (info.isSingleton()) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(info.getSingletonValue()); + inst->removeAndDeallocate(); + loweredInstToAnyValueType[inst] = info.getSingletonValue(); + return true; + } + // Create an any-value type based on the set of types auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(info)); @@ -1476,7 +1683,7 @@ struct DynamicInstLoweringContext return true; } - IRFuncType* getExpectedFuncType(IRCall* inst) + IRFuncType* getExpectedFuncType(IRInst* context, IRCall* inst) { // Translate argument types into expected function type. // For now, we handle just 'in' arguments. @@ -1484,7 +1691,7 @@ struct DynamicInstLoweringContext for (UInt i = 0; i < inst->getArgCount(); i++) { auto arg = inst->getArg(i); - if (auto argInfo = tryGetInfo(arg)) + if (auto argInfo = tryGetInfo(context, arg)) { // If the argument is existential, we need to use the type for existential if (argInfo.judgment == PropagationJudgment::Existential) @@ -1499,7 +1706,7 @@ struct DynamicInstLoweringContext // Translate result type. IRType* resultType = inst->getDataType(); - auto returnInfo = tryGetInfo(inst); + auto returnInfo = tryGetInfo(context, inst); if (returnInfo && returnInfo.judgment == PropagationJudgment::Existential) { resultType = getTypeForExistential(returnInfo); @@ -1510,21 +1717,28 @@ struct DynamicInstLoweringContext return builder.getFuncType(argTypes, resultType); } - bool lowerCall(IRCall* inst) + bool lowerCall(IRInst* context, IRCall* inst) { auto callee = inst->getCallee(); - auto calleeInfo = tryGetInfo(callee); + auto calleeInfo = tryGetInfo(context, callee); if (!calleeInfo || calleeInfo.judgment != PropagationJudgment::Set) return false; - if (calleeInfo.isSingleton() && calleeInfo.getSingletonValue() == callee) - return false; + if (calleeInfo.isSingleton()) + { + if (calleeInfo.getSingletonValue() == callee) + return false; + + IRBuilder builder(inst->getModule()); + builder.replaceOperand(inst->getCalleeUse(), calleeInfo.getSingletonValue()); + return true; // Replaced with a single function + } IRBuilder builder(inst); builder.setInsertBefore(inst); - auto expectedFuncType = getExpectedFuncType(inst); + auto expectedFuncType = getExpectedFuncType(context, inst); // Create dispatch function auto dispatchFunc = createDispatchFunc(collectionToHashSet(calleeInfo), expectedFuncType); @@ -1538,16 +1752,16 @@ struct DynamicInstLoweringContext auto newCall = builder.emitCallInst(inst->getDataType(), dispatchFunc, newArgs); inst->replaceUsesWith(newCall); - if (auto info = tryGetInfo(inst)) - propagationMap[newCall] = info; - replaceType(newCall); // "maybe replace type" + if (auto info = tryGetInfo(context, inst)) + propagationMap[Element(newCall)] = info; + replaceType(context, newCall); // "maybe replace type" inst->removeAndDeallocate(); return true; } - bool lowerMakeExistential(IRMakeExistential* inst) + bool lowerMakeExistential(IRInst* context, IRMakeExistential* inst) { - auto info = tryGetInfo(inst); + auto info = tryGetInfo(context, inst); if (!info || info.judgment != PropagationJudgment::Existential) return false; @@ -1574,10 +1788,11 @@ struct DynamicInstLoweringContext }); // Create the appropriate any-value type - auto anyValueType = createAnyValueType(types); + SLANG_ASSERT(types.getCount() > 0); + auto unionType = types.getCount() > 1 ? createAnyValueType(types) : *types.begin(); // Pack the value - auto packedValue = builder.emitPackAnyValue(anyValueType, inst->getWrappedValue()); + auto packedValue = builder.emitPackAnyValue(unionType, inst->getWrappedValue()); // Create tuple (table_unique_id, PackAnyValue(val)) auto tupleType = builder.getTupleType( @@ -1585,17 +1800,17 @@ struct DynamicInstLoweringContext IRInst* tupleArgs[] = {tableId, packedValue}; auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); - if (auto info = tryGetInfo(inst)) - propagationMap[tuple] = info; + if (auto info = tryGetInfo(context, inst)) + propagationMap[Element(tuple)] = info; inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); return true; } - bool lowerCreateExistentialObject(IRCreateExistentialObject* inst) + bool lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { - auto info = tryGetInfo(inst); + auto info = tryGetInfo(context, inst); if (!info || info.judgment != PropagationJudgment::Existential) return false; @@ -1629,8 +1844,8 @@ struct DynamicInstLoweringContext {translatedID, builder.emitReinterpret(existentialTupleType->getOperand(1), inst->getValue())})); - if (auto info = tryGetInfo(inst)) - propagationMap[existentialTuple] = info; + if (auto info = tryGetInfo(context, inst)) + propagationMap[Element(existentialTuple)] = info; inst->replaceUsesWith(existentialTuple); inst->removeAndDeallocate(); @@ -1781,24 +1996,25 @@ struct DynamicInstLoweringContext builder.emitReturn(defaultValue); } - auto maybePackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* + auto maybeReinterpret = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* { - // If the type is AnyValueType, pack the value if (as(type) && !as(value->getDataType())) { return builder->emitPackAnyValue(type, value); } - return value; // Otherwise, return as is - }; - - auto maybeUnpackValue = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* - { - // If the type is AnyValueType, unpack the value - if (as(value->getDataType()) && !as(type)) + else if (as(value->getDataType()) && !as(type)) { return builder->emitUnpackAnyValue(type, value); } - return value; // Otherwise, return as is + else if (value->getDataType() != type) + { + // If the value's type is different from the expected type, reinterpret it + return builder->emitReinterpret(type, value); + } + else + { + return value; // Otherwise, return as is + } }; // Go back to entry block and create switch @@ -1820,7 +2036,7 @@ struct DynamicInstLoweringContext auto concreteFuncType = as(funcInst->getDataType()); for (UIndex ii = 0; ii < originalParams.getCount(); ii++) { - callArgs.add(maybeUnpackValue( + callArgs.add(maybeReinterpret( &builder, originalParams[ii], concreteFuncType->getParamType(ii))); @@ -1836,7 +2052,7 @@ struct DynamicInstLoweringContext } else { - builder.emitReturn(maybePackValue(&builder, callResult, resultType)); + builder.emitReturn(maybeReinterpret(&builder, callResult, resultType)); } caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); @@ -1974,14 +2190,17 @@ struct DynamicInstLoweringContext Dictionary funcReturnInfo; // Mapping from functions to call-sites. - Dictionary> funcCallSites; + Dictionary> funcCallSites; // Unique ID assignment for functions and witness tables Dictionary uniqueIds; UInt nextUniqueId = 1; // Mapping from lowered instruction to their any-value types - Dictionary loweredInstToAnyValueType; + Dictionary loweredInstToAnyValueType; + + // Set of open contexts + HashSet availableContexts; }; // Main entry point diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang index e14e3a1c063..4721387f0f5 100644 --- a/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang +++ b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-2.slang @@ -1,10 +1,15 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -RWStructuredBuffer outputBuffer; +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):name=scratchBuffer +RWStructuredBuffer scratchBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; interface IBufferRef { - uint8_t[N] get(uint index); - void set(uint index, uint8_t[N] value); + uint[N] get(uint index); + void set(uint index, uint[N] value); This withOffset(uint offset); }; @@ -26,12 +31,11 @@ struct StandardSerializer : ISerializer { static void serialize(T data, IBufferRef buffer) { - buffer.set<4>(0, bit_cast(data)); + buffer.set<1>(0, bit_cast(data)); } - static T deserialize(IBufferRef buffer) { - return bit_cast(buffer.get<4>(0)); + return bit_cast(buffer.get<1>(0)); } }; @@ -54,13 +58,13 @@ struct BDataSerializer : ISerializer static void serialize(BData data, IBufferRef buffer) { StandardSerializer::serialize(data.x, buffer.withOffset(0)); - StandardSerializer::serialize(data.y, buffer.withOffset(4)); + StandardSerializer::serialize(data.y, buffer.withOffset(1)); } static BData deserialize(IBufferRef buffer) { return BData(StandardSerializer::deserialize(buffer.withOffset(0)), - StandardSerializer::deserialize(buffer.withOffset(4))); + StandardSerializer::deserialize(buffer.withOffset(1))); } }; @@ -80,7 +84,7 @@ struct C : ICalculation Data make(float q) { return q; } }; -ICalculation factoryAB(uint id, float x) +ICalculation factoryAB(uint id) { if (id == 0) return A(); @@ -91,53 +95,56 @@ ICalculation factoryAB(uint id, float x) struct BufferRef : IBufferRef { uint offset; - uint8_t[N] get(uint index) + uint[N] get(uint index) { - uint offset = index * N; - uint8_t[N] result; + uint[N] result; for (int i = 0; i < N; ++i) - { - result[i] = outputBuffer[offset + i]; - } + result[i] = scratchBuffer[offset + i]; return result; } - void set(uint index, uint8_t[N] value) + void set(uint index, uint[N] value) { - uint offset = index * N; for (int i = 0; i < N; ++i) - { - outputBuffer[offset + i] = value[i]; - } + scratchBuffer[offset + i] = value[i]; } This withOffset(uint offset) { - return {offset + this.offset}; + return {this.offset + offset}; } }; -IBufferRef getOutputBufferAsRef() +IBufferRef getScratchBufferAsRef() { return BufferRef(0); } -float f(uint id, float x) +void f(uint id, float q) { - let obj = factoryAB(id, x); - obj.Data d = obj.make(x); - obj.DataSerializer::serialize(d, getOutputBufferAsRef()); + let obj = factoryAB(id); + obj.Data d = obj.make(q); + IBufferRef buf = getScratchBufferAsRef(); + buf.set<1>(0, {(uint)id}); + obj.DataSerializer::serialize(d, buf.withOffset(1)); } -float g(uint id, float x) +float g(float x) { - let obj = factoryAB(id, x); - obj.Data d = obj.DataSerializer::deserialize(getOutputBufferAsRef()); + IBufferRef buf = getScratchBufferAsRef(); + uint id = buf.get<1>(0)[0]; + let obj = factoryAB(id); + obj.Data d = obj.DataSerializer::deserialize(buf.withOffset(1)); return obj.calc(d, x); } [numthreads(1, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - f(0, 1.f); - g(0, 1.f); + f(0, 3.f); + outputBuffer[0] = g(2.f); + f(1, 2.f); + outputBuffer[1] = g(2.f); + + // CHECK: 12.0 + // CHECK: 10.0 } \ No newline at end of file From 0f81310529cb98011b877fedf917f19b09300f9a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:04:25 -0400 Subject: [PATCH 012/105] Add lowering of specialization contexts by specializing generic parameters into dynamically-dispatched objects --- source/slang/slang-ir-lower-dynamic-insts.cpp | 404 +++++++++++++++--- .../dynamic-specialization.slang | 49 +++ 2 files changed, 398 insertions(+), 55 deletions(-) create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-specialization.slang diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index a100a4be550..26f3beed9d1 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1,6 +1,7 @@ #include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-any-value-marshalling.h" +#include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -400,9 +401,13 @@ struct DynamicInstLoweringContext PropagationInfo tryGetInfo(IRInst* context, IRInst* inst) { + if (!inst->getParent()) + return none(); + // If this is a global instruction (parent is module), return concrete info if (as(inst->getParent())) - if (as(inst) || as(inst) || as(inst)) + if (as(inst) || as(inst) || as(inst) || + as(inst)) return makeSingletonSet(inst); else return none(); @@ -939,9 +944,8 @@ struct DynamicInstLoweringContext switch (argInfo.judgment) { case PropagationJudgment::None: + return PropagationJudgment::None; // Can't determine the result just yet. case PropagationJudgment::Unbounded: - SLANG_UNEXPECTED( - "Unexpected PropagationJudgment for specialization argument"); case PropagationJudgment::Existential: SLANG_UNEXPECTED( "Unexpected Existential operand in specialization argument. Should be " @@ -1121,7 +1125,7 @@ struct DynamicInstLoweringContext discoverContext(callee, workQueue); this->funcCallSites.addIfNotExists(callee, HashSet()); - if (this->funcCallSites[callee].add(inst)) + if (this->funcCallSites[callee].add(Element(context, inst))) { // If this is a new call site, add a propagation task to the queue (in case there's // already information about this function) @@ -1357,7 +1361,7 @@ struct DynamicInstLoweringContext return none(); // Default case, no propagation info } - bool performDynamicInstLowering() + bool lowerInstsInFunc(IRFunc* func) { // Collect all instructions that need lowering List typeInstsToLower; @@ -1366,72 +1370,72 @@ struct DynamicInstLoweringContext List funcTypesToProcess; bool hasChanges = false; - for (auto globalInst : module->getGlobalInsts()) + auto context = func; + // Process each function's instructions + for (auto block : func->getBlocks()) { - if (auto func = as(globalInst)) + for (auto child : block->getChildren()) { - auto context = func; - // Process each function's instructions - for (auto block : func->getBlocks()) + if (as(child)) + continue; // Skip parameters and terminators + + switch (child->getOp()) { - for (auto child : block->getChildren()) + case kIROp_LookupWitnessMethod: { - if (as(child)) - continue; // Skip parameters and terminators - - switch (child->getOp()) - { - case kIROp_LookupWitnessMethod: - { - if (child->getDataType()->getOp() == kIROp_TypeKind) - typeInstsToLower.add(Element(context, child)); - else - valueInstsToLower.add(Element(context, child)); - break; - } - case kIROp_ExtractExistentialType: + if (child->getDataType()->getOp() == kIROp_TypeKind) typeInstsToLower.add(Element(context, child)); - break; - case kIROp_ExtractExistentialWitnessTable: - case kIROp_ExtractExistentialValue: - case kIROp_MakeExistential: - case kIROp_CreateExistentialObject: + else valueInstsToLower.add(Element(context, child)); - break; - case kIROp_Call: - { - if (auto info = tryGetInfo(context, child)) - if (info.judgment == PropagationJudgment::Existential) - instWithReplacementTypes.add(Element(context, child)); - - if (auto calleeInfo = - tryGetInfo(context, as(child)->getCallee())) - if (calleeInfo.judgment == PropagationJudgment::Set) - valueInstsToLower.add(Element(context, child)); - } - break; - default: - if (auto info = tryGetInfo(context, child)) - if (info.judgment == PropagationJudgment::Existential) - // If this instruction has a set of types, tables, or funcs, - // we need to lower it to a unified type. - instWithReplacementTypes.add(Element(context, child)); - } + break; } - } + case kIROp_ExtractExistentialType: + typeInstsToLower.add(Element(context, child)); + break; + case kIROp_ExtractExistentialWitnessTable: + case kIROp_ExtractExistentialValue: + case kIROp_MakeExistential: + case kIROp_CreateExistentialObject: + valueInstsToLower.add(Element(context, child)); + break; + case kIROp_Call: + { + auto callee = as(child)->getCallee(); + if (auto info = tryGetInfo(context, child)) + if (info.judgment == PropagationJudgment::Existential) + instWithReplacementTypes.add(Element(context, child)); + + if (auto calleeInfo = tryGetInfo(context, callee)) + if (calleeInfo.judgment == PropagationJudgment::Set) + valueInstsToLower.add(Element(context, child)); - funcTypesToProcess.add(func); + if (as(callee)) + valueInstsToLower.add(Element(context, child)); + } + break; + default: + if (auto info = tryGetInfo(context, child)) + if (info.judgment == PropagationJudgment::Existential) + // If this instruction has a set of types, tables, or funcs, + // we need to lower it to a unified type. + instWithReplacementTypes.add(Element(context, child)); + } } } for (auto instWithCtx : typeInstsToLower) + { + if (instWithCtx.inst->getParent() == nullptr) + continue; hasChanges |= lowerInst(instWithCtx.context, instWithCtx.inst); - - for (auto func : funcTypesToProcess) - hasChanges |= replaceFuncType(func, this->funcReturnInfo[func]); + } for (auto instWithCtx : valueInstsToLower) + { + if (instWithCtx.inst->getParent() == nullptr) + continue; hasChanges |= lowerInst(instWithCtx.context, instWithCtx.inst); + } for (auto instWithCtx : instWithReplacementTypes) { @@ -1443,6 +1447,54 @@ struct DynamicInstLoweringContext return hasChanges; } + bool performDynamicInstLowering() + { + List funcsForTypeReplacement; + List funcsToProcess; + + for (auto globalInst : module->getGlobalInsts()) + if (auto func = as(globalInst)) + { + funcsForTypeReplacement.add(func); + funcsToProcess.add(func); + } + + bool hasChanges = false; + do + { + while (funcsForTypeReplacement.getCount() > 0) + { + auto func = funcsForTypeReplacement.getLast(); + funcsForTypeReplacement.removeLast(); + + // Replace the function type with a concrete type if it has existential return types + hasChanges |= replaceFuncType(func, this->funcReturnInfo[func]); + } + + while (funcsToProcess.getCount() > 0) + { + auto func = funcsToProcess.getLast(); + funcsToProcess.removeLast(); + + // Lower the instructions in the function + hasChanges |= lowerInstsInFunc(func); + } + + // The above loops might have added new contexts to lower. + for (auto context : this->contextsToLower) + { + hasChanges |= lowerContext(context); + auto newFunc = cast(this->loweredContexts[context]); + funcsForTypeReplacement.add(newFunc); + funcsToProcess.add(newFunc); + } + this->contextsToLower.clear(); + + } while (funcsForTypeReplacement.getCount() > 0 || funcsToProcess.getCount() > 0); + + return hasChanges; + } + bool replaceFuncType(IRFunc* func, PropagationInfo& returnTypeInfo) { IRFuncType* origFuncType = as(func->getFullType()); @@ -1717,6 +1769,239 @@ struct DynamicInstLoweringContext return builder.getFuncType(argTypes, resultType); } + bool isDynamicGeneric(IRInst* callee) + { + // If the callee is a specialization, and at least one of its arguments + // is a type-flow-collection, then it is a dynamic generic. + // + if (auto specialize = as(callee)) + { + for (UInt i = 0; i < specialize->getArgCount(); i++) + { + auto arg = specialize->getArg(i); + if (as(arg)) + return true; // Found a type-flow-collection argument + } + return false; // No type-flow-collection arguments found + } + + return false; + } + + bool lowerContext(IRInst* context) + { + auto specializeInst = cast(context); + auto generic = cast(specializeInst->getBase()); + auto genericReturnVal = findGenericReturnVal(generic); + + IRBuilder builder(module); + builder.setInsertInto(module); + + // Let's start by creating the function itself. + auto loweredFunc = builder.createFunc(); + builder.setInsertInto(loweredFunc); + builder.setInsertInto(builder.emitBlock()); + // loweredFunc->setFullType(context->getFullType()); + + IRCloneEnv cloneEnv; + Index argIndex = 0; + UCount extraIndices = 0; + // Map the generic's parameters to the specialized arguments. + for (auto param : generic->getFirstBlock()->getParams()) + { + auto specArg = specializeInst->getArg(argIndex++); + if (as(specArg)) + { + // We're dealing with a set of types. + if (as(param->getDataType())) + { + HashSet collectionSet; + for (auto index = 0; index < specArg->getOperandCount(); index++) + { + auto operand = specArg->getOperand(index); + collectionSet.add(operand); + } + + auto unionType = createAnyValueTypeFromInsts(collectionSet); + cloneEnv.mapOldValToNew[param] = unionType; + } + else if (as(param->getDataType())) + { + // Add an integer param to the func. + cloneEnv.mapOldValToNew[param] = builder.emitParam(builder.getUIntType()); + extraIndices++; + } + } + else + { + // For everything else, just set the parameter type to the argument; + SLANG_ASSERT(specArg->getParent()->getOp() == kIROp_ModuleInst); + cloneEnv.mapOldValToNew[param] = specArg; + } + } + + // Clone in the rest of the generic's body including the blocks of the returned func. + for (auto inst = generic->getFirstBlock()->getFirstOrdinaryInst(); inst; + inst = inst->getNextInst()) + { + if (inst == genericReturnVal) + { + auto returnedFunc = cast(inst); + auto funcFirstBlock = returnedFunc->getFirstBlock(); + + // cloneEnv.mapOldValToNew[funcFirstBlock] = loweredFunc->getFirstBlock(); + builder.setInsertInto(loweredFunc); + for (auto block : returnedFunc->getBlocks()) + { + // Merge the first block of the generic with the first block of the + // returned function to merge the parameter lists. + // + // if (block != funcFirstBlock) + //{ + cloneEnv.mapOldValToNew[block] = + cloneInstAndOperands(&cloneEnv, &builder, block); + //} + } + + builder.setInsertInto(loweredFunc->getFirstBlock()); + builder.emitBranch(as(cloneEnv.mapOldValToNew[funcFirstBlock])); + + for (auto param : funcFirstBlock->getParams()) + { + // Clone the parameters of the first block. + builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); + cloneInst(&cloneEnv, &builder, param); + } + + builder.setInsertInto(as(cloneEnv.mapOldValToNew[funcFirstBlock])); + for (auto inst = funcFirstBlock->getFirstOrdinaryInst(); inst; + inst = inst->getNextInst()) + { + // Clone the instructions in the first block. + cloneInst(&cloneEnv, &builder, inst); + } + + for (auto block : returnedFunc->getBlocks()) + { + if (block == funcFirstBlock) + continue; // Already cloned the first block + cloneInstDecorationsAndChildren( + &cloneEnv, + builder.getModule(), + block, + cloneEnv.mapOldValToNew[block]); + } + + builder.setInsertInto(builder.getModule()); + auto loweredFuncType = as( + cloneInst(&cloneEnv, &builder, as(returnedFunc->getFullType()))); + + // Add extra indices to the func-type parameters + List funcTypeParams; + for (Index i = 0; i < extraIndices; i++) + funcTypeParams.add(builder.getUIntType()); + + for (auto paramType : loweredFuncType->getParamTypes()) + funcTypeParams.add(paramType); + + // Set the new function type with the extra indices + loweredFunc->setFullType( + builder.getFuncType(funcTypeParams, loweredFuncType->getResultType())); + } + else if (!as(inst)) + { + // Keep cloning insts in the generic + cloneInst(&cloneEnv, &builder, inst); + } + } + + // Transfer propagation info. + for (auto& [oldVal, newVal] : cloneEnv.mapOldValToNew) + { + if (propagationMap.containsKey(Element(context, oldVal))) + { + // If we have propagation info for the old value, transfer it to the new value + if (auto info = propagationMap[Element(context, oldVal)]) + { + if (newVal->getParent()->getOp() != kIROp_ModuleInst) + propagationMap[Element(loweredFunc, newVal)] = info; + } + } + } + + // Transfer func-return value info. + if (this->funcReturnInfo.containsKey(context)) + { + this->funcReturnInfo[loweredFunc] = this->funcReturnInfo[context]; + } + + context->replaceUsesWith(loweredFunc); + // context->removeAndDeallocate(); + this->loweredContexts[context] = loweredFunc; + return true; + } + + IRInst* getCalleeForContext(IRInst* context) + { + if (this->contextsToLower.contains(context)) + return context; // Not lowered yet. + + if (this->loweredContexts.containsKey(context)) + return this->loweredContexts[context]; + else + this->contextsToLower.add(context); + + return context; + } + + bool lowerCallToDynamicGeneric(IRInst* context, IRCall* inst) + { + auto specializedCallee = as(inst->getCallee()); + auto targetContext = tryGetInfo(context, specializedCallee).getSingletonValue(); + + List callArgs; + for (auto ii = 0; ii < specializedCallee->getArgCount(); ii++) + { + auto specArg = specializedCallee->getArg(ii); + auto argInfo = tryGetInfo(context, specArg); + if (argInfo.judgment == PropagationJudgment::Set) + { + auto collection = as(argInfo.collection); + if (as(collection->getOperand(0))) + { + // Needs an index (spec-arg will carry an index, we'll + // just need to append it to the call) + // + callArgs.add(specArg); + } + else if (as(collection->getOperand(0))) + { + // Needs no dynamic information. Skip. + } + else + { + // If it's a witness table, we need to handle it differently + // For now, we will not lower this case. + SLANG_UNEXPECTED("Unhandled type-flow-collection in dynamic generic call"); + } + } + } + + for (auto ii = 0; ii < inst->getArgCount(); ii++) + callArgs.add(inst->getArg(ii)); + + IRBuilder builder(inst->getModule()); + // builder.replaceOperand(inst->getCalleeUse(), specializedCallee); + builder.setInsertBefore(inst); + auto newCallInst = builder.emitCallInst( + as(targetContext->getDataType())->getResultType(), + getCalleeForContext(targetContext), + callArgs); + inst->replaceUsesWith(newCallInst); + inst->removeAndDeallocate(); + return true; + } + bool lowerCall(IRInst* context, IRCall* inst) { auto callee = inst->getCallee(); @@ -1727,6 +2012,9 @@ struct DynamicInstLoweringContext if (calleeInfo.isSingleton()) { + if (isDynamicGeneric(calleeInfo.getSingletonValue())) + return lowerCallToDynamicGeneric(context, inst); + if (calleeInfo.getSingletonValue() == callee) return false; @@ -2201,6 +2489,12 @@ struct DynamicInstLoweringContext // Set of open contexts HashSet availableContexts; + + // Contexts requiring lowering + HashSet contextsToLower; + + // Lowered contexts. + Dictionary loweredContexts; }; // Main entry point diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization.slang new file mode 100644 index 00000000000..6be5c9c6acb --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + static float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + static float calc(float x) { return x; } +}; + +float calc(T obj, float y) +{ + return T::calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return calc(obj, x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 +} \ No newline at end of file From 9c53d1fab3dd04f617c9cd4fb1f4d47aa879f15f Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:45:49 -0400 Subject: [PATCH 013/105] Update slang-emit.cpp --- source/slang/slang-emit.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index dad2db1fc62..ccb055e4d69 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1045,12 +1045,10 @@ Result linkAndOptimizeIR( if (!codeGenContext->isSpecializationDisabled()) { SpecializationOptions specOptions; - specOptions.lowerWitnessLookups = false; + specOptions.lowerWitnessLookups = true; specializeModule(targetProgram, irModule, codeGenContext->getSink(), specOptions); } - lowerDynamicInsts(irModule, sink); - finalizeSpecialization(irModule); // Lower `Result` types into ordinary struct types. This must happen @@ -2205,10 +2203,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr if (sourceMap) { - auto sourceMapArtifact = ArtifactUtil::createArtifact(ArtifactDesc::make( - ArtifactKind::Json, - ArtifactPayload::SourceMap, - ArtifactStyle::None)); + auto sourceMapArtifact = ArtifactUtil::createArtifact( + ArtifactDesc::make( + ArtifactKind::Json, + ArtifactPayload::SourceMap, + ArtifactStyle::None)); sourceMapArtifact->addRepresentation(sourceMap); From c3f57dff973493fc41444fdda3b2ebe3a8bb6320 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:53:48 -0400 Subject: [PATCH 014/105] Support data-flow and lowering for inout/out function parameters --- source/slang/slang-ir-lower-dynamic-insts.cpp | 323 ++++++++++++--- .../slang/slang-ir-witness-table-wrapper.cpp | 368 ++++++++++-------- source/slang/slang-ir-witness-table-wrapper.h | 8 + ...n.slang => dynamic-specialization-1.slang} | 0 .../dynamic-specialization-2.slang | 74 ++++ .../dynamic-specialization-3.slang | 61 +++ .../dynamic-dispatch/inout.slang | 76 ++++ 7 files changed, 710 insertions(+), 200 deletions(-) rename tests/language-feature/dynamic-dispatch/{dynamic-specialization.slang => dynamic-specialization-1.slang} (100%) create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang create mode 100644 tests/language-feature/dynamic-dispatch/inout.slang diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 26f3beed9d1..5d1fec785f7 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -4,6 +4,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" +#include "slang-ir-witness-table-wrapper.h" #include "slang-ir.h" namespace Slang @@ -483,6 +484,27 @@ struct DynamicInstLoweringContext { updateFuncReturnInfo(context, info, workQueue); } + + // If the user is a top-level inout/out parameter, we need to handle it + // like we would a func-return. + // + if (auto param = as(user)) + { + auto paramBlock = as(param->getParent()); + auto paramFunc = as(paramBlock->getParent()); + if (paramFunc && paramFunc->getFirstBlock() == paramBlock) + { + if (this->funcCallSites.containsKey(context)) + for (auto callSite : this->funcCallSites[context]) + { + workQueue.addLast(WorkItem( + InterproceduralEdge::Direction::FuncToCall, + callSite.context, + as(callSite.inst), + context)); + } + } + } } } @@ -671,11 +693,14 @@ struct DynamicInstLoweringContext } List callInsts; + List storeInsts; // Collect all call instructions in this block for (auto inst : block->getChildren()) { if (auto callInst = as(inst)) callInsts.add(callInst); + else if (auto storeInst = as(inst)) + storeInsts.add(storeInst); } // Look at all the args and reinterpret them if necessary @@ -702,6 +727,21 @@ struct DynamicInstLoweringContext } } } + + // Look at all the stores and reinterpret them if necessary + for (auto storeInst : storeInsts) + { + auto newValToStore = maybeReinterpret( + context, + storeInst->getVal(), + tryGetInfo(storeInst->getPtr())); + if (newValToStore != storeInst->getVal()) + { + // Replace the value in the store instruction + changed = true; + storeInst->setOperand(1, newValToStore); + } + } } } } @@ -741,6 +781,12 @@ struct DynamicInstLoweringContext case kIROp_Specialize: info = analyzeSpecialize(context, as(inst)); break; + case kIROp_Load: + info = analyzeLoad(context, as(inst)); + break; + case kIROp_Store: + info = analyzeStore(context, as(inst), workQueue); + break; default: info = analyzeDefault(context, inst); break; @@ -806,6 +852,24 @@ struct DynamicInstLoweringContext return nullptr; // Not found } + PropagationInfo analyzeLoad(IRInst* context, IRLoad* loadInst) + { + // Transfer the prop info from the address to the loaded value + auto address = loadInst->getPtr(); + return tryGetInfo(context, address); + } + + PropagationInfo analyzeStore( + IRInst* context, + IRStore* storeInst, + LinkedList& workQueue) + { + // Transfer the prop info from stored value to the address + auto address = storeInst->getPtr(); + updateInfo(context, address, tryGetInfo(context, storeInst->getVal()), workQueue); + return none(); // The store itself doesn't have any info. + } + PropagationInfo analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { auto key = inst->getRequirementKey(); @@ -1185,6 +1249,60 @@ struct DynamicInstLoweringContext } } + List getParamInfos(IRInst* context) + { + List infos; + if (as(context)) + { + for (auto param : as(context)->getParams()) + infos.add(tryGetInfo(context, param)); + } + else if (auto specialize = as(context)) + { + auto generic = specialize->getBase(); + auto innerFunc = getGenericReturnVal(generic); + for (auto param : as(innerFunc)->getParams()) + infos.add(tryGetInfo(context, param)); + } + else + { + // If it's not a function or a specialization, we can't get parameter info + SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); + } + + return infos; + } + + List getParamDirections(IRInst* context) + { + List directions; + if (as(context)) + { + for (auto param : as(context)->getParams()) + { + const auto [direction, type] = getParameterDirectionAndType(param->getDataType()); + directions.add(direction); + } + } + else if (auto specialize = as(context)) + { + auto generic = specialize->getBase(); + auto innerFunc = getGenericReturnVal(generic); + for (auto param : as(innerFunc)->getParams()) + { + const auto [direction, type] = getParameterDirectionAndType(param->getDataType()); + directions.add(direction); + } + } + else + { + // If it's not a function or a specialization, we can't get parameter info + SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); + } + + return directions; + } + void propagateInterproceduralEdge(InterproceduralEdge edge, LinkedList& workQueue) { // Handle interprocedural edge @@ -1225,6 +1343,25 @@ struct DynamicInstLoweringContext // Use centralized update method updateInfo(edge.callerContext, callInst, *returnInfo, workQueue); } + + // Also update infos of any out parameters + auto paramInfos = getParamInfos(edge.targetContext); + auto paramDirections = getParamDirections(edge.targetContext); + UIndex argIndex = 0; + for (auto paramInfo : paramInfos) + { + if (paramDirections[argIndex] == kParameterDirection_Out || + paramDirections[argIndex] == kParameterDirection_InOut) + { + updateInfo( + edge.callerContext, + callInst->getArg(argIndex), + paramInfo, + workQueue); + } + argIndex++; + } + break; } default: @@ -1556,7 +1693,16 @@ struct DynamicInstLoweringContext if (!info || info.judgment != PropagationJudgment::Existential) return false; - inst->setFullType(getTypeForExistential(info)); + if (auto ptrType = as(inst->getDataType())) + { + IRBuilder builder(module); + inst->setFullType( + builder.getPtrTypeWithAddressSpace(getTypeForExistential(info), ptrType)); + } + else + { + inst->setFullType(getTypeForExistential(info)); + } return true; } @@ -1735,25 +1881,86 @@ struct DynamicInstLoweringContext return true; } + // Split into direction and type + std::tuple getParameterDirectionAndType(IRType* paramType) + { + if (as(paramType)) + return { + ParameterDirection::kParameterDirection_Out, + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirection::kParameterDirection_InOut, + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirection::kParameterDirection_Ref, + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirection::kParameterDirection_ConstRef, + as(paramType)->getValueType()}; + else + return {ParameterDirection::kParameterDirection_In, paramType}; + } + IRFuncType* getExpectedFuncType(IRInst* context, IRCall* inst) { + IRBuilder builder(module); + builder.setInsertInto(module); + + // We'll retreive just the parameter directions from the callee's func-type, + // since that can't be different before & after the type-flow lowering. + // + List paramDirections; + auto calleeInfo = tryGetInfo(context, inst->getCallee()); + auto funcType = as(calleeInfo.getCollectionElement(0)->getDataType()); + for (auto paramType : funcType->getParamTypes()) + { + auto [direction, type] = getParameterDirectionAndType(paramType); + paramDirections.add(direction); + } + // Translate argument types into expected function type. - // For now, we handle just 'in' arguments. - List argTypes; + List paramTypes; for (UInt i = 0; i < inst->getArgCount(); i++) { auto arg = inst->getArg(i); - if (auto argInfo = tryGetInfo(context, arg)) + + switch (paramDirections[i]) { - // If the argument is existential, we need to use the type for existential - if (argInfo.judgment == PropagationJudgment::Existential) + case ParameterDirection::kParameterDirection_In: { - argTypes.add(getTypeForExistential(argInfo)); - continue; + auto argInfo = tryGetInfo(context, arg); + if (argInfo.judgment == PropagationJudgment::Existential) + paramTypes.add(getTypeForExistential(argInfo)); + else + paramTypes.add(arg->getDataType()); + break; + } + case ParameterDirection::kParameterDirection_Out: + { + auto argInfo = tryGetInfo(context, arg); + if (argInfo.judgment == PropagationJudgment::Existential) + paramTypes.add(builder.getOutType(getTypeForExistential(argInfo))); + else + paramTypes.add(builder.getOutType( + as(arg->getDataType())->getValueType())); + break; } + case ParameterDirection::kParameterDirection_InOut: + { + auto argInfo = tryGetInfo(context, arg); + if (argInfo.judgment == PropagationJudgment::Existential) + paramTypes.add(builder.getInOutType(getTypeForExistential(argInfo))); + else + paramTypes.add(builder.getInOutType( + as(arg->getDataType())->getValueType())); + break; + } + default: + SLANG_UNEXPECTED("Unhandled parameter direction in getExpectedFuncType"); } - - argTypes.add(arg->getDataType()); } // Translate result type. @@ -1764,9 +1971,7 @@ struct DynamicInstLoweringContext resultType = getTypeForExistential(returnInfo); } - IRBuilder builder(module); - builder.setInsertInto(module); - return builder.getFuncType(argTypes, resultType); + return builder.getFuncType(paramTypes, resultType); } bool isDynamicGeneric(IRInst* callee) @@ -2056,9 +2261,23 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - // Get unique ID for the witness table. - auto witnessTable = cast(inst->getWitnessTable()); - auto tableId = builder.getIntValue(builder.getUIntType(), getUniqueID(witnessTable)); + auto witnessTableInfo = tryGetInfo(context, inst->getWitnessTable()); + if (witnessTableInfo.judgment != PropagationJudgment::Set) + return false; // Witness table must be a set of tables + + IRInst* witnessTableID = nullptr; + if (witnessTableInfo.isSingleton()) + { + // Get unique ID for the witness table. + witnessTableID = builder.getIntValue( + builder.getUIntType(), + getUniqueID(witnessTableInfo.getSingletonValue())); + } + else + { + // Dynamic. Use the witness table inst as an integer key. + witnessTableID = inst->getWitnessTable(); + } // Collect types from the witness tables to determine the any-value type HashSet types; @@ -2085,7 +2304,7 @@ struct DynamicInstLoweringContext // Create tuple (table_unique_id, PackAnyValue(val)) auto tupleType = builder.getTupleType( List({builder.getUIntType(), packedValue->getDataType()})); - IRInst* tupleArgs[] = {tableId, packedValue}; + IRInst* tupleArgs[] = {witnessTableID, packedValue}; auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); if (auto info = tryGetInfo(context, inst)) @@ -2284,27 +2503,6 @@ struct DynamicInstLoweringContext builder.emitReturn(defaultValue); } - auto maybeReinterpret = [&](IRBuilder* builder, IRInst* value, IRType* type) -> IRInst* - { - if (as(type) && !as(value->getDataType())) - { - return builder->emitPackAnyValue(type, value); - } - else if (as(value->getDataType()) && !as(type)) - { - return builder->emitUnpackAnyValue(type, value); - } - else if (value->getDataType() != type) - { - // If the value's type is different from the expected type, reinterpret it - return builder->emitReinterpret(type, value); - } - else - { - return value; // Otherwise, return as is - } - }; - // Go back to entry block and create switch builder.setInsertInto(entryBlock); @@ -2316,23 +2514,23 @@ struct DynamicInstLoweringContext { auto funcId = getUniqueID(funcInst); + auto wrapperFunc = + emitWitnessTableWrapper(funcInst->getModule(), funcInst, expectedFuncType); + // Create case block auto caseBlock = builder.emitBlock(); builder.setInsertInto(caseBlock); List callArgs; - auto concreteFuncType = as(funcInst->getDataType()); + auto wrappedFuncType = as(wrapperFunc->getDataType()); for (UIndex ii = 0; ii < originalParams.getCount(); ii++) { - callArgs.add(maybeReinterpret( - &builder, - originalParams[ii], - concreteFuncType->getParamType(ii))); + callArgs.add(originalParams[ii]); } // Call the specific function auto callResult = - builder.emitCallInst(concreteFuncType->getResultType(), funcInst, callArgs); + builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); if (resultType->getOp() == kIROp_VoidType) { @@ -2340,7 +2538,7 @@ struct DynamicInstLoweringContext } else { - builder.emitReturn(maybeReinterpret(&builder, callResult, resultType)); + builder.emitReturn(callResult); } caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); @@ -2444,6 +2642,37 @@ struct DynamicInstLoweringContext return tables; } + bool lowerCollectionTypes() + { + bool hasChanges = false; + + // Lower all global scope ``IRTypeFlowCollection`` objects that + // are made up of types. + // + for (auto inst : module->getGlobalInsts()) + { + if (auto collection = as(inst)) + { + if (collection->getOp() == kIROp_TypeFlowCollection) + { + HashSet types; + for (UInt i = 0; i < collection->getOperandCount(); i++) + { + if (auto type = as(collection->getOperand(i))) + { + types.add(type); + } + } + auto anyValueType = createAnyValueType(types); + collection->replaceUsesWith(anyValueType); + hasChanges = true; + } + } + } + + return hasChanges; + } + bool processModule() { bool hasChanges = false; @@ -2459,6 +2688,10 @@ struct DynamicInstLoweringContext // Phase 2: Dynamic Instruction Lowering hasChanges |= performDynamicInstLowering(); + // Phase 3: Lower collection types. + if (hasChanges) + lowerCollectionTypes(); + return hasChanges; } diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index cfda6225edd..7c573f867e7 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -10,165 +10,15 @@ namespace Slang { struct GenericsLoweringContext; +IRFunc* emitWitnessTableWrapper( + IRModule* module, + IRInst* funcInst, + IRInst* interfaceRequirementVal); struct GenerateWitnessTableWrapperContext { SharedGenericsLoweringContext* sharedContext; - // Represents a work item for packing `inout` or `out` arguments after a concrete call. - struct ArgumentPackWorkItem - { - // A `AnyValue` typed destination. - IRInst* dstArg = nullptr; - // A concrete value to be packed. - IRInst* concreteArg = nullptr; - }; - - // Unpack an `arg` of `IRAnyValue` into concrete type if necessary, to make it feedable into the - // parameter. If `arg` represents a AnyValue typed variable passed in to a concrete `out` - // parameter, this function indicates that it needs to be packed after the call by setting - // `packAfterCall`. - IRInst* maybeUnpackArg( - IRBuilder* builder, - IRType* paramType, - IRInst* arg, - ArgumentPackWorkItem& packAfterCall) - { - packAfterCall.dstArg = nullptr; - packAfterCall.concreteArg = nullptr; - - // If either paramType or argType is a pointer type - // (because of `inout` or `out` modifiers), we extract - // the underlying value type first. - IRType* paramValType = paramType; - IRType* argValType = arg->getDataType(); - IRInst* argVal = arg; - if (auto ptrType = as(paramType)) - { - paramValType = ptrType->getValueType(); - } - auto argType = arg->getDataType(); - if (auto argPtrType = as(argType)) - { - argValType = argPtrType->getValueType(); - argVal = builder->emitLoad(arg); - } - - // Unpack `arg` if the parameter expects concrete type but - // `arg` is an AnyValue. - if (!as(paramValType) && as(argValType)) - { - auto unpackedArgVal = builder->emitUnpackAnyValue(paramValType, argVal); - // if parameter expects an `out` pointer, store the unpacked val into a - // variable and pass in a pointer to that variable. - if (as(paramType)) - { - auto tempVar = builder->emitVar(paramValType); - builder->emitStore(tempVar, unpackedArgVal); - // tempVar needs to be unpacked into original var after the call. - packAfterCall.dstArg = arg; - packAfterCall.concreteArg = tempVar; - return tempVar; - } - else - { - return unpackedArgVal; - } - } - return arg; - } - - IRStringLit* _getWitnessTableWrapperFuncName(IRFunc* func) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(func); - if (auto linkageDecoration = func->findDecoration()) - { - return builder->getStringValue( - (String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); - } - if (auto namehintDecoration = func->findDecoration()) - { - return builder->getStringValue( - (String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); - } - return nullptr; - } - - IRFunc* emitWitnessTableWrapper(IRFunc* func, IRInst* interfaceRequirementVal) - { - auto funcTypeInInterface = cast(interfaceRequirementVal); - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(func); - - auto wrapperFunc = builder->createFunc(); - wrapperFunc->setFullType((IRType*)interfaceRequirementVal); - if (auto name = _getWitnessTableWrapperFuncName(func)) - builder->addNameHintDecoration(wrapperFunc, name); - - builder->setInsertInto(wrapperFunc); - auto block = builder->emitBlock(); - builder->setInsertInto(block); - - ShortList params; - for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) - { - params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); - } - - List args; - List argsToPack; - - SLANG_ASSERT(params.getCount() == (Index)func->getParamCount()); - for (UInt i = 0; i < func->getParamCount(); i++) - { - auto wrapperParam = params[i]; - // Type of the parameter in the callee. - auto funcParamType = func->getParamType(i); - - // If the implementation expects a concrete type - // (either in the form of a pointer for `out`/`inout` parameters, - // or in the form a value for `in` parameters, while - // the interface exposes an AnyValue type, - // we need to unpack the AnyValue argument to the appropriate - // concerete type. - ArgumentPackWorkItem packWorkItem; - auto newArg = maybeUnpackArg(builder, funcParamType, wrapperParam, packWorkItem); - args.add(newArg); - if (packWorkItem.concreteArg) - argsToPack.add(packWorkItem); - } - auto call = builder->emitCallInst(func->getResultType(), func, args); - - // Pack all `out` arguments. - for (auto item : argsToPack) - { - auto anyValType = cast(item.dstArg->getDataType())->getValueType(); - auto concreteVal = builder->emitLoad(item.concreteArg); - auto packedVal = builder->emitPackAnyValue(anyValType, concreteVal); - builder->emitStore(item.dstArg, packedVal); - } - - // Pack return value if necessary. - if (!as(call->getDataType()) && - as(funcTypeInInterface->getResultType())) - { - auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); - builder->emitReturn(pack); - } - else - { - if (call->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(call); - } - return wrapperFunc; - } - void lowerWitnessTable(IRWitnessTable* witnessTable) { auto interfaceType = as(witnessTable->getConformanceType()); @@ -234,7 +84,10 @@ struct GenerateWitnessTableWrapperContext entry->getRequirementKey()); if (auto ordinaryFunc = as(entry->getSatisfyingVal())) { - auto wrapper = emitWitnessTableWrapper(ordinaryFunc, interfaceRequirementVal); + auto wrapper = emitWitnessTableWrapper( + sharedContext->module, + ordinaryFunc, + interfaceRequirementVal); entry->satisfyingVal.set(wrapper); sharedContext->addToWorkList(wrapper); } @@ -270,6 +123,211 @@ struct GenerateWitnessTableWrapperContext } }; +// Represents a work item for packing `inout` or `out` arguments after a concrete call. +struct ArgumentPackWorkItem +{ + enum Kind + { + Pack, + Reinterpret + } kind = Pack; + + // A `AnyValue` typed destination. + IRInst* dstArg = nullptr; + // A concrete value to be packed. + IRInst* concreteArg = nullptr; +}; + +bool isAnyValueType(IRType* type) +{ + if (as(type)) + return true; + if (auto collection = as(type)) + return as(collection->getOperand(0)) != nullptr; + return false; +} + +// Unpack an `arg` of `IRAnyValue` into concrete type if necessary, to make it feedable into the +// parameter. If `arg` represents a AnyValue typed variable passed in to a concrete `out` +// parameter, this function indicates that it needs to be packed after the call by setting +// `packAfterCall`. +IRInst* maybeUnpackArg( + IRBuilder* builder, + IRType* paramType, + IRInst* arg, + ArgumentPackWorkItem& packAfterCall) +{ + packAfterCall.dstArg = nullptr; + packAfterCall.concreteArg = nullptr; + + // If either paramType or argType is a pointer type + // (because of `inout` or `out` modifiers), we extract + // the underlying value type first. + IRType* paramValType = paramType; + IRType* argValType = arg->getDataType(); + IRInst* argVal = arg; + if (auto ptrType = as(paramType)) + { + paramValType = ptrType->getValueType(); + } + auto argType = arg->getDataType(); + if (auto argPtrType = as(argType)) + { + argValType = argPtrType->getValueType(); + argVal = builder->emitLoad(arg); + } + + // Unpack `arg` if the parameter expects concrete type but + // `arg` is an AnyValue. + if (!isAnyValueType(paramValType) && isAnyValueType(argValType)) + { + auto unpackedArgVal = builder->emitUnpackAnyValue(paramValType, argVal); + // if parameter expects an `out` pointer, store the unpacked val into a + // variable and pass in a pointer to that variable. + if (as(paramType)) + { + auto tempVar = builder->emitVar(paramValType); + builder->emitStore(tempVar, unpackedArgVal); + // tempVar needs to be unpacked into original var after the call. + packAfterCall.kind = ArgumentPackWorkItem::Kind::Pack; + packAfterCall.dstArg = arg; + packAfterCall.concreteArg = tempVar; + return tempVar; + } + else + { + return unpackedArgVal; + } + } + + // Reinterpret 'arg' if it is being passed to a parameter with + // a different type collection. For now, we'll approximate this + // by checking if the types are different, but this should be + // encoded in the types. + // + if (paramValType != argValType) + { + auto reinterpretedArgVal = builder->emitReinterpret(paramValType, argVal); + // if parameter expects an `out` pointer, store the unpacked val into a + // variable and pass in a pointer to that variable. + if (as(paramType)) + { + auto tempVar = builder->emitVar(paramValType); + builder->emitStore(tempVar, reinterpretedArgVal); + // tempVar needs to be unpacked into original var after the call. + packAfterCall.kind = ArgumentPackWorkItem::Kind::Reinterpret; + packAfterCall.dstArg = arg; + packAfterCall.concreteArg = tempVar; + return tempVar; + } + else + { + return reinterpretedArgVal; + } + } + return arg; +} + +IRStringLit* _getWitnessTableWrapperFuncName(IRModule* module, IRFunc* func) +{ + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(func); + if (auto linkageDecoration = func->findDecoration()) + { + return builder->getStringValue( + (String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); + } + if (auto namehintDecoration = func->findDecoration()) + { + return builder->getStringValue( + (String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); + } + return nullptr; +} + +IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* interfaceRequirementVal) +{ + auto funcTypeInInterface = cast(interfaceRequirementVal); + auto targetFuncType = as(funcInst->getDataType()); + + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(funcInst); + + auto wrapperFunc = builder->createFunc(); + wrapperFunc->setFullType((IRType*)interfaceRequirementVal); + if (auto func = as(funcInst)) + if (auto name = _getWitnessTableWrapperFuncName(module, func)) + builder->addNameHintDecoration(wrapperFunc, name); + + builder->setInsertInto(wrapperFunc); + auto block = builder->emitBlock(); + builder->setInsertInto(block); + + ShortList params; + for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) + { + params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); + } + + List args; + List argsToPack; + + SLANG_ASSERT(params.getCount() == (Index)targetFuncType->getParamCount()); + for (UInt i = 0; i < targetFuncType->getParamCount(); i++) + { + auto wrapperParam = params[i]; + // Type of the parameter in the callee. + auto funcParamType = targetFuncType->getParamType(i); + + // If the implementation expects a concrete type + // (either in the form of a pointer for `out`/`inout` parameters, + // or in the form a value for `in` parameters, while + // the interface exposes an AnyValue type, + // we need to unpack the AnyValue argument to the appropriate + // concerete type. + ArgumentPackWorkItem packWorkItem; + auto newArg = maybeUnpackArg(builder, funcParamType, wrapperParam, packWorkItem); + args.add(newArg); + if (packWorkItem.concreteArg) + argsToPack.add(packWorkItem); + } + auto call = builder->emitCallInst(targetFuncType->getResultType(), funcInst, args); + + // Pack all `out` arguments. + for (auto item : argsToPack) + { + auto anyValType = cast(item.dstArg->getDataType())->getValueType(); + auto concreteVal = builder->emitLoad(item.concreteArg); + auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) + ? builder->emitPackAnyValue(anyValType, concreteVal) + : builder->emitReinterpret(anyValType, concreteVal); + builder->emitStore(item.dstArg, packedVal); + } + + // Pack return value if necessary. + if (!isAnyValueType(call->getDataType()) && + isAnyValueType(funcTypeInInterface->getResultType())) + { + auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); + builder->emitReturn(pack); + } + else if (call->getDataType() != funcTypeInInterface->getResultType()) + { + auto reinterpret = builder->emitReinterpret(funcTypeInInterface->getResultType(), call); + builder->emitReturn(reinterpret); + } + else + { + if (call->getDataType()->getOp() == kIROp_VoidType) + builder->emitReturn(); + else + builder->emitReturn(call); + } + return wrapperFunc; +} + void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext) { GenerateWitnessTableWrapperContext context; diff --git a/source/slang/slang-ir-witness-table-wrapper.h b/source/slang/slang-ir-witness-table-wrapper.h index 27008b085ed..acc69aad14d 100644 --- a/source/slang/slang-ir-witness-table-wrapper.h +++ b/source/slang/slang-ir-witness-table-wrapper.h @@ -4,6 +4,9 @@ namespace Slang { struct SharedGenericsLoweringContext; +struct IRFunc; +struct IRInst; +struct IRModule; /// This pass generates wrapper functions for witness table function entries. /// @@ -19,4 +22,9 @@ struct SharedGenericsLoweringContext; /// to concrete types and calls the actual implementation. void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext); +IRFunc* emitWitnessTableWrapper( + IRModule* module, + IRInst* funcInst, + IRInst* interfaceRequirementVal); + } // namespace Slang diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-1.slang similarity index 100% rename from tests/language-feature/dynamic-dispatch/dynamic-specialization.slang rename to tests/language-feature/dynamic-dispatch/dynamic-specialization-1.slang diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang new file mode 100644 index 00000000000..43bc8d57175 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-2.slang @@ -0,0 +1,74 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + static float calc(float x) { return x * x; } +}; + +struct C : IInterface +{ + static float calc(float x) { return x; } +}; + +struct D : IInterface +{ + static float calc(float x) { return x * x * x * x; } +}; + +struct E : IInterface +{ + static float calc(float x) { return x * x * x * x; } +}; + +float calc(T obj, float y) +{ + return T::calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return calc(obj, x); // Specialized for A & B. +} + +float g(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = C(); + else + obj = D(); + + return calc(obj, x); // Specialized for C & D. +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 + + outputBuffer[2] = g(0, 2); // CHECK: 2 + outputBuffer[3] = g(1, 2); // CHECK: 16 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang new file mode 100644 index 00000000000..09bba484e4e --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-3.slang @@ -0,0 +1,61 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + This withOffset(float offset); + float calc(float x); +} + +struct A : IInterface +{ + float factor; + A withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return factor * x * x * x; } +}; + +struct B : IInterface +{ + float factor1; + float factor2; + B withOffset(float offset) { return {factor1 + offset, factor2 + offset}; } + float calc(float x) { return factor1 * x * x + factor2 * x; } +}; + +struct C : IInterface +{ + float factor; + C withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return x; } +}; + +T transfer(T obj) +{ + return obj.withOffset(2.5f); +} + +float calc(T obj, float y) +{ + return transfer(obj).calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else + obj = B(); + + return calc(obj, x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 20 + outputBuffer[1] = f(1, 2); // CHECK: 15 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/inout.slang b/tests/language-feature/dynamic-dispatch/inout.slang new file mode 100644 index 00000000000..e1c4063eb83 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/inout.slang @@ -0,0 +1,76 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +// Test data-flow through inout parameters (which will require data-flow through 'IRVar' insts) + +interface IReserver +{ + [mutating] uint reserve(uint x); +} + +struct Single : IReserver +{ + uint total; + [mutating] uint reserve(uint x) + { + if (total < x) + { + total = 0; + return total; + } + else + { + total -= x; + return x; + } + } +}; + +struct Double : IReserver +{ + uint total; + [mutating] uint reserve(uint x) + { + if (total < 2 * x) + { + total = 0; + return total; + } + else + { + total -= 2 * x; + return 2 * x; + } + } +}; + +IReserver make(uint id, uint total) +{ + if (id == 0) + return Single(total); + else + return Double(total); +} + +float f(inout IReserver obj, uint res, float x) +{ + uint y = obj.reserve(res + 1); + + return x * y; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = make(0, 10); + outputBuffer[0] = f(obj, 0, 2); // CHECK: 2 + outputBuffer[1] = f(obj, 1, 3); // CHECK: 6 + outputBuffer[2] = f(obj, 0, 4); // CHECK: 4 + + obj = make(1, 10); + outputBuffer[3] = f(obj, 0, 2); // CHECK: 4 + outputBuffer[4] = f(obj, 1, 3); // CHECK: 12 + outputBuffer[5] = f(obj, 0, 4); // CHECK: 8 +} \ No newline at end of file From 1b50349a565480ac8429387fc392e336db7626d0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:58:23 -0400 Subject: [PATCH 015/105] Delete slang-vscode.natvis --- source/slang/slang-vscode.natvis | 818 ------------------------------- 1 file changed, 818 deletions(-) delete mode 100644 source/slang/slang-vscode.natvis diff --git a/source/slang/slang-vscode.natvis b/source/slang/slang-vscode.natvis deleted file mode 100644 index f159c036835..00000000000 --- a/source/slang/slang-vscode.natvis +++ /dev/null @@ -1,818 +0,0 @@ - - - - - rawVal ? ($T1*)((char*)this + rawVal) : ($T1*)0 - BCPtr nullptr - BCPtr {*($T1*)((char*)this + rawVal)} - - rawVal ? ($T1*)((char*)this + rawVal) : ($T1*)0 - - - - Constant {intOperand} - {(Slang::Val*)nodeOperand} - {nodeOperand} - - *(Slang::Val*)nodeOperand - - - - DeclRef nullptr - - {*declRefBase} - - declRefBase - - - - {astNodeType,en}#{_debugUID}({(Decl*)m_operands.m_buffer[0].values.nodeOperand}) - {astNodeType,en}({(Decl*)m_operands.m_buffer[0].values.nodeOperand}) - DeclRefBase nullptr - - - {*(Decl*)m_operands.m_buffer[0].values.nodeOperand} - - *(Decl*)m_operands.m_buffer[0].values.nodeOperand - - - - {*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)} - - *(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand) - - - - {*(Val*)(this->m_operands.m_buffer[1].values.nodeOperand)} - - *(Val*)(this->m_operands.m_buffer[1].values.nodeOperand) - - - - {*(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand)} - - *(SubtypeWitness*)(this->m_operands.m_buffer[2].values.nodeOperand) - - - - {*(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand)} - - *(DeclRefBase*)(this->m_operands.m_buffer[1].values.nodeOperand) - - - - - - *(Val*)(this->m_operands.m_buffer[index].values.nodeOperand) - index=index+1 - - - - - - {astNodeType,en}#{_debugUID} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand} - - {astNodeType,en} {*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand} - - - {astNodeType,en}#{_debugUID} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en}#{m_operands.m_buffer[0].values.nodeOperand->_debugUID} - {astNodeType,en} {m_operands.m_buffer[0].values.nodeOperand->astNodeType, en} - - *(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand - - - - FuncDecl {nameAndLoc} - - - {{name={(char*)(text.m_buffer.pointer+1), s}}} - - - {{name={(char*)((*name).text.m_buffer.pointer+1), s} loc={loc.raw}}} - - - - requirementKey - satisfyingVal - - - - {{{m_op} {(uint32_t)(void*)this, x}}} - {{{m_op} #{_debugUID}}} - - m_op - _debugUID - typeUse.usedValue - - - - - - - - ((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.chars,[((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.numChars]s8 - - - ((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.chars,[((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.numChars]s8 - - - ((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.chars,[((Slang::IRStringLit*)(((Slang::IRUse*)(child + 1))->usedValue))->value.stringVal.numChars]s8 - - child = child->next - - - ((IRStringLit*)this)->value.stringVal.chars,[((IRStringLit*)this)->value.stringVal.numChars]s8 - ((IRIntLit*)this)->value.intVal - - - - - - - pOperandInst = ((IRUse*)(&(typeUse) + 1 + index))->usedValue - pOperandInst - - child = pOperandInst->m_decorationsAndChildren.first - nameDecoration = 0 - - - nameDecoration = child - - - - nameDecoration = child - - - nameDecoration = child - - child = child->next - - *pOperandInst - *pOperandInst - - index = index + 1 - - - - - - - - - - - child = pItem->m_decorationsAndChildren.first - nameDecoration = 0 - - - nameDecoration = child - - - - nameDecoration = child - - - nameDecoration = child - - child = child->next - - *pItem - *pItem - pItem = pItem->next - index = index + 1 - - - - - parent - - - - firstUse - nextUse - user - - - - - - - {{IRUse {usedValue}}} - - usedValue - - - - - {astNodeType,en} - - (Slang::DeclRefExpr*)&astNodeType - (Slang::VarExpr*)&astNodeType - (Slang::MemberExpr*)&astNodeType - (Slang::StaticMemberExpr*)&astNodeType - (Slang::OverloadedExpr*)&astNodeType - (Slang::OverloadedExpr2*)&astNodeType - (Slang::LiteralExpr*)&astNodeType - (Slang::IntegerLiteralExpr*)&astNodeType - (Slang::FloatingPointLiteralExpr*)&astNodeType - (Slang::BoolLiteralExpr*)&astNodeType - (Slang::NullPtrLiteralExpr*)&astNodeType - (Slang::StringLiteralExpr*)&astNodeType - (Slang::InitializerListExpr*)&astNodeType - (Slang::ExprWithArgsBase*)&astNodeType - (Slang::AggTypeCtorExpr*)&astNodeType - (Slang::AppExprBase*)&astNodeType - (Slang::InvokeExpr*)&astNodeType - (Slang::NewExpr*)&astNodeType - (Slang::OperatorExpr*)&astNodeType - (Slang::InfixExpr*)&astNodeType - (Slang::PrefixExpr*)&astNodeType - (Slang::PostfixExpr*)&astNodeType - (Slang::SelectExpr*)&astNodeType - (Slang::TypeCastExpr*)&astNodeType - (Slang::ExplicitCastExpr*)&astNodeType - (Slang::ImplicitCastExpr*)&astNodeType - (Slang::GenericAppExpr*)&astNodeType - (Slang::TryExpr*)&astNodeType - (Slang::IndexExpr*)&astNodeType - (Slang::MatrixSwizzleExpr*)&astNodeType - (Slang::SwizzleExpr*)&astNodeType - (Slang::DerefExpr*)&astNodeType - (Slang::CastToSuperTypeExpr*)&astNodeType - (Slang::ModifierCastExpr*)&astNodeType - (Slang::SharedTypeExpr*)&astNodeType - (Slang::AssignExpr*)&astNodeType - (Slang::ParenExpr*)&astNodeType - (Slang::ThisExpr*)&astNodeType - (Slang::LetExpr*)&astNodeType - (Slang::ExtractExistentialValueExpr*)&astNodeType - (Slang::OpenRefExpr*)&astNodeType - (Slang::ForwardDifferentiateExpr*)&astNodeType - (Slang::BackwardDifferentiateExpr*)&astNodeType - (Slang::ThisTypeExpr*)&astNodeType - (Slang::AndTypeExpr*)&astNodeType - (Slang::ModifiedTypeExpr*)&astNodeType - (Slang::PointerTypeExpr*)&astNodeType - type - (Slang::Expr*)this,! - - - - {astNodeType,en} - - (Slang::ScopeStmt*)&astNodeType - (Slang::BlockStmt*)&astNodeType - (Slang::BreakableStmt*)&astNodeType - (Slang::SwitchStmt*)&astNodeType - (Slang::LoopStmt*)&astNodeType - (Slang::ForStmt*)&astNodeType - (Slang::UnscopedForStmt*)&astNodeType - (Slang::WhileStmt*)&astNodeType - (Slang::DoWhileStmt*)&astNodeType - (Slang::GpuForeachStmt*)&astNodeType - (Slang::CompileTimeForStmt*)&astNodeType - (Slang::SeqStmt*)&astNodeType - (Slang::UnparsedStmt*)&astNodeType - (Slang::EmptyStmt*)&astNodeType - (Slang::DiscardStmt*)&astNodeType - (Slang::DeclStmt*)&astNodeType - (Slang::IfStmt*)&astNodeType - (Slang::ChildStmt*)&astNodeType - (Slang::CaseStmtBase*)&astNodeType - (Slang::CaseStmt*)&astNodeType - (Slang::DefaultStmt*)&astNodeType - (Slang::JumpStmt*)&astNodeType - (Slang::BreakStmt*)&astNodeType - (Slang::ContinueStmt*)&astNodeType - (Slang::ReturnStmt*)&astNodeType - (Slang::ExpressionStmt*)&astNodeType - (Slang::Stmt*)this,! - - - - {text} - - - {astNodeType,en} {nameAndLoc.name->text} - {astNodeType,en} - - nameAndLoc.name->text - parentDecl - Slang::DeclCheckState(checkState.m_raw & ~Slang::DeclCheckStateExt::kBeingCheckedBit) - (Slang::ContainerDecl*)&astNodeType - (Slang::ExtensionDecl*)&astNodeType - (Slang::StructDecl*)&astNodeType - (Slang::ClassDecl*)&astNodeType - (Slang::EnumDecl*)&astNodeType - (Slang::InterfaceDecl*)&astNodeType - (Slang::AssocTypeDecl*)&astNodeType - (Slang::GlobalGenericParamDecl*)&astNodeType - (Slang::ScopeDecl*)&astNodeType - (Slang::ConstructorDecl*)&astNodeType - (Slang::GetterDecl*)&astNodeType - (Slang::SetterDecl*)&astNodeType - (Slang::RefAccessorDecl*)&astNodeType - (Slang::FuncDecl*)&astNodeType - (Slang::SubscriptDecl*)&astNodeType - (Slang::PropertyDecl*)&astNodeType - (Slang::NamespaceDecl*)&astNodeType - (Slang::ModuleDecl*)&astNodeType - (Slang::GenericDecl*)&astNodeType - (Slang::AttributeDecl*)&astNodeType - (Slang::VarDeclBase*)&astNodeType - (Slang::VarDecl*)&astNodeType - (Slang::LetDecl*)&astNodeType - (Slang::GlobalGenericValueParamDecl*)&astNodeType - (Slang::ParamDecl*)&astNodeType - (Slang::ModernParamDecl*)&astNodeType - (Slang::GenericValueParamDecl*)&astNodeType - (Slang::EnumCaseDecl*)&astNodeType - (Slang::TypeConstraintDecl*)&astNodeType - (Slang::InheritanceDecl*)&astNodeType - (Slang::GenericTypeConstraintDecl*)&astNodeType - (Slang::SimpleTypeDecl*)&astNodeType - (Slang::TypeDefDecl*)&astNodeType - (Slang::TypeAliasDecl*)&astNodeType - (Slang::GenericTypeParamDecl*)&astNodeType - (Slang::UsingDecl*)&astNodeType - (Slang::ImportDecl*)&astNodeType - (Slang::EmptyDecl*)&astNodeType - (Slang::SyntaxDecl*)&astNodeType - (Slang::DeclGroup*)&astNodeType - - (Slang::DeclBase*)this,! - - - - - {astNodeType,en} - - (Slang::ContainerDecl*)&astNodeType - (Slang::ExtensionDecl*)&astNodeType - (Slang::StructDecl*)&astNodeType - (Slang::ClassDecl*)&astNodeType - (Slang::EnumDecl*)&astNodeType - (Slang::InterfaceDecl*)&astNodeType - (Slang::AssocTypeDecl*)&astNodeType - (Slang::GlobalGenericParamDecl*)&astNodeType - (Slang::ScopeDecl*)&astNodeType - (Slang::ConstructorDecl*)&astNodeType - (Slang::GetterDecl*)&astNodeType - (Slang::SetterDecl*)&astNodeType - (Slang::RefAccessorDecl*)&astNodeType - (Slang::FuncDecl*)&astNodeType - (Slang::SubscriptDecl*)&astNodeType - (Slang::PropertyDecl*)&astNodeType - (Slang::NamespaceDecl*)&astNodeType - (Slang::ModuleDecl*)&astNodeType - (Slang::GenericDecl*)&astNodeType - (Slang::AttributeDecl*)&astNodeType - (Slang::VarDeclBase*)&astNodeType - (Slang::VarDecl*)&astNodeType - (Slang::LetDecl*)&astNodeType - (Slang::GlobalGenericValueParamDecl*)&astNodeType - (Slang::ParamDecl*)&astNodeType - (Slang::ModernParamDecl*)&astNodeType - (Slang::GenericValueParamDecl*)&astNodeType - (Slang::EnumCaseDecl*)&astNodeType - (Slang::TypeConstraintDecl*)&astNodeType - (Slang::InheritanceDecl*)&astNodeType - (Slang::GenericTypeConstraintDecl*)&astNodeType - (Slang::SimpleTypeDecl*)&astNodeType - (Slang::TypeDefDecl*)&astNodeType - (Slang::TypeAliasDecl*)&astNodeType - (Slang::GenericTypeParamDecl*)&astNodeType - (Slang::UsingDecl*)&astNodeType - (Slang::ImportDecl*)&astNodeType - (Slang::EmptyDecl*)&astNodeType - (Slang::SyntaxDecl*)&astNodeType - (Slang::DeclGroup*)&astNodeType - (Slang::Decl*)this,! - - - - {astNodeType,en} #{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} - {astNodeType,en} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} - - *(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand) - - - - {astNodeType,en} #{_debugUID} - {astNodeType,en} - - - {m_operands.m_count-2} - - - m_operands.m_count-2 - m_operands.m_buffer - - - {m_operands.m_buffer[m_operands.m_count-2]} - - m_operands.m_buffer[m_operands.m_count-2] - - - - {m_operands.m_buffer[m_operands.m_count-1]} - - m_operands.m_buffer[m_operands.m_count-1] - - - - - - - DeclRefType#{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} - DeclRefType {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)} - DirectRef#{_debugUID} {*(Decl*)m_operands.m_buffer[0].values.nodeOperand} - DirectRef {*(Decl*)m_operands.m_buffer[0].values.nodeOperand} - {astNodeType,en} #{_debugUID} - {astNodeType,en} - - - - {astNodeType} - - - m_operands - - - - - SubstitutionSet{declRef,en} - - declRef - - - - - - substType = subst->astNodeType - shouldBreak = 1 - - - - - - subst = (DeclRefBase*)(((Slang::MemberDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand) - shouldBreak = 0 - - - (LookupDeclRef*)subst - - - - (GenericAppDeclRef*)subst - subst = (DeclRefBase*)(((Slang::GenericAppDeclRef*)subst)->m_operands.m_buffer[1].values.nodeOperand) - shouldBreak = 0 - - - - - - - - - - {astNodeType}({nameAndLoc.name}) - - members - - - - Const({values.intOperand})#{_debugUID} - {*(Val*)values.nodeOperand} - {values.nodeOperand} - - *(Val*)values.nodeOperand - *(Decl*)values.nodeOperand - - - - - _impl - nullptr - {_impl} - - - empty - - - _head._impl != 0 ? &_head : 0 - _impl->next._impl != 0 ? &_impl->next : 0 - *this - - - - - empty - - - _head != 0 ? _head : 0 - next != 0 ? next : 0 - *this - - - - - {astNodeType,en}#{_debugUID} ({m_operands.m_buffer[1].values.intOperand} : {*(Type*)m_operands.m_buffer[0].values.nodeOperand}) - ConstantIntVal ({m_operands.m_buffer[1].values.intOperand} : {*(Type*)m_operands.m_buffer[0].values.nodeOperand}) - - - - {astNodeType,en}#{_debugUID} - {astNodeType,en} - - - m_operands.m_count - m_operands.m_buffer - - - - - - {astNodeType,en}#{_debugUID} - {astNodeType,en} - - - m_operands.m_count - m_operands.m_buffer - - - - - - {astNodeType,en}#{_debugUID} - {astNodeType,en} - - - m_operands.m_count - m_operands.m_buffer - - - - - - BasicExpressionType ({*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand}) - - - - {m_targetSets.map.m_values} - - - m_targetSets - m_targetSets.map.m_values - - - - - {{target={target}}} - - - target - shaderStageSets - shaderStageSets.map.m_values - - - - - {{size={atomSet}}} - - - stage - atomSet - - - - - - - {{max_size={m_buffer.m_count*Slang::UIntSet::kElementSize}}} - - - - - - - - - - - - iter = (Slang::UIntSet::Element)0 - bitIter = (Slang::UIntSet::Element)0 - totalBitIter = (Slang::UIntSet::Element)0 - value = 0 - - - bitIter = 0 - totalBitIter++ - iter++ - - - - - bitValue = (m_buffer[iter]>>bitIter)&1 - - value = totalBitIter - (CapabilityAtom)value - - bitIter++ - totalBitIter++ - - - - - - - - {{max_size={m_buffer.m_count*Slang::UIntSet::kElementSize}}} - - - - - - - - - - - - iter = (Slang::UIntSet::Element)0 - bitIter = (Slang::UIntSet::Element)0 - totalBitIter = (Slang::UIntSet::Element)0 - value = 0 - - - bitIter = 0 - totalBitIter++ - iter++ - - - - - bitValue = (m_buffer[iter]>>bitIter)&1 - - value = totalBitIter - (CapabilityAtom)value - - bitIter++ - totalBitIter++ - - - - - - - - - - {((char*) (m_buffer.pointer+1)),s} - ((char*) (m_buffer.pointer+1)),s - - - - {{ size={m_count} }} - - m_count - - m_count - m_buffer - - - - - - {{ size={m_count} }} - - m_count - m_capacity - - m_count - m_buffer - - - - - - {{ size={m_count} }} - - m_count - m_capacity - - m_count - m_shortBuffer + $i - m_buffer + $i - $T2 - - - - - - {{ size={m_count} }} - - m_count - - m_count - m_buffer - - - - - - {{ {map.m_values} }} - - map.m_values - map - - - - - {{ {dict} }} - - dict - - - - - {{ size={dict._count} }} - - - m_dict._count - m_dict.kvPairs.head - next - value - - - - - - {{ size={m_count} }} - - - m_count - m_kvPairs.head - next - value - - - - - - pointer - empty - RefPtr {*pointer} - - pointer - - - - - - - ($T1*)(m_base->m_data + m_offset) - - - - - - (m_offset == 0x80000000) ? nullptr : ($T1*)(((char*)this) + m_offset) - - - - - - - m_count - - m_count - ($T1*)(m_data.m_base->m_data + m_data.m_offset) - - - - - - - - m_count - - m_count - (m_data.m_offset == 0x80000000) ? nullptr : ($T1*)(((char*)&m_data) + m_data.m_offset) - - - - - - {(m_sizeThenContents + 1),s} - (m_sizeThenContents + 1),s - - - - {m_begin,[m_end-m_begin]s} - m_begin,[m_end-m_begin]s - - - - - From a2923fd31abba47c97017827abfde6226f9b7549 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 29 Jul 2025 12:59:36 -0400 Subject: [PATCH 016/105] Delete sample.slang --- tests/sample.slang | 54 ---------------------------------------------- 1 file changed, 54 deletions(-) delete mode 100644 tests/sample.slang diff --git a/tests/sample.slang b/tests/sample.slang deleted file mode 100644 index 165a78bbbea..00000000000 --- a/tests/sample.slang +++ /dev/null @@ -1,54 +0,0 @@ - - -interface ILog -{ - static float log(float x); -} - -struct SlowLog : ILog -{ - static float log(float x) - { - // implement slow log. - } -} - -struct FastLog : ILog -{ - static float log(float x) - { - // implement slow log. - } -} - -// prototype system -// declare ILog = SlowLog | FastLog; - -// actual slang -extern struct MyLog : ILog; // slangApi->setLinkTimeConst("MyLog", "FastLog"); - - -extern static const int TILE_SIZE; // slangApi->setLinkTimeConst("TILE_SIZE", 16); - -void main() -{ - // Use 1 - MyLog.log(10.f); - - - // Use 2 - MyLog.log(20.f); -} - -matrix matmul(matrix, matrix) -{ - .... -} - -void main() -{ - - matmul(cast(matA), cast(matB)); -} - -// slangApi->setLinkTimeConst("main.MyLog1", "FastLog"); \ No newline at end of file From 2ae4b83471231aef8a1d46609d8fddd803703785 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 29 Jul 2025 14:27:11 -0400 Subject: [PATCH 017/105] Minor fixes --- source/slang/slang-ir-lower-dynamic-insts.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 5d1fec785f7..df20c075ece 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1785,7 +1785,7 @@ struct DynamicInstLoweringContext keyMappingFunc, List({inst->getWitnessTable()})); inst->replaceUsesWith(witnessTableId); - propagationMap[Element(witnessTableId)] = info; + propagationMap[Element(context, witnessTableId)] = info; inst->removeAndDeallocate(); return true; } @@ -1819,7 +1819,7 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); inst->replaceUsesWith(element); - propagationMap[Element(element)] = info; + propagationMap[Element(context, element)] = info; inst->removeAndDeallocate(); return true; } @@ -2246,7 +2246,7 @@ struct DynamicInstLoweringContext auto newCall = builder.emitCallInst(inst->getDataType(), dispatchFunc, newArgs); inst->replaceUsesWith(newCall); if (auto info = tryGetInfo(context, inst)) - propagationMap[Element(newCall)] = info; + propagationMap[Element(context, newCall)] = info; replaceType(context, newCall); // "maybe replace type" inst->removeAndDeallocate(); return true; @@ -2308,7 +2308,7 @@ struct DynamicInstLoweringContext auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); if (auto info = tryGetInfo(context, inst)) - propagationMap[Element(tuple)] = info; + propagationMap[Element(context, tuple)] = info; inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); @@ -2352,7 +2352,7 @@ struct DynamicInstLoweringContext builder.emitReinterpret(existentialTupleType->getOperand(1), inst->getValue())})); if (auto info = tryGetInfo(context, inst)) - propagationMap[Element(existentialTuple)] = info; + propagationMap[Element(context, existentialTuple)] = info; inst->replaceUsesWith(existentialTuple); inst->removeAndDeallocate(); From b58f2e879fec113cebcaca53d7d63eba43ef5cc3 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 29 Jul 2025 16:20:47 -0400 Subject: [PATCH 018/105] Use IR ops instead of a custom struct to track type flow data. --- source/slang/slang-ir-insts-stable-names.lua | 8 +- source/slang/slang-ir-insts.lua | 14 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 861 ++++++++---------- .../slang/slang-ir-witness-table-wrapper.cpp | 8 +- source/slang/slang-ir.cpp | 6 +- 5 files changed, 428 insertions(+), 469 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index c4d8d177e62..abb9cd3b68a 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -669,5 +669,11 @@ return { ["SPIRVAsmOperand.__sampledType"] = 665, ["SPIRVAsmOperand.__imageType"] = 666, ["SPIRVAsmOperand.__sampledImageType"] = 667, - ["TypeFlow.TypeFlowCollection"] = 668 + ["TypeFlowData.CollectionBase.TypeCollection"] = 668, + ["TypeFlowData.CollectionBase.FuncCollection"] = 669, + ["TypeFlowData.CollectionBase.TableCollection"] = 670, + ["TypeFlowData.CollectionBase.GenericCollection"] = 671, + ["TypeFlowData.UnboundedCollection"] = 672, + ["TypeFlowData.CollectionTagType"] = 673, + ["TypeFlowData.CollectionTaggedUnionType"] = 674 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 0c7be24ca65..0db1668eba4 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2173,9 +2173,19 @@ local insts = { { -- A collection of IR instructions used for propagation analysis -- The operands are the elements of the set, sorted by unique ID to ensure canonical ordering - TypeFlow = { + TypeFlowData = { hoistable = true, - { TypeFlowCollection = {} } + { + CollectionBase = { + { TypeCollection = {} }, + { FuncCollection = {} }, + { TableCollection = {} }, + { GenericCollection = {} }, + }, + }, + { UnboundedCollection = {} }, + { CollectionTagType = {} }, -- Operand is TypeCollection/FuncCollection/TableCollection (funcs/tables) + { CollectionTaggedUnionType = {}} -- Operand is TypeCollection, TableCollection for existential }, } } diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index df20c075ece..b2d866c3522 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -90,78 +90,6 @@ struct Element // getHashCode() HashCode64 getHashCode() const { return combineHash(HashCode(context), HashCode(inst)); } }; -// Enumeration for different kinds of judgments about IR instructions. -// -// This forms a lattice with -// -// None < Set < Unbounded -// None < Existential < Unbounded -// -enum class PropagationJudgment -{ - None, // Either uninitialized or irrelevant - Set, // Set of possible types/tables/funcs - Existential, // Existential box with a set of possible witness tables - Unbounded, // Unknown set of possible types/tables/funcs (e.g. COM interface types) -}; - -// Data structure to hold propagation information for an instruction -struct PropagationInfo : RefObject -{ - PropagationJudgment judgment; - - // For sets of types/tables/funcs and existential witness tables - // Instead of HashSet, we use an IRCollection instruction with sorted operands - IRInst* collection; - - PropagationInfo() - : judgment(PropagationJudgment::None), collection(nullptr) - { - } - - PropagationInfo(PropagationJudgment j) - : judgment(j), collection(nullptr) - { - } - - PropagationInfo(PropagationJudgment j, IRInst* coll) - : judgment(j), collection(coll) - { - } - - // NOTE: Factory methods moved to DynamicInstLoweringContext to access collection creation - - bool isNone() const { return judgment == PropagationJudgment::None; } - bool isSingleton() const - { - return judgment == PropagationJudgment::Set && getCollectionCount() == 1; - } - - IRInst* getSingletonValue() const - { - if (judgment == PropagationJudgment::Set && getCollectionCount() == 1) - return getCollectionElement(0); - - SLANG_UNEXPECTED("getSingletonValue called on non-singleton PropagationInfo"); - } - - // Helper functions to access collection elements - UInt getCollectionCount() const - { - if (!collection) - return 0; - return collection->getOperandCount(); - } - - IRInst* getCollectionElement(UInt index) const - { - if (!collection || index >= collection->getOperandCount()) - return nullptr; - return collection->getOperand(index); - } - - operator bool() const { return judgment != PropagationJudgment::None; } -}; // Data structures for interprocedural data-flow analysis @@ -276,38 +204,29 @@ struct WorkItem } }; -bool areInfosEqual(const PropagationInfo& a, const PropagationInfo& b) +bool areInfosEqual(IRInst* a, IRInst* b) { - if (a.judgment != b.judgment) - return false; - - switch (a.judgment) - { - case PropagationJudgment::None: - case PropagationJudgment::Unbounded: - return true; - case PropagationJudgment::Set: - case PropagationJudgment::Existential: - return a.collection == b.collection; - default: - return false; - } + return a == b; } struct DynamicInstLoweringContext { // Helper methods for creating canonical collections - IRInst* createCollection(const HashSet& elements) + IRCollectionBase* createCollection(IROp op, const HashSet& elements) { List sortedElements; for (auto element : elements) sortedElements.add(element); - return createCollection(sortedElements); + return createCollection(op, sortedElements); } - IRInst* createCollection(const List& elements) + IRCollectionBase* createCollection(IROp op, const List& elements) { + SLANG_ASSERT( + op == kIROp_TypeCollection || op == kIROp_FuncCollection || + op == kIROp_TableCollection || op == kIROp_GenericCollection); + if (elements.getCount() == 0) return nullptr; @@ -326,72 +245,120 @@ struct DynamicInstLoweringContext builder.setInsertInto(module); // Use makeTuple as a temporary implementation until IRCollection is available - return builder.emitIntrinsicInst( + return as(builder.emitIntrinsicInst( nullptr, - kIROp_TypeFlowCollection, + op, sortedElements.getCount(), - sortedElements.getBuffer()); + sortedElements.getBuffer())); + } + + IROp getCollectionTypeForInst(IRInst* inst) + { + if (as(inst->getDataType())) + return kIROp_TypeCollection; + else if (as(inst) && !as(inst)) + return kIROp_TypeCollection; + else if (as(inst->getDataType())) + return kIROp_FuncCollection; + else if (as(inst->getDataType())) + return kIROp_TableCollection; + else if (as(inst->getDataType())) + return kIROp_GenericCollection; + else + SLANG_UNEXPECTED("Unsupported collection type for instruction"); } // Factory methods for PropagationInfo - PropagationInfo makeSingletonSet(IRInst* value) + IRCollectionBase* makeSingletonSet(IRInst* value) { HashSet singleSet; singleSet.add(value); - auto collection = createCollection(singleSet); - return PropagationInfo(PropagationJudgment::Set, collection); + return createCollection(getCollectionTypeForInst(value), singleSet); } - PropagationInfo makeSet(const HashSet& values) + IRCollectionBase* makeSet(const HashSet& values) { SLANG_ASSERT(values.getCount() > 0); - auto collection = createCollection(values); - return PropagationInfo(PropagationJudgment::Set, collection); + return createCollection(getCollectionTypeForInst(*values.begin()), values); + } + + IRCollectionTaggedUnionType* makeExistential(IRTableCollection* tableCollection) + { + HashSet typeSet; + // Collect all types from the witness tables + forEachInCollection( + tableCollection, + [&](IRInst* witnessTable) + { + if (auto table = as(witnessTable)) + typeSet.add(table->getConcreteType()); + }); + + auto typeCollection = createCollection(kIROp_TypeCollection, typeSet); + + // Create the tagged union type + IRBuilder builder(module); + List elements = {typeCollection, tableCollection}; + return as(builder.emitIntrinsicInst( + nullptr, + kIROp_CollectionTaggedUnionType, + elements.getCount(), + elements.getBuffer())); } - PropagationInfo makeExistential(const HashSet& tables) + /*IRCollectionTaggedUnionType* makeExistential(const HashSet& tables) { SLANG_ASSERT(tables.getCount() > 0); - auto collection = createCollection(tables); - return PropagationInfo(PropagationJudgment::Existential, collection); + auto tableCollection = createCollection(kIROp_TableCollection, tables); + return makeExistential(tableCollection); + }*/ + + UCount getCollectionCount(IRCollectionBase* collection) + { + if (!collection) + return 0; + return collection->getOperandCount(); + } + + UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) + { + auto typeCollection = taggedUnion->getOperand(0); + return getCollectionCount(as(typeCollection)); } - PropagationInfo makeUnbounded() + IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) { - return PropagationInfo(PropagationJudgment::Unbounded, nullptr); + if (!collection || index >= collection->getOperandCount()) + return nullptr; + return collection->getOperand(index); + } + + IRUnboundedCollection* makeUnbounded() + { + IRBuilder builder(module); + return as( + builder.emitIntrinsicInst(nullptr, kIROp_UnboundedCollection, 0, nullptr)); } - PropagationInfo none() { return PropagationInfo(PropagationJudgment::None, nullptr); } + IRTypeFlowData* none() { return nullptr; } // Helper to iterate over collection elements template - void forEachInCollection(const PropagationInfo& info, F func) + void forEachInCollection(IRCollectionBase* info, F func) { - for (UInt i = 0; i < info.getCollectionCount(); ++i) - func(info.getCollectionElement(i)); + for (UInt i = 0; i < info->getOperandCount(); ++i) + func(info->getOperand(i)); } // Helper to convert collection to HashSet - HashSet collectionToHashSet(const PropagationInfo& info) + HashSet collectionToHashSet(IRCollectionBase* info) { HashSet result; forEachInCollection(info, [&](IRInst* element) { result.add(element); }); return result; } - /*PropagationInfo tryGetInfo(IRInst* inst) - { - // If this is a global instruction (parent is module), return concrete info - if (as(inst->getParent())) - if (as(inst) || as(inst) || as(inst)) - return makeSingletonSet(inst); - else - return none(); - - return tryGetInfo(Element(inst)); - }*/ - - PropagationInfo tryGetInfo(Element element) + IRTypeFlowData* tryGetInfo(Element element) { // For non-global instructions, look up in the map auto found = propagationMap.tryGetValue(element); @@ -400,7 +367,7 @@ struct DynamicInstLoweringContext return none(); } - PropagationInfo tryGetInfo(IRInst* context, IRInst* inst) + IRTypeFlowData* tryGetInfo(IRInst* context, IRInst* inst) { if (!inst->getParent()) return none(); @@ -416,7 +383,7 @@ struct DynamicInstLoweringContext return tryGetInfo(Element(context, inst)); } - PropagationInfo tryGetFuncReturnInfo(IRFunc* func) + IRTypeFlowData* tryGetFuncReturnInfo(IRFunc* func) { auto found = funcReturnInfo.tryGetValue(func); if (found) @@ -430,7 +397,7 @@ struct DynamicInstLoweringContext void updateInfo( IRInst* context, IRInst* inst, - PropagationInfo newInfo, + IRTypeFlowData* newInfo, LinkedList& workQueue) { auto existingInfo = tryGetInfo(context, inst); @@ -452,7 +419,7 @@ struct DynamicInstLoweringContext void addUsersToWorkQueue( IRInst* context, IRInst* inst, - PropagationInfo info, + IRTypeFlowData* info, LinkedList& workQueue) { for (auto use = inst->firstUse; use; use = use->nextUse) @@ -511,7 +478,7 @@ struct DynamicInstLoweringContext // Helper method to update function return info and propagate to call sites void updateFuncReturnInfo( IRInst* callable, - PropagationInfo returnInfo, + IRTypeFlowData* returnInfo, LinkedList& workQueue) { auto existingReturnInfo = getFuncReturnInfo(callable); @@ -591,17 +558,17 @@ struct DynamicInstLoweringContext } } - IRInst* maybeReinterpret(IRInst* context, IRInst* arg, PropagationInfo destInfo) + IRInst* maybeReinterpret(IRInst* context, IRInst* arg, IRTypeFlowData* destInfo) { auto argInfo = tryGetInfo(context, arg); if (!argInfo || !destInfo) return arg; - if (argInfo.judgment == PropagationJudgment::Existential && - destInfo.judgment == PropagationJudgment::Existential) + if (as(argInfo) && as(destInfo)) { - if (argInfo.getCollectionCount() != destInfo.getCollectionCount()) + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) { // If the sets of witness tables are not equal, reinterpret to the parameter type IRBuilder builder(module); @@ -751,7 +718,7 @@ struct DynamicInstLoweringContext void processInstForPropagation(IRInst* context, IRInst* inst, LinkedList& workQueue) { - PropagationInfo info; + IRTypeFlowData* info; switch (inst->getOp()) { @@ -795,7 +762,7 @@ struct DynamicInstLoweringContext updateInfo(context, inst, info, workQueue); } - PropagationInfo analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) + IRTypeFlowData* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { // // TODO: Actually use the integer<->type map present in the linkage to @@ -808,7 +775,8 @@ struct DynamicInstLoweringContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential(tables); + return makeExistential( + as(createCollection(kIROp_TableCollection, tables))); else return none(); } @@ -821,7 +789,7 @@ struct DynamicInstLoweringContext return none(); } - PropagationInfo analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) + IRTypeFlowData* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) { auto witnessTable = inst->getWitnessTable(); auto value = inst->getWrappedValue(); @@ -833,14 +801,14 @@ struct DynamicInstLoweringContext if (!witnessTableInfo) return none(); - if (witnessTableInfo.judgment == PropagationJudgment::Unbounded) + if (as(witnessTableInfo)) return makeUnbounded(); HashSet tables; - if (witnessTableInfo.judgment == PropagationJudgment::Set) - forEachInCollection(witnessTableInfo, [&](IRInst* table) { tables.add(table); }); + if (auto collection = as(witnessTableInfo)) + return makeExistential(collection); - return makeExistential(tables); + SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) @@ -852,14 +820,14 @@ struct DynamicInstLoweringContext return nullptr; // Not found } - PropagationInfo analyzeLoad(IRInst* context, IRLoad* loadInst) + IRTypeFlowData* analyzeLoad(IRInst* context, IRLoad* loadInst) { // Transfer the prop info from the address to the loaded value auto address = loadInst->getPtr(); return tryGetInfo(context, address); } - PropagationInfo analyzeStore( + IRTypeFlowData* analyzeStore( IRInst* context, IRStore* storeInst, LinkedList& workQueue) @@ -870,101 +838,68 @@ struct DynamicInstLoweringContext return none(); // The store itself doesn't have any info. } - PropagationInfo analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) + IRTypeFlowData* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { auto key = inst->getRequirementKey(); auto witnessTable = inst->getWitnessTable(); auto witnessTableInfo = tryGetInfo(context, witnessTable); - switch (witnessTableInfo.judgment) + if (!witnessTableInfo) + return none(); + + if (as(witnessTableInfo)) + return makeUnbounded(); + + if (auto collection = as(witnessTableInfo)) { - case PropagationJudgment::None: - case PropagationJudgment::Unbounded: - return witnessTableInfo.judgment; - case PropagationJudgment::Set: - { - HashSet results; - forEachInCollection( - witnessTableInfo, - [&](IRInst* table) { results.add(findEntryInConcreteTable(table, key)); }); - return makeSet(results); - } - case PropagationJudgment::Existential: - SLANG_UNEXPECTED("Unexpected LookupWitnessMethod on Existential"); - break; - default: - SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeLookupWitnessMethod"); - break; + HashSet results; + forEachInCollection( + collection, + [&](IRInst* table) { results.add(findEntryInConcreteTable(table, key)); }); + return makeSet(results); } + + SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); } - PropagationInfo analyzeExtractExistentialWitnessTable( + IRTypeFlowData* analyzeExtractExistentialWitnessTable( IRInst* context, IRExtractExistentialWitnessTable* inst) { auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); - switch (operandInfo.judgment) - { - case PropagationJudgment::None: + if (!operandInfo) return none(); - case PropagationJudgment::Unbounded: + + if (as(operandInfo)) return makeUnbounded(); - case PropagationJudgment::Existential: - { - // Convert collection to HashSet and create Set PropagationInfo - HashSet tables; - forEachInCollection(operandInfo, [&](IRInst* table) { tables.add(table); }); - return makeSet(tables); - } - case PropagationJudgment::Set: - SLANG_UNEXPECTED( - "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); - break; - default: - SLANG_UNEXPECTED( - "Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); - break; - } + + if (auto taggedUnion = as(operandInfo)) + return as(taggedUnion->getOperand(1)); + + SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } - PropagationInfo analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) + IRTypeFlowData* analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); - switch (operandInfo.judgment) - { - case PropagationJudgment::None: + if (!operandInfo) return none(); - case PropagationJudgment::Unbounded: + + if (as(operandInfo)) return makeUnbounded(); - case PropagationJudgment::Existential: - { - HashSet types; - forEachInCollection( - operandInfo, - [&](IRInst* table) - { - if (auto witnessTable = cast(table)) // Expect witness table - if (auto concreteType = witnessTable->getConcreteType()) - types.add(concreteType); - }); - return makeSet(types); - } - case PropagationJudgment::Set: - SLANG_UNEXPECTED( - "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); - break; - default: - SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialType"); - break; - } + + if (auto taggedUnion = as(operandInfo)) + return as(taggedUnion->getOperand(0)); + + SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); } - PropagationInfo analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) + IRTypeFlowData* analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { // We don't care about the value itself. // (We rely on the propagation info for the type) @@ -972,122 +907,116 @@ struct DynamicInstLoweringContext return none(); } - - PropagationInfo analyzeSpecialize(IRInst* context, IRSpecialize* inst) + IRTypeFlowData* analyzeSpecialize(IRInst* context, IRSpecialize* inst) { auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); - switch (operandInfo.judgment) - { - case PropagationJudgment::None: + if (!operandInfo) return none(); - case PropagationJudgment::Unbounded: + + if (as(operandInfo)) return makeUnbounded(); - case PropagationJudgment::Existential: - { - SLANG_UNEXPECTED( - "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); - } - case PropagationJudgment::Set: + + if (as(operandInfo)) + { + SLANG_UNEXPECTED( + "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); + } + + if (auto collection = as(operandInfo)) + { + List specializationArgs; + for (auto i = 0; i < inst->getArgCount(); ++i) { - List specializationArgs; - for (auto i = 0; i < inst->getArgCount(); ++i) + // For integer args, add as is (also applies to any value args) + if (as(inst->getArg(i))) { - // For integer args, add as is (also applies to any value args) - if (as(inst->getArg(i))) - { - specializationArgs.add(inst->getArg(i)); - continue; - } + specializationArgs.add(inst->getArg(i)); + continue; + } - // For type args, we need to replace any dynamic args with - // their sets. - // - auto argInfo = tryGetInfo(context, inst->getArg(i)); - switch (argInfo.judgment) - { - case PropagationJudgment::None: - return PropagationJudgment::None; // Can't determine the result just yet. - case PropagationJudgment::Unbounded: - case PropagationJudgment::Existential: - SLANG_UNEXPECTED( - "Unexpected Existential operand in specialization argument. Should be " - "set"); - case PropagationJudgment::Set: - { - if (argInfo.getCollectionCount() == 1) - specializationArgs.add(argInfo.getSingletonValue()); - else - specializationArgs.add(argInfo.collection); - break; - } - default: - SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeSpecialize"); - break; - } + // For type args, we need to replace any dynamic args with + // their sets. + // + auto argInfo = tryGetInfo(context, inst->getArg(i)); + if (!argInfo) + return none(); // Can't determine the result just yet. + + if (as(argInfo) || as(argInfo)) + { + SLANG_UNEXPECTED( + "Unexpected Existential operand in specialization argument. Should be " + "set"); + } + + if (auto argCollection = as(argInfo)) + { + if (getCollectionCount(argCollection) == 1) + specializationArgs.add(getCollectionElement(argCollection, 0)); + else + specializationArgs.add(argCollection); + } + else + { + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeSpecialize"); } + } - IRType* typeOfSpecialization = nullptr; - if (inst->getDataType()->getParent()->getOp() == kIROp_ModuleInst) - typeOfSpecialization = inst->getDataType(); - else if (auto funcType = as(inst->getDataType())) + IRType* typeOfSpecialization = nullptr; + if (inst->getDataType()->getParent()->getOp() == kIROp_ModuleInst) + typeOfSpecialization = inst->getDataType(); + else if (auto funcType = as(inst->getDataType())) + { + auto substituteSets = [&](IRInst* type) -> IRInst* { - auto substituteSets = [&](IRInst* type) -> IRInst* + if (auto info = tryGetInfo(context, type)) { - if (auto info = tryGetInfo(context, type)) + if (auto infoCollection = as(info)) { - if (info.judgment == PropagationJudgment::Set) - { - if (info.getCollectionCount() == 1) - return info.getSingletonValue(); - else - return info.collection; - } + if (getCollectionCount(infoCollection) == 1) + return getCollectionElement(infoCollection, 0); else - return type; + return infoCollection; } else return type; - }; + } + else + return type; + }; - List newParamTypes; - for (auto paramType : funcType->getParamTypes()) - newParamTypes.add((IRType*)substituteSets(paramType)); + List newParamTypes; + for (auto paramType : funcType->getParamTypes()) + newParamTypes.add((IRType*)substituteSets(paramType)); + IRBuilder builder(module); + builder.setInsertInto(module); + typeOfSpecialization = builder.getFuncType( + newParamTypes.getCount(), + newParamTypes.getBuffer(), + (IRType*)substituteSets(funcType->getResultType())); + } + else + { + SLANG_ASSERT_FAILURE("Unexpected data type for specialization instruction"); + } + + // Specialize each element in the set + HashSet specializedSet; + forEachInCollection( + collection, + [&](IRInst* arg) + { + // Create a new specialized instruction for each argument IRBuilder builder(module); builder.setInsertInto(module); - typeOfSpecialization = builder.getFuncType( - newParamTypes.getCount(), - newParamTypes.getBuffer(), - (IRType*)substituteSets(funcType->getResultType())); - } - else - { - SLANG_ASSERT_FAILURE("Unexpected data type for specialization instruction"); - } - - // Specialize each element in the set - HashSet specializedSet; - forEachInCollection( - operandInfo, - [&](IRInst* arg) - { - // Create a new specialized instruction for each argument - IRBuilder builder(module); - builder.setInsertInto(module); - specializedSet.add(builder.emitSpecializeInst( - typeOfSpecialization, - arg, - specializationArgs)); - }); - return makeSet(specializedSet); - } - break; - default: - SLANG_UNEXPECTED( - "Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); - break; + specializedSet.add( + builder.emitSpecializeInst(typeOfSpecialization, arg, specializationArgs)); + }); + return makeSet(specializedSet); } + + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); } void discoverContext(IRInst* context, LinkedList& workQueue) @@ -1129,13 +1058,9 @@ struct DynamicInstLoweringContext if (as(arg)) continue; - if (auto collection = as(arg)) + if (auto collection = as(arg)) { - updateInfo( - context, - param, - PropagationInfo(PropagationJudgment::Set, collection), - workQueue); + updateInfo(context, param, collection, workQueue); } else if (as(arg) || as(arg)) { @@ -1162,7 +1087,7 @@ struct DynamicInstLoweringContext } } - PropagationInfo analyzeCall(IRInst* context, IRCall* inst, LinkedList& workQueue) + IRTypeFlowData* analyzeCall(IRInst* context, IRCall* inst, LinkedList& workQueue) { auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); @@ -1200,10 +1125,10 @@ struct DynamicInstLoweringContext WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); }; - if (calleeInfo.judgment == PropagationJudgment::Set) + if (auto collection = as(calleeInfo)) { // If we have a set of functions, register each one - forEachInCollection(calleeInfo, [&](IRInst* func) { propagateToCallSite(func); }); + forEachInCollection(collection, [&](IRInst* func) { propagateToCallSite(func); }); } if (auto callInfo = tryGetInfo(context, inst)) @@ -1249,9 +1174,9 @@ struct DynamicInstLoweringContext } } - List getParamInfos(IRInst* context) + List getParamInfos(IRInst* context) { - List infos; + List infos; if (as(context)) { for (auto param : as(context)->getParams()) @@ -1370,7 +1295,7 @@ struct DynamicInstLoweringContext } } - PropagationInfo getFuncReturnInfo(IRInst* callee) + IRTypeFlowData* getFuncReturnInfo(IRInst* callee) { funcReturnInfo.addIfNotExists(callee, none()); return funcReturnInfo[callee]; @@ -1404,7 +1329,7 @@ struct DynamicInstLoweringContext } } - PropagationInfo unionPropagationInfo(const List& infos) + IRTypeFlowData* unionPropagationInfo(const List& infos) { if (infos.getCount() == 0) { @@ -1432,62 +1357,55 @@ struct DynamicInstLoweringContext return infos[0]; } - // Need to create a union - collect all possible values based on judgment types + // Need to create a union - collect all possible values based on IR instruction types HashSet allValues; - IRFuncType* dynFuncType = nullptr; - PropagationJudgment unionJudgment = PropagationJudgment::None; - // Determine the union judgment type and collect values + // Determine the union type and collect values + bool hasUnbounded = false; + bool hasExistential = false; + for (auto info : infos) { - switch (info.judgment) + if (!info) + continue; + + if (as(info)) { - case PropagationJudgment::None: - break; - case PropagationJudgment::Set: - unionJudgment = PropagationJudgment::Set; - forEachInCollection(info, [&](IRInst* value) { allValues.add(value); }); - break; - case PropagationJudgment::Existential: - // For existential union, we need to collect all witness tables - // For now, we'll handle this properly by creating a new existential with all tables - unionJudgment = PropagationJudgment::Existential; - forEachInCollection(info, [&](IRInst* value) { allValues.add(value); }); - break; - case PropagationJudgment::Unbounded: // If any info is unbounded, the union is unbounded return makeUnbounded(); } + else if (auto taggedUnion = as(info)) + { + hasExistential = true; + auto tableCollection = as(taggedUnion->getOperand(1)); + forEachInCollection(tableCollection, [&](IRInst* value) { allValues.add(value); }); + } + else if (auto collection = as(info)) + { + forEachInCollection(collection, [&](IRInst* value) { allValues.add(value); }); + } } - if (unionJudgment == PropagationJudgment::Existential) - if (allValues.getCount() > 0) - return makeExistential(allValues); - else - return none(); + if (hasExistential && allValues.getCount() > 0) + return makeExistential( + as(createCollection(kIROp_TableCollection, allValues))); - if (unionJudgment == PropagationJudgment::Set) - if (allValues.getCount() > 0) - return makeSet(allValues); - else - return none(); - - // If we reach here, crash instead of returning none (which could make the analysis go into - // an infinite loop) - // - SLANG_UNEXPECTED("Unhandled prop-info union"); + if (allValues.getCount() > 0) + return makeSet(allValues); + else + return none(); } - PropagationInfo unionPropagationInfo(PropagationInfo info1, PropagationInfo info2) + IRTypeFlowData* unionPropagationInfo(IRTypeFlowData* info1, IRTypeFlowData* info2) { // Union the two infos - List infos; + List infos; infos.add(info1); infos.add(info2); return unionPropagationInfo(infos); } - PropagationInfo analyzeDefault(IRInst* context, IRInst* inst) + IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) { // Check if this is a global type, witness table, or function. // If so, it's a concrete element. We'll create a singleton set for it. @@ -1539,11 +1457,11 @@ struct DynamicInstLoweringContext { auto callee = as(child)->getCallee(); if (auto info = tryGetInfo(context, child)) - if (info.judgment == PropagationJudgment::Existential) + if (as(info)) instWithReplacementTypes.add(Element(context, child)); if (auto calleeInfo = tryGetInfo(context, callee)) - if (calleeInfo.judgment == PropagationJudgment::Set) + if (as(calleeInfo)) valueInstsToLower.add(Element(context, child)); if (as(callee)) @@ -1552,7 +1470,7 @@ struct DynamicInstLoweringContext break; default: if (auto info = tryGetInfo(context, child)) - if (info.judgment == PropagationJudgment::Existential) + if (as(info)) // If this instruction has a set of types, tables, or funcs, // we need to lower it to a unified type. instWithReplacementTypes.add(Element(context, child)); @@ -1632,14 +1550,14 @@ struct DynamicInstLoweringContext return hasChanges; } - bool replaceFuncType(IRFunc* func, PropagationInfo& returnTypeInfo) + bool replaceFuncType(IRFunc* func, IRTypeFlowData* returnTypeInfo) { IRFuncType* origFuncType = as(func->getFullType()); IRType* returnType = origFuncType->getResultType(); - if (returnTypeInfo.judgment == PropagationJudgment::Existential) + if (auto taggedUnion = as(returnTypeInfo)) { // If the return type is existential, we need to replace it with a tuple type - returnType = getTypeForExistential(returnTypeInfo); + returnType = getTypeForExistential(taggedUnion); } List paramTypes; @@ -1647,9 +1565,9 @@ struct DynamicInstLoweringContext { // Extract the existential type from the parameter if it exists auto paramInfo = tryGetInfo(param); - if (paramInfo && paramInfo.judgment == PropagationJudgment::Existential) + if (auto paramTaggedUnion = as(paramInfo)) { - paramTypes.add(getTypeForExistential(paramInfo)); + paramTypes.add(getTypeForExistential(paramTaggedUnion)); } else paramTypes.add(param->getDataType()); @@ -1666,15 +1584,16 @@ struct DynamicInstLoweringContext return true; } - IRType* getTypeForExistential(PropagationInfo info) + IRType* getTypeForExistential(IRCollectionTaggedUnionType* taggedUnion) { // Replace type with Tuple IRBuilder builder(module); builder.setInsertInto(module); HashSet types; + auto tableCollection = as(taggedUnion->getOperand(1)); forEachInCollection( - info, + tableCollection, [&](IRInst* table) { if (auto witnessTable = as(table)) @@ -1690,18 +1609,19 @@ struct DynamicInstLoweringContext bool replaceType(IRInst* context, IRInst* inst) { auto info = tryGetInfo(context, inst); - if (!info || info.judgment != PropagationJudgment::Existential) + auto taggedUnion = as(info); + if (!taggedUnion) return false; if (auto ptrType = as(inst->getDataType())) { IRBuilder builder(module); inst->setFullType( - builder.getPtrTypeWithAddressSpace(getTypeForExistential(info), ptrType)); + builder.getPtrTypeWithAddressSpace(getTypeForExistential(taggedUnion), ptrType)); } else { - inst->setFullType(getTypeForExistential(info)); + inst->setFullType(getTypeForExistential(taggedUnion)); } return true; } @@ -1740,55 +1660,58 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - if (info.isSingleton()) - { - // Found a single possible type. Simple replacement. - inst->replaceUsesWith(info.getSingletonValue()); - inst->removeAndDeallocate(); - return true; - } - else if (info.judgment == PropagationJudgment::Set) + if (auto collection = as(info)) { - // Set of types. - if (inst->getDataType()->getOp() == kIROp_TypeKind) + if (getCollectionCount(collection) == 1) { - // Create an any-value type based on the set of types - auto typeSet = collectionToHashSet(info); - auto unionType = typeSet.getCount() > 1 ? createAnyValueTypeFromInsts(typeSet) - : *typeSet.begin(); - - // Store the mapping for later use - loweredInstToAnyValueType[inst] = unionType; - - // Replace the instruction with the any-value type - inst->replaceUsesWith(unionType); + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(getCollectionElement(collection, 0)); inst->removeAndDeallocate(); return true; } else { - // Get the witness table operand info - auto witnessTableInst = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(context, witnessTableInst); - - if (witnessTableInfo.judgment == PropagationJudgment::Set) + // Set of types. + if (inst->getDataType()->getOp() == kIROp_TypeKind) { - // Create a key mapping function - auto keyMappingFunc = createKeyMappingFunc( - inst->getRequirementKey(), - collectionToHashSet(witnessTableInfo), - collectionToHashSet(info)); - - // Replace with call to key mapping function - auto witnessTableId = builder.emitCallInst( - builder.getUIntType(), - keyMappingFunc, - List({inst->getWitnessTable()})); - inst->replaceUsesWith(witnessTableId); - propagationMap[Element(context, witnessTableId)] = info; + // Create an any-value type based on the set of types + auto typeSet = collectionToHashSet(collection); + auto unionType = typeSet.getCount() > 1 ? createAnyValueTypeFromInsts(typeSet) + : *typeSet.begin(); + + // Store the mapping for later use + loweredInstToAnyValueType[inst] = unionType; + + // Replace the instruction with the any-value type + inst->replaceUsesWith(unionType); inst->removeAndDeallocate(); return true; } + else + { + // Get the witness table operand info + auto witnessTableInst = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(context, witnessTableInst); + + if (auto witnessTableCollection = as(witnessTableInfo)) + { + // Create a key mapping function + auto keyMappingFunc = createKeyMappingFunc( + inst->getRequirementKey(), + collectionToHashSet(witnessTableCollection), + collectionToHashSet(collection)); + + // Replace with call to key mapping function + auto witnessTableId = builder.emitCallInst( + builder.getUIntType(), + keyMappingFunc, + List({inst->getWitnessTable()})); + inst->replaceUsesWith(witnessTableId); + propagationMap[Element(context, witnessTableId)] = info; + inst->removeAndDeallocate(); + return true; + } + } } } @@ -1806,22 +1729,25 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - if (info.isSingleton()) + if (auto collection = as(info)) { - // Found a single possible type. Simple replacement. - inst->replaceUsesWith(info.getSingletonValue()); - inst->removeAndDeallocate(); - return true; - } - else if (info.judgment == PropagationJudgment::Set) - { - // Replace with GetElement(loweredInst, 0) -> uint - auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); - inst->replaceUsesWith(element); - propagationMap[Element(context, element)] = info; - inst->removeAndDeallocate(); - return true; + if (getCollectionCount(collection) == 1) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(getCollectionElement(collection, 0)); + inst->removeAndDeallocate(); + return true; + } + else + { + // Replace with GetElement(loweredInst, 0) -> uint + auto operand = inst->getOperand(0); + auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); + inst->replaceUsesWith(element); + propagationMap[Element(context, element)] = info; + inst->removeAndDeallocate(); + return true; + } } return false; } @@ -1829,7 +1755,8 @@ struct DynamicInstLoweringContext bool lowerExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { auto operandInfo = tryGetInfo(context, inst->getOperand(0)); - if (!operandInfo || operandInfo.judgment != PropagationJudgment::Existential) + auto taggedUnion = as(operandInfo); + if (!taggedUnion) return false; IRBuilder builder(inst); @@ -1854,23 +1781,25 @@ struct DynamicInstLoweringContext bool lowerExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto info = tryGetInfo(context, inst); - if (!info || info.judgment != PropagationJudgment::Set) + auto collection = as(info); + if (!collection) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); - if (info.isSingleton()) + if (getCollectionCount(collection) == 1) { // Found a single possible type. Simple replacement. - inst->replaceUsesWith(info.getSingletonValue()); + auto singletonValue = getCollectionElement(collection, 0); + inst->replaceUsesWith(singletonValue); inst->removeAndDeallocate(); - loweredInstToAnyValueType[inst] = info.getSingletonValue(); + loweredInstToAnyValueType[inst] = singletonValue; return true; } // Create an any-value type based on the set of types - auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(info)); + auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(collection)); // Store the mapping for later use loweredInstToAnyValueType[inst] = anyValueType; @@ -1914,7 +1843,11 @@ struct DynamicInstLoweringContext // List paramDirections; auto calleeInfo = tryGetInfo(context, inst->getCallee()); - auto funcType = as(calleeInfo.getCollectionElement(0)->getDataType()); + auto calleeCollection = as(calleeInfo); + if (!calleeCollection) + return nullptr; + + auto funcType = as(getCollectionElement(calleeCollection, 0)->getDataType()); for (auto paramType : funcType->getParamTypes()) { auto [direction, type] = getParameterDirectionAndType(paramType); @@ -1932,8 +1865,8 @@ struct DynamicInstLoweringContext case ParameterDirection::kParameterDirection_In: { auto argInfo = tryGetInfo(context, arg); - if (argInfo.judgment == PropagationJudgment::Existential) - paramTypes.add(getTypeForExistential(argInfo)); + if (auto argTaggedUnion = as(argInfo)) + paramTypes.add(getTypeForExistential(argTaggedUnion)); else paramTypes.add(arg->getDataType()); break; @@ -1941,8 +1874,8 @@ struct DynamicInstLoweringContext case ParameterDirection::kParameterDirection_Out: { auto argInfo = tryGetInfo(context, arg); - if (argInfo.judgment == PropagationJudgment::Existential) - paramTypes.add(builder.getOutType(getTypeForExistential(argInfo))); + if (auto argTaggedUnion = as(argInfo)) + paramTypes.add(builder.getOutType(getTypeForExistential(argTaggedUnion))); else paramTypes.add(builder.getOutType( as(arg->getDataType())->getValueType())); @@ -1951,8 +1884,8 @@ struct DynamicInstLoweringContext case ParameterDirection::kParameterDirection_InOut: { auto argInfo = tryGetInfo(context, arg); - if (argInfo.judgment == PropagationJudgment::Existential) - paramTypes.add(builder.getInOutType(getTypeForExistential(argInfo))); + if (auto argTaggedUnion = as(argInfo)) + paramTypes.add(builder.getInOutType(getTypeForExistential(argTaggedUnion))); else paramTypes.add(builder.getInOutType( as(arg->getDataType())->getValueType())); @@ -1966,9 +1899,9 @@ struct DynamicInstLoweringContext // Translate result type. IRType* resultType = inst->getDataType(); auto returnInfo = tryGetInfo(context, inst); - if (returnInfo && returnInfo.judgment == PropagationJudgment::Existential) + if (auto returnTaggedUnion = as(returnInfo)) { - resultType = getTypeForExistential(returnInfo); + resultType = getTypeForExistential(returnTaggedUnion); } return builder.getFuncType(paramTypes, resultType); @@ -1984,7 +1917,7 @@ struct DynamicInstLoweringContext for (UInt i = 0; i < specialize->getArgCount(); i++) { auto arg = specialize->getArg(i); - if (as(arg)) + if (as(arg)) return true; // Found a type-flow-collection argument } return false; // No type-flow-collection arguments found @@ -2015,7 +1948,7 @@ struct DynamicInstLoweringContext for (auto param : generic->getFirstBlock()->getParams()) { auto specArg = specializeInst->getArg(argIndex++); - if (as(specArg)) + if (as(specArg)) { // We're dealing with a set of types. if (as(param->getDataType())) @@ -2162,24 +2095,28 @@ struct DynamicInstLoweringContext bool lowerCallToDynamicGeneric(IRInst* context, IRCall* inst) { auto specializedCallee = as(inst->getCallee()); - auto targetContext = tryGetInfo(context, specializedCallee).getSingletonValue(); + auto calleeInfo = tryGetInfo(context, specializedCallee); + auto calleeCollection = as(calleeInfo); + if (!calleeCollection || getCollectionCount(calleeCollection) != 1) + return false; + + auto targetContext = getCollectionElement(calleeCollection, 0); List callArgs; for (auto ii = 0; ii < specializedCallee->getArgCount(); ii++) { auto specArg = specializedCallee->getArg(ii); auto argInfo = tryGetInfo(context, specArg); - if (argInfo.judgment == PropagationJudgment::Set) + if (auto argCollection = as(argInfo)) { - auto collection = as(argInfo.collection); - if (as(collection->getOperand(0))) + if (as(getCollectionElement(argCollection, 0))) { // Needs an index (spec-arg will carry an index, we'll // just need to append it to the call) // callArgs.add(specArg); } - else if (as(collection->getOperand(0))) + else if (as(getCollectionElement(argCollection, 0))) { // Needs no dynamic information. Skip. } @@ -2212,19 +2149,21 @@ struct DynamicInstLoweringContext auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); - if (!calleeInfo || calleeInfo.judgment != PropagationJudgment::Set) + auto calleeCollection = as(calleeInfo); + if (!calleeCollection) return false; - if (calleeInfo.isSingleton()) + if (getCollectionCount(calleeCollection) == 1) { - if (isDynamicGeneric(calleeInfo.getSingletonValue())) + auto singletonValue = getCollectionElement(calleeCollection, 0); + if (isDynamicGeneric(singletonValue)) return lowerCallToDynamicGeneric(context, inst); - if (calleeInfo.getSingletonValue() == callee) + if (singletonValue == callee) return false; IRBuilder builder(inst->getModule()); - builder.replaceOperand(inst->getCalleeUse(), calleeInfo.getSingletonValue()); + builder.replaceOperand(inst->getCalleeUse(), singletonValue); return true; // Replaced with a single function } @@ -2233,7 +2172,8 @@ struct DynamicInstLoweringContext auto expectedFuncType = getExpectedFuncType(context, inst); // Create dispatch function - auto dispatchFunc = createDispatchFunc(collectionToHashSet(calleeInfo), expectedFuncType); + auto dispatchFunc = + createDispatchFunc(collectionToHashSet(calleeCollection), expectedFuncType); // Replace call with dispatch List newArgs; @@ -2255,23 +2195,25 @@ struct DynamicInstLoweringContext bool lowerMakeExistential(IRInst* context, IRMakeExistential* inst) { auto info = tryGetInfo(context, inst); - if (!info || info.judgment != PropagationJudgment::Existential) + auto taggedUnion = as(info); + if (!taggedUnion) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); auto witnessTableInfo = tryGetInfo(context, inst->getWitnessTable()); - if (witnessTableInfo.judgment != PropagationJudgment::Set) + auto witnessTableCollection = as(witnessTableInfo); + if (!witnessTableCollection) return false; // Witness table must be a set of tables IRInst* witnessTableID = nullptr; - if (witnessTableInfo.isSingleton()) + if (getCollectionCount(witnessTableCollection) == 1) { // Get unique ID for the witness table. witnessTableID = builder.getIntValue( builder.getUIntType(), - getUniqueID(witnessTableInfo.getSingletonValue())); + getUniqueID(getCollectionElement(witnessTableCollection, 0))); } else { @@ -2281,8 +2223,9 @@ struct DynamicInstLoweringContext // Collect types from the witness tables to determine the any-value type HashSet types; + auto tableCollection = as(taggedUnion->getOperand(1)); forEachInCollection( - info, + tableCollection, [&](IRInst* table) { if (auto witnessTableInst = as(table)) @@ -2318,12 +2261,14 @@ struct DynamicInstLoweringContext bool lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { auto info = tryGetInfo(context, inst); - if (!info || info.judgment != PropagationJudgment::Existential) + auto taggedUnion = as(info); + if (!taggedUnion) return false; Dictionary mapping; + auto tableCollection = as(taggedUnion->getOperand(1)); forEachInCollection( - info, + tableCollection, [&](IRInst* table) { // Get unique ID for the witness table @@ -2344,7 +2289,7 @@ struct DynamicInstLoweringContext createIntegerMappingFunc(mapping), List({inst->getTypeID()})); - auto existentialTupleType = as(getTypeForExistential(info)); + auto existentialTupleType = as(getTypeForExistential(taggedUnion)); auto existentialTuple = builder.emitMakeTuple( existentialTupleType, List( @@ -2646,14 +2591,14 @@ struct DynamicInstLoweringContext { bool hasChanges = false; - // Lower all global scope ``IRTypeFlowCollection`` objects that + // Lower all global scope ``IRCollectionBase`` objects that // are made up of types. // for (auto inst : module->getGlobalInsts()) { - if (auto collection = as(inst)) + if (auto collection = as(inst)) { - if (collection->getOp() == kIROp_TypeFlowCollection) + if (collection->getOp() == kIROp_TypeCollection) { HashSet types; for (UInt i = 0; i < collection->getOperandCount(); i++) @@ -2705,10 +2650,10 @@ struct DynamicInstLoweringContext DiagnosticSink* sink; // Mapping from instruction to propagation information - Dictionary propagationMap; + Dictionary propagationMap; // Mapping from function to return value propagation information - Dictionary funcReturnInfo; + Dictionary funcReturnInfo; // Mapping from functions to call-sites. Dictionary> funcCallSites; diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index 7c573f867e7..481e4ae0219 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -140,10 +140,8 @@ struct ArgumentPackWorkItem bool isAnyValueType(IRType* type) { - if (as(type)) + if (as(type) || as(type)) return true; - if (auto collection = as(type)) - return as(collection->getOperand(0)) != nullptr; return false; } @@ -301,8 +299,8 @@ IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* inte auto anyValType = cast(item.dstArg->getDataType())->getValueType(); auto concreteVal = builder->emitLoad(item.concreteArg); auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) - ? builder->emitPackAnyValue(anyValType, concreteVal) - : builder->emitReinterpret(anyValType, concreteVal); + ? builder->emitPackAnyValue(anyValType, concreteVal) + : builder->emitReinterpret(anyValType, concreteVal); builder->emitStore(item.dstArg, packedVal); } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b2ef78c4dd6..bd98de38229 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8531,6 +8531,9 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) if (as(this)) return false; + if (as(this)) + return false; + switch (getOp()) { // By default, assume that we might have side effects, @@ -8722,9 +8725,6 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_DetachDerivative: return false; - case kIROp_TypeFlowCollection: - return false; - case kIROp_Div: case kIROp_IRem: if (isIntegralScalarOrCompositeType(getFullType())) From 7cd7d71bbea2d0bfd67582e2b62a95a920e1ae51 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 30 Jul 2025 20:55:48 -0400 Subject: [PATCH 019/105] Halfway through refactoring the code to split analysis and lowering --- source/slang/slang-ir-insts-stable-names.lua | 5 +- source/slang/slang-ir-insts.lua | 7 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 1801 ++++++++++++----- source/slang/slang-ir-lower-dynamic-insts.h | 1 + source/slang/slang-ir-specialize.cpp | 2 + 5 files changed, 1303 insertions(+), 513 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index abb9cd3b68a..e5dc000ebc6 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -675,5 +675,8 @@ return { ["TypeFlowData.CollectionBase.GenericCollection"] = 671, ["TypeFlowData.UnboundedCollection"] = 672, ["TypeFlowData.CollectionTagType"] = 673, - ["TypeFlowData.CollectionTaggedUnionType"] = 674 + ["TypeFlowData.CollectionTaggedUnionType"] = 674, + ["GetTagForSuperCollection"] = 675, + ["GetTagForMappedCollection"] = 676, + ["GetTagFromSequentialID"] = 677 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 0db1668eba4..6069af690a8 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2187,7 +2187,12 @@ local insts = { { CollectionTagType = {} }, -- Operand is TypeCollection/FuncCollection/TableCollection (funcs/tables) { CollectionTaggedUnionType = {}} -- Operand is TypeCollection, TableCollection for existential }, - } + }, + { GetTagForSuperCollection = {} }, -- Translate a tag from a set to its equivalent in a super-set + { GetTagForMappedCollection = {} }, -- Translate a tag from a set to its equivalent in a different set + -- based on a mapping induced by a lookup key + { GetTagFromSequentialID = {} } -- Translate an existing sequential ID & and interface type into a tag + -- the provided collection. } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index b2d866c3522..a35cc81d72a 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -2,6 +2,7 @@ #include "slang-ir-any-value-marshalling.h" #include "slang-ir-clone.h" +#include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir-witness-table-wrapper.h" @@ -209,6 +210,29 @@ bool areInfosEqual(IRInst* a, IRInst* b) return a == b; } +// Helper to iterate over collection elements +template +void forEachInCollection(IRCollectionBase* info, F func) +{ + for (UInt i = 0; i < info->getOperandCount(); ++i) + func(info->getOperand(i)); +} + +template +void forEachInCollection(IRCollectionTagType* tagType, F func) +{ + forEachInCollection(as(tagType->getOperand(0)), func); +} + +static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) +{ + if (auto concreteTable = as(witnessTable)) + for (auto entry : concreteTable->getEntries()) + if (entry->getRequirementKey() == key) + return entry->getSatisfyingVal(); + return nullptr; // Not found +} + struct DynamicInstLoweringContext { // Helper methods for creating canonical collections @@ -306,12 +330,14 @@ struct DynamicInstLoweringContext elements.getBuffer())); } - /*IRCollectionTaggedUnionType* makeExistential(const HashSet& tables) + IRCollectionTagType* makeTagType(IRCollectionBase* collection) { - SLANG_ASSERT(tables.getCount() > 0); - auto tableCollection = createCollection(kIROp_TableCollection, tables); - return makeExistential(tableCollection); - }*/ + IRInst* collectionInst = collection; + // Create the tag type from the collection + IRBuilder builder(module); + return as( + builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); + } UCount getCollectionCount(IRCollectionBase* collection) { @@ -326,6 +352,12 @@ struct DynamicInstLoweringContext return getCollectionCount(as(typeCollection)); } + UCount getCollectionCount(IRCollectionTagType* tagType) + { + auto collection = tagType->getOperand(0); + return getCollectionCount(as(collection)); + } + IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) { if (!collection || index >= collection->getOperandCount()) @@ -333,6 +365,12 @@ struct DynamicInstLoweringContext return collection->getOperand(index); } + IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) + { + auto typeCollection = collectionTagType->getOperand(0); + return getCollectionElement(as(typeCollection), index); + } + IRUnboundedCollection* makeUnbounded() { IRBuilder builder(module); @@ -342,14 +380,6 @@ struct DynamicInstLoweringContext IRTypeFlowData* none() { return nullptr; } - // Helper to iterate over collection elements - template - void forEachInCollection(IRCollectionBase* info, F func) - { - for (UInt i = 0; i < info->getOperandCount(); ++i) - func(info->getOperand(i)); - } - // Helper to convert collection to HashSet HashSet collectionToHashSet(IRCollectionBase* info) { @@ -558,6 +588,67 @@ struct DynamicInstLoweringContext } } + IRInst* upcastCollection(IRInst* context, IRInst* arg, IRTypeFlowData* destInfo) + { + /*auto argInfo = tryGetInfo(context, arg); + if (!argInfo || !destInfo) + return arg;*/ + auto argInfo = as(arg->getDataType()); + if (!argInfo || !destInfo) + return arg; + + if (as(argInfo) && as(destInfo)) + { + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + // If the sets of witness tables are not equal, reinterpret to the parameter type + IRBuilder builder(module); + builder.setInsertAfter(arg); + + auto argTupleType = as(arg->getDataType()); + auto tag = + builder.emitGetTupleElement((IRType*)argTupleType->getOperand(0), arg, 0); + auto value = + builder.emitGetTupleElement((IRType*)argTupleType->getOperand(1), arg, 1); + + auto newTag = upcastCollection( + context, + tag, + makeTagType(cast(destInfo->getOperand(1)))); + auto newValue = upcastCollection( + context, + value, + cast(destInfo->getOperand(0))); + + return builder.emitMakeTuple(newTag, newValue); + } + } + else if (as(argInfo) && as(destInfo)) + { + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + IRBuilder builder(module); + builder.setInsertAfter(arg); + return builder + .emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); + } + } + else if (as(argInfo) && as(destInfo)) + { + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + // If the sets of witness tables are not equal, reinterpret to the parameter type + IRBuilder builder(module); + builder.setInsertAfter(arg); + return builder.emitReinterpret((IRType*)destInfo, arg); + } + } + } + + /* IRInst* maybeReinterpret(IRInst* context, IRInst* arg, IRTypeFlowData* destInfo) { auto argInfo = tryGetInfo(context, arg); @@ -715,6 +806,7 @@ struct DynamicInstLoweringContext return changed; } + */ void processInstForPropagation(IRInst* context, IRInst* inst, LinkedList& workQueue) { @@ -804,20 +896,13 @@ struct DynamicInstLoweringContext if (as(witnessTableInfo)) return makeUnbounded(); - HashSet tables; - if (auto collection = as(witnessTableInfo)) - return makeExistential(collection); + if (as(witnessTable)) + return makeExistential(as(makeSingletonSet(witnessTable))); - SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); - } + if (auto collectionTag = as(witnessTableInfo)) + return makeExistential(cast(collectionTag->getOperand(0))); - static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) - { - if (auto concreteTable = as(witnessTable)) - for (auto entry : concreteTable->getEntries()) - if (entry->getRequirementKey() == key) - return entry->getSatisfyingVal(); - return nullptr; // Not found + SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } IRTypeFlowData* analyzeLoad(IRInst* context, IRLoad* loadInst) @@ -851,13 +936,13 @@ struct DynamicInstLoweringContext if (as(witnessTableInfo)) return makeUnbounded(); - if (auto collection = as(witnessTableInfo)) + if (auto tagType = as(witnessTableInfo)) { HashSet results; forEachInCollection( - collection, + cast(tagType->getOperand(0)), [&](IRInst* table) { results.add(findEntryInConcreteTable(table, key)); }); - return makeSet(results); + return makeTagType(makeSet(results)); } SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); @@ -877,7 +962,7 @@ struct DynamicInstLoweringContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return as(taggedUnion->getOperand(1)); + return makeTagType(cast(taggedUnion->getOperand(1))); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } @@ -894,17 +979,24 @@ struct DynamicInstLoweringContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return as(taggedUnion->getOperand(0)); + return makeTagType(cast(taggedUnion->getOperand(0))); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); } IRTypeFlowData* analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { - // We don't care about the value itself. - // (We rely on the propagation info for the type) - // - return none(); + auto operand = inst->getOperand(0); + auto operandInfo = tryGetInfo(context, operand); + + if (!operandInfo) + return none(); + + if (as(operandInfo)) + return makeUnbounded(); + + if (auto taggedUnion = as(operandInfo)) + return cast(taggedUnion->getOperand(0)); } IRTypeFlowData* analyzeSpecialize(IRInst* context, IRSpecialize* inst) @@ -924,7 +1016,7 @@ struct DynamicInstLoweringContext "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); } - if (auto collection = as(operandInfo)) + if (auto collectionTag = as(operandInfo)) { List specializationArgs; for (auto i = 0; i < inst->getArgCount(); ++i) @@ -950,12 +1042,13 @@ struct DynamicInstLoweringContext "set"); } - if (auto argCollection = as(argInfo)) + if (auto argCollectionTag = as(argInfo)) { - if (getCollectionCount(argCollection) == 1) - specializationArgs.add(getCollectionElement(argCollection, 0)); + if (getCollectionCount(argCollectionTag) == 1) + specializationArgs.add(getCollectionElement(argCollectionTag, 0)); else - specializationArgs.add(argCollection); + specializationArgs.add( + cast(argCollectionTag->getOperand(0))); } else { @@ -972,12 +1065,12 @@ struct DynamicInstLoweringContext { if (auto info = tryGetInfo(context, type)) { - if (auto infoCollection = as(info)) + if (auto infoCollectionTag = as(info)) { - if (getCollectionCount(infoCollection) == 1) - return getCollectionElement(infoCollection, 0); + if (getCollectionCount(infoCollectionTag) == 1) + return getCollectionElement(infoCollectionTag, 0); else - return infoCollection; + return as(infoCollectionTag->getOperand(0)); } else return type; @@ -1004,7 +1097,7 @@ struct DynamicInstLoweringContext // Specialize each element in the set HashSet specializedSet; forEachInCollection( - collection, + collectionTag, [&](IRInst* arg) { // Create a new specialized instruction for each argument @@ -1013,7 +1106,8 @@ struct DynamicInstLoweringContext specializedSet.add( builder.emitSpecializeInst(typeOfSpecialization, arg, specializationArgs)); }); - return makeSet(specializedSet); + + return makeTagType(makeSet(specializedSet)); } SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); @@ -1329,80 +1423,66 @@ struct DynamicInstLoweringContext } } - IRTypeFlowData* unionPropagationInfo(const List& infos) + template + T* unionCollection(T* collection1, T* collection2) { - if (infos.getCount() == 0) - { - return none(); - } + SLANG_ASSERT(as(collection1) && as(collection2)); + SLANG_ASSERT(collection1->getOp() == collection2->getOp()); + + if (!collection1) + return collection2; + if (!collection2) + return collection1; + if (collection1 == collection2) + return collection1; + + HashSet allValues; + // Collect all values from both collections + forEachInCollection(collection1, [&](IRInst* value) { allValues.add(value); }); + forEachInCollection(collection2, [&](IRInst* value) { allValues.add(value); }); + + return as(createCollection( + collection1->getOp(), + allValues)); // Create a new collection with the union of values + } + + IRTypeFlowData* unionPropagationInfo(IRTypeFlowData* info1, IRTypeFlowData* info2) + { + if (!info1) + return info2; + if (!info2) + return info1; - if (infos.getCount() == 1) + if (as(info1) && as(info2)) { - return infos[0]; + // If either info is unbounded, the union is unbounded + return makeUnbounded(); } - // Check if all infos are the same - bool allSame = true; - for (Index i = 1; i < infos.getCount(); i++) + if (as(info1) && as(info2)) { - if (!areInfosEqual(infos[0], infos[i])) - { - allSame = false; - break; - } + return makeExistential( + unionCollection( + cast(info1->getOperand(1)), + cast(info2->getOperand(1)))); } - if (allSame) + if (as(info1) && as(info2)) { - return infos[0]; + return makeTagType( + unionCollection( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); } - // Need to create a union - collect all possible values based on IR instruction types - HashSet allValues; - - // Determine the union type and collect values - bool hasUnbounded = false; - bool hasExistential = false; - - for (auto info : infos) + if (as(info1) && as(info2)) { - if (!info) - continue; - - if (as(info)) - { - // If any info is unbounded, the union is unbounded - return makeUnbounded(); - } - else if (auto taggedUnion = as(info)) - { - hasExistential = true; - auto tableCollection = as(taggedUnion->getOperand(1)); - forEachInCollection(tableCollection, [&](IRInst* value) { allValues.add(value); }); - } - else if (auto collection = as(info)) - { - forEachInCollection(collection, [&](IRInst* value) { allValues.add(value); }); - } + return unionCollection( + cast(info1), + cast(info2)); } - if (hasExistential && allValues.getCount() > 0) - return makeExistential( - as(createCollection(kIROp_TableCollection, allValues))); - - if (allValues.getCount() > 0) - return makeSet(allValues); - else - return none(); - } - - IRTypeFlowData* unionPropagationInfo(IRTypeFlowData* info1, IRTypeFlowData* info2) - { - // Union the two infos - List infos; - infos.add(info1); - infos.add(info2); - return unionPropagationInfo(infos); + SLANG_UNEXPECTED("Unhandled propagation info types in unionPropagationInfo"); } IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) @@ -1416,6 +1496,91 @@ struct DynamicInstLoweringContext return none(); // Default case, no propagation info } + bool lowerInstsInBlock(IRInst* context, IRBlock* block) + { + List instsToLower; + bool hasChanges = false; + for (auto inst : block->getChildren()) + instsToLower.add(inst); + + UIndex paramIndex = 0; + for (auto inst : instsToLower) + hasChanges |= lowerInst(context, inst); + + return hasChanges; + } + + bool lowerFunc(IRFunc* func) + { + bool hasChanges = false; + for (auto block : func->getBlocks()) + hasChanges |= lowerInstsInBlock(func, block); + + for (auto block : func->getBlocks()) + { + UIndex paramIndex = 0; + // Process each parameter in this block (these are phi parameters) + for (auto param : block->getParams()) + { + auto paramInfo = tryGetInfo(param); + if (!paramInfo) + { + paramIndex++; + continue; + } + + // Find all predecessors of this block + for (auto pred : block->getPredecessors()) + { + auto terminator = pred->getTerminator(); + if (auto unconditionalBranch = as(terminator)) + { + auto arg = unconditionalBranch->getArg(paramIndex); + auto newArg = + upcastCollection(func, arg, as(param->getDataType())); + + if (newArg != arg) + { + hasChanges = true; + // Replace the argument in the branch instruction + SLANG_ASSERT(!as(unconditionalBranch)); + unconditionalBranch->setOperand(1 + paramIndex, newArg); + } + } + } + + paramIndex++; + } + + // Is the terminator a return instruction? + if (auto returnInst = as(block->getTerminator())) + { + if (!as(returnInst->getVal()->getDataType())) + { + auto funcReturnInfo = getFuncReturnInfo(func); + auto newReturnVal = + upcastCollection(func, returnInst->getVal(), funcReturnInfo); + if (newReturnVal != returnInst->getVal()) + { + // Replace the return value with the reinterpreted value + hasChanges = true; + returnInst->setOperand(0, newReturnVal); + } + } + } + } + + auto effectiveFuncType = getEffectiveFuncType(func); + if (effectiveFuncType != func->getFullType()) + { + hasChanges = true; + func->setFullType(effectiveFuncType); + } + + return hasChanges; + } + + /* bool lowerInstsInFunc(IRFunc* func) { // Collect all instructions that need lowering @@ -1501,22 +1666,24 @@ struct DynamicInstLoweringContext return hasChanges; } + */ bool performDynamicInstLowering() { - List funcsForTypeReplacement; + // List funcsForTypeReplacement; List funcsToProcess; for (auto globalInst : module->getGlobalInsts()) if (auto func = as(globalInst)) { - funcsForTypeReplacement.add(func); + // funcsForTypeReplacement.add(func); funcsToProcess.add(func); } bool hasChanges = false; do { + /* while (funcsForTypeReplacement.getCount() > 0) { auto func = funcsForTypeReplacement.getLast(); @@ -1525,6 +1692,7 @@ struct DynamicInstLoweringContext // Replace the function type with a concrete type if it has existential return types hasChanges |= replaceFuncType(func, this->funcReturnInfo[func]); } + */ while (funcsToProcess.getCount() > 0) { @@ -1532,7 +1700,7 @@ struct DynamicInstLoweringContext funcsToProcess.removeLast(); // Lower the instructions in the function - hasChanges |= lowerInstsInFunc(func); + hasChanges |= lowerFunc(func); } // The above loops might have added new contexts to lower. @@ -1540,16 +1708,16 @@ struct DynamicInstLoweringContext { hasChanges |= lowerContext(context); auto newFunc = cast(this->loweredContexts[context]); - funcsForTypeReplacement.add(newFunc); funcsToProcess.add(newFunc); } this->contextsToLower.clear(); - } while (funcsForTypeReplacement.getCount() > 0 || funcsToProcess.getCount() > 0); + } while (funcsToProcess.getCount() > 0); return hasChanges; } + /* bool replaceFuncType(IRFunc* func, IRTypeFlowData* returnTypeInfo) { IRFuncType* origFuncType = as(func->getFullType()); @@ -1583,14 +1751,15 @@ struct DynamicInstLoweringContext func->setFullType(newFuncType); return true; } + */ IRType* getTypeForExistential(IRCollectionTaggedUnionType* taggedUnion) { - // Replace type with Tuple + // Replace type with Tuple IRBuilder builder(module); builder.setInsertInto(module); - HashSet types; + /*HashSet types; auto tableCollection = as(taggedUnion->getOperand(1)); forEachInCollection( tableCollection, @@ -1599,29 +1768,62 @@ struct DynamicInstLoweringContext if (auto witnessTable = as(table)) if (auto concreteType = witnessTable->getConcreteType()) types.add(concreteType); - }); + });*/ + + auto typeCollection = cast(taggedUnion->getOperand(0)); + auto tableCollection = cast(taggedUnion->getOperand(1)); + return builder.getTupleType( + List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); + } + + IRType* getLoweredType(IRTypeFlowData* info) + { + if (!info) + return nullptr; + + if (as(info)) + return nullptr; + + if (auto taggedUnion = as(info)) + { + // If this is a tagged union, we need to create a tuple type + return getTypeForExistential(taggedUnion); + } + + if (auto collectionTag = as(info)) + { + // If this is a collection tag, we can return the collection type + return (IRType*)collectionTag; + } + + if (auto collection = as(info)) + { + if (getCollectionCount(collection) == 1) + { + // If there's only one type in the collection, return it directly + return (IRType*)getCollectionElement(collection, 0); + } + + // If this is a concrete collection, return it directly + return (IRType*)collection; + } - SLANG_ASSERT(types.getCount() > 0); - auto unionType = types.getCount() > 1 ? createAnyValueTypeFromInsts(types) : *types.begin(); - return builder.getTupleType(List({builder.getUIntType(), (IRType*)unionType})); + SLANG_UNEXPECTED("Unhandled IRTypeFlowData type in getLoweredType"); } bool replaceType(IRInst* context, IRInst* inst) { auto info = tryGetInfo(context, inst); - auto taggedUnion = as(info); - if (!taggedUnion) - return false; - if (auto ptrType = as(inst->getDataType())) { IRBuilder builder(module); - inst->setFullType( - builder.getPtrTypeWithAddressSpace(getTypeForExistential(taggedUnion), ptrType)); + if (auto loweredType = getLoweredType(info)) + inst->setFullType(builder.getPtrTypeWithAddressSpace(loweredType, ptrType)); } else { - inst->setFullType(getTypeForExistential(taggedUnion)); + if (auto loweredType = getLoweredType(info)) + inst->setFullType(loweredType); } return true; } @@ -1646,8 +1848,14 @@ struct DynamicInstLoweringContext return lowerMakeExistential(context, as(inst)); case kIROp_CreateExistentialObject: return lowerCreateExistentialObject(context, as(inst)); + case kIROp_Store: + SLANG_UNEXPECTED("handle this"); default: - return false; + { + if (auto info = tryGetInfo(context, inst)) + return replaceType(context, info); + return false; + } } } @@ -1657,10 +1865,14 @@ struct DynamicInstLoweringContext if (!info) return false; + auto collectionTagType = as(info); + if (!collectionTagType) + return false; + IRBuilder builder(inst); builder.setInsertBefore(inst); - if (auto collection = as(info)) + /*if (auto collection = as(info)) { if (getCollectionCount(collection) == 1) { @@ -1669,51 +1881,77 @@ struct DynamicInstLoweringContext inst->removeAndDeallocate(); return true; } - else + else if (auto typeCollection = as(collection)) { // Set of types. - if (inst->getDataType()->getOp() == kIROp_TypeKind) - { - // Create an any-value type based on the set of types - auto typeSet = collectionToHashSet(collection); - auto unionType = typeSet.getCount() > 1 ? createAnyValueTypeFromInsts(typeSet) - : *typeSet.begin(); - - // Store the mapping for later use - loweredInstToAnyValueType[inst] = unionType; - - // Replace the instruction with the any-value type - inst->replaceUsesWith(unionType); - inst->removeAndDeallocate(); - return true; - } - else - { - // Get the witness table operand info - auto witnessTableInst = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(context, witnessTableInst); + // Create an any-value type based on the set of types + auto typeSet = collectionToHashSet(collection); + auto unionType = typeSet.getCount() > 1 ? createAnyValueTypeFromInsts(typeSet) + : *typeSet.begin(); - if (auto witnessTableCollection = as(witnessTableInfo)) - { - // Create a key mapping function - auto keyMappingFunc = createKeyMappingFunc( - inst->getRequirementKey(), - collectionToHashSet(witnessTableCollection), - collectionToHashSet(collection)); - - // Replace with call to key mapping function - auto witnessTableId = builder.emitCallInst( - builder.getUIntType(), - keyMappingFunc, - List({inst->getWitnessTable()})); - inst->replaceUsesWith(witnessTableId); - propagationMap[Element(context, witnessTableId)] = info; - inst->removeAndDeallocate(); - return true; - } - } + // Store the mapping for later use + loweredInstToAnyValueType[inst] = unionType; + + // Replace the instruction with the any-value type + inst->replaceUsesWith(typeCollection); + inst->removeAndDeallocate(); + return true; } } + else*/ + + if (getCollectionCount(collectionTagType) == 1) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(getCollectionElement(collectionTagType, 0)); + inst->removeAndDeallocate(); + return true; + } + + if (auto typeCollection = as(collectionTagType->getOperand(0))) + { + // If this is a type collection, we can replace it with the collection type + // We don't currently care about the tag of a type. + // + inst->replaceUsesWith(typeCollection); + inst->removeAndDeallocate(); + return true; + } + + // Get the witness table operand info + auto witnessTableInst = inst->getWitnessTable(); + auto witnessTableInfo = tryGetInfo(context, witnessTableInst); + + SLANG_ASSERT(as(witnessTableInfo)); + List operands = {witnessTableInst, inst->getRequirementKey()}; + + auto newInst = builder.emitIntrinsicInst( + (IRType*)info, + kIROp_GetTagForMappedCollection, + operands.getCount(), + operands.getBuffer()); + inst->replaceUsesWith(newInst); + propagationMap[Element(context, newInst)] = info; + inst->removeAndDeallocate(); + + /*if (auto witnessTableCollection = as(witnessTableInfo)) + { + // Create a key mapping function + auto keyMappingFunc = createKeyMappingFunc( + inst->getRequirementKey(), + collectionToHashSet(witnessTableCollection), + collectionToHashSet(collection)); + + // Replace with call to key mapping function + auto witnessTableId = builder.emitCallInst( + builder.getUIntType(), + keyMappingFunc, + List({inst->getWitnessTable()})); + inst->replaceUsesWith(witnessTableId); + propagationMap[Element(context, witnessTableId)] = info; + inst->removeAndDeallocate(); + return true; + }*/ return false; } @@ -1729,27 +1967,27 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - if (auto collection = as(info)) + auto collectionTagType = as(info); + if (!collectionTagType) + return false; + + if (getCollectionCount(collectionTagType) == 1) { - if (getCollectionCount(collection) == 1) - { - // Found a single possible type. Simple replacement. - inst->replaceUsesWith(getCollectionElement(collection, 0)); - inst->removeAndDeallocate(); - return true; - } - else - { - // Replace with GetElement(loweredInst, 0) -> uint - auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement(builder.getUIntType(), operand, 0); - inst->replaceUsesWith(element); - propagationMap[Element(context, element)] = info; - inst->removeAndDeallocate(); - return true; - } + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(getCollectionElement(collectionTagType, 0)); + inst->removeAndDeallocate(); + return true; + } + else + { + // Replace with GetElement(loweredInst, 0) -> uint + auto operand = inst->getOperand(0); + auto element = builder.emitGetTupleElement((IRType*)collectionTagType, operand, 0); + inst->replaceUsesWith(element); + propagationMap[Element(context, element)] = info; + inst->removeAndDeallocate(); + return true; } - return false; } bool lowerExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) @@ -1759,20 +1997,26 @@ struct DynamicInstLoweringContext if (!taggedUnion) return false; - IRBuilder builder(inst); - builder.setInsertBefore(inst); - + /* // Check if we have a lowered any-value type for the result auto resultType = inst->getDataType(); auto loweredType = loweredInstToAnyValueType.tryGetValue(inst); if (loweredType) { resultType = (IRType*)*loweredType; - } + }*/ + + auto info = tryGetInfo(context, inst); + auto typeCollection = as(info); + if (!typeCollection) + return false; + + IRBuilder builder(inst); + builder.setInsertBefore(inst); - // Replace with GetElement(loweredInst, 1) -> AnyValueType + // Replace with GetElement(loweredInst, 1) : TypeCollection auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement(resultType, operand, 1); + auto element = builder.emitGetTupleElement((IRType*)info, operand, 1); inst->replaceUsesWith(element); inst->removeAndDeallocate(); return true; @@ -1781,31 +2025,32 @@ struct DynamicInstLoweringContext bool lowerExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto info = tryGetInfo(context, inst); - auto collection = as(info); - if (!collection) + auto collectionTagType = as(info); + if (!collectionTagType) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); - if (getCollectionCount(collection) == 1) + if (getCollectionCount(collectionTagType) == 1) { // Found a single possible type. Simple replacement. - auto singletonValue = getCollectionElement(collection, 0); + auto singletonValue = getCollectionElement(collectionTagType, 0); inst->replaceUsesWith(singletonValue); inst->removeAndDeallocate(); - loweredInstToAnyValueType[inst] = singletonValue; + // loweredInstToAnyValueType[inst] = singletonValue; return true; } // Create an any-value type based on the set of types + /* auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(collection)); // Store the mapping for later use - loweredInstToAnyValueType[inst] = anyValueType; + loweredInstToAnyValueType[inst] = anyValueType;*/ - // Replace the instruction with the any-value type - inst->replaceUsesWith(anyValueType); + // Replace the instruction with the collection type. + inst->replaceUsesWith(collectionTagType->getOperand(0)); inst->removeAndDeallocate(); return true; } @@ -1833,14 +2078,148 @@ struct DynamicInstLoweringContext return {ParameterDirection::kParameterDirection_In, paramType}; } - IRFuncType* getExpectedFuncType(IRInst* context, IRCall* inst) + IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirection direction, IRType* type) { - IRBuilder builder(module); - builder.setInsertInto(module); - - // We'll retreive just the parameter directions from the callee's func-type, - // since that can't be different before & after the type-flow lowering. - // + switch (direction) + { + case ParameterDirection::kParameterDirection_In: + return type; + case ParameterDirection::kParameterDirection_Out: + return builder->getOutType(type); + case ParameterDirection::kParameterDirection_InOut: + return builder->getInOutType(type); + case ParameterDirection::kParameterDirection_ConstRef: + return builder->getConstRefType(type); + default: + SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); + return nullptr; + } + } + + IRFuncType* getEffectiveFuncType(IRInst* callee) + { + IRBuilder builder(module); + + List paramTypes; + IRType* resultType = nullptr; + + auto updateType = [&](IRType* currentType, IRType* newType) -> IRType* + { + if (auto collection = as(currentType)) + { + List collectionElements; + forEachInCollection( + collection, + [&](IRInst* element) + { + if (auto loweredType = getLoweredType(tryGetInfo(callee, element))) + collectionElements.add(loweredType); + else + collectionElements.add(element); + }); + collectionElements.add(newType); + + // If this is a collection, we need to create a new collection with the new type + auto newCollection = createCollection(collection->getOp(), collectionElements); + return (IRType*)newCollection; + } + else if (currentType == newType) + { + return currentType; + } + else if (currentType == nullptr) + { + return newType; + } + else // Need to create a new collection. + { + List collectionElements; + + SLANG_ASSERT(!as(currentType) && !as(newType)); + + collectionElements.add(currentType); + collectionElements.add(newType); + + // If this is a collection, we need to create a new collection with the new type + auto newCollection = createCollection(currentType->getOp(), collectionElements); + return (IRType*)newCollection; + } + }; + + auto updateParamType = [&](UInt index, IRType* paramType) -> IRType* + { + if (paramTypes.getCount() <= index) + { + // If we don't have enough types, just add the new type + paramTypes.add(paramType); + return paramType; + } + else + { + // Otherwise, update the existing type + auto [currentDirection, currentType] = + getParameterDirectionAndType(paramTypes[index]); + auto [newDirection, newType] = getParameterDirectionAndType(paramType); + auto updatedType = updateType(currentType, paramType); + SLANG_ASSERT(currentDirection == newDirection); + paramTypes[index] = fromDirectionAndType(&builder, currentDirection, updatedType); + return updatedType; + } + }; + + List contextsToProcess; + if (auto collection = as(callee)) + { + forEachInCollection(collection, [&](IRInst* func) { contextsToProcess.add(func); }); + } + else if (auto collectionTagType = as(callee->getDataType())) + { + forEachInCollection( + collectionTagType, + [&](IRInst* func) { contextsToProcess.add(func); }); + } + else + { + // Otherwise, just process the single function + contextsToProcess.add(callee); + } + + for (auto func : contextsToProcess) + { + auto paramInfos = getParamInfos(callee); + auto paramDirections = getParamDirections(callee); + for (UInt i = 0; i < paramInfos.getCount(); i++) + { + if (auto loweredType = getLoweredType(paramInfos[i])) + updateParamType( + i, + fromDirectionAndType(&builder, paramDirections[i], loweredType)); + else + SLANG_UNEXPECTED("Unhandled parameter type in getEffectiveFuncType"); + } + + auto returnType = getFuncReturnInfo(func); + if (auto newResultType = getLoweredType(returnType)) + { + resultType = updateType(resultType, newResultType); + } + else + { + resultType = updateType(resultType, (IRType*)returnType); + } + } + + return builder.getFuncType(paramTypes, resultType); + } + + /*IRFuncType* getExpectedFuncType(IRInst* context, IRCall* inst) + { + IRBuilder builder(module); + builder.setInsertInto(module); + + // We'll retreive just the parameter directions from the callee's func-type, + // since that can't be different before & after the type-flow lowering. + // List paramDirections; auto calleeInfo = tryGetInfo(context, inst->getCallee()); auto calleeCollection = as(calleeInfo); @@ -1905,7 +2284,7 @@ struct DynamicInstLoweringContext } return builder.getFuncType(paramTypes, resultType); - } + }*/ bool isDynamicGeneric(IRInst* callee) { @@ -1943,31 +2322,26 @@ struct DynamicInstLoweringContext IRCloneEnv cloneEnv; Index argIndex = 0; - UCount extraIndices = 0; + List extraParamTypes; // Map the generic's parameters to the specialized arguments. for (auto param : generic->getFirstBlock()->getParams()) { auto specArg = specializeInst->getArg(argIndex++); - if (as(specArg)) + if (auto collection = as(specArg)) { // We're dealing with a set of types. if (as(param->getDataType())) { - HashSet collectionSet; - for (auto index = 0; index < specArg->getOperandCount(); index++) - { - auto operand = specArg->getOperand(index); - collectionSet.add(operand); - } - - auto unionType = createAnyValueTypeFromInsts(collectionSet); - cloneEnv.mapOldValToNew[param] = unionType; + // auto unionType = createAnyValueTypeFromInsts(collectionSet); + cloneEnv.mapOldValToNew[param] = collection; } else if (as(param->getDataType())) { // Add an integer param to the func. - cloneEnv.mapOldValToNew[param] = builder.emitParam(builder.getUIntType()); - extraIndices++; + auto tagType = (IRType*)makeTagType(collection); + cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); + extraParamTypes.add(tagType); + // extraIndices++; } } else @@ -2036,8 +2410,8 @@ struct DynamicInstLoweringContext // Add extra indices to the func-type parameters List funcTypeParams; - for (Index i = 0; i < extraIndices; i++) - funcTypeParams.add(builder.getUIntType()); + for (Index i = 0; i < extraParamTypes.getCount(); i++) + funcTypeParams.add(extraParamTypes[i]); for (auto paramType : loweredFuncType->getParamTypes()) funcTypeParams.add(paramType); @@ -2092,6 +2466,38 @@ struct DynamicInstLoweringContext return context; } + List getArgsForDynamicSpecialization(IRSpecialize* specializedCallee) + { + List callArgs; + for (auto ii = 0; ii < specializedCallee->getArgCount(); ii++) + { + auto specArg = specializedCallee->getArg(ii); + auto argInfo = specArg->getDataType(); + if (auto argCollection = as(argInfo)) + { + if (as(getCollectionElement(argCollection, 0))) + { + // Needs an index (spec-arg will carry an index, we'll + // just need to append it to the call) + // + callArgs.add(specArg); + } + else if (as(getCollectionElement(argCollection, 0))) + { + // Needs no dynamic information. Skip. + } + else + { + // If it's a witness table, we need to handle it differently + // For now, we will not lower this case. + SLANG_UNEXPECTED("Unhandled type-flow-collection in dynamic generic call"); + } + } + } + + return callArgs; + } + bool lowerCallToDynamicGeneric(IRInst* context, IRCall* inst) { auto specializedCallee = as(inst->getCallee()); @@ -2147,49 +2553,154 @@ struct DynamicInstLoweringContext bool lowerCall(IRInst* context, IRCall* inst) { auto callee = inst->getCallee(); - auto calleeInfo = tryGetInfo(context, callee); + auto expectedFuncType = getEffectiveFuncType(callee); - auto calleeCollection = as(calleeInfo); - if (!calleeCollection) - return false; + // First, we'll legalize all operands by upcasting if necessary. + // This needs to be done even if the callee is not a collection. + // + // List paramTypeFlows = getParamInfos(callee); + // List paramDirections = getParamDirections(callee); + bool changed = false; + List newArgs; + for (UInt i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + const auto [paramDirection, paramType] = + getParameterDirectionAndType(expectedFuncType->getParamType(i)); + if (!as(paramType)) + { + SLANG_ASSERT(!as(arg->getDataType())); + newArgs.add(arg); // No need to change the argument + continue; + } + + IRInst* newArg = nullptr; + switch (paramDirection) + { + case kParameterDirection_In: + newArgs.add(upcastCollection(context, arg, as(paramType))); + break; + default: + SLANG_UNEXPECTED("Unhandled parameter direction in lowerCall"); + } + + /*if (newArg != arg) + { + // If the argument changed, replace the old one. + changed = true; + IRBuilder builder(inst->getModule()); + builder.setInsertBefore(inst); + builder.replaceOperand(&inst->getArgs()[i + 1], newArg); + }*/ + } + + // New we need to determine the new callee. + IRInst* newCallee = nullptr; + + List extraArgs; - if (getCollectionCount(calleeCollection) == 1) + // auto calleeInfo = tryGetInfo(context, callee); + auto calleeInfo = as(callee->getDataType()); + auto calleeCollection = as(calleeInfo); + if (!calleeCollection) + newCallee = callee; // Not a collection, no need to lower + else if (getCollectionCount(calleeCollection) == 1) { auto singletonValue = getCollectionElement(calleeCollection, 0); + if (singletonValue == callee) + { + newCallee = callee; + } + else + { + changed = true; + if (isDynamicGeneric(singletonValue)) + extraArgs = getArgsForDynamicSpecialization(cast(singletonValue)); + + newCallee = singletonValue; + } + + /* if (isDynamicGeneric(singletonValue)) return lowerCallToDynamicGeneric(context, inst); if (singletonValue == callee) return false; + */ + + // IRBuilder builder(inst->getModule()); + // builder.replaceOperand(inst->getCalleeUse(), singletonValue); + // newCallee = singletonValue; // Replace with the single value + // return true; // Replaced with a single function + } + else + { + changed = true; + // Multiple elements in the collection. + extraArgs.add(callee); + auto funcCollection = cast(callee->getOperand(0)); + + // Check if the first element is a dynamic generic (this should imply that all + // elements are similar dynamic generics, but we might want to check for that..) + // + if (isDynamicGeneric(getCollectionElement(funcCollection, 0))) + { + SLANG_UNEXPECTED("Dynamic generic in a collection call"); + auto dynamicSpecArgs = getArgsForDynamicSpecialization( + cast(getCollectionElement(funcCollection, 0))); + for (auto& arg : dynamicSpecArgs) + extraArgs.add(arg); + } + + if (!as(funcCollection->getDataType())) + { + auto typeForCollection = getEffectiveFuncType(funcCollection); + funcCollection->setFullType(typeForCollection); + } - IRBuilder builder(inst->getModule()); - builder.replaceOperand(inst->getCalleeUse(), singletonValue); - return true; // Replaced with a single function + newCallee = funcCollection; } IRBuilder builder(inst); builder.setInsertBefore(inst); - auto expectedFuncType = getExpectedFuncType(context, inst); - // Create dispatch function + + /* Create dispatch function auto dispatchFunc = - createDispatchFunc(collectionToHashSet(calleeCollection), expectedFuncType); + createDispatchFunc(collectionToHashSet(calleeCollection), expectedFuncType);*/ + // Replace call with dispatch - List newArgs; + /*List newArgs; newArgs.add(callee); // Add the lookup as first argument (will get lowered into an uint tag) for (UInt i = 1; i < inst->getOperandCount(); i++) { newArgs.add(inst->getOperand(i)); - } + }*/ - auto newCall = builder.emitCallInst(inst->getDataType(), dispatchFunc, newArgs); - inst->replaceUsesWith(newCall); - if (auto info = tryGetInfo(context, inst)) - propagationMap[Element(context, newCall)] = info; - replaceType(context, newCall); // "maybe replace type" - inst->removeAndDeallocate(); - return true; + if (changed) + { + List callArgs; + + auto newCall = + builder.emitCallInst(expectedFuncType->getResultType(), newCallee, newArgs); + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); + return true; + } + else if (expectedFuncType->getResultType() != inst->getDataType()) + { + // If we didn't change the callee or the arguments, we still might + // need to update the result type. + // + inst->setFullType(expectedFuncType->getResultType()); + return true; + } + else + { + // Nothing changed. + return false; + } } bool lowerMakeExistential(IRInst* context, IRMakeExistential* inst) @@ -2202,56 +2713,53 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - auto witnessTableInfo = tryGetInfo(context, inst->getWitnessTable()); - auto witnessTableCollection = as(witnessTableInfo); - if (!witnessTableCollection) - return false; // Witness table must be a set of tables + // auto witnessTableInfo = tryGetInfo(context, inst->getWitnessTable()); + + // Collect types from the witness tables to determine the any-value type + auto tableCollection = as(taggedUnion->getOperand(1)); + auto typeCollection = as(taggedUnion->getOperand(0)); IRInst* witnessTableID = nullptr; - if (getCollectionCount(witnessTableCollection) == 1) + if (auto witnessTable = as(inst->getWitnessTable())) { // Get unique ID for the witness table. - witnessTableID = builder.getIntValue( + /*witnessTableID = builder.getIntValue( builder.getUIntType(), - getUniqueID(getCollectionElement(witnessTableCollection, 0))); - } - else - { - // Dynamic. Use the witness table inst as an integer key. + getUniqueID(getCollectionElement(witnessTableCollection, 0)));*/ + auto singletonTagType = makeTagType(makeSingletonSet(witnessTable)); + auto zeroValueOfTagType = builder.getIntValue((IRType*)singletonTagType, 0); + witnessTableID = builder.emitIntrinsicInst( + (IRType*)makeTagType(tableCollection), + kIROp_GetTagForSuperCollection, + 1, + &zeroValueOfTagType); + } + else if ( + auto witnessTableTag = as(inst->getWitnessTable()->getDataType())) + { + // Dynamic. Use the witness table inst as a tag witnessTableID = inst->getWitnessTable(); } - // Collect types from the witness tables to determine the any-value type - HashSet types; - auto tableCollection = as(taggedUnion->getOperand(1)); - forEachInCollection( - tableCollection, - [&](IRInst* table) - { - if (auto witnessTableInst = as(table)) - { - if (auto concreteType = witnessTableInst->getConcreteType()) - { - types.add(concreteType); - } - } - }); // Create the appropriate any-value type - SLANG_ASSERT(types.getCount() > 0); - auto unionType = types.getCount() > 1 ? createAnyValueType(types) : *types.begin(); + auto collectionType = getCollectionCount(typeCollection) == 1 + ? (IRType*)typeCollection->getOperand(0) + : (IRType*)typeCollection; // Pack the value - auto packedValue = builder.emitPackAnyValue(unionType, inst->getWrappedValue()); + auto packedValue = builder.emitPackAnyValue(collectionType, inst->getWrappedValue()); + + auto taggedUnionTupleType = getLoweredType(taggedUnion); // Create tuple (table_unique_id, PackAnyValue(val)) - auto tupleType = builder.getTupleType( - List({builder.getUIntType(), packedValue->getDataType()})); IRInst* tupleArgs[] = {witnessTableID, packedValue}; - auto tuple = builder.emitMakeTuple(tupleType, 2, tupleArgs); + auto tuple = builder.emitMakeTuple(taggedUnionTupleType, 2, tupleArgs); + /* if (auto info = tryGetInfo(context, inst)) propagationMap[Element(context, tuple)] = info; + */ inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); @@ -2259,6 +2767,40 @@ struct DynamicInstLoweringContext } bool lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) + { + auto info = tryGetInfo(context, inst); + auto taggedUnion = as(info); + if (!taggedUnion) + return false; + + auto taggedUnionTupleType = getLoweredType(taggedUnion); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + List args; + args.add(inst->getDataType()); + args.add(inst->getTypeID()); + auto translatedTag = builder.emitIntrinsicInst( + (IRType*)taggedUnionTupleType->getOperand(0), + kIROp_GetTagFromSequentialID, + args.getCount(), + args.getBuffer()); + + auto packedValue = builder.emitPackAnyValue( + (IRType*)taggedUnionTupleType->getOperand(1), + inst->getValue()); + + auto newInst = builder.emitMakeTuple( + taggedUnionTupleType, + List({translatedTag, packedValue})); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + + /*bool _lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { auto info = tryGetInfo(context, inst); auto taggedUnion = as(info); @@ -2302,7 +2844,7 @@ struct DynamicInstLoweringContext inst->replaceUsesWith(existentialTuple); inst->removeAndDeallocate(); return true; - } + }*/ UInt getUniqueID(IRInst* funcOrTable) { @@ -2315,70 +2857,7 @@ struct DynamicInstLoweringContext return newId; } - IRFunc* createIntegerMappingFunc(Dictionary& mapping) - { - // Create a function that maps input IDs to output IDs - IRBuilder builder(module); - - auto funcType = - builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); - auto func = builder.createFunc(); - builder.setInsertInto(func); - func->setFullType(funcType); - - auto entryBlock = builder.emitBlock(); - builder.setInsertInto(entryBlock); - - auto param = builder.emitParam(builder.getUIntType()); - - // Create default block that returns 0 - auto defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), 0)); - - // Go back to entry block and create switch - builder.setInsertInto(entryBlock); - - // Create case blocks for each input table - List caseValues; - List caseBlocks; - - for (auto item : mapping) - { - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), item.second)); - - caseValues.add(builder.getIntValue(builder.getUIntType(), item.first)); - caseBlocks.add(caseBlock); - } - - // Create flattened case arguments array - List flattenedCaseArgs; - for (Index i = 0; i < caseValues.getCount(); i++) - { - flattenedCaseArgs.add(caseValues[i]); - flattenedCaseArgs.add(caseBlocks[i]); - } - - // Emit an unreachable block for the break block. - auto unreachableBlock = builder.emitBlock(); - builder.setInsertInto(unreachableBlock); - builder.emitUnreachable(); - - // Go back to entry and emit switch - builder.setInsertInto(entryBlock); - builder.emitSwitch( - param, - unreachableBlock, - defaultBlock, - flattenedCaseArgs.getCount(), - flattenedCaseArgs.getBuffer()); - - return func; - } - + /* IRFunc* createKeyMappingFunc( IRInst* key, const HashSet& inputTables, @@ -2394,183 +2873,34 @@ struct DynamicInstLoweringContext mapping[inputId] = outputId; } - return createIntegerMappingFunc(mapping); + return createIntegerMappingFunc(module, mapping); } + */ - IRFunc* createDispatchFunc(const HashSet& funcs, IRFuncType* expectedFuncType) - { - // Create a dispatch function with switch-case for each function - IRBuilder builder(module); - - // Extract parameter types from the first function in the set - List paramTypes; - paramTypes.add(builder.getUIntType()); // ID parameter - - // Get parameter types from first function - List funcArray; - for (auto func : funcs) - funcArray.add(func); - - for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) - { - paramTypes.add(expectedFuncType->getParamType(i)); - } - - auto resultType = expectedFuncType->getResultType(); - auto funcType = builder.getFuncType(paramTypes, resultType); - auto func = builder.createFunc(); - builder.setInsertInto(func); - func->setFullType(funcType); + bool isExistentialType(IRType* type) { return as(type) != nullptr; } - auto entryBlock = builder.emitBlock(); - builder.setInsertInto(entryBlock); + bool isInterfaceType(IRType* type) { return as(type) != nullptr; } - auto idParam = builder.emitParam(builder.getUIntType()); + HashSet collectExistentialTables(IRInterfaceType* interfaceType) + { + HashSet tables; - // Create parameters for the original function arguments - List originalParams; - for (UInt i = 1; i < paramTypes.getCount(); i++) + IRWitnessTableType* targetTableType = nullptr; + // First, find the IRWitnessTableType that wraps the given interfaceType + for (auto use = interfaceType->firstUse; use; use = use->nextUse) { - originalParams.add(builder.emitParam(paramTypes[i])); + if (auto wtType = as(use->getUser())) + { + if (wtType->getConformanceType() == interfaceType) + { + targetTableType = wtType; + break; + } + } } - // Create default block - auto defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); - } - else - { - // Return a default-constructed value - auto defaultValue = builder.emitDefaultConstruct(resultType); - builder.emitReturn(defaultValue); - } - - // Go back to entry block and create switch - builder.setInsertInto(entryBlock); - - // Create case blocks for each function - List caseValues; - List caseBlocks; - - for (auto funcInst : funcs) - { - auto funcId = getUniqueID(funcInst); - - auto wrapperFunc = - emitWitnessTableWrapper(funcInst->getModule(), funcInst, expectedFuncType); - - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); - - List callArgs; - auto wrappedFuncType = as(wrapperFunc->getDataType()); - for (UIndex ii = 0; ii < originalParams.getCount(); ii++) - { - callArgs.add(originalParams[ii]); - } - - // Call the specific function - auto callResult = - builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); - - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); - } - else - { - builder.emitReturn(callResult); - } - - caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); - caseBlocks.add(caseBlock); - } - - // Create flattened case arguments array - List flattenedCaseArgs; - for (Index i = 0; i < caseValues.getCount(); i++) - { - flattenedCaseArgs.add(caseValues[i]); - flattenedCaseArgs.add(caseBlocks[i]); - } - - // Create an unreachable block for the break block. - auto unreachableBlock = builder.emitBlock(); - builder.setInsertInto(unreachableBlock); - builder.emitUnreachable(); - - // Go back to entry and emit switch - builder.setInsertInto(entryBlock); - builder.emitSwitch( - idParam, - unreachableBlock, - defaultBlock, - flattenedCaseArgs.getCount(), - flattenedCaseArgs.getBuffer()); - - return func; - } - - IRAnyValueType* createAnyValueType(const HashSet& types) - { - IRBuilder builder(module); - auto size = calculateAnyValueSize(types); - return builder.getAnyValueType(size); - } - - IRAnyValueType* createAnyValueTypeFromInsts(const HashSet& typeInsts) - { - HashSet types; - for (auto inst : typeInsts) - { - if (auto type = as(inst)) - { - types.add(type); - } - } - return createAnyValueType(types); - } - - SlangInt calculateAnyValueSize(const HashSet& types) - { - SlangInt maxSize = 0; - for (auto type : types) - { - auto size = getAnyValueSize(type); - if (size > maxSize) - maxSize = size; - } - return maxSize; - } - - bool isExistentialType(IRType* type) { return as(type) != nullptr; } - - bool isInterfaceType(IRType* type) { return as(type) != nullptr; } - - HashSet collectExistentialTables(IRInterfaceType* interfaceType) - { - HashSet tables; - - IRWitnessTableType* targetTableType = nullptr; - // First, find the IRWitnessTableType that wraps the given interfaceType - for (auto use = interfaceType->firstUse; use; use = use->nextUse) - { - if (auto wtType = as(use->getUser())) - { - if (wtType->getConformanceType() == interfaceType) - { - targetTableType = wtType; - break; - } - } - } - - // If the target witness table type was found, gather all witness tables using it - if (targetTableType) + // If the target witness table type was found, gather all witness tables using it + if (targetTableType) { for (auto use = targetTableType->firstUse; use; use = use->nextUse) { @@ -2587,7 +2917,7 @@ struct DynamicInstLoweringContext return tables; } - bool lowerCollectionTypes() + /*bool lowerTypeCollections() { bool hasChanges = false; @@ -2615,8 +2945,45 @@ struct DynamicInstLoweringContext } } + return hasChanges; + }*/ + + /* + bool transferDataToInstTypes() + { + bool hasChanges = false; + + for (auto& pair : propagationMap) + { + auto instWithContext = pair.first; + auto flowData = pair.second; + + if (!flowData) + continue; // No propagation data + + if (as(flowData)) + { + // If the flow data is an unbounded collection, don't touch + // the types. + continue; + } + + auto inst = instWithContext.inst; + auto context = instWithContext.context; + + // Only transfer data for insts that are in top-level + // contexts. We'll come back to specialized contexts later. + // + if (context->getOp() == kIROp_Func) + { + inst->setFullType((IRType*)flowData); + hasChanges = true; + } + } + return hasChanges; } + */ bool processModule() { @@ -2628,14 +2995,14 @@ struct DynamicInstLoweringContext // Phase 1.5: Insert reinterprets for points where sets merge // e.g. phi, return, call // - hasChanges |= insertReinterprets(); + // hasChanges |= insertReinterprets(); // Phase 2: Dynamic Instruction Lowering hasChanges |= performDynamicInstLowering(); // Phase 3: Lower collection types. - if (hasChanges) - lowerCollectionTypes(); + // if (hasChanges) + // lowerTypeCollections(); return hasChanges; } @@ -2675,6 +3042,418 @@ struct DynamicInstLoweringContext Dictionary loweredContexts; }; +SlangInt calculateAnyValueSize(const HashSet& types) +{ + SlangInt maxSize = 0; + for (auto type : types) + { + auto size = getAnyValueSize(type); + if (size > maxSize) + maxSize = size; + } + return maxSize; +} + +IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& types) +{ + auto size = calculateAnyValueSize(types); + return builder->getAnyValueType(size); +} + +IRFunc* createDispatchFunc(IRFuncCollection* collection) +{ + // An effective func type should have been set during the dynamic-inst-lowering + // pass. + // + IRFuncType* expectedFuncType = cast(collection->getFullType()); + + // Create a dispatch function with switch-case for each function + IRBuilder builder(collection->getModule()); + + List paramTypes; + paramTypes.add(builder.getUIntType()); // ID parameter + for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) + paramTypes.add(expectedFuncType->getParamType(i)); + + auto resultType = expectedFuncType->getResultType(); + auto funcType = builder.getFuncType(paramTypes, resultType); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto idParam = builder.emitParam(builder.getUIntType()); + + // Create parameters for the original function arguments + List originalParams; + for (UInt i = 1; i < paramTypes.getCount(); i++) + { + originalParams.add(builder.emitParam(paramTypes[i])); + } + + // Create default block + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + // Return a default-constructed value + auto defaultValue = builder.emitDefaultConstruct(resultType); + builder.emitReturn(defaultValue); + } + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each function + List caseValues; + List caseBlocks; + + UIndex funcSeqID = 0; + forEachInCollection( + collection, + [&](IRInst* funcInst) + { + auto funcId = funcSeqID++; + auto wrapperFunc = + emitWitnessTableWrapper(funcInst->getModule(), funcInst, expectedFuncType); + + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + + List callArgs; + auto wrappedFuncType = as(wrapperFunc->getDataType()); + for (UIndex ii = 0; ii < originalParams.getCount(); ii++) + { + callArgs.add(originalParams[ii]); + } + + // Call the specific function + auto callResult = + builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); + + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + builder.emitReturn(callResult); + } + + caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); + caseBlocks.add(caseBlock); + }); + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Create an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + idParam, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + + +IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping) +{ + // Create a function that maps input IDs to output IDs + IRBuilder builder(module); + + auto funcType = + builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto param = builder.emitParam(builder.getUIntType()); + + // Create default block that returns 0 + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), 0)); + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each input table + List caseValues; + List caseBlocks; + + for (auto item : mapping) + { + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), item.second)); + + caseValues.add(builder.getIntValue(builder.getUIntType(), item.first)); + caseBlocks.add(caseBlock); + } + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Emit an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + param, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + +// This context lowers `IRGetTagFromSequentialID`, +// `IRGetTagForSuperCollection`, and `IRGetTagForMappedCollection` instructions, +// + +struct TagOpsLoweringContext : public InstPassBase +{ + TagOpsLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) + { + auto srcInterfaceType = cast(inst->getOperand(0)); + auto srcSeqID = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + + UIndex dstSeqID = 0; + forEachInCollection( + destCollection, + [&](IRInst* table) + { + // Get unique ID for the witness table + auto witnessTable = cast(table); + auto outputId = dstSeqID++; + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping[inputId] = outputId; // Map ID to itself for now + } + }); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping), + List({srcSeqID})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + + void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) + { + auto srcCollection = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + + IRBuilder builder(inst->getModule()); + + List indices; + for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + { + // Find in destCollection + auto srcElement = srcCollection->getOperand(i); + for (UInt j = 0; j < destCollection->getOperandCount(); j++) + { + auto destElement = destCollection->getOperand(j); + if (srcElement == destElement) + { + indices.add(builder.getIntValue(builder.getUIntType(), j)); + break; // Found the index + } + } + // destCollection must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } + + // Create an array for the lookup + auto lookupArrayType = builder.getArrayType( + builder.getUIntType(), + builder.getIntValue(builder.getUIntType(), indices.getCount())); + auto lookupArray = + builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); + auto resultID = + builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); + inst->replaceUsesWith(resultID); + inst->removeAndDeallocate(); + } + + void lowerGetTagForMappedCollection(IRGetTagForMappedCollection* inst) + { + auto srcCollection = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + auto key = cast(inst->getOperand(1)); + + IRBuilder builder(inst->getModule()); + + List indices; + for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + { + // Find in destCollection + auto srcElement = findEntryInConcreteTable(srcCollection->getOperand(i), key); + for (UInt j = 0; j < destCollection->getOperandCount(); j++) + { + auto destElement = destCollection->getOperand(j); + if (srcElement == destElement) + { + indices.add(builder.getIntValue(builder.getUIntType(), j)); + break; // Found the index + } + } + + // destCollection must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } + + // Create an array for the lookup + auto lookupArrayType = builder.getArrayType( + builder.getUIntType(), + builder.getIntValue(builder.getUIntType(), indices.getCount())); + auto lookupArray = + builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); + auto resultID = + builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); + inst->replaceUsesWith(resultID); + inst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_GetTagFromSequentialID: + lowerGetTagFromSequentialID(as(inst)); + break; + case kIROp_GetTagForSuperCollection: + lowerGetTagForSuperCollection(as(inst)); + break; + case kIROp_GetTagForMappedCollection: + lowerGetTagForMappedCollection(as(inst)); + break; + default: + break; + } + } + + void processModule() + { + processAllReachableInsts([&](IRInst* inst) { return processInst(inst); }); + } +}; + +// This context lowers `IRTypeCollection` and `IRFuncCollection` instructions +struct CollectionLoweringContext : public InstPassBase +{ + CollectionLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerTypeCollection(IRTypeCollection* collection) + { + HashSet types; + for (UInt i = 0; i < collection->getOperandCount(); i++) + { + if (auto type = as(collection->getOperand(i))) + { + types.add(type); + } + } + + IRBuilder builder(collection->getModule()); + auto anyValueType = createAnyValueType(&builder, types); + collection->replaceUsesWith(anyValueType); + } + + void lowerFuncCollection(IRFuncCollection* collection) + { + IRBuilder builder(collection->getModule()); + auto dispatchFunc = createDispatchFunc(collection); + collection->replaceUsesWith(dispatchFunc); + } + + void processModule() + { + processInstsOfType( + kIROp_FuncCollection, + [&](IRFuncCollection* inst) { return lowerFuncCollection(inst); }); + + processInstsOfType( + kIROp_TypeCollection, + [&](IRTypeCollection* inst) { return lowerTypeCollection(inst); }); + + processInstsOfType( + kIROp_CollectionTagType, + [&](IRCollectionTagType* inst) + { + IRBuilder builder(inst->getModule()); + inst->replaceUsesWith(builder.getUIntType()); + }); + } +}; + +void lowerCollectionAndTagInsts(IRModule* module, DiagnosticSink* sink) +{ + TagOpsLoweringContext tagContext(module); + tagContext.processModule(); + + CollectionLoweringContext context(module); + context.processModule(); +} + // Main entry point bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) { diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index 9e7d779aac6..32a782b7285 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -8,4 +8,5 @@ namespace Slang { // Main entry point for the pass bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); +void lowerCollectionAndTagInsts(IRModule* module, DiagnosticSink* sink); } // namespace Slang diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 7e131226e91..2a96cfc490d 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3048,6 +3048,8 @@ void finalizeSpecialization(IRModule* module) break; } } + + lowerCollectionAndTagInsts(module, nullptr); } IRInst* specializeGenericImpl( From 015e93bc19f6497ff0a36cef99b390db2a636bfa Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:40:06 -0400 Subject: [PATCH 020/105] Dispatch tests passing with overhauled approach --- source/slang/slang-ir-lower-dynamic-insts.cpp | 780 +++++++++++++----- source/slang/slang-ir-specialize.cpp | 188 +++++ .../slang/slang-ir-witness-table-wrapper.cpp | 144 +++- 3 files changed, 882 insertions(+), 230 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index a35cc81d72a..554b77ce206 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -4,6 +4,7 @@ #include "slang-ir-clone.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" +#include "slang-ir-specialize.h" #include "slang-ir-util.h" #include "slang-ir-witness-table-wrapper.h" #include "slang-ir.h" @@ -237,15 +238,6 @@ struct DynamicInstLoweringContext { // Helper methods for creating canonical collections IRCollectionBase* createCollection(IROp op, const HashSet& elements) - { - List sortedElements; - for (auto element : elements) - sortedElements.add(element); - - return createCollection(op, sortedElements); - } - - IRCollectionBase* createCollection(IROp op, const List& elements) { SLANG_ASSERT( op == kIROp_TypeCollection || op == kIROp_FuncCollection || @@ -259,8 +251,11 @@ struct DynamicInstLoweringContext if (element->getParent()->getOp() != kIROp_ModuleInst) SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); + List sortedElements; + for (auto element : elements) + sortedElements.add(element); + // Sort elements by their unique IDs to ensure canonical ordering - List sortedElements = elements; sortedElements.sort( [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); @@ -278,16 +273,17 @@ struct DynamicInstLoweringContext IROp getCollectionTypeForInst(IRInst* inst) { + if (as(inst)) + return kIROp_GenericCollection; + if (as(inst->getDataType())) return kIROp_TypeCollection; - else if (as(inst) && !as(inst)) - return kIROp_TypeCollection; else if (as(inst->getDataType())) return kIROp_FuncCollection; + else if (as(inst) && !as(inst)) + return kIROp_TypeCollection; else if (as(inst->getDataType())) return kIROp_TableCollection; - else if (as(inst->getDataType())) - return kIROp_GenericCollection; else SLANG_UNEXPECTED("Unsupported collection type for instruction"); } @@ -399,16 +395,37 @@ struct DynamicInstLoweringContext IRTypeFlowData* tryGetInfo(IRInst* context, IRInst* inst) { + if (auto typeFlowData = as(inst->getDataType())) + { + // If the instruction already has a stablilized type flow data, + // return it directly. + // + return typeFlowData; + } + if (!inst->getParent()) return none(); // If this is a global instruction (parent is module), return concrete info if (as(inst->getParent())) + { if (as(inst) || as(inst) || as(inst) || as(inst)) + { + // We won't directly handle interface types, but rather treat objects of interface + // type as objects that can be specialized with collections. + // + if (as(inst)) + return none(); + + if (as(inst) && as(getGenericReturnVal(inst))) + return none(); + return makeSingletonSet(inst); + } else return none(); + } return tryGetInfo(Element(context, inst)); } @@ -588,18 +605,51 @@ struct DynamicInstLoweringContext } } - IRInst* upcastCollection(IRInst* context, IRInst* arg, IRTypeFlowData* destInfo) + IRInst* upcastCollection(IRInst* context, IRInst* arg, IRType* destInfo) { - /*auto argInfo = tryGetInfo(context, arg); - if (!argInfo || !destInfo) - return arg;*/ - auto argInfo = as(arg->getDataType()); + auto argInfo = arg->getDataType(); if (!argInfo || !destInfo) return arg; - if (as(argInfo) && as(destInfo)) + if (as(argInfo) && as(destInfo)) { - if (getCollectionCount(as(argInfo)) != + auto argTupleType = as(argInfo); + auto destTupleType = as(destInfo); + + List upcastedElements; + bool hasUpcastedElements = false; + + IRBuilder builder(module); + builder.setInsertAfter(arg); + + // Upcast each element of the tuple + for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) + { + auto argElementType = argTupleType->getOperand(i); + auto destElementType = destTupleType->getOperand(i); + + // If the element types are different, we need to reinterpret + if (argElementType != destElementType) + { + hasUpcastedElements = true; + upcastedElements.add(upcastCollection( + context, + builder.emitGetTupleElement((IRType*)argElementType, arg, i), + (IRType*)destElementType)); + } + else + { + upcastedElements.add( + builder.emitGetTupleElement((IRType*)argElementType, arg, i)); + } + } + + if (hasUpcastedElements) + { + return builder.emitMakeTuple(upcastedElements); + } + + /*if (getCollectionCount(as(argInfo)) != getCollectionCount(as(destInfo))) { // If the sets of witness tables are not equal, reinterpret to the parameter type @@ -622,7 +672,7 @@ struct DynamicInstLoweringContext cast(destInfo->getOperand(0))); return builder.emitMakeTuple(newTag, newValue); - } + }*/ } else if (as(argInfo) && as(destInfo)) { @@ -646,6 +696,14 @@ struct DynamicInstLoweringContext return builder.emitReinterpret((IRType*)destInfo, arg); } } + else if (!as(argInfo) && as(destInfo)) + { + IRBuilder builder(module); + builder.setInsertAfter(arg); + return builder.emitPackAnyValue((IRType*)destInfo, arg); + } + + return arg; // Can use as-is. } /* @@ -1016,20 +1074,24 @@ struct DynamicInstLoweringContext "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); } - if (auto collectionTag = as(operandInfo)) + if (as(operandInfo) || as(operandInfo)) { + // If any of the specialization arguments need a tag (or the generic itself is a tag), + // we need the result to also be wrapped in a tag type. + bool needsTag = false; + List specializationArgs; for (auto i = 0; i < inst->getArgCount(); ++i) { - // For integer args, add as is (also applies to any value args) - if (as(inst->getArg(i))) + // For concrete args, add as-is. + if (isGlobalInst(inst->getArg(i))) { specializationArgs.add(inst->getArg(i)); continue; } - // For type args, we need to replace any dynamic args with - // their sets. + // For dynamic args, we need to replace them with + // their sets (if available) // auto argInfo = tryGetInfo(context, inst->getArg(i)); if (!argInfo) @@ -1047,8 +1109,11 @@ struct DynamicInstLoweringContext if (getCollectionCount(argCollectionTag) == 1) specializationArgs.add(getCollectionElement(argCollectionTag, 0)); else + { + needsTag = true; specializationArgs.add( cast(argCollectionTag->getOperand(0))); + } } else { @@ -1070,7 +1135,9 @@ struct DynamicInstLoweringContext if (getCollectionCount(infoCollectionTag) == 1) return getCollectionElement(infoCollectionTag, 0); else + { return as(infoCollectionTag->getOperand(0)); + } } else return type; @@ -1094,10 +1161,21 @@ struct DynamicInstLoweringContext SLANG_ASSERT_FAILURE("Unexpected data type for specialization instruction"); } + IRCollectionBase* collection = nullptr; + if (auto _collection = as(operandInfo)) + { + collection = _collection; + } + else if (auto collectionTagType = as(operandInfo)) + { + needsTag = true; + collection = cast(collectionTagType->getOperand(0)); + } + // Specialize each element in the set HashSet specializedSet; forEachInCollection( - collectionTag, + collection, [&](IRInst* arg) { // Create a new specialized instruction for each argument @@ -1107,7 +1185,10 @@ struct DynamicInstLoweringContext builder.emitSpecializeInst(typeOfSpecialization, arg, specializationArgs)); }); - return makeTagType(makeSet(specializedSet)); + if (needsTag) + return makeTagType(makeSet(specializedSet)); + else + return makeSet(specializedSet); } SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); @@ -1154,7 +1235,7 @@ struct DynamicInstLoweringContext if (auto collection = as(arg)) { - updateInfo(context, param, collection, workQueue); + updateInfo(context, param, makeTagType(collection), workQueue); } else if (as(arg) || as(arg)) { @@ -1219,9 +1300,14 @@ struct DynamicInstLoweringContext WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); }; - if (auto collection = as(calleeInfo)) + if (auto collectionTag = as(calleeInfo)) { // If we have a set of functions, register each one + forEachInCollection(collectionTag, [&](IRInst* func) { propagateToCallSite(func); }); + } + else if (auto collection = as(calleeInfo)) + { + // If we have a collection of functions, register each one forEachInCollection(collection, [&](IRInst* func) { propagateToCallSite(func); }); } @@ -1268,6 +1354,46 @@ struct DynamicInstLoweringContext } } + bool isGlobalInst(IRInst* inst) { return inst->getParent()->getOp() == kIROp_ModuleInst; } + + List getParamEffectiveTypes(IRInst* context) + { + List effectiveTypes; + IRFunc* func = nullptr; + if (as(context)) + { + func = as(context); + } + else if (auto specialize = as(context)) + { + auto generic = specialize->getBase(); + auto innerFunc = getGenericReturnVal(generic); + func = cast(innerFunc); + } + else + { + // If it's not a function or a specialization, we can't get parameter info + SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); + } + + UIndex idx = 0; + for (auto param : func->getParams()) + { + if (auto newType = tryGetInfo(context, param)) + effectiveTypes.add((IRType*)newType); + else + { + const auto [direction, type] = getParameterDirectionAndType( + as(context->getDataType())->getParamType(idx)); + SLANG_ASSERT(isGlobalInst(type)); + effectiveTypes.add((IRType*)type); + } + idx++; + } + + return effectiveTypes; + } + List getParamInfos(IRInst* context) { List infos; @@ -1345,6 +1471,16 @@ struct DynamicInstLoweringContext auto arg = callInst->getOperand(argIndex); if (auto argInfo = tryGetInfo(edge.callerContext, arg)) { + const auto [paramDirection, paramType] = + getParameterDirectionAndType(param->getDataType()); + + // Only update if the parameter is abstract type. + if (isGlobalInst(paramType) && !(as(paramType))) + { + argIndex++; + continue; + } + // Use centralized update method updateInfo(edge.targetContext, param, argInfo, workQueue); } @@ -1536,8 +1672,7 @@ struct DynamicInstLoweringContext if (auto unconditionalBranch = as(terminator)) { auto arg = unconditionalBranch->getArg(paramIndex); - auto newArg = - upcastCollection(func, arg, as(param->getDataType())); + auto newArg = upcastCollection(func, arg, param->getDataType()); if (newArg != arg) { @@ -1557,14 +1692,16 @@ struct DynamicInstLoweringContext { if (!as(returnInst->getVal()->getDataType())) { - auto funcReturnInfo = getFuncReturnInfo(func); - auto newReturnVal = - upcastCollection(func, returnInst->getVal(), funcReturnInfo); - if (newReturnVal != returnInst->getVal()) + if (auto loweredType = getLoweredType(getFuncReturnInfo(func))) { - // Replace the return value with the reinterpreted value - hasChanges = true; - returnInst->setOperand(0, newReturnVal); + auto newReturnVal = + upcastCollection(func, returnInst->getVal(), loweredType); + if (newReturnVal != returnInst->getVal()) + { + // Replace the return value with the reinterpreted value + hasChanges = true; + returnInst->setOperand(0, newReturnVal); + } } } } @@ -1704,13 +1841,13 @@ struct DynamicInstLoweringContext } // The above loops might have added new contexts to lower. - for (auto context : this->contextsToLower) + /*for (auto context : this->contextsToLower) { hasChanges |= lowerContext(context); auto newFunc = cast(this->loweredContexts[context]); funcsToProcess.add(newFunc); } - this->contextsToLower.clear(); + this->contextsToLower.clear();*/ } while (funcsToProcess.getCount() > 0); @@ -1772,10 +1909,29 @@ struct DynamicInstLoweringContext auto typeCollection = cast(taggedUnion->getOperand(0)); auto tableCollection = cast(taggedUnion->getOperand(1)); + + if (getCollectionCount(typeCollection) == 1) + return builder.getTupleType( + List( + {(IRType*)makeTagType(tableCollection), + (IRType*)getCollectionElement(typeCollection, 0)})); + return builder.getTupleType( List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); } + IRType* lowerTypeForInst(IRInst* context, IRInst* inst) + { + if (auto info = tryGetInfo(context, inst)) + { + return getLoweredType(info); + } + else + { + return inst->getDataType(); // If no info, return the original type + } + } + IRType* getLoweredType(IRTypeFlowData* info) { if (!info) @@ -1808,24 +1964,46 @@ struct DynamicInstLoweringContext return (IRType*)collection; } + if (as(info) || as(info)) + { + // Don't lower these collections.. they should be used through + // tag types, or be processed out during lowering. + // + return nullptr; + } + SLANG_UNEXPECTED("Unhandled IRTypeFlowData type in getLoweredType"); } bool replaceType(IRInst* context, IRInst* inst) { - auto info = tryGetInfo(context, inst); - if (auto ptrType = as(inst->getDataType())) - { - IRBuilder builder(module); - if (auto loweredType = getLoweredType(info)) - inst->setFullType(builder.getPtrTypeWithAddressSpace(loweredType, ptrType)); - } - else + if (auto info = tryGetInfo(context, inst)) { - if (auto loweredType = getLoweredType(info)) - inst->setFullType(loweredType); + if (auto ptrType = as(inst->getDataType())) + { + IRBuilder builder(module); + if (auto loweredType = getLoweredType(info)) + { + auto loweredPtrType = builder.getPtrTypeWithAddressSpace(loweredType, ptrType); + if (loweredPtrType == inst->getDataType()) + return false; // No change + inst->setFullType(loweredPtrType); + return true; + } + } + else + { + if (auto loweredType = getLoweredType(info)) + { + if (loweredType == inst->getDataType()) + return false; // No change + inst->setFullType(loweredType); + return true; + } + } } - return true; + + return false; } bool lowerInst(IRInst* context, IRInst* inst) @@ -1849,11 +2027,11 @@ struct DynamicInstLoweringContext case kIROp_CreateExistentialObject: return lowerCreateExistentialObject(context, as(inst)); case kIROp_Store: - SLANG_UNEXPECTED("handle this"); + return lowerStore(context, as(inst)); default: { if (auto info = tryGetInfo(context, inst)) - return replaceType(context, info); + return replaceType(context, inst); return false; } } @@ -1861,6 +2039,15 @@ struct DynamicInstLoweringContext bool lowerLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { + // Handle trivial case. + if (auto witnessTable = as(inst->getWitnessTable())) + { + inst->replaceUsesWith( + findEntryInConcreteTable(witnessTable, inst->getRequirementKey())); + inst->removeAndDeallocate(); + return true; + } + auto info = tryGetInfo(context, inst); if (!info) return false; @@ -2096,55 +2283,85 @@ struct DynamicInstLoweringContext } } - IRFuncType* getEffectiveFuncType(IRInst* callee) + bool isTaggedUnionType(IRInst* type) { - IRBuilder builder(module); + if (auto tupleType = as(type)) + return as(tupleType->getOperand(0)) != nullptr; - List paramTypes; - IRType* resultType = nullptr; + return false; + } - auto updateType = [&](IRType* currentType, IRType* newType) -> IRType* + IRType* updateType(IRType* currentType, IRType* newType) + { + // TODO: This is feeling very similar to the unionCollection logic. + // Maybe unify? + if (auto collection = as(currentType)) { - if (auto collection = as(currentType)) - { - List collectionElements; - forEachInCollection( - collection, - [&](IRInst* element) - { - if (auto loweredType = getLoweredType(tryGetInfo(callee, element))) - collectionElements.add(loweredType); - else - collectionElements.add(element); - }); - collectionElements.add(newType); + HashSet collectionElements; + forEachInCollection( + collection, + [&](IRInst* element) { collectionElements.add(element); }); - // If this is a collection, we need to create a new collection with the new type - auto newCollection = createCollection(collection->getOp(), collectionElements); - return (IRType*)newCollection; - } - else if (currentType == newType) + if (auto newCollection = as(newType)) { - return currentType; + // If the new type is also a collection, merge the two collections + forEachInCollection( + newCollection, + [&](IRInst* element) { collectionElements.add(element); }); } - else if (currentType == nullptr) + else { - return newType; + // Otherwise, just add the new type to the collection + collectionElements.add(newType); } - else // Need to create a new collection. - { - List collectionElements; - SLANG_ASSERT(!as(currentType) && !as(newType)); + // If this is a collection, we need to create a new collection with the new type + auto newCollection = createCollection(collection->getOp(), collectionElements); + return (IRType*)newCollection; + } + else if (currentType == newType) + { + return currentType; + } + else if (currentType == nullptr) + { + return newType; + } + else if (isTaggedUnionType(currentType) && isTaggedUnionType(newType)) + { + IRBuilder builder(module); + // Merge the elements of both tagged unions into a new tuple type + return builder.getTupleType( + List( + {(IRType*)makeTagType( + as(updateType( + (IRType*)currentType->getOperand(0)->getOperand(0), + (IRType*)newType->getOperand(0)->getOperand(0)))), + (IRType*)updateType( + (IRType*)currentType->getOperand(1), + (IRType*)newType->getOperand(1))})); + } + else // Need to create a new collection. + { + HashSet collectionElements; + + SLANG_ASSERT(!as(currentType) && !as(newType)); - collectionElements.add(currentType); - collectionElements.add(newType); + collectionElements.add(currentType); + collectionElements.add(newType); - // If this is a collection, we need to create a new collection with the new type - auto newCollection = createCollection(currentType->getOp(), collectionElements); - return (IRType*)newCollection; - } - }; + // If this is a collection, we need to create a new collection with the new type + auto newCollection = createCollection(kIROp_TypeCollection, collectionElements); + return (IRType*)newCollection; + } + } + + IRFuncType* getEffectiveFuncType(IRInst* callee) + { + IRBuilder builder(module); + + List paramTypes; + IRType* resultType = nullptr; auto updateParamType = [&](UInt index, IRType* paramType) -> IRType* { @@ -2160,7 +2377,7 @@ struct DynamicInstLoweringContext auto [currentDirection, currentType] = getParameterDirectionAndType(paramTypes[index]); auto [newDirection, newType] = getParameterDirectionAndType(paramType); - auto updatedType = updateType(currentType, paramType); + auto updatedType = updateType(currentType, newType); SLANG_ASSERT(currentDirection == newDirection); paramTypes[index] = fromDirectionAndType(&builder, currentDirection, updatedType); return updatedType; @@ -2184,32 +2401,84 @@ struct DynamicInstLoweringContext contextsToProcess.add(callee); } - for (auto func : contextsToProcess) + for (auto context : contextsToProcess) { - auto paramInfos = getParamInfos(callee); - auto paramDirections = getParamDirections(callee); - for (UInt i = 0; i < paramInfos.getCount(); i++) + auto paramEffectiveTypes = getParamEffectiveTypes(context); + auto paramDirections = getParamDirections(context); + for (UInt i = 0; i < paramEffectiveTypes.getCount(); i++) { - if (auto loweredType = getLoweredType(paramInfos[i])) + if (auto collectionType = as(paramEffectiveTypes[i])) + updateParamType( + i, + fromDirectionAndType( + &builder, + paramDirections[i], + getLoweredType(collectionType))); + else if (paramEffectiveTypes[i] != nullptr) updateParamType( i, - fromDirectionAndType(&builder, paramDirections[i], loweredType)); + fromDirectionAndType( + &builder, + paramDirections[i], + (IRType*)paramEffectiveTypes[i])); else SLANG_UNEXPECTED("Unhandled parameter type in getEffectiveFuncType"); } - auto returnType = getFuncReturnInfo(func); + auto returnType = getFuncReturnInfo(context); if (auto newResultType = getLoweredType(returnType)) { resultType = updateType(resultType, newResultType); } + else if (auto funcType = as(context->getDataType())) + { + SLANG_ASSERT(isGlobalInst(funcType->getResultType())); + resultType = updateType(resultType, funcType->getResultType()); + } else { - resultType = updateType(resultType, (IRType*)returnType); + SLANG_UNEXPECTED("Cannot determine result type for context"); } } - return builder.getFuncType(paramTypes, resultType); + // + // Add in extra parameter types for a call to the callee. + // + + List extraParamTypes; + // If the callee is a collection, then we need a tag as input. + if (auto funcCollection = as(callee)) + { + // If this is a non-trivial collection, we need to add a tag type for the collection + // as the first parameter. + if (getCollectionCount(funcCollection) > 1) + extraParamTypes.add((IRType*)makeTagType(funcCollection)); + // extraParamTypes.add((IRType*)makeTagType(funcCollection)); + } + + // If the any of the elements in the callee (or the callee itself in case + // of a singleton) is a dynamic specialization, each non-singleton TableCollection, + // requries a corresponding tag input. + // + auto calleeToCheck = as(callee) + ? getCollectionElement(as(callee), 0) + : callee; + if (isDynamicGeneric(calleeToCheck)) + { + auto specializeInst = as(calleeToCheck); + + // If this is a dynamic generic, we need to add a tag type for each + // TableCollection in the callee. + for (UIndex i = 0; i < specializeInst->getArgCount(); i++) + if (auto tableCollection = as(specializeInst->getArg(i))) + extraParamTypes.add((IRType*)makeTagType(tableCollection)); + } + + List allParamTypes; + allParamTypes.addRange(extraParamTypes); + allParamTypes.addRange(paramTypes); + + return builder.getFuncType(allParamTypes, resultType); } /*IRFuncType* getExpectedFuncType(IRInst* context, IRCall* inst) @@ -2305,6 +2574,7 @@ struct DynamicInstLoweringContext return false; } + /* bool lowerContext(IRInst* context) { auto specializeInst = cast(context); @@ -2452,6 +2722,7 @@ struct DynamicInstLoweringContext this->loweredContexts[context] = loweredFunc; return true; } + */ IRInst* getCalleeForContext(IRInst* context) { @@ -2473,26 +2744,12 @@ struct DynamicInstLoweringContext { auto specArg = specializedCallee->getArg(ii); auto argInfo = specArg->getDataType(); - if (auto argCollection = as(argInfo)) - { - if (as(getCollectionElement(argCollection, 0))) - { - // Needs an index (spec-arg will carry an index, we'll - // just need to append it to the call) - // - callArgs.add(specArg); - } - else if (as(getCollectionElement(argCollection, 0))) - { - // Needs no dynamic information. Skip. - } - else - { - // If it's a witness table, we need to handle it differently - // For now, we will not lower this case. - SLANG_UNEXPECTED("Unhandled type-flow-collection in dynamic generic call"); - } - } + + // Pull all tag-type arguments from the specialization arguments + // and add them to the call arguments. + // + if (as(argInfo)) + callArgs.add(specArg); } return callArgs; @@ -2553,55 +2810,26 @@ struct DynamicInstLoweringContext bool lowerCall(IRInst* context, IRCall* inst) { auto callee = inst->getCallee(); - auto expectedFuncType = getEffectiveFuncType(callee); + IRInst* calleeTagInst = nullptr; - // First, we'll legalize all operands by upcasting if necessary. - // This needs to be done even if the callee is not a collection. + // If we're calling using a tag, place a call to the collection, + // with the tag as the first argument. So the callee is + // the collection itself. // - // List paramTypeFlows = getParamInfos(callee); - // List paramDirections = getParamDirections(callee); - bool changed = false; - List newArgs; - for (UInt i = 0; i < inst->getArgCount(); i++) + if (auto collectionTag = as(callee->getDataType())) { - auto arg = inst->getArg(i); - const auto [paramDirection, paramType] = - getParameterDirectionAndType(expectedFuncType->getParamType(i)); - if (!as(paramType)) - { - SLANG_ASSERT(!as(arg->getDataType())); - newArgs.add(arg); // No need to change the argument - continue; - } - - IRInst* newArg = nullptr; - switch (paramDirection) - { - case kParameterDirection_In: - newArgs.add(upcastCollection(context, arg, as(paramType))); - break; - default: - SLANG_UNEXPECTED("Unhandled parameter direction in lowerCall"); - } - - /*if (newArg != arg) - { - // If the argument changed, replace the old one. - changed = true; - IRBuilder builder(inst->getModule()); - builder.setInsertBefore(inst); - builder.replaceOperand(&inst->getArgs()[i + 1], newArg); - }*/ + if (getCollectionCount(collectionTag) > 1) + calleeTagInst = callee; // Only keep the tag if there are multiple elements. + callee = collectionTag->getOperand(0); } - // New we need to determine the new callee. - IRInst* newCallee = nullptr; + auto expectedFuncType = getEffectiveFuncType(callee); - List extraArgs; + List newArgs; + IRInst* newCallee = nullptr; - // auto calleeInfo = tryGetInfo(context, callee); - auto calleeInfo = as(callee->getDataType()); - auto calleeCollection = as(calleeInfo); + // Determine a new callee. + auto calleeCollection = as(callee); if (!calleeCollection) newCallee = callee; // Not a collection, no need to lower else if (getCollectionCount(calleeCollection) == 1) @@ -2613,43 +2841,29 @@ struct DynamicInstLoweringContext } else { - changed = true; if (isDynamicGeneric(singletonValue)) - extraArgs = getArgsForDynamicSpecialization(cast(singletonValue)); + newArgs.addRange( + getArgsForDynamicSpecialization(cast(inst->getCallee()))); newCallee = singletonValue; } - - /* - if (isDynamicGeneric(singletonValue)) - return lowerCallToDynamicGeneric(context, inst); - - if (singletonValue == callee) - return false; - */ - - // IRBuilder builder(inst->getModule()); - // builder.replaceOperand(inst->getCalleeUse(), singletonValue); - // newCallee = singletonValue; // Replace with the single value - // return true; // Replaced with a single function } else { - changed = true; // Multiple elements in the collection. - extraArgs.add(callee); - auto funcCollection = cast(callee->getOperand(0)); + if (calleeTagInst) + newArgs.add(calleeTagInst); + auto funcCollection = cast(calleeCollection); // Check if the first element is a dynamic generic (this should imply that all // elements are similar dynamic generics, but we might want to check for that..) // if (isDynamicGeneric(getCollectionElement(funcCollection, 0))) { - SLANG_UNEXPECTED("Dynamic generic in a collection call"); - auto dynamicSpecArgs = getArgsForDynamicSpecialization( - cast(getCollectionElement(funcCollection, 0))); + auto dynamicSpecArgs = + getArgsForDynamicSpecialization(cast(inst->getCallee())); for (auto& arg : dynamicSpecArgs) - extraArgs.add(arg); + newArgs.add(arg); } if (!as(funcCollection->getDataType())) @@ -2661,10 +2875,68 @@ struct DynamicInstLoweringContext newCallee = funcCollection; } + // First, we'll legalize all operands by upcasting if necessary. + // This needs to be done even if the callee is not a collection. + // + // List paramTypeFlows = getParamInfos(callee); + // List paramDirections = getParamDirections(callee); + UCount extraArgCount = newArgs.getCount(); + for (UInt i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + const auto [paramDirection, paramType] = + getParameterDirectionAndType(expectedFuncType->getParamType(i + extraArgCount)); + + IRInst* newArg = nullptr; + switch (paramDirection) + { + case kParameterDirection_In: + newArgs.add(upcastCollection(context, arg, paramType)); + break; + case kParameterDirection_Out: + case kParameterDirection_InOut: + { + auto argValueType = as(arg->getDataType())->getValueType(); + if (argValueType != paramType) + { + SLANG_UNEXPECTED("ptr-typed parameters should have matching types"); + /* + IRBuilder varBuilder(inst->getModule()); + varBuilder.setInsertAfter(arg); + auto callVar = varBuilder.emitVar(paramType); + + if (paramDirection == kParameterDirection_InOut) + { + varBuilder.emitStore( + callVar, + upcastCollection(context, varBuilder.emitLoad(arg), paramType)); + } + + varBuilder.emitStore(arg, varBuilder.emitLoad(callVar));*/ + } + else + { + newArgs.add(arg); + } + break; + } + default: + SLANG_UNEXPECTED("Unhandled parameter direction in lowerCall"); + } + + /*if (newArg != arg) + { + // If the argument changed, replace the old one. + changed = true; + IRBuilder builder(inst->getModule()); + builder.setInsertBefore(inst); + builder.replaceOperand(&inst->getArgs()[i + 1], newArg); + }*/ + } + IRBuilder builder(inst); builder.setInsertBefore(inst); - /* Create dispatch function auto dispatchFunc = createDispatchFunc(collectionToHashSet(calleeCollection), expectedFuncType);*/ @@ -2678,10 +2950,23 @@ struct DynamicInstLoweringContext newArgs.add(inst->getOperand(i)); }*/ - if (changed) + bool changed = false; + for (UInt i = 0; i < newArgs.getCount(); i++) { - List callArgs; + if (newArgs[i] != inst->getArg(i)) + { + changed = true; + break; + } + } + if (newCallee != inst->getCallee()) + { + changed = true; + } + + if (changed) + { auto newCall = builder.emitCallInst(expectedFuncType->getResultType(), newCallee, newArgs); inst->replaceUsesWith(newCall); @@ -2741,14 +3026,15 @@ struct DynamicInstLoweringContext witnessTableID = inst->getWitnessTable(); } - // Create the appropriate any-value type auto collectionType = getCollectionCount(typeCollection) == 1 ? (IRType*)typeCollection->getOperand(0) : (IRType*)typeCollection; // Pack the value - auto packedValue = builder.emitPackAnyValue(collectionType, inst->getWrappedValue()); + auto packedValue = as(collectionType) + ? builder.emitPackAnyValue(collectionType, inst->getWrappedValue()) + : inst->getWrappedValue(); auto taggedUnionTupleType = getLoweredType(taggedUnion); @@ -2800,6 +3086,26 @@ struct DynamicInstLoweringContext return true; } + bool lowerStore(IRInst* context, IRStore* inst) + { + auto ptr = inst->getPtr(); + auto ptrInfo = as(ptr->getDataType())->getValueType(); + + auto valInfo = inst->getVal()->getDataType(); + + auto loweredVal = upcastCollection(context, inst->getVal(), valInfo); + + if (loweredVal != inst->getVal()) + { + // If the value was changed, we need to update the store instruction. + IRBuilder builder(inst); + builder.replaceOperand(inst->getValUse(), loweredVal); + return true; + } + + return false; + } + /*bool _lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { auto info = tryGetInfo(context, inst); @@ -3065,21 +3371,31 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) // An effective func type should have been set during the dynamic-inst-lowering // pass. // - IRFuncType* expectedFuncType = cast(collection->getFullType()); + IRFuncType* dispatchFuncType = cast(collection->getFullType()); // Create a dispatch function with switch-case for each function IRBuilder builder(collection->getModule()); - List paramTypes; + /*List paramTypes; paramTypes.add(builder.getUIntType()); // ID parameter for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) - paramTypes.add(expectedFuncType->getParamType(i)); + paramTypes.add(expectedFuncType->getParamType(i));*/ + + // auto resultType = expectedFuncType->getResultType(); + // auto funcType = builder.getFuncType(paramTypes, resultType); + + // Consume the first parameter of the expected function type + List innerParamTypes; + for (auto paramType : dispatchFuncType->getParamTypes()) + innerParamTypes.add(paramType); + innerParamTypes.removeAt(0); // Remove the first parameter (ID) + + auto resultType = dispatchFuncType->getResultType(); + auto innerFuncType = builder.getFuncType(innerParamTypes, resultType); - auto resultType = expectedFuncType->getResultType(); - auto funcType = builder.getFuncType(paramTypes, resultType); auto func = builder.createFunc(); builder.setInsertInto(func); - func->setFullType(funcType); + func->setFullType(dispatchFuncType); auto entryBlock = builder.emitBlock(); builder.setInsertInto(entryBlock); @@ -3088,9 +3404,9 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) // Create parameters for the original function arguments List originalParams; - for (UInt i = 1; i < paramTypes.getCount(); i++) + for (UInt i = 0; i < innerParamTypes.getCount(); i++) { - originalParams.add(builder.emitParam(paramTypes[i])); + originalParams.add(builder.emitParam(innerParamTypes[i])); } // Create default block @@ -3121,7 +3437,7 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) { auto funcId = funcSeqID++; auto wrapperFunc = - emitWitnessTableWrapper(funcInst->getModule(), funcInst, expectedFuncType); + emitWitnessTableWrapper(funcInst->getModule(), funcInst, innerFuncType); // Create case block auto caseBlock = builder.emitBlock(); @@ -3280,7 +3596,7 @@ struct TagOpsLoweringContext : public InstPassBase }); IRBuilder builder(inst); - builder.setInsertBefore(inst); + builder.setInsertAfter(inst); auto translatedID = builder.emitCallInst( inst->getDataType(), createIntegerMappingFunc(builder.getModule(), mapping), @@ -3298,23 +3614,31 @@ struct TagOpsLoweringContext : public InstPassBase cast(cast(inst->getDataType())->getOperand(0)); IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); List indices; for (UInt i = 0; i < srcCollection->getOperandCount(); i++) { // Find in destCollection auto srcElement = srcCollection->getOperand(i); + + bool found = false; for (UInt j = 0; j < destCollection->getOperandCount(); j++) { auto destElement = destCollection->getOperand(j); if (srcElement == destElement) { + found = true; indices.add(builder.getIntValue(builder.getUIntType(), j)); break; // Found the index } } - // destCollection must be a super-set - SLANG_UNEXPECTED("Element not found in destination collection"); + + if (!found) + { + // destCollection must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } } // Create an array for the lookup @@ -3338,24 +3662,30 @@ struct TagOpsLoweringContext : public InstPassBase auto key = cast(inst->getOperand(1)); IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); List indices; for (UInt i = 0; i < srcCollection->getOperandCount(); i++) { // Find in destCollection + bool found = false; auto srcElement = findEntryInConcreteTable(srcCollection->getOperand(i), key); for (UInt j = 0; j < destCollection->getOperandCount(); j++) { auto destElement = destCollection->getOperand(j); if (srcElement == destElement) { + found = true; indices.add(builder.getIntValue(builder.getUIntType(), j)); break; // Found the index } } - // destCollection must be a super-set - SLANG_UNEXPECTED("Element not found in destination collection"); + if (!found) + { + // destCollection must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } } // Create an array for the lookup @@ -3388,9 +3718,37 @@ struct TagOpsLoweringContext : public InstPassBase } } + void lowerFuncCollection(IRFuncCollection* collection) + { + IRBuilder builder(collection->getModule()); + if (collection->hasUses() && collection->getDataType() != nullptr) + { + auto dispatchFunc = createDispatchFunc(collection); + traverseUses( + collection, + [&](IRUse* use) + { + if (auto callInst = as(use->getUser())) + { + // If the call is a collection call, replace it with the dispatch function + if (callInst->getCallee() == collection) + { + IRBuilder callBuilder(callInst); + callBuilder.setInsertBefore(callInst); + callBuilder.replaceOperand(callInst->getCalleeUse(), dispatchFunc); + } + } + }); + } + } + void processModule() { - processAllReachableInsts([&](IRInst* inst) { return processInst(inst); }); + processInstsOfType( + kIROp_FuncCollection, + [&](IRFuncCollection* inst) { return lowerFuncCollection(inst); }); + + processAllInsts([&](IRInst* inst) { return processInst(inst); }); } }; @@ -3418,19 +3776,9 @@ struct CollectionLoweringContext : public InstPassBase collection->replaceUsesWith(anyValueType); } - void lowerFuncCollection(IRFuncCollection* collection) - { - IRBuilder builder(collection->getModule()); - auto dispatchFunc = createDispatchFunc(collection); - collection->replaceUsesWith(dispatchFunc); - } void processModule() { - processInstsOfType( - kIROp_FuncCollection, - [&](IRFuncCollection* inst) { return lowerFuncCollection(inst); }); - processInstsOfType( kIROp_TypeCollection, [&](IRTypeCollection* inst) { return lowerTypeCollection(inst); }); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 2a96cfc490d..c4e6eb1370b 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1159,6 +1159,8 @@ struct SpecializationContext if (options.lowerWitnessLookups) { iterChanged = lowerDynamicInsts(module, sink); + if (iterChanged) + eliminateDeadCode(module->getModuleInst()); } if (!iterChanged || sink->getErrorCount()) @@ -3052,12 +3054,194 @@ void finalizeSpecialization(IRModule* module) lowerCollectionAndTagInsts(module, nullptr); } + +// DUPLICATE: merge. +static bool isDynamicGeneric(IRInst* callee) +{ + // If the callee is a specialization, and at least one of its arguments + // is a type-flow-collection, then it is a dynamic generic. + // + if (auto specialize = as(callee)) + { + for (UInt i = 0; i < specialize->getArgCount(); i++) + { + auto arg = specialize->getArg(i); + if (as(arg)) + return true; // Found a type-flow-collection argument + } + return false; // No type-flow-collection arguments found + } + + return false; +} + +static IRCollectionTagType* makeTagType(IRCollectionBase* collection) +{ + IRInst* collectionInst = collection; + // Create the tag type from the collection + IRBuilder builder(collection->getModule()); + return as( + builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); +} + +static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) +{ + auto generic = cast(specializeInst->getBase()); + auto genericReturnVal = findGenericReturnVal(generic); + + IRBuilder builder(specializeInst->getModule()); + builder.setInsertInto(specializeInst->getModule()); + + // Let's start by creating the function itself. + auto loweredFunc = builder.createFunc(); + builder.setInsertInto(loweredFunc); + builder.setInsertInto(builder.emitBlock()); + // loweredFunc->setFullType(context->getFullType()); + + IRCloneEnv cloneEnv; + Index argIndex = 0; + List extraParamTypes; + // Map the generic's parameters to the specialized arguments. + for (auto param : generic->getFirstBlock()->getParams()) + { + auto specArg = specializeInst->getArg(argIndex++); + if (auto collection = as(specArg)) + { + // We're dealing with a set of types. + if (as(param->getDataType())) + { + // auto unionType = createAnyValueTypeFromInsts(collectionSet); + cloneEnv.mapOldValToNew[param] = collection; + } + else if (as(param->getDataType())) + { + // Add an integer param to the func. + auto tagType = (IRType*)makeTagType(collection); + cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); + extraParamTypes.add(tagType); + // extraIndices++; + } + } + else + { + // For everything else, just set the parameter type to the argument; + SLANG_ASSERT(specArg->getParent()->getOp() == kIROp_ModuleInst); + cloneEnv.mapOldValToNew[param] = specArg; + } + } + + // Clone in the rest of the generic's body including the blocks of the returned func. + for (auto inst = generic->getFirstBlock()->getFirstOrdinaryInst(); inst; + inst = inst->getNextInst()) + { + if (inst == genericReturnVal) + { + auto returnedFunc = cast(inst); + auto funcFirstBlock = returnedFunc->getFirstBlock(); + + // cloneEnv.mapOldValToNew[funcFirstBlock] = loweredFunc->getFirstBlock(); + builder.setInsertInto(loweredFunc); + for (auto block : returnedFunc->getBlocks()) + { + // Merge the first block of the generic with the first block of the + // returned function to merge the parameter lists. + // + // if (block != funcFirstBlock) + //{ + cloneEnv.mapOldValToNew[block] = cloneInstAndOperands(&cloneEnv, &builder, block); + //} + } + + builder.setInsertInto(loweredFunc->getFirstBlock()); + builder.emitBranch(as(cloneEnv.mapOldValToNew[funcFirstBlock])); + + for (auto param : funcFirstBlock->getParams()) + { + // Clone the parameters of the first block. + builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); + cloneInst(&cloneEnv, &builder, param); + } + + builder.setInsertInto(as(cloneEnv.mapOldValToNew[funcFirstBlock])); + for (auto inst = funcFirstBlock->getFirstOrdinaryInst(); inst; + inst = inst->getNextInst()) + { + // Clone the instructions in the first block. + cloneInst(&cloneEnv, &builder, inst); + } + + for (auto block : returnedFunc->getBlocks()) + { + if (block == funcFirstBlock) + continue; // Already cloned the first block + cloneInstDecorationsAndChildren( + &cloneEnv, + builder.getModule(), + block, + cloneEnv.mapOldValToNew[block]); + } + + builder.setInsertInto(builder.getModule()); + auto loweredFuncType = as( + cloneInst(&cloneEnv, &builder, as(returnedFunc->getFullType()))); + + // Add extra indices to the func-type parameters + List funcTypeParams; + for (Index i = 0; i < extraParamTypes.getCount(); i++) + funcTypeParams.add(extraParamTypes[i]); + + for (auto paramType : loweredFuncType->getParamTypes()) + funcTypeParams.add(paramType); + + // Set the new function type with the extra indices + loweredFunc->setFullType( + builder.getFuncType(funcTypeParams, loweredFuncType->getResultType())); + } + else if (!as(inst)) + { + // Keep cloning insts in the generic + cloneInst(&cloneEnv, &builder, inst); + } + } + + /* + // Transfer propagation info. + for (auto& [oldVal, newVal] : cloneEnv.mapOldValToNew) + { + if (propagationMap.containsKey(Element(context, oldVal))) + { + // If we have propagation info for the old value, transfer it to the new value + if (auto info = propagationMap[Element(context, oldVal)]) + { + if (newVal->getParent()->getOp() != kIROp_ModuleInst) + propagationMap[Element(loweredFunc, newVal)] = info; + } + } + } + + // Transfer func-return value info. + if (this->funcReturnInfo.containsKey(context)) + { + this->funcReturnInfo[loweredFunc] = this->funcReturnInfo[context]; + } + + + context->replaceUsesWith(loweredFunc); + */ + // context->removeAndDeallocate(); + // this->loweredContexts[context] = loweredFunc; + return loweredFunc; +} + IRInst* specializeGenericImpl( IRGeneric* genericVal, IRSpecialize* specializeInst, IRModule* module, SpecializationContext* context) { + if (isDynamicGeneric(specializeInst)) + return specializeDynamicGeneric(specializeInst); + // Effectively, specializing a generic amounts to "calling" the generic // on its concrete argument values and computing the // result it returns. @@ -3196,6 +3380,10 @@ IRInst* specializeGeneric(IRSpecialize* specializeInst) if (!module) return specializeInst; + if (isDynamicGeneric(specializeInst)) + return specializeDynamicGeneric(specializeInst); + + // Standard static specialization of generic. return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); } diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index 481e4ae0219..5aed1a99bcb 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -123,13 +123,123 @@ struct GenerateWitnessTableWrapperContext } }; + +// DUPLICATES... put into common file. + +static bool isTaggedUnionType(IRInst* type) +{ + if (auto tupleType = as(type)) + return as(tupleType->getOperand(0)) != nullptr; + + return false; +} + +static UCount getCollectionCount(IRCollectionBase* collection) +{ + if (!collection) + return 0; + return collection->getOperandCount(); +} + +static UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) +{ + auto typeCollection = taggedUnion->getOperand(0); + return getCollectionCount(as(typeCollection)); +} + +static UCount getCollectionCount(IRCollectionTagType* tagType) +{ + auto collection = tagType->getOperand(0); + return getCollectionCount(as(collection)); +} + +static IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) +{ + if (!collection || index >= collection->getOperandCount()) + return nullptr; + return collection->getOperand(index); +} + +static IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) +{ + auto typeCollection = collectionTagType->getOperand(0); + return getCollectionElement(as(typeCollection), index); +} + +static IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) +{ + auto argInfo = arg->getDataType(); + if (!argInfo || !destInfo) + return arg; + + if (isTaggedUnionType(argInfo) && isTaggedUnionType(destInfo)) + { + auto argTupleType = as(argInfo); + auto destTupleType = as(destInfo); + + List upcastedElements; + bool hasUpcastedElements = false; + + // Upcast each element of the tuple + for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) + { + auto argElementType = argTupleType->getOperand(i); + auto destElementType = destTupleType->getOperand(i); + + // If the element types are different, we need to reinterpret + if (argElementType != destElementType) + { + hasUpcastedElements = true; + upcastedElements.add(upcastCollection( + builder, + builder->emitGetTupleElement((IRType*)argElementType, arg, i), + (IRType*)destElementType)); + } + else + { + upcastedElements.add(builder->emitGetTupleElement((IRType*)argElementType, arg, i)); + } + } + + if (hasUpcastedElements) + { + return builder->emitMakeTuple(upcastedElements); + } + } + else if (as(argInfo) && as(destInfo)) + { + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + return builder + ->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); + } + } + else if (as(argInfo) && as(destInfo)) + { + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + // If the sets of witness tables are not equal, reinterpret to the parameter type + return builder->emitReinterpret((IRType*)destInfo, arg); + } + } + else if (!as(argInfo) && as(destInfo)) + { + return builder->emitPackAnyValue((IRType*)destInfo, arg); + } + + return arg; // Can use as-is. +} + + // Represents a work item for packing `inout` or `out` arguments after a concrete call. struct ArgumentPackWorkItem { enum Kind { Pack, - Reinterpret + UpCast, } kind = Pack; // A `AnyValue` typed destination. @@ -172,20 +282,23 @@ IRInst* maybeUnpackArg( if (auto argPtrType = as(argType)) { argValType = argPtrType->getValueType(); - argVal = builder->emitLoad(arg); } + // Unpack `arg` if the parameter expects concrete type but // `arg` is an AnyValue. if (!isAnyValueType(paramValType) && isAnyValueType(argValType)) { - auto unpackedArgVal = builder->emitUnpackAnyValue(paramValType, argVal); // if parameter expects an `out` pointer, store the unpacked val into a // variable and pass in a pointer to that variable. if (as(paramType)) { auto tempVar = builder->emitVar(paramValType); - builder->emitStore(tempVar, unpackedArgVal); + if (as(paramType)) + builder->emitStore( + tempVar, + builder->emitUnpackAnyValue(paramValType, builder->emitLoad(arg))); + // tempVar needs to be unpacked into original var after the call. packAfterCall.kind = ArgumentPackWorkItem::Kind::Pack; packAfterCall.dstArg = arg; @@ -194,7 +307,7 @@ IRInst* maybeUnpackArg( } else { - return unpackedArgVal; + return builder->emitUnpackAnyValue(paramValType, argVal); } } @@ -203,24 +316,24 @@ IRInst* maybeUnpackArg( // by checking if the types are different, but this should be // encoded in the types. // - if (paramValType != argValType) + if (isTaggedUnionType(paramValType) && isTaggedUnionType(argValType) && + paramValType != argValType) { - auto reinterpretedArgVal = builder->emitReinterpret(paramValType, argVal); // if parameter expects an `out` pointer, store the unpacked val into a // variable and pass in a pointer to that variable. - if (as(paramType)) + if (as(paramType)) { auto tempVar = builder->emitVar(paramValType); - builder->emitStore(tempVar, reinterpretedArgVal); + // tempVar needs to be unpacked into original var after the call. - packAfterCall.kind = ArgumentPackWorkItem::Kind::Reinterpret; + packAfterCall.kind = ArgumentPackWorkItem::Kind::UpCast; packAfterCall.dstArg = arg; packAfterCall.concreteArg = tempVar; return tempVar; } else { - return reinterpretedArgVal; + SLANG_UNEXPECTED("Unexpected upcast for non-out parameter"); } } return arg; @@ -244,6 +357,7 @@ IRStringLit* _getWitnessTableWrapperFuncName(IRModule* module, IRFunc* func) return nullptr; } + IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* interfaceRequirementVal) { auto funcTypeInInterface = cast(interfaceRequirementVal); @@ -300,7 +414,7 @@ IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* inte auto concreteVal = builder->emitLoad(item.concreteArg); auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) ? builder->emitPackAnyValue(anyValType, concreteVal) - : builder->emitReinterpret(anyValType, concreteVal); + : upcastCollection(builder, concreteVal, anyValType); builder->emitStore(item.dstArg, packedVal); } @@ -311,9 +425,11 @@ IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* inte auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); builder->emitReturn(pack); } - else if (call->getDataType() != funcTypeInInterface->getResultType()) + else if ( + isTaggedUnionType(call->getDataType()) && + isTaggedUnionType(funcTypeInInterface->getResultType())) { - auto reinterpret = builder->emitReinterpret(funcTypeInInterface->getResultType(), call); + auto reinterpret = upcastCollection(builder, call, funcTypeInInterface->getResultType()); builder->emitReturn(reinterpret); } else From a716f4df36a2f3058bc11d5814aae4b13d776740 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:47:27 -0400 Subject: [PATCH 021/105] Delete commented-out lines --- source/slang/slang-ir-lower-dynamic-insts.cpp | 819 ------------------ 1 file changed, 819 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 554b77ce206..43f86830595 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -648,31 +648,6 @@ struct DynamicInstLoweringContext { return builder.emitMakeTuple(upcastedElements); } - - /*if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - // If the sets of witness tables are not equal, reinterpret to the parameter type - IRBuilder builder(module); - builder.setInsertAfter(arg); - - auto argTupleType = as(arg->getDataType()); - auto tag = - builder.emitGetTupleElement((IRType*)argTupleType->getOperand(0), arg, 0); - auto value = - builder.emitGetTupleElement((IRType*)argTupleType->getOperand(1), arg, 1); - - auto newTag = upcastCollection( - context, - tag, - makeTagType(cast(destInfo->getOperand(1)))); - auto newValue = upcastCollection( - context, - value, - cast(destInfo->getOperand(0))); - - return builder.emitMakeTuple(newTag, newValue); - }*/ } else if (as(argInfo) && as(destInfo)) { @@ -706,166 +681,6 @@ struct DynamicInstLoweringContext return arg; // Can use as-is. } - /* - IRInst* maybeReinterpret(IRInst* context, IRInst* arg, IRTypeFlowData* destInfo) - { - auto argInfo = tryGetInfo(context, arg); - - if (!argInfo || !destInfo) - return arg; - - if (as(argInfo) && as(destInfo)) - { - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - // If the sets of witness tables are not equal, reinterpret to the parameter type - IRBuilder builder(module); - builder.setInsertAfter(arg); - - // We'll use nulltype for the reinterpret since the type is going to be re-written - // and if it doesn't, this will help catch it before code-gen. - // - auto reinterpret = builder.emitReinterpret(nullptr, arg); - propagationMap[Element(reinterpret)] = destInfo; - return reinterpret; // Return the reinterpret instruction - } - } - - return arg; // Can use as-is. - } - - bool insertReinterprets() - { - bool changed = false; - // Process each function in the module - for (auto inst : module->getGlobalInsts()) - { - if (auto func = as(inst)) - { - auto context = func; - // Skip the first block as it contains function parameters, not phi parameters - for (auto block = func->getFirstBlock()->getNextBlock(); block; - block = block->getNextBlock()) - { - // Process each parameter in this block (these are phi parameters) - for (auto param : block->getParams()) - { - auto paramInfo = tryGetInfo(param); - if (!paramInfo) - continue; - - // Check all predecessors and their corresponding arguments - Index paramIndex = 0; - for (auto p : block->getParams()) - { - if (p == param) - break; - paramIndex++; - } - - // Find all predecessors of this block - for (auto pred : block->getPredecessors()) - { - auto terminator = pred->getTerminator(); - if (!terminator) - continue; - - if (auto unconditionalBranch = as(terminator)) - { - // Get the argument at the same index as this parameter - if (paramIndex < unconditionalBranch->getArgCount()) - { - auto arg = unconditionalBranch->getArg(paramIndex); - auto newArg = maybeReinterpret(context, arg, tryGetInfo(param)); - - if (newArg != arg) - { - changed = true; - // Replace the argument in the branch instruction - SLANG_ASSERT(!as(unconditionalBranch)); - unconditionalBranch->setOperand(1 + paramIndex, newArg); - } - } - } - } - } - - // Is the terminator a return instruction? - if (auto returnInst = as(block->getTerminator())) - { - if (!as(returnInst->getVal()->getDataType())) - { - auto funcReturnInfo = tryGetFuncReturnInfo(func); - auto newReturnVal = - maybeReinterpret(context, returnInst->getVal(), funcReturnInfo); - if (newReturnVal != returnInst->getVal()) - { - // Replace the return value with the reinterpreted value - changed = true; - returnInst->setOperand(0, newReturnVal); - } - } - } - - List callInsts; - List storeInsts; - // Collect all call instructions in this block - for (auto inst : block->getChildren()) - { - if (auto callInst = as(inst)) - callInsts.add(callInst); - else if (auto storeInst = as(inst)) - storeInsts.add(storeInst); - } - - // Look at all the args and reinterpret them if necessary - for (auto callInst : callInsts) - { - if (auto irFunc = as(callInst->getCallee())) - { - List params; - List args; - Index i = 0; - for (auto param : irFunc->getParams()) - { - auto newArg = maybeReinterpret( - context, - callInst->getArg(i), - tryGetInfo(param)); - if (newArg != callInst->getArg(i)) - { - // Replace the argument in the call instruction - changed = true; - callInst->setArg(i, newArg); - } - i++; - } - } - } - - // Look at all the stores and reinterpret them if necessary - for (auto storeInst : storeInsts) - { - auto newValToStore = maybeReinterpret( - context, - storeInst->getVal(), - tryGetInfo(storeInst->getPtr())); - if (newValToStore != storeInst->getVal()) - { - // Replace the value in the store instruction - changed = true; - storeInst->setOperand(1, newValToStore); - } - } - } - } - } - - return changed; - } - */ - void processInstForPropagation(IRInst* context, IRInst* inst, LinkedList& workQueue) { IRTypeFlowData* info; @@ -1717,120 +1532,19 @@ struct DynamicInstLoweringContext return hasChanges; } - /* - bool lowerInstsInFunc(IRFunc* func) - { - // Collect all instructions that need lowering - List typeInstsToLower; - List valueInstsToLower; - List instWithReplacementTypes; - List funcTypesToProcess; - - bool hasChanges = false; - auto context = func; - // Process each function's instructions - for (auto block : func->getBlocks()) - { - for (auto child : block->getChildren()) - { - if (as(child)) - continue; // Skip parameters and terminators - - switch (child->getOp()) - { - case kIROp_LookupWitnessMethod: - { - if (child->getDataType()->getOp() == kIROp_TypeKind) - typeInstsToLower.add(Element(context, child)); - else - valueInstsToLower.add(Element(context, child)); - break; - } - case kIROp_ExtractExistentialType: - typeInstsToLower.add(Element(context, child)); - break; - case kIROp_ExtractExistentialWitnessTable: - case kIROp_ExtractExistentialValue: - case kIROp_MakeExistential: - case kIROp_CreateExistentialObject: - valueInstsToLower.add(Element(context, child)); - break; - case kIROp_Call: - { - auto callee = as(child)->getCallee(); - if (auto info = tryGetInfo(context, child)) - if (as(info)) - instWithReplacementTypes.add(Element(context, child)); - - if (auto calleeInfo = tryGetInfo(context, callee)) - if (as(calleeInfo)) - valueInstsToLower.add(Element(context, child)); - - if (as(callee)) - valueInstsToLower.add(Element(context, child)); - } - break; - default: - if (auto info = tryGetInfo(context, child)) - if (as(info)) - // If this instruction has a set of types, tables, or funcs, - // we need to lower it to a unified type. - instWithReplacementTypes.add(Element(context, child)); - } - } - } - - for (auto instWithCtx : typeInstsToLower) - { - if (instWithCtx.inst->getParent() == nullptr) - continue; - hasChanges |= lowerInst(instWithCtx.context, instWithCtx.inst); - } - - for (auto instWithCtx : valueInstsToLower) - { - if (instWithCtx.inst->getParent() == nullptr) - continue; - hasChanges |= lowerInst(instWithCtx.context, instWithCtx.inst); - } - - for (auto instWithCtx : instWithReplacementTypes) - { - if (instWithCtx.inst->getParent() == nullptr) - continue; - hasChanges |= replaceType(instWithCtx.context, instWithCtx.inst); - } - - return hasChanges; - } - */ - bool performDynamicInstLowering() { - // List funcsForTypeReplacement; List funcsToProcess; for (auto globalInst : module->getGlobalInsts()) if (auto func = as(globalInst)) { - // funcsForTypeReplacement.add(func); funcsToProcess.add(func); } bool hasChanges = false; do { - /* - while (funcsForTypeReplacement.getCount() > 0) - { - auto func = funcsForTypeReplacement.getLast(); - funcsForTypeReplacement.removeLast(); - - // Replace the function type with a concrete type if it has existential return types - hasChanges |= replaceFuncType(func, this->funcReturnInfo[func]); - } - */ - while (funcsToProcess.getCount() > 0) { auto func = funcsToProcess.getLast(); @@ -1839,74 +1553,17 @@ struct DynamicInstLoweringContext // Lower the instructions in the function hasChanges |= lowerFunc(func); } - - // The above loops might have added new contexts to lower. - /*for (auto context : this->contextsToLower) - { - hasChanges |= lowerContext(context); - auto newFunc = cast(this->loweredContexts[context]); - funcsToProcess.add(newFunc); - } - this->contextsToLower.clear();*/ - } while (funcsToProcess.getCount() > 0); return hasChanges; } - /* - bool replaceFuncType(IRFunc* func, IRTypeFlowData* returnTypeInfo) - { - IRFuncType* origFuncType = as(func->getFullType()); - IRType* returnType = origFuncType->getResultType(); - if (auto taggedUnion = as(returnTypeInfo)) - { - // If the return type is existential, we need to replace it with a tuple type - returnType = getTypeForExistential(taggedUnion); - } - - List paramTypes; - for (auto param : func->getFirstBlock()->getParams()) - { - // Extract the existential type from the parameter if it exists - auto paramInfo = tryGetInfo(param); - if (auto paramTaggedUnion = as(paramInfo)) - { - paramTypes.add(getTypeForExistential(paramTaggedUnion)); - } - else - paramTypes.add(param->getDataType()); - } - - IRBuilder builder(module); - builder.setInsertBefore(func); - - auto newFuncType = builder.getFuncType(paramTypes, returnType); - if (newFuncType == func->getFullType()) - return false; // No change - - func->setFullType(newFuncType); - return true; - } - */ - IRType* getTypeForExistential(IRCollectionTaggedUnionType* taggedUnion) { // Replace type with Tuple IRBuilder builder(module); builder.setInsertInto(module); - /*HashSet types; - auto tableCollection = as(taggedUnion->getOperand(1)); - forEachInCollection( - tableCollection, - [&](IRInst* table) - { - if (auto witnessTable = as(table)) - if (auto concreteType = witnessTable->getConcreteType()) - types.add(concreteType); - });*/ - auto typeCollection = cast(taggedUnion->getOperand(0)); auto tableCollection = cast(taggedUnion->getOperand(1)); @@ -2059,34 +1716,6 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - /*if (auto collection = as(info)) - { - if (getCollectionCount(collection) == 1) - { - // Found a single possible type. Simple replacement. - inst->replaceUsesWith(getCollectionElement(collection, 0)); - inst->removeAndDeallocate(); - return true; - } - else if (auto typeCollection = as(collection)) - { - // Set of types. - // Create an any-value type based on the set of types - auto typeSet = collectionToHashSet(collection); - auto unionType = typeSet.getCount() > 1 ? createAnyValueTypeFromInsts(typeSet) - : *typeSet.begin(); - - // Store the mapping for later use - loweredInstToAnyValueType[inst] = unionType; - - // Replace the instruction with the any-value type - inst->replaceUsesWith(typeCollection); - inst->removeAndDeallocate(); - return true; - } - } - else*/ - if (getCollectionCount(collectionTagType) == 1) { // Found a single possible type. Simple replacement. @@ -2121,25 +1750,6 @@ struct DynamicInstLoweringContext propagationMap[Element(context, newInst)] = info; inst->removeAndDeallocate(); - /*if (auto witnessTableCollection = as(witnessTableInfo)) - { - // Create a key mapping function - auto keyMappingFunc = createKeyMappingFunc( - inst->getRequirementKey(), - collectionToHashSet(witnessTableCollection), - collectionToHashSet(collection)); - - // Replace with call to key mapping function - auto witnessTableId = builder.emitCallInst( - builder.getUIntType(), - keyMappingFunc, - List({inst->getWitnessTable()})); - inst->replaceUsesWith(witnessTableId); - propagationMap[Element(context, witnessTableId)] = info; - inst->removeAndDeallocate(); - return true; - }*/ - return false; } @@ -2184,15 +1794,6 @@ struct DynamicInstLoweringContext if (!taggedUnion) return false; - /* - // Check if we have a lowered any-value type for the result - auto resultType = inst->getDataType(); - auto loweredType = loweredInstToAnyValueType.tryGetValue(inst); - if (loweredType) - { - resultType = (IRType*)*loweredType; - }*/ - auto info = tryGetInfo(context, inst); auto typeCollection = as(info); if (!typeCollection) @@ -2225,17 +1826,9 @@ struct DynamicInstLoweringContext auto singletonValue = getCollectionElement(collectionTagType, 0); inst->replaceUsesWith(singletonValue); inst->removeAndDeallocate(); - // loweredInstToAnyValueType[inst] = singletonValue; return true; } - // Create an any-value type based on the set of types - /* - auto anyValueType = createAnyValueTypeFromInsts(collectionToHashSet(collection)); - - // Store the mapping for later use - loweredInstToAnyValueType[inst] = anyValueType;*/ - // Replace the instruction with the collection type. inst->replaceUsesWith(collectionTagType->getOperand(0)); inst->removeAndDeallocate(); @@ -2481,80 +2074,6 @@ struct DynamicInstLoweringContext return builder.getFuncType(allParamTypes, resultType); } - /*IRFuncType* getExpectedFuncType(IRInst* context, IRCall* inst) - { - IRBuilder builder(module); - builder.setInsertInto(module); - - // We'll retreive just the parameter directions from the callee's func-type, - // since that can't be different before & after the type-flow lowering. - // - List paramDirections; - auto calleeInfo = tryGetInfo(context, inst->getCallee()); - auto calleeCollection = as(calleeInfo); - if (!calleeCollection) - return nullptr; - - auto funcType = as(getCollectionElement(calleeCollection, 0)->getDataType()); - for (auto paramType : funcType->getParamTypes()) - { - auto [direction, type] = getParameterDirectionAndType(paramType); - paramDirections.add(direction); - } - - // Translate argument types into expected function type. - List paramTypes; - for (UInt i = 0; i < inst->getArgCount(); i++) - { - auto arg = inst->getArg(i); - - switch (paramDirections[i]) - { - case ParameterDirection::kParameterDirection_In: - { - auto argInfo = tryGetInfo(context, arg); - if (auto argTaggedUnion = as(argInfo)) - paramTypes.add(getTypeForExistential(argTaggedUnion)); - else - paramTypes.add(arg->getDataType()); - break; - } - case ParameterDirection::kParameterDirection_Out: - { - auto argInfo = tryGetInfo(context, arg); - if (auto argTaggedUnion = as(argInfo)) - paramTypes.add(builder.getOutType(getTypeForExistential(argTaggedUnion))); - else - paramTypes.add(builder.getOutType( - as(arg->getDataType())->getValueType())); - break; - } - case ParameterDirection::kParameterDirection_InOut: - { - auto argInfo = tryGetInfo(context, arg); - if (auto argTaggedUnion = as(argInfo)) - paramTypes.add(builder.getInOutType(getTypeForExistential(argTaggedUnion))); - else - paramTypes.add(builder.getInOutType( - as(arg->getDataType())->getValueType())); - break; - } - default: - SLANG_UNEXPECTED("Unhandled parameter direction in getExpectedFuncType"); - } - } - - // Translate result type. - IRType* resultType = inst->getDataType(); - auto returnInfo = tryGetInfo(context, inst); - if (auto returnTaggedUnion = as(returnInfo)) - { - resultType = getTypeForExistential(returnTaggedUnion); - } - - return builder.getFuncType(paramTypes, resultType); - }*/ - bool isDynamicGeneric(IRInst* callee) { // If the callee is a specialization, and at least one of its arguments @@ -2574,156 +2093,6 @@ struct DynamicInstLoweringContext return false; } - /* - bool lowerContext(IRInst* context) - { - auto specializeInst = cast(context); - auto generic = cast(specializeInst->getBase()); - auto genericReturnVal = findGenericReturnVal(generic); - - IRBuilder builder(module); - builder.setInsertInto(module); - - // Let's start by creating the function itself. - auto loweredFunc = builder.createFunc(); - builder.setInsertInto(loweredFunc); - builder.setInsertInto(builder.emitBlock()); - // loweredFunc->setFullType(context->getFullType()); - - IRCloneEnv cloneEnv; - Index argIndex = 0; - List extraParamTypes; - // Map the generic's parameters to the specialized arguments. - for (auto param : generic->getFirstBlock()->getParams()) - { - auto specArg = specializeInst->getArg(argIndex++); - if (auto collection = as(specArg)) - { - // We're dealing with a set of types. - if (as(param->getDataType())) - { - // auto unionType = createAnyValueTypeFromInsts(collectionSet); - cloneEnv.mapOldValToNew[param] = collection; - } - else if (as(param->getDataType())) - { - // Add an integer param to the func. - auto tagType = (IRType*)makeTagType(collection); - cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); - extraParamTypes.add(tagType); - // extraIndices++; - } - } - else - { - // For everything else, just set the parameter type to the argument; - SLANG_ASSERT(specArg->getParent()->getOp() == kIROp_ModuleInst); - cloneEnv.mapOldValToNew[param] = specArg; - } - } - - // Clone in the rest of the generic's body including the blocks of the returned func. - for (auto inst = generic->getFirstBlock()->getFirstOrdinaryInst(); inst; - inst = inst->getNextInst()) - { - if (inst == genericReturnVal) - { - auto returnedFunc = cast(inst); - auto funcFirstBlock = returnedFunc->getFirstBlock(); - - // cloneEnv.mapOldValToNew[funcFirstBlock] = loweredFunc->getFirstBlock(); - builder.setInsertInto(loweredFunc); - for (auto block : returnedFunc->getBlocks()) - { - // Merge the first block of the generic with the first block of the - // returned function to merge the parameter lists. - // - // if (block != funcFirstBlock) - //{ - cloneEnv.mapOldValToNew[block] = - cloneInstAndOperands(&cloneEnv, &builder, block); - //} - } - - builder.setInsertInto(loweredFunc->getFirstBlock()); - builder.emitBranch(as(cloneEnv.mapOldValToNew[funcFirstBlock])); - - for (auto param : funcFirstBlock->getParams()) - { - // Clone the parameters of the first block. - builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); - cloneInst(&cloneEnv, &builder, param); - } - - builder.setInsertInto(as(cloneEnv.mapOldValToNew[funcFirstBlock])); - for (auto inst = funcFirstBlock->getFirstOrdinaryInst(); inst; - inst = inst->getNextInst()) - { - // Clone the instructions in the first block. - cloneInst(&cloneEnv, &builder, inst); - } - - for (auto block : returnedFunc->getBlocks()) - { - if (block == funcFirstBlock) - continue; // Already cloned the first block - cloneInstDecorationsAndChildren( - &cloneEnv, - builder.getModule(), - block, - cloneEnv.mapOldValToNew[block]); - } - - builder.setInsertInto(builder.getModule()); - auto loweredFuncType = as( - cloneInst(&cloneEnv, &builder, as(returnedFunc->getFullType()))); - - // Add extra indices to the func-type parameters - List funcTypeParams; - for (Index i = 0; i < extraParamTypes.getCount(); i++) - funcTypeParams.add(extraParamTypes[i]); - - for (auto paramType : loweredFuncType->getParamTypes()) - funcTypeParams.add(paramType); - - // Set the new function type with the extra indices - loweredFunc->setFullType( - builder.getFuncType(funcTypeParams, loweredFuncType->getResultType())); - } - else if (!as(inst)) - { - // Keep cloning insts in the generic - cloneInst(&cloneEnv, &builder, inst); - } - } - - // Transfer propagation info. - for (auto& [oldVal, newVal] : cloneEnv.mapOldValToNew) - { - if (propagationMap.containsKey(Element(context, oldVal))) - { - // If we have propagation info for the old value, transfer it to the new value - if (auto info = propagationMap[Element(context, oldVal)]) - { - if (newVal->getParent()->getOp() != kIROp_ModuleInst) - propagationMap[Element(loweredFunc, newVal)] = info; - } - } - } - - // Transfer func-return value info. - if (this->funcReturnInfo.containsKey(context)) - { - this->funcReturnInfo[loweredFunc] = this->funcReturnInfo[context]; - } - - context->replaceUsesWith(loweredFunc); - // context->removeAndDeallocate(); - this->loweredContexts[context] = loweredFunc; - return true; - } - */ - IRInst* getCalleeForContext(IRInst* context) { if (this->contextsToLower.contains(context)) @@ -2900,19 +2269,6 @@ struct DynamicInstLoweringContext if (argValueType != paramType) { SLANG_UNEXPECTED("ptr-typed parameters should have matching types"); - /* - IRBuilder varBuilder(inst->getModule()); - varBuilder.setInsertAfter(arg); - auto callVar = varBuilder.emitVar(paramType); - - if (paramDirection == kParameterDirection_InOut) - { - varBuilder.emitStore( - callVar, - upcastCollection(context, varBuilder.emitLoad(arg), paramType)); - } - - varBuilder.emitStore(arg, varBuilder.emitLoad(callVar));*/ } else { @@ -2923,33 +2279,11 @@ struct DynamicInstLoweringContext default: SLANG_UNEXPECTED("Unhandled parameter direction in lowerCall"); } - - /*if (newArg != arg) - { - // If the argument changed, replace the old one. - changed = true; - IRBuilder builder(inst->getModule()); - builder.setInsertBefore(inst); - builder.replaceOperand(&inst->getArgs()[i + 1], newArg); - }*/ } IRBuilder builder(inst); builder.setInsertBefore(inst); - /* Create dispatch function - auto dispatchFunc = - createDispatchFunc(collectionToHashSet(calleeCollection), expectedFuncType);*/ - - - // Replace call with dispatch - /*List newArgs; - newArgs.add(callee); // Add the lookup as first argument (will get lowered into an uint tag) - for (UInt i = 1; i < inst->getOperandCount(); i++) - { - newArgs.add(inst->getOperand(i)); - }*/ - bool changed = false; for (UInt i = 0; i < newArgs.getCount(); i++) { @@ -2998,8 +2332,6 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertBefore(inst); - // auto witnessTableInfo = tryGetInfo(context, inst->getWitnessTable()); - // Collect types from the witness tables to determine the any-value type auto tableCollection = as(taggedUnion->getOperand(1)); auto typeCollection = as(taggedUnion->getOperand(0)); @@ -3007,10 +2339,6 @@ struct DynamicInstLoweringContext IRInst* witnessTableID = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { - // Get unique ID for the witness table. - /*witnessTableID = builder.getIntValue( - builder.getUIntType(), - getUniqueID(getCollectionElement(witnessTableCollection, 0)));*/ auto singletonTagType = makeTagType(makeSingletonSet(witnessTable)); auto zeroValueOfTagType = builder.getIntValue((IRType*)singletonTagType, 0); witnessTableID = builder.emitIntrinsicInst( @@ -3042,11 +2370,6 @@ struct DynamicInstLoweringContext IRInst* tupleArgs[] = {witnessTableID, packedValue}; auto tuple = builder.emitMakeTuple(taggedUnionTupleType, 2, tupleArgs); - /* - if (auto info = tryGetInfo(context, inst)) - propagationMap[Element(context, tuple)] = info; - */ - inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); return true; @@ -3106,52 +2429,6 @@ struct DynamicInstLoweringContext return false; } - /*bool _lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) - { - auto info = tryGetInfo(context, inst); - auto taggedUnion = as(info); - if (!taggedUnion) - return false; - - Dictionary mapping; - auto tableCollection = as(taggedUnion->getOperand(1)); - forEachInCollection( - tableCollection, - [&](IRInst* table) - { - // Get unique ID for the witness table - auto witnessTable = cast(table); - auto outputId = getUniqueID(witnessTable); - auto seqDecoration = table->findDecoration(); - if (seqDecoration) - { - auto inputId = seqDecoration->getSequentialID(); - mapping[inputId] = outputId; // Map ID to itself for now - } - }); - - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto translatedID = builder.emitCallInst( - builder.getUIntType(), - createIntegerMappingFunc(mapping), - List({inst->getTypeID()})); - - auto existentialTupleType = as(getTypeForExistential(taggedUnion)); - auto existentialTuple = builder.emitMakeTuple( - existentialTupleType, - List( - {translatedID, - builder.emitReinterpret(existentialTupleType->getOperand(1), inst->getValue())})); - - if (auto info = tryGetInfo(context, inst)) - propagationMap[Element(context, existentialTuple)] = info; - - inst->replaceUsesWith(existentialTuple); - inst->removeAndDeallocate(); - return true; - }*/ - UInt getUniqueID(IRInst* funcOrTable) { auto existingId = uniqueIds.tryGetValue(funcOrTable); @@ -3163,26 +2440,6 @@ struct DynamicInstLoweringContext return newId; } - /* - IRFunc* createKeyMappingFunc( - IRInst* key, - const HashSet& inputTables, - const HashSet& outputVals) - { - Dictionary mapping; - - // Create a mapping. - for (auto table : inputTables) - { - auto inputId = getUniqueID(table); - auto outputId = getUniqueID(findEntryInConcreteTable(table, key)); - mapping[inputId] = outputId; - } - - return createIntegerMappingFunc(module, mapping); - } - */ - bool isExistentialType(IRType* type) { return as(type) != nullptr; } bool isInterfaceType(IRType* type) { return as(type) != nullptr; } @@ -3223,74 +2480,6 @@ struct DynamicInstLoweringContext return tables; } - /*bool lowerTypeCollections() - { - bool hasChanges = false; - - // Lower all global scope ``IRCollectionBase`` objects that - // are made up of types. - // - for (auto inst : module->getGlobalInsts()) - { - if (auto collection = as(inst)) - { - if (collection->getOp() == kIROp_TypeCollection) - { - HashSet types; - for (UInt i = 0; i < collection->getOperandCount(); i++) - { - if (auto type = as(collection->getOperand(i))) - { - types.add(type); - } - } - auto anyValueType = createAnyValueType(types); - collection->replaceUsesWith(anyValueType); - hasChanges = true; - } - } - } - - return hasChanges; - }*/ - - /* - bool transferDataToInstTypes() - { - bool hasChanges = false; - - for (auto& pair : propagationMap) - { - auto instWithContext = pair.first; - auto flowData = pair.second; - - if (!flowData) - continue; // No propagation data - - if (as(flowData)) - { - // If the flow data is an unbounded collection, don't touch - // the types. - continue; - } - - auto inst = instWithContext.inst; - auto context = instWithContext.context; - - // Only transfer data for insts that are in top-level - // contexts. We'll come back to specialized contexts later. - // - if (context->getOp() == kIROp_Func) - { - inst->setFullType((IRType*)flowData); - hasChanges = true; - } - } - - return hasChanges; - } - */ - bool processModule() { bool hasChanges = false; @@ -3376,14 +2565,6 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) // Create a dispatch function with switch-case for each function IRBuilder builder(collection->getModule()); - /*List paramTypes; - paramTypes.add(builder.getUIntType()); // ID parameter - for (UInt i = 0; i < expectedFuncType->getParamCount(); i++) - paramTypes.add(expectedFuncType->getParamType(i));*/ - - // auto resultType = expectedFuncType->getResultType(); - // auto funcType = builder.getFuncType(paramTypes, resultType); - // Consume the first parameter of the expected function type List innerParamTypes; for (auto paramType : dispatchFuncType->getParamTypes()) From 89794ff60abeb754f67a5bc68a5a565f17fb4453 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 5 Aug 2025 11:30:42 -0400 Subject: [PATCH 022/105] More fixes to get dynamic dispatch tests passing --- source/slang/slang-emit.cpp | 3 + source/slang/slang-ir-lower-dynamic-insts.cpp | 459 +++++++++++++++--- source/slang/slang-ir-lower-dynamic-insts.h | 8 +- source/slang/slang-ir-lower-generics.cpp | 4 + source/slang/slang-ir-specialize.cpp | 2 +- 5 files changed, 399 insertions(+), 77 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ccb055e4d69..223632db4be 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1140,6 +1140,9 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } + lowerTagInsts(irModule, sink); + lowerTypeCollections(irModule, sink); + if (requiredLoweringPassSet.reinterpret) lowerReinterpret(targetProgram, irModule, sink); diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 43f86830595..8bcc12cdbb5 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -9,9 +9,11 @@ #include "slang-ir-witness-table-wrapper.h" #include "slang-ir.h" + namespace Slang { +constexpr IRIntegerValue kDefaultAnyValueSize = 16; // Elements for which we keep track of propagation information. struct Element { @@ -605,6 +607,32 @@ struct DynamicInstLoweringContext } } + IRIntegerValue getInterfaceAnyValueSize(IRInst* type) + { + if (auto decor = type->findDecoration()) + { + return decor->getSize(); + } + + // We could conceivably make it an error to have an interface + // without an `[anyValueSize(...)]` attribute, but then we risk + // producing error messages even when doing 100% static specialization. + // + // It is simpler to use a reasonable default size and treat any + // type without an explicit attribute as using that size. + // + return kDefaultAnyValueSize; + } + + IRType* lowerInterfaceType(IRInterfaceType* interfaceType) + { + IRBuilder builder(module); + auto anyValueType = builder.getAnyValueType(getInterfaceAnyValueSize(interfaceType)); + auto witnessTableType = builder.getWitnessTableIDType((IRType*)interfaceType); + auto rttiType = builder.getRTTIHandleType(); + return builder.getTupleType({rttiType, witnessTableType, anyValueType}); + } + IRInst* upcastCollection(IRInst* context, IRInst* arg, IRType* destInfo) { auto argInfo = arg->getDataType(); @@ -677,6 +705,35 @@ struct DynamicInstLoweringContext builder.setInsertAfter(arg); return builder.emitPackAnyValue((IRType*)destInfo, arg); } + else if (as(argInfo) && as(destInfo)) + { + auto loweredInterfaceType = lowerInterfaceType(as(argInfo)); + IRBuilder builder(module); + builder.setInsertAfter(arg); + auto witnessTable = + builder.emitGetTupleElement(builder.getWitnessTableIDType(argInfo), arg, 1); + auto tableID = builder.emitGetSequentialIDInst(witnessTable); + auto tableCollection = cast(destInfo->getOperand(1)); + auto typeCollection = cast(destInfo->getOperand(0)); + + List getTagOperands; + getTagOperands.add(argInfo); + getTagOperands.add(tableID); + auto tableTag = builder.emitIntrinsicInst( + (IRType*)makeTagType(tableCollection), + kIROp_GetTagFromSequentialID, + getTagOperands.getCount(), + getTagOperands.getBuffer()); + + return builder.emitMakeTuple( + {tableTag, + builder.emitReinterpret( + (IRType*)typeCollection, + builder.emitGetTupleElement( + (IRType*)loweredInterfaceType->getOperand(0), + arg, + 2))}); + } return arg; // Can use as-is. } @@ -714,11 +771,16 @@ struct DynamicInstLoweringContext info = analyzeSpecialize(context, as(inst)); break; case kIROp_Load: - info = analyzeLoad(context, as(inst)); + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoad: + info = analyzeLoad(context, inst); break; case kIROp_Store: info = analyzeStore(context, as(inst), workQueue); break; + case kIROp_GetElementPtr: + info = analyzeGetElementPtr(context, as(inst), workQueue); + break; default: info = analyzeDefault(context, inst); break; @@ -736,19 +798,18 @@ struct DynamicInstLoweringContext if (auto interfaceType = as(inst->getDataType())) { - if (!interfaceType->findDecoration()) - { - auto tables = collectExistentialTables(interfaceType); - if (tables.getCount() > 0) - return makeExistential( - as(createCollection(kIROp_TableCollection, tables))); - else - return none(); - } - else + if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) { + // If this is a COM interface, we treat it as unbounded return makeUnbounded(); } + + auto tables = collectExistentialTables(interfaceType); + if (tables.getCount() > 0) + return makeExistential( + as(createCollection(kIROp_TableCollection, tables))); + else + return none(); } return none(); @@ -778,11 +839,63 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } - IRTypeFlowData* analyzeLoad(IRInst* context, IRLoad* loadInst) + bool isResourcePointer(IRInst* inst) { - // Transfer the prop info from the address to the loaded value - auto address = loadInst->getPtr(); - return tryGetInfo(context, address); + return isPointerToResourceType(inst->getDataType()) || + inst->getOp() == kIROp_RWStructuredBufferGetElementPtr; + } + + IRTypeFlowData* analyzeLoad(IRInst* context, IRInst* inst) + { + // Default: Transfer the prop info from the address to the loaded value + if (auto loadInst = as(inst)) + { + if (isResourcePointer(loadInst->getPtr())) + { + if (auto interfaceType = as(loadInst->getDataType())) + { + if (!isComInterfaceType(interfaceType) && !isBuiltin(interfaceType)) + { + auto tables = collectExistentialTables(interfaceType); + if (tables.getCount() > 0) + return makeExistential( + as( + createCollection(kIROp_TableCollection, tables))); + else + return none(); + } + else + { + return makeUnbounded(); + } + } + } + + // If the load is from a pointer, we can transfer the info directly + auto address = as(loadInst)->getPtr(); + return tryGetInfo(context, address); + } + else if (as(inst) || as(inst)) + { + if (auto interfaceType = as(inst->getDataType())) + { + if (!isComInterfaceType(interfaceType) && !isBuiltin(interfaceType)) + { + auto tables = collectExistentialTables(interfaceType); + if (tables.getCount() > 0) + return makeExistential( + as(createCollection(kIROp_TableCollection, tables))); + else + return none(); + } + else + { + return makeUnbounded(); + } + } + } + + return none(); // No info for other load types } IRTypeFlowData* analyzeStore( @@ -1054,7 +1167,11 @@ struct DynamicInstLoweringContext } else if (as(arg) || as(arg)) { - updateInfo(context, param, makeSingletonSet(arg), workQueue); + updateInfo( + context, + param, + makeTagType(makeSingletonSet(arg)), + workQueue); } else { @@ -1132,6 +1249,16 @@ struct DynamicInstLoweringContext return none(); } + IRTypeFlowData* analyzeGetElementPtr( + IRInst* context, + IRGetElementPtr* inst, + LinkedList& workQueue) + { + IRTypeFlowData* thisInstInfo = tryGetInfo(context, inst); + updateInfo(context, inst->getBase(), thisInstInfo, workQueue); + return thisInstInfo; + } + void propagateWithinFuncEdge(IRInst* context, IREdge edge, LinkedList& workQueue) { // Handle intra-procedural edge (original logic) @@ -1362,7 +1489,7 @@ struct DynamicInstLoweringContext if (auto interfaceType = as(paramType)) { - if (interfaceType->findDecoration()) + if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) propagationMap[Element(context, param)] = makeUnbounded(); else propagationMap[Element(context, param)] = none(); // Initialize to none. @@ -1438,13 +1565,30 @@ struct DynamicInstLoweringContext IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) { - // Check if this is a global type, witness table, or function. + // Check if this is a global concrete type, witness table, or function. // If so, it's a concrete element. We'll create a singleton set for it. - if (inst->getParent()->getOp() == kIROp_ModuleInst && - (as(inst) || as(inst) || as(inst))) + if (isGlobalInst(inst) && + (!as(inst) && + (as(inst) || as(inst) || as(inst)))) return makeSingletonSet(inst); - else - return none(); // Default case, no propagation info + + auto instType = inst->getDataType(); + if (isGlobalInst(inst)) + { + if (as(instType) && !(as(instType))) + return none(); // We'll avoid storing propagation info for concrete insts. (can just + // use the inst directly) + + if (as(instType)) + { + // As a general rule, if none of the non-default cases handled this inst that is + // producing an existential type, then we assume that we can't constrain it + // + return makeUnbounded(); + } + } + + return none(); // Default case, no propagation info } bool lowerInstsInBlock(IRInst* context, IRBlock* block) @@ -1683,6 +1827,11 @@ struct DynamicInstLoweringContext return lowerMakeExistential(context, as(inst)); case kIROp_CreateExistentialObject: return lowerCreateExistentialObject(context, as(inst)); + case kIROp_RWStructuredBufferLoad: + case kIROp_StructuredBufferLoad: + return lowerStructuredBufferLoad(context, inst); + case kIROp_Load: + return lowerLoad(context, inst); case kIROp_Store: return lowerStore(context, as(inst)); default: @@ -1781,7 +1930,7 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto element = builder.emitGetTupleElement((IRType*)collectionTagType, operand, 0); inst->replaceUsesWith(element); - propagationMap[Element(context, element)] = info; + // propagationMap[Element(context, element)] = info; inst->removeAndDeallocate(); return true; } @@ -1789,6 +1938,22 @@ struct DynamicInstLoweringContext bool lowerExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { + auto existential = inst->getOperand(0); + auto existentialInfo = existential->getDataType(); + if (isTaggedUnionType(existentialInfo)) + { + auto valType = existentialInfo->getOperand(1); + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + auto val = builder.emitGetTupleElement((IRType*)valType, existential, 1); + inst->replaceUsesWith(val); + inst->removeAndDeallocate(); + return true; + } + + return false; + /* auto operandInfo = tryGetInfo(context, inst->getOperand(0)); auto taggedUnion = as(operandInfo); if (!taggedUnion) @@ -1807,7 +1972,7 @@ struct DynamicInstLoweringContext auto element = builder.emitGetTupleElement((IRType*)info, operand, 1); inst->replaceUsesWith(element); inst->removeAndDeallocate(); - return true; + return true;*/ } bool lowerExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) @@ -2188,10 +2353,24 @@ struct DynamicInstLoweringContext if (auto collectionTag = as(callee->getDataType())) { if (getCollectionCount(collectionTag) > 1) + { calleeTagInst = callee; // Only keep the tag if there are multiple elements. + + // If we're placing a specialized call, use the base tag since the + // specialization arguments will also become arguments to the call. + // + if (auto specializedTag = as(calleeTagInst)) + calleeTagInst = specializedTag->getBase(); + } callee = collectionTag->getOperand(0); } + // If by this point, we haven't resolved our callee into a global inst ( + // either a collection or a single function), then we can't lower it (likely unbounded) + // + if (!isGlobalInst(callee)) + return false; + auto expectedFuncType = getEffectiveFuncType(callee); List newArgs; @@ -2396,9 +2575,17 @@ struct DynamicInstLoweringContext args.getCount(), args.getBuffer()); - auto packedValue = builder.emitPackAnyValue( - (IRType*)taggedUnionTupleType->getOperand(1), - inst->getValue()); + IRInst* packedValue = nullptr; + if (auto collection = as(taggedUnionTupleType->getOperand(1))) + { + packedValue = builder.emitPackAnyValue((IRType*)collection, inst->getValue()); + } + else + { + packedValue = builder.emitReinterpret( + (IRType*)taggedUnionTupleType->getOperand(1), + inst->getValue()); + } auto newInst = builder.emitMakeTuple( taggedUnionTupleType, @@ -2409,6 +2596,92 @@ struct DynamicInstLoweringContext return true; } + bool lowerStructuredBufferLoad(IRInst* context, IRInst* inst) + { + auto valInfo = tryGetInfo(context, inst); + + if (!valInfo) + return false; + + auto bufferType = (IRType*)inst->getOperand(0)->getDataType(); + auto bufferBaseType = (IRType*)bufferType->getOperand(0); + + if (bufferBaseType != (IRType*)getLoweredType(valInfo)) + { + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + IRCloneEnv cloneEnv; + auto newLoad = cloneInst(&cloneEnv, &builder, inst); + + auto loweredVal = upcastCollection(context, newLoad, (IRType*)valInfo); + + // TODO: this is a hack. Encode this in the type-flow-data. + if (as(bufferBaseType) && !isComInterfaceType(inst->getDataType()) && + !isBuiltin(inst->getDataType())) + { + newLoad->setFullType(lowerInterfaceType(as(bufferBaseType))); + } + + inst->replaceUsesWith(loweredVal); + inst->removeAndDeallocate(); + return true; + } + else if (inst->getDataType() != bufferBaseType) + { + // If the data type is not the same, we need to update it. + inst->setFullType((IRType*)getLoweredType(valInfo)); + return true; + } + else + { + // No change needed. + return false; + } + } + + bool lowerLoad(IRInst* context, IRInst* inst) + { + auto valInfo = tryGetInfo(context, inst); + + if (!valInfo) + return false; + + IRType* ptrValType = nullptr; + ptrValType = as(as(inst)->getPtr()->getDataType())->getValueType(); + + if (ptrValType != (IRType*)getLoweredType(valInfo)) + { + SLANG_ASSERT(!as(inst)); + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + IRCloneEnv cloneEnv; + auto newLoad = cloneInst(&cloneEnv, &builder, inst); + + auto loweredVal = upcastCollection(context, newLoad, (IRType*)valInfo); + + // TODO: this is a hack. Encode this in the type-flow-data. + if (as(ptrValType) && !isComInterfaceType(inst->getDataType()) && + !isBuiltin(inst->getDataType())) + { + newLoad->setFullType(lowerInterfaceType(as(ptrValType))); + } + + inst->replaceUsesWith(loweredVal); + inst->removeAndDeallocate(); + + return true; + } + else if (inst->getDataType() != ptrValType) + { + inst->setFullType((IRType*)getLoweredType(valInfo)); + return true; + } + + return false; + } + bool lowerStore(IRInst* context, IRStore* inst) { auto ptr = inst->getPtr(); @@ -2416,7 +2689,7 @@ struct DynamicInstLoweringContext auto valInfo = inst->getVal()->getDataType(); - auto loweredVal = upcastCollection(context, inst->getVal(), valInfo); + auto loweredVal = upcastCollection(context, inst->getVal(), ptrInfo); if (loweredVal != inst->getVal()) { @@ -2749,44 +3022,6 @@ struct TagOpsLoweringContext : public InstPassBase { } - void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) - { - auto srcInterfaceType = cast(inst->getOperand(0)); - auto srcSeqID = inst->getOperand(1); - - Dictionary mapping; - - // Map from sequential ID to unique ID - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); - - UIndex dstSeqID = 0; - forEachInCollection( - destCollection, - [&](IRInst* table) - { - // Get unique ID for the witness table - auto witnessTable = cast(table); - auto outputId = dstSeqID++; - auto seqDecoration = table->findDecoration(); - if (seqDecoration) - { - auto inputId = seqDecoration->getSequentialID(); - mapping[inputId] = outputId; // Map ID to itself for now - } - }); - - IRBuilder builder(inst); - builder.setInsertAfter(inst); - auto translatedID = builder.emitCallInst( - inst->getDataType(), - createIntegerMappingFunc(builder.getModule(), mapping), - List({srcSeqID})); - - inst->replaceUsesWith(translatedID); - inst->removeAndDeallocate(); - } - void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) { auto srcCollection = cast( @@ -2885,9 +3120,6 @@ struct TagOpsLoweringContext : public InstPassBase { switch (inst->getOp()) { - case kIROp_GetTagFromSequentialID: - lowerGetTagFromSequentialID(as(inst)); - break; case kIROp_GetTagForSuperCollection: lowerGetTagForSuperCollection(as(inst)); break; @@ -2957,13 +3189,93 @@ struct CollectionLoweringContext : public InstPassBase collection->replaceUsesWith(anyValueType); } - void processModule() { processInstsOfType( kIROp_TypeCollection, [&](IRTypeCollection* inst) { return lowerTypeCollection(inst); }); + } +}; + +void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) +{ + CollectionLoweringContext context(module); + context.processModule(); +} + +struct SequentialIDTagLoweringContext : public InstPassBase +{ + SequentialIDTagLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) + { + auto srcInterfaceType = cast(inst->getOperand(0)); + auto srcSeqID = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + + UIndex dstSeqID = 0; + forEachInCollection( + destCollection, + [&](IRInst* table) + { + // Get unique ID for the witness table + auto witnessTable = cast(table); + auto outputId = dstSeqID++; + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping[inputId] = outputId; // Map ID to itself for now + } + }); + + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping), + List({srcSeqID})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + + void processModule() + { + processInstsOfType( + kIROp_GetTagFromSequentialID, + [&](IRGetTagFromSequentialID* inst) { return lowerGetTagFromSequentialID(inst); }); + } +}; + +void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) +{ + SequentialIDTagLoweringContext context(module); + context.processModule(); +} +void lowerTagInsts(IRModule* module, DiagnosticSink* sink) +{ + TagOpsLoweringContext tagContext(module); + tagContext.processModule(); +} + +struct TagTypeLoweringContext : public InstPassBase +{ + TagTypeLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void processModule() + { processInstsOfType( kIROp_CollectionTagType, [&](IRCollectionTagType* inst) @@ -2974,12 +3286,9 @@ struct CollectionLoweringContext : public InstPassBase } }; -void lowerCollectionAndTagInsts(IRModule* module, DiagnosticSink* sink) +void lowerTagTypes(IRModule* module) { - TagOpsLoweringContext tagContext(module); - tagContext.processModule(); - - CollectionLoweringContext context(module); + TagTypeLoweringContext context(module); context.processModule(); } diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index 32a782b7285..eae821a0a9d 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -8,5 +8,11 @@ namespace Slang { // Main entry point for the pass bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); -void lowerCollectionAndTagInsts(IRModule* module, DiagnosticSink* sink); +// void lowerCollectionAndTagInsts(IRModule* module, DiagnosticSink* sink); + +void lowerTypeCollections(IRModule* module, DiagnosticSink* sink); +void lowerTagInsts(IRModule* module, DiagnosticSink* sink); + +void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink); +void lowerTagTypes(IRModule* module); } // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index be21a66ba74..8854693e7e7 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -9,6 +9,7 @@ #include "slang-ir-generics-lowering-context.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-layout.h" +#include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-lower-existential.h" #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-function.h" @@ -137,6 +138,9 @@ void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, Diagnos if (sink->getErrorCount() != 0) return; + lowerSequentialIDTagCasts(sharedContext->module, sharedContext->sink); + lowerTagTypes(sharedContext->module); + lowerIsTypeInsts(sharedContext); specializeDynamicAssociatedTypeLookup(sharedContext); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index c4e6eb1370b..8d654f31faa 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3051,7 +3051,7 @@ void finalizeSpecialization(IRModule* module) } } - lowerCollectionAndTagInsts(module, nullptr); + // lowerCollectionAndTagInsts(module, nullptr); } From 10b8118e096a51e5d155c0757854b543dc36f480 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 5 Aug 2025 15:21:32 -0400 Subject: [PATCH 023/105] Fix up to get all dynamic-dispatch tests to pass --- source/slang/slang-ir-lower-dynamic-insts.cpp | 266 ++++++++++++------ 1 file changed, 175 insertions(+), 91 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 8bcc12cdbb5..59f1ab9fd74 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -386,7 +386,7 @@ struct DynamicInstLoweringContext return result; } - IRTypeFlowData* tryGetInfo(Element element) + IRInst* tryGetInfo(Element element) { // For non-global instructions, look up in the map auto found = propagationMap.tryGetValue(element); @@ -395,7 +395,7 @@ struct DynamicInstLoweringContext return none(); } - IRTypeFlowData* tryGetInfo(IRInst* context, IRInst* inst) + IRInst* tryGetInfo(IRInst* context, IRInst* inst) { if (auto typeFlowData = as(inst->getDataType())) { @@ -432,7 +432,7 @@ struct DynamicInstLoweringContext return tryGetInfo(Element(context, inst)); } - IRTypeFlowData* tryGetFuncReturnInfo(IRFunc* func) + IRInst* tryGetFuncReturnInfo(IRFunc* func) { auto found = funcReturnInfo.tryGetValue(func); if (found) @@ -443,11 +443,7 @@ struct DynamicInstLoweringContext // Centralized method to update propagation info and manage work queue // Use this when you want to propagate new information to an existing instruction // This will union the new info with existing info and add users to work queue if changed - void updateInfo( - IRInst* context, - IRInst* inst, - IRTypeFlowData* newInfo, - LinkedList& workQueue) + void updateInfo(IRInst* context, IRInst* inst, IRInst* newInfo, LinkedList& workQueue) { auto existingInfo = tryGetInfo(context, inst); auto unionedInfo = unionPropagationInfo(existingInfo, newInfo); @@ -468,7 +464,7 @@ struct DynamicInstLoweringContext void addUsersToWorkQueue( IRInst* context, IRInst* inst, - IRTypeFlowData* info, + IRInst* info, LinkedList& workQueue) { for (auto use = inst->firstUse; use; use = use->nextUse) @@ -525,10 +521,7 @@ struct DynamicInstLoweringContext } // Helper method to update function return info and propagate to call sites - void updateFuncReturnInfo( - IRInst* callable, - IRTypeFlowData* returnInfo, - LinkedList& workQueue) + void updateFuncReturnInfo(IRInst* callable, IRInst* returnInfo, LinkedList& workQueue) { auto existingReturnInfo = getFuncReturnInfo(callable); auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnInfo); @@ -740,7 +733,7 @@ struct DynamicInstLoweringContext void processInstForPropagation(IRInst* context, IRInst* inst, LinkedList& workQueue) { - IRTypeFlowData* info; + IRInst* info; switch (inst->getOp()) { @@ -779,7 +772,7 @@ struct DynamicInstLoweringContext info = analyzeStore(context, as(inst), workQueue); break; case kIROp_GetElementPtr: - info = analyzeGetElementPtr(context, as(inst), workQueue); + info = analyzeGetElementPtr(context, as(inst)); break; default: info = analyzeDefault(context, inst); @@ -789,7 +782,7 @@ struct DynamicInstLoweringContext updateInfo(context, inst, info, workQueue); } - IRTypeFlowData* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) + IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { // // TODO: Actually use the integer<->type map present in the linkage to @@ -815,7 +808,7 @@ struct DynamicInstLoweringContext return none(); } - IRTypeFlowData* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) + IRInst* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) { auto witnessTable = inst->getWitnessTable(); auto value = inst->getWrappedValue(); @@ -845,7 +838,7 @@ struct DynamicInstLoweringContext inst->getOp() == kIROp_RWStructuredBufferGetElementPtr; } - IRTypeFlowData* analyzeLoad(IRInst* context, IRInst* inst) + IRInst* analyzeLoad(IRInst* context, IRInst* inst) { // Default: Transfer the prop info from the address to the loaded value if (auto loadInst = as(inst)) @@ -873,7 +866,10 @@ struct DynamicInstLoweringContext // If the load is from a pointer, we can transfer the info directly auto address = as(loadInst)->getPtr(); - return tryGetInfo(context, address); + if (auto addrInfo = tryGetInfo(context, address)) + return as(addrInfo)->getValueType(); + else + return none(); // No info for the address } else if (as(inst) || as(inst)) { @@ -898,18 +894,43 @@ struct DynamicInstLoweringContext return none(); // No info for other load types } - IRTypeFlowData* analyzeStore( - IRInst* context, - IRStore* storeInst, - LinkedList& workQueue) + IRInst* analyzeStore(IRInst* context, IRStore* storeInst, LinkedList& workQueue) { // Transfer the prop info from stored value to the address auto address = storeInst->getPtr(); - updateInfo(context, address, tryGetInfo(context, storeInst->getVal()), workQueue); + if (auto valInfo = tryGetInfo(context, storeInst->getVal())) + { + IRBuilder builder(module); + auto ptrInfo = builder.getPtrTypeWithAddressSpace( + (IRType*)valInfo, + as(address->getDataType())); + + // Update the base instruction for the entire access chain + maybeUpdatePtr(context, address, ptrInfo, workQueue); + } + return none(); // The store itself doesn't have any info. } - IRTypeFlowData* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) + IRInst* analyzeGetElementPtr(IRInst* context, IRGetElementPtr* getElementPtr) + { + // The base info should be in Ptr> form, so we just need to unpack and + // return Ptr as the result. + // + IRBuilder builder(module); + builder.setInsertAfter(getElementPtr); + auto basePtr = getElementPtr->getBase(); + if (auto ptrType = as(tryGetInfo(context, basePtr))) + { + auto arrayType = as(ptrType->getValueType()); + SLANG_ASSERT(arrayType); + return builder.getPtrTypeWithAddressSpace(arrayType->getElementType(), ptrType); + } + + return none(); // No info for the base pointer + } + + IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { auto key = inst->getRequirementKey(); @@ -934,7 +955,7 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); } - IRTypeFlowData* analyzeExtractExistentialWitnessTable( + IRInst* analyzeExtractExistentialWitnessTable( IRInst* context, IRExtractExistentialWitnessTable* inst) { @@ -953,7 +974,7 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } - IRTypeFlowData* analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) + IRInst* analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); @@ -970,7 +991,7 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); } - IRTypeFlowData* analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) + IRInst* analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); @@ -985,7 +1006,7 @@ struct DynamicInstLoweringContext return cast(taggedUnion->getOperand(0)); } - IRTypeFlowData* analyzeSpecialize(IRInst* context, IRSpecialize* inst) + IRInst* analyzeSpecialize(IRInst* context, IRSpecialize* inst) { auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); @@ -1194,7 +1215,7 @@ struct DynamicInstLoweringContext } } - IRTypeFlowData* analyzeCall(IRInst* context, IRCall* inst, LinkedList& workQueue) + IRInst* analyzeCall(IRInst* context, IRCall* inst, LinkedList& workQueue) { auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); @@ -1249,14 +1270,41 @@ struct DynamicInstLoweringContext return none(); } - IRTypeFlowData* analyzeGetElementPtr( + void maybeUpdatePtr( IRInst* context, - IRGetElementPtr* inst, + IRInst* inst, + IRInst* info, LinkedList& workQueue) { - IRTypeFlowData* thisInstInfo = tryGetInfo(context, inst); - updateInfo(context, inst->getBase(), thisInstInfo, workQueue); - return thisInstInfo; + if (auto getElementPtr = as(inst)) + { + if (auto thisPtrInfo = as(info)) + { + auto thisValueType = thisPtrInfo->getValueType(); + + IRInst* baseValueType = + as(getElementPtr->getBase()->getDataType())->getValueType(); + SLANG_ASSERT(as(baseValueType)); + + // Propagate 'this' information to the base by wrapping it as a pointer to array. + IRBuilder builder(module); + auto baseInfo = builder.getPtrTypeWithAddressSpace( + builder.getArrayType( + (IRType*)thisValueType, + as(baseValueType)->getElementCount()), + as(getElementPtr->getBase()->getDataType())); + maybeUpdatePtr(context, getElementPtr->getBase(), baseInfo, workQueue); + } + } + else if (auto var = as(inst)) + { + updateInfo(context, var, info, workQueue); + } + else + { + // Do nothing.. + return; + } } void propagateWithinFuncEdge(IRInst* context, IREdge edge, LinkedList& workQueue) @@ -1321,7 +1369,7 @@ struct DynamicInstLoweringContext UIndex idx = 0; for (auto param : func->getParams()) { - if (auto newType = tryGetInfo(context, param)) + /*if (auto newType = tryGetInfo(context, param)) effectiveTypes.add((IRType*)newType); else { @@ -1329,16 +1377,21 @@ struct DynamicInstLoweringContext as(context->getDataType())->getParamType(idx)); SLANG_ASSERT(isGlobalInst(type)); effectiveTypes.add((IRType*)type); - } + }*/ + if (auto newType = tryGetInfo(context, param)) + effectiveTypes.add((IRType*)newType); + else + effectiveTypes.add( + (IRType*)as(context->getDataType())->getParamType(idx)); idx++; } return effectiveTypes; } - List getParamInfos(IRInst* context) + List getParamInfos(IRInst* context) { - List infos; + List infos; if (as(context)) { for (auto param : as(context)->getParams()) @@ -1423,8 +1476,29 @@ struct DynamicInstLoweringContext continue; } - // Use centralized update method - updateInfo(edge.targetContext, param, argInfo, workQueue); + switch (paramDirection) + { + case kParameterDirection_Out: + case kParameterDirection_InOut: + { + IRBuilder builder(module); + auto newInfo = fromDirectionAndType( + &builder, + paramDirection, + as(argInfo)->getValueType()); + updateInfo(edge.targetContext, param, newInfo, workQueue); + break; + } + case kParameterDirection_In: + { + // Use centralized update method + updateInfo(edge.targetContext, param, argInfo, workQueue); + break; + } + default: + SLANG_UNEXPECTED( + "Unhandled parameter direction in interprocedural edge"); + } } } argIndex++; @@ -1450,9 +1524,13 @@ struct DynamicInstLoweringContext if (paramDirections[argIndex] == kParameterDirection_Out || paramDirections[argIndex] == kParameterDirection_InOut) { + auto arg = callInst->getArg(argIndex); + auto argPtrType = as(arg->getDataType()); + + IRBuilder builder(module); updateInfo( edge.callerContext, - callInst->getArg(argIndex), + builder.getPtrTypeWithAddressSpace((IRType*)paramInfo, argPtrType), paramInfo, workQueue); } @@ -1467,7 +1545,7 @@ struct DynamicInstLoweringContext } } - IRTypeFlowData* getFuncReturnInfo(IRInst* callee) + IRInst* getFuncReturnInfo(IRInst* callee) { funcReturnInfo.addIfNotExists(callee, none()); return funcReturnInfo[callee]; @@ -1524,13 +1602,36 @@ struct DynamicInstLoweringContext allValues)); // Create a new collection with the union of values } - IRTypeFlowData* unionPropagationInfo(IRTypeFlowData* info1, IRTypeFlowData* info2) + IRInst* unionPropagationInfo(IRInst* info1, IRInst* info2) { if (!info1) return info2; if (!info2) return info1; + if (as(info1) && as(info2)) + { + SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); + // If both are array types, union their element types + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getArrayType( + (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), + info1->getOperand(1)); // Keep the same size + } + + if (as(info1) && as(info2)) + { + SLANG_ASSERT(info1->getOp() == info2->getOp()); + + // If both are array types, union their element types + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getPtrTypeWithAddressSpace( + (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), + as(info1)); + } + if (as(info1) && as(info2)) { // If either info is unbounded, the union is unbounded @@ -1721,19 +1822,7 @@ struct DynamicInstLoweringContext List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); } - IRType* lowerTypeForInst(IRInst* context, IRInst* inst) - { - if (auto info = tryGetInfo(context, inst)) - { - return getLoweredType(info); - } - else - { - return inst->getDataType(); // If no info, return the original type - } - } - - IRType* getLoweredType(IRTypeFlowData* info) + IRType* getLoweredType(IRInst* info) { if (!info) return nullptr; @@ -1741,6 +1830,22 @@ struct DynamicInstLoweringContext if (as(info)) return nullptr; + if (auto ptrType = as(info)) + { + IRBuilder builder(module); + return builder.getPtrTypeWithAddressSpace( + (IRType*)getLoweredType(ptrType->getValueType()), + ptrType); + } + + if (auto arrayType = as(info)) + { + IRBuilder builder(module); + return builder.getArrayType( + (IRType*)getLoweredType(arrayType->getElementType()), + arrayType->getElementCount()); + } + if (auto taggedUnion = as(info)) { // If this is a tagged union, we need to create a tuple type @@ -1773,37 +1878,22 @@ struct DynamicInstLoweringContext return nullptr; } - SLANG_UNEXPECTED("Unhandled IRTypeFlowData type in getLoweredType"); + return (IRType*)info; + // SLANG_UNEXPECTED("Unhandled IRTypeFlowData type in getLoweredType"); } bool replaceType(IRInst* context, IRInst* inst) { if (auto info = tryGetInfo(context, inst)) { - if (auto ptrType = as(inst->getDataType())) + if (auto loweredType = getLoweredType(info)) { - IRBuilder builder(module); - if (auto loweredType = getLoweredType(info)) - { - auto loweredPtrType = builder.getPtrTypeWithAddressSpace(loweredType, ptrType); - if (loweredPtrType == inst->getDataType()) - return false; // No change - inst->setFullType(loweredPtrType); - return true; - } - } - else - { - if (auto loweredType = getLoweredType(info)) - { - if (loweredType == inst->getDataType()) - return false; // No change - inst->setFullType(loweredType); - return true; - } + if (loweredType == inst->getDataType()) + return false; // No change + inst->setFullType(loweredType); + return true; } } - return false; } @@ -2165,22 +2255,16 @@ struct DynamicInstLoweringContext auto paramDirections = getParamDirections(context); for (UInt i = 0; i < paramEffectiveTypes.getCount(); i++) { - if (auto collectionType = as(paramEffectiveTypes[i])) - updateParamType( - i, - fromDirectionAndType( - &builder, - paramDirections[i], - getLoweredType(collectionType))); + updateParamType(i, getLoweredType(paramEffectiveTypes[i])); + /*if (auto collectionType = as(paramEffectiveTypes[i])) + else if (paramEffectiveTypes[i] != nullptr) updateParamType( i, fromDirectionAndType( &builder, paramDirections[i], - (IRType*)paramEffectiveTypes[i])); - else - SLANG_UNEXPECTED("Unhandled parameter type in getEffectiveFuncType"); + (IRType*)paramEffectiveTypes[i]));*/ } auto returnType = getFuncReturnInfo(context); @@ -2785,10 +2869,10 @@ struct DynamicInstLoweringContext DiagnosticSink* sink; // Mapping from instruction to propagation information - Dictionary propagationMap; + Dictionary propagationMap; // Mapping from function to return value propagation information - Dictionary funcReturnInfo; + Dictionary funcReturnInfo; // Mapping from functions to call-sites. Dictionary> funcCallSites; From 2f858898363c391e89af2589fed61c3e75d320d6 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 7 Aug 2025 17:26:39 -0400 Subject: [PATCH 024/105] Multiple fixes to auto-diff and specialization pass. All dynamic-dispatch tests passing --- source/slang/slang-ir-autodiff-fwd.cpp | 12 +- source/slang/slang-ir-autodiff-rev.cpp | 34 +- .../slang-ir-autodiff-transcriber-base.cpp | 9 + source/slang/slang-ir-insts-stable-names.lua | 3 +- source/slang/slang-ir-insts.lua | 3 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 486 +++++++++++++++++- source/slang/slang-ir-specialize.cpp | 90 ++-- source/slang/slang-ir-util.cpp | 23 +- source/slang/slang-lower-to-ir.cpp | 4 +- 9 files changed, 591 insertions(+), 73 deletions(-) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index cd729ce6bee..67ff8b00e98 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1194,7 +1194,11 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( // the generic args to specialize the primal function. This is true for all of our core // module functions, but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffBaseSpecialize->getBase()->getDataType(), + args.getCount(), + args.getBuffer()), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); @@ -1209,7 +1213,11 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize( } diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffBase->getDataType(), + args.getCount(), + args.getBuffer()), diffBase, args.getCount(), args.getBuffer()); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ca768cd66eb..b4cb1920286 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -1435,8 +1435,28 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( { args.add(primalSpecialize->getArg(i)); } + IRType* typeForSpecialization = nullptr; + if ((*diffBase)->getDataType()->getOp() == kIROp_TypeKind || + (*diffBase)->getDataType()->getOp() == kIROp_GenericKind) + { + typeForSpecialization = (*diffBase)->getDataType(); + } + else if ((*diffBase)->getDataType()->getOp() == kIROp_Generic) + { + typeForSpecialization = (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + (*diffBase)->getDataType(), + args.getCount(), + args.getBuffer()); + } + else + { + // Default to type kind for now. + typeForSpecialization = builder->getTypeKind(); + } + auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + typeForSpecialization, *diffBase, args.getCount(), args.getBuffer()); @@ -1472,7 +1492,11 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( // the generic args to specialize the primal function. This is true for all of our core // module functions, but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffBaseSpecialize->getBase()->getDataType(), + args.getCount(), + args.getBuffer()), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); @@ -1488,7 +1512,11 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( } auto diffCallee = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), + (IRType*)builder->emitSpecializeInst( + builder->getTypeKind(), + diffCallee->getDataType(), + args.getCount(), + args.getBuffer()), diffCallee, args.getCount(), args.getBuffer()); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index a4934dc282e..f5203b44a69 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1032,6 +1032,15 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene mapDifferentialInst(origGeneric->getFirstBlock(), bodyBlock); transcribeBlockImpl(&builder, origGeneric->getFirstBlock(), instsToSkip); + auto diffReturnVal = getGenericReturnVal(diffGeneric); + if (auto func = as(diffReturnVal)) + { + IRInst* outSpecializedValue = nullptr; + auto hoistedFuncType = + hoistValueFromGeneric(builder, func->getDataType(), outSpecializedValue); + diffGeneric->setFullType((IRType*)hoistedFuncType); + } + return InstPair(primalGeneric, diffGeneric); } diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index e5dc000ebc6..73b89a94d90 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -678,5 +678,6 @@ return { ["TypeFlowData.CollectionTaggedUnionType"] = 674, ["GetTagForSuperCollection"] = 675, ["GetTagForMappedCollection"] = 676, - ["GetTagFromSequentialID"] = 677 + ["GetTagFromSequentialID"] = 677, + ["GetSequentialIDFromTag"] = 678 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 6069af690a8..4015c594793 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2191,8 +2191,9 @@ local insts = { { GetTagForSuperCollection = {} }, -- Translate a tag from a set to its equivalent in a super-set { GetTagForMappedCollection = {} }, -- Translate a tag from a set to its equivalent in a different set -- based on a mapping induced by a lookup key - { GetTagFromSequentialID = {} } -- Translate an existing sequential ID & and interface type into a tag + { GetTagFromSequentialID = {} }, -- Translate an existing sequential ID & and interface type into a tag -- the provided collection. + { GetSequentialIDFromTag = {} } -- Translate a tag from the given collection to a sequential ID. } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 59f1ab9fd74..4f200d91441 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -13,6 +13,9 @@ namespace Slang { +// Forward-declare.. (TODO: Just include this from the header instead) +IRInst* specializeGeneric(IRSpecialize* specializeInst); + constexpr IRIntegerValue kDefaultAnyValueSize = 16; // Elements for which we keep track of propagation information. struct Element @@ -287,7 +290,7 @@ struct DynamicInstLoweringContext else if (as(inst->getDataType())) return kIROp_TableCollection; else - SLANG_UNEXPECTED("Unsupported collection type for instruction"); + return kIROp_Invalid; // Return invalid IROp when not supported } // Factory methods for PropagationInfo @@ -423,6 +426,9 @@ struct DynamicInstLoweringContext if (as(inst) && as(getGenericReturnVal(inst))) return none(); + // TODO: We really should return something like Singleton(collectionInst) here + // instead of directly returning the collection. + // return makeSingletonSet(inst); } else @@ -641,7 +647,7 @@ struct DynamicInstLoweringContext bool hasUpcastedElements = false; IRBuilder builder(module); - builder.setInsertAfter(arg); + setInsertAfterOrdinaryInst(&builder, arg); // Upcast each element of the tuple for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) @@ -768,12 +774,21 @@ struct DynamicInstLoweringContext case kIROp_StructuredBufferLoad: info = analyzeLoad(context, inst); break; + case kIROp_MakeStruct: + info = analyzeMakeStruct(context, as(inst), workQueue); + break; case kIROp_Store: info = analyzeStore(context, as(inst), workQueue); break; case kIROp_GetElementPtr: info = analyzeGetElementPtr(context, as(inst)); break; + case kIROp_FieldAddress: + info = analyzeFieldAddress(context, as(inst)); + break; + case kIROp_FieldExtract: + info = analyzeFieldExtract(context, as(inst)); + break; default: info = analyzeDefault(context, inst); break; @@ -832,6 +847,41 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } + IRInst* analyzeMakeStruct( + IRInst* context, + IRMakeStruct* makeStruct, + LinkedList& workQueue) + { + // We'll process this in the same way as a field-address, but for + // all fields of the struct. + // + auto structType = as(makeStruct->getDataType()); + UIndex operandIndex = 0; + for (auto field : structType->getFields()) + { + auto operand = makeStruct->getOperand(operandIndex); + if (auto fieldInfo = tryGetInfo(context, operand)) + { + IRInst* existingInfo = nullptr; + this->fieldInfo.tryGetValue(field, existingInfo); + auto newInfo = unionPropagationInfo(existingInfo, fieldInfo); + if (newInfo && !areInfosEqual(existingInfo, newInfo)) + { + // Update the field info map + this->fieldInfo[field] = newInfo; + + if (this->fieldUseSites.containsKey(field)) + for (auto useSite : this->fieldUseSites[field]) + workQueue.addLast(WorkItem(useSite.context, useSite.inst)); + } + } + + operandIndex++; + } + + return none(); // the make struct itself doesn't have any info. + } + bool isResourcePointer(IRInst* inst) { return isPointerToResourceType(inst->getDataType()) || @@ -930,6 +980,57 @@ struct DynamicInstLoweringContext return none(); // No info for the base pointer } + IRInst* analyzeFieldAddress(IRInst* context, IRFieldAddress* fieldAddress) + { + // The base info should be in Ptr form, so we just need to return Ptr as the result. + // + IRBuilder builder(module); + builder.setInsertAfter(fieldAddress); + auto basePtr = fieldAddress->getBase(); + + if (auto basePtrType = as(basePtr->getDataType())) + { + auto structType = as(basePtrType->getValueType()); + SLANG_ASSERT(structType); + auto structField = + findStructField(structType, as(fieldAddress->getField())); + + // Register this as a user of the field so updates will invoke this function again. + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(Element(context, fieldAddress)); + + if (this->fieldInfo.containsKey(structField)) + { + IRBuilder builder(module); + return builder.getPtrTypeWithAddressSpace( + (IRType*)this->fieldInfo[structField], + as(fieldAddress->getDataType())); + } + } + return none(); + } + + IRInst* analyzeFieldExtract(IRInst* context, IRFieldExtract* fieldExtract) + { + IRBuilder builder(module); + + auto structType = as(fieldExtract->getBase()->getDataType()); + SLANG_ASSERT(structType); + auto structField = findStructField(structType, as(fieldExtract->getField())); + + // Register this as a user of the field so updates will invoke this function again. + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(Element(context, fieldExtract)); + + if (this->fieldInfo.containsKey(structField)) + { + IRBuilder builder(module); + return this->fieldInfo[structField]; + } + + return none(); + } + IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { auto key = inst->getRequirementKey(); @@ -1105,9 +1206,30 @@ struct DynamicInstLoweringContext newParamTypes.getBuffer(), (IRType*)substituteSets(funcType->getResultType())); } + else if (auto typeInfo = tryGetInfo(context, inst->getDataType())) + { + // There's one other case we'd like to handle, where the func-type itself is a + // dynamic IRSpecialize. In this situation, we'd want to use the type inst's info to + // find the collection-based specialization and create a func-type from it. + // + if (auto tag = as(typeInfo)) + { + SLANG_ASSERT(getCollectionCount(tag) == 1); + auto specializeInst = as(getCollectionElement(tag, 0)); + auto funcType = as(specializeGeneric(specializeInst)); + if (!funcType) + { + SLANG_UNEXPECTED( + "Unexpected IRSpecialize in analyzeSpecialize for func type"); + return none(); + } + typeOfSpecialization = funcType; + } + } else { - SLANG_ASSERT_FAILURE("Unexpected data type for specialization instruction"); + // We don't have a type we can work with just yet. + return none(); // No info for the type } IRCollectionBase* collection = nullptr; @@ -1296,6 +1418,47 @@ struct DynamicInstLoweringContext maybeUpdatePtr(context, getElementPtr->getBase(), baseInfo, workQueue); } } + else if (auto fieldAddress = as(inst)) + { + // If this is a field address, update the fieldInfos map. + if (auto thisPtrInfo = as(info)) + { + auto thisValueType = thisPtrInfo->getValueType(); + IRBuilder builder(module); + auto baseStructPtrType = as(fieldAddress->getBase()->getDataType()); + auto baseStructType = as(baseStructPtrType->getValueType()); + if (!baseStructType) + return; // Do nothing.. + + if (auto fieldKey = as(fieldAddress->getField())) + { + IRStructField* foundField = findStructField(baseStructType, fieldKey); + IRInst* existingInfo = nullptr; + this->fieldInfo.tryGetValue(foundField, existingInfo); + + if (existingInfo) + existingInfo = builder.getPtrTypeWithAddressSpace( + (IRType*)existingInfo, + as(fieldAddress->getDataType())); + + if (auto newInfo = unionPropagationInfo(info, existingInfo)) + { + if (newInfo != existingInfo) + { + auto newInfoValType = cast(newInfo)->getValueType(); + + // Update the field info map + this->fieldInfo[foundField] = newInfoValType; + + // Add a work item to update the field extract + if (this->fieldUseSites.containsKey(foundField)) + for (auto useSite : this->fieldUseSites[foundField]) + workQueue.addLast(WorkItem(useSite.context, useSite.inst)); + } + } + } + } + } else if (auto var = as(inst)) { updateInfo(context, var, info, workQueue); @@ -1701,7 +1864,30 @@ struct DynamicInstLoweringContext UIndex paramIndex = 0; for (auto inst : instsToLower) + { hasChanges |= lowerInst(context, inst); + } + + return hasChanges; + } + + bool lowerStructType(IRStructType* structType) + { + bool hasChanges = false; + for (auto field : structType->getFields()) + { + IRInst* info = nullptr; + this->fieldInfo.tryGetValue(field, info); + if (!info) + continue; + + auto loweredFieldType = getLoweredType(info); + if (loweredFieldType != field->getDataType()) + { + hasChanges = true; + field->setFieldType(loweredFieldType); + } + } return hasChanges; } @@ -1780,25 +1966,26 @@ struct DynamicInstLoweringContext bool performDynamicInstLowering() { List funcsToProcess; + List structsToProcess; for (auto globalInst : module->getGlobalInsts()) + { if (auto func = as(globalInst)) - { funcsToProcess.add(func); - } + else if (auto structType = as(globalInst)) + structsToProcess.add(structType); + } bool hasChanges = false; - do - { - while (funcsToProcess.getCount() > 0) - { - auto func = funcsToProcess.getLast(); - funcsToProcess.removeLast(); - // Lower the instructions in the function - hasChanges |= lowerFunc(func); - } - } while (funcsToProcess.getCount() > 0); + // Lower struct types first so that data access can be + // marshalled properly during func lowering. + // + for (auto structType : structsToProcess) + hasChanges |= lowerStructType(structType); + + for (auto func : funcsToProcess) + hasChanges |= lowerFunc(func); return hasChanges; } @@ -1884,8 +2071,32 @@ struct DynamicInstLoweringContext bool replaceType(IRInst* context, IRInst* inst) { + if (as(inst->getParent())) + { + if (as(inst) || as(inst) || as(inst) || + as(inst)) + { + // Don't replace global concrete vals. + return false; + } + } + if (auto info = tryGetInfo(context, inst)) { + /* // Special cast for type collections. + if (auto collectionTagType = as(info)) + { + if (as(collectionTagType->getOperand(0))) + { + // Remove the tag and replace the inst itself with the type + // in the collection. + // + inst->replaceUsesWith(getLoweredType(collectionTagType->getOperand(0))); + inst->removeAndDeallocate(); + return true; + } + }*/ + if (auto loweredType = getLoweredType(info)) { if (loweredType == inst->getDataType()) @@ -1915,15 +2126,21 @@ struct DynamicInstLoweringContext return lowerCall(context, as(inst)); case kIROp_MakeExistential: return lowerMakeExistential(context, as(inst)); + case kIROp_MakeStruct: + return lowerMakeStruct(context, as(inst)); case kIROp_CreateExistentialObject: return lowerCreateExistentialObject(context, as(inst)); case kIROp_RWStructuredBufferLoad: case kIROp_StructuredBufferLoad: return lowerStructuredBufferLoad(context, inst); + case kIROp_Specialize: + return lowerSpecialize(context, as(inst)); case kIROp_Load: return lowerLoad(context, inst); case kIROp_Store: return lowerStore(context, as(inst)); + case kIROp_GetSequentialID: + return lowerGetSequentialID(context, as(inst)); default: { if (auto info = tryGetInfo(context, inst)) @@ -2529,14 +2746,14 @@ struct DynamicInstLoweringContext case kParameterDirection_InOut: { auto argValueType = as(arg->getDataType())->getValueType(); - if (argValueType != paramType) + /*if (argValueType != paramType) { SLANG_UNEXPECTED("ptr-typed parameters should have matching types"); } else - { - newArgs.add(arg); - } + {*/ + newArgs.add(arg); + //} break; } default: @@ -2548,12 +2765,17 @@ struct DynamicInstLoweringContext builder.setInsertBefore(inst); bool changed = false; - for (UInt i = 0; i < newArgs.getCount(); i++) + if (newArgs.getCount() != inst->getArgCount()) + changed = true; + else { - if (newArgs[i] != inst->getArg(i)) + for (UInt i = 0; i < newArgs.getCount(); i++) { - changed = true; - break; + if (newArgs[i] != inst->getArg(i)) + { + changed = true; + break; + } } } @@ -2585,6 +2807,30 @@ struct DynamicInstLoweringContext } } + bool lowerMakeStruct(IRInst* context, IRMakeStruct* inst) + { + auto structType = as(inst->getDataType()); + + // Reinterpret any of the arguments as necessary. + bool changed = false; + UIndex operandIndex = 0; + for (auto field : structType->getFields()) + { + auto arg = inst->getOperand(operandIndex); + auto newArg = upcastCollection(context, arg, field->getFieldType()); + + if (arg != newArg) + { + changed = true; + inst->setOperand(operandIndex, newArg); + } + + operandIndex++; + } + + return changed; + } + bool lowerMakeExistential(IRInst* context, IRMakeExistential* inst) { auto info = tryGetInfo(context, inst); @@ -2724,6 +2970,72 @@ struct DynamicInstLoweringContext } } + bool lowerSpecialize(IRInst* context, IRSpecialize* inst) + { + auto returnVal = getGenericReturnVal(inst->getBase()); + + // Functions should be handled at the call site (in lowerCall) + // since witness table specialization arguments must be inlined into the call. + // + if (as(returnVal)) + { + // TODO: Maybe make this the 'default' behavior if a lowering call + // returns false. + // + if (auto info = tryGetInfo(context, inst)) + return replaceType(context, inst); + else + return false; + } + + // For all other specializations, we'll 'drop' the dyanamic tag information. + bool changed = false; + List args; + for (UIndex i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + auto argDataType = arg->getDataType(); + if (auto collectionTagType = as(argDataType)) + { + // If this is a tag type, replace with collection. + changed = true; + args.add(collectionTagType->getOperand(0)); + } + else + { + args.add(arg); + } + } + + /*IRType* typeForSpecialization = nullptr; + if (auto info = tryGetInfo(context, inst)) + { + changed = true; + typeForSpecialization = getLoweredType(info); + } + else + { + typeForSpecialization = inst->getDataType(); + }*/ + IRBuilder builder(inst); + IRType* typeForSpecialization = builder.getTypeKind(); + + if (changed) + { + auto newInst = builder.emitSpecializeInst( + typeForSpecialization, + inst->getBase(), + args.getCount(), + args.getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + + return false; + } + bool lowerLoad(IRInst* context, IRInst* inst) { auto valInfo = tryGetInfo(context, inst); @@ -2766,6 +3078,23 @@ struct DynamicInstLoweringContext return false; } + bool handleDefaultStore(IRInst* context, IRStore* inst) + { + SLANG_ASSERT(inst->getVal()->getOp() == kIROp_DefaultConstruct); + auto ptr = inst->getPtr(); + auto destInfo = as(ptr->getDataType())->getValueType(); + auto valInfo = inst->getVal()->getDataType(); + + // "Legalize" the store type. + if (destInfo != valInfo) + { + inst->getVal()->setFullType(destInfo); + return true; + } + else + return false; + } + bool lowerStore(IRInst* context, IRStore* inst) { auto ptr = inst->getPtr(); @@ -2773,6 +3102,15 @@ struct DynamicInstLoweringContext auto valInfo = inst->getVal()->getDataType(); + // Special case for default initialization: + // + // Raw default initialization has been almost entirely + // removed from Slang, but the auto-diff process can sometimes + // produce a store of default-constructed value. + // + if (auto defaultConstruct = as(inst->getVal())) + return handleDefaultStore(context, inst); + auto loweredVal = upcastCollection(context, inst->getVal(), ptrInfo); if (loweredVal != inst->getVal()) @@ -2786,15 +3124,41 @@ struct DynamicInstLoweringContext return false; } - UInt getUniqueID(IRInst* funcOrTable) + bool lowerGetSequentialID(IRInst* context, IRGetSequentialID* inst) + { + auto arg = inst->getOperand(0); + if (auto tagType = as(arg->getDataType())) + { + IRBuilder builder(inst); + setInsertAfterOrdinaryInst(&builder, inst); + auto firstElement = + getCollectionElement(as(tagType->getOperand(0)), 0); + auto interfaceType = + as(as(firstElement)->getConformanceType()); + List args = {interfaceType, arg}; + auto newInst = builder.emitIntrinsicInst( + (IRType*)builder.getUIntType(), + kIROp_GetSequentialIDFromTag, + args.getCount(), + args.getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + + return false; + } + + UInt getUniqueID(IRInst* inst) { - auto existingId = uniqueIds.tryGetValue(funcOrTable); + auto existingId = uniqueIds.tryGetValue(inst); if (existingId) return *existingId; - UInt newId = nextUniqueId++; - uniqueIds[funcOrTable] = newId; - return newId; + // If we reach here, this instruction wasn't assigned an ID during initialization + // This should only happen for instructions that don't support collection types + SLANG_UNEXPECTED("getUniqueID called on instruction without pre-assigned ID"); } bool isExistentialType(IRType* type) { return as(type) != nullptr; } @@ -2862,6 +3226,22 @@ struct DynamicInstLoweringContext DynamicInstLoweringContext(IRModule* module, DiagnosticSink* sink) : module(module), sink(sink) { + initializeUniqueIDs(); + } + + // Initialize unique IDs for all global instructions that can be part of collections + void initializeUniqueIDs() + { + UInt currentID = 1; + for (auto inst : module->getGlobalInsts()) + { + // Only assign IDs to instructions that can be part of collections + IROp collectionType = getCollectionTypeForInst(inst); + if (collectionType != kIROp_Invalid) + { + uniqueIds[inst] = currentID++; + } + } } // Basic context @@ -2874,9 +3254,15 @@ struct DynamicInstLoweringContext // Mapping from function to return value propagation information Dictionary funcReturnInfo; + // Mapping from struct fields to propagation information + Dictionary fieldInfo; + // Mapping from functions to call-sites. Dictionary> funcCallSites; + // Mapping from fields to use-sites. + Dictionary> fieldUseSites; + // Unique ID assignment for functions and witness tables Dictionary uniqueIds; UInt nextUniqueId = 1; @@ -3293,6 +3679,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase : InstPassBase(module) { } + void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) { auto srcInterfaceType = cast(inst->getOperand(0)); @@ -3331,11 +3718,54 @@ struct SequentialIDTagLoweringContext : public InstPassBase inst->removeAndDeallocate(); } + + void lowerGetSequentialIDFromTag(IRGetSequentialIDFromTag* inst) + { + auto srcInterfaceType = cast(inst->getOperand(0)); + auto srcTagInst = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destCollection = cast( + cast(srcTagInst->getDataType())->getOperand(0)); + + UIndex dstSeqID = 0; + forEachInCollection( + destCollection, + [&](IRInst* table) + { + // Get unique ID for the witness table + auto witnessTable = cast(table); + auto outputId = dstSeqID++; + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping.add({outputId, inputId}); + } + }); + + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping), + List({srcTagInst})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + void processModule() { processInstsOfType( kIROp_GetTagFromSequentialID, [&](IRGetTagFromSequentialID* inst) { return lowerGetTagFromSequentialID(inst); }); + + processInstsOfType( + kIROp_GetSequentialIDFromTag, + [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); } }; diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 8d654f31faa..9e62f45756a 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1142,7 +1142,7 @@ struct SpecializationContext if (iterChanged) { this->changed = true; - eliminateDeadCode(module->getModuleInst()); + // eliminateDeadCode(module->getModuleInst()); applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } @@ -1160,7 +1160,24 @@ struct SpecializationContext { iterChanged = lowerDynamicInsts(module, sink); if (iterChanged) + { + // We'll write out the specialization info to an inst, + // and read it back again so we can remove entries + // for specializations that are no longer needed. + // + // If we don't do this, we'll end up with deallocated + // references in the specialization dictionaries, and + // can't reliably handle situations where the same specialization + // is requested again in the future once a different function + // has been specialized. + // + writeSpecializationDictionaries(); + genericSpecializations.clear(); + existentialSpecializedFuncs.clear(); + existentialSpecializedStructs.clear(); eliminateDeadCode(module->getModuleInst()); + readSpecializationDictionaries(); + } } if (!iterChanged || sink->getErrorCount()) @@ -3063,6 +3080,12 @@ static bool isDynamicGeneric(IRInst* callee) // if (auto specialize = as(callee)) { + auto generic = as(specialize->getBase()); + + // Only functions need dynamic-aware specialization. + if (getGenericReturnVal(generic)->getOp() != kIROp_Func) + return false; + for (UInt i = 0; i < specialize->getArgCount(); i++) { auto arg = specialize->getArg(i); @@ -3096,9 +3119,16 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) auto loweredFunc = builder.createFunc(); builder.setInsertInto(loweredFunc); builder.setInsertInto(builder.emitBlock()); - // loweredFunc->setFullType(context->getFullType()); IRCloneEnv cloneEnv; + cloneEnv.squashChildrenMapping = true; + + IRCloneEnv staticCloningEnv; + // Use this as the child to 'override' certain elements in the parent environment with + // their static versions. + // + staticCloningEnv.parent = &cloneEnv; + Index argIndex = 0; List extraParamTypes; // Map the generic's parameters to the specialized arguments. @@ -3115,11 +3145,17 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) } else if (as(param->getDataType())) { - // Add an integer param to the func. + // For cloning parameter types, we want to just use the + // collection. + // + staticCloningEnv.mapOldValToNew[param] = collection; + + // We'll create an integer parameter for all the rest of + // the insts which will may need the runtime tag. + // auto tagType = (IRType*)makeTagType(collection); cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); extraParamTypes.add(tagType); - // extraIndices++; } } else @@ -3146,10 +3182,7 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // Merge the first block of the generic with the first block of the // returned function to merge the parameter lists. // - // if (block != funcFirstBlock) - //{ cloneEnv.mapOldValToNew[block] = cloneInstAndOperands(&cloneEnv, &builder, block); - //} } builder.setInsertInto(loweredFunc->getFirstBlock()); @@ -3159,7 +3192,8 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) { // Clone the parameters of the first block. builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); - cloneInst(&cloneEnv, &builder, param); + auto newParam = cloneInst(&staticCloningEnv, &builder, param); + cloneEnv.mapOldValToNew[param] = newParam; // Transfer the param to the dynamic env } builder.setInsertInto(as(cloneEnv.mapOldValToNew[funcFirstBlock])); @@ -3174,6 +3208,7 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) { if (block == funcFirstBlock) continue; // Already cloned the first block + cloneInstDecorationsAndChildren( &cloneEnv, builder.getModule(), @@ -3199,37 +3234,22 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) } else if (!as(inst)) { - // Keep cloning insts in the generic + // Clone insts in the generic under two different environments: + // One that is "static" (for cloning types), and one that is "dynamic" + // which uses tags for witness tables instead of a static collection. + // + // We'll want to use the dynamic environment for cloning everything in the function + // body, but the static environment for cloning parameter types. + // + // Note that this can result in some types inside the function (i.e. not used in the + // parameter types) being cloned under the dynamic environment, but + // a subsequent pass of dynamic-inst-lowering will convert those to the static form. + // + cloneInst(&staticCloningEnv, &builder, inst); cloneInst(&cloneEnv, &builder, inst); } } - /* - // Transfer propagation info. - for (auto& [oldVal, newVal] : cloneEnv.mapOldValToNew) - { - if (propagationMap.containsKey(Element(context, oldVal))) - { - // If we have propagation info for the old value, transfer it to the new value - if (auto info = propagationMap[Element(context, oldVal)]) - { - if (newVal->getParent()->getOp() != kIROp_ModuleInst) - propagationMap[Element(loweredFunc, newVal)] = info; - } - } - } - - // Transfer func-return value info. - if (this->funcReturnInfo.containsKey(context)) - { - this->funcReturnInfo[loweredFunc] = this->funcReturnInfo[context]; - } - - - context->replaceUsesWith(loweredFunc); - */ - // context->removeAndDeallocate(); - // this->loweredContexts[context] = loweredFunc; return loweredFunc; } diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 9b852b80314..a6eed86cdc0 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -175,8 +175,29 @@ IRInst* specializeWithGeneric( { genArgs.add(param); } + + // Default to type kind for now. + IRType* typeForSpecialization = builder.getTypeKind(); + + auto dataType = genericToSpecialize->getDataType(); + if (dataType) + { + if (dataType->getOp() == kIROp_TypeKind || dataType->getOp() == kIROp_GenericKind) + { + typeForSpecialization = (genericToSpecialize)->getDataType(); + } + else if (dataType->getOp() == kIROp_Generic) + { + typeForSpecialization = (IRType*)builder.emitSpecializeInst( + builder.getTypeKind(), + (genericToSpecialize)->getDataType(), + genArgs.getCount(), + genArgs.getBuffer()); + } + } + return builder.emitSpecializeInst( - builder.getTypeKind(), + typeForSpecialization, genericToSpecialize, (UInt)genArgs.getCount(), genArgs.getBuffer()); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 55c175a83cf..33bc7c1c1ae 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3330,8 +3330,8 @@ void collectParameterLists( auto thisType = getThisParamTypeForContainer(context, parentDeclRef); if (thisType) { - thisType = as( - thisType->substitute(getCurrentASTBuilder(), SubstitutionSet(declRef))); + /*thisType = as( + thisType->substitute(getCurrentASTBuilder(), SubstitutionSet(declRef)));*/ if (declRef.getDecl()->findModifier()) { auto noDiffAttr = context->astBuilder->getNoDiffModifierVal(); From 1d9977709bb6ffb0b090606fc0aa1bc340878da4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 8 Aug 2025 16:56:52 -0400 Subject: [PATCH 025/105] More test fixes (COM interfaces are properly left alone now) --- 2.hlsl | 262 +++++++++++++++++ source/slang/slang-ir-lower-dynamic-insts.cpp | 263 ++++++++++++------ source/slang/slang-ir-specialize.cpp | 6 + tests/autodiff/custom-intrinsic-1.slang | 69 +++-- 4 files changed, 486 insertions(+), 114 deletions(-) create mode 100644 2.hlsl diff --git a/2.hlsl b/2.hlsl new file mode 100644 index 00000000000..c4f1b11161c --- /dev/null +++ b/2.hlsl @@ -0,0 +1,262 @@ +#pragma pack_matrix(row_major) + +#line 18 "tests/compute/dynamic-dispatch-17.slang" +struct UserDefinedPackedType_0 +{ + float3 val_0; + uint flags_0; +}; + + +#line 28 +RWStructuredBuffer gObj_0 : register(u1); + + +#line 25 +RWStructuredBuffer gOutputBuffer_0 : register(u0); + + +#line 25 +struct AnyValue16 +{ + uint field0_0; + uint field1_0; + uint field2_0; + uint field3_0; +}; + + +#line 25 +AnyValue16 packAnyValue16_0(UserDefinedPackedType_0 _S1) +{ + +#line 25 + AnyValue16 _S2; + +#line 25 + _S2.field0_0 = 0U; + +#line 25 + _S2.field1_0 = 0U; + +#line 25 + _S2.field2_0 = 0U; + +#line 25 + _S2.field3_0 = 0U; + +#line 25 + _S2.field0_0 = (uint)(asuint(_S1.val_0[int(0)])); + +#line 25 + _S2.field1_0 = (uint)(asuint(_S1.val_0[int(1)])); + +#line 25 + _S2.field2_0 = (uint)(asuint(_S1.val_0[int(2)])); + +#line 25 + _S2.field3_0 = _S1.flags_0; + +#line 25 + return _S2; +} + + +#line 39 +uint _S3(uint _S4) +{ + +#line 39 + switch(_S4) + { + case 3U: + { + +#line 39 + return 0U; + } + case 4U: + { + +#line 39 + return 1U; + } + default: + { + +#line 39 + return 0U; + } + } + +#line 39 +} + + +#line 50 +struct FloatVal_0 +{ + float val_1; +}; + + +#line 48 +float ReturnsZero_get_0() +{ + +#line 48 + return 0.0f; +} + + + +float FloatVal_run_0(FloatVal_0 this_0) +{ + + float _S5 = ReturnsZero_get_0(); + +#line 56 + return this_0.val_1 + _S5; +} + + +#line 56 +float U_S4main8FloatVal3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(AnyValue16 _S6) +{ + +#line 56 + float _S7 = FloatVal_run_0(_S6); + +#line 56 + return _S7; +} + +struct Float4Struct_0 +{ + float4 val_2; +}; + + +#line 60 +struct Float4Val_0 +{ + Float4Struct_0 val_3; +}; + + +#line 63 +float Float4Val_run_0(Float4Val_0 this_1) +{ + + float _S8 = this_1.val_3.val_2.x + this_1.val_3.val_2.y; + +#line 66 + float _S9 = ReturnsZero_get_0(); + +#line 66 + return _S8 + _S9; +} + + +#line 66 +float U_S4main9Float4Val3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(AnyValue16 _S10) +{ + +#line 66 + float _S11 = Float4Val_run_0(_S10); + +#line 66 + return _S11; +} + + +#line 66 +float _S12(uint _S13, AnyValue16 _S14) +{ + +#line 66 + switch(_S13) + { + case 0U: + { + +#line 66 + float _S15 = U_S4main8FloatVal3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(_S14); + +#line 66 + return _S15; + } + case 1U: + { + +#line 66 + float _S16 = U_S4main9Float4Val3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(_S14); + +#line 66 + return _S16; + } + default: + { + +#line 66 + return 0.0f; + } + } + +#line 66 +} + + +#line 34 +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID_0 : SV_DispatchThreadID) +{ + int i_0 = int(0); + +#line 37 + float result_0 = 0.0f; + +#line 37 + for(;;) + { + +#line 37 + if(i_0 < int(2)) + { + } + else + { + +#line 37 + break; + } + UserDefinedPackedType_0 rawObj_0 = gObj_0.Load(i_0); + uint _S17 = _S3(rawObj_0.flags_0); + +#line 40 + uint _S18[int(2)] = { 0U, 1U }; + +#line 40 + AnyValue16 _S19 = packAnyValue16_0(rawObj_0); + float _S20 = _S12(_S18[_S17], _S19); + +#line 41 + float result_1 = result_0 + _S20; + +#line 37 + int i_1 = i_0 + int(1); + +#line 37 + i_0 = i_1; + +#line 37 + result_0 = result_1; + +#line 37 + } + +#line 43 + gOutputBuffer_0[int(0)] = result_0; + return; +} + diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 4f200d91441..dc7e5e82a62 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -239,6 +239,33 @@ static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) return nullptr; // Not found } +struct WorkQueue +{ + List enqueueList; + List dequeueList; + UIndex dequeueIndex = 0; + + void enqueue(const WorkItem& item) { enqueueList.add(item); } + + WorkItem dequeue() + { + if (dequeueList.getCount() <= dequeueIndex) + { + dequeueList.swapWith(enqueueList); + enqueueList.clear(); + dequeueIndex = 0; + } + + SLANG_ASSERT(dequeueIndex < dequeueList.getCount()); + return dequeueList[dequeueIndex++]; + } + + bool hasItems() const + { + return (dequeueIndex < dequeueList.getCount()) || (enqueueList.getCount() > 0); + } +}; + struct DynamicInstLoweringContext { // Helper methods for creating canonical collections @@ -449,7 +476,7 @@ struct DynamicInstLoweringContext // Centralized method to update propagation info and manage work queue // Use this when you want to propagate new information to an existing instruction // This will union the new info with existing info and add users to work queue if changed - void updateInfo(IRInst* context, IRInst* inst, IRInst* newInfo, LinkedList& workQueue) + void updateInfo(IRInst* context, IRInst* inst, IRInst* newInfo, WorkQueue& workQueue) { auto existingInfo = tryGetInfo(context, inst); auto unionedInfo = unionPropagationInfo(existingInfo, newInfo); @@ -467,11 +494,7 @@ struct DynamicInstLoweringContext // Helper to add users of an instruction to the work queue based on how they use it // This handles intra-procedural edges, inter-procedural edges, and return value propagation - void addUsersToWorkQueue( - IRInst* context, - IRInst* inst, - IRInst* info, - LinkedList& workQueue) + void addUsersToWorkQueue(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { for (auto use = inst->firstUse; use; use = use->nextUse) { @@ -480,7 +503,7 @@ struct DynamicInstLoweringContext // If user is in a different block (or the inst is a param), add that block to work // queue. // - workQueue.addLast(WorkItem(context, user)); + workQueue.enqueue(WorkItem(context, user)); // If user is a terminator, add intra-procedural edges if (auto terminator = as(user)) @@ -492,7 +515,7 @@ struct DynamicInstLoweringContext for (auto succIter = successors.begin(); succIter != successors.end(); ++succIter) { - workQueue.addLast(WorkItem(context, succIter.getEdge())); + workQueue.enqueue(WorkItem(context, succIter.getEdge())); } } } @@ -515,7 +538,7 @@ struct DynamicInstLoweringContext if (this->funcCallSites.containsKey(context)) for (auto callSite : this->funcCallSites[context]) { - workQueue.addLast(WorkItem( + workQueue.enqueue(WorkItem( InterproceduralEdge::Direction::FuncToCall, callSite.context, as(callSite.inst), @@ -527,7 +550,7 @@ struct DynamicInstLoweringContext } // Helper method to update function return info and propagate to call sites - void updateFuncReturnInfo(IRInst* callable, IRInst* returnInfo, LinkedList& workQueue) + void updateFuncReturnInfo(IRInst* callable, IRInst* returnInfo, WorkQueue& workQueue) { auto existingReturnInfo = getFuncReturnInfo(callable); auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnInfo); @@ -541,7 +564,7 @@ struct DynamicInstLoweringContext { for (auto callSite : funcCallSites[callable]) { - workQueue.addLast(WorkItem( + workQueue.enqueue(WorkItem( InterproceduralEdge::Direction::FuncToCall, callSite.context, as(callSite.inst), @@ -551,7 +574,7 @@ struct DynamicInstLoweringContext } } - void processBlock(IRInst* context, IRBlock* block, LinkedList& workQueue) + void processBlock(IRInst* context, IRBlock* block, WorkQueue& workQueue) { for (auto inst : block->getChildren()) { @@ -571,7 +594,7 @@ struct DynamicInstLoweringContext void performInformationPropagation() { // Global worklist for interprocedural analysis - LinkedList workQueue; + WorkQueue workQueue; // Add all global functions to worklist for (auto inst : module->getGlobalInsts()) @@ -579,11 +602,9 @@ struct DynamicInstLoweringContext discoverContext(func, workQueue); // Process until fixed point - while (workQueue.getCount() > 0) + while (workQueue.hasItems()) { - // Pop work item from front - auto item = workQueue.getFirst(); - workQueue.getFirstNode()->removeAndDelete(); + auto item = workQueue.dequeue(); switch (item.type) { @@ -724,10 +745,14 @@ struct DynamicInstLoweringContext getTagOperands.getCount(), getTagOperands.getBuffer()); + auto effectiveValType = getCollectionCount(typeCollection) > 1 + ? typeCollection + : getCollectionElement(typeCollection, 0); + return builder.emitMakeTuple( {tableTag, builder.emitReinterpret( - (IRType*)typeCollection, + (IRType*)effectiveValType, builder.emitGetTupleElement( (IRType*)loweredInterfaceType->getOperand(0), arg, @@ -737,7 +762,7 @@ struct DynamicInstLoweringContext return arg; // Can use as-is. } - void processInstForPropagation(IRInst* context, IRInst* inst, LinkedList& workQueue) + void processInstForPropagation(IRInst* context, IRInst* inst, WorkQueue& workQueue) { IRInst* info; @@ -829,6 +854,12 @@ struct DynamicInstLoweringContext auto value = inst->getWrappedValue(); auto valueType = value->getDataType(); + // If we're building an existential for a COM pointer, + // we won't try to lower that. + // + if (isComInterfaceType(inst->getDataType())) + return makeUnbounded(); + // Get the witness table info auto witnessTableInfo = tryGetInfo(context, witnessTable); @@ -847,10 +878,7 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } - IRInst* analyzeMakeStruct( - IRInst* context, - IRMakeStruct* makeStruct, - LinkedList& workQueue) + IRInst* analyzeMakeStruct(IRInst* context, IRMakeStruct* makeStruct, WorkQueue& workQueue) { // We'll process this in the same way as a field-address, but for // all fields of the struct. @@ -872,7 +900,7 @@ struct DynamicInstLoweringContext if (this->fieldUseSites.containsKey(field)) for (auto useSite : this->fieldUseSites[field]) - workQueue.addLast(WorkItem(useSite.context, useSite.inst)); + workQueue.enqueue(WorkItem(useSite.context, useSite.inst)); } } @@ -884,8 +912,20 @@ struct DynamicInstLoweringContext bool isResourcePointer(IRInst* inst) { - return isPointerToResourceType(inst->getDataType()) || - inst->getOp() == kIROp_RWStructuredBufferGetElementPtr; + if (isPointerToResourceType(inst->getDataType()) || + inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) + return true; + + if (as(inst)) + return true; + + if (auto ptr = as(inst)) + return isResourcePointer(ptr->getBase()); + + if (auto fieldAddress = as(inst)) + return isResourcePointer(fieldAddress->getBase()); + + return false; } IRInst* analyzeLoad(IRInst* context, IRInst* inst) @@ -944,7 +984,7 @@ struct DynamicInstLoweringContext return none(); // No info for other load types } - IRInst* analyzeStore(IRInst* context, IRStore* storeInst, LinkedList& workQueue) + IRInst* analyzeStore(IRInst* context, IRStore* storeInst, WorkQueue& workQueue) { // Transfer the prop info from stored value to the address auto address = storeInst->getPtr(); @@ -990,21 +1030,22 @@ struct DynamicInstLoweringContext if (auto basePtrType = as(basePtr->getDataType())) { - auto structType = as(basePtrType->getValueType()); - SLANG_ASSERT(structType); - auto structField = - findStructField(structType, as(fieldAddress->getField())); + if (auto structType = as(basePtrType->getValueType())) + { + auto structField = + findStructField(structType, as(fieldAddress->getField())); - // Register this as a user of the field so updates will invoke this function again. - this->fieldUseSites.addIfNotExists(structField, HashSet()); - this->fieldUseSites[structField].add(Element(context, fieldAddress)); + // Register this as a user of the field so updates will invoke this function again. + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(Element(context, fieldAddress)); - if (this->fieldInfo.containsKey(structField)) - { - IRBuilder builder(module); - return builder.getPtrTypeWithAddressSpace( - (IRType*)this->fieldInfo[structField], - as(fieldAddress->getDataType())); + if (this->fieldInfo.containsKey(structField)) + { + IRBuilder builder(module); + return builder.getPtrTypeWithAddressSpace( + (IRType*)this->fieldInfo[structField], + as(fieldAddress->getDataType())); + } } } return none(); @@ -1014,20 +1055,21 @@ struct DynamicInstLoweringContext { IRBuilder builder(module); - auto structType = as(fieldExtract->getBase()->getDataType()); - SLANG_ASSERT(structType); - auto structField = findStructField(structType, as(fieldExtract->getField())); + if (auto structType = as(fieldExtract->getBase()->getDataType())) + { + auto structField = + findStructField(structType, as(fieldExtract->getField())); - // Register this as a user of the field so updates will invoke this function again. - this->fieldUseSites.addIfNotExists(structField, HashSet()); - this->fieldUseSites[structField].add(Element(context, fieldExtract)); + // Register this as a user of the field so updates will invoke this function again. + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(Element(context, fieldExtract)); - if (this->fieldInfo.containsKey(structField)) - { - IRBuilder builder(module); - return this->fieldInfo[structField]; + if (this->fieldInfo.containsKey(structField)) + { + IRBuilder builder(module); + return this->fieldInfo[structField]; + } } - return none(); } @@ -1265,7 +1307,7 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); } - void discoverContext(IRInst* context, LinkedList& workQueue) + void discoverContext(IRInst* context, WorkQueue& workQueue) { if (this->availableContexts.add(context)) { @@ -1283,7 +1325,7 @@ struct DynamicInstLoweringContext // Add all blocks to the work queue for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) - workQueue.addLast(WorkItem(context, block)); + workQueue.enqueue(WorkItem(context, block)); break; } case kIROp_Specialize: @@ -1328,16 +1370,16 @@ struct DynamicInstLoweringContext // Add all blocks to the work queue for an initial sweep for (auto block = generic->getFirstBlock(); block; block = block->getNextBlock()) - workQueue.addLast(WorkItem(context, block)); + workQueue.enqueue(WorkItem(context, block)); for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) - workQueue.addLast(WorkItem(context, block)); + workQueue.enqueue(WorkItem(context, block)); } } } } - IRInst* analyzeCall(IRInst* context, IRCall* inst, LinkedList& workQueue) + IRInst* analyzeCall(IRInst* context, IRCall* inst, WorkQueue& workQueue) { auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); @@ -1368,10 +1410,10 @@ struct DynamicInstLoweringContext { // If this is a new call site, add a propagation task to the queue (in case there's // already information about this function) - workQueue.addLast( + workQueue.enqueue( WorkItem(InterproceduralEdge::Direction::FuncToCall, context, inst, callee)); } - workQueue.addLast( + workQueue.enqueue( WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); }; @@ -1392,11 +1434,7 @@ struct DynamicInstLoweringContext return none(); } - void maybeUpdatePtr( - IRInst* context, - IRInst* inst, - IRInst* info, - LinkedList& workQueue) + void maybeUpdatePtr(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { if (auto getElementPtr = as(inst)) { @@ -1453,7 +1491,7 @@ struct DynamicInstLoweringContext // Add a work item to update the field extract if (this->fieldUseSites.containsKey(foundField)) for (auto useSite : this->fieldUseSites[foundField]) - workQueue.addLast(WorkItem(useSite.context, useSite.inst)); + workQueue.enqueue(WorkItem(useSite.context, useSite.inst)); } } } @@ -1470,7 +1508,7 @@ struct DynamicInstLoweringContext } } - void propagateWithinFuncEdge(IRInst* context, IREdge edge, LinkedList& workQueue) + void propagateWithinFuncEdge(IRInst* context, IREdge edge, WorkQueue& workQueue) { // Handle intra-procedural edge (original logic) auto predecessorBlock = edge.getPredecessor(); @@ -1509,6 +1547,24 @@ struct DynamicInstLoweringContext bool isGlobalInst(IRInst* inst) { return inst->getParent()->getOp() == kIROp_ModuleInst; } + bool isIntrinsic(IRInst* inst) + { + auto func = as(inst); + if (auto specialize = as(inst)) + { + auto generic = specialize->getBase(); + func = as(getGenericReturnVal(generic)); + } + + if (!func) + return false; + + if (func->getFirstBlock() == nullptr) + return true; + + return false; + } + List getParamEffectiveTypes(IRInst* context) { List effectiveTypes; @@ -1545,7 +1601,8 @@ struct DynamicInstLoweringContext effectiveTypes.add((IRType*)newType); else effectiveTypes.add( - (IRType*)as(context->getDataType())->getParamType(idx)); + //(IRType*)as(context->getDataType())->getParamType(idx) + param->getDataType()); idx++; } @@ -1606,7 +1663,7 @@ struct DynamicInstLoweringContext return directions; } - void propagateInterproceduralEdge(InterproceduralEdge edge, LinkedList& workQueue) + void propagateInterproceduralEdge(InterproceduralEdge edge, WorkQueue& workQueue) { // Handle interprocedural edge auto callInst = edge.callInst; @@ -2135,6 +2192,8 @@ struct DynamicInstLoweringContext return lowerStructuredBufferLoad(context, inst); case kIROp_Specialize: return lowerSpecialize(context, as(inst)); + case kIROp_GetValueFromBoundInterface: + return lowerGetValueFromBoundInterface(context, as(inst)); case kIROp_Load: return lowerLoad(context, inst); case kIROp_Store: @@ -2669,7 +2728,7 @@ struct DynamicInstLoweringContext // If by this point, we haven't resolved our callee into a global inst ( // either a collection or a single function), then we can't lower it (likely unbounded) // - if (!isGlobalInst(callee)) + if (!isGlobalInst(callee) || isIntrinsic(callee)) return false; auto expectedFuncType = getEffectiveFuncType(callee); @@ -2744,16 +2803,10 @@ struct DynamicInstLoweringContext break; case kParameterDirection_Out: case kParameterDirection_InOut: + case kParameterDirection_ConstRef: + case kParameterDirection_Ref: { - auto argValueType = as(arg->getDataType())->getValueType(); - /*if (argValueType != paramType) - { - SLANG_UNEXPECTED("ptr-typed parameters should have matching types"); - } - else - {*/ newArgs.add(arg); - //} break; } default: @@ -2972,12 +3025,22 @@ struct DynamicInstLoweringContext bool lowerSpecialize(IRInst* context, IRSpecialize* inst) { - auto returnVal = getGenericReturnVal(inst->getBase()); + bool isFuncReturn = false; + + // TODO: Would checking this inst's info be enough instead? + // This seems long-winded. + if (auto concreteGeneric = as(inst->getBase())) + isFuncReturn = as(getGenericReturnVal(concreteGeneric)) != nullptr; + else if (auto tagType = as(inst->getBase()->getDataType())) + { + auto firstConcreteGeneric = as(getCollectionElement(tagType, 0)); + isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; + } - // Functions should be handled at the call site (in lowerCall) + // Functions/Collections of Functions should be handled at the call site (in lowerCall) // since witness table specialization arguments must be inlined into the call. // - if (as(returnVal)) + if (isFuncReturn) { // TODO: Maybe make this the 'default' behavior if a lowering call // returns false. @@ -3007,16 +3070,6 @@ struct DynamicInstLoweringContext } } - /*IRType* typeForSpecialization = nullptr; - if (auto info = tryGetInfo(context, inst)) - { - changed = true; - typeForSpecialization = getLoweredType(info); - } - else - { - typeForSpecialization = inst->getDataType(); - }*/ IRBuilder builder(inst); IRType* typeForSpecialization = builder.getTypeKind(); @@ -3036,6 +3089,25 @@ struct DynamicInstLoweringContext return false; } + + bool lowerGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) + { + auto destType = inst->getDataType(); + auto operandInfo = inst->getOperand(0)->getDataType(); + if (auto taggedUnionTupleType = as(operandInfo)) + { + SLANG_ASSERT(taggedUnionTupleType->getOperand(1) == destType); + + IRBuilder builder(inst); + setInsertAfterOrdinaryInst(&builder, inst); + auto newInst = builder.emitGetTupleElement((IRType*)destType, inst->getOperand(0), 1); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + return false; + } + bool lowerLoad(IRInst* context, IRInst* inst) { auto valInfo = tryGetInfo(context, inst); @@ -3156,9 +3228,20 @@ struct DynamicInstLoweringContext if (existingId) return *existingId; + // If we reach here, the instruction was not assigned an ID during initialization. + // This can happen for instructions that are generated during the dynamic analysis + // process. + // + // We will ensure that they are moved to the end of the module, and assign them a new ID. + // This will ensure a stable ordering on subsequent passes. + // + inst->moveToEnd(); + uniqueIds[inst] = nextUniqueId; + return nextUniqueId++; + // If we reach here, this instruction wasn't assigned an ID during initialization // This should only happen for instructions that don't support collection types - SLANG_UNEXPECTED("getUniqueID called on instruction without pre-assigned ID"); + // SLANG_UNEXPECTED("getUniqueID called on instruction without pre-assigned ID"); } bool isExistentialType(IRType* type) { return as(type) != nullptr; } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 9e62f45756a..e615bacf24c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1142,6 +1142,12 @@ struct SpecializationContext if (iterChanged) { this->changed = true; + writeSpecializationDictionaries(); + genericSpecializations.clear(); + existentialSpecializedFuncs.clear(); + existentialSpecializedStructs.clear(); + eliminateDeadCode(module->getModuleInst()); + readSpecializationDictionaries(); // eliminateDeadCode(module->getModuleInst()); applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } diff --git a/tests/autodiff/custom-intrinsic-1.slang b/tests/autodiff/custom-intrinsic-1.slang index e2ad6010b7a..f9616ca08bd 100644 --- a/tests/autodiff/custom-intrinsic-1.slang +++ b/tests/autodiff/custom-intrinsic-1.slang @@ -12,15 +12,22 @@ typealias IDFloat = __BuiltinFloatingPointType & IDifferentiable; namespace myintrinsiclib { __generic - __target_intrinsic(hlsl, "exp($0)") - __target_intrinsic(glsl, "exp($0)") - __target_intrinsic(cuda, "$P_exp($0)") - __target_intrinsic(cpp, "$P_exp($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 27 _0") - __target_intrinsic(metal, "exp($0)") - __target_intrinsic(wgsl, "exp($0)") [ForwardDerivative(d_myexp)] - T myexp(T x); + T myexp(T x) + { + __target_switch + { + case cpp: __intrinsic_asm "$P_exp($0)"; + case cuda: __intrinsic_asm "$P_exp($0)"; + case glsl: __intrinsic_asm "exp"; + case hlsl: __intrinsic_asm "exp"; + case metal: __intrinsic_asm "exp"; + case spirv: return spirv_asm { + OpExtInst $$T result glsl450 27 $x + }; + case wgsl: __intrinsic_asm "exp"; + } + } __generic DifferentialPair d_myexp(DifferentialPair dpx) @@ -33,15 +40,22 @@ namespace myintrinsiclib // Sine __generic - __target_intrinsic(hlsl, "sin($0)") - __target_intrinsic(glsl, "sin($0)") - __target_intrinsic(metal, "sin($0)") - __target_intrinsic(cuda, "$P_sin($0)") - __target_intrinsic(cpp, "$P_sin($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 13 _0") - __target_intrinsic(wgsl, "sin($0)") [ForwardDerivative(d_mysin)] - T mysin(T x); + T mysin(T x) + { + __target_switch + { + case hlsl: __intrinsic_asm "sin($0)"; + case glsl: __intrinsic_asm "sin($0)"; + case metal: __intrinsic_asm "sin($0)"; + case cuda: __intrinsic_asm "$P_sin($0)"; + case cpp: __intrinsic_asm "$P_sin($0)"; + case spirv: return spirv_asm { + OpExtInst $$T result glsl450 13 $x + }; + case wgsl: __intrinsic_asm "sin($0)"; + } + } __generic DifferentialPair d_mysin(DifferentialPair dpx) @@ -53,15 +67,22 @@ namespace myintrinsiclib // Cosine __generic - __target_intrinsic(hlsl, "cos($0)") - __target_intrinsic(glsl, "cos($0)") - __target_intrinsic(metal, "cos($0)") - __target_intrinsic(cuda, "$P_cos($0)") - __target_intrinsic(cpp, "$P_cos($0)") - __target_intrinsic(spirv, "12 resultType resultId glsl450 14 _0") - __target_intrinsic(wgsl, "cos($0)") [ForwardDerivative(d_mycos)] - T mycos(T x); + T mycos(T x) + { + __target_switch + { + case hlsl: __intrinsic_asm "cos($0)"; + case glsl: __intrinsic_asm "cos($0)"; + case metal: __intrinsic_asm "cos($0)"; + case cuda: __intrinsic_asm "$P_cos($0)"; + case cpp: __intrinsic_asm "$P_cos($0)"; + case spirv: return spirv_asm { + OpExtInst $$T result glsl450 14 $x + }; + case wgsl: __intrinsic_asm "cos($0)"; + } + } __generic DifferentialPair d_mycos(DifferentialPair dpx) From b579ea2d56b810444abe51a672651db8edc7cb67 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 12:53:14 -0400 Subject: [PATCH 026/105] Fix up for one more test --- source/slang/slang-ir-lower-dynamic-insts.cpp | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index dc7e5e82a62..9d81e4d97e6 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1257,16 +1257,21 @@ struct DynamicInstLoweringContext if (auto tag = as(typeInfo)) { SLANG_ASSERT(getCollectionCount(tag) == 1); - auto specializeInst = as(getCollectionElement(tag, 0)); - auto funcType = as(specializeGeneric(specializeInst)); - if (!funcType) - { - SLANG_UNEXPECTED( - "Unexpected IRSpecialize in analyzeSpecialize for func type"); - return none(); - } + auto specializeInst = cast(getCollectionElement(tag, 0)); + auto funcType = cast(specializeGeneric(specializeInst)); typeOfSpecialization = funcType; } + else if (auto collection = as(typeInfo)) + { + SLANG_ASSERT(getCollectionCount(collection) == 1); + auto specializeInst = cast(getCollectionElement(collection, 0)); + auto funcType = cast(specializeGeneric(specializeInst)); + typeOfSpecialization = funcType; + } + else + { + return none(); + } } else { @@ -2701,11 +2706,27 @@ struct DynamicInstLoweringContext return true; } + void maybeSpecializeCalleeType(IRInst* callee) + { + if (auto specializeInst = as(callee->getDataType())) + { + if (isGlobalInst(specializeInst)) + callee->setFullType((IRType*)specializeGeneric(specializeInst)); + } + } + bool lowerCall(IRInst* context, IRCall* inst) { auto callee = inst->getCallee(); IRInst* calleeTagInst = nullptr; + // This is a bit of a workaround for specialized callee's + // whose function types haven't been specialized yet (can + // occur for concrete IRSpecialize insts that are created + // during the lowering process). + // + maybeSpecializeCalleeType(callee); + // If we're calling using a tag, place a call to the collection, // with the tag as the first argument. So the callee is // the collection itself. From 11a1fee289809914982b5ac884cd4c515b46f141 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 13:11:35 -0400 Subject: [PATCH 027/105] Remove commented out code. --- source/slang/slang-ir-lower-dynamic-insts.cpp | 89 ++----------------- 1 file changed, 7 insertions(+), 82 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 9d81e4d97e6..9335877922b 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -474,8 +474,10 @@ struct DynamicInstLoweringContext } // Centralized method to update propagation info and manage work queue - // Use this when you want to propagate new information to an existing instruction + // + // Use this when you want to propagate new information to an existing instruction. // This will union the new info with existing info and add users to work queue if changed + // void updateInfo(IRInst* context, IRInst* inst, IRInst* newInfo, WorkQueue& workQueue) { auto existingInfo = tryGetInfo(context, inst); @@ -824,11 +826,6 @@ struct DynamicInstLoweringContext IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { - // - // TODO: Actually use the integer<->type map present in the linkage to - // extract a set of possible witness tables (if the index is a compile-time constant). - // - if (auto interfaceType = as(inst->getDataType())) { if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) @@ -1593,21 +1590,10 @@ struct DynamicInstLoweringContext UIndex idx = 0; for (auto param : func->getParams()) { - /*if (auto newType = tryGetInfo(context, param)) - effectiveTypes.add((IRType*)newType); - else - { - const auto [direction, type] = getParameterDirectionAndType( - as(context->getDataType())->getParamType(idx)); - SLANG_ASSERT(isGlobalInst(type)); - effectiveTypes.add((IRType*)type); - }*/ if (auto newType = tryGetInfo(context, param)) effectiveTypes.add((IRType*)newType); else - effectiveTypes.add( - //(IRType*)as(context->getDataType())->getParamType(idx) - param->getDataType()); + effectiveTypes.add(param->getDataType()); idx++; } @@ -2145,20 +2131,6 @@ struct DynamicInstLoweringContext if (auto info = tryGetInfo(context, inst)) { - /* // Special cast for type collections. - if (auto collectionTagType = as(info)) - { - if (as(collectionTagType->getOperand(0))) - { - // Remove the tag and replace the inst itself with the type - // in the collection. - // - inst->replaceUsesWith(getLoweredType(collectionTagType->getOperand(0))); - inst->removeAndDeallocate(); - return true; - } - }*/ - if (auto loweredType = getLoweredType(info)) { if (loweredType == inst->getDataType()) @@ -2301,7 +2273,6 @@ struct DynamicInstLoweringContext auto operand = inst->getOperand(0); auto element = builder.emitGetTupleElement((IRType*)collectionTagType, operand, 0); inst->replaceUsesWith(element); - // propagationMap[Element(context, element)] = info; inst->removeAndDeallocate(); return true; } @@ -2324,26 +2295,6 @@ struct DynamicInstLoweringContext } return false; - /* - auto operandInfo = tryGetInfo(context, inst->getOperand(0)); - auto taggedUnion = as(operandInfo); - if (!taggedUnion) - return false; - - auto info = tryGetInfo(context, inst); - auto typeCollection = as(info); - if (!typeCollection) - return false; - - IRBuilder builder(inst); - builder.setInsertBefore(inst); - - // Replace with GetElement(loweredInst, 1) : TypeCollection - auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement((IRType*)info, operand, 1); - inst->replaceUsesWith(element); - inst->removeAndDeallocate(); - return true;*/ } bool lowerExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) @@ -2534,19 +2485,9 @@ struct DynamicInstLoweringContext { auto paramEffectiveTypes = getParamEffectiveTypes(context); auto paramDirections = getParamDirections(context); + for (UInt i = 0; i < paramEffectiveTypes.getCount(); i++) - { updateParamType(i, getLoweredType(paramEffectiveTypes[i])); - /*if (auto collectionType = as(paramEffectiveTypes[i])) - - else if (paramEffectiveTypes[i] != nullptr) - updateParamType( - i, - fromDirectionAndType( - &builder, - paramDirections[i], - (IRType*)paramEffectiveTypes[i]));*/ - } auto returnType = getFuncReturnInfo(context); if (auto newResultType = getLoweredType(returnType)) @@ -2565,7 +2506,7 @@ struct DynamicInstLoweringContext } // - // Add in extra parameter types for a call to the callee. + // Add in extra parameter types for a call to a non-concrete callee. // List extraParamTypes; @@ -2576,7 +2517,6 @@ struct DynamicInstLoweringContext // as the first parameter. if (getCollectionCount(funcCollection) > 1) extraParamTypes.add((IRType*)makeTagType(funcCollection)); - // extraParamTypes.add((IRType*)makeTagType(funcCollection)); } // If the any of the elements in the callee (or the callee itself in case @@ -2695,7 +2635,6 @@ struct DynamicInstLoweringContext callArgs.add(inst->getArg(ii)); IRBuilder builder(inst->getModule()); - // builder.replaceOperand(inst->getCalleeUse(), specializedCallee); builder.setInsertBefore(inst); auto newCallInst = builder.emitCallInst( as(targetContext->getDataType())->getResultType(), @@ -3250,8 +3189,7 @@ struct DynamicInstLoweringContext return *existingId; // If we reach here, the instruction was not assigned an ID during initialization. - // This can happen for instructions that are generated during the dynamic analysis - // process. + // This can happen for instructions that are generated during the analysis. // // We will ensure that they are moved to the end of the module, and assign them a new ID. // This will ensure a stable ordering on subsequent passes. @@ -3259,10 +3197,6 @@ struct DynamicInstLoweringContext inst->moveToEnd(); uniqueIds[inst] = nextUniqueId; return nextUniqueId++; - - // If we reach here, this instruction wasn't assigned an ID during initialization - // This should only happen for instructions that don't support collection types - // SLANG_UNEXPECTED("getUniqueID called on instruction without pre-assigned ID"); } bool isExistentialType(IRType* type) { return as(type) != nullptr; } @@ -3312,18 +3246,9 @@ struct DynamicInstLoweringContext // Phase 1: Information Propagation performInformationPropagation(); - // Phase 1.5: Insert reinterprets for points where sets merge - // e.g. phi, return, call - // - // hasChanges |= insertReinterprets(); - // Phase 2: Dynamic Instruction Lowering hasChanges |= performDynamicInstLowering(); - // Phase 3: Lower collection types. - // if (hasChanges) - // lowerTypeCollections(); - return hasChanges; } From 8f1c28b199d57deeed28d2e61efd88cfedb901a2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 13:37:13 -0400 Subject: [PATCH 028/105] Remove unused function --- source/slang/slang-ir-lower-dynamic-insts.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 9335877922b..bdc5ede0c07 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -128,8 +128,8 @@ struct WorkItem enum class Type { None, // Invalid - Inst, // Propagate through a single instruction - Block, // Propagate information within a block + Inst, // Propagate through a single instruction. + Block, // Propagate through each instruction in a block. IntraProc, // Propagate through within-function edge (IREdge) InterProc // Propagate across function call/return (InterproceduralEdge) }; @@ -408,14 +408,6 @@ struct DynamicInstLoweringContext IRTypeFlowData* none() { return nullptr; } - // Helper to convert collection to HashSet - HashSet collectionToHashSet(IRCollectionBase* info) - { - HashSet result; - forEachInCollection(info, [&](IRInst* element) { result.add(element); }); - return result; - } - IRInst* tryGetInfo(Element element) { // For non-global instructions, look up in the map @@ -1501,11 +1493,12 @@ struct DynamicInstLoweringContext } else if (auto var = as(inst)) { + // If we hit a local var, we'll update it's info. updateInfo(context, var, info, workQueue); } else { - // Do nothing.. + // If we hit something unsupported, assume no information. return; } } From e3081fe5002f0463dd158422fef11bae8c2d203b Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:33:26 -0400 Subject: [PATCH 029/105] Fix formatting and warnings --- source/slang/slang-emit.cpp | 9 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 118 ++++++++---------- source/slang/slang-ir-specialize.cpp | 5 +- 3 files changed, 61 insertions(+), 71 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 223632db4be..aeeeb8f8df8 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -2206,11 +2206,10 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr& outAr if (sourceMap) { - auto sourceMapArtifact = ArtifactUtil::createArtifact( - ArtifactDesc::make( - ArtifactKind::Json, - ArtifactPayload::SourceMap, - ArtifactStyle::None)); + auto sourceMapArtifact = ArtifactUtil::createArtifact(ArtifactDesc::make( + ArtifactKind::Json, + ArtifactPayload::SourceMap, + ArtifactStyle::None)); sourceMapArtifact->addRepresentation(sourceMap); diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index bdc5ede0c07..f6149309044 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -34,7 +34,7 @@ struct Element validateElement(); } - bool validateElement() const + void validateElement() const { switch (context->getOp()) { @@ -243,7 +243,7 @@ struct WorkQueue { List enqueueList; List dequeueList; - UIndex dequeueIndex = 0; + Index dequeueIndex = 0; void enqueue(const WorkItem& item) { enqueueList.add(item); } @@ -515,7 +515,7 @@ struct DynamicInstLoweringContext } // If user is a return instruction, handle function return propagation - if (auto returnInst = as(user)) + if (as(user)) { updateFuncReturnInfo(context, info, workQueue); } @@ -818,6 +818,7 @@ struct DynamicInstLoweringContext IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { + SLANG_UNUSED(context); if (auto interfaceType = as(inst->getDataType())) { if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) @@ -840,8 +841,6 @@ struct DynamicInstLoweringContext IRInst* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) { auto witnessTable = inst->getWitnessTable(); - auto value = inst->getWrappedValue(); - auto valueType = value->getDataType(); // If we're building an existential for a COM pointer, // we won't try to lower that. @@ -877,11 +876,11 @@ struct DynamicInstLoweringContext for (auto field : structType->getFields()) { auto operand = makeStruct->getOperand(operandIndex); - if (auto fieldInfo = tryGetInfo(context, operand)) + if (auto operandInfo = tryGetInfo(context, operand)) { IRInst* existingInfo = nullptr; this->fieldInfo.tryGetValue(field, existingInfo); - auto newInfo = unionPropagationInfo(existingInfo, fieldInfo); + auto newInfo = unionPropagationInfo(existingInfo, operandInfo); if (newInfo && !areInfosEqual(existingInfo, newInfo)) { // Update the field info map @@ -930,9 +929,8 @@ struct DynamicInstLoweringContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential( - as( - createCollection(kIROp_TableCollection, tables))); + return makeExistential(as( + createCollection(kIROp_TableCollection, tables))); else return none(); } @@ -1030,7 +1028,6 @@ struct DynamicInstLoweringContext if (this->fieldInfo.containsKey(structField)) { - IRBuilder builder(module); return builder.getPtrTypeWithAddressSpace( (IRType*)this->fieldInfo[structField], as(fieldAddress->getDataType())); @@ -1055,7 +1052,6 @@ struct DynamicInstLoweringContext if (this->fieldInfo.containsKey(structField)) { - IRBuilder builder(module); return this->fieldInfo[structField]; } } @@ -1136,6 +1132,8 @@ struct DynamicInstLoweringContext if (auto taggedUnion = as(operandInfo)) return cast(taggedUnion->getOperand(0)); + + return none(); } IRInst* analyzeSpecialize(IRInst* context, IRSpecialize* inst) @@ -1162,7 +1160,7 @@ struct DynamicInstLoweringContext bool needsTag = false; List specializationArgs; - for (auto i = 0; i < inst->getArgCount(); ++i) + for (UInt i = 0; i < inst->getArgCount(); ++i) { // For concrete args, add as-is. if (isGlobalInst(inst->getArg(i))) @@ -1247,15 +1245,15 @@ struct DynamicInstLoweringContext { SLANG_ASSERT(getCollectionCount(tag) == 1); auto specializeInst = cast(getCollectionElement(tag, 0)); - auto funcType = cast(specializeGeneric(specializeInst)); - typeOfSpecialization = funcType; + auto specializedFuncType = cast(specializeGeneric(specializeInst)); + typeOfSpecialization = specializedFuncType; } else if (auto collection = as(typeInfo)) { SLANG_ASSERT(getCollectionCount(collection) == 1); auto specializeInst = cast(getCollectionElement(collection, 0)); - auto funcType = cast(specializeGeneric(specializeInst)); - typeOfSpecialization = funcType; + auto specializedFuncType = cast(specializeGeneric(specializeInst)); + typeOfSpecialization = specializedFuncType; } else { @@ -1332,7 +1330,7 @@ struct DynamicInstLoweringContext // first generic block. // IRParam* param = generic->getFirstBlock()->getFirstParam(); - for (auto index = 0; index < specialize->getArgCount() && param; + for (UInt index = 0; index < specialize->getArgCount() && param; ++index, param = param->getNextParam()) { // Map the specialization argument to the corresponding parameter @@ -1378,8 +1376,6 @@ struct DynamicInstLoweringContext auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); - auto funcType = as(callee->getDataType()); - // // Propagate the input judgments to the call & append a work item // for inter-procedural propagation. @@ -1455,7 +1451,6 @@ struct DynamicInstLoweringContext // If this is a field address, update the fieldInfos map. if (auto thisPtrInfo = as(info)) { - auto thisValueType = thisPtrInfo->getValueType(); IRBuilder builder(module); auto baseStructPtrType = as(fieldAddress->getBase()->getDataType()); auto baseStructType = as(baseStructPtrType->getValueType()); @@ -1524,7 +1519,7 @@ struct DynamicInstLoweringContext return; // Collect propagation info for each argument and update corresponding parameter - Index paramIndex = 0; + UInt paramIndex = 0; for (auto param : successorBlock->getParams()) { if (paramIndex < unconditionalBranch->getArgCount()) @@ -1580,14 +1575,12 @@ struct DynamicInstLoweringContext SLANG_UNEXPECTED("Unexpected context type for parameter info retrieval"); } - UIndex idx = 0; for (auto param : func->getParams()) { if (auto newType = tryGetInfo(context, param)) effectiveTypes.add((IRType*)newType); else effectiveTypes.add(param->getDataType()); - idx++; } return effectiveTypes; @@ -1662,7 +1655,7 @@ struct DynamicInstLoweringContext if (!firstBlock) return; - Index argIndex = 1; // Skip callee (operand 0) + UInt argIndex = 1; // Skip callee (operand 0) for (auto param : firstBlock->getParams()) { if (argIndex < callInst->getOperandCount()) @@ -1844,18 +1837,16 @@ struct DynamicInstLoweringContext if (as(info1) && as(info2)) { - return makeExistential( - unionCollection( - cast(info1->getOperand(1)), - cast(info2->getOperand(1)))); + return makeExistential(unionCollection( + cast(info1->getOperand(1)), + cast(info2->getOperand(1)))); } if (as(info1) && as(info2)) { - return makeTagType( - unionCollection( - cast(info1->getOperand(0)), - cast(info2->getOperand(0)))); + return makeTagType(unionCollection( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); } if (as(info1) && as(info2)) @@ -1870,6 +1861,7 @@ struct DynamicInstLoweringContext IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) { + SLANG_UNUSED(context); // Check if this is a global concrete type, witness table, or function. // If so, it's a concrete element. We'll create a singleton set for it. if (isGlobalInst(inst) && @@ -1903,7 +1895,6 @@ struct DynamicInstLoweringContext for (auto inst : block->getChildren()) instsToLower.add(inst); - UIndex paramIndex = 0; for (auto inst : instsToLower) { hasChanges |= lowerInst(context, inst); @@ -2041,10 +2032,9 @@ struct DynamicInstLoweringContext auto tableCollection = cast(taggedUnion->getOperand(1)); if (getCollectionCount(typeCollection) == 1) - return builder.getTupleType( - List( - {(IRType*)makeTagType(tableCollection), - (IRType*)getCollectionElement(typeCollection, 0)})); + return builder.getTupleType(List( + {(IRType*)makeTagType(tableCollection), + (IRType*)getCollectionElement(typeCollection, 0)})); return builder.getTupleType( List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); @@ -2172,7 +2162,7 @@ struct DynamicInstLoweringContext return lowerGetSequentialID(context, as(inst)); default: { - if (auto info = tryGetInfo(context, inst)) + if (tryGetInfo(context, inst)) return replaceType(context, inst); return false; } @@ -2273,6 +2263,8 @@ struct DynamicInstLoweringContext bool lowerExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { + SLANG_UNUSED(context); + auto existential = inst->getOperand(0); auto existentialInfo = existential->getDataType(); if (isTaggedUnionType(existentialInfo)) @@ -2352,7 +2344,6 @@ struct DynamicInstLoweringContext return builder->getConstRefType(type); default: SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); - return nullptr; } } @@ -2404,15 +2395,13 @@ struct DynamicInstLoweringContext { IRBuilder builder(module); // Merge the elements of both tagged unions into a new tuple type - return builder.getTupleType( - List( - {(IRType*)makeTagType( - as(updateType( - (IRType*)currentType->getOperand(0)->getOperand(0), - (IRType*)newType->getOperand(0)->getOperand(0)))), - (IRType*)updateType( - (IRType*)currentType->getOperand(1), - (IRType*)newType->getOperand(1))})); + return builder.getTupleType(List( + {(IRType*)makeTagType(as(updateType( + (IRType*)currentType->getOperand(0)->getOperand(0), + (IRType*)newType->getOperand(0)->getOperand(0)))), + (IRType*)updateType( + (IRType*)currentType->getOperand(1), + (IRType*)newType->getOperand(1))})); } else // Need to create a new collection. { @@ -2436,7 +2425,7 @@ struct DynamicInstLoweringContext List paramTypes; IRType* resultType = nullptr; - auto updateParamType = [&](UInt index, IRType* paramType) -> IRType* + auto updateParamType = [&](Index index, IRType* paramType) -> IRType* { if (paramTypes.getCount() <= index) { @@ -2479,7 +2468,7 @@ struct DynamicInstLoweringContext auto paramEffectiveTypes = getParamEffectiveTypes(context); auto paramDirections = getParamDirections(context); - for (UInt i = 0; i < paramEffectiveTypes.getCount(); i++) + for (Index i = 0; i < paramEffectiveTypes.getCount(); i++) updateParamType(i, getLoweredType(paramEffectiveTypes[i])); auto returnType = getFuncReturnInfo(context); @@ -2748,7 +2737,6 @@ struct DynamicInstLoweringContext const auto [paramDirection, paramType] = getParameterDirectionAndType(expectedFuncType->getParamType(i + extraArgCount)); - IRInst* newArg = nullptr; switch (paramDirection) { case kParameterDirection_In: @@ -2771,13 +2759,13 @@ struct DynamicInstLoweringContext builder.setInsertBefore(inst); bool changed = false; - if (newArgs.getCount() != inst->getArgCount()) + if (((UInt)newArgs.getCount()) != inst->getArgCount()) changed = true; else { - for (UInt i = 0; i < newArgs.getCount(); i++) + for (Index i = 0; i < newArgs.getCount(); i++) { - if (newArgs[i] != inst->getArg(i)) + if (newArgs[i] != inst->getArg((UInt)i)) { changed = true; break; @@ -3045,6 +3033,7 @@ struct DynamicInstLoweringContext bool lowerGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) { + SLANG_UNUSED(context); auto destType = inst->getDataType(); auto operandInfo = inst->getOperand(0)->getDataType(); if (auto taggedUnionTupleType = as(operandInfo)) @@ -3105,6 +3094,7 @@ struct DynamicInstLoweringContext bool handleDefaultStore(IRInst* context, IRStore* inst) { + SLANG_UNUSED(context); SLANG_ASSERT(inst->getVal()->getOp() == kIROp_DefaultConstruct); auto ptr = inst->getPtr(); auto destInfo = as(ptr->getDataType())->getValueType(); @@ -3125,8 +3115,6 @@ struct DynamicInstLoweringContext auto ptr = inst->getPtr(); auto ptrInfo = as(ptr->getDataType())->getValueType(); - auto valInfo = inst->getVal()->getDataType(); - // Special case for default initialization: // // Raw default initialization has been almost entirely @@ -3151,6 +3139,7 @@ struct DynamicInstLoweringContext bool lowerGetSequentialID(IRInst* context, IRGetSequentialID* inst) { + SLANG_UNUSED(context); auto arg = inst->getOperand(0); if (auto tagType = as(arg->getDataType())) { @@ -3350,7 +3339,7 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) // Create parameters for the original function arguments List originalParams; - for (UInt i = 0; i < innerParamTypes.getCount(); i++) + for (Index i = 0; i < innerParamTypes.getCount(); i++) { originalParams.add(builder.emitParam(innerParamTypes[i])); } @@ -3391,7 +3380,7 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) List callArgs; auto wrappedFuncType = as(wrapperFunc->getDataType()); - for (UIndex ii = 0; ii < originalParams.getCount(); ii++) + for (Index ii = 0; ii < originalParams.getCount(); ii++) { callArgs.add(originalParams[ii]); } @@ -3691,6 +3680,7 @@ struct CollectionLoweringContext : public InstPassBase void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) { + SLANG_UNUSED(sink); CollectionLoweringContext context(module); context.processModule(); } @@ -3704,7 +3694,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) { - auto srcInterfaceType = cast(inst->getOperand(0)); + SLANG_UNUSED(cast(inst->getOperand(0))); auto srcSeqID = inst->getOperand(1); Dictionary mapping; @@ -3719,7 +3709,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase [&](IRInst* table) { // Get unique ID for the witness table - auto witnessTable = cast(table); + SLANG_UNUSED(cast(table)); auto outputId = dstSeqID++; auto seqDecoration = table->findDecoration(); if (seqDecoration) @@ -3743,7 +3733,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase void lowerGetSequentialIDFromTag(IRGetSequentialIDFromTag* inst) { - auto srcInterfaceType = cast(inst->getOperand(0)); + SLANG_UNUSED(cast(inst->getOperand(0))); auto srcTagInst = inst->getOperand(1); Dictionary mapping; @@ -3758,7 +3748,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase [&](IRInst* table) { // Get unique ID for the witness table - auto witnessTable = cast(table); + SLANG_UNUSED(cast(table)); auto outputId = dstSeqID++; auto seqDecoration = table->findDecoration(); if (seqDecoration) @@ -3793,12 +3783,14 @@ struct SequentialIDTagLoweringContext : public InstPassBase void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) { + SLANG_UNUSED(sink); SequentialIDTagLoweringContext context(module); context.processModule(); } void lowerTagInsts(IRModule* module, DiagnosticSink* sink) { + SLANG_UNUSED(sink); TagOpsLoweringContext tagContext(module); tagContext.processModule(); } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index e615bacf24c..7647b2d50b0 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3312,9 +3312,8 @@ IRInst* specializeGenericImpl( builder->setInsertBefore(genericVal); List pendingWorkList; - SLANG_DEFER( - for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--) if (context) - context->addToWorkList(pendingWorkList[ii]);); + SLANG_DEFER(for (Index ii = pendingWorkList.getCount() - 1; ii >= 0; ii--) if (context) + context->addToWorkList(pendingWorkList[ii]);); // Now we will run through the body of the generic and // clone each of its instructions into the global scope, From 5f3b841c352ff31abc04e98c7387b06f8da15fef Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:43:33 -0400 Subject: [PATCH 030/105] Update slang-ir-witness-table-wrapper.cpp --- .../slang/slang-ir-witness-table-wrapper.cpp | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index 5aed1a99bcb..a29c08f893b 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -141,31 +141,12 @@ static UCount getCollectionCount(IRCollectionBase* collection) return collection->getOperandCount(); } -static UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) -{ - auto typeCollection = taggedUnion->getOperand(0); - return getCollectionCount(as(typeCollection)); -} - static UCount getCollectionCount(IRCollectionTagType* tagType) { auto collection = tagType->getOperand(0); return getCollectionCount(as(collection)); } -static IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) -{ - if (!collection || index >= collection->getOperandCount()) - return nullptr; - return collection->getOperand(index); -} - -static IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) -{ - auto typeCollection = collectionTagType->getOperand(0); - return getCollectionElement(as(typeCollection), index); -} - static IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) { auto argInfo = arg->getDataType(); From b313c0a8fad4fa80fbd1c23c3f9c4a471b89f5f2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:59:41 -0400 Subject: [PATCH 031/105] Update slang-ir-specialize.cpp --- source/slang/slang-ir-specialize.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 7647b2d50b0..64ef15414fd 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3203,11 +3203,11 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) } builder.setInsertInto(as(cloneEnv.mapOldValToNew[funcFirstBlock])); - for (auto inst = funcFirstBlock->getFirstOrdinaryInst(); inst; - inst = inst->getNextInst()) + for (auto _inst = funcFirstBlock->getFirstOrdinaryInst(); _inst; + _inst = _inst->getNextInst()) { // Clone the instructions in the first block. - cloneInst(&cloneEnv, &builder, inst); + cloneInst(&cloneEnv, &builder, _inst); } for (auto block : returnedFunc->getBlocks()) From aeaa4b41542ade8ecf3771b0499bbd820c66b4c8 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 15:14:58 -0400 Subject: [PATCH 032/105] Clean up, fix warnings and add a new inst for a cleaner way to handle casting interfaces to tagged unions --- source/slang/slang-ir-insts-stable-names.lua | 9 ++++---- source/slang/slang-ir-insts.lua | 23 ++++++++++++++----- source/slang/slang-ir-lower-dynamic-insts.cpp | 11 +++------ source/slang/slang-ir-specialize.cpp | 2 -- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 73b89a94d90..bc6be8a3e82 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -676,8 +676,9 @@ return { ["TypeFlowData.UnboundedCollection"] = 672, ["TypeFlowData.CollectionTagType"] = 673, ["TypeFlowData.CollectionTaggedUnionType"] = 674, - ["GetTagForSuperCollection"] = 675, - ["GetTagForMappedCollection"] = 676, - ["GetTagFromSequentialID"] = 677, - ["GetSequentialIDFromTag"] = 678 + ["CastInterfaceToTaggedUnionPtr"] = 675, + ["GetTagForSuperCollection"] = 676, + ["GetTagForMappedCollection"] = 677, + ["GetTagFromSequentialID"] = 678, + ["GetSequentialIDFromTag"] = 679 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 4015c594793..a5f22858fda 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2188,12 +2188,23 @@ local insts = { { CollectionTaggedUnionType = {}} -- Operand is TypeCollection, TableCollection for existential }, }, - { GetTagForSuperCollection = {} }, -- Translate a tag from a set to its equivalent in a super-set - { GetTagForMappedCollection = {} }, -- Translate a tag from a set to its equivalent in a different set - -- based on a mapping induced by a lookup key - { GetTagFromSequentialID = {} }, -- Translate an existing sequential ID & and interface type into a tag - -- the provided collection. - { GetSequentialIDFromTag = {} } -- Translate a tag from the given collection to a sequential ID. + { CastInterfaceToTaggedUnionPtr = { + -- Cast an interface-typed pointer to a tagged-union pointer with a known set. + } }, + { GetTagForSuperCollection = { + -- Translate a tag from a set to its equivalent in a super-set + } }, + { GetTagForMappedCollection = { + -- Translate a tag from a set to its equivalent in a different set + -- based on a mapping induced by a lookup key + } }, + { GetTagFromSequentialID = { + -- Translate an existing sequential ID (a 'global' ID) & and interface type into a tag + -- the provided collection (a 'local' ID) + } }, + { GetSequentialIDFromTag = { + -- Translate a tag from the given collection (a 'local' ID) to a sequential ID (a 'global' ID) + } } } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index f6149309044..e8e15f5f225 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -84,11 +84,6 @@ struct Element context = func; } - Element(const Element& other) - : context(other.context), inst(other.inst) - { - } - bool operator==(const Element& other) const { return context == other.context && inst == other.inst; @@ -2561,7 +2556,7 @@ struct DynamicInstLoweringContext List getArgsForDynamicSpecialization(IRSpecialize* specializedCallee) { List callArgs; - for (auto ii = 0; ii < specializedCallee->getArgCount(); ii++) + for (UInt ii = 0; ii < specializedCallee->getArgCount(); ii++) { auto specArg = specializedCallee->getArg(ii); auto argInfo = specArg->getDataType(); @@ -2587,7 +2582,7 @@ struct DynamicInstLoweringContext auto targetContext = getCollectionElement(calleeCollection, 0); List callArgs; - for (auto ii = 0; ii < specializedCallee->getArgCount(); ii++) + for (UInt ii = 0; ii < specializedCallee->getArgCount(); ii++) { auto specArg = specializedCallee->getArg(ii); auto argInfo = tryGetInfo(context, specArg); @@ -2613,7 +2608,7 @@ struct DynamicInstLoweringContext } } - for (auto ii = 0; ii < inst->getArgCount(); ii++) + for (UInt ii = 0; ii < inst->getArgCount(); ii++) callArgs.add(inst->getArg(ii)); IRBuilder builder(inst->getModule()); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 64ef15414fd..f35908f89e6 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3073,8 +3073,6 @@ void finalizeSpecialization(IRModule* module) break; } } - - // lowerCollectionAndTagInsts(module, nullptr); } From f554ff76078fb26c3d16bdfd8270dbb9d3c7f742 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 17:53:54 -0400 Subject: [PATCH 033/105] Clean up implementation of interface -> tagged-union casts --- source/slang/slang-emit.cpp | 3 + source/slang/slang-ir-insts-stable-names.lua | 9 +- source/slang/slang-ir-insts.lua | 3 + source/slang/slang-ir-lower-dynamic-insts.cpp | 400 ++++++++++++------ source/slang/slang-ir-lower-dynamic-insts.h | 2 + 5 files changed, 272 insertions(+), 145 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index aeeeb8f8df8..ec1ebe8e251 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1140,6 +1140,9 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } + if (lowerTaggedUnionPtrCasts(irModule, sink)) + requiredLoweringPassSet.reinterpret = true; // TODO: Is this the right way to handle this? + lowerTagInsts(irModule, sink); lowerTypeCollections(irModule, sink); diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index bc6be8a3e82..09d776096cb 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -677,8 +677,9 @@ return { ["TypeFlowData.CollectionTagType"] = 673, ["TypeFlowData.CollectionTaggedUnionType"] = 674, ["CastInterfaceToTaggedUnionPtr"] = 675, - ["GetTagForSuperCollection"] = 676, - ["GetTagForMappedCollection"] = 677, - ["GetTagFromSequentialID"] = 678, - ["GetSequentialIDFromTag"] = 679 + ["CastTaggedUnionToInterfacePtr"] = 676, + ["GetTagForSuperCollection"] = 677, + ["GetTagForMappedCollection"] = 678, + ["GetTagFromSequentialID"] = 679, + ["GetSequentialIDFromTag"] = 680 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index a5f22858fda..2b4579469b1 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2191,6 +2191,9 @@ local insts = { { CastInterfaceToTaggedUnionPtr = { -- Cast an interface-typed pointer to a tagged-union pointer with a known set. } }, + { CastTaggedUnionToInterfacePtr = { + -- Cast a tagged-union pointer with a known set to a corresponding interface-typed pointer. + } }, { GetTagForSuperCollection = { -- Translate a tag from a set to its equivalent in a super-set } }, diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index e8e15f5f225..60035a5d0b4 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -261,6 +261,51 @@ struct WorkQueue } }; + +// TODO: Move to utilities + +IRCollectionTagType* makeTagType(IRCollectionBase* collection) +{ + IRInst* collectionInst = collection; + // Create the tag type from the collection + IRBuilder builder(collection->getModule()); + return as( + builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); +} + +UCount getCollectionCount(IRCollectionBase* collection) +{ + if (!collection) + return 0; + return collection->getOperandCount(); +} + +UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) +{ + auto typeCollection = taggedUnion->getOperand(0); + return getCollectionCount(as(typeCollection)); +} + +UCount getCollectionCount(IRCollectionTagType* tagType) +{ + auto collection = tagType->getOperand(0); + return getCollectionCount(as(collection)); +} + +IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) +{ + if (!collection || index >= collection->getOperandCount()) + return nullptr; + return collection->getOperand(index); +} + +IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) +{ + auto typeCollection = collectionTagType->getOperand(0); + return getCollectionElement(as(typeCollection), index); +} + + struct DynamicInstLoweringContext { // Helper methods for creating canonical collections @@ -353,47 +398,6 @@ struct DynamicInstLoweringContext elements.getBuffer())); } - IRCollectionTagType* makeTagType(IRCollectionBase* collection) - { - IRInst* collectionInst = collection; - // Create the tag type from the collection - IRBuilder builder(module); - return as( - builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); - } - - UCount getCollectionCount(IRCollectionBase* collection) - { - if (!collection) - return 0; - return collection->getOperandCount(); - } - - UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) - { - auto typeCollection = taggedUnion->getOperand(0); - return getCollectionCount(as(typeCollection)); - } - - UCount getCollectionCount(IRCollectionTagType* tagType) - { - auto collection = tagType->getOperand(0); - return getCollectionCount(as(collection)); - } - - IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) - { - if (!collection || index >= collection->getOperandCount()) - return nullptr; - return collection->getOperand(index); - } - - IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) - { - auto typeCollection = collectionTagType->getOperand(0); - return getCollectionElement(as(typeCollection), index); - } - IRUnboundedCollection* makeUnbounded() { IRBuilder builder(module); @@ -616,32 +620,6 @@ struct DynamicInstLoweringContext } } - IRIntegerValue getInterfaceAnyValueSize(IRInst* type) - { - if (auto decor = type->findDecoration()) - { - return decor->getSize(); - } - - // We could conceivably make it an error to have an interface - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - return kDefaultAnyValueSize; - } - - IRType* lowerInterfaceType(IRInterfaceType* interfaceType) - { - IRBuilder builder(module); - auto anyValueType = builder.getAnyValueType(getInterfaceAnyValueSize(interfaceType)); - auto witnessTableType = builder.getWitnessTableIDType((IRType*)interfaceType); - auto rttiType = builder.getRTTIHandleType(); - return builder.getTupleType({rttiType, witnessTableType, anyValueType}); - } - IRInst* upcastCollection(IRInst* context, IRInst* arg, IRType* destInfo) { auto argInfo = arg->getDataType(); @@ -714,39 +692,6 @@ struct DynamicInstLoweringContext builder.setInsertAfter(arg); return builder.emitPackAnyValue((IRType*)destInfo, arg); } - else if (as(argInfo) && as(destInfo)) - { - auto loweredInterfaceType = lowerInterfaceType(as(argInfo)); - IRBuilder builder(module); - builder.setInsertAfter(arg); - auto witnessTable = - builder.emitGetTupleElement(builder.getWitnessTableIDType(argInfo), arg, 1); - auto tableID = builder.emitGetSequentialIDInst(witnessTable); - auto tableCollection = cast(destInfo->getOperand(1)); - auto typeCollection = cast(destInfo->getOperand(0)); - - List getTagOperands; - getTagOperands.add(argInfo); - getTagOperands.add(tableID); - auto tableTag = builder.emitIntrinsicInst( - (IRType*)makeTagType(tableCollection), - kIROp_GetTagFromSequentialID, - getTagOperands.getCount(), - getTagOperands.getBuffer()); - - auto effectiveValType = getCollectionCount(typeCollection) > 1 - ? typeCollection - : getCollectionElement(typeCollection, 0); - - return builder.emitMakeTuple( - {tableTag, - builder.emitReinterpret( - (IRType*)effectiveValType, - builder.emitGetTupleElement( - (IRType*)loweredInterfaceType->getOperand(0), - arg, - 2))}); - } return arg; // Can use as-is. } @@ -2925,26 +2870,39 @@ struct DynamicInstLoweringContext auto bufferType = (IRType*)inst->getOperand(0)->getDataType(); auto bufferBaseType = (IRType*)bufferType->getOperand(0); - if (bufferBaseType != (IRType*)getLoweredType(valInfo)) + auto loweredValType = (IRType*)getLoweredType(valInfo); + if (bufferBaseType != loweredValType) { - IRBuilder builder(inst); - builder.setInsertAfter(inst); - - IRCloneEnv cloneEnv; - auto newLoad = cloneInst(&cloneEnv, &builder, inst); - - auto loweredVal = upcastCollection(context, newLoad, (IRType*)valInfo); - - // TODO: this is a hack. Encode this in the type-flow-data. - if (as(bufferBaseType) && !isComInterfaceType(inst->getDataType()) && - !isBuiltin(inst->getDataType())) + if (as(bufferBaseType)) { - newLoad->setFullType(lowerInterfaceType(as(bufferBaseType))); - } + // If we're dealing with a loading a known tagged union value from + // an interface-typed pointer, we'll cast the pointer itself and + // defer the lowering of the load until later. + // + // This avoids having to change the source pointer type + // and confusing any future runs of the type flow + // analysis pass. + // + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto bufferHandle = inst->getOperand(0); + auto newHandle = builder.emitIntrinsicInst( + builder.getPtrType(loweredValType), + kIROp_CastInterfaceToTaggedUnionPtr, + 1, + &bufferHandle); + List newLoadOperands = {newHandle, inst->getOperand(1)}; + auto newLoad = builder.emitIntrinsicInst( + loweredValType, + inst->getOp(), + newLoadOperands.getCount(), + newLoadOperands.getBuffer()); + + inst->replaceUsesWith(newLoad); + inst->removeAndDeallocate(); - inst->replaceUsesWith(loweredVal); - inst->removeAndDeallocate(); - return true; + return true; + } } else if (inst->getDataType() != bufferBaseType) { @@ -2952,11 +2910,8 @@ struct DynamicInstLoweringContext inst->setFullType((IRType*)getLoweredType(valInfo)); return true; } - else - { - // No change needed. - return false; - } + + return false; } bool lowerSpecialize(IRInst* context, IRSpecialize* inst) @@ -3052,31 +3007,39 @@ struct DynamicInstLoweringContext if (!valInfo) return false; - IRType* ptrValType = nullptr; - ptrValType = as(as(inst)->getPtr()->getDataType())->getValueType(); + auto loadPtr = as(inst)->getPtr(); + auto loadPtrType = as(loadPtr->getDataType()); + auto ptrValType = loadPtrType->getValueType(); - if (ptrValType != (IRType*)getLoweredType(valInfo)) + IRType* loweredType = (IRType*)getLoweredType(valInfo); + if (ptrValType != loweredType) { SLANG_ASSERT(!as(inst)); - IRBuilder builder(inst); - builder.setInsertAfter(inst); - - IRCloneEnv cloneEnv; - auto newLoad = cloneInst(&cloneEnv, &builder, inst); - auto loweredVal = upcastCollection(context, newLoad, (IRType*)valInfo); - - // TODO: this is a hack. Encode this in the type-flow-data. - if (as(ptrValType) && !isComInterfaceType(inst->getDataType()) && - !isBuiltin(inst->getDataType())) + if (as(ptrValType)) { - newLoad->setFullType(lowerInterfaceType(as(ptrValType))); - } - - inst->replaceUsesWith(loweredVal); - inst->removeAndDeallocate(); + // If we're dealing with a loading a known tagged union value from + // an interface-typed pointer, we'll cast the pointer itself and + // defer the lowering of the load until later. + // + // This avoids having to change the source pointer type + // and confusing any future runs of the type flow + // analysis pass. + // + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto newLoadPtr = builder.emitIntrinsicInst( + builder.getPtrTypeWithAddressSpace(loweredType, loadPtrType), + kIROp_CastInterfaceToTaggedUnionPtr, + 1, + &loadPtr); + auto newLoad = builder.emitLoad(loweredType, newLoadPtr); + + inst->replaceUsesWith(newLoad); + inst->removeAndDeallocate(); - return true; + return true; + } } else if (inst->getDataType() != ptrValType) { @@ -3815,6 +3778,161 @@ void lowerTagTypes(IRModule* module) context.processModule(); } +// This context lowers `CastInterfaceToTaggedUnionPtr` and +// `CastTaggedUnionToInterfacePtr` by finding all `IRLoad` and +// `IRStore` uses of these insts, and upcasting the tagged-union +// tuple to the the interface-based tuple (of the loaded inst or before +// storing the val, as necessary) +// +struct TaggedUnionCastLoweringContext : public InstPassBase +{ + TaggedUnionCastLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + IRInst* convertToTaggedUnion( + IRBuilder* builder, + IRInst* val, + IRInst* interfaceType, + IRInst* targetType) + { + auto baseInterfaceValue = val; + auto witnessTable = builder->emitExtractExistentialWitnessTable(baseInterfaceValue); + auto tableID = builder->emitGetSequentialIDInst(witnessTable); + + auto taggedUnionTupleType = cast(targetType); + + List getTagOperands; + getTagOperands.add(interfaceType); + getTagOperands.add(tableID); + auto tableTag = builder->emitIntrinsicInst( + (IRType*)taggedUnionTupleType->getOperand(0), + kIROp_GetTagFromSequentialID, + getTagOperands.getCount(), + getTagOperands.getBuffer()); + + return builder->emitMakeTuple( + {tableTag, + builder->emitReinterpret( + (IRType*)taggedUnionTupleType->getOperand(1), + builder->emitExtractExistentialValue( + (IRType*)builder->emitExtractExistentialType(baseInterfaceValue), + baseInterfaceValue))}); + } + + void lowerCastInterfaceToTaggedUnionPtr(IRCastInterfaceToTaggedUnionPtr* inst) + { + // Find all uses of the inst + traverseUses( + inst, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_Load: + { + auto baseInterfacePtr = inst->getOperand(0); + auto baseInterfaceType = as( + as(baseInterfacePtr->getDataType())->getValueType()); + + // Rewrite the load to use the original ptr and load + // an interface-typed object. + // + IRBuilder builder(module); + builder.setInsertAfter(user); + builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); + builder.replaceOperand(&user->typeUse, baseInterfaceType); + + // Then, we'll rewrite it. + List oldUses; + traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); + + auto newVal = convertToTaggedUnion( + &builder, + user, + baseInterfaceType, + as(inst->getDataType())->getValueType()); + for (auto oldUse : oldUses) + { + builder.replaceOperand(oldUse, newVal); + } + break; + } + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + auto baseInterfacePtr = inst->getOperand(0); + auto baseInterfaceType = + as((baseInterfacePtr->getDataType())->getOperand(0)); + + IRBuilder builder(module); + builder.setInsertAfter(user); + builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); + builder.replaceOperand(&user->typeUse, baseInterfaceType); + + // Then, we'll rewrite it. + List oldUses; + traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); + + auto newVal = convertToTaggedUnion( + &builder, + user, + baseInterfaceType, + as(inst->getDataType())->getValueType()); + for (auto oldUse : oldUses) + { + builder.replaceOperand(oldUse, newVal); + } + break; + } + default: + SLANG_UNEXPECTED("Unexpected user of CastInterfaceToTaggedUnionPtr"); + } + }); + + SLANG_ASSERT(!inst->hasUses()); + inst->removeAndDeallocate(); + } + + void lowerCastTaggedUnionToInterfacePtr(IRCastTaggedUnionToInterfacePtr* inst) + { + SLANG_UNUSED(inst); + SLANG_UNEXPECTED("Unexpected inst of CastTaggedUnionToInterfacePtr"); + } + + bool processModule() + { + bool hasCastInsts = false; + processInstsOfType( + kIROp_CastInterfaceToTaggedUnionPtr, + [&](IRCastInterfaceToTaggedUnionPtr* inst) + { + hasCastInsts = true; + return lowerCastInterfaceToTaggedUnionPtr(inst); + }); + + processInstsOfType( + kIROp_CastTaggedUnionToInterfacePtr, + [&](IRCastTaggedUnionToInterfacePtr* inst) + { + hasCastInsts = true; + return lowerCastTaggedUnionToInterfacePtr(inst); + }); + + return hasCastInsts; + } +}; + +bool lowerTaggedUnionPtrCasts(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + + TaggedUnionCastLoweringContext context(module); + return context.processModule(); +} + // Main entry point bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) { diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index eae821a0a9d..ab36ac114b2 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -15,4 +15,6 @@ void lowerTagInsts(IRModule* module, DiagnosticSink* sink); void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink); void lowerTagTypes(IRModule* module); + +bool lowerTaggedUnionPtrCasts(IRModule* module, DiagnosticSink* sink); } // namespace Slang From d7ef2a2901135ab7de7a4d001405b858772ac3c3 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 11 Aug 2025 18:13:21 -0400 Subject: [PATCH 034/105] Some more fixups --- source/slang/slang-ir-any-value-marshalling.cpp | 4 ++++ source/slang/slang-ir-layout.cpp | 7 +++++++ source/slang/slang-ir-lower-dynamic-insts.h | 1 - 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 18862313ff2..75a0842bcad 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -1018,6 +1018,10 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { return alignUp(offset, 4) + kRTTIHandleSize; } + case kIROp_CollectionTagType: + { + return alignUp(offset, 4) + 4; + } case kIROp_InterfaceType: { auto interfaceType = cast(type); diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index ac0c7960111..23dd83cfcbd 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -309,6 +309,13 @@ static Result _calcSizeAndAlignment( return SLANG_OK; } break; + case kIROp_CollectionTagType: + { + outSizeAndAlignment->size = 4; + outSizeAndAlignment->alignment = 4; + return SLANG_OK; + } + break; case kIROp_InterfaceType: { auto interfaceType = cast(type); diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index ab36ac114b2..5b255c61a04 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -8,7 +8,6 @@ namespace Slang { // Main entry point for the pass bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); -// void lowerCollectionAndTagInsts(IRModule* module, DiagnosticSink* sink); void lowerTypeCollections(IRModule* module, DiagnosticSink* sink); void lowerTagInsts(IRModule* module, DiagnosticSink* sink); From 209031e66306ce641e0952f2515618aac7ddb736 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 12 Aug 2025 16:09:35 -0400 Subject: [PATCH 035/105] More fixes to get all tests passing --- source/slang/slang-emit.cpp | 2 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 294 +++++++++++++----- source/slang/slang-ir-lower-dynamic-insts.h | 2 +- .../no-type-conformance.slang.expected | 4 +- tests/ir/dump-module-info.slang | 2 +- tests/wgsl/switch-case.slang | 32 +- 6 files changed, 242 insertions(+), 94 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ec1ebe8e251..c6439343422 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1140,7 +1140,7 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } - if (lowerTaggedUnionPtrCasts(irModule, sink)) + if (lowerTaggedUnionTypes(irModule, sink)) requiredLoweringPassSet.reinterpret = true; // TODO: Is this the right way to handle this? lowerTagInsts(irModule, sink); diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 60035a5d0b4..49af666d15a 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -485,10 +485,34 @@ struct DynamicInstLoweringContext addUsersToWorkQueue(context, inst, unionedInfo, workQueue); } + bool isFuncParam(IRParam* param) + { + auto paramBlock = as(param->getParent()); + auto paramFunc = as(paramBlock->getParent()); + return (paramFunc && paramFunc->getFirstBlock() == paramBlock); + } + + void addContextUsersToWorkQueue(IRInst* context, WorkQueue& workQueue) + { + if (this->funcCallSites.containsKey(context)) + for (auto callSite : this->funcCallSites[context]) + { + workQueue.enqueue(WorkItem( + InterproceduralEdge::Direction::FuncToCall, + callSite.context, + as(callSite.inst), + context)); + } + } + // Helper to add users of an instruction to the work queue based on how they use it // This handles intra-procedural edges, inter-procedural edges, and return value propagation void addUsersToWorkQueue(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { + if (auto param = as(inst)) + if (isFuncParam(param)) + addContextUsersToWorkQueue(context, workQueue); + for (auto use = inst->firstUse; use; use = use->nextUse) { auto user = use->getUser(); @@ -515,30 +539,14 @@ struct DynamicInstLoweringContext // If user is a return instruction, handle function return propagation if (as(user)) - { updateFuncReturnInfo(context, info, workQueue); - } // If the user is a top-level inout/out parameter, we need to handle it // like we would a func-return. // if (auto param = as(user)) - { - auto paramBlock = as(param->getParent()); - auto paramFunc = as(paramBlock->getParent()); - if (paramFunc && paramFunc->getFirstBlock() == paramBlock) - { - if (this->funcCallSites.containsKey(context)) - for (auto callSite : this->funcCallSites[context]) - { - workQueue.enqueue(WorkItem( - InterproceduralEdge::Direction::FuncToCall, - callSite.context, - as(callSite.inst), - context)); - } - } - } + if (isFuncParam(param)) + addContextUsersToWorkQueue(context, workQueue); } } @@ -626,7 +634,39 @@ struct DynamicInstLoweringContext if (!argInfo || !destInfo) return arg; - if (as(argInfo) && as(destInfo)) + if (as(argInfo) && as(destInfo)) + { + // Handle upcasting between collection tagged unions + auto argTUType = as(argInfo); + auto destTUType = as(destInfo); + + if (getCollectionCount(argTUType) != getCollectionCount(destTUType)) + { + // Technically, IRCollectionTaggedUnionType is not a TupleType, + // but in practice it works the same way so we'll re-use Slang's + // tuple accessors & constructors + // + IRBuilder builder(arg->getModule()); + setInsertAfterOrdinaryInst(&builder, arg); + auto argTableTag = builder.emitGetTupleElement( + (IRType*)makeTagType(as(argTUType->getOperand(1))), + arg, + 0); + auto reinterpretedTag = upcastCollection( + context, + argTableTag, + (IRType*)makeTagType(as(destTUType->getOperand(1)))); + + auto argVal = + builder.emitGetTupleElement((IRType*)argTUType->getOperand(0), arg, 1); + auto reinterpretedVal = + upcastCollection(context, argVal, (IRType*)destTUType->getOperand(0)); + return builder.emitMakeTuple( + (IRType*)destTUType, + {reinterpretedTag, reinterpretedVal}); + } + } + else if (as(argInfo) && as(destInfo)) { auto argTupleType = as(argInfo); auto destTupleType = as(destInfo); @@ -812,6 +852,9 @@ struct DynamicInstLoweringContext // all fields of the struct. // auto structType = as(makeStruct->getDataType()); + if (!structType) + return none(); + UIndex operandIndex = 0; for (auto field : structType->getFields()) { @@ -1431,6 +1474,18 @@ struct DynamicInstLoweringContext // If we hit a local var, we'll update it's info. updateInfo(context, var, info, workQueue); } + else if (auto param = as(inst)) + { + // We'll also update function parameters, + // but first change the info from PtrTypeBase + // to the specific pointer type for the parameter. + // + IRBuilder builder(param->getModule()); + auto newInfo = builder.getPtrTypeWithAddressSpace( + (IRType*)as(info)->getValueType(), + as(param->getDataType())); + updateInfo(context, param, newInfo, workQueue); + } else { // If we hit something unsupported, assume no information. @@ -1591,7 +1646,13 @@ struct DynamicInstLoweringContext case InterproceduralEdge::Direction::CallToFunc: { // Propagate argument info from call site to function parameters - auto firstBlock = targetCallee->getFirstBlock(); + IRBlock* firstBlock = nullptr; + + if (as(targetCallee)) + firstBlock = targetCallee->getFirstBlock(); + else if (auto specInst = as(targetCallee)) + firstBlock = getGenericReturnVal(specInst->getBase())->getFirstBlock(); + if (!firstBlock) return; @@ -1601,18 +1662,25 @@ struct DynamicInstLoweringContext if (argIndex < callInst->getOperandCount()) { auto arg = callInst->getOperand(argIndex); - if (auto argInfo = tryGetInfo(edge.callerContext, arg)) + const auto [paramDirection, paramType] = + getParameterDirectionAndType(param->getDataType()); + + // Only update if + // 1. The paramType is a global inst and an interface type + // 2. The paramType is a local inst. + // all other cases, continue. + if (isGlobalInst(paramType) && !as(paramType)) { - const auto [paramDirection, paramType] = - getParameterDirectionAndType(param->getDataType()); + argIndex++; + continue; + } - // Only update if the parameter is abstract type. - if (isGlobalInst(paramType) && !(as(paramType))) - { - argIndex++; - continue; - } + IRInst* argInfo = tryGetInfo(edge.callerContext, arg); + if (!argInfo && isGlobalInst(arg->getDataType())) + argInfo = arg->getDataType(); + if (argInfo) + { switch (paramDirection) { case kParameterDirection_Out: @@ -1658,18 +1726,23 @@ struct DynamicInstLoweringContext UIndex argIndex = 0; for (auto paramInfo : paramInfos) { - if (paramDirections[argIndex] == kParameterDirection_Out || - paramDirections[argIndex] == kParameterDirection_InOut) + if (paramInfo) { - auto arg = callInst->getArg(argIndex); - auto argPtrType = as(arg->getDataType()); + if (paramDirections[argIndex] == kParameterDirection_Out || + paramDirections[argIndex] == kParameterDirection_InOut) + { + auto arg = callInst->getArg(argIndex); + auto argPtrType = as(arg->getDataType()); - IRBuilder builder(module); - updateInfo( - edge.callerContext, - builder.getPtrTypeWithAddressSpace((IRType*)paramInfo, argPtrType), - paramInfo, - workQueue); + IRBuilder builder(module); + updateInfo( + edge.callerContext, + arg, + builder.getPtrTypeWithAddressSpace( + (IRType*)as(paramInfo)->getValueType(), + argPtrType), + workQueue); + } } argIndex++; } @@ -1746,6 +1819,9 @@ struct DynamicInstLoweringContext if (!info2) return info1; + if (info1 == info2) + return info1; + if (as(info1) && as(info2)) { SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); @@ -1866,6 +1942,10 @@ struct DynamicInstLoweringContext bool lowerFunc(IRFunc* func) { + // Don't make any changes to non-global or intrinsic functions + if (!isGlobalInst(func) || isIntrinsic(func)) + return false; + bool hasChanges = false; for (auto block : func->getBlocks()) hasChanges |= lowerInstsInBlock(func, block); @@ -1962,24 +2042,6 @@ struct DynamicInstLoweringContext return hasChanges; } - IRType* getTypeForExistential(IRCollectionTaggedUnionType* taggedUnion) - { - // Replace type with Tuple - IRBuilder builder(module); - builder.setInsertInto(module); - - auto typeCollection = cast(taggedUnion->getOperand(0)); - auto tableCollection = cast(taggedUnion->getOperand(1)); - - if (getCollectionCount(typeCollection) == 1) - return builder.getTupleType(List( - {(IRType*)makeTagType(tableCollection), - (IRType*)getCollectionElement(typeCollection, 0)})); - - return builder.getTupleType( - List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); - } - IRType* getLoweredType(IRInst* info) { if (!info) @@ -2007,7 +2069,8 @@ struct DynamicInstLoweringContext if (auto taggedUnion = as(info)) { // If this is a tagged union, we need to create a tuple type - return getTypeForExistential(taggedUnion); + // return getTypeForExistential(taggedUnion); + return (IRType*)taggedUnion; } if (auto collectionTag = as(info)) @@ -2100,6 +2163,8 @@ struct DynamicInstLoweringContext return lowerStore(context, as(inst)); case kIROp_GetSequentialID: return lowerGetSequentialID(context, as(inst)); + case kIROp_IsType: + return lowerIsType(context, as(inst)); default: { if (tryGetInfo(context, inst)) @@ -2207,9 +2272,9 @@ struct DynamicInstLoweringContext auto existential = inst->getOperand(0); auto existentialInfo = existential->getDataType(); - if (isTaggedUnionType(existentialInfo)) + if (as(existentialInfo)) { - auto valType = existentialInfo->getOperand(1); + auto valType = existentialInfo->getOperand(0); IRBuilder builder(inst); builder.setInsertAfter(inst); @@ -2331,6 +2396,15 @@ struct DynamicInstLoweringContext { return newType; } + else if ( + as(currentType) && + as(newType)) + { + // Merge the elements of both tagged unions into a new tuple type + return (IRType*)makeExistential((as(updateType( + (IRType*)currentType->getOperand(0)->getOperand(0), + (IRType*)newType->getOperand(0)->getOperand(0))))); + } else if (isTaggedUnionType(currentType) && isTaggedUnionType(newType)) { IRBuilder builder(module); @@ -2613,6 +2687,12 @@ struct DynamicInstLoweringContext if (!isGlobalInst(callee) || isIntrinsic(callee)) return false; + // One other case to avoid is if the function is a global LookupWitnessMethod + // which can be created when optional witnesses are specialized. + // + if (as(callee)) + return false; + auto expectedFuncType = getEffectiveFuncType(callee); List newArgs; @@ -2744,6 +2824,8 @@ struct DynamicInstLoweringContext bool lowerMakeStruct(IRInst* context, IRMakeStruct* inst) { auto structType = as(inst->getDataType()); + if (!structType) + return false; // Reinterpret any of the arguments as necessary. bool changed = false; @@ -2807,11 +2889,11 @@ struct DynamicInstLoweringContext ? builder.emitPackAnyValue(collectionType, inst->getWrappedValue()) : inst->getWrappedValue(); - auto taggedUnionTupleType = getLoweredType(taggedUnion); + auto taggedUnionType = getLoweredType(taggedUnion); // Create tuple (table_unique_id, PackAnyValue(val)) IRInst* tupleArgs[] = {witnessTableID, packedValue}; - auto tuple = builder.emitMakeTuple(taggedUnionTupleType, 2, tupleArgs); + auto tuple = builder.emitMakeTuple(taggedUnionType, 2, tupleArgs); inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); @@ -2825,7 +2907,7 @@ struct DynamicInstLoweringContext if (!taggedUnion) return false; - auto taggedUnionTupleType = getLoweredType(taggedUnion); + auto taggedUnionType = getLoweredType(taggedUnion); IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -2834,26 +2916,24 @@ struct DynamicInstLoweringContext args.add(inst->getDataType()); args.add(inst->getTypeID()); auto translatedTag = builder.emitIntrinsicInst( - (IRType*)taggedUnionTupleType->getOperand(0), + (IRType*)makeTagType(as(taggedUnionType->getOperand(1))), kIROp_GetTagFromSequentialID, args.getCount(), args.getBuffer()); IRInst* packedValue = nullptr; - if (auto collection = as(taggedUnionTupleType->getOperand(1))) + if (auto collection = as(taggedUnionType->getOperand(0))) { packedValue = builder.emitPackAnyValue((IRType*)collection, inst->getValue()); } else { - packedValue = builder.emitReinterpret( - (IRType*)taggedUnionTupleType->getOperand(1), - inst->getValue()); + packedValue = + builder.emitReinterpret((IRType*)taggedUnionType->getOperand(0), inst->getValue()); } - auto newInst = builder.emitMakeTuple( - taggedUnionTupleType, - List({translatedTag, packedValue})); + auto newInst = + builder.emitMakeTuple(taggedUnionType, List({translatedTag, packedValue})); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -3122,6 +3202,41 @@ struct DynamicInstLoweringContext return false; } + bool lowerIsType(IRInst* context, IRIsType* inst) + { + SLANG_UNUSED(context); + auto witnessTableArg = inst->getValueWitness(); + if (auto tagType = as(witnessTableArg->getDataType())) + { + IRBuilder builder(inst); + setInsertAfterOrdinaryInst(&builder, inst); + auto firstElement = + getCollectionElement(as(tagType->getOperand(0)), 0); + auto interfaceType = + as(as(firstElement)->getConformanceType()); + + // TODO: This is a rather suboptimal implementation that involves using + // global sequential IDs even though we could do it via local IDs. + // + + List args = {interfaceType, witnessTableArg}; + auto valueSeqID = builder.emitIntrinsicInst( + (IRType*)builder.getUIntType(), + kIROp_GetSequentialIDFromTag, + args.getCount(), + args.getBuffer()); + + auto targetSeqID = builder.emitGetSequentialIDInst(inst->getTargetWitness()); + auto eqlInst = builder.emitEql(valueSeqID, targetSeqID); + + inst->replaceUsesWith(eqlInst); + inst->removeAndDeallocate(); + return true; + } + + return false; + } + UInt getUniqueID(IRInst* inst) { auto existingId = uniqueIds.tryGetValue(inst); @@ -3784,9 +3899,9 @@ void lowerTagTypes(IRModule* module) // tuple to the the interface-based tuple (of the loaded inst or before // storing the val, as necessary) // -struct TaggedUnionCastLoweringContext : public InstPassBase +struct TaggedUnionLoweringContext : public InstPassBase { - TaggedUnionCastLoweringContext(IRModule* module) + TaggedUnionLoweringContext(IRModule* module) : InstPassBase(module) { } @@ -3902,8 +4017,37 @@ struct TaggedUnionCastLoweringContext : public InstPassBase SLANG_UNEXPECTED("Unexpected inst of CastTaggedUnionToInterfacePtr"); } + IRType* convertToTupleType(IRCollectionTaggedUnionType* taggedUnion) + { + // Replace type with Tuple + IRBuilder builder(module); + builder.setInsertInto(module); + + auto typeCollection = cast(taggedUnion->getOperand(0)); + auto tableCollection = cast(taggedUnion->getOperand(1)); + + if (getCollectionCount(typeCollection) == 1) + return builder.getTupleType(List( + {(IRType*)makeTagType(tableCollection), + (IRType*)getCollectionElement(typeCollection, 0)})); + + return builder.getTupleType( + List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); + } + bool processModule() { + // First, we'll lower all CollectionTaggedUnionType insts + // into tuples. + // + processInstsOfType( + kIROp_CollectionTaggedUnionType, + [&](IRCollectionTaggedUnionType* inst) + { + inst->replaceUsesWith(convertToTupleType(inst)); + inst->removeAndDeallocate(); + }); + bool hasCastInsts = false; processInstsOfType( kIROp_CastInterfaceToTaggedUnionPtr, @@ -3925,11 +4069,11 @@ struct TaggedUnionCastLoweringContext : public InstPassBase } }; -bool lowerTaggedUnionPtrCasts(IRModule* module, DiagnosticSink* sink) +bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); - TaggedUnionCastLoweringContext context(module); + TaggedUnionLoweringContext context(module); return context.processModule(); } diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-dynamic-insts.h index 5b255c61a04..169d1be851e 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-dynamic-insts.h @@ -15,5 +15,5 @@ void lowerTagInsts(IRModule* module, DiagnosticSink* sink); void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink); void lowerTagTypes(IRModule* module); -bool lowerTaggedUnionPtrCasts(IRModule* module, DiagnosticSink* sink); +bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); } // namespace Slang diff --git a/tests/diagnostics/no-type-conformance.slang.expected b/tests/diagnostics/no-type-conformance.slang.expected index 5f5eda6afe7..ebd9482ec0b 100644 --- a/tests/diagnostics/no-type-conformance.slang.expected +++ b/tests/diagnostics/no-type-conformance.slang.expected @@ -1,8 +1,8 @@ result code = -1 standard error = { tests/diagnostics/no-type-conformance.slang(12): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage. - obj.get(); - ^ +interface IFoo + ^~~~ } standard output = { } diff --git a/tests/ir/dump-module-info.slang b/tests/ir/dump-module-info.slang index c7753b44082..67a43b274fe 100644 --- a/tests/ir/dump-module-info.slang +++ b/tests/ir/dump-module-info.slang @@ -6,7 +6,7 @@ module "foo"; // CHECK: Module Name: foo // This will need bumping whenever we bump the ir module version -// CHECK: Module Version: 1 +// CHECK: Module Version: 2 // Just check that this is in the output with some string // CHECK: Compiler Version: {{.+}} diff --git a/tests/wgsl/switch-case.slang b/tests/wgsl/switch-case.slang index c4ff0996e30..133982cd030 100644 --- a/tests/wgsl/switch-case.slang +++ b/tests/wgsl/switch-case.slang @@ -70,17 +70,21 @@ func fs_main(VertexOutput input)->FragmentOutput return output; } -//WGSL: fn _S9( _S10 : Tuple_0) -> f32 -//WGSL-NEXT: { -//WGSL-NEXT: switch(_S10.value1_0.x) -//WGSL-NEXT: { -//WGSL-NEXT: case u32(0): -//WGSL-NEXT: { -//WGSL-NEXT: return Circle_getArea_0(unpackAnyValue16_0(_S10.value2_0)); -//WGSL-NEXT: } -//WGSL-NEXT: default : -//WGSL-NEXT: { -//WGSL-NEXT: return Rectangle_getArea_0(unpackAnyValue16_1(_S10.value2_0)); -//WGSL-NEXT: } -//WGSL-NEXT: } -//WGSL-NEXT: } +//WGSL: fn _S13( _S14 : u32, _S15 : AnyValue8) -> f32 +//WGSL-NEXT:{ +//WGSL-NEXT: switch(_S14) +//WGSL-NEXT: { +//WGSL-NEXT: case u32(0): +//WGSL-NEXT: { +//WGSL-NEXT: return U_SR14switch_2Dxcase6Circle7getAreap0pf_wtwrapper_0(_S15); +//WGSL-NEXT: } +//WGSL-NEXT: case u32(1): +//WGSL-NEXT: { +//WGSL-NEXT: return U_SR14switch_2Dxcase9Rectangle7getAreap0pf_wtwrapper_0(_S15); +//WGSL-NEXT: } +//WGSL-NEXT: default : +//WGSL-NEXT: { +//WGSL-NEXT: return 0.0f; +//WGSL-NEXT: } +//WGSL-NEXT: } + //WGSL-NEXT:} \ No newline at end of file From 9347b4f29edd297f951f8ddc3936199fc619609e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:19:52 -0400 Subject: [PATCH 036/105] Update slang-ir-lower-dynamic-insts.cpp --- source/slang/slang-ir-lower-dynamic-insts.cpp | 62 +++++++++++-------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 49af666d15a..9afc4d2c727 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1676,34 +1676,46 @@ struct DynamicInstLoweringContext } IRInst* argInfo = tryGetInfo(edge.callerContext, arg); - if (!argInfo && isGlobalInst(arg->getDataType())) - argInfo = arg->getDataType(); - if (argInfo) + switch (paramDirection) { - switch (paramDirection) + case kParameterDirection_Out: + case kParameterDirection_InOut: { - case kParameterDirection_Out: - case kParameterDirection_InOut: + IRBuilder builder(module); + if (!argInfo) { - IRBuilder builder(module); - auto newInfo = fromDirectionAndType( - &builder, - paramDirection, - as(argInfo)->getValueType()); - updateInfo(edge.targetContext, param, newInfo, workQueue); - break; + if (isGlobalInst(arg->getDataType()) && + !as( + as(arg->getDataType())->getValueType())) + argInfo = arg->getDataType(); } - case kParameterDirection_In: - { - // Use centralized update method - updateInfo(edge.targetContext, param, argInfo, workQueue); + + if (!argInfo) break; + + auto newInfo = fromDirectionAndType( + &builder, + paramDirection, + as(argInfo)->getValueType()); + updateInfo(edge.targetContext, param, newInfo, workQueue); + break; + } + case kParameterDirection_In: + { + // Use centralized update method + if (!argInfo) + { + if (isGlobalInst(arg->getDataType()) && + !as(arg->getDataType())) + argInfo = arg->getDataType(); } - default: - SLANG_UNEXPECTED( - "Unhandled parameter direction in interprocedural edge"); + updateInfo(edge.targetContext, param, argInfo, workQueue); + break; } + default: + SLANG_UNEXPECTED( + "Unhandled parameter direction in interprocedural edge"); } } argIndex++; @@ -1930,7 +1942,7 @@ struct DynamicInstLoweringContext continue; auto loweredFieldType = getLoweredType(info); - if (loweredFieldType != field->getDataType()) + if (loweredFieldType != field->getFieldType()) { hasChanges = true; field->setFieldType(loweredFieldType); @@ -2401,9 +2413,8 @@ struct DynamicInstLoweringContext as(newType)) { // Merge the elements of both tagged unions into a new tuple type - return (IRType*)makeExistential((as(updateType( - (IRType*)currentType->getOperand(0)->getOperand(0), - (IRType*)newType->getOperand(0)->getOperand(0))))); + return (IRType*)makeExistential((as( + updateType((IRType*)currentType->getOperand(1), (IRType*)newType->getOperand(1))))); } else if (isTaggedUnionType(currentType) && isTaggedUnionType(newType)) { @@ -2922,7 +2933,8 @@ struct DynamicInstLoweringContext args.getBuffer()); IRInst* packedValue = nullptr; - if (auto collection = as(taggedUnionType->getOperand(0))) + auto collection = as(taggedUnionType->getOperand(0)); + if (getCollectionCount(collection) > 1) { packedValue = builder.emitPackAnyValue((IRType*)collection, inst->getValue()); } From 1a4021c6626f786a38d2a13a320666956738a0ee Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 12 Aug 2025 19:11:52 -0400 Subject: [PATCH 037/105] Fix the last two tests --- source/slang/slang-ir-lower-dynamic-insts.cpp | 57 ++++++++++++------- source/slang/slang-ir-specialize.cpp | 15 +++-- 2 files changed, 48 insertions(+), 24 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 9afc4d2c727..5cd0da28bb6 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1249,6 +1249,14 @@ struct DynamicInstLoweringContext return none(); // No info for the type } + if (!isGlobalInst(typeOfSpecialization)) + { + // Our func-type operand is not yet been lifted. + // For now, we can't say anything. + // + return none(); + } + IRCollectionBase* collection = nullptr; if (auto _collection = as(operandInfo)) { @@ -1432,7 +1440,7 @@ struct DynamicInstLoweringContext else if (auto fieldAddress = as(inst)) { // If this is a field address, update the fieldInfos map. - if (auto thisPtrInfo = as(info)) + if (as(info)) { IRBuilder builder(module); auto baseStructPtrType = as(fieldAddress->getBase()->getDataType()); @@ -1572,10 +1580,12 @@ struct DynamicInstLoweringContext for (auto param : func->getParams()) { - if (auto newType = tryGetInfo(context, param)) - effectiveTypes.add((IRType*)newType); - else - effectiveTypes.add(param->getDataType()); + if (auto newInfo = tryGetInfo(context, param)) + if (getLoweredType(newInfo) != nullptr) // Check that info isn't unbounded + effectiveTypes.add((IRType*)newInfo); + + // Fallback.. no new info, just use the param type. + effectiveTypes.add(param->getDataType()); } return effectiveTypes; @@ -2065,17 +2075,25 @@ struct DynamicInstLoweringContext if (auto ptrType = as(info)) { IRBuilder builder(module); - return builder.getPtrTypeWithAddressSpace( - (IRType*)getLoweredType(ptrType->getValueType()), - ptrType); + if (auto loweredValueType = getLoweredType(ptrType->getValueType())) + { + return builder.getPtrTypeWithAddressSpace((IRType*)loweredValueType, ptrType); + } + else + return nullptr; } if (auto arrayType = as(info)) { IRBuilder builder(module); - return builder.getArrayType( - (IRType*)getLoweredType(arrayType->getElementType()), - arrayType->getElementCount()); + if (auto loweredElementType = getLoweredType(arrayType->getElementType())) + { + return builder.getArrayType( + (IRType*)loweredElementType, + arrayType->getElementCount()); + } + else + return nullptr; } if (auto taggedUnion = as(info)) @@ -2883,8 +2901,7 @@ struct DynamicInstLoweringContext 1, &zeroValueOfTagType); } - else if ( - auto witnessTableTag = as(inst->getWitnessTable()->getDataType())) + else if (as(inst->getWitnessTable()->getDataType())) { // Dynamic. Use the witness table inst as a tag witnessTableID = inst->getWitnessTable(); @@ -2965,7 +2982,8 @@ struct DynamicInstLoweringContext auto loweredValType = (IRType*)getLoweredType(valInfo); if (bufferBaseType != loweredValType) { - if (as(bufferBaseType)) + if (as(bufferBaseType) && !isComInterfaceType(bufferBaseType) && + !isBuiltin(bufferBaseType)) { // If we're dealing with a loading a known tagged union value from // an interface-typed pointer, we'll cast the pointer itself and @@ -3028,7 +3046,7 @@ struct DynamicInstLoweringContext // TODO: Maybe make this the 'default' behavior if a lowering call // returns false. // - if (auto info = tryGetInfo(context, inst)) + if (tryGetInfo(context, inst)) return replaceType(context, inst); else return false; @@ -3078,9 +3096,9 @@ struct DynamicInstLoweringContext SLANG_UNUSED(context); auto destType = inst->getDataType(); auto operandInfo = inst->getOperand(0)->getDataType(); - if (auto taggedUnionTupleType = as(operandInfo)) + if (auto taggedUnionTupleType = as(operandInfo)) { - SLANG_ASSERT(taggedUnionTupleType->getOperand(1) == destType); + // SLANG_ASSERT(taggedUnionTupleType->getOperand(1) == destType); IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); @@ -3108,7 +3126,8 @@ struct DynamicInstLoweringContext { SLANG_ASSERT(!as(inst)); - if (as(ptrValType)) + if (as(ptrValType) && !isComInterfaceType(ptrValType) && + !isBuiltin(ptrValType)) { // If we're dealing with a loading a known tagged union value from // an interface-typed pointer, we'll cast the pointer itself and @@ -3171,7 +3190,7 @@ struct DynamicInstLoweringContext // removed from Slang, but the auto-diff process can sometimes // produce a store of default-constructed value. // - if (auto defaultConstruct = as(inst->getVal())) + if (as(inst->getVal())) return handleDefaultStore(context, inst); auto loweredVal = upcastCollection(context, inst->getVal(), ptrInfo); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index f35908f89e6..41317476f0f 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3189,13 +3189,22 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) cloneEnv.mapOldValToNew[block] = cloneInstAndOperands(&cloneEnv, &builder, block); } + builder.setInsertInto(builder.getModule()); + auto loweredFuncType = + as(cloneInst(&staticCloningEnv, &builder, returnedFunc->getFullType())); + loweredFunc->setFullType((IRType*)loweredFuncType); + builder.setInsertInto(loweredFunc->getFirstBlock()); builder.emitBranch(as(cloneEnv.mapOldValToNew[funcFirstBlock])); for (auto param : funcFirstBlock->getParams()) { // Clone the parameters of the first block. - builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); + if (loweredFunc->getFirstBlock()->getFirstParam() == nullptr) + builder.setInsertBefore(loweredFunc->getFirstBlock()->getFirstChild()); + else + builder.setInsertAfter(loweredFunc->getFirstBlock()->getLastParam()); + auto newParam = cloneInst(&staticCloningEnv, &builder, param); cloneEnv.mapOldValToNew[param] = newParam; // Transfer the param to the dynamic env } @@ -3220,10 +3229,6 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) cloneEnv.mapOldValToNew[block]); } - builder.setInsertInto(builder.getModule()); - auto loweredFuncType = as( - cloneInst(&cloneEnv, &builder, as(returnedFunc->getFullType()))); - // Add extra indices to the func-type parameters List funcTypeParams; for (Index i = 0; i < extraParamTypes.getCount(); i++) From dbd7402a4da00236a7d9a0901f5dc6dbd6db003d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:01:15 -0400 Subject: [PATCH 038/105] Fixup minor bug --- source/slang/slang-ir-lower-dynamic-insts.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 5cd0da28bb6..33f0f2a6394 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -1582,7 +1582,10 @@ struct DynamicInstLoweringContext { if (auto newInfo = tryGetInfo(context, param)) if (getLoweredType(newInfo) != nullptr) // Check that info isn't unbounded + { effectiveTypes.add((IRType*)newInfo); + continue; + } // Fallback.. no new info, just use the param type. effectiveTypes.add(param->getDataType()); From 79773d35d9bc80573a548c8c99dec93371d3e91e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:17:22 -0400 Subject: [PATCH 039/105] Fix unused var --- source/slang/slang-ir-lower-dynamic-insts.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 33f0f2a6394..dccf543dc39 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -3099,10 +3099,8 @@ struct DynamicInstLoweringContext SLANG_UNUSED(context); auto destType = inst->getDataType(); auto operandInfo = inst->getOperand(0)->getDataType(); - if (auto taggedUnionTupleType = as(operandInfo)) + if (as(operandInfo)) { - // SLANG_ASSERT(taggedUnionTupleType->getOperand(1) == destType); - IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); auto newInst = builder.emitGetTupleElement((IRType*)destType, inst->getOperand(0), 1); From 7efc97ce9257fa6acd8416dbd0cffd673c58e8da Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 13 Aug 2025 18:12:20 -0400 Subject: [PATCH 040/105] More fixes for specialization --- source/slang/slang-ir-lower-dynamic-insts.cpp | 1 - source/slang/slang-ir-specialize.cpp | 21 ++++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index dccf543dc39..cc3033b0b94 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -16,7 +16,6 @@ namespace Slang // Forward-declare.. (TODO: Just include this from the header instead) IRInst* specializeGeneric(IRSpecialize* specializeInst); -constexpr IRIntegerValue kDefaultAnyValueSize = 16; // Elements for which we keep track of propagation information. struct Element { diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 41317476f0f..7ad9923ec03 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -964,18 +964,33 @@ struct SpecializationContext void readSpecializationDictionaries() { auto moduleInst = module->getModuleInst(); + ShortList dictInsts; for (auto child : moduleInst->getChildren()) { switch (child->getOp()) { case kIROp_GenericSpecializationDictionary: - _readSpecializationDictionaryImpl(genericSpecializations, child); + case kIROp_ExistentialFuncSpecializationDictionary: + case kIROp_ExistentialTypeSpecializationDictionary: + dictInsts.add(child); + break; + default: + continue; + } + } + + for (auto dict : dictInsts) + { + switch (dict->getOp()) + { + case kIROp_GenericSpecializationDictionary: + _readSpecializationDictionaryImpl(genericSpecializations, dict); break; case kIROp_ExistentialFuncSpecializationDictionary: - _readSpecializationDictionaryImpl(existentialSpecializedFuncs, child); + _readSpecializationDictionaryImpl(existentialSpecializedFuncs, dict); break; case kIROp_ExistentialTypeSpecializationDictionary: - _readSpecializationDictionaryImpl(existentialSpecializedStructs, child); + _readSpecializationDictionaryImpl(existentialSpecializedStructs, dict); break; default: continue; From a5cf6b2cc781f8c14d9d63c98c7f5de737725ff5 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 14 Aug 2025 17:51:53 -0400 Subject: [PATCH 041/105] More fixes for Falcor --- source/slang/slang-ir-link.cpp | 1 + source/slang/slang-ir-lower-dynamic-insts.cpp | 31 +++++++++++++------ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index a3466c8c716..26495770ac5 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -302,6 +302,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_GlobalGenericParam: case kIROp_WitnessTable: case kIROp_InterfaceType: + case kIROp_EnumType: return cloneGlobalValue(this, originalValue); case kIROp_BoolLit: diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index cc3033b0b94..040716b47e8 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -468,10 +468,15 @@ struct DynamicInstLoweringContext // Use this when you want to propagate new information to an existing instruction. // This will union the new info with existing info and add users to work queue if changed // - void updateInfo(IRInst* context, IRInst* inst, IRInst* newInfo, WorkQueue& workQueue) + void updateInfo( + IRInst* context, + IRInst* inst, + IRInst* newInfo, + bool takeUnion, + WorkQueue& workQueue) { auto existingInfo = tryGetInfo(context, inst); - auto unionedInfo = unionPropagationInfo(existingInfo, newInfo); + auto unionedInfo = (takeUnion) ? unionPropagationInfo(existingInfo, newInfo) : newInfo; // Only proceed if info actually changed if (areInfosEqual(existingInfo, unionedInfo)) @@ -792,7 +797,10 @@ struct DynamicInstLoweringContext break; } - updateInfo(context, inst, info, workQueue); + // TODO: Remove this workaround.. there are a few insts + // where we shouldn't + bool takeUnion = !as(inst); + updateInfo(context, inst, info, takeUnion, workQueue); } IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) @@ -1330,7 +1338,7 @@ struct DynamicInstLoweringContext if (auto collection = as(arg)) { - updateInfo(context, param, makeTagType(collection), workQueue); + updateInfo(context, param, makeTagType(collection), true, workQueue); } else if (as(arg) || as(arg)) { @@ -1338,6 +1346,7 @@ struct DynamicInstLoweringContext context, param, makeTagType(makeSingletonSet(arg)), + true, workQueue); } else @@ -1479,7 +1488,7 @@ struct DynamicInstLoweringContext else if (auto var = as(inst)) { // If we hit a local var, we'll update it's info. - updateInfo(context, var, info, workQueue); + updateInfo(context, var, info, true, workQueue); } else if (auto param = as(inst)) { @@ -1491,7 +1500,7 @@ struct DynamicInstLoweringContext auto newInfo = builder.getPtrTypeWithAddressSpace( (IRType*)as(info)->getValueType(), as(param->getDataType())); - updateInfo(context, param, newInfo, workQueue); + updateInfo(context, param, newInfo, true, workQueue); } else { @@ -1530,7 +1539,7 @@ struct DynamicInstLoweringContext if (auto argInfo = tryGetInfo(context, arg)) { // Use centralized update method - updateInfo(context, param, argInfo, workQueue); + updateInfo(context, param, argInfo, true, workQueue); } } paramIndex++; @@ -1693,6 +1702,7 @@ struct DynamicInstLoweringContext { case kParameterDirection_Out: case kParameterDirection_InOut: + case kParameterDirection_ConstRef: { IRBuilder builder(module); if (!argInfo) @@ -1710,7 +1720,7 @@ struct DynamicInstLoweringContext &builder, paramDirection, as(argInfo)->getValueType()); - updateInfo(edge.targetContext, param, newInfo, workQueue); + updateInfo(edge.targetContext, param, newInfo, true, workQueue); break; } case kParameterDirection_In: @@ -1722,7 +1732,7 @@ struct DynamicInstLoweringContext !as(arg->getDataType())) argInfo = arg->getDataType(); } - updateInfo(edge.targetContext, param, argInfo, workQueue); + updateInfo(edge.targetContext, param, argInfo, true, workQueue); break; } default: @@ -1741,7 +1751,7 @@ struct DynamicInstLoweringContext if (returnInfo) { // Use centralized update method - updateInfo(edge.callerContext, callInst, *returnInfo, workQueue); + updateInfo(edge.callerContext, callInst, *returnInfo, true, workQueue); } // Also update infos of any out parameters @@ -1765,6 +1775,7 @@ struct DynamicInstLoweringContext builder.getPtrTypeWithAddressSpace( (IRType*)as(paramInfo)->getValueType(), argPtrType), + true, workQueue); } } From 95c38f2aff9210fc25a713b6eb23da6556fd1c7c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 15 Aug 2025 13:11:57 -0400 Subject: [PATCH 042/105] Fix insts getting inserted into the middle of the param list --- source/slang/slang-ir-lower-dynamic-insts.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index 040716b47e8..d35bdb14d25 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -714,7 +714,7 @@ struct DynamicInstLoweringContext getCollectionCount(as(destInfo))) { IRBuilder builder(module); - builder.setInsertAfter(arg); + setInsertAfterOrdinaryInst(&builder, arg); return builder .emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); } @@ -726,14 +726,14 @@ struct DynamicInstLoweringContext { // If the sets of witness tables are not equal, reinterpret to the parameter type IRBuilder builder(module); - builder.setInsertAfter(arg); + setInsertAfterOrdinaryInst(&builder, arg); return builder.emitReinterpret((IRType*)destInfo, arg); } } else if (!as(argInfo) && as(destInfo)) { IRBuilder builder(module); - builder.setInsertAfter(arg); + setInsertAfterOrdinaryInst(&builder, arg); return builder.emitPackAnyValue((IRType*)destInfo, arg); } From 76b39ef50c0afdaa72f2adfb60b71157d1b0d913 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 19 Aug 2025 17:15:12 -0400 Subject: [PATCH 043/105] Make the default conformance be the last one instead of the first one Fixes slangpy conformances test --- source/slang/slang-ir-lower-dynamic-insts.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp index d35bdb14d25..c638c613604 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-insts.cpp @@ -3543,7 +3543,7 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) } -IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping) +IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping, UInt defaultVal) { // Create a function that maps input IDs to output IDs IRBuilder builder(module); @@ -3562,7 +3562,7 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi // Create default block that returns 0 auto defaultBlock = builder.emitBlock(); builder.setInsertInto(defaultBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), 0)); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), defaultVal)); // Go back to entry block and create switch builder.setInsertInto(entryBlock); @@ -3836,9 +3836,10 @@ struct SequentialIDTagLoweringContext : public InstPassBase IRBuilder builder(inst); builder.setInsertAfter(inst); + UInt defaultID = dstSeqID - 1; // Default to last available conformance. auto translatedID = builder.emitCallInst( inst->getDataType(), - createIntegerMappingFunc(builder.getModule(), mapping), + createIntegerMappingFunc(builder.getModule(), mapping, defaultID), List({srcSeqID})); inst->replaceUsesWith(translatedID); @@ -3877,7 +3878,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase builder.setInsertAfter(inst); auto translatedID = builder.emitCallInst( inst->getDataType(), - createIntegerMappingFunc(builder.getModule(), mapping), + createIntegerMappingFunc(builder.getModule(), mapping, 0), List({srcTagInst})); inst->replaceUsesWith(translatedID); From e1485c0adc7730d2b5066f8608ad3aaa9cbf8bdc Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:41:40 -0400 Subject: [PATCH 044/105] Delete 2.hlsl --- 2.hlsl | 262 --------------------------------------------------------- 1 file changed, 262 deletions(-) delete mode 100644 2.hlsl diff --git a/2.hlsl b/2.hlsl deleted file mode 100644 index c4f1b11161c..00000000000 --- a/2.hlsl +++ /dev/null @@ -1,262 +0,0 @@ -#pragma pack_matrix(row_major) - -#line 18 "tests/compute/dynamic-dispatch-17.slang" -struct UserDefinedPackedType_0 -{ - float3 val_0; - uint flags_0; -}; - - -#line 28 -RWStructuredBuffer gObj_0 : register(u1); - - -#line 25 -RWStructuredBuffer gOutputBuffer_0 : register(u0); - - -#line 25 -struct AnyValue16 -{ - uint field0_0; - uint field1_0; - uint field2_0; - uint field3_0; -}; - - -#line 25 -AnyValue16 packAnyValue16_0(UserDefinedPackedType_0 _S1) -{ - -#line 25 - AnyValue16 _S2; - -#line 25 - _S2.field0_0 = 0U; - -#line 25 - _S2.field1_0 = 0U; - -#line 25 - _S2.field2_0 = 0U; - -#line 25 - _S2.field3_0 = 0U; - -#line 25 - _S2.field0_0 = (uint)(asuint(_S1.val_0[int(0)])); - -#line 25 - _S2.field1_0 = (uint)(asuint(_S1.val_0[int(1)])); - -#line 25 - _S2.field2_0 = (uint)(asuint(_S1.val_0[int(2)])); - -#line 25 - _S2.field3_0 = _S1.flags_0; - -#line 25 - return _S2; -} - - -#line 39 -uint _S3(uint _S4) -{ - -#line 39 - switch(_S4) - { - case 3U: - { - -#line 39 - return 0U; - } - case 4U: - { - -#line 39 - return 1U; - } - default: - { - -#line 39 - return 0U; - } - } - -#line 39 -} - - -#line 50 -struct FloatVal_0 -{ - float val_1; -}; - - -#line 48 -float ReturnsZero_get_0() -{ - -#line 48 - return 0.0f; -} - - - -float FloatVal_run_0(FloatVal_0 this_0) -{ - - float _S5 = ReturnsZero_get_0(); - -#line 56 - return this_0.val_1 + _S5; -} - - -#line 56 -float U_S4main8FloatVal3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(AnyValue16 _S6) -{ - -#line 56 - float _S7 = FloatVal_run_0(_S6); - -#line 56 - return _S7; -} - -struct Float4Struct_0 -{ - float4 val_2; -}; - - -#line 60 -struct Float4Val_0 -{ - Float4Struct_0 val_3; -}; - - -#line 63 -float Float4Val_run_0(Float4Val_0 this_1) -{ - - float _S8 = this_1.val_3.val_2.x + this_1.val_3.val_2.y; - -#line 66 - float _S9 = ReturnsZero_get_0(); - -#line 66 - return _S8 + _S9; -} - - -#line 66 -float U_S4main9Float4Val3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(AnyValue16 _S10) -{ - -#line 66 - float _S11 = Float4Val_run_0(_S10); - -#line 66 - return _S11; -} - - -#line 66 -float _S12(uint _S13, AnyValue16 _S14) -{ - -#line 66 - switch(_S13) - { - case 0U: - { - -#line 66 - float _S15 = U_S4main8FloatVal3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(_S14); - -#line 66 - return _S15; - } - case 1U: - { - -#line 66 - float _S16 = U_S4main9Float4Val3rung2TCGP04main12IReturnsZerop0pfG0_ST4main11ReturnsZero1_SW4main11ReturnsZero4main12IReturnsZero_wtwrapper_0(_S14); - -#line 66 - return _S16; - } - default: - { - -#line 66 - return 0.0f; - } - } - -#line 66 -} - - -#line 34 -[numthreads(1, 1, 1)] -void computeMain(uint3 dispatchThreadID_0 : SV_DispatchThreadID) -{ - int i_0 = int(0); - -#line 37 - float result_0 = 0.0f; - -#line 37 - for(;;) - { - -#line 37 - if(i_0 < int(2)) - { - } - else - { - -#line 37 - break; - } - UserDefinedPackedType_0 rawObj_0 = gObj_0.Load(i_0); - uint _S17 = _S3(rawObj_0.flags_0); - -#line 40 - uint _S18[int(2)] = { 0U, 1U }; - -#line 40 - AnyValue16 _S19 = packAnyValue16_0(rawObj_0); - float _S20 = _S12(_S18[_S17], _S19); - -#line 41 - float result_1 = result_0 + _S20; - -#line 37 - int i_1 = i_0 + int(1); - -#line 37 - i_0 = i_1; - -#line 37 - result_0 = result_1; - -#line 37 - } - -#line 43 - gOutputBuffer_0[int(0)] = result_0; - return; -} - From 0c9fa6b4bd92f51b41b5e01e98bab19fdf917ccb Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:49:37 -0400 Subject: [PATCH 045/105] Split files into 3 (collection utilities, specialization and lowering) --- source/slang/slang-emit.cpp | 3 +- source/slang/slang-ir-lower-generics.cpp | 2 +- .../slang/slang-ir-lower-typeflow-insts.cpp | 727 +++++++++++ ...nsts.h => slang-ir-lower-typeflow-insts.h} | 6 +- source/slang/slang-ir-specialize.cpp | 55 +- source/slang/slang-ir-typeflow-collection.cpp | 133 ++ source/slang/slang-ir-typeflow-collection.h | 53 + ...s.cpp => slang-ir-typeflow-specialize.cpp} | 1096 ++--------------- source/slang/slang-ir-typeflow-specialize.h | 9 + source/slang/slang-ir.h | 8 + .../dynamic-specialization-4.slang | 72 ++ .../dynamic-specialization-5.slang | 66 + 12 files changed, 1224 insertions(+), 1006 deletions(-) create mode 100644 source/slang/slang-ir-lower-typeflow-insts.cpp rename source/slang/{slang-ir-lower-dynamic-insts.h => slang-ir-lower-typeflow-insts.h} (64%) create mode 100644 source/slang/slang-ir-typeflow-collection.cpp create mode 100644 source/slang/slang-ir-typeflow-collection.h rename source/slang/{slang-ir-lower-dynamic-insts.cpp => slang-ir-typeflow-specialize.cpp} (75%) create mode 100644 source/slang/slang-ir-typeflow-specialize.h create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang create mode 100644 tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c6439343422..d7f45975e28 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -76,7 +76,6 @@ #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-coopvec.h" -#include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-lower-dynamic-resource-heap.h" #include "slang-ir-lower-enum-type.h" #include "slang-ir-lower-generics.h" @@ -86,6 +85,7 @@ #include "slang-ir-lower-reinterpret.h" #include "slang-ir-lower-result-type.h" #include "slang-ir-lower-tuple-types.h" +#include "slang-ir-lower-typeflow-insts.h" #include "slang-ir-metadata.h" #include "slang-ir-metal-legalize.h" #include "slang-ir-missing-return.h" @@ -113,6 +113,7 @@ #include "slang-ir-synthesize-active-mask.h" #include "slang-ir-transform-params-to-constref.h" #include "slang-ir-translate-global-varying-var.h" +#include "slang-ir-typeflow-specialize.h" #include "slang-ir-undo-param-copy.h" #include "slang-ir-uniformity.h" #include "slang-ir-user-type-hint.h" diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 8854693e7e7..6ad6e38ef25 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -9,12 +9,12 @@ #include "slang-ir-generics-lowering-context.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-layout.h" -#include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-lower-existential.h" #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-type.h" #include "slang-ir-lower-tuple-types.h" +#include "slang-ir-lower-typeflow-insts.h" #include "slang-ir-specialize-dispatch.h" #include "slang-ir-specialize-dynamic-associatedtype-lookup.h" #include "slang-ir-ssa-simplification.h" diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp new file mode 100644 index 00000000000..ac418ade1f1 --- /dev/null +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -0,0 +1,727 @@ +#include "slang-ir-lower-typeflow-insts.h" + +#include "slang-ir-any-value-marshalling.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-insts.h" +#include "slang-ir-typeflow-collection.h" +#include "slang-ir-util.h" +#include "slang-ir-witness-table-wrapper.h" +#include "slang-ir.h" + +namespace Slang +{ +SlangInt calculateAnyValueSize(const HashSet& types) +{ + SlangInt maxSize = 0; + for (auto type : types) + { + auto size = getAnyValueSize(type); + if (size > maxSize) + maxSize = size; + } + return maxSize; +} + +IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& types) +{ + auto size = calculateAnyValueSize(types); + return builder->getAnyValueType(size); +} + +IRFunc* createDispatchFunc(IRFuncCollection* collection) +{ + // An effective func type should have been set during the dynamic-inst-lowering + // pass. + // + IRFuncType* dispatchFuncType = cast(collection->getFullType()); + + // Create a dispatch function with switch-case for each function + IRBuilder builder(collection->getModule()); + + // Consume the first parameter of the expected function type + List innerParamTypes; + for (auto paramType : dispatchFuncType->getParamTypes()) + innerParamTypes.add(paramType); + innerParamTypes.removeAt(0); // Remove the first parameter (ID) + + auto resultType = dispatchFuncType->getResultType(); + auto innerFuncType = builder.getFuncType(innerParamTypes, resultType); + + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(dispatchFuncType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto idParam = builder.emitParam(builder.getUIntType()); + + // Create parameters for the original function arguments + List originalParams; + for (Index i = 0; i < innerParamTypes.getCount(); i++) + { + originalParams.add(builder.emitParam(innerParamTypes[i])); + } + + // Create default block + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + // Return a default-constructed value + auto defaultValue = builder.emitDefaultConstruct(resultType); + builder.emitReturn(defaultValue); + } + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each function + List caseValues; + List caseBlocks; + + UIndex funcSeqID = 0; + forEachInCollection( + collection, + [&](IRInst* funcInst) + { + auto funcId = funcSeqID++; + auto wrapperFunc = + emitWitnessTableWrapper(funcInst->getModule(), funcInst, innerFuncType); + + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + + List callArgs; + auto wrappedFuncType = as(wrapperFunc->getDataType()); + for (Index ii = 0; ii < originalParams.getCount(); ii++) + { + callArgs.add(originalParams[ii]); + } + + // Call the specific function + auto callResult = + builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); + + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + builder.emitReturn(callResult); + } + + caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); + caseBlocks.add(caseBlock); + }); + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Create an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + idParam, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + + +IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping, UInt defaultVal) +{ + // Create a function that maps input IDs to output IDs + IRBuilder builder(module); + + auto funcType = + builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); + auto func = builder.createFunc(); + builder.setInsertInto(func); + func->setFullType(funcType); + + auto entryBlock = builder.emitBlock(); + builder.setInsertInto(entryBlock); + + auto param = builder.emitParam(builder.getUIntType()); + + // Create default block that returns 0 + auto defaultBlock = builder.emitBlock(); + builder.setInsertInto(defaultBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), defaultVal)); + + // Go back to entry block and create switch + builder.setInsertInto(entryBlock); + + // Create case blocks for each input table + List caseValues; + List caseBlocks; + + for (auto item : mapping) + { + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); + builder.emitReturn(builder.getIntValue(builder.getUIntType(), item.second)); + + caseValues.add(builder.getIntValue(builder.getUIntType(), item.first)); + caseBlocks.add(caseBlock); + } + + // Create flattened case arguments array + List flattenedCaseArgs; + for (Index i = 0; i < caseValues.getCount(); i++) + { + flattenedCaseArgs.add(caseValues[i]); + flattenedCaseArgs.add(caseBlocks[i]); + } + + // Emit an unreachable block for the break block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Go back to entry and emit switch + builder.setInsertInto(entryBlock); + builder.emitSwitch( + param, + unreachableBlock, + defaultBlock, + flattenedCaseArgs.getCount(), + flattenedCaseArgs.getBuffer()); + + return func; +} + +// This context lowers `IRGetTagFromSequentialID`, +// `IRGetTagForSuperCollection`, and `IRGetTagForMappedCollection` instructions, +// + +struct TagOpsLoweringContext : public InstPassBase +{ + TagOpsLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) + { + auto srcCollection = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + + List indices; + for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + { + // Find in destCollection + auto srcElement = srcCollection->getOperand(i); + + bool found = false; + for (UInt j = 0; j < destCollection->getOperandCount(); j++) + { + auto destElement = destCollection->getOperand(j); + if (srcElement == destElement) + { + found = true; + indices.add(builder.getIntValue(builder.getUIntType(), j)); + break; // Found the index + } + } + + if (!found) + { + // destCollection must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } + } + + // Create an array for the lookup + auto lookupArrayType = builder.getArrayType( + builder.getUIntType(), + builder.getIntValue(builder.getUIntType(), indices.getCount())); + auto lookupArray = + builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); + auto resultID = + builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); + inst->replaceUsesWith(resultID); + inst->removeAndDeallocate(); + } + + void lowerGetTagForMappedCollection(IRGetTagForMappedCollection* inst) + { + auto srcCollection = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + auto key = cast(inst->getOperand(1)); + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + + List indices; + for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + { + // Find in destCollection + bool found = false; + auto srcElement = + findWitnessTableEntry(cast(srcCollection->getOperand(i)), key); + for (UInt j = 0; j < destCollection->getOperandCount(); j++) + { + auto destElement = destCollection->getOperand(j); + if (srcElement == destElement) + { + found = true; + indices.add(builder.getIntValue(builder.getUIntType(), j)); + break; // Found the index + } + } + + if (!found) + { + // destCollection must be a super-set + SLANG_UNEXPECTED("Element not found in destination collection"); + } + } + + // Create an array for the lookup + auto lookupArrayType = builder.getArrayType( + builder.getUIntType(), + builder.getIntValue(builder.getUIntType(), indices.getCount())); + auto lookupArray = + builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); + auto resultID = + builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); + inst->replaceUsesWith(resultID); + inst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_GetTagForSuperCollection: + lowerGetTagForSuperCollection(as(inst)); + break; + case kIROp_GetTagForMappedCollection: + lowerGetTagForMappedCollection(as(inst)); + break; + default: + break; + } + } + + void lowerFuncCollection(IRFuncCollection* collection) + { + IRBuilder builder(collection->getModule()); + if (collection->hasUses() && collection->getDataType() != nullptr) + { + auto dispatchFunc = createDispatchFunc(collection); + traverseUses( + collection, + [&](IRUse* use) + { + if (auto callInst = as(use->getUser())) + { + // If the call is a collection call, replace it with the dispatch function + if (callInst->getCallee() == collection) + { + IRBuilder callBuilder(callInst); + callBuilder.setInsertBefore(callInst); + callBuilder.replaceOperand(callInst->getCalleeUse(), dispatchFunc); + } + } + }); + } + } + + void processModule() + { + processInstsOfType( + kIROp_FuncCollection, + [&](IRFuncCollection* inst) { return lowerFuncCollection(inst); }); + + processAllInsts([&](IRInst* inst) { return processInst(inst); }); + } +}; + +// This context lowers `IRTypeCollection` and `IRFuncCollection` instructions +struct CollectionLoweringContext : public InstPassBase +{ + CollectionLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerTypeCollection(IRTypeCollection* collection) + { + HashSet types; + for (UInt i = 0; i < collection->getOperandCount(); i++) + { + if (auto type = as(collection->getOperand(i))) + { + types.add(type); + } + } + + IRBuilder builder(collection->getModule()); + auto anyValueType = createAnyValueType(&builder, types); + collection->replaceUsesWith(anyValueType); + } + + void processModule() + { + processInstsOfType( + kIROp_TypeCollection, + [&](IRTypeCollection* inst) { return lowerTypeCollection(inst); }); + } +}; + +void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + CollectionLoweringContext context(module); + context.processModule(); +} + +struct SequentialIDTagLoweringContext : public InstPassBase +{ + SequentialIDTagLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) + { + SLANG_UNUSED(cast(inst->getOperand(0))); + auto srcSeqID = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + + UIndex dstSeqID = 0; + forEachInCollection( + destCollection, + [&](IRInst* table) + { + // Get unique ID for the witness table + SLANG_UNUSED(cast(table)); + auto outputId = dstSeqID++; + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping[inputId] = outputId; // Map ID to itself for now + } + }); + + IRBuilder builder(inst); + builder.setInsertAfter(inst); + UInt defaultID = dstSeqID - 1; // Default to last available conformance. + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping, defaultID), + List({srcSeqID})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + + + void lowerGetSequentialIDFromTag(IRGetSequentialIDFromTag* inst) + { + SLANG_UNUSED(cast(inst->getOperand(0))); + auto srcTagInst = inst->getOperand(1); + + Dictionary mapping; + + // Map from sequential ID to unique ID + auto destCollection = cast( + cast(srcTagInst->getDataType())->getOperand(0)); + + UIndex dstSeqID = 0; + forEachInCollection( + destCollection, + [&](IRInst* table) + { + // Get unique ID for the witness table + SLANG_UNUSED(cast(table)); + auto outputId = dstSeqID++; + auto seqDecoration = table->findDecoration(); + if (seqDecoration) + { + auto inputId = seqDecoration->getSequentialID(); + mapping.add({outputId, inputId}); + } + }); + + IRBuilder builder(inst); + builder.setInsertAfter(inst); + auto translatedID = builder.emitCallInst( + inst->getDataType(), + createIntegerMappingFunc(builder.getModule(), mapping, 0), + List({srcTagInst})); + + inst->replaceUsesWith(translatedID); + inst->removeAndDeallocate(); + } + + void processModule() + { + processInstsOfType( + kIROp_GetTagFromSequentialID, + [&](IRGetTagFromSequentialID* inst) { return lowerGetTagFromSequentialID(inst); }); + + processInstsOfType( + kIROp_GetSequentialIDFromTag, + [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); + } +}; + +void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + SequentialIDTagLoweringContext context(module); + context.processModule(); +} + +void lowerTagInsts(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + TagOpsLoweringContext tagContext(module); + tagContext.processModule(); +} + +struct TagTypeLoweringContext : public InstPassBase +{ + TagTypeLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void processModule() + { + processInstsOfType( + kIROp_CollectionTagType, + [&](IRCollectionTagType* inst) + { + IRBuilder builder(inst->getModule()); + inst->replaceUsesWith(builder.getUIntType()); + }); + } +}; + +void lowerTagTypes(IRModule* module) +{ + TagTypeLoweringContext context(module); + context.processModule(); +} + +// This context lowers `CastInterfaceToTaggedUnionPtr` and +// `CastTaggedUnionToInterfacePtr` by finding all `IRLoad` and +// `IRStore` uses of these insts, and upcasting the tagged-union +// tuple to the the interface-based tuple (of the loaded inst or before +// storing the val, as necessary) +// +struct TaggedUnionLoweringContext : public InstPassBase +{ + TaggedUnionLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + IRInst* convertToTaggedUnion( + IRBuilder* builder, + IRInst* val, + IRInst* interfaceType, + IRInst* targetType) + { + auto baseInterfaceValue = val; + auto witnessTable = builder->emitExtractExistentialWitnessTable(baseInterfaceValue); + auto tableID = builder->emitGetSequentialIDInst(witnessTable); + + auto taggedUnionTupleType = cast(targetType); + + List getTagOperands; + getTagOperands.add(interfaceType); + getTagOperands.add(tableID); + auto tableTag = builder->emitIntrinsicInst( + (IRType*)taggedUnionTupleType->getOperand(0), + kIROp_GetTagFromSequentialID, + getTagOperands.getCount(), + getTagOperands.getBuffer()); + + return builder->emitMakeTuple( + {tableTag, + builder->emitReinterpret( + (IRType*)taggedUnionTupleType->getOperand(1), + builder->emitExtractExistentialValue( + (IRType*)builder->emitExtractExistentialType(baseInterfaceValue), + baseInterfaceValue))}); + } + + void lowerCastInterfaceToTaggedUnionPtr(IRCastInterfaceToTaggedUnionPtr* inst) + { + // Find all uses of the inst + traverseUses( + inst, + [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_Load: + { + auto baseInterfacePtr = inst->getOperand(0); + auto baseInterfaceType = as( + as(baseInterfacePtr->getDataType())->getValueType()); + + // Rewrite the load to use the original ptr and load + // an interface-typed object. + // + IRBuilder builder(module); + builder.setInsertAfter(user); + builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); + builder.replaceOperand(&user->typeUse, baseInterfaceType); + + // Then, we'll rewrite it. + List oldUses; + traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); + + auto newVal = convertToTaggedUnion( + &builder, + user, + baseInterfaceType, + as(inst->getDataType())->getValueType()); + for (auto oldUse : oldUses) + { + builder.replaceOperand(oldUse, newVal); + } + break; + } + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + auto baseInterfacePtr = inst->getOperand(0); + auto baseInterfaceType = + as((baseInterfacePtr->getDataType())->getOperand(0)); + + IRBuilder builder(module); + builder.setInsertAfter(user); + builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); + builder.replaceOperand(&user->typeUse, baseInterfaceType); + + // Then, we'll rewrite it. + List oldUses; + traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); + + auto newVal = convertToTaggedUnion( + &builder, + user, + baseInterfaceType, + as(inst->getDataType())->getValueType()); + for (auto oldUse : oldUses) + { + builder.replaceOperand(oldUse, newVal); + } + break; + } + default: + SLANG_UNEXPECTED("Unexpected user of CastInterfaceToTaggedUnionPtr"); + } + }); + + SLANG_ASSERT(!inst->hasUses()); + inst->removeAndDeallocate(); + } + + void lowerCastTaggedUnionToInterfacePtr(IRCastTaggedUnionToInterfacePtr* inst) + { + SLANG_UNUSED(inst); + SLANG_UNEXPECTED("Unexpected inst of CastTaggedUnionToInterfacePtr"); + } + + IRType* convertToTupleType(IRCollectionTaggedUnionType* taggedUnion) + { + // Replace type with Tuple + IRBuilder builder(module); + builder.setInsertInto(module); + + auto typeCollection = cast(taggedUnion->getOperand(0)); + auto tableCollection = cast(taggedUnion->getOperand(1)); + + if (getCollectionCount(typeCollection) == 1) + return builder.getTupleType(List( + {(IRType*)makeTagType(tableCollection), + (IRType*)getCollectionElement(typeCollection, 0)})); + + return builder.getTupleType( + List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); + } + + bool processModule() + { + // First, we'll lower all CollectionTaggedUnionType insts + // into tuples. + // + processInstsOfType( + kIROp_CollectionTaggedUnionType, + [&](IRCollectionTaggedUnionType* inst) + { + inst->replaceUsesWith(convertToTupleType(inst)); + inst->removeAndDeallocate(); + }); + + bool hasCastInsts = false; + processInstsOfType( + kIROp_CastInterfaceToTaggedUnionPtr, + [&](IRCastInterfaceToTaggedUnionPtr* inst) + { + hasCastInsts = true; + return lowerCastInterfaceToTaggedUnionPtr(inst); + }); + + processInstsOfType( + kIROp_CastTaggedUnionToInterfacePtr, + [&](IRCastTaggedUnionToInterfacePtr* inst) + { + hasCastInsts = true; + return lowerCastTaggedUnionToInterfacePtr(inst); + }); + + return hasCastInsts; + } +}; + +bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + + TaggedUnionLoweringContext context(module); + return context.processModule(); +} +}; // namespace Slang \ No newline at end of file diff --git a/source/slang/slang-ir-lower-dynamic-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h similarity index 64% rename from source/slang/slang-ir-lower-dynamic-insts.h rename to source/slang/slang-ir-lower-typeflow-insts.h index 169d1be851e..c3be5d27339 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -1,13 +1,9 @@ -// slang-ir-lower-dynamic-insts.h +// slang-ir-typeflow-specialize.h #pragma once -#include "../core/slang-linked-list.h" -#include "../core/slang-smart-pointer.h" #include "slang-ir.h" namespace Slang { -// Main entry point for the pass -bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink); void lowerTypeCollections(IRModule* module, DiagnosticSink* sink); void lowerTagInsts(IRModule* module, DiagnosticSink* sink); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 7ad9923ec03..e04ea1db552 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -5,11 +5,12 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-insts.h" -#include "slang-ir-lower-dynamic-insts.h" #include "slang-ir-lower-witness-lookup.h" #include "slang-ir-peephole.h" #include "slang-ir-sccp.h" #include "slang-ir-ssa-simplification.h" +#include "slang-ir-typeflow-collection.h" +#include "slang-ir-typeflow-specialize.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -849,7 +850,45 @@ struct SpecializationContext auto witnessTable = as(lookupInst->getWitnessTable()); if (!witnessTable) { - return false; + if (auto collection = as(lookupInst->getWitnessTable())) + { + auto requirementKey = lookupInst->getRequirementKey(); + + HashSet satisfyingValSet; + bool skipSpecialization = false; + forEachInCollection( + collection, + [&](IRInst* instElement) + { + if (auto table = as(instElement)) + { + if (auto satisfyingVal = findWitnessVal(table, requirementKey)) + { + satisfyingValSet.add(satisfyingVal); + return; + } + } + + // If we reach here, we didn't find a satisfying value. + skipSpecialization = true; + }); + + if (!skipSpecialization) + { + CollectionBuilder cBuilder(lookupInst->getModule()); + auto newCollection = cBuilder.makeSet(satisfyingValSet); + addUsersToWorkList(lookupInst); + lookupInst->replaceUsesWith(newCollection); + lookupInst->removeAndDeallocate(); + return true; + } + else + return false; + } + else + { + return false; + } } // Because we have a concrete witness table, we can @@ -1179,7 +1218,7 @@ struct SpecializationContext // if (options.lowerWitnessLookups) { - iterChanged = lowerDynamicInsts(module, sink); + iterChanged = specializeDynamicInsts(module, sink); if (iterChanged) { // We'll write out the specialization info to an inst, @@ -3090,7 +3129,6 @@ void finalizeSpecialization(IRModule* module) } } - // DUPLICATE: merge. static bool isDynamicGeneric(IRInst* callee) { @@ -3117,15 +3155,6 @@ static bool isDynamicGeneric(IRInst* callee) return false; } -static IRCollectionTagType* makeTagType(IRCollectionBase* collection) -{ - IRInst* collectionInst = collection; - // Create the tag type from the collection - IRBuilder builder(collection->getModule()); - return as( - builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); -} - static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) { auto generic = cast(specializeInst->getBase()); diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp new file mode 100644 index 00000000000..54b9b113653 --- /dev/null +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -0,0 +1,133 @@ +#include "slang-ir-typeflow-collection.h" + +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +IRCollectionTagType* makeTagType(IRCollectionBase* collection) +{ + IRInst* collectionInst = collection; + // Create the tag type from the collection + IRBuilder builder(collection->getModule()); + return as( + builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); +} + +UCount getCollectionCount(IRCollectionBase* collection) +{ + if (!collection) + return 0; + return collection->getOperandCount(); +} + +UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) +{ + auto typeCollection = taggedUnion->getOperand(0); + return getCollectionCount(as(typeCollection)); +} + +UCount getCollectionCount(IRCollectionTagType* tagType) +{ + auto collection = tagType->getOperand(0); + return getCollectionCount(as(collection)); +} + +IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) +{ + if (!collection || index >= collection->getOperandCount()) + return nullptr; + return collection->getOperand(index); +} + +IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) +{ + auto typeCollection = collectionTagType->getOperand(0); + return getCollectionElement(as(typeCollection), index); +} + +CollectionBuilder::CollectionBuilder(IRModule* module) + : module(module) +{ + this->uniqueIds = module->getUniqueIdMap(); +} + +UInt CollectionBuilder::getUniqueID(IRInst* inst) +{ + auto existingId = uniqueIds->tryGetValue(inst); + if (existingId) + return *existingId; + + auto id = uniqueIds->getCount(); + uniqueIds->add(inst, id); + return id; +} + +// Helper methods for creating canonical collections +IRCollectionBase* CollectionBuilder::createCollection(IROp op, const HashSet& elements) +{ + SLANG_ASSERT( + op == kIROp_TypeCollection || op == kIROp_FuncCollection || op == kIROp_TableCollection || + op == kIROp_GenericCollection); + + if (elements.getCount() == 0) + return nullptr; + + // Verify that all operands are global instructions + for (auto element : elements) + if (element->getParent()->getOp() != kIROp_ModuleInst) + SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); + + List sortedElements; + for (auto element : elements) + sortedElements.add(element); + + // Sort elements by their unique IDs to ensure canonical ordering + sortedElements.sort( + [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); + + // Create the collection instruction + IRBuilder builder(module); + builder.setInsertInto(module); + + // Use makeTuple as a temporary implementation until IRCollection is available + return as(builder.emitIntrinsicInst( + nullptr, + op, + sortedElements.getCount(), + sortedElements.getBuffer())); +} + +IROp CollectionBuilder::getCollectionTypeForInst(IRInst* inst) +{ + if (as(inst)) + return kIROp_GenericCollection; + + if (as(inst->getDataType())) + return kIROp_TypeCollection; + else if (as(inst->getDataType())) + return kIROp_FuncCollection; + else if (as(inst) && !as(inst)) + return kIROp_TypeCollection; + else if (as(inst->getDataType())) + return kIROp_TableCollection; + else + return kIROp_Invalid; // Return invalid IROp when not supported +} + +// Factory methods for PropagationInfo +IRCollectionBase* CollectionBuilder::makeSingletonSet(IRInst* value) +{ + HashSet singleSet; + singleSet.add(value); + return createCollection(getCollectionTypeForInst(value), singleSet); +} + +IRCollectionBase* CollectionBuilder::makeSet(const HashSet& values) +{ + SLANG_ASSERT(values.getCount() > 0); + return createCollection(getCollectionTypeForInst(*values.begin()), values); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h new file mode 100644 index 00000000000..ca399a4b208 --- /dev/null +++ b/source/slang/slang-ir-typeflow-collection.h @@ -0,0 +1,53 @@ +// slang-ir-typeflow-collection.h +#pragma once +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +IRCollectionTagType* makeTagType(IRCollectionBase* collection); + +UCount getCollectionCount(IRCollectionBase* collection); +UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion); +UCount getCollectionCount(IRCollectionTagType* tagType); + +IRInst* getCollectionElement(IRCollectionBase* collection, UInt index); +IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index); + +// Helper to iterate over collection elements + +template +void forEachInCollection(IRCollectionBase* info, F func) +{ + for (UInt i = 0; i < info->getOperandCount(); ++i) + func(info->getOperand(i)); +} + +template +void forEachInCollection(IRCollectionTagType* tagType, F func) +{ + forEachInCollection(as(tagType->getOperand(0)), func); +} + +struct CollectionBuilder +{ + CollectionBuilder(IRModule* module); + + UInt getUniqueID(IRInst* inst); + + // Helper methods for creating canonical collections + IRCollectionBase* createCollection(IROp op, const HashSet& elements); + IROp getCollectionTypeForInst(IRInst* inst); + IRCollectionBase* makeSingletonSet(IRInst* value); + IRCollectionBase* makeSet(const HashSet& values); + +private: + // Reference to parent module + IRModule* module; + + // Unique ID assignment for functions and witness tables + Dictionary* uniqueIds; +}; + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-typeflow-specialize.cpp similarity index 75% rename from source/slang/slang-ir-lower-dynamic-insts.cpp rename to source/slang/slang-ir-typeflow-specialize.cpp index c638c613604..eabebdf3085 100644 --- a/source/slang/slang-ir-lower-dynamic-insts.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -1,10 +1,11 @@ -#include "slang-ir-lower-dynamic-insts.h" +#include "slang-ir-typeflow-specialize.h" #include "slang-ir-any-value-marshalling.h" #include "slang-ir-clone.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" #include "slang-ir-specialize.h" +#include "slang-ir-typeflow-collection.h" #include "slang-ir-util.h" #include "slang-ir-witness-table-wrapper.h" #include "slang-ir.h" @@ -210,29 +211,6 @@ bool areInfosEqual(IRInst* a, IRInst* b) return a == b; } -// Helper to iterate over collection elements -template -void forEachInCollection(IRCollectionBase* info, F func) -{ - for (UInt i = 0; i < info->getOperandCount(); ++i) - func(info->getOperand(i)); -} - -template -void forEachInCollection(IRCollectionTagType* tagType, F func) -{ - forEachInCollection(as(tagType->getOperand(0)), func); -} - -static IRInst* findEntryInConcreteTable(IRInst* witnessTable, IRInst* key) -{ - if (auto concreteTable = as(witnessTable)) - for (auto entry : concreteTable->getEntries()) - if (entry->getRequirementKey() == key) - return entry->getSatisfyingVal(); - return nullptr; // Not found -} - struct WorkQueue { List enqueueList; @@ -260,119 +238,8 @@ struct WorkQueue } }; - -// TODO: Move to utilities - -IRCollectionTagType* makeTagType(IRCollectionBase* collection) +struct TypeFlowSpecializationContext { - IRInst* collectionInst = collection; - // Create the tag type from the collection - IRBuilder builder(collection->getModule()); - return as( - builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); -} - -UCount getCollectionCount(IRCollectionBase* collection) -{ - if (!collection) - return 0; - return collection->getOperandCount(); -} - -UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) -{ - auto typeCollection = taggedUnion->getOperand(0); - return getCollectionCount(as(typeCollection)); -} - -UCount getCollectionCount(IRCollectionTagType* tagType) -{ - auto collection = tagType->getOperand(0); - return getCollectionCount(as(collection)); -} - -IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) -{ - if (!collection || index >= collection->getOperandCount()) - return nullptr; - return collection->getOperand(index); -} - -IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) -{ - auto typeCollection = collectionTagType->getOperand(0); - return getCollectionElement(as(typeCollection), index); -} - - -struct DynamicInstLoweringContext -{ - // Helper methods for creating canonical collections - IRCollectionBase* createCollection(IROp op, const HashSet& elements) - { - SLANG_ASSERT( - op == kIROp_TypeCollection || op == kIROp_FuncCollection || - op == kIROp_TableCollection || op == kIROp_GenericCollection); - - if (elements.getCount() == 0) - return nullptr; - - // Verify that all operands are global instructions - for (auto element : elements) - if (element->getParent()->getOp() != kIROp_ModuleInst) - SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); - - List sortedElements; - for (auto element : elements) - sortedElements.add(element); - - // Sort elements by their unique IDs to ensure canonical ordering - sortedElements.sort( - [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); - - // Create the collection instruction - IRBuilder builder(module); - builder.setInsertInto(module); - - // Use makeTuple as a temporary implementation until IRCollection is available - return as(builder.emitIntrinsicInst( - nullptr, - op, - sortedElements.getCount(), - sortedElements.getBuffer())); - } - - IROp getCollectionTypeForInst(IRInst* inst) - { - if (as(inst)) - return kIROp_GenericCollection; - - if (as(inst->getDataType())) - return kIROp_TypeCollection; - else if (as(inst->getDataType())) - return kIROp_FuncCollection; - else if (as(inst) && !as(inst)) - return kIROp_TypeCollection; - else if (as(inst->getDataType())) - return kIROp_TableCollection; - else - return kIROp_Invalid; // Return invalid IROp when not supported - } - - // Factory methods for PropagationInfo - IRCollectionBase* makeSingletonSet(IRInst* value) - { - HashSet singleSet; - singleSet.add(value); - return createCollection(getCollectionTypeForInst(value), singleSet); - } - - IRCollectionBase* makeSet(const HashSet& values) - { - SLANG_ASSERT(values.getCount() > 0); - return createCollection(getCollectionTypeForInst(*values.begin()), values); - } - IRCollectionTaggedUnionType* makeExistential(IRTableCollection* tableCollection) { HashSet typeSet; @@ -385,7 +252,7 @@ struct DynamicInstLoweringContext typeSet.add(table->getConcreteType()); }); - auto typeCollection = createCollection(kIROp_TypeCollection, typeSet); + auto typeCollection = cBuilder.createCollection(kIROp_TypeCollection, typeSet); // Create the tagged union type IRBuilder builder(module); @@ -446,7 +313,7 @@ struct DynamicInstLoweringContext // TODO: We really should return something like Singleton(collectionInst) here // instead of directly returning the collection. // - return makeSingletonSet(inst); + return cBuilder.makeSingletonSet(inst); } else return none(); @@ -816,8 +683,8 @@ struct DynamicInstLoweringContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential( - as(createCollection(kIROp_TableCollection, tables))); + return makeExistential(as( + cBuilder.createCollection(kIROp_TableCollection, tables))); else return none(); } @@ -845,7 +712,7 @@ struct DynamicInstLoweringContext return makeUnbounded(); if (as(witnessTable)) - return makeExistential(as(makeSingletonSet(witnessTable))); + return makeExistential(as(cBuilder.makeSingletonSet(witnessTable))); if (auto collectionTag = as(witnessTableInfo)) return makeExistential(cast(collectionTag->getOperand(0))); @@ -920,7 +787,7 @@ struct DynamicInstLoweringContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) return makeExistential(as( - createCollection(kIROp_TableCollection, tables))); + cBuilder.createCollection(kIROp_TableCollection, tables))); else return none(); } @@ -946,8 +813,8 @@ struct DynamicInstLoweringContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential( - as(createCollection(kIROp_TableCollection, tables))); + return makeExistential(as( + cBuilder.createCollection(kIROp_TableCollection, tables))); else return none(); } @@ -1066,8 +933,9 @@ struct DynamicInstLoweringContext HashSet results; forEachInCollection( cast(tagType->getOperand(0)), - [&](IRInst* table) { results.add(findEntryInConcreteTable(table, key)); }); - return makeTagType(makeSet(results)); + [&](IRInst* table) + { results.add(findWitnessTableEntry(cast(table), key)); }); + return makeTagType(cBuilder.makeSet(results)); } SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); @@ -1289,9 +1157,9 @@ struct DynamicInstLoweringContext }); if (needsTag) - return makeTagType(makeSet(specializedSet)); + return makeTagType(cBuilder.makeSet(specializedSet)); else - return makeSet(specializedSet); + return cBuilder.makeSet(specializedSet); } SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); @@ -1345,7 +1213,7 @@ struct DynamicInstLoweringContext updateInfo( context, param, - makeTagType(makeSingletonSet(arg)), + makeTagType(cBuilder.makeSingletonSet(arg)), true, workQueue); } @@ -1842,7 +1710,7 @@ struct DynamicInstLoweringContext forEachInCollection(collection1, [&](IRInst* value) { allValues.add(value); }); forEachInCollection(collection2, [&](IRInst* value) { allValues.add(value); }); - return as(createCollection( + return as(cBuilder.createCollection( collection1->getOp(), allValues)); // Create a new collection with the union of values } @@ -1918,7 +1786,7 @@ struct DynamicInstLoweringContext if (isGlobalInst(inst) && (!as(inst) && (as(inst) || as(inst) || as(inst)))) - return makeSingletonSet(inst); + return cBuilder.makeSingletonSet(inst); auto instType = inst->getDataType(); if (isGlobalInst(inst)) @@ -1939,7 +1807,7 @@ struct DynamicInstLoweringContext return none(); // Default case, no propagation info } - bool lowerInstsInBlock(IRInst* context, IRBlock* block) + bool specializeInstsInBlock(IRInst* context, IRBlock* block) { List instsToLower; bool hasChanges = false; @@ -1948,13 +1816,13 @@ struct DynamicInstLoweringContext for (auto inst : instsToLower) { - hasChanges |= lowerInst(context, inst); + hasChanges |= specializeInst(context, inst); } return hasChanges; } - bool lowerStructType(IRStructType* structType) + bool specializeStructType(IRStructType* structType) { bool hasChanges = false; for (auto field : structType->getFields()) @@ -1964,18 +1832,18 @@ struct DynamicInstLoweringContext if (!info) continue; - auto loweredFieldType = getLoweredType(info); - if (loweredFieldType != field->getFieldType()) + auto specializedFieldType = getLoweredType(info); + if (specializedFieldType != field->getFieldType()) { hasChanges = true; - field->setFieldType(loweredFieldType); + field->setFieldType(specializedFieldType); } } return hasChanges; } - bool lowerFunc(IRFunc* func) + bool specializeFunc(IRFunc* func) { // Don't make any changes to non-global or intrinsic functions if (!isGlobalInst(func) || isIntrinsic(func)) @@ -1983,7 +1851,7 @@ struct DynamicInstLoweringContext bool hasChanges = false; for (auto block : func->getBlocks()) - hasChanges |= lowerInstsInBlock(func, block); + hasChanges |= specializeInstsInBlock(func, block); for (auto block : func->getBlocks()) { @@ -2025,10 +1893,10 @@ struct DynamicInstLoweringContext { if (!as(returnInst->getVal()->getDataType())) { - if (auto loweredType = getLoweredType(getFuncReturnInfo(func))) + if (auto specializedType = getLoweredType(getFuncReturnInfo(func))) { auto newReturnVal = - upcastCollection(func, returnInst->getVal(), loweredType); + upcastCollection(func, returnInst->getVal(), specializedType); if (newReturnVal != returnInst->getVal()) { // Replace the return value with the reinterpreted value @@ -2066,13 +1934,13 @@ struct DynamicInstLoweringContext bool hasChanges = false; // Lower struct types first so that data access can be - // marshalled properly during func lowering. + // marshalled properly during func specializeing. // for (auto structType : structsToProcess) - hasChanges |= lowerStructType(structType); + hasChanges |= specializeStructType(structType); for (auto func : funcsToProcess) - hasChanges |= lowerFunc(func); + hasChanges |= specializeFunc(func); return hasChanges; } @@ -2088,9 +1956,9 @@ struct DynamicInstLoweringContext if (auto ptrType = as(info)) { IRBuilder builder(module); - if (auto loweredValueType = getLoweredType(ptrType->getValueType())) + if (auto specializedValueType = getLoweredType(ptrType->getValueType())) { - return builder.getPtrTypeWithAddressSpace((IRType*)loweredValueType, ptrType); + return builder.getPtrTypeWithAddressSpace((IRType*)specializedValueType, ptrType); } else return nullptr; @@ -2099,10 +1967,10 @@ struct DynamicInstLoweringContext if (auto arrayType = as(info)) { IRBuilder builder(module); - if (auto loweredElementType = getLoweredType(arrayType->getElementType())) + if (auto specializedElementType = getLoweredType(arrayType->getElementType())) { return builder.getArrayType( - (IRType*)loweredElementType, + (IRType*)specializedElementType, arrayType->getElementCount()); } else @@ -2136,8 +2004,8 @@ struct DynamicInstLoweringContext if (as(info) || as(info)) { - // Don't lower these collections.. they should be used through - // tag types, or be processed out during lowering. + // Don't specialize these collections.. they should be used through + // tag types, or be processed out during specializeing. // return nullptr; } @@ -2160,56 +2028,59 @@ struct DynamicInstLoweringContext if (auto info = tryGetInfo(context, inst)) { - if (auto loweredType = getLoweredType(info)) + if (auto specializedType = getLoweredType(info)) { - if (loweredType == inst->getDataType()) + if (specializedType == inst->getDataType()) return false; // No change - inst->setFullType(loweredType); + inst->setFullType(specializedType); return true; } } return false; } - bool lowerInst(IRInst* context, IRInst* inst) + bool specializeInst(IRInst* context, IRInst* inst) { switch (inst->getOp()) { case kIROp_LookupWitnessMethod: - return lowerLookupWitnessMethod(context, as(inst)); + return specializeLookupWitnessMethod(context, as(inst)); case kIROp_ExtractExistentialWitnessTable: - return lowerExtractExistentialWitnessTable( + return specializeExtractExistentialWitnessTable( context, as(inst)); case kIROp_ExtractExistentialType: - return lowerExtractExistentialType(context, as(inst)); + return specializeExtractExistentialType(context, as(inst)); case kIROp_ExtractExistentialValue: - return lowerExtractExistentialValue(context, as(inst)); + return specializeExtractExistentialValue(context, as(inst)); case kIROp_Call: - return lowerCall(context, as(inst)); + return specializeCall(context, as(inst)); case kIROp_MakeExistential: - return lowerMakeExistential(context, as(inst)); + return specializeMakeExistential(context, as(inst)); case kIROp_MakeStruct: - return lowerMakeStruct(context, as(inst)); + return specializeMakeStruct(context, as(inst)); case kIROp_CreateExistentialObject: - return lowerCreateExistentialObject(context, as(inst)); + return specializeCreateExistentialObject(context, as(inst)); case kIROp_RWStructuredBufferLoad: case kIROp_StructuredBufferLoad: - return lowerStructuredBufferLoad(context, inst); + return specializeStructuredBufferLoad(context, inst); case kIROp_Specialize: - return lowerSpecialize(context, as(inst)); + return specializeSpecialize(context, as(inst)); case kIROp_GetValueFromBoundInterface: - return lowerGetValueFromBoundInterface(context, as(inst)); + return specializeGetValueFromBoundInterface( + context, + as(inst)); case kIROp_Load: - return lowerLoad(context, inst); + return specializeLoad(context, inst); case kIROp_Store: - return lowerStore(context, as(inst)); + return specializeStore(context, as(inst)); case kIROp_GetSequentialID: - return lowerGetSequentialID(context, as(inst)); + return specializeGetSequentialID(context, as(inst)); case kIROp_IsType: - return lowerIsType(context, as(inst)); + return specializeIsType(context, as(inst)); default: { + // Default case: replace inst type with specialized type (if available) if (tryGetInfo(context, inst)) return replaceType(context, inst); return false; @@ -2217,13 +2088,12 @@ struct DynamicInstLoweringContext } } - bool lowerLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) + bool specializeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { // Handle trivial case. if (auto witnessTable = as(inst->getWitnessTable())) { - inst->replaceUsesWith( - findEntryInConcreteTable(witnessTable, inst->getRequirementKey())); + inst->replaceUsesWith(findWitnessTableEntry(witnessTable, inst->getRequirementKey())); inst->removeAndDeallocate(); return true; } @@ -2276,7 +2146,7 @@ struct DynamicInstLoweringContext return false; } - bool lowerExtractExistentialWitnessTable( + bool specializeExtractExistentialWitnessTable( IRInst* context, IRExtractExistentialWitnessTable* inst) { @@ -2300,7 +2170,7 @@ struct DynamicInstLoweringContext } else { - // Replace with GetElement(loweredInst, 0) -> uint + // Replace with GetElement(specializedInst, 0) -> uint auto operand = inst->getOperand(0); auto element = builder.emitGetTupleElement((IRType*)collectionTagType, operand, 0); inst->replaceUsesWith(element); @@ -2309,7 +2179,7 @@ struct DynamicInstLoweringContext } } - bool lowerExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) + bool specializeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { SLANG_UNUSED(context); @@ -2330,7 +2200,7 @@ struct DynamicInstLoweringContext return false; } - bool lowerExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) + bool specializeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto info = tryGetInfo(context, inst); auto collectionTagType = as(info); @@ -2428,7 +2298,7 @@ struct DynamicInstLoweringContext } // If this is a collection, we need to create a new collection with the new type - auto newCollection = createCollection(collection->getOp(), collectionElements); + auto newCollection = cBuilder.createCollection(collection->getOp(), collectionElements); return (IRType*)newCollection; } else if (currentType == newType) @@ -2469,7 +2339,8 @@ struct DynamicInstLoweringContext collectionElements.add(newType); // If this is a collection, we need to create a new collection with the new type - auto newCollection = createCollection(kIROp_TypeCollection, collectionElements); + auto newCollection = + cBuilder.createCollection(kIROp_TypeCollection, collectionElements); return (IRType*)newCollection; } } @@ -2604,7 +2475,7 @@ struct DynamicInstLoweringContext IRInst* getCalleeForContext(IRInst* context) { if (this->contextsToLower.contains(context)) - return context; // Not lowered yet. + return context; // Not specialized yet. if (this->loweredContexts.containsKey(context)) return this->loweredContexts[context]; @@ -2632,7 +2503,7 @@ struct DynamicInstLoweringContext return callArgs; } - bool lowerCallToDynamicGeneric(IRInst* context, IRCall* inst) + bool specializeCallToDynamicGeneric(IRInst* context, IRCall* inst) { auto specializedCallee = as(inst->getCallee()); auto calleeInfo = tryGetInfo(context, specializedCallee); @@ -2663,7 +2534,7 @@ struct DynamicInstLoweringContext else { // If it's a witness table, we need to handle it differently - // For now, we will not lower this case. + // For now, we will not specialize this case. SLANG_UNEXPECTED("Unhandled type-flow-collection in dynamic generic call"); } } @@ -2692,7 +2563,7 @@ struct DynamicInstLoweringContext } } - bool lowerCall(IRInst* context, IRCall* inst) + bool specializeCall(IRInst* context, IRCall* inst) { auto callee = inst->getCallee(); IRInst* calleeTagInst = nullptr; @@ -2700,7 +2571,7 @@ struct DynamicInstLoweringContext // This is a bit of a workaround for specialized callee's // whose function types haven't been specialized yet (can // occur for concrete IRSpecialize insts that are created - // during the lowering process). + // during the specializeing process). // maybeSpecializeCalleeType(callee); @@ -2724,7 +2595,7 @@ struct DynamicInstLoweringContext } // If by this point, we haven't resolved our callee into a global inst ( - // either a collection or a single function), then we can't lower it (likely unbounded) + // either a collection or a single function), then we can't specialize it (likely unbounded) // if (!isGlobalInst(callee) || isIntrinsic(callee)) return false; @@ -2743,7 +2614,7 @@ struct DynamicInstLoweringContext // Determine a new callee. auto calleeCollection = as(callee); if (!calleeCollection) - newCallee = callee; // Not a collection, no need to lower + newCallee = callee; // Not a collection, no need to specialize else if (getCollectionCount(calleeCollection) == 1) { auto singletonValue = getCollectionElement(calleeCollection, 0); @@ -2813,7 +2684,7 @@ struct DynamicInstLoweringContext break; } default: - SLANG_UNEXPECTED("Unhandled parameter direction in lowerCall"); + SLANG_UNEXPECTED("Unhandled parameter direction in specializeCall"); } } @@ -2863,7 +2734,7 @@ struct DynamicInstLoweringContext } } - bool lowerMakeStruct(IRInst* context, IRMakeStruct* inst) + bool specializeMakeStruct(IRInst* context, IRMakeStruct* inst) { auto structType = as(inst->getDataType()); if (!structType) @@ -2889,7 +2760,7 @@ struct DynamicInstLoweringContext return changed; } - bool lowerMakeExistential(IRInst* context, IRMakeExistential* inst) + bool specializeMakeExistential(IRInst* context, IRMakeExistential* inst) { auto info = tryGetInfo(context, inst); auto taggedUnion = as(info); @@ -2906,7 +2777,7 @@ struct DynamicInstLoweringContext IRInst* witnessTableID = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { - auto singletonTagType = makeTagType(makeSingletonSet(witnessTable)); + auto singletonTagType = makeTagType(cBuilder.makeSingletonSet(witnessTable)); auto zeroValueOfTagType = builder.getIntValue((IRType*)singletonTagType, 0); witnessTableID = builder.emitIntrinsicInst( (IRType*)makeTagType(tableCollection), @@ -2941,7 +2812,7 @@ struct DynamicInstLoweringContext return true; } - bool lowerCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) + bool specializeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { auto info = tryGetInfo(context, inst); auto taggedUnion = as(info); @@ -2982,7 +2853,7 @@ struct DynamicInstLoweringContext return true; } - bool lowerStructuredBufferLoad(IRInst* context, IRInst* inst) + bool specializeStructuredBufferLoad(IRInst* context, IRInst* inst) { auto valInfo = tryGetInfo(context, inst); @@ -2992,15 +2863,15 @@ struct DynamicInstLoweringContext auto bufferType = (IRType*)inst->getOperand(0)->getDataType(); auto bufferBaseType = (IRType*)bufferType->getOperand(0); - auto loweredValType = (IRType*)getLoweredType(valInfo); - if (bufferBaseType != loweredValType) + auto specializedValType = (IRType*)getLoweredType(valInfo); + if (bufferBaseType != specializedValType) { if (as(bufferBaseType) && !isComInterfaceType(bufferBaseType) && !isBuiltin(bufferBaseType)) { // If we're dealing with a loading a known tagged union value from // an interface-typed pointer, we'll cast the pointer itself and - // defer the lowering of the load until later. + // defer the specializeing of the load until later. // // This avoids having to change the source pointer type // and confusing any future runs of the type flow @@ -3010,13 +2881,13 @@ struct DynamicInstLoweringContext builder.setInsertAfter(inst); auto bufferHandle = inst->getOperand(0); auto newHandle = builder.emitIntrinsicInst( - builder.getPtrType(loweredValType), + builder.getPtrType(specializedValType), kIROp_CastInterfaceToTaggedUnionPtr, 1, &bufferHandle); List newLoadOperands = {newHandle, inst->getOperand(1)}; auto newLoad = builder.emitIntrinsicInst( - loweredValType, + specializedValType, inst->getOp(), newLoadOperands.getCount(), newLoadOperands.getBuffer()); @@ -3037,7 +2908,7 @@ struct DynamicInstLoweringContext return false; } - bool lowerSpecialize(IRInst* context, IRSpecialize* inst) + bool specializeSpecialize(IRInst* context, IRSpecialize* inst) { bool isFuncReturn = false; @@ -3051,12 +2922,12 @@ struct DynamicInstLoweringContext isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; } - // Functions/Collections of Functions should be handled at the call site (in lowerCall) + // Functions/Collections of Functions should be handled at the call site (in specializeCall) // since witness table specialization arguments must be inlined into the call. // if (isFuncReturn) { - // TODO: Maybe make this the 'default' behavior if a lowering call + // TODO: Maybe make this the 'default' behavior if a specializeing call // returns false. // if (tryGetInfo(context, inst)) @@ -3104,7 +2975,7 @@ struct DynamicInstLoweringContext } - bool lowerGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) + bool specializeGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) { SLANG_UNUSED(context); auto destType = inst->getDataType(); @@ -3121,7 +2992,7 @@ struct DynamicInstLoweringContext return false; } - bool lowerLoad(IRInst* context, IRInst* inst) + bool specializeLoad(IRInst* context, IRInst* inst) { auto valInfo = tryGetInfo(context, inst); @@ -3132,8 +3003,8 @@ struct DynamicInstLoweringContext auto loadPtrType = as(loadPtr->getDataType()); auto ptrValType = loadPtrType->getValueType(); - IRType* loweredType = (IRType*)getLoweredType(valInfo); - if (ptrValType != loweredType) + IRType* specializedType = (IRType*)getLoweredType(valInfo); + if (ptrValType != specializedType) { SLANG_ASSERT(!as(inst)); @@ -3142,7 +3013,7 @@ struct DynamicInstLoweringContext { // If we're dealing with a loading a known tagged union value from // an interface-typed pointer, we'll cast the pointer itself and - // defer the lowering of the load until later. + // defer the specializeing of the load until later. // // This avoids having to change the source pointer type // and confusing any future runs of the type flow @@ -3151,11 +3022,11 @@ struct DynamicInstLoweringContext IRBuilder builder(inst); builder.setInsertAfter(inst); auto newLoadPtr = builder.emitIntrinsicInst( - builder.getPtrTypeWithAddressSpace(loweredType, loadPtrType), + builder.getPtrTypeWithAddressSpace(specializedType, loadPtrType), kIROp_CastInterfaceToTaggedUnionPtr, 1, &loadPtr); - auto newLoad = builder.emitLoad(loweredType, newLoadPtr); + auto newLoad = builder.emitLoad(specializedType, newLoadPtr); inst->replaceUsesWith(newLoad); inst->removeAndDeallocate(); @@ -3190,7 +3061,7 @@ struct DynamicInstLoweringContext return false; } - bool lowerStore(IRInst* context, IRStore* inst) + bool specializeStore(IRInst* context, IRStore* inst) { auto ptr = inst->getPtr(); auto ptrInfo = as(ptr->getDataType())->getValueType(); @@ -3204,20 +3075,20 @@ struct DynamicInstLoweringContext if (as(inst->getVal())) return handleDefaultStore(context, inst); - auto loweredVal = upcastCollection(context, inst->getVal(), ptrInfo); + auto specializedVal = upcastCollection(context, inst->getVal(), ptrInfo); - if (loweredVal != inst->getVal()) + if (specializedVal != inst->getVal()) { // If the value was changed, we need to update the store instruction. IRBuilder builder(inst); - builder.replaceOperand(inst->getValUse(), loweredVal); + builder.replaceOperand(inst->getValUse(), specializedVal); return true; } return false; } - bool lowerGetSequentialID(IRInst* context, IRGetSequentialID* inst) + bool specializeGetSequentialID(IRInst* context, IRGetSequentialID* inst) { SLANG_UNUSED(context); auto arg = inst->getOperand(0); @@ -3244,7 +3115,7 @@ struct DynamicInstLoweringContext return false; } - bool lowerIsType(IRInst* context, IRIsType* inst) + bool specializeIsType(IRInst* context, IRIsType* inst) { SLANG_UNUSED(context); auto witnessTableArg = inst->getValueWitness(); @@ -3279,23 +3150,6 @@ struct DynamicInstLoweringContext return false; } - UInt getUniqueID(IRInst* inst) - { - auto existingId = uniqueIds.tryGetValue(inst); - if (existingId) - return *existingId; - - // If we reach here, the instruction was not assigned an ID during initialization. - // This can happen for instructions that are generated during the analysis. - // - // We will ensure that they are moved to the end of the module, and assign them a new ID. - // This will ensure a stable ordering on subsequent passes. - // - inst->moveToEnd(); - uniqueIds[inst] = nextUniqueId; - return nextUniqueId++; - } - bool isExistentialType(IRType* type) { return as(type) != nullptr; } bool isInterfaceType(IRType* type) { return as(type) != nullptr; } @@ -3349,26 +3203,11 @@ struct DynamicInstLoweringContext return hasChanges; } - DynamicInstLoweringContext(IRModule* module, DiagnosticSink* sink) - : module(module), sink(sink) + TypeFlowSpecializationContext(IRModule* module, DiagnosticSink* sink) + : module(module), sink(sink), cBuilder(module) { - initializeUniqueIDs(); } - // Initialize unique IDs for all global instructions that can be part of collections - void initializeUniqueIDs() - { - UInt currentID = 1; - for (auto inst : module->getGlobalInsts()) - { - // Only assign IDs to instructions that can be part of collections - IROp collectionType = getCollectionTypeForInst(inst); - if (collectionType != kIROp_Invalid) - { - uniqueIds[inst] = currentID++; - } - } - } // Basic context IRModule* module; @@ -3389,11 +3228,7 @@ struct DynamicInstLoweringContext // Mapping from fields to use-sites. Dictionary> fieldUseSites; - // Unique ID assignment for functions and witness tables - Dictionary uniqueIds; - UInt nextUniqueId = 1; - - // Mapping from lowered instruction to their any-value types + // Mapping from specialized instruction to their any-value types Dictionary loweredInstToAnyValueType; // Set of open contexts @@ -3404,726 +3239,15 @@ struct DynamicInstLoweringContext // Lowered contexts. Dictionary loweredContexts; -}; - -SlangInt calculateAnyValueSize(const HashSet& types) -{ - SlangInt maxSize = 0; - for (auto type : types) - { - auto size = getAnyValueSize(type); - if (size > maxSize) - maxSize = size; - } - return maxSize; -} - -IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& types) -{ - auto size = calculateAnyValueSize(types); - return builder->getAnyValueType(size); -} - -IRFunc* createDispatchFunc(IRFuncCollection* collection) -{ - // An effective func type should have been set during the dynamic-inst-lowering - // pass. - // - IRFuncType* dispatchFuncType = cast(collection->getFullType()); - - // Create a dispatch function with switch-case for each function - IRBuilder builder(collection->getModule()); - - // Consume the first parameter of the expected function type - List innerParamTypes; - for (auto paramType : dispatchFuncType->getParamTypes()) - innerParamTypes.add(paramType); - innerParamTypes.removeAt(0); // Remove the first parameter (ID) - - auto resultType = dispatchFuncType->getResultType(); - auto innerFuncType = builder.getFuncType(innerParamTypes, resultType); - - auto func = builder.createFunc(); - builder.setInsertInto(func); - func->setFullType(dispatchFuncType); - - auto entryBlock = builder.emitBlock(); - builder.setInsertInto(entryBlock); - - auto idParam = builder.emitParam(builder.getUIntType()); - - // Create parameters for the original function arguments - List originalParams; - for (Index i = 0; i < innerParamTypes.getCount(); i++) - { - originalParams.add(builder.emitParam(innerParamTypes[i])); - } - - // Create default block - auto defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); - } - else - { - // Return a default-constructed value - auto defaultValue = builder.emitDefaultConstruct(resultType); - builder.emitReturn(defaultValue); - } - - // Go back to entry block and create switch - builder.setInsertInto(entryBlock); - - // Create case blocks for each function - List caseValues; - List caseBlocks; - - UIndex funcSeqID = 0; - forEachInCollection( - collection, - [&](IRInst* funcInst) - { - auto funcId = funcSeqID++; - auto wrapperFunc = - emitWitnessTableWrapper(funcInst->getModule(), funcInst, innerFuncType); - - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); - - List callArgs; - auto wrappedFuncType = as(wrapperFunc->getDataType()); - for (Index ii = 0; ii < originalParams.getCount(); ii++) - { - callArgs.add(originalParams[ii]); - } - - // Call the specific function - auto callResult = - builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); - - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); - } - else - { - builder.emitReturn(callResult); - } - - caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); - caseBlocks.add(caseBlock); - }); - - // Create flattened case arguments array - List flattenedCaseArgs; - for (Index i = 0; i < caseValues.getCount(); i++) - { - flattenedCaseArgs.add(caseValues[i]); - flattenedCaseArgs.add(caseBlocks[i]); - } - - // Create an unreachable block for the break block. - auto unreachableBlock = builder.emitBlock(); - builder.setInsertInto(unreachableBlock); - builder.emitUnreachable(); - - // Go back to entry and emit switch - builder.setInsertInto(entryBlock); - builder.emitSwitch( - idParam, - unreachableBlock, - defaultBlock, - flattenedCaseArgs.getCount(), - flattenedCaseArgs.getBuffer()); - - return func; -} - - -IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping, UInt defaultVal) -{ - // Create a function that maps input IDs to output IDs - IRBuilder builder(module); - - auto funcType = - builder.getFuncType(List({builder.getUIntType()}), builder.getUIntType()); - auto func = builder.createFunc(); - builder.setInsertInto(func); - func->setFullType(funcType); - - auto entryBlock = builder.emitBlock(); - builder.setInsertInto(entryBlock); - - auto param = builder.emitParam(builder.getUIntType()); - - // Create default block that returns 0 - auto defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), defaultVal)); - - // Go back to entry block and create switch - builder.setInsertInto(entryBlock); - - // Create case blocks for each input table - List caseValues; - List caseBlocks; - - for (auto item : mapping) - { - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); - builder.emitReturn(builder.getIntValue(builder.getUIntType(), item.second)); - - caseValues.add(builder.getIntValue(builder.getUIntType(), item.first)); - caseBlocks.add(caseBlock); - } - - // Create flattened case arguments array - List flattenedCaseArgs; - for (Index i = 0; i < caseValues.getCount(); i++) - { - flattenedCaseArgs.add(caseValues[i]); - flattenedCaseArgs.add(caseBlocks[i]); - } - - // Emit an unreachable block for the break block. - auto unreachableBlock = builder.emitBlock(); - builder.setInsertInto(unreachableBlock); - builder.emitUnreachable(); - - // Go back to entry and emit switch - builder.setInsertInto(entryBlock); - builder.emitSwitch( - param, - unreachableBlock, - defaultBlock, - flattenedCaseArgs.getCount(), - flattenedCaseArgs.getBuffer()); - - return func; -} - -// This context lowers `IRGetTagFromSequentialID`, -// `IRGetTagForSuperCollection`, and `IRGetTagForMappedCollection` instructions, -// - -struct TagOpsLoweringContext : public InstPassBase -{ - TagOpsLoweringContext(IRModule* module) - : InstPassBase(module) - { - } - - void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) - { - auto srcCollection = cast( - cast(inst->getOperand(0)->getDataType())->getOperand(0)); - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); - - IRBuilder builder(inst->getModule()); - builder.setInsertAfter(inst); - - List indices; - for (UInt i = 0; i < srcCollection->getOperandCount(); i++) - { - // Find in destCollection - auto srcElement = srcCollection->getOperand(i); - - bool found = false; - for (UInt j = 0; j < destCollection->getOperandCount(); j++) - { - auto destElement = destCollection->getOperand(j); - if (srcElement == destElement) - { - found = true; - indices.add(builder.getIntValue(builder.getUIntType(), j)); - break; // Found the index - } - } - - if (!found) - { - // destCollection must be a super-set - SLANG_UNEXPECTED("Element not found in destination collection"); - } - } - - // Create an array for the lookup - auto lookupArrayType = builder.getArrayType( - builder.getUIntType(), - builder.getIntValue(builder.getUIntType(), indices.getCount())); - auto lookupArray = - builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); - auto resultID = - builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); - inst->replaceUsesWith(resultID); - inst->removeAndDeallocate(); - } - - void lowerGetTagForMappedCollection(IRGetTagForMappedCollection* inst) - { - auto srcCollection = cast( - cast(inst->getOperand(0)->getDataType())->getOperand(0)); - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); - auto key = cast(inst->getOperand(1)); - - IRBuilder builder(inst->getModule()); - builder.setInsertAfter(inst); - - List indices; - for (UInt i = 0; i < srcCollection->getOperandCount(); i++) - { - // Find in destCollection - bool found = false; - auto srcElement = findEntryInConcreteTable(srcCollection->getOperand(i), key); - for (UInt j = 0; j < destCollection->getOperandCount(); j++) - { - auto destElement = destCollection->getOperand(j); - if (srcElement == destElement) - { - found = true; - indices.add(builder.getIntValue(builder.getUIntType(), j)); - break; // Found the index - } - } - - if (!found) - { - // destCollection must be a super-set - SLANG_UNEXPECTED("Element not found in destination collection"); - } - } - // Create an array for the lookup - auto lookupArrayType = builder.getArrayType( - builder.getUIntType(), - builder.getIntValue(builder.getUIntType(), indices.getCount())); - auto lookupArray = - builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); - auto resultID = - builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); - inst->replaceUsesWith(resultID); - inst->removeAndDeallocate(); - } - - void processInst(IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_GetTagForSuperCollection: - lowerGetTagForSuperCollection(as(inst)); - break; - case kIROp_GetTagForMappedCollection: - lowerGetTagForMappedCollection(as(inst)); - break; - default: - break; - } - } - - void lowerFuncCollection(IRFuncCollection* collection) - { - IRBuilder builder(collection->getModule()); - if (collection->hasUses() && collection->getDataType() != nullptr) - { - auto dispatchFunc = createDispatchFunc(collection); - traverseUses( - collection, - [&](IRUse* use) - { - if (auto callInst = as(use->getUser())) - { - // If the call is a collection call, replace it with the dispatch function - if (callInst->getCallee() == collection) - { - IRBuilder callBuilder(callInst); - callBuilder.setInsertBefore(callInst); - callBuilder.replaceOperand(callInst->getCalleeUse(), dispatchFunc); - } - } - }); - } - } - - void processModule() - { - processInstsOfType( - kIROp_FuncCollection, - [&](IRFuncCollection* inst) { return lowerFuncCollection(inst); }); - - processAllInsts([&](IRInst* inst) { return processInst(inst); }); - } -}; - -// This context lowers `IRTypeCollection` and `IRFuncCollection` instructions -struct CollectionLoweringContext : public InstPassBase -{ - CollectionLoweringContext(IRModule* module) - : InstPassBase(module) - { - } - - void lowerTypeCollection(IRTypeCollection* collection) - { - HashSet types; - for (UInt i = 0; i < collection->getOperandCount(); i++) - { - if (auto type = as(collection->getOperand(i))) - { - types.add(type); - } - } - - IRBuilder builder(collection->getModule()); - auto anyValueType = createAnyValueType(&builder, types); - collection->replaceUsesWith(anyValueType); - } - - void processModule() - { - processInstsOfType( - kIROp_TypeCollection, - [&](IRTypeCollection* inst) { return lowerTypeCollection(inst); }); - } -}; - -void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) -{ - SLANG_UNUSED(sink); - CollectionLoweringContext context(module); - context.processModule(); -} - -struct SequentialIDTagLoweringContext : public InstPassBase -{ - SequentialIDTagLoweringContext(IRModule* module) - : InstPassBase(module) - { - } - - void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) - { - SLANG_UNUSED(cast(inst->getOperand(0))); - auto srcSeqID = inst->getOperand(1); - - Dictionary mapping; - - // Map from sequential ID to unique ID - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); - - UIndex dstSeqID = 0; - forEachInCollection( - destCollection, - [&](IRInst* table) - { - // Get unique ID for the witness table - SLANG_UNUSED(cast(table)); - auto outputId = dstSeqID++; - auto seqDecoration = table->findDecoration(); - if (seqDecoration) - { - auto inputId = seqDecoration->getSequentialID(); - mapping[inputId] = outputId; // Map ID to itself for now - } - }); - - IRBuilder builder(inst); - builder.setInsertAfter(inst); - UInt defaultID = dstSeqID - 1; // Default to last available conformance. - auto translatedID = builder.emitCallInst( - inst->getDataType(), - createIntegerMappingFunc(builder.getModule(), mapping, defaultID), - List({srcSeqID})); - - inst->replaceUsesWith(translatedID); - inst->removeAndDeallocate(); - } - - - void lowerGetSequentialIDFromTag(IRGetSequentialIDFromTag* inst) - { - SLANG_UNUSED(cast(inst->getOperand(0))); - auto srcTagInst = inst->getOperand(1); - - Dictionary mapping; - - // Map from sequential ID to unique ID - auto destCollection = cast( - cast(srcTagInst->getDataType())->getOperand(0)); - - UIndex dstSeqID = 0; - forEachInCollection( - destCollection, - [&](IRInst* table) - { - // Get unique ID for the witness table - SLANG_UNUSED(cast(table)); - auto outputId = dstSeqID++; - auto seqDecoration = table->findDecoration(); - if (seqDecoration) - { - auto inputId = seqDecoration->getSequentialID(); - mapping.add({outputId, inputId}); - } - }); - - IRBuilder builder(inst); - builder.setInsertAfter(inst); - auto translatedID = builder.emitCallInst( - inst->getDataType(), - createIntegerMappingFunc(builder.getModule(), mapping, 0), - List({srcTagInst})); - - inst->replaceUsesWith(translatedID); - inst->removeAndDeallocate(); - } - - void processModule() - { - processInstsOfType( - kIROp_GetTagFromSequentialID, - [&](IRGetTagFromSequentialID* inst) { return lowerGetTagFromSequentialID(inst); }); - - processInstsOfType( - kIROp_GetSequentialIDFromTag, - [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); - } -}; - -void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) -{ - SLANG_UNUSED(sink); - SequentialIDTagLoweringContext context(module); - context.processModule(); -} - -void lowerTagInsts(IRModule* module, DiagnosticSink* sink) -{ - SLANG_UNUSED(sink); - TagOpsLoweringContext tagContext(module); - tagContext.processModule(); -} - -struct TagTypeLoweringContext : public InstPassBase -{ - TagTypeLoweringContext(IRModule* module) - : InstPassBase(module) - { - } - - void processModule() - { - processInstsOfType( - kIROp_CollectionTagType, - [&](IRCollectionTagType* inst) - { - IRBuilder builder(inst->getModule()); - inst->replaceUsesWith(builder.getUIntType()); - }); - } -}; - -void lowerTagTypes(IRModule* module) -{ - TagTypeLoweringContext context(module); - context.processModule(); -} - -// This context lowers `CastInterfaceToTaggedUnionPtr` and -// `CastTaggedUnionToInterfacePtr` by finding all `IRLoad` and -// `IRStore` uses of these insts, and upcasting the tagged-union -// tuple to the the interface-based tuple (of the loaded inst or before -// storing the val, as necessary) -// -struct TaggedUnionLoweringContext : public InstPassBase -{ - TaggedUnionLoweringContext(IRModule* module) - : InstPassBase(module) - { - } - - IRInst* convertToTaggedUnion( - IRBuilder* builder, - IRInst* val, - IRInst* interfaceType, - IRInst* targetType) - { - auto baseInterfaceValue = val; - auto witnessTable = builder->emitExtractExistentialWitnessTable(baseInterfaceValue); - auto tableID = builder->emitGetSequentialIDInst(witnessTable); - - auto taggedUnionTupleType = cast(targetType); - - List getTagOperands; - getTagOperands.add(interfaceType); - getTagOperands.add(tableID); - auto tableTag = builder->emitIntrinsicInst( - (IRType*)taggedUnionTupleType->getOperand(0), - kIROp_GetTagFromSequentialID, - getTagOperands.getCount(), - getTagOperands.getBuffer()); - - return builder->emitMakeTuple( - {tableTag, - builder->emitReinterpret( - (IRType*)taggedUnionTupleType->getOperand(1), - builder->emitExtractExistentialValue( - (IRType*)builder->emitExtractExistentialType(baseInterfaceValue), - baseInterfaceValue))}); - } - - void lowerCastInterfaceToTaggedUnionPtr(IRCastInterfaceToTaggedUnionPtr* inst) - { - // Find all uses of the inst - traverseUses( - inst, - [&](IRUse* use) - { - auto user = use->getUser(); - switch (user->getOp()) - { - case kIROp_Load: - { - auto baseInterfacePtr = inst->getOperand(0); - auto baseInterfaceType = as( - as(baseInterfacePtr->getDataType())->getValueType()); - - // Rewrite the load to use the original ptr and load - // an interface-typed object. - // - IRBuilder builder(module); - builder.setInsertAfter(user); - builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); - builder.replaceOperand(&user->typeUse, baseInterfaceType); - - // Then, we'll rewrite it. - List oldUses; - traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); - - auto newVal = convertToTaggedUnion( - &builder, - user, - baseInterfaceType, - as(inst->getDataType())->getValueType()); - for (auto oldUse : oldUses) - { - builder.replaceOperand(oldUse, newVal); - } - break; - } - case kIROp_StructuredBufferLoad: - case kIROp_RWStructuredBufferLoad: - { - auto baseInterfacePtr = inst->getOperand(0); - auto baseInterfaceType = - as((baseInterfacePtr->getDataType())->getOperand(0)); - - IRBuilder builder(module); - builder.setInsertAfter(user); - builder.replaceOperand(user->getOperands() + 0, baseInterfacePtr); - builder.replaceOperand(&user->typeUse, baseInterfaceType); - - // Then, we'll rewrite it. - List oldUses; - traverseUses(user, [&](IRUse* oldUse) { oldUses.add(oldUse); }); - - auto newVal = convertToTaggedUnion( - &builder, - user, - baseInterfaceType, - as(inst->getDataType())->getValueType()); - for (auto oldUse : oldUses) - { - builder.replaceOperand(oldUse, newVal); - } - break; - } - default: - SLANG_UNEXPECTED("Unexpected user of CastInterfaceToTaggedUnionPtr"); - } - }); - - SLANG_ASSERT(!inst->hasUses()); - inst->removeAndDeallocate(); - } - - void lowerCastTaggedUnionToInterfacePtr(IRCastTaggedUnionToInterfacePtr* inst) - { - SLANG_UNUSED(inst); - SLANG_UNEXPECTED("Unexpected inst of CastTaggedUnionToInterfacePtr"); - } - - IRType* convertToTupleType(IRCollectionTaggedUnionType* taggedUnion) - { - // Replace type with Tuple - IRBuilder builder(module); - builder.setInsertInto(module); - - auto typeCollection = cast(taggedUnion->getOperand(0)); - auto tableCollection = cast(taggedUnion->getOperand(1)); - - if (getCollectionCount(typeCollection) == 1) - return builder.getTupleType(List( - {(IRType*)makeTagType(tableCollection), - (IRType*)getCollectionElement(typeCollection, 0)})); - - return builder.getTupleType( - List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); - } - - bool processModule() - { - // First, we'll lower all CollectionTaggedUnionType insts - // into tuples. - // - processInstsOfType( - kIROp_CollectionTaggedUnionType, - [&](IRCollectionTaggedUnionType* inst) - { - inst->replaceUsesWith(convertToTupleType(inst)); - inst->removeAndDeallocate(); - }); - - bool hasCastInsts = false; - processInstsOfType( - kIROp_CastInterfaceToTaggedUnionPtr, - [&](IRCastInterfaceToTaggedUnionPtr* inst) - { - hasCastInsts = true; - return lowerCastInterfaceToTaggedUnionPtr(inst); - }); - - processInstsOfType( - kIROp_CastTaggedUnionToInterfacePtr, - [&](IRCastTaggedUnionToInterfacePtr* inst) - { - hasCastInsts = true; - return lowerCastTaggedUnionToInterfacePtr(inst); - }); - - return hasCastInsts; - } + // Context for building collection insts + CollectionBuilder cBuilder; }; -bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) -{ - SLANG_UNUSED(sink); - - TaggedUnionLoweringContext context(module); - return context.processModule(); -} - // Main entry point -bool lowerDynamicInsts(IRModule* module, DiagnosticSink* sink) +bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink) { - DynamicInstLoweringContext context(module, sink); + TypeFlowSpecializationContext context(module, sink); return context.processModule(); } diff --git a/source/slang/slang-ir-typeflow-specialize.h b/source/slang/slang-ir-typeflow-specialize.h new file mode 100644 index 00000000000..1503cc03927 --- /dev/null +++ b/source/slang/slang-ir-typeflow-specialize.h @@ -0,0 +1,9 @@ +// slang-ir-typeflow-specialize.h +#pragma once +#include "slang-ir.h" + +namespace Slang +{ +// Main entry point for the pass +bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink); +} // namespace Slang diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 1f641803683..aab3679f2d1 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2342,6 +2342,8 @@ struct IRModule : RefObject IRDeduplicationContext* getDeduplicationContext() const { return &m_deduplicationContext; } + Dictionary* getUniqueIdMap() { return &m_mapInstToUniqueId; } + IRDominatorTree* findDominatorTree(IRGlobalValueWithCode* func) { IRAnalysis* analysis = m_mapInstToAnalysis.tryGetValue(func); @@ -2458,6 +2460,12 @@ struct IRModule : RefObject Dictionary m_mapInstToAnalysis; Dictionary> m_mapMangledNameToGlobalInst; + + /// Hold a mapping for inst -> uniqueID. This mapping is generated on + /// demand if passes need them, rather than eagerly storing them on + /// insts when unnecessary. + /// + Dictionary m_mapInstToUniqueId; }; diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang new file mode 100644 index 00000000000..752f5b62c3c --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-4.slang @@ -0,0 +1,72 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + This withOffset(float offset); + float calc(float x); +} + +struct A : IInterface +{ + float factor; + A withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return factor * x * x * x; } +}; + +struct B : IInterface +{ + float factor1; + float factor2; + B withOffset(float offset) { return {factor1 + offset, factor2 + offset}; } + float calc(float x) { return factor1 * x * x + factor2 * x; } +}; + +struct C : IInterface +{ + float factor; + C withOffset(float offset) { return {factor + offset}; } + float calc(float x) { return x; } +}; + +T transfer(T obj) +{ + return obj.withOffset(2.5f); +} + +struct Foo +{ + T a; + T b; +} + +Foo make(T obj) +{ + return {obj, obj.withOffset(1.0f)}; +} + +float calc(Foo obj, float y) +{ + return obj.a.calc(y) + obj.b.calc(y); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(x); + else + obj = B(x, x + 1); + + return calc(make(obj), x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 40 + outputBuffer[1] = f(1, 2); // CHECK: 34 +} \ No newline at end of file diff --git a/tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang b/tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang new file mode 100644 index 00000000000..0d92ded6825 --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/dynamic-specialization-5.slang @@ -0,0 +1,66 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + associatedtype Extras; + This withOffset(float offset); + Extras preCalc(float x); + float postCalc(Extras extras); +} + +struct A : IInterface +{ + float factor; + typealias Extras = float; + A withOffset(float offset) { return {factor + offset}; } + float preCalc(float x) { return factor * x * x * x; } + float postCalc(float extras) { return extras; } +}; + +struct B : IInterface +{ + float factor1; + float factor2; + typealias Extras = float2; + B withOffset(float offset) { return {factor1 + offset, factor2 + offset}; } + float preCalc(float x) { return factor1 * x * x + factor2 * x; } + float postCalc(Extras extras) { return dot(extras, float2(1, 1)); } +}; + +struct Foo +{ + T a; + T.Extras extras; +} + +Foo make(T obj, float y) +{ + return {obj, obj.preCalc(y)}; +} + +float calc(Foo obj) +{ + return obj.a.postCalc(obj.extras); +} + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(x); + else + obj = B(x, x + 1); + + return calc(make(obj, x)); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 16 + outputBuffer[1] = f(1, 2); // CHECK: 28 +} \ No newline at end of file From f9e4d7173ec176398932195c3f5d5915f747c599 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:05:22 -0400 Subject: [PATCH 046/105] Update slang-ir.h --- source/slang/slang-ir.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index aab3679f2d1..b6207d921bf 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -2410,7 +2410,7 @@ struct IRModule : RefObject // anything to do with serialization format // const static UInt k_minSupportedModuleVersion = 1; - const static UInt k_maxSupportedModuleVersion = 1; + const static UInt k_maxSupportedModuleVersion = 2; static_assert(k_minSupportedModuleVersion <= k_maxSupportedModuleVersion); private: From 0e1ba02080038b224222536c845f2a6e397c019d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:07:16 -0400 Subject: [PATCH 047/105] Fix formatting --- source/slang/slang-lower-to-ir.cpp | 45 +++++++++++++----------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 33bc7c1c1ae..a6a0a44c23b 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1875,11 +1875,10 @@ struct ValLoweringVisitor : ValVisitorgetMidToSup()); } - return LoweredValInfo::simple( - getBuilder()->emitLookupInterfaceMethodInst( - getBuilder()->getWitnessTableType(lowerType(context, val->getSup())), - baseWitnessTable, - midToSup)); + return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( + getBuilder()->getWitnessTableType(lowerType(context, val->getSup())), + baseWitnessTable, + midToSup)); } LoweredValInfo visitForwardDifferentiateVal(ForwardDifferentiateVal* val) @@ -4196,20 +4195,18 @@ struct ExprLoweringVisitorBase : public ExprVisitor auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - return LoweredValInfo::simple( - getBuilder()->emitForwardDifferentiateInst( - lowerType(context, expr->type), - baseVal.val)); + return LoweredValInfo::simple(getBuilder()->emitForwardDifferentiateInst( + lowerType(context, expr->type), + baseVal.val)); } LoweredValInfo visitDetachExpr(DetachExpr* expr) { auto baseVal = lowerRValueExpr(context, expr->inner); - return LoweredValInfo::simple( - getBuilder()->emitDetachDerivative( - lowerType(context, expr->type), - getSimpleVal(context, baseVal))); + return LoweredValInfo::simple(getBuilder()->emitDetachDerivative( + lowerType(context, expr->type), + getSimpleVal(context, baseVal))); } LoweredValInfo visitPrimalSubstituteExpr(PrimalSubstituteExpr* expr) @@ -4295,10 +4292,9 @@ struct ExprLoweringVisitorBase : public ExprVisitor auto baseVal = lowerSubExpr(expr->baseFunction); SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); - return LoweredValInfo::simple( - getBuilder()->emitBackwardDifferentiateInst( - lowerType(context, expr->type), - baseVal.val)); + return LoweredValInfo::simple(getBuilder()->emitBackwardDifferentiateInst( + lowerType(context, expr->type), + baseVal.val)); } LoweredValInfo visitDispatchKernelExpr(DispatchKernelExpr* expr) @@ -4309,14 +4305,13 @@ struct ExprLoweringVisitorBase : public ExprVisitor auto groupSize = lowerRValueExpr(context, expr->dispatchSize); // Actual arguments to be filled in when we lower the actual call expr. // This is handled in `emitCallToVal`. - return LoweredValInfo::simple( - getBuilder()->emitDispatchKernelInst( - lowerType(context, expr->type), - baseVal.val, - getSimpleVal(context, threadSize), - getSimpleVal(context, groupSize), - 0, - nullptr)); + return LoweredValInfo::simple(getBuilder()->emitDispatchKernelInst( + lowerType(context, expr->type), + baseVal.val, + getSimpleVal(context, threadSize), + getSimpleVal(context, groupSize), + 0, + nullptr)); } LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) From 70d1899f934a78b040f43c65072ac354e37d1d95 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Aug 2025 16:40:00 -0400 Subject: [PATCH 048/105] Update switch-case.slang --- tests/wgsl/switch-case.slang | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/tests/wgsl/switch-case.slang b/tests/wgsl/switch-case.slang index 133982cd030..ad87d5a3e08 100644 --- a/tests/wgsl/switch-case.slang +++ b/tests/wgsl/switch-case.slang @@ -70,21 +70,8 @@ func fs_main(VertexOutput input)->FragmentOutput return output; } -//WGSL: fn _S13( _S14 : u32, _S15 : AnyValue8) -> f32 +//WGSL: fn _S{{[0-9]+}}( _S{{[0-9]+}} : u32, _S{{[0-9]+}} : AnyValue8) -> f32 //WGSL-NEXT:{ -//WGSL-NEXT: switch(_S14) -//WGSL-NEXT: { -//WGSL-NEXT: case u32(0): -//WGSL-NEXT: { -//WGSL-NEXT: return U_SR14switch_2Dxcase6Circle7getAreap0pf_wtwrapper_0(_S15); -//WGSL-NEXT: } -//WGSL-NEXT: case u32(1): -//WGSL-NEXT: { -//WGSL-NEXT: return U_SR14switch_2Dxcase9Rectangle7getAreap0pf_wtwrapper_0(_S15); -//WGSL-NEXT: } -//WGSL-NEXT: default : -//WGSL-NEXT: { -//WGSL-NEXT: return 0.0f; -//WGSL-NEXT: } -//WGSL-NEXT: } - //WGSL-NEXT:} \ No newline at end of file +//WGSL-DAG: return U_SR14switch_2Dxcase6Circle7getAreap0pf_wtwrapper_0(_S{{[0-9]+}}); +//WGSL-DAG: return U_SR14switch_2Dxcase9Rectangle7getAreap0pf_wtwrapper_0(_S{{[0-9]+}}); +//WGSL:} \ No newline at end of file From cfb846908877dd160271500e96bfdc4ebcac87eb Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 21 Aug 2025 12:06:33 -0400 Subject: [PATCH 049/105] Fix default value selection --- source/slang/slang-ir-lower-typeflow-insts.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index ac418ade1f1..8f914ae456a 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -163,7 +163,7 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi auto param = builder.emitParam(builder.getUIntType()); - // Create default block that returns 0 + // Create default block that returns defaultVal auto defaultBlock = builder.emitBlock(); builder.setInsertInto(defaultBlock); builder.emitReturn(builder.getIntValue(builder.getUIntType(), defaultVal)); @@ -441,10 +441,18 @@ struct SequentialIDTagLoweringContext : public InstPassBase IRBuilder builder(inst); builder.setInsertAfter(inst); - UInt defaultID = dstSeqID - 1; // Default to last available conformance. + + // Default to largest available sequential ID. + UInt defaultSeqID = 0; + for (auto [inputId, outputId] : mapping) + { + if (inputId > defaultSeqID) + defaultSeqID = inputId; + } + auto translatedID = builder.emitCallInst( inst->getDataType(), - createIntegerMappingFunc(builder.getModule(), mapping, defaultID), + createIntegerMappingFunc(builder.getModule(), mapping, mapping[defaultSeqID]), List({srcSeqID})); inst->replaceUsesWith(translatedID); From 704d381d49a8946ab95241334407c42e90f40e52 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 21 Aug 2025 14:14:44 -0400 Subject: [PATCH 050/105] Update slang-ir-insts-stable-names.lua --- source/slang/slang-ir-insts-stable-names.lua | 26 ++++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 0920266c371..6af90b76a35 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -670,17 +670,17 @@ return { ["SPIRVAsmOperand.__imageType"] = 666, ["SPIRVAsmOperand.__sampledImageType"] = 667, ["Type.CLayout"] = 668, - ["TypeFlowData.CollectionBase.TypeCollection"] = 668, - ["TypeFlowData.CollectionBase.FuncCollection"] = 669, - ["TypeFlowData.CollectionBase.TableCollection"] = 670, - ["TypeFlowData.CollectionBase.GenericCollection"] = 671, - ["TypeFlowData.UnboundedCollection"] = 672, - ["TypeFlowData.CollectionTagType"] = 673, - ["TypeFlowData.CollectionTaggedUnionType"] = 674, - ["CastInterfaceToTaggedUnionPtr"] = 675, - ["CastTaggedUnionToInterfacePtr"] = 676, - ["GetTagForSuperCollection"] = 677, - ["GetTagForMappedCollection"] = 678, - ["GetTagFromSequentialID"] = 679, - ["GetSequentialIDFromTag"] = 680 + ["TypeFlowData.CollectionBase.TypeCollection"] = 669, + ["TypeFlowData.CollectionBase.FuncCollection"] = 670, + ["TypeFlowData.CollectionBase.TableCollection"] = 671, + ["TypeFlowData.CollectionBase.GenericCollection"] = 672, + ["TypeFlowData.UnboundedCollection"] = 673, + ["TypeFlowData.CollectionTagType"] = 674, + ["TypeFlowData.CollectionTaggedUnionType"] = 675, + ["CastInterfaceToTaggedUnionPtr"] = 676, + ["CastTaggedUnionToInterfacePtr"] = 677, + ["GetTagForSuperCollection"] = 678, + ["GetTagForMappedCollection"] = 679, + ["GetTagFromSequentialID"] = 680, + ["GetSequentialIDFromTag"] = 681 } From c30387803a6b7ff6a4583fb20cc6c4ba75ce56a3 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 2 Sep 2025 15:28:30 -0400 Subject: [PATCH 051/105] Add `GetTagForSpecializedCollection` to represent the dynamic result of a specialization --- source/slang/slang-ir-insts-stable-names.lua | 5 +- source/slang/slang-ir-insts.lua | 8 ++- .../slang/slang-ir-lower-typeflow-insts.cpp | 57 ++++++++++++++++++ source/slang/slang-ir-typeflow-specialize.cpp | 60 +++++++++++++++---- 4 files changed, 115 insertions(+), 15 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 6af90b76a35..b95f64bbbc1 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -681,6 +681,7 @@ return { ["CastTaggedUnionToInterfacePtr"] = 677, ["GetTagForSuperCollection"] = 678, ["GetTagForMappedCollection"] = 679, - ["GetTagFromSequentialID"] = 680, - ["GetSequentialIDFromTag"] = 681 + ["GetTagForSpecializedCollection"] = 680, + ["GetTagFromSequentialID"] = 681, + ["GetSequentialIDFromTag"] = 682 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 25a5b18cb11..944234f82f1 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2201,14 +2201,18 @@ local insts = { { GetTagForMappedCollection = { -- Translate a tag from a set to its equivalent in a different set -- based on a mapping induced by a lookup key - } }, + } }, + { GetTagForSpecializedCollection = { + -- Translate a tag from a generic set to its equivalent in a specialized set + -- based on a mapping that is encoded in the operands of this tag instruction + } }, { GetTagFromSequentialID = { -- Translate an existing sequential ID (a 'global' ID) & and interface type into a tag -- the provided collection (a 'local' ID) } }, { GetSequentialIDFromTag = { -- Translate a tag from the given collection (a 'local' ID) to a sequential ID (a 'global' ID) - } } + } } } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 8f914ae456a..6d0798d6de8 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -317,6 +317,60 @@ struct TagOpsLoweringContext : public InstPassBase inst->removeAndDeallocate(); } + void lowerGetTagForSpecializedCollection(IRGetTagForSpecializedCollection* inst) + { + auto srcCollection = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destCollection = + cast(cast(inst->getDataType())->getOperand(0)); + Dictionary mapping; + + for (UInt i = 1; i < inst->getOperandCount(); i += 2) + { + auto srcElement = inst->getOperand(i); + auto destElement = inst->getOperand(i + 1); + mapping[srcElement] = destElement; + } + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + + List indices; + for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + { + // Find in destCollection + bool found = false; + auto mappedElement = mapping[srcCollection->getOperand(i)]; + for (UInt j = 0; j < destCollection->getOperandCount(); j++) + { + auto destElement = destCollection->getOperand(j); + if (mappedElement == destElement) + { + found = true; + indices.add(builder.getIntValue(builder.getUIntType(), j)); + break; // Found the index + } + } + + if (!found) + { + SLANG_UNEXPECTED("Element not found in specialized collection"); + } + } + + // Create an array for the lookup + auto lookupArrayType = builder.getArrayType( + builder.getUIntType(), + builder.getIntValue(builder.getUIntType(), indices.getCount())); + auto lookupArray = + builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); + auto resultID = + builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); + inst->replaceUsesWith(resultID); + inst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) { switch (inst->getOp()) @@ -327,6 +381,9 @@ struct TagOpsLoweringContext : public InstPassBase case kIROp_GetTagForMappedCollection: lowerGetTagForMappedCollection(as(inst)); break; + case kIROp_GetTagForSpecializedCollection: + lowerGetTagForSpecializedCollection(as(inst)); + break; default: break; } diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index eabebdf3085..1997366f1aa 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -2588,8 +2588,8 @@ struct TypeFlowSpecializationContext // If we're placing a specialized call, use the base tag since the // specialization arguments will also become arguments to the call. // - if (auto specializedTag = as(calleeTagInst)) - calleeTagInst = specializedTag->getBase(); + // if (auto specializedTag = as(calleeTagInst)) + // calleeTagInst = specializedTag->getBase(); } callee = collectionTag->getOperand(0); } @@ -2922,21 +2922,59 @@ struct TypeFlowSpecializationContext isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; } - // Functions/Collections of Functions should be handled at the call site (in specializeCall) - // since witness table specialization arguments must be inlined into the call. - // + // We'll emit a dynamic tag inst if the result is a func collection with multiple elements if (isFuncReturn) { - // TODO: Maybe make this the 'default' behavior if a specializeing call - // returns false. - // - if (tryGetInfo(context, inst)) - return replaceType(context, inst); + if (auto info = tryGetInfo(context, inst)) + { + // If our inst represents a collection directly (no run-time info), + // there's nothing to do except replace the type (if necessary) + // + if (as(info)) + return replaceType(context, inst); + + auto specializedCollectionTag = as(info); + + // If the inst represents a singleton collection, there's nothing + // to do except replace the type (if necessary) + // + if (getCollectionCount(specializedCollectionTag) <= 1) + return replaceType(context, inst); + + List mappingOperands; + + // Add the base tag as the first operand. The mapping operands follow + mappingOperands.add(inst->getBase()); + + forEachInCollection( + specializedCollectionTag, + [&](IRInst* element) + { + // Emit the GetTagForSpecializedCollection for each element. + auto specInst = cast(element); + auto baseGeneric = cast(specInst->getBase()); + + mappingOperands.add(baseGeneric); + mappingOperands.add(specInst); + }); + + IRBuilder builder(inst); + setInsertBeforeOrdinaryInst(&builder, inst); + auto newInst = builder.emitIntrinsicInst( + (IRType*)info, + kIROp_GetTagForSpecializedCollection, + mappingOperands.getCount(), + mappingOperands.getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } else return false; } - // For all other specializations, we'll 'drop' the dyanamic tag information. + // For all other specializations, we'll 'drop' the dynamic tag information. bool changed = false; List args; for (UIndex i = 0; i < inst->getArgCount(); i++) From 40239f0758c765cd0d0feb38f6a55ec7668603b5 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 2 Sep 2025 15:56:43 -0400 Subject: [PATCH 052/105] Update direction info to include address space. --- .../a3-02-reference-capability-atoms.md | 3 + source/slang/slang-ir-typeflow-specialize.cpp | 95 +++++++++++++------ 2 files changed, 68 insertions(+), 30 deletions(-) diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md index 01510431004..7f759ca24f8 100644 --- a/docs/user-guide/a3-02-reference-capability-atoms.md +++ b/docs/user-guide/a3-02-reference-capability-atoms.md @@ -984,6 +984,9 @@ Compound Capabilities `cpp_cuda_hlsl_spirv` > CPP, CUDA, HLSL, and SPIRV code-gen targets +`cpp_cuda_metal_spirv` +> CPP, CUDA, Metal, and SPIRV code-gen targets + `cpp_cuda_spirv` > CPP, CUDA and SPIRV code-gen targets diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 1997366f1aa..2febce34933 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -240,6 +240,37 @@ struct WorkQueue struct TypeFlowSpecializationContext { + struct ParameterDirectionInfo + { + enum Kind + { + In, + Out, + InOut, + Ref, + ConstRef + } kind; + + // For Ref and ConstRef + AddressSpace addressSpace; + + ParameterDirectionInfo(Kind kind, AddressSpace addressSpace = (AddressSpace)0) + : kind(kind), addressSpace(addressSpace) + { + } + + ParameterDirectionInfo() + : kind(Kind::In), addressSpace((AddressSpace)0) + { + } + + bool operator==(const ParameterDirectionInfo& other) const + { + return kind == other.kind && addressSpace == other.addressSpace; + } + }; + + IRCollectionTaggedUnionType* makeExistential(IRTableCollection* tableCollection) { HashSet typeSet; @@ -1494,9 +1525,9 @@ struct TypeFlowSpecializationContext return infos; } - List getParamDirections(IRInst* context) + List getParamDirections(IRInst* context) { - List directions; + List directions; if (as(context)) { for (auto param : as(context)->getParams()) @@ -1566,11 +1597,11 @@ struct TypeFlowSpecializationContext IRInst* argInfo = tryGetInfo(edge.callerContext, arg); - switch (paramDirection) + switch (paramDirection.kind) { - case kParameterDirection_Out: - case kParameterDirection_InOut: - case kParameterDirection_ConstRef: + case ParameterDirectionInfo::Kind::Out: + case ParameterDirectionInfo::Kind::InOut: + case ParameterDirectionInfo::Kind::ConstRef: { IRBuilder builder(module); if (!argInfo) @@ -1591,7 +1622,7 @@ struct TypeFlowSpecializationContext updateInfo(edge.targetContext, param, newInfo, true, workQueue); break; } - case kParameterDirection_In: + case ParameterDirectionInfo::Kind::In: { // Use centralized update method if (!argInfo) @@ -1630,8 +1661,8 @@ struct TypeFlowSpecializationContext { if (paramInfo) { - if (paramDirections[argIndex] == kParameterDirection_Out || - paramDirections[argIndex] == kParameterDirection_InOut) + if (paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::Out || + paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::InOut) { auto arg = callInst->getArg(argIndex); auto argPtrType = as(arg->getDataType()); @@ -2226,40 +2257,46 @@ struct TypeFlowSpecializationContext } // Split into direction and type - std::tuple getParameterDirectionAndType(IRType* paramType) + std::tuple getParameterDirectionAndType(IRType* paramType) { if (as(paramType)) return { - ParameterDirection::kParameterDirection_Out, + ParameterDirectionInfo(ParameterDirectionInfo::Kind::Out), as(paramType)->getValueType()}; else if (as(paramType)) return { - ParameterDirection::kParameterDirection_InOut, + ParameterDirectionInfo(ParameterDirectionInfo::Kind::InOut), as(paramType)->getValueType()}; else if (as(paramType)) return { - ParameterDirection::kParameterDirection_Ref, + ParameterDirectionInfo( + ParameterDirectionInfo::Kind::Ref, + as(paramType)->getAddressSpace()), as(paramType)->getValueType()}; else if (as(paramType)) return { - ParameterDirection::kParameterDirection_ConstRef, + ParameterDirectionInfo( + ParameterDirectionInfo::Kind::ConstRef, + as(paramType)->getAddressSpace()), as(paramType)->getValueType()}; else - return {ParameterDirection::kParameterDirection_In, paramType}; + return {ParameterDirectionInfo(ParameterDirectionInfo::Kind::In), paramType}; } - IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirection direction, IRType* type) + IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo direction, IRType* type) { - switch (direction) + switch (direction.kind) { - case ParameterDirection::kParameterDirection_In: + case ParameterDirectionInfo::Kind::In: return type; - case ParameterDirection::kParameterDirection_Out: + case ParameterDirectionInfo::Kind::Out: return builder->getOutType(type); - case ParameterDirection::kParameterDirection_InOut: + case ParameterDirectionInfo::Kind::InOut: return builder->getInOutType(type); - case ParameterDirection::kParameterDirection_ConstRef: - return builder->getConstRefType(type); + case ParameterDirectionInfo::Kind::ConstRef: + return builder->getConstRefType(type, direction.addressSpace); + case ParameterDirectionInfo::Kind::Ref: + return builder->getRefType(type, direction.addressSpace); default: SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); } @@ -2661,8 +2698,6 @@ struct TypeFlowSpecializationContext // First, we'll legalize all operands by upcasting if necessary. // This needs to be done even if the callee is not a collection. // - // List paramTypeFlows = getParamInfos(callee); - // List paramDirections = getParamDirections(callee); UCount extraArgCount = newArgs.getCount(); for (UInt i = 0; i < inst->getArgCount(); i++) { @@ -2670,15 +2705,15 @@ struct TypeFlowSpecializationContext const auto [paramDirection, paramType] = getParameterDirectionAndType(expectedFuncType->getParamType(i + extraArgCount)); - switch (paramDirection) + switch (paramDirection.kind) { - case kParameterDirection_In: + case ParameterDirectionInfo::Kind::In: newArgs.add(upcastCollection(context, arg, paramType)); break; - case kParameterDirection_Out: - case kParameterDirection_InOut: - case kParameterDirection_ConstRef: - case kParameterDirection_Ref: + case ParameterDirectionInfo::Kind::Out: + case ParameterDirectionInfo::Kind::InOut: + case ParameterDirectionInfo::Kind::ConstRef: + case ParameterDirectionInfo::Kind::Ref: { newArgs.add(arg); break; From 01a1869959222ae9f9cbf106dccc448c1118ac1c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 2 Sep 2025 16:50:04 -0400 Subject: [PATCH 053/105] Update specialization pass to cache its specialization entries at the IR level. --- source/slang/slang-ir-specialize.cpp | 176 +++++++++++---------------- 1 file changed, 72 insertions(+), 104 deletions(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index e04ea1db552..45cd97ce955 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -54,6 +54,7 @@ struct SpecializationContext DiagnosticSink* sink; TargetProgram* targetProgram; SpecializationOptions options; + Dictionary irDictionaryMap; bool changed = false; @@ -357,6 +358,7 @@ struct SpecializationContext // this generic again for the same arguments. // genericSpecializations.add(key, specializedVal); + addEntryToIRDictionary(kIROp_GenericSpecializationDictionary, key.vals, specializedVal); return specializedVal; } @@ -962,28 +964,30 @@ struct SpecializationContext for (auto child = dictInst->getFirstChild(); child; child = child->next) childrenCount++; dict.reserve(Index{1} << Math::Log2Ceil(childrenCount * 2)); + + List invalidItems; for (auto child : dictInst->getChildren()) { auto item = as(child); if (!item) continue; IRSimpleSpecializationKey key; - bool shouldSkip = false; + bool isInvalid = false; for (UInt i = 0; i < item->getOperandCount(); i++) { if (item->getOperand(i) == nullptr) { - shouldSkip = true; + isInvalid = true; break; } if (item->getOperand(i)->getParent() == nullptr) { - shouldSkip = true; + isInvalid = true; break; } if (item->getOperand(i)->getOp() == kIROp_Undefined) { - shouldSkip = true; + isInvalid = true; break; } if (i > 0) @@ -991,97 +995,74 @@ struct SpecializationContext key.vals.add(item->getOperand(i)); } } - if (shouldSkip) + if (isInvalid) + { + invalidItems.add(item); + if (dict.containsKey(key)) + dict.remove(key); continue; + } auto value = as::type>( item->getOperand(0)); SLANG_ASSERT(value); dict[key] = value; } - dictInst->removeAndDeallocate(); + + // Clean up the IR dictionary + for (auto item : invalidItems) + item->removeAndDeallocate(); } void readSpecializationDictionaries() { auto moduleInst = module->getModuleInst(); - ShortList dictInsts; - for (auto child : moduleInst->getChildren()) - { - switch (child->getOp()) - { - case kIROp_GenericSpecializationDictionary: - case kIROp_ExistentialFuncSpecializationDictionary: - case kIROp_ExistentialTypeSpecializationDictionary: - dictInsts.add(child); - break; - default: - continue; - } - } - for (auto dict : dictInsts) - { - switch (dict->getOp()) - { - case kIROp_GenericSpecializationDictionary: - _readSpecializationDictionaryImpl(genericSpecializations, dict); - break; - case kIROp_ExistentialFuncSpecializationDictionary: - _readSpecializationDictionaryImpl(existentialSpecializedFuncs, dict); - break; - case kIROp_ExistentialTypeSpecializationDictionary: - _readSpecializationDictionaryImpl(existentialSpecializedStructs, dict); - break; - default: - continue; - } - } + _readSpecializationDictionaryImpl( + genericSpecializations, + getOrCreateIRDictionary(kIROp_GenericSpecializationDictionary)); + + _readSpecializationDictionaryImpl( + existentialSpecializedFuncs, + getOrCreateIRDictionary(kIROp_ExistentialFuncSpecializationDictionary)); + + _readSpecializationDictionaryImpl( + existentialSpecializedStructs, + getOrCreateIRDictionary(kIROp_ExistentialTypeSpecializationDictionary)); } - template - void _writeSpecializationDictionaryImpl(TDict& dict, IROp dictOp, IRInst* moduleInst) + IRInst* getOrCreateIRDictionary(IROp dictOp) { - IRBuilder builder(moduleInst); - builder.setInsertInto(moduleInst); - auto dictInst = builder.emitIntrinsicInst(nullptr, dictOp, 0, nullptr); - builder.setInsertInto(dictInst); - List args; - for (const auto& [key, value] : dict) + if (irDictionaryMap.containsKey(dictOp)) + return irDictionaryMap[dictOp]; + + for (auto child : module->getModuleInst()->getChildren()) { - if (!value->parent) - continue; - for (auto keyVal : key.vals) + if (child->getOp() == dictOp) { - if (!keyVal->parent) - goto next; + irDictionaryMap[dictOp] = child; + return child; } - { - args.clear(); - args.add(value); - args.addRange(key.vals); - builder.emitIntrinsicInst( - nullptr, - kIROp_SpecializationDictionaryItem, - (UInt)args.getCount(), - args.getBuffer()); - } - next:; } + + IRBuilder builder(module); + builder.setInsertInto(module); + auto dictInst = builder.emitIntrinsicInst(nullptr, dictOp, 0, nullptr); + irDictionaryMap[dictOp] = dictInst; + return dictInst; } - void writeSpecializationDictionaries() + + void addEntryToIRDictionary(IROp dictOp, const List& key, IRInst* val) { - auto moduleInst = module->getModuleInst(); - _writeSpecializationDictionaryImpl( - genericSpecializations, - kIROp_GenericSpecializationDictionary, - moduleInst); - _writeSpecializationDictionaryImpl( - existentialSpecializedFuncs, - kIROp_ExistentialFuncSpecializationDictionary, - moduleInst); - _writeSpecializationDictionaryImpl( - existentialSpecializedStructs, - kIROp_ExistentialTypeSpecializationDictionary, - moduleInst); + auto dictInst = getOrCreateIRDictionary(dictOp); + List args; + args.add(val); + args.addRange(key); + IRBuilder builder(module); + builder.setInsertInto(dictInst); + builder.emitIntrinsicInst( + nullptr, + kIROp_SpecializationDictionaryItem, + (UInt)args.getCount(), + args.getBuffer()); } // All of the machinery for generic specialization @@ -1090,9 +1071,9 @@ struct SpecializationContext // void processModule() { - // Read specialization dictionary from module if it is defined. - // This prevents us from generating duplicated specializations - // when this pass is invoked iteratively. + // Sync local dictionaries with the IR specialization + // dictionaries for faster lookup. + // readSpecializationDictionaries(); // The unspecialized IR we receive as input will have @@ -1196,14 +1177,11 @@ struct SpecializationContext if (iterChanged) { this->changed = true; - writeSpecializationDictionaries(); - genericSpecializations.clear(); - existentialSpecializedFuncs.clear(); - existentialSpecializedStructs.clear(); eliminateDeadCode(module->getModuleInst()); - readSpecializationDictionaries(); - // eliminateDeadCode(module->getModuleInst()); applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); + + // Sync our local dictionary with the one in the IR. + readSpecializationDictionaries(); } // Once the work list has gone dry, we should have the invariant @@ -1221,21 +1199,9 @@ struct SpecializationContext iterChanged = specializeDynamicInsts(module, sink); if (iterChanged) { - // We'll write out the specialization info to an inst, - // and read it back again so we can remove entries - // for specializations that are no longer needed. - // - // If we don't do this, we'll end up with deallocated - // references in the specialization dictionaries, and - // can't reliably handle situations where the same specialization - // is requested again in the future once a different function - // has been specialized. - // - writeSpecializationDictionaries(); - genericSpecializations.clear(); - existentialSpecializedFuncs.clear(); - existentialSpecializedStructs.clear(); eliminateDeadCode(module->getModuleInst()); + + // Sync our local dictionary with the one in the IR. readSpecializationDictionaries(); } } @@ -1243,12 +1209,6 @@ struct SpecializationContext if (!iterChanged || sink->getErrorCount()) break; } - - - // For functions that still have `specialize` uses left, we need to preserve the - // its specializations in resulting IR so they can be reconstructed when this - // specialization pass gets invoked again. - writeSpecializationDictionaries(); } void addInstsToWorkListRec(IRInst* inst) @@ -1573,6 +1533,10 @@ struct SpecializationContext // specializedCallee = createExistentialSpecializedFunc(inst, calleeFunc); existentialSpecializedFuncs.add(key, specializedCallee); + addEntryToIRDictionary( + kIROp_ExistentialFuncSpecializationDictionary, + key.vals, + specializedCallee); } // At this point we have found or generated a specialized version @@ -2727,6 +2691,10 @@ struct SpecializationContext } existentialSpecializedStructs.add(key, newStructType); + addEntryToIRDictionary( + kIROp_ExistentialTypeSpecializationDictionary, + key.vals, + newStructType); } type->replaceUsesWith(newStructType); From 0fe0f8e1d36b9769af6f2f26269afcd82ed507b5 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 2 Sep 2025 18:11:52 -0400 Subject: [PATCH 054/105] Make type flow insts easier to use + fix specialization pass --- source/slang/slang-ir-insts.h | 48 ++++++++ .../slang/slang-ir-lower-typeflow-insts.cpp | 41 ++++--- source/slang/slang-ir-specialize.cpp | 10 +- source/slang/slang-ir-typeflow-collection.cpp | 1 - source/slang/slang-ir-typeflow-specialize.cpp | 107 ++++++++---------- 5 files changed, 118 insertions(+), 89 deletions(-) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9208e546c5c..cd52e8c0094 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3505,6 +3505,54 @@ struct IREmbeddedDownstreamIR : IRInst IRBlobLit* getBlob() { return cast(getOperand(1)); } }; +FIDDLE() +struct IRTypeFlowData : IRInst +{ + FIDDLE(baseInst()) +}; + +FIDDLE() +struct IRCollectionBase : IRTypeFlowData +{ + FIDDLE(baseInst()) + UInt getCount() { return getOperandCount(); } + IRInst* getElement(UInt idx) { return getOperand(idx); } + bool isSingleton() { return getOperandCount() == 1; } +}; + +FIDDLE() +struct IRTableCollection : IRCollectionBase +{ + FIDDLE(leafInst()) +}; + + +FIDDLE() +struct IRTypeCollection : IRCollectionBase +{ + FIDDLE(leafInst()) +}; + +FIDDLE() +struct IRCollectionTagType : IRTypeFlowData +{ + FIDDLE(leafInst()) + IRCollectionBase* getCollection() { return as(getOperand(0)); } + bool isSingleton() { return getCollection()->isSingleton(); } +}; + +FIDDLE() +struct IRCollectionTaggedUnionType : IRTypeFlowData +{ + FIDDLE(leafInst()) + IRTypeCollection* getTypeCollection() { return as(getOperand(0)); } + IRTableCollection* getTableCollection() { return as(getOperand(1)); } + bool isSingleton() + { + return getTypeCollection()->isSingleton() && getTableCollection()->isSingleton(); + } +}; + FIDDLE(allOtherInstStructs()) struct IRBuilderSourceLocRAII; diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 6d0798d6de8..e49853fd210 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -233,15 +233,15 @@ struct TagOpsLoweringContext : public InstPassBase builder.setInsertAfter(inst); List indices; - for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + for (UInt i = 0; i < srcCollection->getCount(); i++) { // Find in destCollection - auto srcElement = srcCollection->getOperand(i); + auto srcElement = srcCollection->getElement(i); bool found = false; - for (UInt j = 0; j < destCollection->getOperandCount(); j++) + for (UInt j = 0; j < destCollection->getCount(); j++) { - auto destElement = destCollection->getOperand(j); + auto destElement = destCollection->getElement(j); if (srcElement == destElement) { found = true; @@ -281,15 +281,15 @@ struct TagOpsLoweringContext : public InstPassBase builder.setInsertAfter(inst); List indices; - for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + for (UInt i = 0; i < srcCollection->getCount(); i++) { // Find in destCollection bool found = false; auto srcElement = - findWitnessTableEntry(cast(srcCollection->getOperand(i)), key); - for (UInt j = 0; j < destCollection->getOperandCount(); j++) + findWitnessTableEntry(cast(srcCollection->getElement(i)), key); + for (UInt j = 0; j < destCollection->getCount(); j++) { - auto destElement = destCollection->getOperand(j); + auto destElement = destCollection->getElement(j); if (srcElement == destElement) { found = true; @@ -319,10 +319,9 @@ struct TagOpsLoweringContext : public InstPassBase void lowerGetTagForSpecializedCollection(IRGetTagForSpecializedCollection* inst) { - auto srcCollection = cast( - cast(inst->getOperand(0)->getDataType())->getOperand(0)); - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); + auto srcCollection = + cast(inst->getOperand(0)->getDataType())->getCollection(); + auto destCollection = cast(inst->getDataType())->getCollection(); Dictionary mapping; for (UInt i = 1; i < inst->getOperandCount(); i += 2) @@ -336,14 +335,14 @@ struct TagOpsLoweringContext : public InstPassBase builder.setInsertAfter(inst); List indices; - for (UInt i = 0; i < srcCollection->getOperandCount(); i++) + for (UInt i = 0; i < srcCollection->getCount(); i++) { // Find in destCollection bool found = false; - auto mappedElement = mapping[srcCollection->getOperand(i)]; - for (UInt j = 0; j < destCollection->getOperandCount(); j++) + auto mappedElement = mapping[srcCollection->getElement(i)]; + for (UInt j = 0; j < destCollection->getCount(); j++) { - auto destElement = destCollection->getOperand(j); + auto destElement = destCollection->getElement(j); if (mappedElement == destElement) { found = true; @@ -434,9 +433,9 @@ struct CollectionLoweringContext : public InstPassBase void lowerTypeCollection(IRTypeCollection* collection) { HashSet types; - for (UInt i = 0; i < collection->getOperandCount(); i++) + for (UInt i = 0; i < collection->getCount(); i++) { - if (auto type = as(collection->getOperand(i))) + if (auto type = as(collection->getElement(i))) { types.add(type); } @@ -477,8 +476,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase Dictionary mapping; // Map from sequential ID to unique ID - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); + auto destCollection = cast(inst->getDataType())->getCollection(); UIndex dstSeqID = 0; forEachInCollection( @@ -525,8 +523,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase Dictionary mapping; // Map from sequential ID to unique ID - auto destCollection = cast( - cast(srcTagInst->getDataType())->getOperand(0)); + auto destCollection = cast(srcTagInst->getDataType())->getCollection(); UIndex dstSeqID = 0; forEachInCollection( diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 45cd97ce955..7cd2b587674 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -978,18 +978,16 @@ struct SpecializationContext if (item->getOperand(i) == nullptr) { isInvalid = true; - break; } - if (item->getOperand(i)->getParent() == nullptr) + else if (item->getOperand(i)->getParent() == nullptr) { isInvalid = true; - break; } - if (item->getOperand(i)->getOp() == kIROp_Undefined) + else if (item->getOperand(i)->getOp() == kIROp_Undefined) { isInvalid = true; - break; } + if (i > 0) { key.vals.add(item->getOperand(i)); @@ -1014,8 +1012,6 @@ struct SpecializationContext } void readSpecializationDictionaries() { - auto moduleInst = module->getModuleInst(); - _readSpecializationDictionaryImpl( genericSpecializations, getOrCreateIRDictionary(kIROp_GenericSpecializationDictionary)); diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index 54b9b113653..ad383cb83e7 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -91,7 +91,6 @@ IRCollectionBase* CollectionBuilder::createCollection(IROp op, const HashSet(builder.emitIntrinsicInst( nullptr, op, diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 2febce34933..d1e7c560702 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -551,18 +551,18 @@ struct TypeFlowSpecializationContext IRBuilder builder(arg->getModule()); setInsertAfterOrdinaryInst(&builder, arg); auto argTableTag = builder.emitGetTupleElement( - (IRType*)makeTagType(as(argTUType->getOperand(1))), + (IRType*)makeTagType(argTUType->getTableCollection()), arg, 0); auto reinterpretedTag = upcastCollection( context, argTableTag, - (IRType*)makeTagType(as(destTUType->getOperand(1)))); + (IRType*)makeTagType(destTUType->getTableCollection())); auto argVal = - builder.emitGetTupleElement((IRType*)argTUType->getOperand(0), arg, 1); + builder.emitGetTupleElement((IRType*)argTUType->getTypeCollection(), arg, 1); auto reinterpretedVal = - upcastCollection(context, argVal, (IRType*)destTUType->getOperand(0)); + upcastCollection(context, argVal, (IRType*)destTUType->getTypeCollection()); return builder.emitMakeTuple( (IRType*)destTUType, {reinterpretedTag, reinterpretedVal}); @@ -746,7 +746,7 @@ struct TypeFlowSpecializationContext return makeExistential(as(cBuilder.makeSingletonSet(witnessTable))); if (auto collectionTag = as(witnessTableInfo)) - return makeExistential(cast(collectionTag->getOperand(0))); + return makeExistential(cast(collectionTag->getCollection())); SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } @@ -963,7 +963,7 @@ struct TypeFlowSpecializationContext { HashSet results; forEachInCollection( - cast(tagType->getOperand(0)), + cast(tagType->getCollection()), [&](IRInst* table) { results.add(findWitnessTableEntry(cast(table), key)); }); return makeTagType(cBuilder.makeSet(results)); @@ -986,7 +986,7 @@ struct TypeFlowSpecializationContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return makeTagType(cast(taggedUnion->getOperand(1))); + return makeTagType(taggedUnion->getTableCollection()); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } @@ -1003,7 +1003,7 @@ struct TypeFlowSpecializationContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return makeTagType(cast(taggedUnion->getOperand(0))); + return makeTagType(taggedUnion->getTypeCollection()); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); } @@ -1020,14 +1020,14 @@ struct TypeFlowSpecializationContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return cast(taggedUnion->getOperand(0)); + return taggedUnion->getTypeCollection(); return none(); } IRInst* analyzeSpecialize(IRInst* context, IRSpecialize* inst) { - auto operand = inst->getOperand(0); + auto operand = inst->getBase(); auto operandInfo = tryGetInfo(context, operand); if (!operandInfo) @@ -1074,13 +1074,12 @@ struct TypeFlowSpecializationContext if (auto argCollectionTag = as(argInfo)) { - if (getCollectionCount(argCollectionTag) == 1) - specializationArgs.add(getCollectionElement(argCollectionTag, 0)); + if (argCollectionTag->isSingleton()) + specializationArgs.add(argCollectionTag->getCollection()->getElement(0)); else { needsTag = true; - specializationArgs.add( - cast(argCollectionTag->getOperand(0))); + specializationArgs.add(argCollectionTag->getCollection()); } } else @@ -1100,12 +1099,10 @@ struct TypeFlowSpecializationContext { if (auto infoCollectionTag = as(info)) { - if (getCollectionCount(infoCollectionTag) == 1) - return getCollectionElement(infoCollectionTag, 0); + if (infoCollectionTag->isSingleton()) + return infoCollectionTag->getCollection()->getElement(0); else - { - return as(infoCollectionTag->getOperand(0)); - } + return infoCollectionTag->getCollection(); } else return type; @@ -1132,15 +1129,15 @@ struct TypeFlowSpecializationContext // if (auto tag = as(typeInfo)) { - SLANG_ASSERT(getCollectionCount(tag) == 1); - auto specializeInst = cast(getCollectionElement(tag, 0)); + SLANG_ASSERT(tag->isSingleton()); + auto specializeInst = cast(tag->getCollection()->getElement(0)); auto specializedFuncType = cast(specializeGeneric(specializeInst)); typeOfSpecialization = specializedFuncType; } else if (auto collection = as(typeInfo)) { - SLANG_ASSERT(getCollectionCount(collection) == 1); - auto specializeInst = cast(getCollectionElement(collection, 0)); + SLANG_ASSERT(collection->isSingleton()); + auto specializeInst = cast(collection->getElement(0)); auto specializedFuncType = cast(specializeGeneric(specializeInst)); typeOfSpecialization = specializedFuncType; } @@ -1171,7 +1168,7 @@ struct TypeFlowSpecializationContext else if (auto collectionTagType = as(operandInfo)) { needsTag = true; - collection = cast(collectionTagType->getOperand(0)); + collection = collectionTagType->getCollection(); } // Specialize each element in the set @@ -2148,7 +2145,7 @@ struct TypeFlowSpecializationContext return true; } - if (auto typeCollection = as(collectionTagType->getOperand(0))) + if (auto typeCollection = as(collectionTagType->getCollection())) { // If this is a type collection, we can replace it with the collection type // We don't currently care about the tag of a type. @@ -2241,17 +2238,17 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertBefore(inst); - if (getCollectionCount(collectionTagType) == 1) + if (collectionTagType->isSingleton()) { // Found a single possible type. Simple replacement. - auto singletonValue = getCollectionElement(collectionTagType, 0); + auto singletonValue = collectionTagType->getCollection()->getElement(0); inst->replaceUsesWith(singletonValue); inst->removeAndDeallocate(); return true; } // Replace the instruction with the collection type. - inst->replaceUsesWith(collectionTagType->getOperand(0)); + inst->replaceUsesWith(collectionTagType->getCollection()); inst->removeAndDeallocate(); return true; } @@ -2351,8 +2348,9 @@ struct TypeFlowSpecializationContext as(newType)) { // Merge the elements of both tagged unions into a new tuple type - return (IRType*)makeExistential((as( - updateType((IRType*)currentType->getOperand(1), (IRType*)newType->getOperand(1))))); + return (IRType*)makeExistential((as(updateType( + (IRType*)as(currentType)->getTableCollection(), + (IRType*)as(newType)->getTableCollection())))); } else if (isTaggedUnionType(currentType) && isTaggedUnionType(newType)) { @@ -2618,17 +2616,9 @@ struct TypeFlowSpecializationContext // if (auto collectionTag = as(callee->getDataType())) { - if (getCollectionCount(collectionTag) > 1) - { + if (!collectionTag->isSingleton()) calleeTagInst = callee; // Only keep the tag if there are multiple elements. - - // If we're placing a specialized call, use the base tag since the - // specialization arguments will also become arguments to the call. - // - // if (auto specializedTag = as(calleeTagInst)) - // calleeTagInst = specializedTag->getBase(); - } - callee = collectionTag->getOperand(0); + callee = collectionTag->getCollection(); } // If by this point, we haven't resolved our callee into a global inst ( @@ -2806,8 +2796,8 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(inst); // Collect types from the witness tables to determine the any-value type - auto tableCollection = as(taggedUnion->getOperand(1)); - auto typeCollection = as(taggedUnion->getOperand(0)); + auto tableCollection = taggedUnion->getTableCollection(); + auto typeCollection = taggedUnion->getTypeCollection(); IRInst* witnessTableID = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) @@ -2827,9 +2817,8 @@ struct TypeFlowSpecializationContext } // Create the appropriate any-value type - auto collectionType = getCollectionCount(typeCollection) == 1 - ? (IRType*)typeCollection->getOperand(0) - : (IRType*)typeCollection; + auto collectionType = typeCollection->isSingleton() ? (IRType*)typeCollection->getElement(0) + : (IRType*)typeCollection; // Pack the value auto packedValue = as(collectionType) @@ -2854,7 +2843,7 @@ struct TypeFlowSpecializationContext if (!taggedUnion) return false; - auto taggedUnionType = getLoweredType(taggedUnion); + auto taggedUnionType = as(getLoweredType(taggedUnion)); IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -2863,25 +2852,27 @@ struct TypeFlowSpecializationContext args.add(inst->getDataType()); args.add(inst->getTypeID()); auto translatedTag = builder.emitIntrinsicInst( - (IRType*)makeTagType(as(taggedUnionType->getOperand(1))), + (IRType*)makeTagType(taggedUnionType->getTableCollection()), kIROp_GetTagFromSequentialID, args.getCount(), args.getBuffer()); IRInst* packedValue = nullptr; - auto collection = as(taggedUnionType->getOperand(0)); - if (getCollectionCount(collection) > 1) + auto collection = taggedUnionType->getTypeCollection(); + if (!collection->isSingleton()) { packedValue = builder.emitPackAnyValue((IRType*)collection, inst->getValue()); } else { - packedValue = - builder.emitReinterpret((IRType*)taggedUnionType->getOperand(0), inst->getValue()); + packedValue = builder.emitReinterpret( + (IRType*)taggedUnionType->getTypeCollection(), + inst->getValue()); } - auto newInst = - builder.emitMakeTuple(taggedUnionType, List({translatedTag, packedValue})); + auto newInst = builder.emitMakeTuple( + (IRType*)taggedUnionType, + List({translatedTag, packedValue})); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -2953,7 +2944,7 @@ struct TypeFlowSpecializationContext isFuncReturn = as(getGenericReturnVal(concreteGeneric)) != nullptr; else if (auto tagType = as(inst->getBase()->getDataType())) { - auto firstConcreteGeneric = as(getCollectionElement(tagType, 0)); + auto firstConcreteGeneric = as(tagType->getCollection()->getElement(0)); isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; } @@ -3020,7 +3011,7 @@ struct TypeFlowSpecializationContext { // If this is a tag type, replace with collection. changed = true; - args.add(collectionTagType->getOperand(0)); + args.add(collectionTagType->getCollection()); } else { @@ -3169,8 +3160,7 @@ struct TypeFlowSpecializationContext { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - auto firstElement = - getCollectionElement(as(tagType->getOperand(0)), 0); + auto firstElement = tagType->getCollection()->getElement(0); auto interfaceType = as(as(firstElement)->getConformanceType()); List args = {interfaceType, arg}; @@ -3196,8 +3186,7 @@ struct TypeFlowSpecializationContext { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - auto firstElement = - getCollectionElement(as(tagType->getOperand(0)), 0); + auto firstElement = tagType->getCollection()->getElement(0); auto interfaceType = as(as(firstElement)->getConformanceType()); From c5f0a5f61e80c48ad7d486533ace4850a7b79be1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 7 Oct 2025 16:13:19 -0400 Subject: [PATCH 055/105] Merge with ToT, add some comments for inst, make specialization dictionaries faster --- source/slang/slang-emit.cpp | 3 +- source/slang/slang-ir-autodiff-rev.cpp | 18 +-- source/slang/slang-ir-insts.lua | 66 ++++++++++- source/slang/slang-ir-link.cpp | 19 ++++ source/slang/slang-ir-specialize.cpp | 99 ++++++++++++----- source/slang/slang-ir-typeflow-specialize.cpp | 103 +++++++++--------- source/slang/slang-ir-typeflow-specialize.h | 2 + .../slang/slang-ir-witness-table-wrapper.cpp | 4 +- source/slang/slang-lower-to-ir.cpp | 20 +++- 9 files changed, 232 insertions(+), 102 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 24d8d0d69e1..5a07c930654 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1143,8 +1143,9 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } + // Tagged union type lowering typically generates more reinterpret instructions. if (lowerTaggedUnionTypes(irModule, sink)) - requiredLoweringPassSet.reinterpret = true; // TODO: Is this the right way to handle this? + requiredLoweringPassSet.reinterpret = true; lowerTagInsts(irModule, sink); lowerTypeCollections(irModule, sink); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 3cb7fc36449..be1687289b7 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -1435,24 +1435,24 @@ InstPair BackwardDiffTranscriberBase::transcribeSpecialize( { args.add(primalSpecialize->getArg(i)); } + IRType* typeForSpecialization = nullptr; - if ((*diffBase)->getDataType()->getOp() == kIROp_TypeKind || - (*diffBase)->getDataType()->getOp() == kIROp_GenericKind) + switch ((*diffBase)->getDataType()->getOp()) { + case kIROp_TypeKind: + case kIROp_GenericKind: typeForSpecialization = (*diffBase)->getDataType(); - } - else if ((*diffBase)->getDataType()->getOp() == kIROp_Generic) - { + break; + case kIROp_Generic: typeForSpecialization = (IRType*)builder->emitSpecializeInst( builder->getTypeKind(), (*diffBase)->getDataType(), args.getCount(), args.getBuffer()); - } - else - { - // Default to type kind for now. + break; + default: typeForSpecialization = builder->getTypeKind(); + break; } auto diffSpecialize = builder->emitSpecializeInst( diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index e36d2d2735c..b5223ddd436 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2187,21 +2187,77 @@ local insts = { }, }, { - -- A collection of IR instructions used for propagation analysis - -- The operands are the elements of the set, sorted by unique ID to ensure canonical ordering TypeFlowData = { + -- A collection of IR instructions used for propagation analysis. hoistable = true, { CollectionBase = { + -- Base class for all collection types. + -- + -- Semantically, collections model sets of concrete values, and use Slang's de-duplication infrastructure + -- to allow set-equality to be the same as inst identity. + -- + -- - Collection ops have one or more operands that represent the elements of the collection + -- + -- - Collection ops must have at least one operand. A zero-operand collection is illegal. + -- The type-flow pass will represent this case using nullptr, so that uniqueness is preserved. + -- + -- - All operands of a collection _must_ be concrete, individual insts + -- - Operands should NOT be an interface or abstract type. + -- - Operands should NOT be type parameters or existentail types (i.e. insts that appear in blocks) + -- - Operands should NOT be collections (i.e. collections should be flat and never heirarchical) + -- + -- - Since collections are hositable, collection ops should (consequently) only appear in the global scope. + -- + -- - Collection operands must be consistently sorted. i.e. a TypeCollection(A, B) and TypeCollection(B, A) + -- cannot exist at the same time, but either one is okay. + -- + -- - To help with the implementation of collections, the CollectionBuilder class is provided + -- in slang-ir-typeflow-collection.h. + -- All collection insts must be built using the CollectionBuilder, which uses a persistent map on the module + -- inst to ensure stable ordering. + -- { TypeCollection = {} }, { FuncCollection = {} }, { TableCollection = {} }, { GenericCollection = {} }, }, }, - { UnboundedCollection = {} }, - { CollectionTagType = {} }, -- Operand is TypeCollection/FuncCollection/TableCollection (funcs/tables) - { CollectionTaggedUnionType = {}} -- Operand is TypeCollection, TableCollection for existential + { UnboundedCollection = { + -- + -- A catch-all opcode to represent unbounded collections during + -- the type-flow specialization pass. + -- + -- This op is usually used to mark insts that can contain a dynamic type + -- whose information cannot be gleaned from the type-flow analysis. + -- + -- E.g. COM interface objects, whose implementations can be fully external to + -- the linkage + -- + -- This op is only used to denote that an inst is unbounded so the specialization + -- pass does not attempt to specialize it. It should not appear in the code after + -- the specialization pass. + -- + } }, + { CollectionTagType = { + -- Represents a tag-type for a collection. + -- + -- An inst whose type is CollectionTagType(collection) is semantically carrying a run-time value that points to + -- one of the elements of the collection operand. + -- + -- Only operand is a CollectionBase + } }, + { CollectionTaggedUnionType = { + -- Represents a tagged union type. + -- + -- An inst whose type is a CollectionTaggedUnionType(typeCollection, tableCollection) is semantically carrying a tuple of + -- two values: a value of CollectionTagType(tableCollection) to represent the tag, and a payload value of type + -- typeCollection (which conceptually represents a union/"anyvalue" type) + -- + -- This is most commonly used to specialize the type of existential insts once the possibilities can be statically determined. + -- + -- Operands are a TypeCollection and a TableCollection that represent the possibilities of the existential + }} }, }, { CastInterfaceToTaggedUnionPtr = { diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 45d8cf8594e..0c386484195 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -806,6 +806,18 @@ IRStructType* cloneStructTypeImpl( return clonedStruct; } +IREnumType* cloneEnumTypeImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IREnumType* originalEnum, + IROriginalValuesForClone const& originalValues) +{ + auto clonedEnum = + builder->createEnumType(cloneType(context, (IRType*)originalEnum->getOperand(0))); + cloneSimpleGlobalValueImpl(context, originalEnum, originalValues, clonedEnum); + return clonedEnum; +} + IRInterfaceType* cloneInterfaceTypeImpl( IRSpecContextBase* context, @@ -1389,6 +1401,9 @@ IRInst* cloneInst( cast(originalInst), originalValues); + case kIROp_EnumType: + return cloneEnumTypeImpl(context, builder, cast(originalInst), originalValues); + case kIROp_InterfaceType: return cloneInterfaceTypeImpl( context, @@ -2401,6 +2416,10 @@ struct IRPrelinkContext : IRSpecContext case kIROp_ClassType: clonedInst = builderForClone->createClassType(); break; + case kIROp_EnumType: + clonedInst = builderForClone->createEnumType( + cloneType(this, (IRType*)cast(originalVal)->getOperand(0))); + break; default: return completeClonedInst(IRSpecContext::maybeCloneValue(originalVal)); } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index cc14532f93e..b3dc273c3a3 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -272,7 +272,7 @@ struct SpecializationContext // using the simple key type defined as part of the IR cloning infrastructure. // typedef IRSimpleSpecializationKey Key; - Dictionary genericSpecializations; + Dictionary genericSpecializations; // Now let's look at the task of finding or generation a @@ -329,9 +329,14 @@ struct SpecializationContext // existing specialization that has been registered. // If one is found, our work is done. // - IRInst* specializedVal = nullptr; - if (genericSpecializations.tryGetValue(key, specializedVal)) - return specializedVal; + IRSpecializationDictionaryItem* specializationEntry = nullptr; + if (genericSpecializations.tryGetValue(key, specializationEntry)) + { + if (specializationEntry->getOperand(0)->getOp() != kIROp_Undefined) + return specializationEntry->getOperand(0); + else + genericSpecializations.remove(key); + } } // If no existing specialization is found, we need @@ -357,8 +362,12 @@ struct SpecializationContext // specializations so that we don't instantiate // this generic again for the same arguments. // - genericSpecializations.add(key, specializedVal); - addEntryToIRDictionary(kIROp_GenericSpecializationDictionary, key.vals, specializedVal); + genericSpecializations.add( + key, + addEntryToIRDictionary( + kIROp_GenericSpecializationDictionary, + key.vals, + specializedVal)); return specializedVal; } @@ -1039,10 +1048,7 @@ struct SpecializationContext dict.remove(key); continue; } - auto value = as::type>( - item->getOperand(0)); - SLANG_ASSERT(value); - dict[key] = value; + dict[key] = item; } // Clean up the IR dictionary @@ -1085,7 +1091,10 @@ struct SpecializationContext return dictInst; } - void addEntryToIRDictionary(IROp dictOp, const List& key, IRInst* val) + IRSpecializationDictionaryItem* addEntryToIRDictionary( + IROp dictOp, + const List& key, + IRInst* val) { auto dictInst = getOrCreateIRDictionary(dictOp); List args; @@ -1093,11 +1102,11 @@ struct SpecializationContext args.addRange(key); IRBuilder builder(module); builder.setInsertInto(dictInst); - builder.emitIntrinsicInst( + return cast(builder.emitIntrinsicInst( nullptr, kIROp_SpecializationDictionaryItem, (UInt)args.getCount(), - args.getBuffer()); + args.getBuffer())); } // All of the machinery for generic specialization @@ -1216,7 +1225,7 @@ struct SpecializationContext applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); // Sync our local dictionary with the one in the IR. - readSpecializationDictionaries(); + // readSpecializationDictionaries(); } // Once the work list has gone dry, we should have the invariant @@ -1237,7 +1246,7 @@ struct SpecializationContext eliminateDeadCode(module->getModuleInst()); // Sync our local dictionary with the one in the IR. - readSpecializationDictionaries(); + // readSpecializationDictionaries(); } } @@ -1560,18 +1569,32 @@ struct SpecializationContext // Once we've constructed our key, we can try to look for an // existing specialization of the callee that we can use. // + IRSpecializationDictionaryItem* specializedCalleeEntry = nullptr; IRFunc* specializedCallee = nullptr; - if (!existentialSpecializedFuncs.tryGetValue(key, specializedCallee)) + if (existentialSpecializedFuncs.tryGetValue(key, specializedCalleeEntry)) + { + if (specializedCalleeEntry->getOperand(0)->getOp() != kIROp_Undefined) + { + specializedCallee = cast(specializedCalleeEntry->getOperand(0)); + } + else + { + existentialSpecializedFuncs.remove(key); + } + } + + if (!specializedCallee) { // If we didn't find a specialized callee already made, then we // will go ahead and create one, and then register it in our cache. // specializedCallee = createExistentialSpecializedFunc(inst, calleeFunc); - existentialSpecializedFuncs.add(key, specializedCallee); - addEntryToIRDictionary( - kIROp_ExistentialFuncSpecializationDictionary, - key.vals, - specializedCallee); + existentialSpecializedFuncs.add( + key, + addEntryToIRDictionary( + kIROp_ExistentialFuncSpecializationDictionary, + key.vals, + specializedCallee)); } // At this point we have found or generated a specialized version @@ -1795,7 +1818,8 @@ struct SpecializationContext // In order to cache and re-use functions that have had existential-type // parameters specialized, we need storage for the cache. // - Dictionary existentialSpecializedFuncs; + Dictionary + existentialSpecializedFuncs; // The logic for creating a specialized callee function by plugging // in concrete types for existentials is similar to other cases of @@ -2600,7 +2624,8 @@ struct SpecializationContext } } - Dictionary existentialSpecializedStructs; + Dictionary + existentialSpecializedStructs; bool maybeSpecializeBindExistentialsType(IRBindExistentialsType* type) { @@ -2694,10 +2719,23 @@ struct SpecializationContext key.vals.add(type->getExistentialArg(ii)); } - IRStructType* newStructType = nullptr; + IRSpecializationDictionaryItem* newStructTypeEntry = nullptr; addUsersToWorkList(type); - if (!existentialSpecializedStructs.tryGetValue(key, newStructType)) + IRStructType* newStructType = nullptr; + if (existentialSpecializedStructs.tryGetValue(key, newStructTypeEntry)) + { + if (newStructTypeEntry->getOperand(0)->getOp() != kIROp_Undefined) + { + newStructType = cast(newStructTypeEntry->getOperand(0)); + } + else + { + existentialSpecializedStructs.remove(key); + } + } + + if (!newStructType) { builder.setInsertBefore(baseStructType); newStructType = builder.createStructType(); @@ -2725,11 +2763,12 @@ struct SpecializationContext builder.createStructField(newStructType, oldField->getKey(), newFieldType); } - existentialSpecializedStructs.add(key, newStructType); - addEntryToIRDictionary( - kIROp_ExistentialTypeSpecializationDictionary, - key.vals, - newStructType); + existentialSpecializedStructs.add( + key, + addEntryToIRDictionary( + kIROp_ExistentialTypeSpecializationDictionary, + key.vals, + newStructType)); } type->replaceUsesWith(newStructType); diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index d1e7c560702..95f94914ce0 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -238,6 +238,30 @@ struct WorkQueue } }; +bool isDynamicGeneric(IRInst* callee) +{ + // If the callee is a specialization, and at least one of its arguments + // is a type-flow-collection, then it is a dynamic generic. + // + if (auto specialize = as(callee)) + { + for (UInt i = 0; i < specialize->getArgCount(); i++) + { + // Only functions need dynamic-aware specialization. + auto generic = specialize->getBase(); + if (getGenericReturnVal(generic)->getOp() != kIROp_Func) + return false; + + auto arg = specialize->getArg(i); + if (as(arg)) + return true; // Found a type-flow-collection argument + } + return false; // No type-flow-collection arguments found + } + + return false; +} + struct TypeFlowSpecializationContext { struct ParameterDirectionInfo @@ -245,13 +269,13 @@ struct TypeFlowSpecializationContext enum Kind { In, + BorrowIn, Out, - InOut, - Ref, - ConstRef + BorrowInOut, + Ref } kind; - // For Ref and ConstRef + // For Ref and BorrowInOut AddressSpace addressSpace; ParameterDirectionInfo(Kind kind, AddressSpace addressSpace = (AddressSpace)0) @@ -341,9 +365,6 @@ struct TypeFlowSpecializationContext if (as(inst) && as(getGenericReturnVal(inst))) return none(); - // TODO: We really should return something like Singleton(collectionInst) here - // instead of directly returning the collection. - // return cBuilder.makeSingletonSet(inst); } else @@ -695,8 +716,6 @@ struct TypeFlowSpecializationContext break; } - // TODO: Remove this workaround.. there are a few insts - // where we shouldn't bool takeUnion = !as(inst); updateInfo(context, inst, info, takeUnion, workQueue); } @@ -1597,8 +1616,8 @@ struct TypeFlowSpecializationContext switch (paramDirection.kind) { case ParameterDirectionInfo::Kind::Out: - case ParameterDirectionInfo::Kind::InOut: - case ParameterDirectionInfo::Kind::ConstRef: + case ParameterDirectionInfo::Kind::BorrowInOut: + case ParameterDirectionInfo::Kind::BorrowIn: { IRBuilder builder(module); if (!argInfo) @@ -1659,7 +1678,8 @@ struct TypeFlowSpecializationContext if (paramInfo) { if (paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::Out || - paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::InOut) + paramDirections[argIndex].kind == + ParameterDirectionInfo::Kind::BorrowInOut) { auto arg = callInst->getArg(argIndex); auto argPtrType = as(arg->getDataType()); @@ -2256,26 +2276,26 @@ struct TypeFlowSpecializationContext // Split into direction and type std::tuple getParameterDirectionAndType(IRType* paramType) { - if (as(paramType)) + if (as(paramType)) return { ParameterDirectionInfo(ParameterDirectionInfo::Kind::Out), - as(paramType)->getValueType()}; - else if (as(paramType)) + as(paramType)->getValueType()}; + else if (as(paramType)) return { - ParameterDirectionInfo(ParameterDirectionInfo::Kind::InOut), - as(paramType)->getValueType()}; - else if (as(paramType)) + ParameterDirectionInfo(ParameterDirectionInfo::Kind::BorrowInOut), + as(paramType)->getValueType()}; + else if (as(paramType)) return { ParameterDirectionInfo( ParameterDirectionInfo::Kind::Ref, - as(paramType)->getAddressSpace()), - as(paramType)->getValueType()}; - else if (as(paramType)) + as(paramType)->getAddressSpace()), + as(paramType)->getValueType()}; + else if (as(paramType)) return { ParameterDirectionInfo( - ParameterDirectionInfo::Kind::ConstRef, - as(paramType)->getAddressSpace()), - as(paramType)->getValueType()}; + ParameterDirectionInfo::Kind::BorrowIn, + as(paramType)->getAddressSpace()), + as(paramType)->getValueType()}; else return {ParameterDirectionInfo(ParameterDirectionInfo::Kind::In), paramType}; } @@ -2287,13 +2307,13 @@ struct TypeFlowSpecializationContext case ParameterDirectionInfo::Kind::In: return type; case ParameterDirectionInfo::Kind::Out: - return builder->getOutType(type); - case ParameterDirectionInfo::Kind::InOut: - return builder->getInOutType(type); - case ParameterDirectionInfo::Kind::ConstRef: - return builder->getConstRefType(type, direction.addressSpace); + return builder->getOutParamType(type); + case ParameterDirectionInfo::Kind::BorrowInOut: + return builder->getBorrowInOutParamType(type); + case ParameterDirectionInfo::Kind::BorrowIn: + return builder->getBorrowInParamType(type, direction.addressSpace); case ParameterDirectionInfo::Kind::Ref: - return builder->getRefType(type, direction.addressSpace); + return builder->getRefParamType(type, direction.addressSpace); default: SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); } @@ -2488,25 +2508,6 @@ struct TypeFlowSpecializationContext return builder.getFuncType(allParamTypes, resultType); } - bool isDynamicGeneric(IRInst* callee) - { - // If the callee is a specialization, and at least one of its arguments - // is a type-flow-collection, then it is a dynamic generic. - // - if (auto specialize = as(callee)) - { - for (UInt i = 0; i < specialize->getArgCount(); i++) - { - auto arg = specialize->getArg(i); - if (as(arg)) - return true; // Found a type-flow-collection argument - } - return false; // No type-flow-collection arguments found - } - - return false; - } - IRInst* getCalleeForContext(IRInst* context) { if (this->contextsToLower.contains(context)) @@ -2701,8 +2702,8 @@ struct TypeFlowSpecializationContext newArgs.add(upcastCollection(context, arg, paramType)); break; case ParameterDirectionInfo::Kind::Out: - case ParameterDirectionInfo::Kind::InOut: - case ParameterDirectionInfo::Kind::ConstRef: + case ParameterDirectionInfo::Kind::BorrowInOut: + case ParameterDirectionInfo::Kind::BorrowIn: case ParameterDirectionInfo::Kind::Ref: { newArgs.add(arg); diff --git a/source/slang/slang-ir-typeflow-specialize.h b/source/slang/slang-ir-typeflow-specialize.h index 1503cc03927..1a134275085 100644 --- a/source/slang/slang-ir-typeflow-specialize.h +++ b/source/slang/slang-ir-typeflow-specialize.h @@ -6,4 +6,6 @@ namespace Slang { // Main entry point for the pass bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink); + +bool isDynamicGeneric(IRInst* callee); } // namespace Slang diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index a29c08f893b..888ddfdf901 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -275,7 +275,7 @@ IRInst* maybeUnpackArg( if (as(paramType)) { auto tempVar = builder->emitVar(paramValType); - if (as(paramType)) + if (as(paramType)) builder->emitStore( tempVar, builder->emitUnpackAnyValue(paramValType, builder->emitLoad(arg))); @@ -302,7 +302,7 @@ IRInst* maybeUnpackArg( { // if parameter expects an `out` pointer, store the unpacked val into a // variable and pass in a pointer to that variable. - if (as(paramType)) + if (as(paramType)) { auto tempVar = builder->emitVar(paramValType); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d866f617e57..b71e36adb0f 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10368,10 +10368,22 @@ struct DeclLoweringVisitor : DeclVisitor for (auto param : paramCloneInfos) { typeBuilder.setInsertInto(param.clonedParam); - param.clonedParam->setFullType((IRType*)cloneInst( - &cloneEnv, - &typeBuilder, - param.originalParam->getFullType())); + + // If the type is present in the generic (i.e. not in module scope), + // then we need to make a clone since it is likely to be dependent on the + // new parameters we just emitted and we want to obtain the effective type + // in the current context. + // + // If it's in the global scope, we're good if we use the same type. + // + if (!as(param.originalParam->getFullType()->getParent())) + param.clonedParam->setFullType((IRType*)cloneInst( + &cloneEnv, + &typeBuilder, + param.originalParam->getFullType())); + else + param.clonedParam->setFullType(param.originalParam->getFullType()); + cloneInstDecorationsAndChildren( &cloneEnv, typeBuilder.getModule(), From bf29b0b575e29a65909d48b5799beb5fa6b7c7aa Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 9 Oct 2025 13:39:26 -0400 Subject: [PATCH 056/105] Cleanup + add details comments --- source/slang/slang-ir-specialize.h | 2 + source/slang/slang-ir-typeflow-specialize.cpp | 2034 +++++++++++------ 2 files changed, 1336 insertions(+), 700 deletions(-) diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index 72f2c6130f7..eb92e7faa85 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -24,4 +24,6 @@ bool specializeModule( void finalizeSpecialization(IRModule* module); +IRInst* specializeGeneric(IRSpecialize* specInst); + } // namespace Slang diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 95f94914ce0..00d7cbdd4b1 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -14,27 +14,43 @@ namespace Slang { -// Forward-declare.. (TODO: Just include this from the header instead) -IRInst* specializeGeneric(IRSpecialize* specializeInst); - -// Elements for which we keep track of propagation information. -struct Element +// Basic unit for which we keep track of propagation information. +// +// This unit has two components: an 'inst' and a 'context' under which we +// are recording propagation info. +// +// The 'inst' must be inside a block (with either generic or func parent), since +// we assume everything in the global scope is concrete. +// +// The 'context' can be one of two cases: +// 1. an IRFunc ONLY if it is not generic (func is in the global scope). 'inst' must +// be inside the func. +// 2. an IRSpecialize(generic, ...). 'inst' must be inside the generic. the +// specialization args must all be global values (either concrete types/values, or collections). +// +// All other possibilites for 'context' are illegal. +// `InstWithContext::validateInstWithContext` enforces these rules. +// +// For an inst inside a generic, it is possible to have different propagation information +// depending on the specialization args, which is why we use the pair to keep track of the context. +// +struct InstWithContext { IRInst* context; IRInst* inst; - Element() + InstWithContext() : context(nullptr), inst(nullptr) { } - Element(IRInst* context, IRInst* inst) + InstWithContext(IRInst* context, IRInst* inst) : context(context), inst(inst) { - validateElement(); + validateInstWithContext(); } - void validateElement() const + void validateInstWithContext() const { switch (context->getOp()) { @@ -61,15 +77,15 @@ struct Element break; default: { - SLANG_UNEXPECTED("Invalid context for Element"); + SLANG_UNEXPECTED("Invalid context for InstWithContext"); } } } - // Create element from an instruction that has a - // concrete parent (i.e. global IRFunc) + // If a context is not specified, we assume it is not in a generic, and + // simply use the parent func. // - Element(IRInst* inst) + InstWithContext(IRInst* inst) : inst(inst) { auto block = cast(inst->getParent()); @@ -84,16 +100,32 @@ struct Element context = func; } - bool operator==(const Element& other) const + bool operator==(const InstWithContext& other) const { return context == other.context && inst == other.inst; } - // getHashCode() HashCode64 getHashCode() const { return combineHash(HashCode(context), HashCode(inst)); } }; -// Data structures for interprocedural data-flow analysis +// Test if inst represents a pointer to a global resource. +bool isResourcePointer(IRInst* inst) +{ + if (isPointerToResourceType(inst->getDataType()) || + inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) + return true; + + if (as(inst)) + return true; + + if (auto ptr = as(inst)) + return isResourcePointer(ptr->getBase()); + + if (auto fieldAddress = as(inst)) + return isResourcePointer(fieldAddress->getBase()); + + return false; +} // Represents an interprocedural edge between call sites and functions struct InterproceduralEdge @@ -107,7 +139,7 @@ struct InterproceduralEdge Direction direction; IRInst* callerContext; // The context of the call (e.g. function or specialized generic) IRCall* callInst; // The call instruction - IRInst* targetContext; // The function/specialized-generic being called/returned from + IRInst* targetContext; // The function/specialized-generic being called or returned from InterproceduralEdge() = default; InterproceduralEdge(Direction dir, IRInst* callerContext, IRCall* call, IRInst* func) @@ -116,8 +148,13 @@ struct InterproceduralEdge } }; - -// Union type representing either an intra-procedural or interprocedural edge +// Representation of a work item used to register work for the main propagation queue. +// When the propagation information for a particular inst is modified non-trivially, new +// 'WorkItem' objects are added to the queue to further propagate the changes. +// +// The "Type" captures the granularity & type of the propagation work, while the union +// holds on to any auxiliary information. +// struct WorkItem { enum class Type @@ -149,7 +186,7 @@ struct WorkItem { SLANG_ASSERT(context != nullptr && inst != nullptr); // Validate that the context is appropriate for the instruction - Element(context, inst).validateElement(); + InstWithContext(context, inst).validateInstWithContext(); } WorkItem(IRInst* context, IRBlock* block) @@ -157,7 +194,7 @@ struct WorkItem { SLANG_ASSERT(context != nullptr && block != nullptr); // Validate that the context is appropriate for the block - Element(context, block->getFirstChild()).validateElement(); + InstWithContext(context, block->getFirstChild()).validateInstWithContext(); } WorkItem(IRInst* context, IREdge edge) @@ -206,11 +243,22 @@ struct WorkItem } }; +// Returns true if the two propagation infos are equal. bool areInfosEqual(IRInst* a, IRInst* b) { + // Since all inst opcodes that are used to represent propagation information + // are hoistable and automatically de-duplicated by the Slang IR infrastructure, + // we can simply test pointer equality + // return a == b; } +// Helper data-structure for efficient enqueueing/dequeueing of work items. +// +// Our 'List' data-structure are currently designed to be efficient at operating as a stack +// but have poor performance for queue-like operations, so this rolls two stacks into a queue +// structure. +// struct WorkQueue { List enqueueList; @@ -238,15 +286,22 @@ struct WorkQueue } }; +// Tests whether a generic can be fully specialized, or if it requires a dynamic information. +// +// This test is primarily used to determine if additional parameters are requried to place a call to +// this callee. +// bool isDynamicGeneric(IRInst* callee) { + // // If the callee is a specialization, and at least one of its arguments - // is a type-flow-collection, then it is a dynamic generic. + // is a collection, then it needs dynamic-dispatch logic to be generated. // if (auto specialize = as(callee)) { for (UInt i = 0; i < specialize->getArgCount(); i++) { + // Only functions need dynamic-aware specialization. auto generic = specialize->getBase(); if (getGenericReturnVal(generic)->getOp() != kIROp_Func) @@ -262,43 +317,136 @@ bool isDynamicGeneric(IRInst* callee) return false; } -struct TypeFlowSpecializationContext +// +// Helper struct to represent a parameter's direction and type component. +// This is used by the type flow system to figure out which direction to propagate +// information for each parameter. +// +struct ParameterDirectionInfo { - struct ParameterDirectionInfo + enum Kind { - enum Kind - { - In, - BorrowIn, - Out, - BorrowInOut, - Ref - } kind; + In, + BorrowIn, + Out, + BorrowInOut, + Ref + } kind; - // For Ref and BorrowInOut - AddressSpace addressSpace; + // For Ref and BorrowInOut + AddressSpace addressSpace; - ParameterDirectionInfo(Kind kind, AddressSpace addressSpace = (AddressSpace)0) - : kind(kind), addressSpace(addressSpace) - { - } + ParameterDirectionInfo(Kind kind, AddressSpace addressSpace = (AddressSpace)0) + : kind(kind), addressSpace(addressSpace) + { + } - ParameterDirectionInfo() - : kind(Kind::In), addressSpace((AddressSpace)0) - { - } + ParameterDirectionInfo() + : kind(Kind::In), addressSpace((AddressSpace)0) + { + } - bool operator==(const ParameterDirectionInfo& other) const - { - return kind == other.kind && addressSpace == other.addressSpace; - } - }; + bool operator==(const ParameterDirectionInfo& other) const + { + return kind == other.kind && addressSpace == other.addressSpace; + } +}; + +// Split parameter type into a direction and a type +std::tuple splitParameterDirectionAndType(IRType* paramType) +{ + if (as(paramType)) + return { + ParameterDirectionInfo(ParameterDirectionInfo::Kind::Out), + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirectionInfo(ParameterDirectionInfo::Kind::BorrowInOut), + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirectionInfo( + ParameterDirectionInfo::Kind::Ref, + as(paramType)->getAddressSpace()), + as(paramType)->getValueType()}; + else if (as(paramType)) + return { + ParameterDirectionInfo( + ParameterDirectionInfo::Kind::BorrowIn, + as(paramType)->getAddressSpace()), + as(paramType)->getValueType()}; + else + return {ParameterDirectionInfo(ParameterDirectionInfo::Kind::In), paramType}; +} + +// Join parameter direction and a type back into a parameter type +IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo direction, IRType* type) +{ + switch (direction.kind) + { + case ParameterDirectionInfo::Kind::In: + return type; + case ParameterDirectionInfo::Kind::Out: + return builder->getOutParamType(type); + case ParameterDirectionInfo::Kind::BorrowInOut: + return builder->getBorrowInOutParamType(type); + case ParameterDirectionInfo::Kind::BorrowIn: + return builder->getBorrowInParamType(type, direction.addressSpace); + case ParameterDirectionInfo::Kind::Ref: + return builder->getRefParamType(type, direction.addressSpace); + default: + SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); + } +} + +// Helper to check if an IRParam is a function parameter (vs. a phi param or generic param) +bool isFuncParam(IRParam* param) +{ + auto paramBlock = as(param->getParent()); + auto paramFunc = as(paramBlock->getParent()); + return (paramFunc && paramFunc->getFirstBlock() == paramBlock); +} + +// Helper to test if an inst is in the global scope. +bool isGlobalInst(IRInst* inst) +{ + return inst->getParent()->getOp() == kIROp_ModuleInst; +} + +// Helper to test if a function or generic contains a body (i.e. is intrinsic/external) +// For the purposes of type-flow, if a function body is not available, we can't analyze it. +// +bool isIntrinsic(IRInst* inst) +{ + auto func = as(inst); + if (auto specialize = as(inst)) + { + auto generic = specialize->getBase(); + func = as(getGenericReturnVal(generic)); + } + + if (!func) + return false; + + if (func->getFirstBlock() == nullptr) + return true; + return false; +} +// Parent context for the full type-flow pass. +struct TypeFlowSpecializationContext +{ + // Create a tagged-union-type out of a given collection of tables. + // + // This type can be used for insts that are semantically a tuple of a tag (to select a table) + // and a payload to contain the existential value. + // IRCollectionTaggedUnionType* makeExistential(IRTableCollection* tableCollection) { HashSet typeSet; - // Collect all types from the witness tables + + // Create a type collection out of the base types from each table. forEachInCollection( tableCollection, [&](IRInst* witnessTable) @@ -309,7 +457,7 @@ struct TypeFlowSpecializationContext auto typeCollection = cBuilder.createCollection(kIROp_TypeCollection, typeSet); - // Create the tagged union type + // Create the tagged union type out of the type and table collection. IRBuilder builder(module); List elements = {typeCollection, tableCollection}; return as(builder.emitIntrinsicInst( @@ -319,6 +467,14 @@ struct TypeFlowSpecializationContext elements.getBuffer())); } + // Create an unbounded collection. + // + // This collection is a catch-all for + // all cases where we can't enumerate the possibilites. We use this as + // a sentinel value to figure out when NOT to specialize a given inst. + // + // Most commonly occurs with COM interface types. + // IRUnboundedCollection* makeUnbounded() { IRBuilder builder(module); @@ -326,17 +482,30 @@ struct TypeFlowSpecializationContext builder.emitIntrinsicInst(nullptr, kIROp_UnboundedCollection, 0, nullptr)); } + // Creates an 'empty' inst (denoted by nullptr), that + // can be used to denote one of two things: + // + // 1. This inst does not have any dynamic components to specialize. + // e.g. an inst with a concrete int-type. + // 2. No possibilties have been propagated for this inst yet. This is the + // default starting state of all insts. + // + // From an order-theoretic perspective, 'none' is the bottom of the lattice. + // IRTypeFlowData* none() { return nullptr; } - IRInst* tryGetInfo(Element element) + IRInst* _tryGetInfo(InstWithContext element) { - // For non-global instructions, look up in the map auto found = propagationMap.tryGetValue(element); if (found) return *found; - return none(); + return none(); // Default info for any inst that we haven't registered. } + // + // Bottleneck method to fetch the current propagation info + // for a given instruction under context. + // IRInst* tryGetInfo(IRInst* context, IRInst* inst) { if (auto typeFlowData = as(inst->getDataType())) @@ -347,10 +516,17 @@ struct TypeFlowSpecializationContext return typeFlowData; } + // A small check for de-allocated insts. if (!inst->getParent()) return none(); - // If this is a global instruction (parent is module), return concrete info + // If this is a global instruction (parent is module), return a singleton set of + // that inst. + // + // Since it's easy to tell when an inst is representing a concrete + // entity, we do this on demand rather than trying to put it in the + // propagation map. + // if (as(inst->getParent())) { if (as(inst) || as(inst) || as(inst) || @@ -371,21 +547,127 @@ struct TypeFlowSpecializationContext return none(); } - return tryGetInfo(Element(context, inst)); + return _tryGetInfo(InstWithContext(context, inst)); } - IRInst* tryGetFuncReturnInfo(IRFunc* func) + // Performs set-union over the two collections, and returns a new + // inst to represent the collection. + // + template + T* unionCollection(T* collection1, T* collection2) { - auto found = funcReturnInfo.tryGetValue(func); - if (found) - return *found; - return none(); + // It may be possible to accelerate this further, but we usually + // don't have to deal with overly large sets (usually 3-20 elements) + // + + SLANG_ASSERT(as(collection1) && as(collection2)); + SLANG_ASSERT(collection1->getOp() == collection2->getOp()); + + if (!collection1) + return collection2; + if (!collection2) + return collection1; + if (collection1 == collection2) + return collection1; + + HashSet allValues; + // Collect all values from both collections + forEachInCollection(collection1, [&](IRInst* value) { allValues.add(value); }); + forEachInCollection(collection2, [&](IRInst* value) { allValues.add(value); }); + + return as(cBuilder.createCollection( + collection1->getOp(), + allValues)); // Create a new collection with the union of values } - // Centralized method to update propagation info and manage work queue + // Find the union of two propagation info insts, and return and + // inst representing the result. // - // Use this when you want to propagate new information to an existing instruction. - // This will union the new info with existing info and add users to work queue if changed + IRInst* unionPropagationInfo(IRInst* info1, IRInst* info2) + { + // This is similar to unionCollection, but must consider structures that + // can be built out of collections. + // + // We allow some level of nesting of collections into other type instructions, + // to let us propagate information elegantly for pointers, parameters, arrays + // and existential tuples. + // + // A few interesting cases are missing, but could be added in easily in the future: + // - TupleType (will allow us to propagate information for each tuple element) + // - OptionalType + // + + // Basic cases: if either is null, it is considered "empty" + // if they're equal, union must be the same inst. + + if (!info1) + return info2; + + if (!info2) + return info1; + + if (areInfosEqual(info1, info2)) + return info1; + + if (as(info1) && as(info2)) + { + SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); + // If both are array types, union their element types + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getArrayType( + (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), + info1->getOperand(1)); // Keep the same size + } + + if (as(info1) && as(info2)) + { + SLANG_ASSERT(info1->getOp() == info2->getOp()); + + // If both are array types, union their element types + IRBuilder builder(module); + builder.setInsertInto(module); + return builder.getPtrTypeWithAddressSpace( + (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), + as(info1)); + } + + if (as(info1) && as(info2)) + { + // If either info is unbounded, the union is unbounded + return makeUnbounded(); + } + + // For all other cases which are structured composites of collections, + // we simply take the collection union for all the collection operands. + // + + if (as(info1) && as(info2)) + { + return makeExistential(unionCollection( + cast(info1->getOperand(1)), + cast(info2->getOperand(1)))); + } + + if (as(info1) && as(info2)) + { + return makeTagType(unionCollection( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); + } + + if (as(info1) && as(info2)) + { + return unionCollection( + cast(info1), + cast(info2)); + } + + SLANG_UNEXPECTED("Unhandled propagation info types in unionPropagationInfo"); + } + + // Centralized method to update propagation info and add + // relevant work items to the work queue if the info changed. // void updateInfo( IRInst* context, @@ -402,19 +684,12 @@ struct TypeFlowSpecializationContext return; // Update the propagation map - propagationMap[Element(context, inst)] = unionedInfo; + propagationMap[InstWithContext(context, inst)] = unionedInfo; // Add all users to appropriate work items addUsersToWorkQueue(context, inst, unionedInfo, workQueue); } - bool isFuncParam(IRParam* param) - { - auto paramBlock = as(param->getParent()); - auto paramFunc = as(paramBlock->getParent()); - return (paramFunc && paramFunc->getFirstBlock() == paramBlock); - } - void addContextUsersToWorkQueue(IRInst* context, WorkQueue& workQueue) { if (this->funcCallSites.containsKey(context)) @@ -428,10 +703,28 @@ struct TypeFlowSpecializationContext } } - // Helper to add users of an instruction to the work queue based on how they use it - // This handles intra-procedural edges, inter-procedural edges, and return value propagation + // Helper method to add new work items to the queue based on the + // flavor of the instruction whose info has been updated. + // void addUsersToWorkQueue(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { + // This method is responsible for ensuring the following property: + // + // If inst's information has changed, then all insts that may potentially depend + // (directly) on it should be added to the work queue. + // + // This includes a few cases: + // + // 1. In the default case, we simply add all users of the insts. + // + // 2. For insts that are used as phi-arguments, we add an intra-procedural + // edge to the target block. + // + // 3. For insts that are used as return values, we add an inter-procedural edge + // to all call sites. We do this by calling `updateFuncReturnInfo`, which takes + // care of modifying the workQueue. + // + if (auto param = as(inst)) if (isFuncParam(param)) addContextUsersToWorkQueue(context, workQueue); @@ -473,7 +766,7 @@ struct TypeFlowSpecializationContext } } - // Helper method to update function return info and propagate to call sites + // Helper method to update function's return info and propagate back to call sites void updateFuncReturnInfo(IRInst* callable, IRInst* returnInfo, WorkQueue& workQueue) { auto existingReturnInfo = getFuncReturnInfo(callable); @@ -483,7 +776,7 @@ struct TypeFlowSpecializationContext { funcReturnInfo[callable] = newReturnInfo; - // Add interprocedural edges to all call sites + // Add interprocedural edges from the function back to all callsites. if (funcCallSites.containsKey(callable)) { for (auto callSite : funcCallSites[callable]) @@ -498,34 +791,42 @@ struct TypeFlowSpecializationContext } } - void processBlock(IRInst* context, IRBlock* block, WorkQueue& workQueue) - { - for (auto inst : block->getChildren()) - { - // Skip parameters & terminator - if (as(inst) || as(inst)) - continue; - processInstForPropagation(context, inst, workQueue); - } - - if (auto returnInfo = as(block->getTerminator())) - { - auto valInfo = returnInfo->getVal(); - updateFuncReturnInfo(context, tryGetInfo(context, valInfo), workQueue); - } - }; - void performInformationPropagation() { - // Global worklist for interprocedural analysis + // This method contains the main loop responsible for propagating information across all + // relevant functions, generics in the call graph. + // + // The mechanism is similar to data-flow analysis: + // 1. We start by initializing the propagation info for all instructions in functions + // that may be externally called. + // + // 2. For each instruction that received a non-trivial update, we add their users to the + // queue + // for further propagation. + // + // 3. Continue (2) until no more information has changed. + // + // This process is guaranteed to terminate because our propagation 'states' (i.e. + // collection insts and their wrapped versions) form a lattice. This is an order-theoretic + // structure that implies that + // (i) each update moves us strictly 'upward', and + // (ii) that there are a finite number of possible states. + // + + // Global worklist for interprocedural analysis. WorkQueue workQueue; - // Add all global functions to worklist + // Add all global functions to worklist. + // + // This could potentially be narrowed down to just entry points, but for + // now we are being conservative. Missing a potential entry point is worse + // than analyzing something that isn't used. + // for (auto inst : module->getGlobalInsts()) if (auto func = as(inst)) discoverContext(func, workQueue); - // Process until fixed point + // Process until fixed point. while (workQueue.hasItems()) { auto item = workQueue.dequeue(); @@ -551,119 +852,12 @@ struct TypeFlowSpecializationContext } } - IRInst* upcastCollection(IRInst* context, IRInst* arg, IRType* destInfo) - { - auto argInfo = arg->getDataType(); - if (!argInfo || !destInfo) - return arg; - - if (as(argInfo) && as(destInfo)) - { - // Handle upcasting between collection tagged unions - auto argTUType = as(argInfo); - auto destTUType = as(destInfo); - if (getCollectionCount(argTUType) != getCollectionCount(destTUType)) - { - // Technically, IRCollectionTaggedUnionType is not a TupleType, - // but in practice it works the same way so we'll re-use Slang's - // tuple accessors & constructors - // - IRBuilder builder(arg->getModule()); - setInsertAfterOrdinaryInst(&builder, arg); - auto argTableTag = builder.emitGetTupleElement( - (IRType*)makeTagType(argTUType->getTableCollection()), - arg, - 0); - auto reinterpretedTag = upcastCollection( - context, - argTableTag, - (IRType*)makeTagType(destTUType->getTableCollection())); + void processInstForPropagation(IRInst* context, IRInst* inst, WorkQueue& workQueue) + { + IRInst* info; - auto argVal = - builder.emitGetTupleElement((IRType*)argTUType->getTypeCollection(), arg, 1); - auto reinterpretedVal = - upcastCollection(context, argVal, (IRType*)destTUType->getTypeCollection()); - return builder.emitMakeTuple( - (IRType*)destTUType, - {reinterpretedTag, reinterpretedVal}); - } - } - else if (as(argInfo) && as(destInfo)) - { - auto argTupleType = as(argInfo); - auto destTupleType = as(destInfo); - - List upcastedElements; - bool hasUpcastedElements = false; - - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - - // Upcast each element of the tuple - for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) - { - auto argElementType = argTupleType->getOperand(i); - auto destElementType = destTupleType->getOperand(i); - - // If the element types are different, we need to reinterpret - if (argElementType != destElementType) - { - hasUpcastedElements = true; - upcastedElements.add(upcastCollection( - context, - builder.emitGetTupleElement((IRType*)argElementType, arg, i), - (IRType*)destElementType)); - } - else - { - upcastedElements.add( - builder.emitGetTupleElement((IRType*)argElementType, arg, i)); - } - } - - if (hasUpcastedElements) - { - return builder.emitMakeTuple(upcastedElements); - } - } - else if (as(argInfo) && as(destInfo)) - { - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder - .emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); - } - } - else if (as(argInfo) && as(destInfo)) - { - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - // If the sets of witness tables are not equal, reinterpret to the parameter type - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder.emitReinterpret((IRType*)destInfo, arg); - } - } - else if (!as(argInfo) && as(destInfo)) - { - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder.emitPackAnyValue((IRType*)destInfo, arg); - } - - return arg; // Can use as-is. - } - - void processInstForPropagation(IRInst* context, IRInst* inst, WorkQueue& workQueue) - { - IRInst* info; - - switch (inst->getOp()) + switch (inst->getOp()) { case kIROp_CreateExistentialObject: info = analyzeCreateExistentialObject(context, as(inst)); @@ -720,6 +914,195 @@ struct TypeFlowSpecializationContext updateInfo(context, inst, info, takeUnion, workQueue); } + void processBlock(IRInst* context, IRBlock* block, WorkQueue& workQueue) + { + for (auto inst : block->getChildren()) + { + // Skip parameters & terminator + if (as(inst) || as(inst)) + continue; + processInstForPropagation(context, inst, workQueue); + } + + if (auto returnInfo = as(block->getTerminator())) + { + auto valInfo = returnInfo->getVal(); + updateFuncReturnInfo(context, tryGetInfo(context, valInfo), workQueue); + } + }; + + void propagateWithinFuncEdge(IRInst* context, IREdge edge, WorkQueue& workQueue) + { + // Handle intra-procedural edge (original logic) + auto predecessorBlock = edge.getPredecessor(); + auto successorBlock = edge.getSuccessor(); + + // Get the terminator instruction and extract arguments + auto terminator = predecessorBlock->getTerminator(); + if (!terminator) + return; + + // Right now, only unconditional branches can propagate new information + auto unconditionalBranch = as(terminator); + if (!unconditionalBranch) + return; + + // Find which successor this edge leads to (should be the target) + if (unconditionalBranch->getTargetBlock() != successorBlock) + return; + + // Collect propagation info for each argument and update corresponding parameter + UInt paramIndex = 0; + for (auto param : successorBlock->getParams()) + { + if (paramIndex < unconditionalBranch->getArgCount()) + { + auto arg = unconditionalBranch->getArg(paramIndex); + if (auto argInfo = tryGetInfo(context, arg)) + { + // Use centralized update method + updateInfo(context, param, argInfo, true, workQueue); + } + } + paramIndex++; + } + } + + void propagateInterproceduralEdge(InterproceduralEdge edge, WorkQueue& workQueue) + { + // Handle interprocedural edge + auto callInst = edge.callInst; + auto targetCallee = edge.targetContext; + + switch (edge.direction) + { + case InterproceduralEdge::Direction::CallToFunc: + { + // Propagate argument info from call site to function parameters + IRBlock* firstBlock = nullptr; + + if (as(targetCallee)) + firstBlock = targetCallee->getFirstBlock(); + else if (auto specInst = as(targetCallee)) + firstBlock = getGenericReturnVal(specInst->getBase())->getFirstBlock(); + + if (!firstBlock) + return; + + UInt argIndex = 1; // Skip callee (operand 0) + for (auto param : firstBlock->getParams()) + { + if (argIndex < callInst->getOperandCount()) + { + auto arg = callInst->getOperand(argIndex); + const auto [paramDirection, paramType] = + splitParameterDirectionAndType(param->getDataType()); + + // Only update if + // 1. The paramType is a global inst and an interface type + // 2. The paramType is a local inst. + // all other cases, continue. + if (isGlobalInst(paramType) && !as(paramType)) + { + argIndex++; + continue; + } + + IRInst* argInfo = tryGetInfo(edge.callerContext, arg); + + switch (paramDirection.kind) + { + case ParameterDirectionInfo::Kind::Out: + case ParameterDirectionInfo::Kind::BorrowInOut: + case ParameterDirectionInfo::Kind::BorrowIn: + { + IRBuilder builder(module); + if (!argInfo) + { + if (isGlobalInst(arg->getDataType()) && + !as( + as(arg->getDataType())->getValueType())) + argInfo = arg->getDataType(); + } + + if (!argInfo) + break; + + auto newInfo = fromDirectionAndType( + &builder, + paramDirection, + as(argInfo)->getValueType()); + updateInfo(edge.targetContext, param, newInfo, true, workQueue); + break; + } + case ParameterDirectionInfo::Kind::In: + { + // Use centralized update method + if (!argInfo) + { + if (isGlobalInst(arg->getDataType()) && + !as(arg->getDataType())) + argInfo = arg->getDataType(); + } + updateInfo(edge.targetContext, param, argInfo, true, workQueue); + break; + } + default: + SLANG_UNEXPECTED( + "Unhandled parameter direction in interprocedural edge"); + } + } + argIndex++; + } + break; + } + case InterproceduralEdge::Direction::FuncToCall: + { + // Propagate return value info from function to call site + auto returnInfo = funcReturnInfo.tryGetValue(targetCallee); + if (returnInfo) + { + // Use centralized update method + updateInfo(edge.callerContext, callInst, *returnInfo, true, workQueue); + } + + // Also update infos of any out parameters + auto paramInfos = getParamInfos(edge.targetContext); + auto paramDirections = getParamDirections(edge.targetContext); + UIndex argIndex = 0; + for (auto paramInfo : paramInfos) + { + if (paramInfo) + { + if (paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::Out || + paramDirections[argIndex].kind == + ParameterDirectionInfo::Kind::BorrowInOut) + { + auto arg = callInst->getArg(argIndex); + auto argPtrType = as(arg->getDataType()); + + IRBuilder builder(module); + updateInfo( + edge.callerContext, + arg, + builder.getPtrTypeWithAddressSpace( + (IRType*)as(paramInfo)->getValueType(), + argPtrType), + true, + workQueue); + } + } + argIndex++; + } + + break; + } + default: + SLANG_UNEXPECTED("Unhandled interprocedural edge direction"); + return; + } + } + IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { SLANG_UNUSED(context); @@ -746,8 +1129,9 @@ struct TypeFlowSpecializationContext { auto witnessTable = inst->getWitnessTable(); - // If we're building an existential for a COM pointer, - // we won't try to lower that. + // If we're building an existential for a COM interface, + // we always assume it is unbounded, since we can receive + // types that we know nothing about in the current linkage. // if (isComInterfaceType(inst->getDataType())) return makeUnbounded(); @@ -805,29 +1189,22 @@ struct TypeFlowSpecializationContext return none(); // the make struct itself doesn't have any info. } - bool isResourcePointer(IRInst* inst) + IRInst* analyzeLoad(IRInst* context, IRInst* inst) { - if (isPointerToResourceType(inst->getDataType()) || - inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) - return true; - - if (as(inst)) - return true; + if (auto loadInst = as(inst)) + { + // If we have a simple load, theres one of two cases: + // + // 1. If we're loading from a resource pointer, we need to treat it + // as unspecializable. If it's a COM interface, we consider it truly + // unbounded. Otherwise, we can simply enumerate all tables for the interface + // type. + // + // 2. In the default case, we can look up the registered information + // for the pointer, which should be of the form PtrTypeBase(valueInfo), and + // use the valueInfo + // - if (auto ptr = as(inst)) - return isResourcePointer(ptr->getBase()); - - if (auto fieldAddress = as(inst)) - return isResourcePointer(fieldAddress->getBase()); - - return false; - } - - IRInst* analyzeLoad(IRInst* context, IRInst* inst) - { - // Default: Transfer the prop info from the address to the loaded value - if (auto loadInst = as(inst)) - { if (isResourcePointer(loadInst->getPtr())) { if (auto interfaceType = as(loadInst->getDataType())) @@ -857,6 +1234,9 @@ struct TypeFlowSpecializationContext } else if (as(inst) || as(inst)) { + // In case of a buffer load, we know we're dealing with a location that cannot + // be specialized, so the logic is similar to case (1) from above. + // if (auto interfaceType = as(inst->getDataType())) { if (!isComInterfaceType(interfaceType) && !isBuiltin(interfaceType)) @@ -880,7 +1260,16 @@ struct TypeFlowSpecializationContext IRInst* analyzeStore(IRInst* context, IRStore* storeInst, WorkQueue& workQueue) { - // Transfer the prop info from stored value to the address + // For a simple store, we will attempt to update the location with + // the information from the stored value. + // + // Since the pointer can be an access chain, we have to recursively transfer + // the information down to the base. This logic is handled by `maybeUpdatePtr` + // + // If the value has "info", we construct an appropriate PtrType(info) and + // update the ptr with it. + // + auto address = storeInst->getPtr(); if (auto valInfo = tryGetInfo(context, storeInst->getVal())) { @@ -889,11 +1278,12 @@ struct TypeFlowSpecializationContext (IRType*)valInfo, as(address->getDataType())); - // Update the base instruction for the entire access chain + // Propagate the information up the access chain to the base location. maybeUpdatePtr(context, address, ptrInfo, workQueue); } - return none(); // The store itself doesn't have any info. + // The store inst itself doesn't produce anything, so it has no info + return none(); } IRInst* analyzeGetElementPtr(IRInst* context, IRGetElementPtr* getElementPtr) @@ -911,12 +1301,17 @@ struct TypeFlowSpecializationContext return builder.getPtrTypeWithAddressSpace(arrayType->getElementType(), ptrType); } - return none(); // No info for the base pointer + return none(); // No info for the base pointer => no info for the result. } IRInst* analyzeFieldAddress(IRInst* context, IRFieldAddress* fieldAddress) { - // The base info should be in Ptr form, so we just need to return Ptr as the result. + // In this case, we don't look up the base's info, but rather, we find + // the IRStructField being accessed, and look up the info in the fieldInfos + // map. + // + // This info will be the in the value form, so we need to wrap it in a + // pointer since the result is an address. // IRBuilder builder(module); builder.setInsertAfter(fieldAddress); @@ -930,8 +1325,8 @@ struct TypeFlowSpecializationContext findStructField(structType, as(fieldAddress->getField())); // Register this as a user of the field so updates will invoke this function again. - this->fieldUseSites.addIfNotExists(structField, HashSet()); - this->fieldUseSites[structField].add(Element(context, fieldAddress)); + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(InstWithContext(context, fieldAddress)); if (this->fieldInfo.containsKey(structField)) { @@ -941,11 +1336,16 @@ struct TypeFlowSpecializationContext } } } - return none(); + + return none(); // No info for the field => no info for the result. } IRInst* analyzeFieldExtract(IRInst* context, IRFieldExtract* fieldExtract) { + // Very similar logic to `analyzeFieldAddress`, but without having to + // wrap the result in a pointer. + // + IRBuilder builder(module); if (auto structType = as(fieldExtract->getBase()->getDataType())) @@ -954,8 +1354,8 @@ struct TypeFlowSpecializationContext findStructField(structType, as(fieldExtract->getField())); // Register this as a user of the field so updates will invoke this function again. - this->fieldUseSites.addIfNotExists(structField, HashSet()); - this->fieldUseSites[structField].add(Element(context, fieldExtract)); + this->fieldUseSites.addIfNotExists(structField, HashSet()); + this->fieldUseSites[structField].add(InstWithContext(context, fieldExtract)); if (this->fieldInfo.containsKey(structField)) { @@ -967,6 +1367,17 @@ struct TypeFlowSpecializationContext IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { + // A LookupWitnessMethod is assumed to by dynamic, so we + // (i) construct a collection of the results by looking up the given + // key in each of the input witness tables + // (ii) wrap the result in a tag type, since the lookup inst is logically holding + // on to run-time information about which element of the collection is active. + // + // Note that the input must be a set of concrete witness tables (or none/unbounded). + // If this is not the case and we see anything abstract, then something has gone + // wrong somewhere when analyzing a previous instruction. + // + auto key = inst->getRequirementKey(); auto witnessTable = inst->getWitnessTable(); @@ -995,6 +1406,15 @@ struct TypeFlowSpecializationContext IRInst* context, IRExtractExistentialWitnessTable* inst) { + // An ExtractExistentialWitnessTable inst is assumed to by dynamic, so we + // extract the set of witness tables from the input existential and + // state that the info of the result is a tag-type of that collection. + // + // Note that since ExtractExistentialWitnessTable can only be used on + // an existential, the input info must be a CollectionTaggedUnionType of + // concrete table and type collections (or none/unbounded) + // + auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); @@ -1012,6 +1432,15 @@ struct TypeFlowSpecializationContext IRInst* analyzeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { + // An ExtractExistentialType inst is assumed to be dynamic, so we + // extract the set of witness tables from the input existential and + // state that the info of the result is a tag-type of that collection. + // + // Note: Since ExtractExistentialType can only be used on + // an existential, the input info must be a CollectionTaggedUnionType of + // concrete table and type collections (or none/unbounded) + // + auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); @@ -1029,6 +1458,17 @@ struct TypeFlowSpecializationContext IRInst* analyzeExtractExistentialValue(IRInst* context, IRExtractExistentialValue* inst) { + // Logically, an ExtractExistentialValue inst is carrying a payload + // of a union type. + // + // We represent this by setting its info to be equal to the type-collection, + // which will later lower into an any-value-type. + // + // Note that there is no 'tag' here since ExtractExistentialValue is not representing + // tag information about which type in the collection is active, but is representing + // a value of the collection's union type. + // + auto operand = inst->getOperand(0); auto operandInfo = tryGetInfo(context, operand); @@ -1046,6 +1486,28 @@ struct TypeFlowSpecializationContext IRInst* analyzeSpecialize(IRInst* context, IRSpecialize* inst) { + // Analyzing an IRSpecialize inst is an interesting case. + // + // If we hit this case, it means we are encountering this instruction + // inside a block, so the arguments to the specialization likely have some + // dynamic types or witness tables. + // + // We'll first look at the specialization base, which may be a single generic + // or a collection of generics. + // + // Then, for each generic, we'll create a specialized version by using the + // collection info for each argument in place of the argument. + // e.g. Specialize(G, A0, A1) becomes Specialize(G, info(A1).collection, + // info(A2).collection) + // (i.e. if the args are tag-types, we only use the collection part) + // + // This transformation is important to lift the 'dynamic' specialize instruction into a + // global specialize instruction while still retaining the information about what types and + // tables the resulting generic should support. + // + // Finally, we put all the specialized vesions back into a collection and return that info. + // + auto operand = inst->getBase(); auto operandInfo = tryGetInfo(context, operand); @@ -1081,14 +1543,14 @@ struct TypeFlowSpecializationContext // their sets (if available) // auto argInfo = tryGetInfo(context, inst->getArg(i)); + + // If any of the args are 'empty' sets, we can't generate a specialization just yet. if (!argInfo) - return none(); // Can't determine the result just yet. + return none(); if (as(argInfo) || as(argInfo)) { - SLANG_UNEXPECTED( - "Unexpected Existential operand in specialization argument. Should be " - "set"); + SLANG_UNEXPECTED("Unexpected Existential operand in specialization argument."); } if (auto argCollectionTag = as(argInfo)) @@ -1107,6 +1569,10 @@ struct TypeFlowSpecializationContext } } + // This part creates a correct type for the specialization, by following the same + // process: replace all operands in the composite type with their propagated collection. + // + IRType* typeOfSpecialization = nullptr; if (inst->getDataType()->getParent()->getOp() == kIROp_ModuleInst) typeOfSpecialization = inst->getDataType(); @@ -1214,6 +1680,18 @@ struct TypeFlowSpecializationContext void discoverContext(IRInst* context, WorkQueue& workQueue) { + // "Discovering" a context, essentially means we check if this is the first + // time we're trying to propagate information into this context. A context + // is a global-scope IRFunc or IRSpecialize. + // + // If it is the first, we enqueue some work to perform initialization of all + // the insts in the body of the func. + // + // Since discover context is only called 'on-demand' as the type-flow propagation + // happens, we avoid having to deal with functions/generics that are never used, + // and minimize the amount of work being performed. + // + if (this->availableContexts.add(context)) { IRFunc* func = nullptr; @@ -1287,17 +1765,14 @@ struct TypeFlowSpecializationContext IRInst* analyzeCall(IRInst* context, IRCall* inst, WorkQueue& workQueue) { + // We don't perform the propagation here, but instead we add inter-procedural + // edges to the work queue. + // The propagation logic is handled in `propagateInterproceduralEdge()` + // + auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); - // - // Propagate the input judgments to the call & append a work item - // for inter-procedural propagation. - // - - // For now, we'll handle just a concrete func. But the logic for multiple functions - // is exactly the same (add an edge for each function). - // auto propagateToCallSite = [&](IRInst* callee) { // Register the call site in the map to allow for the @@ -1309,8 +1784,8 @@ struct TypeFlowSpecializationContext // discoverContext(callee, workQueue); - this->funcCallSites.addIfNotExists(callee, HashSet()); - if (this->funcCallSites[callee].add(Element(context, inst))) + this->funcCallSites.addIfNotExists(callee, HashSet()); + if (this->funcCallSites[callee].add(InstWithContext(context, inst))) { // If this is a new call site, add a propagation task to the queue (in case there's // already information about this function) @@ -1321,14 +1796,15 @@ struct TypeFlowSpecializationContext WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); }; + // If we have a collection of functions (with or without a dynamic tag), register + // each one. + // if (auto collectionTag = as(calleeInfo)) { - // If we have a set of functions, register each one forEachInCollection(collectionTag, [&](IRInst* func) { propagateToCallSite(func); }); } else if (auto collection = as(calleeInfo)) { - // If we have a collection of functions, register each one forEachInCollection(collection, [&](IRInst* func) { propagateToCallSite(func); }); } @@ -1338,13 +1814,33 @@ struct TypeFlowSpecializationContext return none(); } + // Updates the information for an address. void maybeUpdatePtr(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { + // This method recursively walks up the access chain until it hits a location. + // + // Pointers don't have any unique information attached to them because two pointers to the + // same location (directly or indirectly) must mirror the same propagation info. + // + // Thus, our approach will be to update the info for the base value (usually an IRVar or + // IRParam), and then register users for updates to propagate the updated information + // to any access chain instructions. + // + if (auto getElementPtr = as(inst)) { if (auto thisPtrInfo = as(info)) { - auto thisValueType = thisPtrInfo->getValueType(); + // For get-element-ptr, we propagate information by + // wrapping the result's info into the array type. + // + // a : PtrType(ArrayType(T, count)) = ... + // b : PtrType(T) = GetElementPtr(a) + // + // info(a) = ArrayType(info(b), count) + // + + auto thisValueInfo = thisPtrInfo->getValueType(); IRInst* baseValueType = as(getElementPtr->getBase()->getDataType())->getValueType(); @@ -1354,17 +1850,29 @@ struct TypeFlowSpecializationContext IRBuilder builder(module); auto baseInfo = builder.getPtrTypeWithAddressSpace( builder.getArrayType( - (IRType*)thisValueType, + (IRType*)thisValueInfo, as(baseValueType)->getElementCount()), as(getElementPtr->getBase()->getDataType())); + + // Recursively try to update the base pointer. maybeUpdatePtr(context, getElementPtr->getBase(), baseInfo, workQueue); } } else if (auto fieldAddress = as(inst)) { - // If this is a field address, update the fieldInfos map. if (as(info)) { + // Field address is also treated as a base case for the recursion. + // + // For field-address, we record the information against the field itself + // by using the fieldInfos map (after unwrapping the pointer) + // + // a : PtrType(S) = ... + // b : PtrType(T) = GetFieldAddress(a, fieldKey) + // + // infos[findField(S, fieldKey)] = info(T) + // + IRBuilder builder(module); auto baseStructPtrType = as(fieldAddress->getBase()->getDataType()); auto baseStructType = as(baseStructPtrType->getValueType()); @@ -1382,6 +1890,11 @@ struct TypeFlowSpecializationContext (IRType*)existingInfo, as(fieldAddress->getDataType())); + // Manually update the prop info add work items for all users of this field. + // + // This case is not handled by updateInfo(), though in the future + // it makes sense to include this as a case in updateInfo() + // if (auto newInfo = unionPropagationInfo(info, existingInfo)) { if (newInfo != existingInfo) @@ -1391,7 +1904,6 @@ struct TypeFlowSpecializationContext // Update the field info map this->fieldInfo[foundField] = newInfoValType; - // Add a work item to update the field extract if (this->fieldUseSites.containsKey(foundField)) for (auto useSite : this->fieldUseSites[foundField]) workQueue.enqueue(WorkItem(useSite.context, useSite.inst)); @@ -1402,7 +1914,10 @@ struct TypeFlowSpecializationContext } else if (auto var = as(inst)) { - // If we hit a local var, we'll update it's info. + // If we hit a local var, we'll update its info. + // + // This is one of the base cases for the recursion. + // updateInfo(context, var, info, true, workQueue); } else if (auto param = as(inst)) @@ -1411,6 +1926,11 @@ struct TypeFlowSpecializationContext // but first change the info from PtrTypeBase // to the specific pointer type for the parameter. // + // (e.g. parameter may use a BorrowInOutType, but the info + // may be some other PtrType) + // + // This is one of the base cases for the recursion. + // IRBuilder builder(param->getModule()); auto newInfo = builder.getPtrTypeWithAddressSpace( (IRType*)as(info)->getValueType(), @@ -1419,70 +1939,23 @@ struct TypeFlowSpecializationContext } else { - // If we hit something unsupported, assume no information. + // If we hit something unsupported, assume there's nothing to update. return; } } - void propagateWithinFuncEdge(IRInst* context, IREdge edge, WorkQueue& workQueue) - { - // Handle intra-procedural edge (original logic) - auto predecessorBlock = edge.getPredecessor(); - auto successorBlock = edge.getSuccessor(); - - // Get the terminator instruction and extract arguments - auto terminator = predecessorBlock->getTerminator(); - if (!terminator) - return; - - // Right now, only unconditional branches can propagate new information - auto unconditionalBranch = as(terminator); - if (!unconditionalBranch) - return; - - // Find which successor this edge leads to (should be the target) - if (unconditionalBranch->getTargetBlock() != successorBlock) - return; - - // Collect propagation info for each argument and update corresponding parameter - UInt paramIndex = 0; - for (auto param : successorBlock->getParams()) - { - if (paramIndex < unconditionalBranch->getArgCount()) - { - auto arg = unconditionalBranch->getArg(paramIndex); - if (auto argInfo = tryGetInfo(context, arg)) - { - // Use centralized update method - updateInfo(context, param, argInfo, true, workQueue); - } - } - paramIndex++; - } - } - - bool isGlobalInst(IRInst* inst) { return inst->getParent()->getOp() == kIROp_ModuleInst; } - - bool isIntrinsic(IRInst* inst) + // Returns the effective parameter types for a given calling context, after + // the type-flow propagation is complete. + // + List getEffectiveParamTypes(IRInst* context) { - auto func = as(inst); - if (auto specialize = as(inst)) - { - auto generic = specialize->getBase(); - func = as(getGenericReturnVal(generic)); - } - - if (!func) - return false; - - if (func->getFirstBlock() == nullptr) - return true; - - return false; - } + // This proceeds by looking at the propagation info for each parameter, + // then returning the info if one exists. + // + // If one does not exist, it means the parameter has a concrete type + // (not dynamic or generic), and we can just use that for the parameter. + // - List getParamEffectiveTypes(IRInst* context) - { List effectiveTypes; IRFunc* func = nullptr; if (as(context)) @@ -1517,6 +1990,7 @@ struct TypeFlowSpecializationContext return effectiveTypes; } + // Helper to get any recorded propagation info for each parameter of a calling context. List getParamInfos(IRInst* context) { List infos; @@ -1541,14 +2015,19 @@ struct TypeFlowSpecializationContext return infos; } + // Helper to extract the directions of each parameter for a calling context. List getParamDirections(IRInst* context) { + // Note that this method does not actually have to retreive any propagation info, + // since the directions/address-spaces of parameters are always concrete. + // + List directions; if (as(context)) { for (auto param : as(context)->getParams()) { - const auto [direction, type] = getParameterDirectionAndType(param->getDataType()); + const auto [direction, type] = splitParameterDirectionAndType(param->getDataType()); directions.add(direction); } } @@ -1558,7 +2037,7 @@ struct TypeFlowSpecializationContext auto innerFunc = getGenericReturnVal(generic); for (auto param : as(innerFunc)->getParams()) { - const auto [direction, type] = getParameterDirectionAndType(param->getDataType()); + const auto [direction, type] = splitParameterDirectionAndType(param->getDataType()); directions.add(direction); } } @@ -1571,154 +2050,38 @@ struct TypeFlowSpecializationContext return directions; } - void propagateInterproceduralEdge(InterproceduralEdge edge, WorkQueue& workQueue) + // Extract the return value information for a given calling context + IRInst* getFuncReturnInfo(IRInst* context) { - // Handle interprocedural edge - auto callInst = edge.callInst; - auto targetCallee = edge.targetContext; - - switch (edge.direction) - { - case InterproceduralEdge::Direction::CallToFunc: - { - // Propagate argument info from call site to function parameters - IRBlock* firstBlock = nullptr; - - if (as(targetCallee)) - firstBlock = targetCallee->getFirstBlock(); - else if (auto specInst = as(targetCallee)) - firstBlock = getGenericReturnVal(specInst->getBase())->getFirstBlock(); - - if (!firstBlock) - return; - - UInt argIndex = 1; // Skip callee (operand 0) - for (auto param : firstBlock->getParams()) - { - if (argIndex < callInst->getOperandCount()) - { - auto arg = callInst->getOperand(argIndex); - const auto [paramDirection, paramType] = - getParameterDirectionAndType(param->getDataType()); - - // Only update if - // 1. The paramType is a global inst and an interface type - // 2. The paramType is a local inst. - // all other cases, continue. - if (isGlobalInst(paramType) && !as(paramType)) - { - argIndex++; - continue; - } - - IRInst* argInfo = tryGetInfo(edge.callerContext, arg); - - switch (paramDirection.kind) - { - case ParameterDirectionInfo::Kind::Out: - case ParameterDirectionInfo::Kind::BorrowInOut: - case ParameterDirectionInfo::Kind::BorrowIn: - { - IRBuilder builder(module); - if (!argInfo) - { - if (isGlobalInst(arg->getDataType()) && - !as( - as(arg->getDataType())->getValueType())) - argInfo = arg->getDataType(); - } - - if (!argInfo) - break; - - auto newInfo = fromDirectionAndType( - &builder, - paramDirection, - as(argInfo)->getValueType()); - updateInfo(edge.targetContext, param, newInfo, true, workQueue); - break; - } - case ParameterDirectionInfo::Kind::In: - { - // Use centralized update method - if (!argInfo) - { - if (isGlobalInst(arg->getDataType()) && - !as(arg->getDataType())) - argInfo = arg->getDataType(); - } - updateInfo(edge.targetContext, param, argInfo, true, workQueue); - break; - } - default: - SLANG_UNEXPECTED( - "Unhandled parameter direction in interprocedural edge"); - } - } - argIndex++; - } - break; - } - case InterproceduralEdge::Direction::FuncToCall: - { - // Propagate return value info from function to call site - auto returnInfo = funcReturnInfo.tryGetValue(targetCallee); - if (returnInfo) - { - // Use centralized update method - updateInfo(edge.callerContext, callInst, *returnInfo, true, workQueue); - } - - // Also update infos of any out parameters - auto paramInfos = getParamInfos(edge.targetContext); - auto paramDirections = getParamDirections(edge.targetContext); - UIndex argIndex = 0; - for (auto paramInfo : paramInfos) - { - if (paramInfo) - { - if (paramDirections[argIndex].kind == ParameterDirectionInfo::Kind::Out || - paramDirections[argIndex].kind == - ParameterDirectionInfo::Kind::BorrowInOut) - { - auto arg = callInst->getArg(argIndex); - auto argPtrType = as(arg->getDataType()); - - IRBuilder builder(module); - updateInfo( - edge.callerContext, - arg, - builder.getPtrTypeWithAddressSpace( - (IRType*)as(paramInfo)->getValueType(), - argPtrType), - true, - workQueue); - } - } - argIndex++; - } - - break; - } - default: - SLANG_UNEXPECTED("Unhandled interprocedural edge direction"); - return; - } - } - - IRInst* getFuncReturnInfo(IRInst* callee) - { - funcReturnInfo.addIfNotExists(callee, none()); - return funcReturnInfo[callee]; + // We record the information in a separate map, rather than using + // a specific inst. + // + // This is because we need the union of the infos of all Return instructions + // in the function, but there's no physical instruction that represents this. + // (unlike the block control flow case, where phi params exist) + // + // Effectively, this is a 'virtual' inst that represents the union of all + // the return values. + // + funcReturnInfo.addIfNotExists(context, none()); + return funcReturnInfo[context]; } + // Set up initial information for parameters based on their types. void initializeFirstBlockParameters(IRInst* context, IRFunc* func) { + // This method primarily just initializes known COM & Builtin interface + // types to 'unbounded', to avoid specializing any instructions derived from + // these parameters. + // + auto firstBlock = func->getFirstBlock(); if (!firstBlock) return; - // Initialize parameters based on their types + // Initialize parameters with COM/Builtin interface types to 'unbounded' and + // everything else to none. + // for (auto param : firstBlock->getParams()) { auto paramType = param->getDataType(); @@ -1729,103 +2092,22 @@ struct TypeFlowSpecializationContext if (auto interfaceType = as(paramType)) { if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) - propagationMap[Element(context, param)] = makeUnbounded(); + propagationMap[InstWithContext(context, param)] = makeUnbounded(); else - propagationMap[Element(context, param)] = none(); // Initialize to none. + propagationMap[InstWithContext(context, param)] = none(); } else { - propagationMap[Element(context, param)] = none(); + propagationMap[InstWithContext(context, param)] = none(); } } } - template - T* unionCollection(T* collection1, T* collection2) - { - SLANG_ASSERT(as(collection1) && as(collection2)); - SLANG_ASSERT(collection1->getOp() == collection2->getOp()); - - if (!collection1) - return collection2; - if (!collection2) - return collection1; - if (collection1 == collection2) - return collection1; - - HashSet allValues; - // Collect all values from both collections - forEachInCollection(collection1, [&](IRInst* value) { allValues.add(value); }); - forEachInCollection(collection2, [&](IRInst* value) { allValues.add(value); }); - - return as(cBuilder.createCollection( - collection1->getOp(), - allValues)); // Create a new collection with the union of values - } - - IRInst* unionPropagationInfo(IRInst* info1, IRInst* info2) - { - if (!info1) - return info2; - if (!info2) - return info1; - - if (info1 == info2) - return info1; - - if (as(info1) && as(info2)) - { - SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); - // If both are array types, union their element types - IRBuilder builder(module); - builder.setInsertInto(module); - return builder.getArrayType( - (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), - info1->getOperand(1)); // Keep the same size - } - - if (as(info1) && as(info2)) - { - SLANG_ASSERT(info1->getOp() == info2->getOp()); - - // If both are array types, union their element types - IRBuilder builder(module); - builder.setInsertInto(module); - return builder.getPtrTypeWithAddressSpace( - (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), - as(info1)); - } - - if (as(info1) && as(info2)) - { - // If either info is unbounded, the union is unbounded - return makeUnbounded(); - } - - if (as(info1) && as(info2)) - { - return makeExistential(unionCollection( - cast(info1->getOperand(1)), - cast(info2->getOperand(1)))); - } - - if (as(info1) && as(info2)) - { - return makeTagType(unionCollection( - cast(info1->getOperand(0)), - cast(info2->getOperand(0)))); - } - - if (as(info1) && as(info2)) - { - return unionCollection( - cast(info1), - cast(info2)); - } - - SLANG_UNEXPECTED("Unhandled propagation info types in unionPropagationInfo"); - } - + // Default catch-all analysis method for any unhandled case. + // + // TODO: This technically shouldn't get invoked, since global + // insts shouldn't enter analysis at all + // IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) { SLANG_UNUSED(context); @@ -1855,21 +2137,9 @@ struct TypeFlowSpecializationContext return none(); // Default case, no propagation info } - bool specializeInstsInBlock(IRInst* context, IRBlock* block) - { - List instsToLower; - bool hasChanges = false; - for (auto inst : block->getChildren()) - instsToLower.add(inst); - - for (auto inst : instsToLower) - { - hasChanges |= specializeInst(context, inst); - } - - return hasChanges; - } - + // Specialize the fields of a struct type based on the recorded field info (if we + // have a non-trivial specilialization) + // bool specializeStructType(IRStructType* structType) { bool hasChanges = false; @@ -1891,9 +2161,53 @@ struct TypeFlowSpecializationContext return hasChanges; } - bool specializeFunc(IRFunc* func) - { - // Don't make any changes to non-global or intrinsic functions + bool specializeInstsInBlock(IRInst* context, IRBlock* block) + { + List instsToLower; + bool hasChanges = false; + for (auto inst : block->getChildren()) + instsToLower.add(inst); + + for (auto inst : instsToLower) + hasChanges |= specializeInst(context, inst); + + return hasChanges; + } + + bool specializeFunc(IRFunc* func) + { + // When specializing a func, we + // (i) rewrite the types and insts by calling `specializeInstsInBlock` and + // (ii) handle 'merge' points where the collections need to be upcasted. + // + // The merge points are places where a specialized inst might be passed as + // argument to a parameter that has a 'wider' type. + // + // This frequently occurs with phi parameters. + // + // For example: + // A B + // \ / + // C + // + // After specialization, A could pass a value of type TagType(TableCollection{T1, T2}) + // while B passes a value of type TagType(TableCollection{T2, T3}), while the phi + // parameter's type in C has the union type `TagType(TableCollection{T1, T2, T3})` + // + // In this case, we use `upcastCollection` to insert a cast from TagType(TableCollection{T1, + // T2}) -> TagType(TableCollection{T1, T2, T3}) before passing the result as a phi argument. + // + // The same logic applies for the return values. The function's caller expects a union type + // of all possible return statements, so we cast each return inst if there is a mismatch. + // + + // Don't make any changes to non-global or intrinsic functions. + // + // If a function is inside a generic, we wait until the main specialization pass + // turns it into a regular func and the typeflow pass is re-run again. + // This approach is much simpler that trying to incorporate generic parameters into the + // typeflow specialization logic. + // if (!isGlobalInst(func) || isIntrinsic(func)) return false; @@ -1904,10 +2218,9 @@ struct TypeFlowSpecializationContext for (auto block : func->getBlocks()) { UIndex paramIndex = 0; - // Process each parameter in this block (these are phi parameters) for (auto param : block->getParams()) { - auto paramInfo = tryGetInfo(param); + auto paramInfo = _tryGetInfo(InstWithContext(func, param)); if (!paramInfo) { paramIndex++; @@ -1926,9 +2239,14 @@ struct TypeFlowSpecializationContext if (newArg != arg) { hasChanges = true; - // Replace the argument in the branch instruction - SLANG_ASSERT(!as(unconditionalBranch)); - unconditionalBranch->setOperand(1 + paramIndex, newArg); + + // Replace the argument in the branch instruction with the + // properly casted argument. + // + if (auto loop = as(unconditionalBranch)) + loop->setOperand(3 + paramIndex, newArg); + else + unconditionalBranch->setOperand(1 + paramIndex, newArg); } } } @@ -1936,7 +2254,9 @@ struct TypeFlowSpecializationContext paramIndex++; } - // Is the terminator a return instruction? + // If the terminator is a return instruction, perform the same upcasting to + // match the registered return value type for this function. + // if (auto returnInst = as(block->getTerminator())) { if (!as(returnInst->getVal()->getDataType())) @@ -1956,6 +2276,7 @@ struct TypeFlowSpecializationContext } } + // Update the func type for this func accordingly. auto effectiveFuncType = getEffectiveFuncType(func); if (effectiveFuncType != func->getFullType()) { @@ -1966,6 +2287,19 @@ struct TypeFlowSpecializationContext return hasChanges; } + // Main entry point for the second phase of the type-flow analysis pass. + // + // This method is called after information propagation is complete and + // stabilized, and it replaces dynamic insts and types with specialized versions + // based on the collected information. + // + // After this pass is run, there should be no dynamic insts or types remaining, + // _except_ for those that are considered unbounded. + // + // i.e. ExtractExistentialType, ExtractExistentialWitnessTable, ExtractExistentialValue, + // MakeExistential, LookupWitness (and more) are rewritten to concrete tag translation + // insts. + // bool performDynamicInstLowering() { List funcsToProcess; @@ -1993,6 +2327,14 @@ struct TypeFlowSpecializationContext return hasChanges; } + // Returns an effective type to use for an inst, given its + // info. + // + // This basically recursively walks the info and applies the array/ptr-type + // wrappers, while replacing unbounded collection with a nullptr. + // + // If the result of this is null, then the inst should keep its current type. + // IRType* getLoweredType(IRInst* info) { if (!info) @@ -2059,17 +2401,17 @@ struct TypeFlowSpecializationContext } return (IRType*)info; - // SLANG_UNEXPECTED("Unhandled IRTypeFlowData type in getLoweredType"); } + // Replace an insts type with its effective type as determined by the analysis. bool replaceType(IRInst* context, IRInst* inst) { + // If the inst is a global val, we won't modify it. if (as(inst->getParent())) { if (as(inst) || as(inst) || as(inst) || as(inst)) { - // Don't replace global concrete vals. return false; } } @@ -2138,7 +2480,7 @@ struct TypeFlowSpecializationContext bool specializeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { - // Handle trivial case. + // Handle trivial case where inst's operand is a concrete table. if (auto witnessTable = as(inst->getWitnessTable())) { inst->replaceUsesWith(findWitnessTableEntry(witnessTable, inst->getRequirementKey())); @@ -2146,6 +2488,7 @@ struct TypeFlowSpecializationContext return true; } + // Otherwise, we go off the info for the inst. auto info = tryGetInfo(context, inst); if (!info) return false; @@ -2157,14 +2500,17 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertBefore(inst); + // If there's a single element, we can do a simple replacement. if (getCollectionCount(collectionTagType) == 1) { - // Found a single possible type. Simple replacement. inst->replaceUsesWith(getCollectionElement(collectionTagType, 0)); inst->removeAndDeallocate(); return true; } + // If the collection is a type-collection, we'll still do a direct replacement + // effectively dropping the tag information + // if (auto typeCollection = as(collectionTagType->getCollection())) { // If this is a type collection, we can replace it with the collection type @@ -2175,7 +2521,13 @@ struct TypeFlowSpecializationContext return true; } - // Get the witness table operand info + // If we reach here, we have a truly dynamic case. Multiple elements and not a type + // collection We need to emit a run-time inst to keep track of the tag. + // + // We use the GetTagForMappedCollection inst to do this, and set its data type to + // the appropriate tag-type. + // + auto witnessTableInst = inst->getWitnessTable(); auto witnessTableInfo = tryGetInfo(context, witnessTableInst); @@ -2188,7 +2540,9 @@ struct TypeFlowSpecializationContext operands.getCount(), operands.getBuffer()); inst->replaceUsesWith(newInst); - propagationMap[Element(context, newInst)] = info; + + // We'll register the info for the newInst so any users of the new inst can use it. + propagationMap[InstWithContext(context, newInst)] = info; inst->removeAndDeallocate(); return false; @@ -2198,6 +2552,15 @@ struct TypeFlowSpecializationContext IRInst* context, IRExtractExistentialWitnessTable* inst) { + // If we have a non-trivial info registered, it must of + // CollectionTagType(TableCollection(...)) + // + // Futher, the operand must be an existential (CollectionTaggedUnionType), which is + // conceptually lowered to a TupleType(TagType(tableCollection), typeCollection) + // + // We will simply extract the first element of this tuple. + // + auto info = tryGetInfo(context, inst); if (!info) return false; @@ -2273,52 +2636,6 @@ struct TypeFlowSpecializationContext return true; } - // Split into direction and type - std::tuple getParameterDirectionAndType(IRType* paramType) - { - if (as(paramType)) - return { - ParameterDirectionInfo(ParameterDirectionInfo::Kind::Out), - as(paramType)->getValueType()}; - else if (as(paramType)) - return { - ParameterDirectionInfo(ParameterDirectionInfo::Kind::BorrowInOut), - as(paramType)->getValueType()}; - else if (as(paramType)) - return { - ParameterDirectionInfo( - ParameterDirectionInfo::Kind::Ref, - as(paramType)->getAddressSpace()), - as(paramType)->getValueType()}; - else if (as(paramType)) - return { - ParameterDirectionInfo( - ParameterDirectionInfo::Kind::BorrowIn, - as(paramType)->getAddressSpace()), - as(paramType)->getValueType()}; - else - return {ParameterDirectionInfo(ParameterDirectionInfo::Kind::In), paramType}; - } - - IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo direction, IRType* type) - { - switch (direction.kind) - { - case ParameterDirectionInfo::Kind::In: - return type; - case ParameterDirectionInfo::Kind::Out: - return builder->getOutParamType(type); - case ParameterDirectionInfo::Kind::BorrowInOut: - return builder->getBorrowInOutParamType(type); - case ParameterDirectionInfo::Kind::BorrowIn: - return builder->getBorrowInParamType(type, direction.addressSpace); - case ParameterDirectionInfo::Kind::Ref: - return builder->getRefParamType(type, direction.addressSpace); - default: - SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); - } - } - bool isTaggedUnionType(IRInst* type) { if (auto tupleType = as(type)) @@ -2400,8 +2717,28 @@ struct TypeFlowSpecializationContext } } + // Get an effective func type to use for the callee. + // The callee may be a collection, in which case, this returns a union-ed functype./ + // IRFuncType* getEffectiveFuncType(IRInst* callee) { + // The effective func type for a callee is calculated as follows: + // + // (i) we build up the effective parameter types for the callee + // by taking the union of each parameter type + // for each callee in the collection. + // + // (ii) build up the effective result type in a similar manner. + // + // (iii) add extra tag parameters as necessary: + // + // - if we have multiple callees, then a parameter of TagType(callee) is appended + // to the beginning to select the callee. + // + // - if our callee is Specialize inst with collection args, then for each + // table-collection argument, a tag is required as input. + // + IRBuilder builder(module); List paramTypes; @@ -2411,16 +2748,19 @@ struct TypeFlowSpecializationContext { if (paramTypes.getCount() <= index) { - // If we don't have enough types, just add the new type - paramTypes.add(paramType); + // If this index hasn't been seen yet, expand the buffer and initialize + // the type. + // + paramTypes.growToCount(index + 1); + paramTypes[index] = paramType; return paramType; } else { // Otherwise, update the existing type auto [currentDirection, currentType] = - getParameterDirectionAndType(paramTypes[index]); - auto [newDirection, newType] = getParameterDirectionAndType(paramType); + splitParameterDirectionAndType(paramTypes[index]); + auto [newDirection, newType] = splitParameterDirectionAndType(paramType); auto updatedType = updateType(currentType, newType); SLANG_ASSERT(currentDirection == newDirection); paramTypes[index] = fromDirectionAndType(&builder, currentDirection, updatedType); @@ -2447,7 +2787,7 @@ struct TypeFlowSpecializationContext for (auto context : contextsToProcess) { - auto paramEffectiveTypes = getParamEffectiveTypes(context); + auto paramEffectiveTypes = getEffectiveParamTypes(context); auto paramDirections = getParamDirections(context); for (Index i = 0; i < paramEffectiveTypes.getCount(); i++) @@ -2496,6 +2836,7 @@ struct TypeFlowSpecializationContext // If this is a dynamic generic, we need to add a tag type for each // TableCollection in the callee. + // for (UIndex i = 0; i < specializeInst->getArgCount(); i++) if (auto tableCollection = as(specializeInst->getArg(i))) extraParamTypes.add((IRType*)makeTagType(tableCollection)); @@ -2508,6 +2849,158 @@ struct TypeFlowSpecializationContext return builder.getFuncType(allParamTypes, resultType); } + // Upcast the value in 'arg' to match the destInfo type. This method inserts + // any necessary reinterprets or tag translation instructions. + // + IRInst* upcastCollection(IRInst* context, IRInst* arg, IRType* destInfo) + { + // The upcasting process inserts the appropriate instructions + // to make arg's type match the type provided by destInfo. + // + // This process depends on the structure of arg and destInfo. + // + // We only deal with the type-flow data-types that are created in + // our pass (CollectionBase/CollectionTaggedUnionType/CollectionTagType/any other + // composites of these insts) + // + + auto argInfo = arg->getDataType(); + if (!argInfo || !destInfo) + return arg; + + if (as(argInfo) && as(destInfo)) + { + // A collection tagged union is essentially a tuple(TagType(tableCollection), + // typeCollection) We simply extract the two components, upcast each one, and put it + // back together. + // + + auto argTUType = as(argInfo); + auto destTUType = as(destInfo); + + if (getCollectionCount(argTUType) != getCollectionCount(destTUType)) + { + // Technically, IRCollectionTaggedUnionType is not a TupleType, + // but in practice it works the same way so we'll re-use Slang's + // tuple accessors & constructors + // + IRBuilder builder(arg->getModule()); + setInsertAfterOrdinaryInst(&builder, arg); + auto argTableTag = builder.emitGetTupleElement( + (IRType*)makeTagType(argTUType->getTableCollection()), + arg, + 0); + auto reinterpretedTag = upcastCollection( + context, + argTableTag, + (IRType*)makeTagType(destTUType->getTableCollection())); + + auto argVal = + builder.emitGetTupleElement((IRType*)argTUType->getTypeCollection(), arg, 1); + auto reinterpretedVal = + upcastCollection(context, argVal, (IRType*)destTUType->getTypeCollection()); + return builder.emitMakeTuple( + (IRType*)destTUType, + {reinterpretedTag, reinterpretedVal}); + } + } + else if (as(argInfo) && as(destInfo)) + { + // TODO: This case should not occur anymore since we replaced the bare tuple-type + // with collection-tagged-union-type. + // + SLANG_UNEXPECTED("Should not happen"); + + auto argTupleType = as(argInfo); + auto destTupleType = as(destInfo); + + List upcastedElements; + bool hasUpcastedElements = false; + + IRBuilder builder(module); + setInsertAfterOrdinaryInst(&builder, arg); + + // Upcast each element of the tuple + for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) + { + auto argElementType = argTupleType->getOperand(i); + auto destElementType = destTupleType->getOperand(i); + + // If the element types are different, we need to reinterpret + if (argElementType != destElementType) + { + hasUpcastedElements = true; + upcastedElements.add(upcastCollection( + context, + builder.emitGetTupleElement((IRType*)argElementType, arg, i), + (IRType*)destElementType)); + } + else + { + upcastedElements.add( + builder.emitGetTupleElement((IRType*)argElementType, arg, i)); + } + } + + if (hasUpcastedElements) + { + return builder.emitMakeTuple(upcastedElements); + } + } + else if (as(argInfo) && as(destInfo)) + { + // If the arg represents a tag of a colleciton, but the dest is a _different_ + // collection, then we need to emit a tag operation to reinterpret the + // tag. + // + // Note that, by the invariant provided by the typeflow analysis, the target + // collection must necessarily be a super-set. + // + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + IRBuilder builder(module); + setInsertAfterOrdinaryInst(&builder, arg); + return builder + .emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); + } + } + else if (as(argInfo) && as(destInfo)) + { + // If the arg has a collection type, but the dest is a _different_ collection, + // we need to perform a reinterpret. + // + // e.g. TypeCollection({T1, T2}) may lower to AnyValueType(N), while + // TypeCollection({T1, T2, T3}) may lower to AnyValueType(M). Since the target + // is necessarily a super-set, the target any-value-type is always larger (M >= N), + // so we only need a simple reinterpret. + // + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + // If the sets of witness tables are not equal, reinterpret to the parameter type + IRBuilder builder(module); + setInsertAfterOrdinaryInst(&builder, arg); + return builder.emitReinterpret((IRType*)destInfo, arg); + } + } + else if (!as(argInfo) && as(destInfo)) + { + // If the arg is not a collection-type, but the dest is a collection, + // we need to perform a pack operation. + // + // This case only arises when passing a value of type T to a parameter + // of a type-collection that contains T. + // + IRBuilder builder(module); + setInsertAfterOrdinaryInst(&builder, arg); + return builder.emitPackAnyValue((IRType*)destInfo, arg); + } + + return arg; // Can use as-is. + } + + // TODO: Is this required? IRInst* getCalleeForContext(IRInst* context) { if (this->contextsToLower.contains(context)) @@ -2521,6 +3014,11 @@ struct TypeFlowSpecializationContext return context; } + // Helper function for specializing calls. + // + // For a `Specialize` instruction that has dynamic tag arguments, + // extract all the tags and return them as a list. + // List getArgsForDynamicSpecialization(IRSpecialize* specializedCallee) { List callArgs; @@ -2539,57 +3037,6 @@ struct TypeFlowSpecializationContext return callArgs; } - bool specializeCallToDynamicGeneric(IRInst* context, IRCall* inst) - { - auto specializedCallee = as(inst->getCallee()); - auto calleeInfo = tryGetInfo(context, specializedCallee); - auto calleeCollection = as(calleeInfo); - if (!calleeCollection || getCollectionCount(calleeCollection) != 1) - return false; - - auto targetContext = getCollectionElement(calleeCollection, 0); - - List callArgs; - for (UInt ii = 0; ii < specializedCallee->getArgCount(); ii++) - { - auto specArg = specializedCallee->getArg(ii); - auto argInfo = tryGetInfo(context, specArg); - if (auto argCollection = as(argInfo)) - { - if (as(getCollectionElement(argCollection, 0))) - { - // Needs an index (spec-arg will carry an index, we'll - // just need to append it to the call) - // - callArgs.add(specArg); - } - else if (as(getCollectionElement(argCollection, 0))) - { - // Needs no dynamic information. Skip. - } - else - { - // If it's a witness table, we need to handle it differently - // For now, we will not specialize this case. - SLANG_UNEXPECTED("Unhandled type-flow-collection in dynamic generic call"); - } - } - } - - for (UInt ii = 0; ii < inst->getArgCount(); ii++) - callArgs.add(inst->getArg(ii)); - - IRBuilder builder(inst->getModule()); - builder.setInsertBefore(inst); - auto newCallInst = builder.emitCallInst( - as(targetContext->getDataType())->getResultType(), - getCalleeForContext(targetContext), - callArgs); - inst->replaceUsesWith(newCallInst); - inst->removeAndDeallocate(); - return true; - } - void maybeSpecializeCalleeType(IRInst* callee) { if (auto specializeInst = as(callee->getDataType())) @@ -2601,6 +3048,80 @@ struct TypeFlowSpecializationContext bool specializeCall(IRInst* context, IRCall* inst) { + // The overall goal is to remove any dynamic-ness in the call inst + // (i.e. the callee as well as types of arguments should be global + // insts) + // + // There are a few cases we need to handle when specializing a call + // inst. + // + // First, we handle the callee: + // + // - If the callee is already a concrete function, there's nothing to do + // + // - If the callee is a dynamic inst of tag type, we replace + // the callee with the collection itself, and pass the tag inst as + // the first operand. Effectively, we are placing a call to a set of functions + // and using the tag to specify which function to call. + // + // e.g. + // let tag : TagType(funcCollection) = /* ... */; + // let val = Call(tag, arg1, arg2, ...); + // becomes + // let tag : TagType(funcCollection) = /* ... */; + // let val = Call(funcCollection, tag, arg1, arg2, ...); + // + // - If any the callee is a dynamic specialization of a generic, we need to add any dynamic + // witness + // table insts as arguments to the call. + // + // e.g.: + // Call( + // Specialize(g, specArgs...), callArgs...); + // where atleast one of specialization args is a dynamic tag inst. + // + // Our convention for dynamic generics is that the dynamic witness table + // operands are added to the front of the regular call arguments. + // + // So, we'll turn this into: + // Call( + // Specialize(g, staticFormOfSpecArgs...), dynamicSpecArgs..., callArgs...); + // where the new callee is a specialization where any dynamic insts are + // replaced with their static collections. + // + // // --- before specialization --- + // let s1 : TagType(TableCollection(tA, tB, tC)) = /* ... */; + // let s2 : TagType(TypeCollection(A, B, C)) = /* ... */; + // let specCallee = Specialize(generic, s1, s2); + // let val = Call(specCallee, /* call args */); + // + // // --- after specialization --- + // let s1 : TagType(TableCollection(tA, tB, tC)) = /* ... */; + // let s2 : TagType(TypeCollection(A, B, C)) = /* ... */; + // let newSpecCallee = Specialize(generic, + // TableCollection(tA, tB, tC), TypeCollection(A, B, C)); + // let newVal = Call(newSpecCallee, s1, /* call args */); + // + // + // - In case the callee is a collection of dynamically specialized + // generics, _both_ of the above transformations are applied, with + // the callee's tag going first, followed by any witness table tags + // and finally the regular call arguments. + // However, this case is NOT currently well supported because the func-collection + // tag does not encode the additional tags that need to be passed, so this + // is likely to fail currently. + // This is a rare scenario that only occurs on trying to specialize an existential + // method with existential arguments, which we don't officially support. + // + // Secondly, we handle the argument types: + // It is possible that + // the parameters in the callee have been specialized to accept + // a wider collection compared to the arguments from this call site + // + // In this case, we just upcast them using `upcastCollection` before + // creating a new call inst + // + auto callee = inst->getCallee(); IRInst* calleeTagInst = nullptr; @@ -2694,7 +3215,7 @@ struct TypeFlowSpecializationContext { auto arg = inst->getArg(i); const auto [paramDirection, paramType] = - getParameterDirectionAndType(expectedFuncType->getParamType(i + extraArgCount)); + splitParameterDirectionAndType(expectedFuncType->getParamType(i + extraArgCount)); switch (paramDirection.kind) { @@ -2762,6 +3283,11 @@ struct TypeFlowSpecializationContext bool specializeMakeStruct(IRInst* context, IRMakeStruct* inst) { + // The main thing to handle here is that we might have specialized + // the fields of the struct, so we need to upcast the arguments + // if necessary. + // + auto structType = as(inst->getDataType()); if (!structType) return false; @@ -2788,6 +3314,13 @@ struct TypeFlowSpecializationContext bool specializeMakeExistential(IRInst* context, IRMakeExistential* inst) { + // After specialization, existentials (that are not unbounded) are treated as tuples + // of a TableCollection tag and a value of type TypeCollection. + // + // A MakeExistential is just converted into a MakeTuple, with any necessary + // upcasts. + // + auto info = tryGetInfo(context, inst); auto taggedUnion = as(info); if (!taggedUnion) @@ -2839,6 +3372,16 @@ struct TypeFlowSpecializationContext bool specializeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { + // A CreateExistentialObject uses an user-provided ID to create an object. + // Note that this ID is not the same as the tags we use. The user-provided ID must be + // compared against the SequentialID, which is a globally consistent & public ID present + // on the witness tables. + // + // The tags are a locally consistent ID whose semantics are only meaningful within the + // function. We use a special op `GetTagFromSequentialID` to convert from the user-provided + // global ID to a local tag ID. + // + auto info = tryGetInfo(context, inst); auto taggedUnion = as(info); if (!taggedUnion) @@ -2882,6 +3425,14 @@ struct TypeFlowSpecializationContext bool specializeStructuredBufferLoad(IRInst* context, IRInst* inst) { + // The key thing to take care of here is a load from an + // interface-typed pointer. + // + // Our type-flow analysis will convert the + // result into a collection of all available implementations of this + // interface, so we need to cast the result. + // + auto valInfo = tryGetInfo(context, inst); if (!valInfo) @@ -2898,12 +3449,18 @@ struct TypeFlowSpecializationContext { // If we're dealing with a loading a known tagged union value from // an interface-typed pointer, we'll cast the pointer itself and - // defer the specializeing of the load until later. + // defer the specializing of the load until a later lowering + // pass. // // This avoids having to change the source pointer type // and confusing any future runs of the type flow // analysis pass. // + // This is slightly different from how a local 'load' is handled, + // because we don't want to modify the pointer (and consequently the global + // buffer) type, since it is a publicly visible type that is laid out + // in a certain way. + // IRBuilder builder(inst); builder.setInsertAfter(inst); auto bufferHandle = inst->getOperand(0); @@ -2937,6 +3494,31 @@ struct TypeFlowSpecializationContext bool specializeSpecialize(IRInst* context, IRSpecialize* inst) { + // When specializing a `Specialize` instruction, we have a few nuances. + // + // If we're dealing with specializing a type, witness table, or any other + // generic, we simply drop all dynamic tag information, and replace all + // operands with their collection variants. + // + // If we're dealing with a function, there are two cases: + // - A single function when dynamic specialization arguments. + // Removing the dynamic tag information will result in the eventual 'call' + // inst not have access to these insts. + // + // Instead, we'll just replace the type, and retain the `Specialize` inst with + // the dynamic args. It will be specialized out in `specializeCall` instead. + // + // - A collection of functions with concrete specialization arguments. + // In this case, we will emit an instruction to map from the input generic collection + // to the output specialized collection via `GetTagForSpecializedCollection`. + // This inst encodes the key-value mapping in its operands: + // e.g.(input_tag, key0, value0, key1, value1, ...) + // + // - The case where there is a collection of functions with dynamic specialization arguments + // is not currently properly handled. This case should not arise naturally since we + // don't advertise support for it. + // + bool isFuncReturn = false; // TODO: Would checking this inst's info be enough instead? @@ -3039,9 +3621,12 @@ struct TypeFlowSpecializationContext return false; } - bool specializeGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) { + // GetValueFromBoundInterface is essentially accessing the value component of + // an existential. We turn it into a tuple element access. + // + SLANG_UNUSED(context); auto destType = inst->getDataType(); auto operandInfo = inst->getOperand(0)->getDataType(); @@ -3059,6 +3644,20 @@ struct TypeFlowSpecializationContext bool specializeLoad(IRInst* context, IRInst* inst) { + // There's two cases to handle.. + // + // (i) For a simple load, the pointer itself is already specialized. + // so we just need to replace the type of the load with specialized type. + // + // (ii) if there is a mismatch between the two types, the most likely + // case is that we're trying to load from an interface typed location + // (whose type we cannot modify), and cast it into a tagged union tuple. + // + // This case is similar to `specializeStructuredBufferLoad`, where we + // cast the _pointer_ to convert its type, and defer the legalization of + // the load to a later lowering pass. + // + auto valInfo = tryGetInfo(context, inst); if (!valInfo) @@ -3110,6 +3709,16 @@ struct TypeFlowSpecializationContext bool handleDefaultStore(IRInst* context, IRStore* inst) { + // This handles a rare case in the compiler, where we + // try to use default-construct to initialize a field. + // + // This is not technically supported, but it can occur + // during some corner cases during higher-order auto-diff. + // + // In case we've specialized the field, we just need to + // modify the default-construct operand's type to + // match the field. + // SLANG_UNUSED(context); SLANG_ASSERT(inst->getVal()->getOp() == kIROp_DefaultConstruct); auto ptr = inst->getPtr(); @@ -3128,6 +3737,11 @@ struct TypeFlowSpecializationContext bool specializeStore(IRInst* context, IRStore* inst) { + // Similar to `specializeLoad`, we handle cases where + // the pointer has been specialized, so that we upcast + // our value to match the type before writing to the location. + // + auto ptr = inst->getPtr(); auto ptrInfo = as(ptr->getDataType())->getValueType(); @@ -3155,6 +3769,12 @@ struct TypeFlowSpecializationContext bool specializeGetSequentialID(IRInst* context, IRGetSequentialID* inst) { + // A SequentialID is a globally unique ID for a witness table, while the + // the tags we use in the specialization are only locally consistent. + // + // To extract the global ID, we'll use a separate op code `GetSequentialIDFromTag` + // for now and lower it later once all the global sequential IDs have been assigned. + // SLANG_UNUSED(context); auto arg = inst->getOperand(0); if (auto tagType = as(arg->getDataType())) @@ -3181,6 +3801,13 @@ struct TypeFlowSpecializationContext bool specializeIsType(IRInst* context, IRIsType* inst) { + // The is-type checks equality between two witness tables + // via their sequential IDs. + // + // If the dynamic part has been specialized into a tag, we emit + // a `GetSequentialIDFromTag` inst to extract the ID and emit + // an equality test. + // SLANG_UNUSED(context); auto witnessTableArg = inst->getValueWitness(); if (auto tagType = as(witnessTableArg->getDataType())) @@ -3213,10 +3840,6 @@ struct TypeFlowSpecializationContext return false; } - bool isExistentialType(IRType* type) { return as(type) != nullptr; } - - bool isInterfaceType(IRType* type) { return as(type) != nullptr; } - HashSet collectExistentialTables(IRInterfaceType* interfaceType) { HashSet tables; @@ -3258,9 +3881,15 @@ struct TypeFlowSpecializationContext bool hasChanges = false; // Phase 1: Information Propagation + // This phase propagates type information through the module + // and records them into different maps in the current context. + // performInformationPropagation(); - // Phase 2: Dynamic Instruction Lowering + // Phase 2: Dynamic Instruction Specialization + // Re-write dynamic instructions into specialized versions based on the + // type information in the previous phase. + // hasChanges |= performDynamicInstLowering(); return hasChanges; @@ -3276,25 +3905,30 @@ struct TypeFlowSpecializationContext IRModule* module; DiagnosticSink* sink; - // Mapping from instruction to propagation information - Dictionary propagationMap; + // Mapping from (context, inst) --> propagated info + Dictionary propagationMap; - // Mapping from function to return value propagation information + // Mapping from context --> return value info Dictionary funcReturnInfo; - // Mapping from struct fields to propagation information + // Mapping from (struct field) --> propagated info Dictionary fieldInfo; - // Mapping from functions to call-sites. - Dictionary> funcCallSites; - - // Mapping from fields to use-sites. - Dictionary> fieldUseSites; + // Mapping from context --> Set<(context, inst)> + // + // Maintains a mapping from a callable context to all call-sites + // (and caller contexts) + // + Dictionary> funcCallSites; - // Mapping from specialized instruction to their any-value types - Dictionary loweredInstToAnyValueType; + // Mapping from (struct-field) --> Set<(context, inst)> + // + // Maintains a mapping from a struct field to all accesses of that + // field + // + Dictionary> fieldUseSites; - // Set of open contexts + // Set of already discovered contexts. HashSet availableContexts; // Contexts requiring lowering @@ -3303,7 +3937,7 @@ struct TypeFlowSpecializationContext // Lowered contexts. Dictionary loweredContexts; - // Context for building collection insts + // Helper for building collections. CollectionBuilder cBuilder; }; From af27e084e178661be30133ceeaaea005d9a57a23 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:38:49 -0400 Subject: [PATCH 057/105] Add comments for typeflow inst lowering pass. --- .../slang/slang-ir-lower-typeflow-insts.cpp | 185 +++++++++++++++--- 1 file changed, 155 insertions(+), 30 deletions(-) diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index e49853fd210..1c7b5a3a349 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -28,11 +28,12 @@ IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& t return builder->getAnyValueType(size); } +// Generate a single function that dispatches to each function in the collection. +// The resulting function will have one additional parameter to accept the tag +// indicating which function to call. +// IRFunc* createDispatchFunc(IRFuncCollection* collection) { - // An effective func type should have been set during the dynamic-inst-lowering - // pass. - // IRFuncType* dispatchFuncType = cast(collection->getFullType()); // Create a dispatch function with switch-case for each function @@ -146,10 +147,11 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) return func; } - +// Create a function that maps input integers to output integers based on the provided mapping. IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mapping, UInt defaultVal) { - // Create a function that maps input IDs to output IDs + // Emit a switch statement with the inputs as case labels and outputs as return values. + IRBuilder builder(module); auto funcType = @@ -211,10 +213,9 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi return func; } -// This context lowers `IRGetTagFromSequentialID`, -// `IRGetTagForSuperCollection`, and `IRGetTagForMappedCollection` instructions, +// This context lowers `GetTagForSpecializedCollection`, +// `GetTagForSuperCollection`, and `GetTagForMappedCollection` instructions, // - struct TagOpsLoweringContext : public InstPassBase { TagOpsLoweringContext(IRModule* module) @@ -224,6 +225,25 @@ struct TagOpsLoweringContext : public InstPassBase void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) { + // We use the result type and the type of the operand + // to figure out the source and destination collections. + // + // We then replace this with an array access, where the i'th + // element of the array is the corresponding index in the super + // collection. + // + // e.g. + // let a : TagType(TableCollection(B, C)) = /* ... */; + // let b : TagType(TableCollection(A, B, C)) = GetTagForSuperCollection(a); + // becomes + // let a : TagType(TableCollection(B, C)) = /* ... */; + // let lookupArr : ArrayType = [1, 2]; // B is at index 1, C is at index 2 + // let b : TagType(TableCollection(A, B, C)) = ElementExtract(lookupArr, a); + // + // Note that we leave the tag-types of the output intact since we may need to lower + // later tag operations. + // + auto srcCollection = cast( cast(inst->getOperand(0)->getDataType())->getOperand(0)); auto destCollection = @@ -271,6 +291,26 @@ struct TagOpsLoweringContext : public InstPassBase void lowerGetTagForMappedCollection(IRGetTagForMappedCollection* inst) { + // We use the result type and the type of the operand + // to figure out the source and destination collections. + // + // We then replace this with an array access, where the i'th + // element of the array is the corresponding index in the mapped + // collection. + // + // e.g. + // let a : TagType(TableCollection(B, C)) = /* ... */; + // let b : TagType(FuncCollection(C_key, B_key)) = + // GetTagForMappedCollection(a, key); + // becomes + // let a : TagType(TableCollection(B, C)) = /* ... */; + // let lookupArr : ArrayType = [1, 0]; // B is at index 1, C is at index 0 + // let b : TagType(FuncCollection(C_key, B_key)) = ElementExtract(lookupArr, a); + // + // Note that we leave the tag-types of the output intact since we may need to lower + // later tag operations. + // + auto srcCollection = cast( cast(inst->getOperand(0)->getDataType())->getOperand(0)); auto destCollection = @@ -319,6 +359,28 @@ struct TagOpsLoweringContext : public InstPassBase void lowerGetTagForSpecializedCollection(IRGetTagForSpecializedCollection* inst) { + // We use the result type and the type of the operand + // to figure out the source and destination collections. + // + // We then replace this with an array access, where the i'th + // element of the array is the corresponding index in the mapped + // collection. + // + // The mapping between elements is provided as pairs of operands to the instruction. + // + // e.g. + // let a : TagType(GenericCollection(B, C)) = /* ... */; + // let b : TagType(FuncCollection(E, F)) = + // GetTagForSpecializedCollection(a, B, F, C, E); + // becomes + // let a : TagType(GenericCollection(B, C)) = /* ... */; + // let lookupArr : ArrayType = [1, 0]; // B->F, C->E + // let b : TagType(FuncCollection(E, F)) = ElementExtract(lookupArr, a); + // + // Note that we leave the tag-types of the output intact since we may need to lower + // later tag operations. + // + auto srcCollection = cast(inst->getOperand(0)->getDataType())->getCollection(); auto destCollection = cast(inst->getDataType())->getCollection(); @@ -369,7 +431,6 @@ struct TagOpsLoweringContext : public InstPassBase inst->removeAndDeallocate(); } - void processInst(IRInst* inst) { switch (inst->getOp()) @@ -390,6 +451,12 @@ struct TagOpsLoweringContext : public InstPassBase void lowerFuncCollection(IRFuncCollection* collection) { + // Replace the `IRFuncCollection` with a dispatch function, + // which takes an extra first parameter for the tag (i.e. ID) + // + // We'll also replace the callee in all 'call' insts. + // + IRBuilder builder(collection->getModule()); if (collection->hasUses() && collection->getDataType() != nullptr) { @@ -422,7 +489,7 @@ struct TagOpsLoweringContext : public InstPassBase } }; -// This context lowers `IRTypeCollection` and `IRFuncCollection` instructions +// This context lowers `TypeCollection` instructions. struct CollectionLoweringContext : public InstPassBase { CollectionLoweringContext(IRModule* module) @@ -432,6 +499,10 @@ struct CollectionLoweringContext : public InstPassBase void lowerTypeCollection(IRTypeCollection* collection) { + // Type collections are replaced with `AnyValueType` large enough to hold + // any of the types in the collection. + // + HashSet types; for (UInt i = 0; i < collection->getCount(); i++) { @@ -454,6 +525,7 @@ struct CollectionLoweringContext : public InstPassBase } }; +// Lower `TypeCollection` instructions void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -461,16 +533,30 @@ void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) context.processModule(); } +// This context lowers `IRGetTagFromSequentialID` and `IRGetSequentialIDFromTag` instructions. +// Note: This pass requires that sequential ID decorations have been created for all witness +// tables. +// struct SequentialIDTagLoweringContext : public InstPassBase { SequentialIDTagLoweringContext(IRModule* module) : InstPassBase(module) { } - void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) { - SLANG_UNUSED(cast(inst->getOperand(0))); + // We use the result type to figure out the destination collection + // for which we need to generate the tag. + // + // We then replace this with call into an integer mapping function, + // which takes the sequential ID and returns the local ID (i.e. tag). + // + // To construct, the mapping, we lookup the sequential ID decorator on + // each element of the destination collection, and map it to the table's + // operand index in the collection. + // + + // We use the result type and the type of the operand auto srcSeqID = inst->getOperand(1); Dictionary mapping; @@ -484,7 +570,6 @@ struct SequentialIDTagLoweringContext : public InstPassBase [&](IRInst* table) { // Get unique ID for the witness table - SLANG_UNUSED(cast(table)); auto outputId = dstSeqID++; auto seqDecoration = table->findDecoration(); if (seqDecoration) @@ -497,7 +582,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase IRBuilder builder(inst); builder.setInsertAfter(inst); - // Default to largest available sequential ID. + // By default, use the tag for the largest available sequential ID. UInt defaultSeqID = 0; for (auto [inputId, outputId] : mapping) { @@ -517,6 +602,10 @@ struct SequentialIDTagLoweringContext : public InstPassBase void lowerGetSequentialIDFromTag(IRGetSequentialIDFromTag* inst) { + // Similar logic to the `GetTagFromSequentialID` case, except that + // we reverse the mapping. + // + SLANG_UNUSED(cast(inst->getOperand(0))); auto srcTagInst = inst->getOperand(1); @@ -564,6 +653,7 @@ struct SequentialIDTagLoweringContext : public InstPassBase } }; +// Lower `GetTagFromSequentialID` and `GetSequentialIDFromTag` instructions void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -571,6 +661,7 @@ void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) context.processModule(); } +// Lower `FuncCollection`, `GetTagForSuperCollection`, `GetTagForMappedCollection` void lowerTagInsts(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -578,6 +669,8 @@ void lowerTagInsts(IRModule* module, DiagnosticSink* sink) tagContext.processModule(); } +// This context lowers `IRCollectionTagType` instructions, by replacing +// them with a suitable integer type. struct TagTypeLoweringContext : public InstPassBase { TagTypeLoweringContext(IRModule* module) @@ -648,7 +741,39 @@ struct TaggedUnionLoweringContext : public InstPassBase void lowerCastInterfaceToTaggedUnionPtr(IRCastInterfaceToTaggedUnionPtr* inst) { - // Find all uses of the inst + // `CastInterfaceToTaggedUnionPtr` is used to 'reinterpret' a pointer to an interface-typed + // location into a tagged union type. Usually this is to avoid changing the type of the + // base location because it is externally visible, and to avoid touching the external layout + // of the interface type. + // + // To lower this, we won't actually change the pointer or the base location, but instead + // rewrite all loads and stores out of this pointer by converting the existential into a + // tagged union tuple. + // + // e.g. + // + // let basePtr : PtrType(InterfaceType(I)) = /* ... */; + // let tuPtr : PtrType(CollectionTaggedUnionType(types, tables)) = + // CastInterfaceToTaggedUnionPtr(basePtr); + // let loadedVal : CollectionTaggedUnionType(...) = Load(tuPtr); + // + // becomes + // + // let basePtr : PtrType(InterfaceType(I)) = /* ... */; + // let intermediateVal : InterfaceType(I) = Load(basePtr); + // let loadedTableID : TagType(tables) = + // GetTagFromSequentialID( + // InterfaceType(I), + // GetSequentialID( + // ExtractExistentialWitnessTable(intermediateVal))); + // let loadedVal : types = ExtractExistentialValue(intermediateVal); + // let loadedTuple : TupleType(TagType(tables), types) = + // MakeTuple(loadedTableID, loadedVal); + // + // The logic is similar for StructuredBufferLoad and RWStructuredBufferLoad, + // but the operands structure is slightly different. + // + traverseUses( inst, [&](IRUse* use) @@ -721,15 +846,19 @@ struct TaggedUnionLoweringContext : public InstPassBase inst->removeAndDeallocate(); } - void lowerCastTaggedUnionToInterfacePtr(IRCastTaggedUnionToInterfacePtr* inst) - { - SLANG_UNUSED(inst); - SLANG_UNEXPECTED("Unexpected inst of CastTaggedUnionToInterfacePtr"); - } - IRType* convertToTupleType(IRCollectionTaggedUnionType* taggedUnion) { - // Replace type with Tuple + // Replace `CollectionTaggedUnionType(typeCollection, tableCollection)` with + // `TupleType(CollectionTagType(tableCollection), typeCollection)` + // + // Unless the collection has a single element, in which case we + // replace it with `TupleType(CollectionTagType(tableCollection), elementType)` + // + // We still maintain a tuple type (even though it's not really necesssary) to avoid + // breaking any operations that assumed this is a tuple. + // In the single element case, the tuple should be optimized away. + // + IRBuilder builder(module); builder.setInsertInto(module); @@ -758,6 +887,7 @@ struct TaggedUnionLoweringContext : public InstPassBase inst->removeAndDeallocate(); }); + // Then, convert any loads/stores from reinterpreted pointers. bool hasCastInsts = false; processInstsOfType( kIROp_CastInterfaceToTaggedUnionPtr, @@ -767,18 +897,13 @@ struct TaggedUnionLoweringContext : public InstPassBase return lowerCastInterfaceToTaggedUnionPtr(inst); }); - processInstsOfType( - kIROp_CastTaggedUnionToInterfacePtr, - [&](IRCastTaggedUnionToInterfacePtr* inst) - { - hasCastInsts = true; - return lowerCastTaggedUnionToInterfacePtr(inst); - }); - return hasCastInsts; } }; +// Lower `CollectionTaggedUnion`and `CastInterfaceToTaggedUnionPtr` instructions +// May create new `Reinterpret` instructions. +// bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); From 625d521f1c8770dd3b97d13c7f292070bdbb8ff4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:44:16 -0400 Subject: [PATCH 058/105] More comment fixes --- source/slang/slang-ir-lower-typeflow-insts.cpp | 5 ----- source/slang/slang-ir-lower-typeflow-insts.h | 12 ++++++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 1c7b5a3a349..7b48c6c2c77 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -653,7 +653,6 @@ struct SequentialIDTagLoweringContext : public InstPassBase } }; -// Lower `GetTagFromSequentialID` and `GetSequentialIDFromTag` instructions void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -661,7 +660,6 @@ void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) context.processModule(); } -// Lower `FuncCollection`, `GetTagForSuperCollection`, `GetTagForMappedCollection` void lowerTagInsts(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -901,9 +899,6 @@ struct TaggedUnionLoweringContext : public InstPassBase } }; -// Lower `CollectionTaggedUnion`and `CastInterfaceToTaggedUnionPtr` instructions -// May create new `Reinterpret` instructions. -// bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h index c3be5d27339..82f1cf5fc48 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -5,11 +5,23 @@ namespace Slang { +// Lower `TypeCollection` instructions void lowerTypeCollections(IRModule* module, DiagnosticSink* sink); + +// Lower `FuncCollection`, `GetTagForSuperCollection`, `GetTagForMappedCollection` and +// `GetTagForSpecializedCollection` instructions +// void lowerTagInsts(IRModule* module, DiagnosticSink* sink); +// Lower `GetTagFromSequentialID` and `GetSequentialIDFromTag` instructions void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink); + +// Lower `CollectionTagType` instructions void lowerTagTypes(IRModule* module); +// Lower `CollectionTaggedUnion`and `CastInterfaceToTaggedUnionPtr` instructions +// May create new `Reinterpret` instructions. +// bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); + } // namespace Slang From e50190a680d6dff8117bd2d56aabd07c45e9cf85 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Oct 2025 11:37:27 -0400 Subject: [PATCH 059/105] More documentation --- source/slang/slang-ir-specialize.cpp | 73 ++++++++++++------- source/slang/slang-ir-specialize.h | 2 + source/slang/slang-ir-typeflow-specialize.cpp | 45 ------------ 3 files changed, 48 insertions(+), 72 deletions(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index b3dc273c3a3..ce1bb0c56ce 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3171,36 +3171,55 @@ void finalizeSpecialization(IRModule* module) } } -/* -// DUPLICATE: merge. -static bool isDynamicGeneric(IRInst* callee) +// Evaluate a `Specialize` inst where the arguments are collections rather than +// concrete singleton types and the generic returns a function. +// +// This needs to be slightly different from the usual case because the function +// needs dynamic information to select a specific element from each collection +// at runtime. +// +// The resulting function will therefore have additional parameters at the beginning +// to accept this information. +// +static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) { - // If the callee is a specialization, and at least one of its arguments - // is a type-flow-collection, then it is a dynamic generic. + // The high-level logic for specializing a generic to operate over collections + // is similar to specializing a simple generic: + // We "evaluate" the instructions in the first block of the generic and return + // the function that is returned by the generic. + // + // The key difference is that in the static case, all generic parameters, and instructions + // in the generic's body are guaranteed to be "baked out" into concrete types or witness tables. + // + // In the dynamic case, some generic parameters may turn into function parameters that accept a + // tag, and any lookup instructions might then have to be cloned into the function body. + // + // This is a slightly complex transformation that proceeds as follows: + // + // - Create an empty function that represents the final product. + // + // - Add any dynamic parameters of the generic to the function's first block. Keep track of the + // first block for later. For now, we only treat `WitnessTableType` parameters that have + // `TableCollection` arguments (with atleast 2 distinct elements) as dynamic. Each such + // parameter will get a corresponding parameter of `TagType(tableCollection)` + // + // - Clone in the rest of the generic's body into the first block of the function. + // The tricky part here is that we may have parameter types that depend on other parameters. + // This is a pattern that is allowed in the generic, but not in functions. + // + // To handle this, we maintain two cloning environments, a regular `cloneEnv` that registers + // the parameters, and a `staticCloneEnv` that is a child of `cloneEnv`, but overrides the + // dynamic parameters with their static collection. The `staticCloneEnv` is used to clone in + // the parameter and function types, while the `cloneEnv` is used for the rest. + // + // - When we reach the return value (i.e. the function inside the generic), we clone in the + // parameters + // of the function into the first block, and place them _after_ the parameters derived from + // the generic. Then the rest of the inner function's first block is cloned in. + // + // - All other blocks can be cloned in as usual. // - if (auto specialize = as(callee)) - { - auto generic = as(specialize->getBase()); - - // Only functions need dynamic-aware specialization. - if (getGenericReturnVal(generic)->getOp() != kIROp_Func) - return false; - - for (UInt i = 0; i < specialize->getArgCount(); i++) - { - auto arg = specialize->getArg(i); - if (as(arg)) - return true; // Found a type-flow-collection argument - } - return false; // No type-flow-collection arguments found - } - - return false; -} -*/ -static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) -{ auto generic = cast(specializeInst->getBase()); auto genericReturnVal = findGenericReturnVal(generic); diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index eb92e7faa85..3d657e4875f 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -4,6 +4,8 @@ namespace Slang { struct IRModule; +struct IRInst; +struct IRSpecialize; class DiagnosticSink; class TargetProgram; diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 00d7cbdd4b1..c89b0e9eef4 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -2646,8 +2646,6 @@ struct TypeFlowSpecializationContext IRType* updateType(IRType* currentType, IRType* newType) { - // TODO: This is feeling very similar to the unionCollection logic. - // Maybe unify? if (auto collection = as(currentType)) { HashSet collectionElements; @@ -2904,49 +2902,6 @@ struct TypeFlowSpecializationContext {reinterpretedTag, reinterpretedVal}); } } - else if (as(argInfo) && as(destInfo)) - { - // TODO: This case should not occur anymore since we replaced the bare tuple-type - // with collection-tagged-union-type. - // - SLANG_UNEXPECTED("Should not happen"); - - auto argTupleType = as(argInfo); - auto destTupleType = as(destInfo); - - List upcastedElements; - bool hasUpcastedElements = false; - - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - - // Upcast each element of the tuple - for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) - { - auto argElementType = argTupleType->getOperand(i); - auto destElementType = destTupleType->getOperand(i); - - // If the element types are different, we need to reinterpret - if (argElementType != destElementType) - { - hasUpcastedElements = true; - upcastedElements.add(upcastCollection( - context, - builder.emitGetTupleElement((IRType*)argElementType, arg, i), - (IRType*)destElementType)); - } - else - { - upcastedElements.add( - builder.emitGetTupleElement((IRType*)argElementType, arg, i)); - } - } - - if (hasUpcastedElements) - { - return builder.emitMakeTuple(upcastedElements); - } - } else if (as(argInfo) && as(destInfo)) { // If the arg represents a tag of a colleciton, but the dest is a _different_ From ad663cb778a19cc7c6e7558618c9bf4cdf07422f Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:46:08 -0400 Subject: [PATCH 060/105] More comments + minor changes --- source/slang/slang-ir-insts.lua | 6 +-- source/slang/slang-ir-typeflow-collection.cpp | 10 ++-- source/slang/slang-ir-typeflow-collection.h | 46 ++++++++++++++++--- 3 files changed, 48 insertions(+), 14 deletions(-) diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 374c8e7c431..af9002725d9 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2248,8 +2248,8 @@ local insts = { { CollectionTagType = { -- Represents a tag-type for a collection. -- - -- An inst whose type is CollectionTagType(collection) is semantically carrying a run-time value that points to - -- one of the elements of the collection operand. + -- An inst whose type is CollectionTagType(collection) is semantically carrying a + -- run-time value that "picks" one of the elements of the collection operand. -- -- Only operand is a CollectionBase } }, @@ -2263,7 +2263,7 @@ local insts = { -- This is most commonly used to specialize the type of existential insts once the possibilities can be statically determined. -- -- Operands are a TypeCollection and a TableCollection that represent the possibilities of the existential - }} + } } }, }, { CastInterfaceToTaggedUnionPtr = { diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index ad383cb83e7..b6a7441d068 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -24,13 +24,13 @@ UCount getCollectionCount(IRCollectionBase* collection) UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) { - auto typeCollection = taggedUnion->getOperand(0); + auto typeCollection = taggedUnion->getTypeCollection(); return getCollectionCount(as(typeCollection)); } UCount getCollectionCount(IRCollectionTagType* tagType) { - auto collection = tagType->getOperand(0); + auto collection = tagType->getCollection(); return getCollectionCount(as(collection)); } @@ -43,8 +43,8 @@ IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) { - auto typeCollection = collectionTagType->getOperand(0); - return getCollectionElement(as(typeCollection), index); + auto collection = collectionTagType->getCollection(); + return getCollectionElement(as(collection), index); } CollectionBuilder::CollectionBuilder(IRModule* module) @@ -64,7 +64,7 @@ UInt CollectionBuilder::getUniqueID(IRInst* inst) return id; } -// Helper methods for creating canonical collections +// Helper method for creating canonical collections IRCollectionBase* CollectionBuilder::createCollection(IROp op, const HashSet& elements) { SLANG_ASSERT( diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h index ca399a4b208..f681a0ab85a 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-collection.h @@ -8,6 +8,10 @@ namespace Slang IRCollectionTagType* makeTagType(IRCollectionBase* collection); +// +// Count and indexing helpers +// + UCount getCollectionCount(IRCollectionBase* collection); UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion); UCount getCollectionCount(IRCollectionTagType* tagType); @@ -15,7 +19,9 @@ UCount getCollectionCount(IRCollectionTagType* tagType); IRInst* getCollectionElement(IRCollectionBase* collection, UInt index); IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index); -// Helper to iterate over collection elements +// +// Helpers to iterate over elements of a collection. +// template void forEachInCollection(IRCollectionBase* info, F func) @@ -27,26 +33,54 @@ void forEachInCollection(IRCollectionBase* info, F func) template void forEachInCollection(IRCollectionTagType* tagType, F func) { - forEachInCollection(as(tagType->getOperand(0)), func); + forEachInCollection(as(tagType->getCollection()), func); } +// Builder class that helps greatly with struct CollectionBuilder { + // Get a collection builder for 'module'. CollectionBuilder(IRModule* module); - UInt getUniqueID(IRInst* inst); - - // Helper methods for creating canonical collections + // Create an inst to represent the elements in the set. + // + // All insts in `elements` must be global and concrete. They must not + // be collections themselves. + // + // Op must be one of the ops in `CollectionBase` + // + // For a given set, the returned inst is always the same within a single + // module. + // IRCollectionBase* createCollection(IROp op, const HashSet& elements); + + // Get a suitable collection op-code to use for an set containing 'inst'. IROp getCollectionTypeForInst(IRInst* inst); + + // Create a collection with a single element IRCollectionBase* makeSingletonSet(IRInst* value); + + // Create a collection with the given elements (the collection op will be + // automatically deduced using` getCollectionTypeForInst`) + // IRCollectionBase* makeSet(const HashSet& values); private: + // Return a unique ID for the inst. Assuming the module pointer + // is consistent, this should always be the same for a given inst. + // + UInt getUniqueID(IRInst* inst); + // Reference to parent module IRModule* module; - // Unique ID assignment for functions and witness tables + // Unique ID assignment for functions and witness tables. + // + // This is a pointer to a shared dictionary (typically + // a part of the module inst) so that all CollectionBuilder + // objects for the same module will always produce the same + // ordering. + // Dictionary* uniqueIds; }; From 9dc2b974bfdb96058bca8af69272cc694a8ba112 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Oct 2025 16:50:33 -0400 Subject: [PATCH 061/105] Remove commented code --- source/slang/slang-ir-specialize.cpp | 6 ------ source/slang/slang-lower-to-ir.cpp | 2 -- 2 files changed, 8 deletions(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 36db3b210f5..b1702c8d390 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1224,9 +1224,6 @@ struct SpecializationContext this->changed = true; eliminateDeadCode(module->getModuleInst()); applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); - - // Sync our local dictionary with the one in the IR. - // readSpecializationDictionaries(); } // Once the work list has gone dry, we should have the invariant @@ -1245,9 +1242,6 @@ struct SpecializationContext if (iterChanged) { eliminateDeadCode(module->getModuleInst()); - - // Sync our local dictionary with the one in the IR. - // readSpecializationDictionaries(); } } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 09e0f4efd86..b01d40a3daa 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3357,8 +3357,6 @@ void collectParameterLists( auto thisType = getThisParamTypeForContainer(context, parentDeclRef); if (thisType) { - /*thisType = as( - thisType->substitute(getCurrentASTBuilder(), SubstitutionSet(declRef)));*/ if (declRef.getDecl()->findModifier()) { auto noDiffAttr = context->astBuilder->getNoDiffModifierVal(); From e40d50f4d68bf127cd86ad597c45677990b6fe13 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Oct 2025 19:20:41 -0400 Subject: [PATCH 062/105] Fix some comments --- source/slang/slang-ir-typeflow-specialize.cpp | 12 ++++-------- source/slang/slang-ir-typeflow-specialize.h | 13 ++++++++++++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index c89b0e9eef4..6fb458a6219 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -2104,10 +2104,6 @@ struct TypeFlowSpecializationContext } // Default catch-all analysis method for any unhandled case. - // - // TODO: This technically shouldn't get invoked, since global - // insts shouldn't enter analysis at all - // IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) { SLANG_UNUSED(context); @@ -2287,7 +2283,7 @@ struct TypeFlowSpecializationContext return hasChanges; } - // Main entry point for the second phase of the type-flow analysis pass. + // Implements phase 2 of the type-flow specialization pass. // // This method is called after information propagation is complete and // stabilized, and it replaces dynamic insts and types with specialized versions @@ -2296,9 +2292,9 @@ struct TypeFlowSpecializationContext // After this pass is run, there should be no dynamic insts or types remaining, // _except_ for those that are considered unbounded. // - // i.e. ExtractExistentialType, ExtractExistentialWitnessTable, ExtractExistentialValue, - // MakeExistential, LookupWitness (and more) are rewritten to concrete tag translation - // insts. + // i.e. `ExtractExistentialType`, `ExtractExistentialWitnessTable`, `ExtractExistentialValue`, + // `MakeExistential`, `LookupWitness` (and more) are rewritten to concrete tag translation + // insts (e.g. `GetTagForMappedCollection`, `GetTagForSpecializedCollection`, etc.) // bool performDynamicInstLowering() { diff --git a/source/slang/slang-ir-typeflow-specialize.h b/source/slang/slang-ir-typeflow-specialize.h index 1a134275085..8863ac915b8 100644 --- a/source/slang/slang-ir-typeflow-specialize.h +++ b/source/slang/slang-ir-typeflow-specialize.h @@ -4,7 +4,18 @@ namespace Slang { -// Main entry point for the pass + +// Convert dynamic insts such as `LookupWitnessMethod`, `ExtractExistentialValue`, +// `ExtractExistentialType`, `ExtractExistentialWitnessTable` and more into specialized versions +// based on the possible values at at the use sites, based on a data-flow-style interprocedural +// analysis. +// +// This pass is intended to be run after all specialization insts with concrete arguments have +// already been processed. +// +// This pass may generate more `Specialize` insts, so it should be run in a loop with +// the standard specialization pass until a no more changes can be made. +// bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink); bool isDynamicGeneric(IRInst* callee); From ad2e3219b53e73fd9f1bad9502966f397312e2cd Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Oct 2025 19:54:27 -0400 Subject: [PATCH 063/105] Update slang-ir-typeflow-collection.h --- source/slang/slang-ir-typeflow-collection.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h index f681a0ab85a..27283ce8288 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-collection.h @@ -36,7 +36,13 @@ void forEachInCollection(IRCollectionTagType* tagType, F func) forEachInCollection(as(tagType->getCollection()), func); } -// Builder class that helps greatly with +// Builder class that helps greatly with constructing `CollectionBase` instructions, +// which conceptually represent sets, and maintain the property that the equal sets +// should always be represented by the same instruction. +// +// Uses a unique ID assignment to keep stable ordering throughout the lifetime of the +// module. +// struct CollectionBuilder { // Get a collection builder for 'module'. From f09ce0cb55134b641ec1be4c32523ea5093ad002 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:12:07 -0400 Subject: [PATCH 064/105] Put the speciliazation dictionary checks for `Undefined` into a separate function --- source/slang/slang-ir-specialize.cpp | 62 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index b1702c8d390..39052bf6e00 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -329,14 +329,8 @@ struct SpecializationContext // existing specialization that has been registered. // If one is found, our work is done. // - IRSpecializationDictionaryItem* specializationEntry = nullptr; - if (genericSpecializations.tryGetValue(key, specializationEntry)) - { - if (specializationEntry->getOperand(0)->getOp() != kIROp_Undefined) - return specializationEntry->getOperand(0); - else - genericSpecializations.remove(key); - } + if (auto specializedVal = tryGetDictionaryEntry(genericSpecializations, key)) + return specializedVal; } // If no existing specialization is found, we need @@ -1109,6 +1103,30 @@ struct SpecializationContext args.getBuffer())); } + // Look up an entry in the given specialization dictionary. + // + // Takes care of cases where the entry is not longer valid + // by removing the entry in the dictionary and returning null. + // + IRInst* tryGetDictionaryEntry( + Dictionary& dict, + const IRSimpleSpecializationKey& key) + { + IRSpecializationDictionaryItem* item = nullptr; + if (dict.tryGetValue(key, item)) + { + if (item->getOperand(0)->getOp() != kIROp_Undefined) + return item->getOperand(0); + else + { + dict.remove(key); + return nullptr; + } + } + + return nullptr; + } + // All of the machinery for generic specialization // has been defined above, so we will now walk // through the flow of the overall specialization pass. @@ -1565,18 +1583,8 @@ struct SpecializationContext // existing specialization of the callee that we can use. // IRSpecializationDictionaryItem* specializedCalleeEntry = nullptr; - IRFunc* specializedCallee = nullptr; - if (existentialSpecializedFuncs.tryGetValue(key, specializedCalleeEntry)) - { - if (specializedCalleeEntry->getOperand(0)->getOp() != kIROp_Undefined) - { - specializedCallee = cast(specializedCalleeEntry->getOperand(0)); - } - else - { - existentialSpecializedFuncs.remove(key); - } - } + IRFunc* specializedCallee = + cast(tryGetDictionaryEntry(existentialSpecializedFuncs, key)); if (!specializedCallee) { @@ -2718,18 +2726,8 @@ struct SpecializationContext IRSpecializationDictionaryItem* newStructTypeEntry = nullptr; addUsersToWorkList(type); - IRStructType* newStructType = nullptr; - if (existentialSpecializedStructs.tryGetValue(key, newStructTypeEntry)) - { - if (newStructTypeEntry->getOperand(0)->getOp() != kIROp_Undefined) - { - newStructType = cast(newStructTypeEntry->getOperand(0)); - } - else - { - existentialSpecializedStructs.remove(key); - } - } + IRStructType* newStructType = + cast(tryGetDictionaryEntry(existentialSpecializedStructs, key)); if (!newStructType) { From 9f976864b93704cb869cb8654b41abe092129a30 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:25:23 -0400 Subject: [PATCH 065/105] Add support for specializing optional existential types. --- source/slang/slang-ir-typeflow-specialize.cpp | 276 ++++++++++++++++++ 1 file changed, 276 insertions(+) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 6fb458a6219..dd8ae57a962 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -905,6 +905,12 @@ struct TypeFlowSpecializationContext case kIROp_FieldExtract: info = analyzeFieldExtract(context, as(inst)); break; + case kIROp_MakeOptionalNone: + info = analyzeMakeOptionalNone(context, as(inst)); + break; + case kIROp_MakeOptionalValue: + info = analyzeMakeOptionalValue(context, as(inst)); + break; default: info = analyzeDefault(context, inst); break; @@ -1365,6 +1371,124 @@ struct TypeFlowSpecializationContext return none(); } + // Locate the 'none' witness table in the global scope + // of the module in context. This will be the table + // that conforms to 'nullptr' and has 'void' as the concrete type + // + IRWitnessTable* findNoneWitness() + { + IRBuilder builder(module); + auto voidType = builder.getVoidType(); + for (auto inst : module->getGlobalInsts()) + { + if (auto witnessTable = as(inst)) + { + if (witnessTable->getConcreteType() == voidType && + witnessTable->getConformanceType() == nullptr) + return witnessTable; + } + } + + return nullptr; + } + + // Get the witness table inst to be used for the 'none' case of + // an optional witness table. + // + IRWitnessTable* getNoneWitness() + { + if (auto table = findNoneWitness()) + return table; + + IRBuilder builder(module); + auto voidType = builder.getVoidType(); + + return builder.createWitnessTable(voidType, nullptr); + } + + // Returns true if the inst is of the form OptionalType + bool isOptionalExistentialType(IRInst* inst) + { + if (auto optionalType = as(inst)) + return as(optionalType->getValueType()) != nullptr; + return false; + } + + IRInst* analyzeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) + { + // If the optional type we're dealing with is an optional concrete type, we won't + // touch this case, since there's nothing dynamic to specialize. + // + // If the type inside the optional is an interface type, then we will treat it slightly + // differently by including 'none' as one of the possible candidates of the existential + // value. + // + // The `MakeOptionalNone` case represents the creating of an existential out of the + // 'none' witness table and a void value, so we'll represent that using the tagged union + // type. + // + SLANG_UNUSED(context); + if (isOptionalExistentialType(inst->getDataType())) + { + IRBuilder builder(module); + auto noneTableSet = cast( + cBuilder.createCollection(kIROp_TableCollection, getNoneWitness())); + return makeExistential(noneTableSet); + } + + return none(); + } + + IRInst* analyzeMakeOptionalValue(IRInst* context, IRMakeOptionalValue* inst) + { + // If the optional type we're dealing with is an optional concrete type, we won't + // touch this case, since there's nothing dynamic to specialize. + // + // If the type inside the optional is an interface type, then we will treat it slightly + // differently, by conceptually treating it as an interface type that has all the possible + // elements of the interface type plus an additional 'none' element. + // + // The `MakeOptionalValue` case is then very similar to the `MakeExistential` case, only we + // already have an existential as input. + // + // Thus, we simply pass the input existential info as-is. + // + // Note: we don't actually have to add a new 'none' table to the collection, since that will + // automatically occur if this value ever merges with a value created using + // `MakeOptionalNone` + // + if (isOptionalExistentialType(inst->getDataType())) + { + if (auto info = tryGetInfo(context, inst->getValue())) + { + SLANG_ASSERT(as(info)); + return info; + } + } + + return none(); + } + + IRInst* analyzeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) + { + if (isOptionalExistentialType(inst->getDataType())) + { + // This is an interesting case.. technically, at this point we could go + // from a larger collection to a smaller one (without the none-type). + // + // However, for simplicitly reasons, we currently only allow up-casting, + // so for now we'll just passthrough all types (so the result will + // assume that 'none-type' is a possiblity even though we statically know + // that it isn't). + // + if (auto info = tryGetInfo(context, inst->getOperand(0))) + { + SLANG_ASSERT(as(info)); + return info; + } + } + } + IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { // A LookupWitnessMethod is assumed to by dynamic, so we @@ -2464,6 +2588,14 @@ struct TypeFlowSpecializationContext return specializeGetSequentialID(context, as(inst)); case kIROp_IsType: return specializeIsType(context, as(inst)); + case kIROp_MakeOptionalNone: + return specializeMakeOptionalNone(context, as(inst)); + case kIROp_MakeOptionalValue: + return specializeMakeOptionalValue(context, as(inst)); + case kIROp_OptionalHasValue: + return specializeOptionalHasValue(context, as(inst)); + case kIROp_GetOptionalValue: + return specializeGetOptionalValue(context, as(inst)); default: { // Default case: replace inst type with specialized type (if available) @@ -3791,6 +3923,150 @@ struct TypeFlowSpecializationContext return false; } + bool specializeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) + { + if (auto taggedUnionType = as(inst->getDataType())) + { + // If we're dealing with a `MakeOptionalNone` for an existential type, then + // this just becomes a tagged union tuple where the set of tables is {none} + // (i.e. singleton set of none witness) + // + + IRBuilder builder(module); + builder.setInsertBefore(inst); + + // Create a tuple for the empty type.. + SLANG_ASSERT(taggedUnionType->getTableCollection()->isSingleton()); + auto noneWitnessTable = taggedUnionType->getTableCollection()->getElement(0); + + auto singletonTagType = makeTagType(cBuilder.makeSingletonSet(noneWitnessTable)); + auto zeroValueOfTagType = builder.getIntValue((IRType*)singletonTagType, 0); + + List tupleOperands; + tupleOperands.add(zeroValueOfTagType); + tupleOperands.add( + builder.emitDefaultConstruct((IRType*)taggedUnionType->getTypeCollection())); + + auto newTuple = builder.emitMakeTuple(tupleOperands); + inst->replaceUsesWith(newTuple); + propagationMap[InstWithContext(context, newTuple)] = taggedUnionType; + inst->removeAndDeallocate(); + + return true; + } + + return false; + } + + bool specializeMakeOptionalValue(IRInst* context, IRMakeOptionalValue* inst) + { + SLANG_UNUSED(context); + if (auto taggedUnionType = as(inst->getValue()->getDataType())) + { + // If we're dealing with a `MakeOptionalValue` for an existential type, + // we don't actually have to change anything, since logically, the input and output + // represent the same set of types and tables. + // + // We'll do a simple replace. + // + + auto newInst = inst->getValue(); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + + return true; + } + + return false; + } + + bool specializeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) + { + SLANG_UNUSED(context); + if (auto taggedUnionType = + as(inst->getOptionalOperand()->getDataType())) + { + // Since `GetOptionalValue` is the reverse of `MakeOptionalValue`, and we treat + // the latter as a no-op, then `GetOptionalValue` is also a no-op (we simply pass + // the inner existential value as-is) + // + + auto newInst = inst->getOptionalOperand(); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + return false; + } + + bool specializeOptionalHasValue(IRInst* context, IROptionalHasValue* inst) + { + SLANG_UNUSED(context); + if (auto taggedUnionType = + as(inst->getOptionalOperand()->getDataType())) + { + // The logic here is similar to specializing IsType, but we'll directly compare + // tags instead of trying to use sequential ID. + // + // There's two cases to handle here: + // 1. We statically know that it cannot be a 'none' because the + // input's collection type doesn't have a 'none'. In this case + // we just return a true. + // + // 2. 'none' is a possibility. In this case, we create a 0 value of + // type TagType(TableCollection(NoneWitness)) and then upcast it + // to TagType(inputTableCollection). This will convert the value + // to the corresponding value of 'none' in the input's table collection + // allowing us to directly compare it against the tag part of the + // input tagged union. + // + + IRBuilder builder(inst); + + bool containsNone = false; + forEachInCollection( + taggedUnionType->getTableCollection(), + [&](IRInst* wt) + { + if (wt == getNoneWitness()) + containsNone = true; + }); + + if (!containsNone) + { + auto trueVal = builder.getBoolValue(true); + inst->replaceUsesWith(trueVal); + inst->removeAndDeallocate(); + return true; + } + else + { + auto dynTag = builder.emitGetTupleElement( + (IRType*)makeTagType(taggedUnionType->getTableCollection()), + inst->getOptionalOperand(), + 0); + + IRInst* noneWitnessTagType = + makeTagType(cBuilder.makeSingletonSet(getNoneWitness())); + IRInst* noneSingletonWitnessTag = + builder.getIntValue((IRType*)noneWitnessTagType, 0); + + // Cast tag to super collection + auto noneWitnessTag = builder.emitIntrinsicInst( + (IRType*)makeTagType(taggedUnionType->getTableCollection()), + kIROp_GetTagForSuperCollection, + 1, + &noneSingletonWitnessTag); + + auto newInst = builder.emitNeq(dynTag, noneWitnessTag); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + return false; + } + HashSet collectExistentialTables(IRInterfaceType* interfaceType) { HashSet tables; From 80f21932ba56868c6c87de4a10f875360577e895 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:26:25 -0400 Subject: [PATCH 066/105] Update slang-ir-typeflow-specialize.cpp --- source/slang/slang-ir-typeflow-specialize.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index dd8ae57a962..6b81082f674 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -4034,6 +4034,10 @@ struct TypeFlowSpecializationContext if (!containsNone) { + // If 'none' isn't a part of the collection, statically set + // to true. + // + auto trueVal = builder.getBoolValue(true); inst->replaceUsesWith(trueVal); inst->removeAndDeallocate(); @@ -4041,6 +4045,10 @@ struct TypeFlowSpecializationContext } else { + // Otherwise, we'll extract the tag and compare against + // the value for 'none' (in the context of the tag's collection) + // + auto dynTag = builder.emitGetTupleElement( (IRType*)makeTagType(taggedUnionType->getTableCollection()), inst->getOptionalOperand(), @@ -4051,7 +4059,9 @@ struct TypeFlowSpecializationContext IRInst* noneSingletonWitnessTag = builder.getIntValue((IRType*)noneWitnessTagType, 0); - // Cast tag to super collection + // Cast the singleton tag to the target collection tag (will convert the + // value to the corresponding value for the larger set) + // auto noneWitnessTag = builder.emitIntrinsicInst( (IRType*)makeTagType(taggedUnionType->getTableCollection()), kIROp_GetTagForSuperCollection, From c9b3965e6995318cd2d30ed2cdcea40cad725957 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 14 Oct 2025 17:57:42 -0400 Subject: [PATCH 067/105] More fixes for optional existential types --- source/slang/slang-ir-specialize-dispatch.cpp | 7 +++++++ source/slang/slang-ir-typeflow-specialize.cpp | 8 ++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index b862b3dd087..8d2a08a618d 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -200,6 +200,13 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex else { auto witnessTableType = as(inst->getDataType()); + + if (witnessTableType && witnessTableType->getConformanceType() == nullptr) + { + // Ignore witness tables that represent 'none' for optional witness table types. + continue; + } + if (witnessTableType && witnessTableType->getConformanceType() ->findDecoration()) { diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 6b81082f674..7685c9c8bdc 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -1403,7 +1403,7 @@ struct TypeFlowSpecializationContext IRBuilder builder(module); auto voidType = builder.getVoidType(); - return builder.createWitnessTable(voidType, nullptr); + return builder.createWitnessTable(nullptr, voidType); } // Returns true if the inst is of the form OptionalType @@ -3925,7 +3925,7 @@ struct TypeFlowSpecializationContext bool specializeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) { - if (auto taggedUnionType = as(inst->getDataType())) + if (auto taggedUnionType = as(tryGetInfo(context, inst))) { // If we're dealing with a `MakeOptionalNone` for an existential type, then // this just becomes a tagged union tuple where the set of tables is {none} @@ -3947,7 +3947,7 @@ struct TypeFlowSpecializationContext tupleOperands.add( builder.emitDefaultConstruct((IRType*)taggedUnionType->getTypeCollection())); - auto newTuple = builder.emitMakeTuple(tupleOperands); + auto newTuple = builder.emitMakeTuple((IRType*)taggedUnionType, tupleOperands); inst->replaceUsesWith(newTuple); propagationMap[InstWithContext(context, newTuple)] = taggedUnionType; inst->removeAndDeallocate(); @@ -4048,7 +4048,7 @@ struct TypeFlowSpecializationContext // Otherwise, we'll extract the tag and compare against // the value for 'none' (in the context of the tag's collection) // - + builder.setInsertBefore(inst); auto dynTag = builder.emitGetTupleElement( (IRType*)makeTagType(taggedUnionType->getTableCollection()), inst->getOptionalOperand(), From 43a84f3d153e48e4daa8ef5ba32f15e462868bb6 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 15 Oct 2025 12:58:31 -0400 Subject: [PATCH 068/105] More fixes for existential IFoo (add another test) --- source/slang/slang-ir-insts.lua | 4 ++- source/slang/slang-ir-typeflow-specialize.cpp | 35 ++++++++++--------- .../dynamic-dispatch/func-call-input-1.slang | 2 +- ...ional-ifoo.slang => optional-ifoo-1.slang} | 0 .../types/optional-ifoo-2.slang | 33 +++++++++++++++++ 5 files changed, 56 insertions(+), 18 deletions(-) rename tests/language-feature/types/{optional-ifoo.slang => optional-ifoo-1.slang} (100%) create mode 100644 tests/language-feature/types/optional-ifoo-2.slang diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 7f593c7ec3f..85ab97d667f 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2229,7 +2229,7 @@ local insts = { -- { TypeCollection = {} }, { FuncCollection = {} }, - { TableCollection = {} }, + { TableCollection = {} }, -- TODO: Rename to WitnessTableCollection { GenericCollection = {} }, }, }, @@ -2248,6 +2248,8 @@ local insts = { -- pass does not attempt to specialize it. It should not appear in the code after -- the specialization pass. -- + -- TODO: Consider the scenario where we can combine the unbounded case with known cases. + -- unbounded collection should probably be an element and not a separate op. } }, { CollectionTagType = { -- Represents a tag-type for a collection. diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 7685c9c8bdc..3e91eaaeac0 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -1410,7 +1410,8 @@ struct TypeFlowSpecializationContext bool isOptionalExistentialType(IRInst* inst) { if (auto optionalType = as(inst)) - return as(optionalType->getValueType()) != nullptr; + if (auto interfaceType = as(optionalType->getValueType())) + return !isComInterfaceType(interfaceType) && !isBuiltin(interfaceType); return false; } @@ -3061,7 +3062,23 @@ struct TypeFlowSpecializationContext if (getCollectionCount(as(argInfo)) != getCollectionCount(as(destInfo))) { - // If the sets of witness tables are not equal, reinterpret to the parameter type + auto argCollection = as(argInfo); + if (argCollection->isSingleton() && as(argCollection->getElement(0))) + { + // There's a specific case where we're trying to reinterpret a value of 'void' + // type. We'll avoid emitting a reinterpret in this case, and emit a + // default-construct instead. + // + IRBuilder builder(module); + setInsertAfterOrdinaryInst(&builder, arg); + return builder.emitDefaultConstruct((IRType*)destInfo); + } + + // General case: + // + // If the sets of witness tables are not equal, reinterpret to the + // parameter type + // IRBuilder builder(module); setInsertAfterOrdinaryInst(&builder, arg); return builder.emitReinterpret((IRType*)destInfo, arg); @@ -3083,20 +3100,6 @@ struct TypeFlowSpecializationContext return arg; // Can use as-is. } - // TODO: Is this required? - IRInst* getCalleeForContext(IRInst* context) - { - if (this->contextsToLower.contains(context)) - return context; // Not specialized yet. - - if (this->loweredContexts.containsKey(context)) - return this->loweredContexts[context]; - else - this->contextsToLower.add(context); - - return context; - } - // Helper function for specializing calls. // // For a `Specialize` instruction that has dynamic tag arguments, diff --git a/tests/language-feature/dynamic-dispatch/func-call-input-1.slang b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang index 60e2ad547c0..e93ededa7ae 100644 --- a/tests/language-feature/dynamic-dispatch/func-call-input-1.slang +++ b/tests/language-feature/dynamic-dispatch/func-call-input-1.slang @@ -31,7 +31,7 @@ IInterface factoryAB(uint id, float x) return B(); } -// Should lower to accept A, B, C or D (but not E) +// Should lower to accept A, B (but not C) float calc(IInterface obj, float y) { return obj.calc(y); diff --git a/tests/language-feature/types/optional-ifoo.slang b/tests/language-feature/types/optional-ifoo-1.slang similarity index 100% rename from tests/language-feature/types/optional-ifoo.slang rename to tests/language-feature/types/optional-ifoo-1.slang diff --git a/tests/language-feature/types/optional-ifoo-2.slang b/tests/language-feature/types/optional-ifoo-2.slang new file mode 100644 index 00000000000..0a19fa9ad35 --- /dev/null +++ b/tests/language-feature/types/optional-ifoo-2.slang @@ -0,0 +1,33 @@ +//TEST:INTERPRET(filecheck=CHECK): + +interface IFoo +{ + int get_result(); +} + +struct FooImpl : IFoo +{ + int get_result() { return data; } + int data; +} + +Optional generate_foo(int i) +{ + if (i % 5 == 0) + { + FooImpl result = {}; + result.data = i; + return { result }; + } + else + { + return {}; + } +} + +void main() +{ + // CHECK: hasValue: 1 + let result_foo = generate_foo(100); + printf("hasValue: %d\n", (int)result_foo.hasValue); +} From 091139e534eb6e4240928ffeaf5185897b3d690f Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:43:36 -0400 Subject: [PATCH 069/105] Fix VM emit for unreachable blocks. --- source/slang/slang-emit-vm.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/source/slang/slang-emit-vm.cpp b/source/slang/slang-emit-vm.cpp index 80d3762aab0..b38565cc5a9 100644 --- a/source/slang/slang-emit-vm.cpp +++ b/source/slang/slang-emit-vm.cpp @@ -1079,6 +1079,11 @@ class ByteCodeEmitter } } + bool isUnreachableBlock(IRBlock* block) + { + return as(block->getTerminator()) != nullptr; + } + void emitFunction(IRFunc* func) { VMByteCodeFunctionBuilder funcBuilder; @@ -1096,6 +1101,9 @@ class ByteCodeEmitter for (auto block : func->getBlocks()) { + if (isUnreachableBlock(block)) + continue; + mapBlockToByteOffset[block] = funcBuilder.code.getCount(); for (auto inst : block->getChildren()) From b0225349474c1d21541ab03bc061d100379eb6d9 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:47:36 -0400 Subject: [PATCH 070/105] Update slang-ir-specialize.cpp --- source/slang/slang-ir-specialize.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 39052bf6e00..b8bcc7eca47 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1582,7 +1582,6 @@ struct SpecializationContext // Once we've constructed our key, we can try to look for an // existing specialization of the callee that we can use. // - IRSpecializationDictionaryItem* specializedCalleeEntry = nullptr; IRFunc* specializedCallee = cast(tryGetDictionaryEntry(existentialSpecializedFuncs, key)); @@ -2723,7 +2722,6 @@ struct SpecializationContext key.vals.add(type->getExistentialArg(ii)); } - IRSpecializationDictionaryItem* newStructTypeEntry = nullptr; addUsersToWorkList(type); IRStructType* newStructType = From 589e0dc8ef7e4de279fe995011fe3ac3b803407d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:56:05 -0400 Subject: [PATCH 071/105] Overhaul semantics of the propagation info insts + several fixes and refinements --- source/slang/slang-emit.cpp | 3 +- source/slang/slang-ir-autodiff-transpose.h | 2 +- source/slang/slang-ir-insts-stable-names.lua | 18 +- source/slang/slang-ir-insts.h | 149 +- source/slang/slang-ir-insts.lua | 117 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 0 .../slang/slang-ir-lower-typeflow-insts.cpp | 356 ++++- source/slang/slang-ir-lower-typeflow-insts.h | 21 +- ...ecialize-dynamic-associatedtype-lookup.cpp | 6 + source/slang/slang-ir-specialize.cpp | 52 +- source/slang/slang-ir-specialize.h | 5 + source/slang/slang-ir-typeflow-collection.cpp | 123 +- source/slang/slang-ir-typeflow-collection.h | 5 + source/slang/slang-ir-typeflow-specialize.cpp | 1248 +++++++++++------ .../slang/slang-ir-witness-table-wrapper.cpp | 18 +- source/slang/slang-ir.cpp | 15 + ...es.slang => dependent-assoc-types-1.slang} | 0 .../dynamic-dispatch/generic-method.slang | 46 + 18 files changed, 1659 insertions(+), 525 deletions(-) create mode 100644 source/slang/slang-ir-lower-dynamic-insts.cpp rename tests/language-feature/dynamic-dispatch/{dependent-assoc-types.slang => dependent-assoc-types-1.slang} (100%) create mode 100644 tests/language-feature/dynamic-dispatch/generic-method.slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index fe9efd9bb79..d9fe06ba077 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1143,11 +1143,12 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, sink)); } + lowerTagInsts(irModule, sink); + // Tagged union type lowering typically generates more reinterpret instructions. if (lowerTaggedUnionTypes(irModule, sink)) requiredLoweringPassSet.reinterpret = true; - lowerTagInsts(irModule, sink); lowerTypeCollections(irModule, sink); if (requiredLoweringPassSet.reinterpret) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index c70374e7706..b24ed50095f 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2212,7 +2212,7 @@ struct DiffTransposePass // If we reach this point, revValue must be a differentiable type. auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness( builder, - primalType, + fwdInst->getDataType(), DiffConformanceKind::Value); SLANG_ASSERT(revTypeWitness); diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index fd050918f68..1e8b1285a16 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -683,16 +683,26 @@ return { ["Undefined.LoadFromUninitializedMemory"] = 679, ["TypeFlowData.CollectionBase.TypeCollection"] = 680, ["TypeFlowData.CollectionBase.FuncCollection"] = 681, - ["TypeFlowData.CollectionBase.TableCollection"] = 682, + ["TypeFlowData.CollectionBase.WitnessTableCollection"] = 682, ["TypeFlowData.CollectionBase.GenericCollection"] = 683, ["TypeFlowData.UnboundedCollection"] = 684, - ["TypeFlowData.CollectionTagType"] = 685, - ["TypeFlowData.CollectionTaggedUnionType"] = 686, + ["Type.CollectionTagType"] = 685, + ["Type.CollectionTaggedUnionType"] = 686, ["CastInterfaceToTaggedUnionPtr"] = 687, ["CastTaggedUnionToInterfacePtr"] = 688, ["GetTagForSuperCollection"] = 689, ["GetTagForMappedCollection"] = 690, ["GetTagForSpecializedCollection"] = 691, ["GetTagFromSequentialID"] = 692, - ["GetSequentialIDFromTag"] = 693 + ["GetSequentialIDFromTag"] = 693, + ["GetElementFromTag"] = 694, + ["GetDispatcher"] = 695, + ["GetSpecializedDispatcher"] = 696, + ["GetTagFromTaggedUnion"] = 697, + ["GetValueFromTaggedUnion"] = 698, + ["Type.ValueOfCollectionType"] = 699, + ["Type.ElementOfCollectionType"] = 700, + ["MakeTaggedUnion"] = 701, + ["GetTypeTagFromTaggedUnion"] = 702, + ["GetTagOfElementInCollection"] = 703, } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fff70b90835..478cca930eb 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3537,7 +3537,7 @@ struct IRCollectionBase : IRTypeFlowData }; FIDDLE() -struct IRTableCollection : IRCollectionBase +struct IRWitnessTableCollection : IRCollectionBase { FIDDLE(leafInst()) }; @@ -3550,7 +3550,7 @@ struct IRTypeCollection : IRCollectionBase }; FIDDLE() -struct IRCollectionTagType : IRTypeFlowData +struct IRCollectionTagType : IRType { FIDDLE(leafInst()) IRCollectionBase* getCollection() { return as(getOperand(0)); } @@ -3558,17 +3558,34 @@ struct IRCollectionTagType : IRTypeFlowData }; FIDDLE() -struct IRCollectionTaggedUnionType : IRTypeFlowData +struct IRCollectionTaggedUnionType : IRType { FIDDLE(leafInst()) - IRTypeCollection* getTypeCollection() { return as(getOperand(0)); } - IRTableCollection* getTableCollection() { return as(getOperand(1)); } + IRWitnessTableCollection* getWitnessTableCollection() + { + return as(getOperand(0)); + } + IRTypeCollection* getTypeCollection() { return as(getOperand(1)); } bool isSingleton() { - return getTypeCollection()->isSingleton() && getTableCollection()->isSingleton(); + return getTypeCollection()->isSingleton() && getWitnessTableCollection()->isSingleton(); } }; +FIDDLE() +struct IRElementOfCollectionType : IRType +{ + FIDDLE(leafInst()) + IRCollectionBase* getCollection() { return as(getOperand(0)); } +}; + +FIDDLE() +struct IRValueOfCollectionType : IRType +{ + FIDDLE(leafInst()) + IRCollectionBase* getCollection() { return as(getOperand(0)); } +}; + FIDDLE(allOtherInstStructs()) struct IRBuilderSourceLocRAII; @@ -4193,6 +4210,12 @@ struct IRBuilder return emitMakeTuple(SLANG_COUNT_OF(args), args); } + IRMakeTaggedUnion* emitMakeTaggedUnion(IRType* type, IRInst* tag, IRInst* value) + { + IRInst* args[] = {tag, value}; + return cast(emitIntrinsicInst(type, kIROp_MakeTaggedUnion, 2, args)); + } + IRInst* emitMakeValuePack(IRType* type, UInt count, IRInst* const* args); IRInst* emitMakeValuePack(UInt count, IRInst* const* args); @@ -4702,6 +4725,120 @@ struct IRBuilder IRMetalSetPrimitive* emitMetalSetPrimitive(IRInst* index, IRInst* primitive); IRMetalSetIndices* emitMetalSetIndices(IRInst* index, IRInst* indices); + // TODO: Move all the collection-based ops into the builder. + IRUnboundedCollection* emitUnboundedCollection() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedCollection, 0, nullptr)); + } + + IRGetElementFromTag* emitGetElementFromTag(IRInst* tag) + { + auto tagType = cast(tag->getDataType()); + IRInst* collection = tagType->getCollection(); + auto elementType = cast( + emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &collection)); + return cast( + emitIntrinsicInst(elementType, kIROp_GetElementFromTag, 1, &tag)); + } + + IRGetTagFromTaggedUnion* emitGetTagFromTaggedUnion(IRInst* tag) + { + auto taggedUnionType = cast(tag->getDataType()); + + IRInst* collection = taggedUnionType->getWitnessTableCollection(); + auto tableTagType = cast( + emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collection)); + + return cast( + emitIntrinsicInst(tableTagType, kIROp_GetTagFromTaggedUnion, 1, &tag)); + } + + IRGetTypeTagFromTaggedUnion* emitGetTypeTagFromTaggedUnion(IRInst* tag) + { + auto taggedUnionType = cast(tag->getDataType()); + + IRInst* collection = taggedUnionType->getTypeCollection(); + auto typeTagType = cast( + emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collection)); + + return cast( + emitIntrinsicInst(typeTagType, kIROp_GetTypeTagFromTaggedUnion, 1, &tag)); + } + + IRGetValueFromTaggedUnion* emitGetValueFromTaggedUnion(IRInst* taggedUnion) + { + auto taggedUnionType = cast(taggedUnion->getDataType()); + + IRInst* typeCollection = taggedUnionType->getTypeCollection(); + auto valueOfTypeCollectionType = cast( + emitIntrinsicInst(nullptr, kIROp_ValueOfCollectionType, 1, &typeCollection)); + + return cast(emitIntrinsicInst( + valueOfTypeCollectionType, + kIROp_GetValueFromTaggedUnion, + 1, + &taggedUnion)); + } + + IRGetDispatcher* emitGetDispatcher( + IRFuncType* funcType, + IRWitnessTableCollection* witnessTableCollection, + IRStructKey* key) + { + IRInst* args[] = {witnessTableCollection, key}; + return cast(emitIntrinsicInst(funcType, kIROp_GetDispatcher, 2, args)); + } + + IRGetSpecializedDispatcher* emitGetSpecializedDispatcher( + IRFuncType* funcType, + IRWitnessTableCollection* witnessTableCollection, + IRStructKey* key, + List const& specArgs) + { + List args; + args.add(witnessTableCollection); + args.add(key); + for (auto specArg : specArgs) + { + args.add(specArg); + } + return cast(emitIntrinsicInst( + funcType, + kIROp_GetSpecializedDispatcher, + (UInt)args.getCount(), + args.getBuffer())); + } + + IRValueOfCollectionType* getValueOfCollectionType(IRInst* operand) + { + return as( + emitIntrinsicInst(nullptr, kIROp_ValueOfCollectionType, 1, &operand)); + } + + IRElementOfCollectionType* getElementOfCollectionType(IRInst* operand) + { + return as( + emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &operand)); + } + + IRGetTagOfElementInCollection* emitGetTagOfElementInCollection( + IRType* tagType, + IRInst* element, + IRInst* collection) + { + SLANG_ASSERT(tagType->getOp() == kIROp_CollectionTagType); + IRInst* args[] = {element, collection}; + return cast( + emitIntrinsicInst(tagType, kIROp_GetTagOfElementInCollection, 2, args)); + } + + IRCollectionTagType* getCollectionTagType(IRInst* collection) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collection)); + } + // // Decorations // diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index be21a3f5ea2..21b7ffd4291 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -445,6 +445,35 @@ local insts = { }, }, }, + { ValueOfCollectionType = { + hoistable = true, + -- A type that represents that the value's _type_ is one of types in the collection operand. + } }, + { ElementOfCollectionType = { + hoistable = true, + -- A type that represents that the value must be an element of the collection operand. + } }, + { CollectionTagType = { + hoistable = true, + -- Represents a tag-type for a collection. + -- + -- An inst whose type is CollectionTagType(collection) is semantically carrying a + -- run-time value that "picks" one of the elements of the collection operand. + -- + -- Only operand is a CollectionBase + } }, + { CollectionTaggedUnionType = { + hoistable = true, + -- Represents a tagged union type. + -- + -- An inst whose type is a CollectionTaggedUnionType(typeCollection, witnessTableCollection) is semantically carrying a tuple of + -- two values: a value of CollectionTagType(witnessTableCollection) to represent the tag, and a payload value of type + -- ValueOfCollectionType(typeCollection), which conceptually represents a union/"anyvalue" type. + -- + -- This is most commonly used to specialize the type of existential insts once the possibilities can be statically determined. + -- + -- Operands are a TypeCollection and a WitnessTableCollection that represent the possibilities of the existential + } } }, }, -- IRGlobalValueWithCode @@ -2269,7 +2298,7 @@ local insts = { -- { TypeCollection = {} }, { FuncCollection = {} }, - { TableCollection = {} }, -- TODO: Rename to WitnessTableCollection + { WitnessTableCollection = {} }, { GenericCollection = {} }, }, }, @@ -2291,25 +2320,6 @@ local insts = { -- TODO: Consider the scenario where we can combine the unbounded case with known cases. -- unbounded collection should probably be an element and not a separate op. } }, - { CollectionTagType = { - -- Represents a tag-type for a collection. - -- - -- An inst whose type is CollectionTagType(collection) is semantically carrying a - -- run-time value that "picks" one of the elements of the collection operand. - -- - -- Only operand is a CollectionBase - } }, - { CollectionTaggedUnionType = { - -- Represents a tagged union type. - -- - -- An inst whose type is a CollectionTaggedUnionType(typeCollection, tableCollection) is semantically carrying a tuple of - -- two values: a value of CollectionTagType(tableCollection) to represent the tag, and a payload value of type - -- typeCollection (which conceptually represents a union/"anyvalue" type) - -- - -- This is most commonly used to specialize the type of existential insts once the possibilities can be statically determined. - -- - -- Operands are a TypeCollection and a TableCollection that represent the possibilities of the existential - } } }, }, { CastInterfaceToTaggedUnionPtr = { @@ -2320,6 +2330,7 @@ local insts = { } }, { GetTagForSuperCollection = { -- Translate a tag from a set to its equivalent in a super-set + -- TODO: Lower using a global ID and not local IDs + mapping ops. } }, { GetTagForMappedCollection = { -- Translate a tag from a set to its equivalent in a different set @@ -2335,7 +2346,71 @@ local insts = { } }, { GetSequentialIDFromTag = { -- Translate a tag from the given collection (a 'local' ID) to a sequential ID (a 'global' ID) - } } + } }, + { GetElementFromTag = { + -- Translate a tag to its corresponding element in the collection. + -- Input's type: CollectionTagType(collection). + -- Output's type: ElementOfCollectionType(collection) + -- + operands = {{"tag"}} + } }, + { GetDispatcher = { + -- Get a dispatcher function for a given witness table set + key. + -- + -- Inputs: set of witness tables to create a dispatched for and the key to use to identify the + -- entry that needs to be dispatched to. All witness tables must have an entry for the given key. + -- or else this is a malformed inst. + -- + -- Output: a value of 'FuncType' that can be called. + -- This func-type will take a `TagType(witnessTableCollection)` as the first parameter to + -- discriminate which witness table to use, and the rest of the parameters. + -- + hoistable = true, + operands = {{"witnessTableCollection", "IRWitnessTableCollection"}, {"lookupKey", "IRStructKey"}} + } }, + { GetSpecializedDispatcher = { + -- Get a specialized dispatcher function for a given witness table set + key, where + -- the key points to a generic function. + -- + -- Inputs: set of witness tables to create a dispatched for and the key to use to identify the + -- entry that needs to be dispatched to. All witness tables must have an entry for the given key. + -- or else this is a malformed inst. + -- A set of specialization arguments (these must be concrete/global types or collections) + -- + -- Output: a value of `FuncType` that can be called. + -- This func-type will take a `TagType(witnessTableCollection)` as the first parameter to + -- discriminate which generic to use, and the rest of the parameters. + -- + hoistable = true + } }, + { GetTagFromTaggedUnion = { + -- Translate a tagged-union value to its corresponding tag in the tagged-union's set. + -- Input's type: CollectionTaggedUnionType(typeCollection, tableCollection) + -- Output's type: CollectionTagType(tableCollection) + operands = {{"taggedUnionValue"}} + } }, + { GetTypeTagFromTaggedUnion = { + -- Translate a tagged-union value to its corresponding type tag in the tagged-union's set. + -- Input's type: CollectionTaggedUnionType(typeCollection, tableCollection) + -- Output's type: CollectionTagType(typeCollection) + operands = {{"taggedUnionValue"}} + } }, + { GetValueFromTaggedUnion = { + -- Translate a tagged-union value to its corresponding value in the tagged-union's set. + -- Input's type: CollectionTaggedUnionType(typeCollection, tableCollection) + -- Output's type: ValueOfCollectionType(typeCollection) + operands = {{"taggedUnionValue"}} + } }, + { MakeTaggedUnion = { + -- Create a tagged-union value from a tag and a value. + -- Input's type: CollectionTagType(tableCollection), ValueOfCollectionType(typeCollection) + -- Output's type: CollectionTaggedUnionType(typeCollection, tableCollection) + operands = { { "tag" }, { "value" } }, + } }, + { GetTagOfElementInCollection = { + -- Get the tag corresponding to an element in a collection. + hoistable = true + } }, } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp new file mode 100644 index 00000000000..e69de29bb2d diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 7b48c6c2c77..f664157c9a8 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -3,6 +3,7 @@ #include "slang-ir-any-value-marshalling.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" +#include "slang-ir-specialize.h" #include "slang-ir-typeflow-collection.h" #include "slang-ir-util.h" #include "slang-ir-witness-table-wrapper.h" @@ -32,12 +33,10 @@ IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& t // The resulting function will have one additional parameter to accept the tag // indicating which function to call. // -IRFunc* createDispatchFunc(IRFuncCollection* collection) +IRFunc* createDispatchFunc(IRFuncType* dispatchFuncType, Dictionary& mapping) { - IRFuncType* dispatchFuncType = cast(collection->getFullType()); - // Create a dispatch function with switch-case for each function - IRBuilder builder(collection->getModule()); + IRBuilder builder(dispatchFuncType->getModule()); // Consume the first parameter of the expected function type List innerParamTypes; @@ -85,42 +84,40 @@ IRFunc* createDispatchFunc(IRFuncCollection* collection) List caseValues; List caseBlocks; - UIndex funcSeqID = 0; - forEachInCollection( - collection, - [&](IRInst* funcInst) - { - auto funcId = funcSeqID++; - auto wrapperFunc = - emitWitnessTableWrapper(funcInst->getModule(), funcInst, innerFuncType); + for (auto kvPair : mapping) + { + auto funcInst = kvPair.second; + auto funcTag = kvPair.first; - // Create case block - auto caseBlock = builder.emitBlock(); - builder.setInsertInto(caseBlock); + auto wrapperFunc = emitWitnessTableWrapper(funcInst->getModule(), funcInst, innerFuncType); - List callArgs; - auto wrappedFuncType = as(wrapperFunc->getDataType()); - for (Index ii = 0; ii < originalParams.getCount(); ii++) - { - callArgs.add(originalParams[ii]); - } + // Create case block + auto caseBlock = builder.emitBlock(); + builder.setInsertInto(caseBlock); - // Call the specific function - auto callResult = - builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); + List callArgs; + auto wrappedFuncType = as(wrapperFunc->getDataType()); + for (Index ii = 0; ii < originalParams.getCount(); ii++) + { + callArgs.add(originalParams[ii]); + } - if (resultType->getOp() == kIROp_VoidType) - { - builder.emitReturn(); - } - else - { - builder.emitReturn(callResult); - } + // Call the specific function + auto callResult = + builder.emitCallInst(wrappedFuncType->getResultType(), wrapperFunc, callArgs); + + if (resultType->getOp() == kIROp_VoidType) + { + builder.emitReturn(); + } + else + { + builder.emitReturn(callResult); + } - caseValues.add(builder.getIntValue(builder.getUIntType(), funcId)); - caseBlocks.add(caseBlock); - }); + caseValues.add(funcTag); + caseBlocks.add(caseBlock); + } // Create flattened case arguments array List flattenedCaseArgs; @@ -233,12 +230,12 @@ struct TagOpsLoweringContext : public InstPassBase // collection. // // e.g. - // let a : TagType(TableCollection(B, C)) = /* ... */; - // let b : TagType(TableCollection(A, B, C)) = GetTagForSuperCollection(a); + // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; + // let b : TagType(WitnessTableCollection(A, B, C)) = GetTagForSuperCollection(a); // becomes - // let a : TagType(TableCollection(B, C)) = /* ... */; + // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; // let lookupArr : ArrayType = [1, 2]; // B is at index 1, C is at index 2 - // let b : TagType(TableCollection(A, B, C)) = ElementExtract(lookupArr, a); + // let b : TagType(WitnessTableCollection(A, B, C)) = ElementExtract(lookupArr, a); // // Note that we leave the tag-types of the output intact since we may need to lower // later tag operations. @@ -299,11 +296,11 @@ struct TagOpsLoweringContext : public InstPassBase // collection. // // e.g. - // let a : TagType(TableCollection(B, C)) = /* ... */; + // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; // let b : TagType(FuncCollection(C_key, B_key)) = // GetTagForMappedCollection(a, key); // becomes - // let a : TagType(TableCollection(B, C)) = /* ... */; + // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; // let lookupArr : ArrayType = [1, 0]; // B is at index 1, C is at index 0 // let b : TagType(FuncCollection(C_key, B_key)) = ElementExtract(lookupArr, a); // @@ -311,7 +308,7 @@ struct TagOpsLoweringContext : public InstPassBase // later tag operations. // - auto srcCollection = cast( + auto srcCollection = cast( cast(inst->getOperand(0)->getDataType())->getOperand(0)); auto destCollection = cast(cast(inst->getDataType())->getOperand(0)); @@ -431,6 +428,30 @@ struct TagOpsLoweringContext : public InstPassBase inst->removeAndDeallocate(); } + void lowerGetTagOfElementInCollection(IRGetTagOfElementInCollection* inst) + { + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + + // find the index of the element in the collection + auto collection = cast(inst->getOperand(1)); + auto element = inst->getOperand(0); + UInt foundIndex = UInt(-1); + for (UInt i = 0; i < collection->getCount(); i++) + { + if (collection->getElement(i) == element) + { + foundIndex = i; + break; + } + } + + SLANG_ASSERT(foundIndex != UInt(-1)); + auto resultValue = builder.getIntValue(inst->getDataType(), foundIndex); + inst->replaceUsesWith(resultValue); + inst->removeAndDeallocate(); + } + void processInst(IRInst* inst) { switch (inst->getOp()) @@ -444,31 +465,138 @@ struct TagOpsLoweringContext : public InstPassBase case kIROp_GetTagForSpecializedCollection: lowerGetTagForSpecializedCollection(as(inst)); break; + case kIROp_GetTagOfElementInCollection: + lowerGetTagOfElementInCollection(as(inst)); + break; default: break; } } - void lowerFuncCollection(IRFuncCollection* collection) + + void processModule() + { + processAllInsts([&](IRInst* inst) { return processInst(inst); }); + } +}; + +struct DispatcherLoweringContext : public InstPassBase +{ + DispatcherLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + void lowerGetSpecializedDispatcher(IRGetSpecializedDispatcher* dispatcher) + { + // Replace the `IRGetSpecializedDispatcher` with a dispatch function, + // which takes an extra first parameter for the tag (i.e. ID) + // + // We'll also replace the callee in all 'call' insts. + // + + auto witnessTableCollection = cast(dispatcher->getOperand(0)); + auto key = cast(dispatcher->getOperand(1)); + + List specArgs; + for (UIndex i = 2; i < dispatcher->getOperandCount(); i++) + { + specArgs.add(dispatcher->getOperand(i)); + } + + Dictionary elements; + UInt index = 0; + IRBuilder builder(dispatcher->getModule()); + forEachInCollection( + witnessTableCollection, + [&](IRInst* table) + { + auto generic = + cast(findWitnessTableEntry(cast(table), key)); + + auto specializedFuncType = + (IRType*)specializeGeneric(cast(builder.emitSpecializeInst( + builder.getTypeKind(), + generic->getDataType(), + specArgs.getCount(), + specArgs.getBuffer()))); + + auto specializedFunc = builder.emitSpecializeInst( + specializedFuncType, + generic, + specArgs.getCount(), + specArgs.getBuffer()); + + auto singletonTag = builder.emitGetTagOfElementInCollection( + builder.getCollectionTagType(witnessTableCollection), + table, + witnessTableCollection); + + elements.add(singletonTag, specializedFunc); + }); + + if (dispatcher->hasUses() && dispatcher->getDataType() != nullptr) + { + auto dispatchFunc = + createDispatchFunc(cast(dispatcher->getDataType()), elements); + traverseUses( + dispatcher, + [&](IRUse* use) + { + if (auto callInst = as(use->getUser())) + { + // Replace callee with the generated dispatchFunc. + if (callInst->getCallee() == dispatcher) + { + IRBuilder callBuilder(callInst); + callBuilder.setInsertBefore(callInst); + callBuilder.replaceOperand(callInst->getCalleeUse(), dispatchFunc); + } + } + }); + } + } + + void lowerGetDispatcher(IRGetDispatcher* dispatcher) { - // Replace the `IRFuncCollection` with a dispatch function, + // Replace the `IRGetDispatcher` with a dispatch function, // which takes an extra first parameter for the tag (i.e. ID) // // We'll also replace the callee in all 'call' insts. // - IRBuilder builder(collection->getModule()); - if (collection->hasUses() && collection->getDataType() != nullptr) + auto witnessTableCollection = cast(dispatcher->getOperand(0)); + auto key = cast(dispatcher->getOperand(1)); + + IRBuilder builder(dispatcher->getModule()); + + Dictionary elements; + UInt index = 0; + forEachInCollection( + witnessTableCollection, + [&](IRInst* table) + { + auto tag = builder.emitGetTagOfElementInCollection( + builder.getCollectionTagType(witnessTableCollection), + table, + witnessTableCollection); + elements.add( + tag, + cast(findWitnessTableEntry(cast(table), key))); + }); + + if (dispatcher->hasUses() && dispatcher->getDataType() != nullptr) { - auto dispatchFunc = createDispatchFunc(collection); + auto dispatchFunc = + createDispatchFunc(cast(dispatcher->getDataType()), elements); traverseUses( - collection, + dispatcher, [&](IRUse* use) { if (auto callInst = as(use->getUser())) { - // If the call is a collection call, replace it with the dispatch function - if (callInst->getCallee() == collection) + // Replace callee with the generated dispatchFunc. + if (callInst->getCallee() == dispatcher) { IRBuilder callBuilder(callInst); callBuilder.setInsertBefore(callInst); @@ -481,14 +609,24 @@ struct TagOpsLoweringContext : public InstPassBase void processModule() { - processInstsOfType( - kIROp_FuncCollection, - [&](IRFuncCollection* inst) { return lowerFuncCollection(inst); }); + processInstsOfType( + kIROp_GetDispatcher, + [&](IRGetDispatcher* inst) { return lowerGetDispatcher(inst); }); - processAllInsts([&](IRInst* inst) { return processInst(inst); }); + processInstsOfType( + kIROp_GetSpecializedDispatcher, + [&](IRGetSpecializedDispatcher* inst) { return lowerGetSpecializedDispatcher(inst); }); } }; +bool lowerDispatchers(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + DispatcherLoweringContext context(module); + context.processModule(); + return true; +} + // This context lowers `TypeCollection` instructions. struct CollectionLoweringContext : public InstPassBase { @@ -497,35 +635,37 @@ struct CollectionLoweringContext : public InstPassBase { } - void lowerTypeCollection(IRTypeCollection* collection) + void lowerValueOfCollectionType(IRValueOfCollectionType* valueOfCollectionType) { // Type collections are replaced with `AnyValueType` large enough to hold // any of the types in the collection. // HashSet types; - for (UInt i = 0; i < collection->getCount(); i++) + for (UInt i = 0; i < valueOfCollectionType->getCollection()->getCount(); i++) { - if (auto type = as(collection->getElement(i))) + if (auto type = as(valueOfCollectionType->getCollection()->getElement(i))) { types.add(type); } } - IRBuilder builder(collection->getModule()); + IRBuilder builder(module); auto anyValueType = createAnyValueType(&builder, types); - collection->replaceUsesWith(anyValueType); + valueOfCollectionType->replaceUsesWith(anyValueType); } void processModule() { - processInstsOfType( - kIROp_TypeCollection, - [&](IRTypeCollection* inst) { return lowerTypeCollection(inst); }); + processInstsOfType( + kIROp_ValueOfCollectionType, + [&](IRValueOfCollectionType* inst) { return lowerValueOfCollectionType(inst); }); } }; -// Lower `TypeCollection` instructions +// Lower `ValueOfCollectionType(TypeCollection(...))` instructions by replacing them with +// appropriate `AnyValueType` instructions. +// void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -860,18 +1000,79 @@ struct TaggedUnionLoweringContext : public InstPassBase IRBuilder builder(module); builder.setInsertInto(module); - auto typeCollection = cast(taggedUnion->getOperand(0)); - auto tableCollection = cast(taggedUnion->getOperand(1)); + auto typeCollection = builder.getValueOfCollectionType(taggedUnion->getTypeCollection()); + auto tableCollection = taggedUnion->getWitnessTableCollection(); - if (getCollectionCount(typeCollection) == 1) + if (taggedUnion->getTypeCollection()->isSingleton()) return builder.getTupleType(List( {(IRType*)makeTagType(tableCollection), - (IRType*)getCollectionElement(typeCollection, 0)})); + (IRType*)taggedUnion->getTypeCollection()->getElement(0)})); return builder.getTupleType( List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); } + bool lowerGetValueFromTaggedUnion(IRGetValueFromTaggedUnion* inst) + { + // We replace `GetValueFromTaggedUnion(taggedUnionVal)` with + // `GetTupleElement(taggedUnionVal, 1)` + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tupleVal = inst->getOperand(0); + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)as(tupleVal->getDataType())->getOperand(1), + tupleVal, + 1)); + inst->removeAndDeallocate(); + return true; + } + + bool lowerGetTagFromTaggedUnion(IRGetTagFromTaggedUnion* inst) + { + // We replace `GetTagFromTaggedUnion(taggedUnionVal)` with + // `GetTupleElement(taggedUnionVal, 0)` + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tupleVal = inst->getOperand(0); + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)as(tupleVal->getDataType())->getOperand(0), + tupleVal, + 0)); + inst->removeAndDeallocate(); + return true; + } + + bool lowerGetTypeTagFromTaggedUnion(IRGetTypeTagFromTaggedUnion* inst) + { + // We don't use type tags anywhere, so this instruction should have no + // uses. + // + SLANG_ASSERT(inst->hasUses() == false); + inst->removeAndDeallocate(); + return true; + } + + bool lowerMakeTaggedUnion(IRMakeTaggedUnion* inst) + { + // We replace `MakeTaggedUnion(tag, val)` with `MakeTuple(tag, val)` + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tag = inst->getOperand(0); + auto val = inst->getOperand(1); + inst->replaceUsesWith(builder.emitMakeTuple((IRType*)inst->getDataType(), {tag, val})); + inst->removeAndDeallocate(); + return true; + } + bool processModule() { // First, we'll lower all CollectionTaggedUnionType insts @@ -885,6 +1086,27 @@ struct TaggedUnionLoweringContext : public InstPassBase inst->removeAndDeallocate(); }); + // TODO: Is this repeated scanning of the module inefficient? + // It feels like this form could be very efficient if it's automatically + // 'fused' together. + // + processInstsOfType( + kIROp_GetTagFromTaggedUnion, + [&](IRGetTagFromTaggedUnion* inst) { return lowerGetTagFromTaggedUnion(inst); }); + + processInstsOfType( + kIROp_GetTypeTagFromTaggedUnion, + [&](IRGetTypeTagFromTaggedUnion* inst) + { return lowerGetTypeTagFromTaggedUnion(inst); }); + + processInstsOfType( + kIROp_GetValueFromTaggedUnion, + [&](IRGetValueFromTaggedUnion* inst) { return lowerGetValueFromTaggedUnion(inst); }); + + processInstsOfType( + kIROp_MakeTaggedUnion, + [&](IRMakeTaggedUnion* inst) { return lowerMakeTaggedUnion(inst); }); + // Then, convert any loads/stores from reinterpreted pointers. bool hasCastInsts = false; processInstsOfType( diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h index 82f1cf5fc48..573587a9360 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -5,10 +5,18 @@ namespace Slang { -// Lower `TypeCollection` instructions +// Lower `ValueOfCollectionType` types. void lowerTypeCollections(IRModule* module, DiagnosticSink* sink); -// Lower `FuncCollection`, `GetTagForSuperCollection`, `GetTagForMappedCollection` and +// Lower `CollectionTaggedUnion` and `CastInterfaceToTaggedUnionPtr` instructions +// May create new `Reinterpret` instructions. +// +bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); + +// Lower `CollectionTagType` types +void lowerTagTypes(IRModule* module); + +// Lower `GetTagForSuperCollection`, `GetTagForMappedCollection` and // `GetTagForSpecializedCollection` instructions // void lowerTagInsts(IRModule* module, DiagnosticSink* sink); @@ -16,12 +24,7 @@ void lowerTagInsts(IRModule* module, DiagnosticSink* sink); // Lower `GetTagFromSequentialID` and `GetSequentialIDFromTag` instructions void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink); -// Lower `CollectionTagType` instructions -void lowerTagTypes(IRModule* module); - -// Lower `CollectionTaggedUnion`and `CastInterfaceToTaggedUnionPtr` instructions -// May create new `Reinterpret` instructions. -// -bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); +// Lower `GetDispatcher` and `GetSpecializedDispatcher` instructions +bool lowerDispatchers(IRModule* module, DiagnosticSink* sink); } // namespace Slang diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 5bbc62bed5a..8d9c58b1b6c 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -148,6 +148,12 @@ struct AssociatedTypeLookupSpecializationContext cast(witnessTableType)->getConformanceType()); if (!interfaceType) return; + + List tables = + sharedContext->getWitnessTablesFromInterfaceType(interfaceType); + if (tables.getCount() == 0) + return; + auto key = inst->getRequirementKey(); IRFunc* func = nullptr; if (!sharedContext->mapInterfaceRequirementKeyToDispatchMethods.tryGetValue(key, func)) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 8b448b59ccd..26d05cdf1ae 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -5,6 +5,7 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-insts.h" +#include "slang-ir-lower-typeflow-insts.h" #include "slang-ir-lower-witness-lookup.h" #include "slang-ir-peephole.h" #include "slang-ir-sccp.h" @@ -863,7 +864,7 @@ struct SpecializationContext IRInterfaceType* interfaceType = nullptr; if (!witnessTable) { - if (auto collection = as(lookupInst->getWitnessTable())) + if (auto collection = as(lookupInst->getWitnessTable())) { auto requirementKey = lookupInst->getRequirementKey(); @@ -891,8 +892,24 @@ struct SpecializationContext CollectionBuilder cBuilder(lookupInst->getModule()); auto newCollection = cBuilder.makeSet(satisfyingValSet); addUsersToWorkList(lookupInst); - lookupInst->replaceUsesWith(newCollection); - lookupInst->removeAndDeallocate(); + if (as(newCollection)) + { + IRBuilder builder(module); + lookupInst->replaceUsesWith( + builder.getValueOfCollectionType(newCollection)); + lookupInst->removeAndDeallocate(); + } + else if (as(newCollection)) + { + lookupInst->replaceUsesWith(newCollection); + lookupInst->removeAndDeallocate(); + } + else + { + // Should not see any other case. + SLANG_UNREACHABLE("unexpected collection type"); + } + return true; } else @@ -1115,7 +1132,7 @@ struct SpecializationContext IRSpecializationDictionaryItem* item = nullptr; if (dict.tryGetValue(key, item)) { - if (as(item->getOperand(0))) + if (!as(item->getOperand(0))) return item->getOperand(0); else { @@ -1260,6 +1277,7 @@ struct SpecializationContext if (iterChanged) { eliminateDeadCode(module->getModuleInst()); + lowerDispatchers(module, sink); } } @@ -3173,7 +3191,7 @@ void finalizeSpecialization(IRModule* module) // The resulting function will therefore have additional parameters at the beginning // to accept this information. // -static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) +IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) { // The high-level logic for specializing a generic to operate over collections // is similar to specializing a simple generic: @@ -3192,7 +3210,7 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // // - Add any dynamic parameters of the generic to the function's first block. Keep track of the // first block for later. For now, we only treat `WitnessTableType` parameters that have - // `TableCollection` arguments (with atleast 2 distinct elements) as dynamic. Each such + // `WitnessTableCollection` arguments (with atleast 2 distinct elements) as dynamic. Each such // parameter will get a corresponding parameter of `TagType(tableCollection)` // // - Clone in the rest of the generic's body into the first block of the function. @@ -3234,6 +3252,7 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) Index argIndex = 0; List extraParamTypes; + OrderedDictionary extraParamMap; // Map the generic's parameters to the specialized arguments. for (auto param : generic->getFirstBlock()->getParams()) { @@ -3243,7 +3262,8 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // We're dealing with a set of types. if (as(param->getDataType())) { - cloneEnv.mapOldValToNew[param] = collection; + SLANG_ASSERT("Should not happen"); + cloneEnv.mapOldValToNew[param] = builder.getValueOfCollectionType(collection); } else if (as(param->getDataType())) { @@ -3256,7 +3276,8 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // the insts which will may need the runtime tag. // auto tagType = (IRType*)makeTagType(collection); - cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); + // cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); + extraParamMap.add(param, builder.emitParam(tagType)); extraParamTypes.add(tagType); } } @@ -3268,6 +3289,21 @@ static IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) } } + // The parameters we've used so far are merely tags, we can't use them directly + // without turning them into elements. + // + // We'll emit a `GetElementFromTag` on the parameter and map that to the original generic + // parameter. + // + for (auto paramPair : extraParamMap) + { + auto originalParam = paramPair.key; + auto newTagParam = paramPair.value; + + auto getElementInst = builder.emitGetElementFromTag(newTagParam); + cloneEnv.mapOldValToNew[originalParam] = getElementInst; + } + // Clone in the rest of the generic's body including the blocks of the returned func. for (auto inst = generic->getFirstBlock()->getFirstOrdinaryInst(); inst; inst = inst->getNextInst()) diff --git a/source/slang/slang-ir-specialize.h b/source/slang/slang-ir-specialize.h index 3d657e4875f..a82f53e8ae7 100644 --- a/source/slang/slang-ir-specialize.h +++ b/source/slang/slang-ir-specialize.h @@ -28,4 +28,9 @@ void finalizeSpecialization(IRModule* module); IRInst* specializeGeneric(IRSpecialize* specInst); +// Specialize a generic with one or more arguments that are collections rather +// than single concrete values. +// +IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst); + } // namespace Slang diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index b6a7441d068..dff7eb45a9f 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -1,6 +1,7 @@ #include "slang-ir-typeflow-collection.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" #include "slang-ir.h" namespace Slang @@ -68,8 +69,8 @@ UInt CollectionBuilder::getUniqueID(IRInst* inst) IRCollectionBase* CollectionBuilder::createCollection(IROp op, const HashSet& elements) { SLANG_ASSERT( - op == kIROp_TypeCollection || op == kIROp_FuncCollection || op == kIROp_TableCollection || - op == kIROp_GenericCollection); + op == kIROp_TypeCollection || op == kIROp_FuncCollection || + op == kIROp_WitnessTableCollection || op == kIROp_GenericCollection); if (elements.getCount() == 0) return nullptr; @@ -110,7 +111,7 @@ IROp CollectionBuilder::getCollectionTypeForInst(IRInst* inst) else if (as(inst) && !as(inst)) return kIROp_TypeCollection; else if (as(inst->getDataType())) - return kIROp_TableCollection; + return kIROp_WitnessTableCollection; else return kIROp_Invalid; // Return invalid IROp when not supported } @@ -129,4 +130,120 @@ IRCollectionBase* CollectionBuilder::makeSet(const HashSet& values) return createCollection(getCollectionTypeForInst(*values.begin()), values); } +// Upcast the value in 'arg' to match the destInfo type. This method inserts +// any necessary reinterprets or tag translation instructions. +// +IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) +{ + // The upcasting process inserts the appropriate instructions + // to make arg's type match the type provided by destInfo. + // + // This process depends on the structure of arg and destInfo. + // + // We only deal with the type-flow data-types that are created in + // our pass (CollectionBase/CollectionTaggedUnionType/CollectionTagType/any other + // composites of these insts) + // + + auto argInfo = arg->getDataType(); + if (!argInfo || !destInfo) + return arg; + + if (as(argInfo) && as(destInfo)) + { + // A collection tagged union is essentially a tuple(TagType(tableCollection), + // typeCollection) We simply extract the two components, upcast each one, and put it + // back together. + // + + auto argTUType = as(argInfo); + auto destTUType = as(destInfo); + + if (getCollectionCount(argTUType) != getCollectionCount(destTUType)) + { + // Technically, IRCollectionTaggedUnionType is not a TupleType, + // but in practice it works the same way so we'll re-use Slang's + // tuple accessors & constructors + // + // IRBuilder builder(module); + // setInsertAfterOrdinaryInst(&builder, arg); + auto argTableTag = builder->emitGetTagFromTaggedUnion(arg); + auto reinterpretedTag = upcastCollection( + builder, + argTableTag, + makeTagType(destTUType->getWitnessTableCollection())); + + auto argVal = builder->emitGetValueFromTaggedUnion(arg); + auto reinterpretedVal = upcastCollection( + builder, + argVal, + builder->getValueOfCollectionType(destTUType->getTypeCollection())); + return builder->emitMakeTaggedUnion(destTUType, reinterpretedTag, reinterpretedVal); + } + } + else if (as(argInfo) && as(destInfo)) + { + // If the arg represents a tag of a colleciton, but the dest is a _different_ + // collection, then we need to emit a tag operation to reinterpret the + // tag. + // + // Note that, by the invariant provided by the typeflow analysis, the target + // collection must necessarily be a super-set. + // + if (getCollectionCount(as(argInfo)) != + getCollectionCount(as(destInfo))) + { + return builder + ->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); + } + } + else if (as(argInfo) && as(destInfo)) + { + // If the arg has a collection type, but the dest is a _different_ collection, + // we need to perform a reinterpret. + // + // e.g. TypeCollection({T1, T2}) may lower to AnyValueType(N), while + // TypeCollection({T1, T2, T3}) may lower to AnyValueType(M). Since the target + // is necessarily a super-set, the target any-value-type is always larger (M >= N), + // so we only need a simple reinterpret. + // + if (getCollectionCount(as(argInfo)->getCollection()) != + getCollectionCount(as(destInfo)->getCollection())) + { + auto argCollection = as(argInfo)->getCollection(); + if (argCollection->isSingleton() && as(argCollection->getElement(0))) + { + // There's a specific case where we're trying to reinterpret a value of 'void' + // type. We'll avoid emitting a reinterpret in this case, and emit a + // default-construct instead. + // + // IRBuilder builder(module); + // setInsertAfterOrdinaryInst(&builder, arg); + return builder->emitDefaultConstruct((IRType*)destInfo); + } + + // General case: + // + // If the sets of witness tables are not equal, reinterpret to the + // parameter type + // + // IRBuilder builder(module); + // setInsertAfterOrdinaryInst(&builder, arg); + return builder->emitReinterpret((IRType*)destInfo, arg); + } + } + else if (!as(argInfo) && as(destInfo)) + { + // If the arg is not a collection-type, but the dest is a collection, + // we need to perform a pack operation. + // + // This case only arises when passing a value of type T to a parameter + // of a type-collection that contains T. + // + return builder->emitPackAnyValue((IRType*)destInfo, arg); + } + + return arg; // Can use as-is. +} + } // namespace Slang diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h index 27283ce8288..110be0bff7d 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-collection.h @@ -36,6 +36,11 @@ void forEachInCollection(IRCollectionTagType* tagType, F func) forEachInCollection(as(tagType->getCollection()), func); } +// Upcast the value in 'arg' to match the destInfo type. This method inserts +// any necessary reinterprets or tag translation instructions. +// +IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo); + // Builder class that helps greatly with constructing `CollectionBase` instructions, // which conceptually represent sets, and maintain the property that the equal sets // should always be represented by the same instruction. diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 3e91eaaeac0..3a5443b2724 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -442,7 +442,7 @@ struct TypeFlowSpecializationContext // This type can be used for insts that are semantically a tuple of a tag (to select a table) // and a payload to contain the existential value. // - IRCollectionTaggedUnionType* makeExistential(IRTableCollection* tableCollection) + IRCollectionTaggedUnionType* makeTaggedUnionType(IRWitnessTableCollection* tableCollection) { HashSet typeSet; @@ -459,7 +459,7 @@ struct TypeFlowSpecializationContext // Create the tagged union type out of the type and table collection. IRBuilder builder(module); - List elements = {typeCollection, tableCollection}; + List elements = {tableCollection, typeCollection}; return as(builder.emitIntrinsicInst( nullptr, kIROp_CollectionTaggedUnionType, @@ -494,6 +494,22 @@ struct TypeFlowSpecializationContext // IRTypeFlowData* none() { return nullptr; } + IRValueOfCollectionType* makeValueOfCollectionType(IRTypeCollection* typeCollection) + { + IRBuilder builder(module); + IRInst* operand = typeCollection; + return cast( + builder.emitIntrinsicInst(nullptr, kIROp_ValueOfCollectionType, 1, &operand)); + } + + IRElementOfCollectionType* makeElementOfCollectionType(IRCollectionBase* collection) + { + IRBuilder builder(module); + IRInst* operand = collection; + return cast( + builder.emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &operand)); + } + IRInst* _tryGetInfo(InstWithContext element) { auto found = propagationMap.tryGetValue(element); @@ -502,18 +518,89 @@ struct TypeFlowSpecializationContext return none(); // Default info for any inst that we haven't registered. } + bool isConcreteType(IRInst* inst) + { + if (!isGlobalInst(inst) || as(inst) || + as(inst) && as(inst)) + return false; + + if (as(inst)) + { + auto ptrType = as(inst); + return isConcreteType(ptrType->getValueType()); + } + + if (as(inst)) + { + auto arrayType = as(inst); + return isConcreteType(arrayType->getElementType()) && + isGlobalInst(arrayType->getElementCount()); + } + + if (as(inst)) + { + auto optionalType = as(inst); + return isConcreteType(optionalType->getValueType()); + } + + return true; + } + + IRInst* tryGetArgInfo(IRInst* context, IRInst* inst) + { + if (auto info = tryGetInfo(context, inst)) + return info; + + + IRBuilder builder(module); + if (auto ptrType = as(inst->getDataType())) + { + if (isConcreteType(ptrType->getValueType())) + return builder.getPtrTypeWithAddressSpace( + builder.getValueOfCollectionType( + cast(cBuilder.makeSingletonSet(ptrType->getValueType()))), + ptrType); + else + return none(); + } + + if (auto arrayType = as(inst->getDataType())) + { + if (isConcreteType(arrayType)) + { + return builder.getArrayType( + builder.getValueOfCollectionType(cast( + cBuilder.makeSingletonSet(arrayType->getElementType()))), + arrayType->getElementCount()); + } + else + return none(); + } + + if (isConcreteType(inst->getDataType())) + return builder.getValueOfCollectionType( + cast(cBuilder.makeSingletonSet(inst->getDataType()))); + else + return none(); + } + // // Bottleneck method to fetch the current propagation info // for a given instruction under context. // IRInst* tryGetInfo(IRInst* context, IRInst* inst) { - if (auto typeFlowData = as(inst->getDataType())) + if (inst->getDataType()) { - // If the instruction already has a stablilized type flow data, - // return it directly. - // - return typeFlowData; + switch (inst->getDataType()->getOp()) + { + case kIROp_CollectionTaggedUnionType: + case kIROp_ValueOfCollectionType: + case kIROp_ElementOfCollectionType: + // These insts directly represent type-flow information, + // so we return them directly. + return inst->getDataType(); + } } // A small check for de-allocated insts. @@ -527,6 +614,7 @@ struct TypeFlowSpecializationContext // entity, we do this on demand rather than trying to put it in the // propagation map. // + /* if (as(inst->getParent())) { if (as(inst) || as(inst) || as(inst) || @@ -541,11 +629,16 @@ struct TypeFlowSpecializationContext if (as(inst) && as(getGenericReturnVal(inst))) return none(); - return cBuilder.makeSingletonSet(inst); + return makeElementOfCollectionType(cBuilder.makeSingletonSet(inst)); } else return none(); } + */ + if (as(inst->getParent())) + { + return none(); + } return _tryGetInfo(InstWithContext(context, inst)); } @@ -644,9 +737,9 @@ struct TypeFlowSpecializationContext if (as(info1) && as(info2)) { - return makeExistential(unionCollection( - cast(info1->getOperand(1)), - cast(info2->getOperand(1)))); + return makeTaggedUnionType(unionCollection( + as(info1)->getWitnessTableCollection(), + as(info2)->getWitnessTableCollection())); } if (as(info1) && as(info2)) @@ -656,8 +749,23 @@ struct TypeFlowSpecializationContext cast(info2->getOperand(0)))); } + if (as(info1) && as(info2)) + { + return makeElementOfCollectionType(unionCollection( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); + } + + if (as(info1) && as(info2)) + { + return makeValueOfCollectionType(unionCollection( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); + } + if (as(info1) && as(info2)) { + SLANG_UNEXPECTED("Should not see 'raw' collection types anymore"); return unionCollection( cast(info1), cast(info2)); @@ -763,12 +871,23 @@ struct TypeFlowSpecializationContext if (auto param = as(user)) if (isFuncParam(param)) addContextUsersToWorkQueue(context, workQueue); + + // TODO: Stopgap workaround. + // Add an analyzeFuncType for this.. + if (auto funcType = as(user)) + if (as(funcType->getParent())) + addUsersToWorkQueue(context, funcType, none(), workQueue); } } // Helper method to update function's return info and propagate back to call sites void updateFuncReturnInfo(IRInst* callable, IRInst* returnInfo, WorkQueue& workQueue) { + // Don't update info if the callee has a concrete return type. + auto callableFuncType = cast(callable->getDataType()); + if (isConcreteType(callableFuncType->getResultType())) + return; + auto existingReturnInfo = getFuncReturnInfo(callable); auto newReturnInfo = unionPropagationInfo(existingReturnInfo, returnInfo); @@ -855,7 +974,7 @@ struct TypeFlowSpecializationContext void processInstForPropagation(IRInst* context, IRInst* inst, WorkQueue& workQueue) { - IRInst* info; + IRInst* info = nullptr; switch (inst->getOp()) { @@ -882,6 +1001,9 @@ struct TypeFlowSpecializationContext case kIROp_Call: info = analyzeCall(context, as(inst), workQueue); break; + // case kIROp_Param: + // info = analyzeParam(context, as(inst)); + // break; case kIROp_Specialize: info = analyzeSpecialize(context, as(inst)); break; @@ -911,13 +1033,22 @@ struct TypeFlowSpecializationContext case kIROp_MakeOptionalValue: info = analyzeMakeOptionalValue(context, as(inst)); break; - default: - info = analyzeDefault(context, inst); - break; } - bool takeUnion = !as(inst); - updateInfo(context, inst, info, takeUnion, workQueue); + if (!info && inst->getDataType()) + { + if (auto dataTypeInfo = tryGetInfo(context, inst->getDataType())) + { + if (auto elementOfCollectionType = as(dataTypeInfo)) + { + info = makeValueOfCollectionType( + cast(elementOfCollectionType->getCollection())); + } + } + } + + if (info) + updateInfo(context, inst, info, false, workQueue); } void processBlock(IRInst* context, IRBlock* block, WorkQueue& workQueue) @@ -932,8 +1063,8 @@ struct TypeFlowSpecializationContext if (auto returnInfo = as(block->getTerminator())) { - auto valInfo = returnInfo->getVal(); - updateFuncReturnInfo(context, tryGetInfo(context, valInfo), workQueue); + auto val = returnInfo->getVal(); + updateFuncReturnInfo(context, tryGetArgInfo(context, val), workQueue); } }; @@ -964,7 +1095,7 @@ struct TypeFlowSpecializationContext if (paramIndex < unconditionalBranch->getArgCount()) { auto arg = unconditionalBranch->getArg(paramIndex); - if (auto argInfo = tryGetInfo(context, arg)) + if (auto argInfo = tryGetArgInfo(context, arg)) { // Use centralized update method updateInfo(context, param, argInfo, true, workQueue); @@ -1008,13 +1139,20 @@ struct TypeFlowSpecializationContext // 1. The paramType is a global inst and an interface type // 2. The paramType is a local inst. // all other cases, continue. - if (isGlobalInst(paramType) && !as(paramType)) + // + // This is primarily just an optimization. Without this, + // we'd be storing 'singleton' sets for parameters with + // regular concrete types (i.e. 99% of cases). + // This optimization ignores them and re-derives the info + // from the data-type. + // + if (isConcreteType(paramType)) { argIndex++; continue; } - IRInst* argInfo = tryGetInfo(edge.callerContext, arg); + IRInst* argInfo = tryGetArgInfo(edge.callerContext, arg); switch (paramDirection.kind) { @@ -1023,6 +1161,7 @@ struct TypeFlowSpecializationContext case ParameterDirectionInfo::Kind::BorrowIn: { IRBuilder builder(module); + /* if (!argInfo) { if (isGlobalInst(arg->getDataType()) && @@ -1030,6 +1169,7 @@ struct TypeFlowSpecializationContext as(arg->getDataType())->getValueType())) argInfo = arg->getDataType(); } + */ if (!argInfo) break; @@ -1043,13 +1183,14 @@ struct TypeFlowSpecializationContext } case ParameterDirectionInfo::Kind::In: { - // Use centralized update method + /* if (!argInfo) { if (isGlobalInst(arg->getDataType()) && !as(arg->getDataType())) argInfo = arg->getDataType(); } + */ updateInfo(edge.targetContext, param, argInfo, true, workQueue); break; } @@ -1064,12 +1205,36 @@ struct TypeFlowSpecializationContext } case InterproceduralEdge::Direction::FuncToCall: { - // Propagate return value info from function to call site - auto returnInfo = funcReturnInfo.tryGetValue(targetCallee); - if (returnInfo) + // If the call inst cannot accept anything dynamic, then + // no need to propagate anything to the result of the call inst. + // + // We'll still need to consider out parameters separately. + // + if (!isConcreteType(callInst->getDataType())) { - // Use centralized update method - updateInfo(edge.callerContext, callInst, *returnInfo, true, workQueue); + auto returnInfoPtr = funcReturnInfo.tryGetValue(targetCallee); + auto returnInfo = (returnInfoPtr) ? *returnInfoPtr : nullptr; + if (!returnInfo) + { + // If the targetCallee's return type is concrete, but the + // callInst's return type is not, we should still propagate the + // known concrete type. + // + auto concreteReturnType = + cast(targetCallee->getDataType())->getResultType(); + if (isConcreteType(concreteReturnType)) + { + IRBuilder builder(module); + returnInfo = builder.getValueOfCollectionType(cast( + cBuilder.makeSingletonSet(concreteReturnType))); + } + } + + if (returnInfo) + { + // Use centralized update method + updateInfo(edge.callerContext, callInst, returnInfo, true, workQueue); + } } // Also update infos of any out parameters @@ -1122,8 +1287,8 @@ struct TypeFlowSpecializationContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential(as( - cBuilder.createCollection(kIROp_TableCollection, tables))); + return makeTaggedUnionType(as( + cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); else return none(); } @@ -1142,6 +1307,11 @@ struct TypeFlowSpecializationContext if (isComInterfaceType(inst->getDataType())) return makeUnbounded(); + // Concrete case. + if (as(witnessTable)) + return makeTaggedUnionType( + as(cBuilder.makeSingletonSet(witnessTable))); + // Get the witness table info auto witnessTableInfo = tryGetInfo(context, witnessTable); @@ -1151,11 +1321,9 @@ struct TypeFlowSpecializationContext if (as(witnessTableInfo)) return makeUnbounded(); - if (as(witnessTable)) - return makeExistential(as(cBuilder.makeSingletonSet(witnessTable))); - - if (auto collectionTag = as(witnessTableInfo)) - return makeExistential(cast(collectionTag->getCollection())); + if (auto elementOfCollectionType = as(witnessTableInfo)) + return makeTaggedUnionType( + cast(elementOfCollectionType->getCollection())); SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } @@ -1219,8 +1387,8 @@ struct TypeFlowSpecializationContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential(as( - cBuilder.createCollection(kIROp_TableCollection, tables))); + return makeTaggedUnionType(as( + cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); else return none(); } @@ -1249,8 +1417,8 @@ struct TypeFlowSpecializationContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeExistential(as( - cBuilder.createCollection(kIROp_TableCollection, tables))); + return makeTaggedUnionType(as( + cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); else return none(); } @@ -1432,9 +1600,9 @@ struct TypeFlowSpecializationContext if (isOptionalExistentialType(inst->getDataType())) { IRBuilder builder(module); - auto noneTableSet = cast( - cBuilder.createCollection(kIROp_TableCollection, getNoneWitness())); - return makeExistential(noneTableSet); + auto noneTableSet = cast( + cBuilder.createCollection(kIROp_WitnessTableCollection, getNoneWitness())); + return makeTaggedUnionType(noneTableSet); } return none(); @@ -1508,22 +1676,22 @@ struct TypeFlowSpecializationContext auto witnessTable = inst->getWitnessTable(); auto witnessTableInfo = tryGetInfo(context, witnessTable); - if (!witnessTableInfo) - return none(); - - if (as(witnessTableInfo)) - return makeUnbounded(); - - if (auto tagType = as(witnessTableInfo)) + if (auto elementOfCollectionType = as(witnessTableInfo)) { HashSet results; forEachInCollection( - cast(tagType->getCollection()), + cast(elementOfCollectionType->getCollection()), [&](IRInst* table) { results.add(findWitnessTableEntry(cast(table), key)); }); - return makeTagType(cBuilder.makeSet(results)); + return makeElementOfCollectionType(cBuilder.makeSet(results)); } + if (!witnessTableInfo) + return none(); + + if (as(witnessTableInfo)) + return makeUnbounded(); + SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); } @@ -1550,7 +1718,7 @@ struct TypeFlowSpecializationContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return makeTagType(taggedUnion->getTableCollection()); + return makeElementOfCollectionType(taggedUnion->getWitnessTableCollection()); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } @@ -1576,7 +1744,7 @@ struct TypeFlowSpecializationContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return makeTagType(taggedUnion->getTypeCollection()); + return makeElementOfCollectionType(taggedUnion->getTypeCollection()); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); } @@ -1604,7 +1772,7 @@ struct TypeFlowSpecializationContext return makeUnbounded(); if (auto taggedUnion = as(operandInfo)) - return taggedUnion->getTypeCollection(); + return makeValueOfCollectionType(taggedUnion->getTypeCollection()); return none(); } @@ -1636,9 +1804,6 @@ struct TypeFlowSpecializationContext auto operand = inst->getBase(); auto operandInfo = tryGetInfo(context, operand); - if (!operandInfo) - return none(); - if (as(operandInfo)) return makeUnbounded(); @@ -1648,11 +1813,12 @@ struct TypeFlowSpecializationContext "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); } - if (as(operandInfo) || as(operandInfo)) + // Handle the 'many' or 'one' cases. + if (as(operandInfo) || isGlobalInst(operand)) { // If any of the specialization arguments need a tag (or the generic itself is a tag), // we need the result to also be wrapped in a tag type. - bool needsTag = false; + bool needsElement = false; List specializationArgs; for (UInt i = 0; i < inst->getArgCount(); ++i) @@ -1678,14 +1844,29 @@ struct TypeFlowSpecializationContext SLANG_UNEXPECTED("Unexpected Existential operand in specialization argument."); } - if (auto argCollectionTag = as(argInfo)) + if (auto elementOfCollectionType = as(argInfo)) { - if (argCollectionTag->isSingleton()) - specializationArgs.add(argCollectionTag->getCollection()->getElement(0)); + if (elementOfCollectionType->getCollection()->isSingleton()) + specializationArgs.add( + elementOfCollectionType->getCollection()->getElement(0)); else { - needsTag = true; - specializationArgs.add(argCollectionTag->getCollection()); + needsElement = true; + if (auto typeCollection = + as(elementOfCollectionType->getCollection())) + { + specializationArgs.add(makeValueOfCollectionType(typeCollection)); + } + else if (as( + elementOfCollectionType->getCollection())) + { + specializationArgs.add(elementOfCollectionType->getCollection()); + } + else + { + SLANG_UNEXPECTED( + "Unexpected collection type in specialization argument."); + } } } else @@ -1707,12 +1888,13 @@ struct TypeFlowSpecializationContext { if (auto info = tryGetInfo(context, type)) { - if (auto infoCollectionTag = as(info)) + if (auto elementOfCollectionType = as(info)) { - if (infoCollectionTag->isSingleton()) - return infoCollectionTag->getCollection()->getElement(0); + if (elementOfCollectionType->getCollection()->isSingleton()) + return elementOfCollectionType->getCollection()->getElement(0); else - return infoCollectionTag->getCollection(); + return makeValueOfCollectionType(cast( + elementOfCollectionType->getCollection())); } else return type; @@ -1737,15 +1919,17 @@ struct TypeFlowSpecializationContext // dynamic IRSpecialize. In this situation, we'd want to use the type inst's info to // find the collection-based specialization and create a func-type from it. // - if (auto tag = as(typeInfo)) + if (auto elementOfCollectionType = as(typeInfo)) { - SLANG_ASSERT(tag->isSingleton()); - auto specializeInst = cast(tag->getCollection()->getElement(0)); + SLANG_ASSERT(elementOfCollectionType->getCollection()->isSingleton()); + auto specializeInst = + cast(elementOfCollectionType->getCollection()->getElement(0)); auto specializedFuncType = cast(specializeGeneric(specializeInst)); typeOfSpecialization = specializedFuncType; } else if (auto collection = as(typeInfo)) { + SLANG_UNEXPECTED("shouldn't see this case"); SLANG_ASSERT(collection->isSingleton()); auto specializeInst = cast(collection->getElement(0)); auto specializedFuncType = cast(specializeGeneric(specializeInst)); @@ -1770,36 +1954,42 @@ struct TypeFlowSpecializationContext return none(); } + // Specialize each element in the set + HashSet specializedSet; + IRCollectionBase* collection = nullptr; - if (auto _collection = as(operandInfo)) + if (auto elementOfCollectionType = as(operandInfo)) { - collection = _collection; + collection = elementOfCollectionType->getCollection(); + + forEachInCollection( + collection, + [&](IRInst* arg) + { + // Create a new specialized instruction for each argument + IRBuilder builder(module); + builder.setInsertInto(module); + specializedSet.add(builder.emitSpecializeInst( + typeOfSpecialization, + arg, + specializationArgs)); + }); } - else if (auto collectionTagType = as(operandInfo)) + else { - needsTag = true; - collection = collectionTagType->getCollection(); + // Concrete case.. + IRBuilder builder(module); + builder.setInsertInto(module); + specializedSet.add( + builder.emitSpecializeInst(typeOfSpecialization, operand, specializationArgs)); } - // Specialize each element in the set - HashSet specializedSet; - forEachInCollection( - collection, - [&](IRInst* arg) - { - // Create a new specialized instruction for each argument - IRBuilder builder(module); - builder.setInsertInto(module); - specializedSet.add( - builder.emitSpecializeInst(typeOfSpecialization, arg, specializationArgs)); - }); - - if (needsTag) - return makeTagType(cBuilder.makeSet(specializedSet)); - else - return cBuilder.makeSet(specializedSet); + return makeElementOfCollectionType(cBuilder.makeSet(specializedSet)); } + if (!operandInfo) + return none(); + SLANG_UNEXPECTED("Unhandled PropagationJudgment in analyzeExtractExistentialWitnessTable"); } @@ -1856,14 +2046,19 @@ struct TypeFlowSpecializationContext if (auto collection = as(arg)) { - updateInfo(context, param, makeTagType(collection), true, workQueue); + updateInfo( + context, + param, + makeElementOfCollectionType(collection), + true, + workQueue); } else if (as(arg) || as(arg)) { updateInfo( context, param, - makeTagType(cBuilder.makeSingletonSet(arg)), + makeElementOfCollectionType(cBuilder.makeSingletonSet(arg)), true, workQueue); } @@ -1924,13 +2119,15 @@ struct TypeFlowSpecializationContext // If we have a collection of functions (with or without a dynamic tag), register // each one. // - if (auto collectionTag = as(calleeInfo)) + if (auto elementOfCollectionType = as(calleeInfo)) { - forEachInCollection(collectionTag, [&](IRInst* func) { propagateToCallSite(func); }); + forEachInCollection( + elementOfCollectionType->getCollection(), + [&](IRInst* func) { propagateToCallSite(func); }); } - else if (auto collection = as(calleeInfo)) + else if (isGlobalInst(callee)) { - forEachInCollection(collection, [&](IRInst* func) { propagateToCallSite(func); }); + propagateToCallSite(callee); } if (auto callInfo = tryGetInfo(context, inst)) @@ -1939,6 +2136,23 @@ struct TypeFlowSpecializationContext return none(); } + IRInst* analyzeParam(IRInst* context, IRInst* inst, WorkQueue& workQueue) + { + /* We only need to handle one case, where we calculate the info based on the + // info for the data-type. + // + if (auto info = tryGetInfo(context, inst->getDataType())) + { + if (auto elementOfCollection = as(info)) + { + return makeValueOfCollectionType( + cast(elementOfCollection->getCollection())); + } + } + + return none();*/ + } + // Updates the information for an address. void maybeUpdatePtr(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { @@ -2122,14 +2336,14 @@ struct TypeFlowSpecializationContext if (as(context)) { for (auto param : as(context)->getParams()) - infos.add(tryGetInfo(context, param)); + infos.add(tryGetArgInfo(context, param)); } else if (auto specialize = as(context)) { auto generic = specialize->getBase(); auto innerFunc = getGenericReturnVal(generic); for (auto param : as(innerFunc)->getParams()) - infos.add(tryGetInfo(context, param)); + infos.add(tryGetArgInfo(context, param)); } else { @@ -2311,12 +2525,14 @@ struct TypeFlowSpecializationContext // \ / // C // - // After specialization, A could pass a value of type TagType(TableCollection{T1, T2}) - // while B passes a value of type TagType(TableCollection{T2, T3}), while the phi - // parameter's type in C has the union type `TagType(TableCollection{T1, T2, T3})` + // After specialization, A could pass a value of type TagType(WitnessTableCollection{T1, + // T2}) while B passes a value of type TagType(WitnessTableCollection{T2, T3}), while the + // phi parameter's type in C has the union type `TagType(WitnessTableCollection{T1, T2, + // T3})` // - // In this case, we use `upcastCollection` to insert a cast from TagType(TableCollection{T1, - // T2}) -> TagType(TableCollection{T1, T2, T3}) before passing the result as a phi argument. + // In this case, we use `upcastCollection` to insert a cast from + // TagType(WitnessTableCollection{T1, T2}) -> TagType(WitnessTableCollection{T1, T2, T3}) + // before passing the result as a phi argument. // // The same logic applies for the return values. The function's caller expects a union type // of all possible return statements, so we cast each return inst if there is a mismatch. @@ -2355,7 +2571,9 @@ struct TypeFlowSpecializationContext if (auto unconditionalBranch = as(terminator)) { auto arg = unconditionalBranch->getArg(paramIndex); - auto newArg = upcastCollection(func, arg, param->getDataType()); + IRBuilder builder(module); + builder.setInsertBefore(unconditionalBranch); + auto newArg = upcastCollection(&builder, arg, param->getDataType()); if (newArg != arg) { @@ -2384,8 +2602,10 @@ struct TypeFlowSpecializationContext { if (auto specializedType = getLoweredType(getFuncReturnInfo(func))) { + IRBuilder builder(module); + builder.setInsertBefore(returnInst); auto newReturnVal = - upcastCollection(func, returnInst->getVal(), specializedType); + upcastCollection(&builder, returnInst->getVal(), specializedType); if (newReturnVal != returnInst->getVal()) { // Replace the return value with the reinterpreted value @@ -2442,6 +2662,23 @@ struct TypeFlowSpecializationContext for (auto structType : structsToProcess) hasChanges |= specializeStructType(structType); + /*while (funcsToProcess.getCount() > 0) + { + for (auto func : funcsToProcess) + hasChanges |= specializeFunc(func); + + funcsToProcess.clear(); + for (auto context : contextsToLower) + { + hasChanges = true; + auto dynGenericFunc = specializeDynamicGeneric(cast(context)); + context->replaceUsesWith(dynGenericFunc); + context->removeAndDeallocate(); + funcsToProcess.add(cast(dynGenericFunc)); + } + + contextsToLower.clear(); + }*/ for (auto func : funcsToProcess) hasChanges |= specializeFunc(func); @@ -2490,30 +2727,27 @@ struct TypeFlowSpecializationContext if (auto taggedUnion = as(info)) { - // If this is a tagged union, we need to create a tuple type - // return getTypeForExistential(taggedUnion); return (IRType*)taggedUnion; } - if (auto collectionTag = as(info)) + if (auto elementOfCollectionType = as(info)) { - // If this is a collection tag, we can return the collection type - return (IRType*)collectionTag; + // Replace element-of-collection types with tag types. + return makeTagType(elementOfCollectionType->getCollection()); } - if (auto collection = as(info)) + if (auto valOfCollectionType = as(info)) { - if (getCollectionCount(collection) == 1) + if (valOfCollectionType->getCollection()->isSingleton()) { // If there's only one type in the collection, return it directly - return (IRType*)getCollectionElement(collection, 0); + return (IRType*)valOfCollectionType->getCollection()->getElement(0); } - // If this is a concrete collection, return it directly - return (IRType*)collection; + return valOfCollectionType; } - if (as(info) || as(info)) + if (as(info) || as(info)) { // Don't specialize these collections.. they should be used through // tag types, or be processed out during specializeing. @@ -2581,6 +2815,8 @@ struct TypeFlowSpecializationContext return specializeGetValueFromBoundInterface( context, as(inst)); + case kIROp_GetElementFromTag: + return specializeGetElementFromTag(context, as(inst)); case kIROp_Load: return specializeLoad(context, inst); case kIROp_Store: @@ -2622,17 +2858,19 @@ struct TypeFlowSpecializationContext if (!info) return false; - auto collectionTagType = as(info); - if (!collectionTagType) + // If we didn't resolve anything for this inst, don't modify it. + auto elementOfCollectionType = as(info); + if (!elementOfCollectionType) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); // If there's a single element, we can do a simple replacement. - if (getCollectionCount(collectionTagType) == 1) + if (elementOfCollectionType->getCollection()->getCount() == 1) { - inst->replaceUsesWith(getCollectionElement(collectionTagType, 0)); + auto element = elementOfCollectionType->getCollection()->getElement(0); + inst->replaceUsesWith(element); inst->removeAndDeallocate(); return true; } @@ -2640,7 +2878,7 @@ struct TypeFlowSpecializationContext // If the collection is a type-collection, we'll still do a direct replacement // effectively dropping the tag information // - if (auto typeCollection = as(collectionTagType->getCollection())) + if (auto typeCollection = as(elementOfCollectionType->getCollection())) { // If this is a type collection, we can replace it with the collection type // We don't currently care about the tag of a type. @@ -2658,21 +2896,26 @@ struct TypeFlowSpecializationContext // auto witnessTableInst = inst->getWitnessTable(); - auto witnessTableInfo = tryGetInfo(context, witnessTableInst); + auto witnessTableInfo = witnessTableInst->getDataType(); - SLANG_ASSERT(as(witnessTableInfo)); - List operands = {witnessTableInst, inst->getRequirementKey()}; + if (auto witnessTableOperandTagType = as(witnessTableInfo)) + { + auto thisInstInfo = cast(tryGetInfo(context, inst)); + if (thisInstInfo->getCollection() != nullptr) + { + List operands = {witnessTableInst, inst->getRequirementKey()}; - auto newInst = builder.emitIntrinsicInst( - (IRType*)info, - kIROp_GetTagForMappedCollection, - operands.getCount(), - operands.getBuffer()); - inst->replaceUsesWith(newInst); + auto newInst = builder.emitIntrinsicInst( + (IRType*)makeTagType(thisInstInfo->getCollection()), + kIROp_GetTagForMappedCollection, + operands.getCount(), + operands.getBuffer()); - // We'll register the info for the newInst so any users of the new inst can use it. - propagationMap[InstWithContext(context, newInst)] = info; - inst->removeAndDeallocate(); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } return false; } @@ -2682,7 +2925,7 @@ struct TypeFlowSpecializationContext IRExtractExistentialWitnessTable* inst) { // If we have a non-trivial info registered, it must of - // CollectionTagType(TableCollection(...)) + // CollectionTagType(WitnessTableCollection(...)) // // Futher, the operand must be an existential (CollectionTaggedUnionType), which is // conceptually lowered to a TupleType(TagType(tableCollection), typeCollection) @@ -2697,22 +2940,25 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertBefore(inst); - auto collectionTagType = as(info); - if (!collectionTagType) + auto elementOfCollectionType = as(info); + if (!elementOfCollectionType) return false; - if (getCollectionCount(collectionTagType) == 1) + if (elementOfCollectionType->getCollection()->getCount() == 1) { // Found a single possible type. Simple replacement. - inst->replaceUsesWith(getCollectionElement(collectionTagType, 0)); + inst->replaceUsesWith(elementOfCollectionType->getCollection()->getElement(0)); inst->removeAndDeallocate(); return true; } else { - // Replace with GetElement(specializedInst, 0) -> uint + // Replace with GetElement(specializedInst, 0) -> TagType(tableCollection) + // which retreives a 'tag' (i.e. a run-time identifier for one of the elements + // of the collection) + // auto operand = inst->getOperand(0); - auto element = builder.emitGetTupleElement((IRType*)collectionTagType, operand, 0); + auto element = builder.emitGetTagFromTaggedUnion(operand); inst->replaceUsesWith(element); inst->removeAndDeallocate(); return true; @@ -2727,11 +2973,11 @@ struct TypeFlowSpecializationContext auto existentialInfo = existential->getDataType(); if (as(existentialInfo)) { - auto valType = existentialInfo->getOperand(0); IRBuilder builder(inst); builder.setInsertAfter(inst); - auto val = builder.emitGetTupleElement((IRType*)valType, existential, 1); + // auto val = builder.emitGetTupleElement((IRType*)valType, existential, 1); + auto val = builder.emitGetValueFromTaggedUnion(existential); inst->replaceUsesWith(val); inst->removeAndDeallocate(); return true; @@ -2742,7 +2988,7 @@ struct TypeFlowSpecializationContext bool specializeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { - auto info = tryGetInfo(context, inst); + /*auto info = tryGetInfo(context, inst); auto collectionTagType = as(info); if (!collectionTagType) return false; @@ -2762,31 +3008,53 @@ struct TypeFlowSpecializationContext // Replace the instruction with the collection type. inst->replaceUsesWith(collectionTagType->getCollection()); inst->removeAndDeallocate(); - return true; + return true;*/ + + auto info = tryGetInfo(context, inst); + if (auto elementOfCollectionType = as(info)) + { + if (elementOfCollectionType->getCollection()->isSingleton()) + { + // Found a single possible type. Statically known concrete type. + auto singletonValue = elementOfCollectionType->getCollection()->getElement(0); + inst->replaceUsesWith(singletonValue); + inst->removeAndDeallocate(); + return true; + } + else + { + // Multiple elements, emit a tag inst. + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newInst = builder.emitGetTypeTagFromTaggedUnion(inst->getOperand(0)); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + + return false; } bool isTaggedUnionType(IRInst* type) { - if (auto tupleType = as(type)) - return as(tupleType->getOperand(0)) != nullptr; - - return false; + return as(type) != nullptr; } IRType* updateType(IRType* currentType, IRType* newType) { - if (auto collection = as(currentType)) + if (auto valOfCollectionType = as(currentType)) { HashSet collectionElements; forEachInCollection( - collection, + valOfCollectionType->getCollection(), [&](IRInst* element) { collectionElements.add(element); }); - if (auto newCollection = as(newType)) + if (auto newValOfCollectionType = as(newType)) { // If the new type is also a collection, merge the two collections forEachInCollection( - newCollection, + newValOfCollectionType->getCollection(), [&](IRInst* element) { collectionElements.add(element); }); } else @@ -2796,8 +3064,9 @@ struct TypeFlowSpecializationContext } // If this is a collection, we need to create a new collection with the new type - auto newCollection = cBuilder.createCollection(collection->getOp(), collectionElements); - return (IRType*)newCollection; + auto newCollection = + cBuilder.createCollection(kIROp_TypeCollection, collectionElements); + return makeValueOfCollectionType(cast(newCollection)); } else if (currentType == newType) { @@ -2812,21 +3081,9 @@ struct TypeFlowSpecializationContext as(newType)) { // Merge the elements of both tagged unions into a new tuple type - return (IRType*)makeExistential((as(updateType( - (IRType*)as(currentType)->getTableCollection(), - (IRType*)as(newType)->getTableCollection())))); - } - else if (isTaggedUnionType(currentType) && isTaggedUnionType(newType)) - { - IRBuilder builder(module); - // Merge the elements of both tagged unions into a new tuple type - return builder.getTupleType(List( - {(IRType*)makeTagType(as(updateType( - (IRType*)currentType->getOperand(0)->getOperand(0), - (IRType*)newType->getOperand(0)->getOperand(0)))), - (IRType*)updateType( - (IRType*)currentType->getOperand(1), - (IRType*)newType->getOperand(1))})); + return (IRType*)makeTaggedUnionType((unionCollection( + as(currentType)->getWitnessTableCollection(), + as(newType)->getWitnessTableCollection()))); } else // Need to create a new collection. { @@ -2840,8 +3097,98 @@ struct TypeFlowSpecializationContext // If this is a collection, we need to create a new collection with the new type auto newCollection = cBuilder.createCollection(kIROp_TypeCollection, collectionElements); - return (IRType*)newCollection; + return makeValueOfCollectionType(cast(newCollection)); } + + SLANG_UNEXPECTED("Unhandled case in updateType"); + } + + IRFuncType* getEffectiveFuncTypeForDispatcher( + IRWitnessTableCollection* tableCollection, + IRStructKey* key, + IRFuncCollection* resultFuncCollection) + { + List specArgs; + return getEffectiveFuncTypeForDispatcher( + tableCollection, + key, + resultFuncCollection, + specArgs); + } + + IRFuncType* getEffectiveFuncTypeForDispatcher( + IRWitnessTableCollection* tableCollection, + IRStructKey* key, + IRFuncCollection* resultFuncCollection, + List& specArgs) + { + SLANG_UNUSED(key); + + List paramTypes; + IRType* resultType = nullptr; + + List extraParamTypes; + extraParamTypes.add((IRType*)makeTagType(tableCollection)); + + for (auto specArg : specArgs) + if (as(specArg)) + extraParamTypes.add((IRType*)makeTagType(as(specArg))); + + IRBuilder builder(module); + auto updateParamType = [&](Index index, IRType* paramType) -> IRType* + { + if (paramTypes.getCount() <= index) + { + // If this index hasn't been seen yet, expand the buffer and initialize + // the type. + // + paramTypes.growToCount(index + 1); + paramTypes[index] = paramType; + return paramType; + } + else + { + // Otherwise, update the existing type + auto [currentDirection, currentType] = + splitParameterDirectionAndType(paramTypes[index]); + auto [newDirection, newType] = splitParameterDirectionAndType(paramType); + auto updatedType = updateType(currentType, newType); + SLANG_ASSERT(currentDirection == newDirection); + paramTypes[index] = fromDirectionAndType(&builder, currentDirection, updatedType); + return updatedType; + } + }; + + forEachInCollection( + resultFuncCollection, + [&](IRInst* func) + { + auto paramEffectiveTypes = getEffectiveParamTypes(func); + auto paramDirections = getParamDirections(func); + + for (Index i = 0; i < paramEffectiveTypes.getCount(); i++) + updateParamType(i, getLoweredType(paramEffectiveTypes[i])); + + auto returnType = getFuncReturnInfo(func); + if (auto newResultType = getLoweredType(returnType)) + { + resultType = updateType(resultType, newResultType); + } + else if (auto funcType = as(func->getDataType())) + { + SLANG_ASSERT(isGlobalInst(funcType->getResultType())); + resultType = updateType(resultType, funcType->getResultType()); + } + else + { + SLANG_UNEXPECTED("Cannot determine result type for context"); + } + }); + + List allParamTypes; + allParamTypes.addRange(extraParamTypes); + allParamTypes.addRange(paramTypes); + return builder.getFuncType(allParamTypes, resultType); } // Get an effective func type to use for the callee. @@ -2902,6 +3249,8 @@ struct TypeFlowSpecializationContext } else if (auto collectionTagType = as(callee->getDataType())) { + SLANG_UNEXPECTED( + "Should never try to directly call a tag type. Not semantically meaningful"); forEachInCollection( collectionTagType, [&](IRInst* func) { contextsToProcess.add(func); }); @@ -2951,7 +3300,7 @@ struct TypeFlowSpecializationContext } // If the any of the elements in the callee (or the callee itself in case - // of a singleton) is a dynamic specialization, each non-singleton TableCollection, + // of a singleton) is a dynamic specialization, each non-singleton WitnessTableCollection, // requries a corresponding tag input. // auto calleeToCheck = as(callee) @@ -2962,10 +3311,10 @@ struct TypeFlowSpecializationContext auto specializeInst = as(calleeToCheck); // If this is a dynamic generic, we need to add a tag type for each - // TableCollection in the callee. + // WitnessTableCollection in the callee. // for (UIndex i = 0; i < specializeInst->getArgCount(); i++) - if (auto tableCollection = as(specializeInst->getArg(i))) + if (auto tableCollection = as(specializeInst->getArg(i))) extraParamTypes.add((IRType*)makeTagType(tableCollection)); } @@ -2976,130 +3325,6 @@ struct TypeFlowSpecializationContext return builder.getFuncType(allParamTypes, resultType); } - // Upcast the value in 'arg' to match the destInfo type. This method inserts - // any necessary reinterprets or tag translation instructions. - // - IRInst* upcastCollection(IRInst* context, IRInst* arg, IRType* destInfo) - { - // The upcasting process inserts the appropriate instructions - // to make arg's type match the type provided by destInfo. - // - // This process depends on the structure of arg and destInfo. - // - // We only deal with the type-flow data-types that are created in - // our pass (CollectionBase/CollectionTaggedUnionType/CollectionTagType/any other - // composites of these insts) - // - - auto argInfo = arg->getDataType(); - if (!argInfo || !destInfo) - return arg; - - if (as(argInfo) && as(destInfo)) - { - // A collection tagged union is essentially a tuple(TagType(tableCollection), - // typeCollection) We simply extract the two components, upcast each one, and put it - // back together. - // - - auto argTUType = as(argInfo); - auto destTUType = as(destInfo); - - if (getCollectionCount(argTUType) != getCollectionCount(destTUType)) - { - // Technically, IRCollectionTaggedUnionType is not a TupleType, - // but in practice it works the same way so we'll re-use Slang's - // tuple accessors & constructors - // - IRBuilder builder(arg->getModule()); - setInsertAfterOrdinaryInst(&builder, arg); - auto argTableTag = builder.emitGetTupleElement( - (IRType*)makeTagType(argTUType->getTableCollection()), - arg, - 0); - auto reinterpretedTag = upcastCollection( - context, - argTableTag, - (IRType*)makeTagType(destTUType->getTableCollection())); - - auto argVal = - builder.emitGetTupleElement((IRType*)argTUType->getTypeCollection(), arg, 1); - auto reinterpretedVal = - upcastCollection(context, argVal, (IRType*)destTUType->getTypeCollection()); - return builder.emitMakeTuple( - (IRType*)destTUType, - {reinterpretedTag, reinterpretedVal}); - } - } - else if (as(argInfo) && as(destInfo)) - { - // If the arg represents a tag of a colleciton, but the dest is a _different_ - // collection, then we need to emit a tag operation to reinterpret the - // tag. - // - // Note that, by the invariant provided by the typeflow analysis, the target - // collection must necessarily be a super-set. - // - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder - .emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); - } - } - else if (as(argInfo) && as(destInfo)) - { - // If the arg has a collection type, but the dest is a _different_ collection, - // we need to perform a reinterpret. - // - // e.g. TypeCollection({T1, T2}) may lower to AnyValueType(N), while - // TypeCollection({T1, T2, T3}) may lower to AnyValueType(M). Since the target - // is necessarily a super-set, the target any-value-type is always larger (M >= N), - // so we only need a simple reinterpret. - // - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - auto argCollection = as(argInfo); - if (argCollection->isSingleton() && as(argCollection->getElement(0))) - { - // There's a specific case where we're trying to reinterpret a value of 'void' - // type. We'll avoid emitting a reinterpret in this case, and emit a - // default-construct instead. - // - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder.emitDefaultConstruct((IRType*)destInfo); - } - - // General case: - // - // If the sets of witness tables are not equal, reinterpret to the - // parameter type - // - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder.emitReinterpret((IRType*)destInfo, arg); - } - } - else if (!as(argInfo) && as(destInfo)) - { - // If the arg is not a collection-type, but the dest is a collection, - // we need to perform a pack operation. - // - // This case only arises when passing a value of type T to a parameter - // of a type-collection that contains T. - // - IRBuilder builder(module); - setInsertAfterOrdinaryInst(&builder, arg); - return builder.emitPackAnyValue((IRType*)destInfo, arg); - } - - return arg; // Can use as-is. - } - // Helper function for specializing calls. // // For a `Specialize` instruction that has dynamic tag arguments, @@ -3116,8 +3341,9 @@ struct TypeFlowSpecializationContext // Pull all tag-type arguments from the specialization arguments // and add them to the call arguments. // - if (as(argInfo)) - callArgs.add(specArg); + if (auto tagType = as(argInfo)) + if (as(tagType->getCollection())) + callArgs.add(specArg); } return callArgs; @@ -3176,16 +3402,16 @@ struct TypeFlowSpecializationContext // replaced with their static collections. // // // --- before specialization --- - // let s1 : TagType(TableCollection(tA, tB, tC)) = /* ... */; + // let s1 : TagType(WitnessTableCollection(tA, tB, tC)) = /* ... */; // let s2 : TagType(TypeCollection(A, B, C)) = /* ... */; // let specCallee = Specialize(generic, s1, s2); // let val = Call(specCallee, /* call args */); // // // --- after specialization --- - // let s1 : TagType(TableCollection(tA, tB, tC)) = /* ... */; + // let s1 : TagType(WitnessTableCollection(tA, tB, tC)) = /* ... */; // let s2 : TagType(TypeCollection(A, B, C)) = /* ... */; // let newSpecCallee = Specialize(generic, - // TableCollection(tA, tB, tC), TypeCollection(A, B, C)); + // WitnessTableCollection(tA, tB, tC), TypeCollection(A, B, C)); // let newVal = Call(newSpecCallee, s1, /* call args */); // // @@ -3209,7 +3435,12 @@ struct TypeFlowSpecializationContext // auto callee = inst->getCallee(); - IRInst* calleeTagInst = nullptr; + if (as(callee) || + as(callee)) // Already specialized + return false; + + // IRInst* calleeTagInst = nullptr; + List callArgs; // This is a bit of a workaround for specialized callee's // whose function types haven't been specialized yet (can @@ -3225,8 +3456,120 @@ struct TypeFlowSpecializationContext if (auto collectionTag = as(callee->getDataType())) { if (!collectionTag->isSingleton()) - calleeTagInst = callee; // Only keep the tag if there are multiple elements. - callee = collectionTag->getCollection(); + { + // Multiple callees case: + // + // If we need to use a tag, we'll do a bit of an optimization here.. + // + // Instead of building a dispatcher on then func-collection, we'll + // build it on the table collection that it is looked up from. This + // avoids the extra map. + // + // This works primarily because this is the only way to call a dynamic + // function. If we ever have the ability to pass functions around more + // flexibly, then this should just become a specific case. + + // TODO: Also handle the case where we need to perform a static specialization. + + if (auto tagMapOperand = as(callee)) + { + auto tableTag = tagMapOperand->getOperand(0); + auto lookupKey = cast(tagMapOperand->getOperand(1)); + + auto tableCollection = cast( + cast(tableTag->getDataType())->getCollection()); + IRBuilder builder(module); + + callee = builder.emitGetDispatcher( + getEffectiveFuncTypeForDispatcher( + tableCollection, + lookupKey, + cast(collectionTag->getCollection())), + tableCollection, + lookupKey); + + callArgs.add(tableTag); + } + else if ( + auto specializedTagMapOperand = as(callee)) + { + auto tagMapOperand = + cast(specializedTagMapOperand->getOperand(0)); + auto tableTag = tagMapOperand->getOperand(0); + auto tableCollection = cast( + cast(tableTag->getDataType())->getCollection()); + auto lookupKey = cast(tagMapOperand->getOperand(1)); + + List specArgs; + for (auto argIdx = 1; argIdx < specializedTagMapOperand->getOperandCount(); + ++argIdx) + { + auto arg = specializedTagMapOperand->getOperand(argIdx); + if (auto tagType = as(arg->getDataType())) + { + SLANG_ASSERT(!tagType->getCollection()->isSingleton()); + if (as(tagType->getCollection())) + { + callArgs.add(arg); + specArgs.add(tagType->getCollection()); + } + else + { + specArgs.add(tagType->getCollection()); + } + } + else + { + SLANG_ASSERT(isGlobalInst(arg)); + specArgs.add(arg); + } + } + + IRBuilder builder(module); + builder.setInsertBefore(callee); + callee = builder.emitGetSpecializedDispatcher( + getEffectiveFuncTypeForDispatcher( + tableCollection, + lookupKey, + cast(collectionTag->getCollection()), + specArgs), + tableCollection, + lookupKey, + specArgs); + + callArgs.add(tableTag); + } + else + { + SLANG_UNEXPECTED( + "Cannot specialize call with non-singleton collection tag callee"); + } + } + else if (isDynamicGeneric(collectionTag->getCollection()->getElement(0))) + { + // Single element which is a dynamic generic specialization. + callArgs.addRange(getArgsForDynamicSpecialization(cast(callee))); + callee = collectionTag->getCollection()->getElement(0); + + auto funcType = getEffectiveFuncType(callee); + callee->setFullType(funcType); + + // contextsToLower.add(callee); + } + else + { + SLANG_ASSERT("Shouldn't get here.."); + // Single element which is a concrete function. + callee = collectionTag->getCollection()->getElement(0); + + auto funcType = getEffectiveFuncType(callee); + callee->setFullType(funcType); + } + } + else + { + auto funcType = getEffectiveFuncType(callee); + callee->setFullType(funcType); } // If by this point, we haven't resolved our callee into a global inst ( @@ -3241,8 +3584,9 @@ struct TypeFlowSpecializationContext if (as(callee)) return false; - auto expectedFuncType = getEffectiveFuncType(callee); + // auto expectedFuncType = getEffectiveFuncType(callee); + /* List newArgs; IRInst* newCallee = nullptr; @@ -3292,7 +3636,46 @@ struct TypeFlowSpecializationContext newCallee = funcCollection; } + */ + + // First, we'll legalize all operands by upcasting if necessary. + // This needs to be done even if the callee is not a collection. + // + UCount extraArgCount = callArgs.getCount(); + for (UInt i = 0; i < inst->getArgCount(); i++) + { + auto arg = inst->getArg(i); + const auto [paramDirection, paramType] = splitParameterDirectionAndType( + cast(callee->getFullType())->getParamType(i + extraArgCount)); + + switch (paramDirection.kind) + { + // We'll upcast any in-parameters. + case ParameterDirectionInfo::Kind::In: + { + IRBuilder builder(context); + builder.setInsertBefore(inst); + callArgs.add(upcastCollection(&builder, arg, paramType)); + break; + } + + // Out parameters are handled at the callee's end + case ParameterDirectionInfo::Kind::Out: + + // For all other modes, collections must match ('subtyping' is not allowed) + case ParameterDirectionInfo::Kind::BorrowInOut: + case ParameterDirectionInfo::Kind::BorrowIn: + case ParameterDirectionInfo::Kind::Ref: + { + callArgs.add(arg); + break; + } + default: + SLANG_UNEXPECTED("Unhandled parameter direction in specializeCall"); + } + } + /* // First, we'll legalize all operands by upcasting if necessary. // This needs to be done even if the callee is not a collection. // @@ -3320,18 +3703,19 @@ struct TypeFlowSpecializationContext SLANG_UNEXPECTED("Unhandled parameter direction in specializeCall"); } } + */ IRBuilder builder(inst); builder.setInsertBefore(inst); bool changed = false; - if (((UInt)newArgs.getCount()) != inst->getArgCount()) + if (((UInt)callArgs.getCount()) != inst->getArgCount()) changed = true; else { - for (Index i = 0; i < newArgs.getCount(); i++) + for (Index i = 0; i < callArgs.getCount(); i++) { - if (newArgs[i] != inst->getArg((UInt)i)) + if (callArgs[i] != inst->getArg((UInt)i)) { changed = true; break; @@ -3339,25 +3723,26 @@ struct TypeFlowSpecializationContext } } - if (newCallee != inst->getCallee()) + if (callee != inst->getCallee()) { changed = true; } + auto calleeFuncType = cast(callee->getFullType()); + if (changed) { - auto newCall = - builder.emitCallInst(expectedFuncType->getResultType(), newCallee, newArgs); + auto newCall = builder.emitCallInst(calleeFuncType->getResultType(), callee, callArgs); inst->replaceUsesWith(newCall); inst->removeAndDeallocate(); return true; } - else if (expectedFuncType->getResultType() != inst->getDataType()) + else if (calleeFuncType->getResultType() != inst->getFullType()) { // If we didn't change the callee or the arguments, we still might // need to update the result type. // - inst->setFullType(expectedFuncType->getResultType()); + inst->setFullType(calleeFuncType->getResultType()); return true; } else @@ -3384,7 +3769,9 @@ struct TypeFlowSpecializationContext for (auto field : structType->getFields()) { auto arg = inst->getOperand(operandIndex); - auto newArg = upcastCollection(context, arg, field->getFieldType()); + IRBuilder builder(context); + builder.setInsertBefore(inst); + auto newArg = upcastCollection(&builder, arg, field->getFieldType()); if (arg != newArg) { @@ -3401,7 +3788,7 @@ struct TypeFlowSpecializationContext bool specializeMakeExistential(IRInst* context, IRMakeExistential* inst) { // After specialization, existentials (that are not unbounded) are treated as tuples - // of a TableCollection tag and a value of type TypeCollection. + // of a WitnessTableCollection tag and a value of type TypeCollection. // // A MakeExistential is just converted into a MakeTuple, with any necessary // upcasts. @@ -3416,29 +3803,33 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(inst); // Collect types from the witness tables to determine the any-value type - auto tableCollection = taggedUnion->getTableCollection(); + auto tableCollection = taggedUnion->getWitnessTableCollection(); auto typeCollection = taggedUnion->getTypeCollection(); - IRInst* witnessTableID = nullptr; + IRInst* witnessTableTag = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { auto singletonTagType = makeTagType(cBuilder.makeSingletonSet(witnessTable)); - auto zeroValueOfTagType = builder.getIntValue((IRType*)singletonTagType, 0); - witnessTableID = builder.emitIntrinsicInst( + IRInst* tagValue = builder.emitGetTagOfElementInCollection( + (IRType*)singletonTagType, + witnessTable, + tableCollection); + witnessTableTag = builder.emitIntrinsicInst( (IRType*)makeTagType(tableCollection), kIROp_GetTagForSuperCollection, 1, - &zeroValueOfTagType); + &tagValue); } else if (as(inst->getWitnessTable()->getDataType())) { // Dynamic. Use the witness table inst as a tag - witnessTableID = inst->getWitnessTable(); + witnessTableTag = inst->getWitnessTable(); } // Create the appropriate any-value type - auto collectionType = typeCollection->isSingleton() ? (IRType*)typeCollection->getElement(0) - : (IRType*)typeCollection; + auto collectionType = typeCollection->isSingleton() + ? (IRType*)typeCollection->getElement(0) + : builder.getValueOfCollectionType((IRType*)typeCollection); // Pack the value auto packedValue = as(collectionType) @@ -3447,9 +3838,7 @@ struct TypeFlowSpecializationContext auto taggedUnionType = getLoweredType(taggedUnion); - // Create tuple (table_unique_id, PackAnyValue(val)) - IRInst* tupleArgs[] = {witnessTableID, packedValue}; - auto tuple = builder.emitMakeTuple(taggedUnionType, 2, tupleArgs); + auto tuple = builder.emitMakeTaggedUnion(taggedUnionType, witnessTableTag, packedValue); inst->replaceUsesWith(tuple); inst->removeAndDeallocate(); @@ -3478,31 +3867,30 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertBefore(inst); - List args; - args.add(inst->getDataType()); - args.add(inst->getTypeID()); + IRInst* args[] = {inst->getDataType(), inst->getTypeID()}; auto translatedTag = builder.emitIntrinsicInst( - (IRType*)makeTagType(taggedUnionType->getTableCollection()), + (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), kIROp_GetTagFromSequentialID, - args.getCount(), - args.getBuffer()); + 2, + args); IRInst* packedValue = nullptr; auto collection = taggedUnionType->getTypeCollection(); if (!collection->isSingleton()) { - packedValue = builder.emitPackAnyValue((IRType*)collection, inst->getValue()); + packedValue = builder.emitPackAnyValue( + (IRType*)builder.getValueOfCollectionType(collection), + inst->getValue()); } else { packedValue = builder.emitReinterpret( - (IRType*)taggedUnionType->getTypeCollection(), + (IRType*)builder.getValueOfCollectionType(collection), inst->getValue()); } - auto newInst = builder.emitMakeTuple( - (IRType*)taggedUnionType, - List({translatedTag, packedValue})); + auto newInst = + builder.emitMakeTaggedUnion((IRType*)taggedUnionType, translatedTag, packedValue); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -3622,51 +4010,96 @@ struct TypeFlowSpecializationContext { if (auto info = tryGetInfo(context, inst)) { - // If our inst represents a collection directly (no run-time info), - // there's nothing to do except replace the type (if necessary) - // - if (as(info)) - return replaceType(context, inst); + if (auto elementOfCollectionType = as(info)) + { + // TODO: Should we make it such that the `GetTagForSpecializedCollection` + // is emitted in the single func case too? + // + // Basically, as long as any of the specialization operands are dynamic, + // we should probably emit a tag. + // + // Currently, if the func is a singleton, we leave it as a Specialize inst + // with dynamic args to be handled in specializeCall. + // + if (elementOfCollectionType->getCollection()->isSingleton()) + { + // If the result is a singleton collection, we can just + // replace the type (if necessary) and be done with it. + return replaceType(context, inst); + } + else + { + // Otherwise, we'll emit a tag mapping instruction. + IRBuilder builder(inst); + setInsertBeforeOrdinaryInst(&builder, inst); - auto specializedCollectionTag = as(info); + List specOperands; + specOperands.add(inst->getBase()); - // If the inst represents a singleton collection, there's nothing - // to do except replace the type (if necessary) - // - if (getCollectionCount(specializedCollectionTag) <= 1) - return replaceType(context, inst); + for (auto ii = 0; ii < inst->getArgCount(); ii++) + specOperands.add(inst->getArg(ii)); - List mappingOperands; + auto newInst = builder.emitIntrinsicInst( + (IRType*)makeTagType(elementOfCollectionType->getCollection()), + kIROp_GetTagForSpecializedCollection, + specOperands.getCount(), + specOperands.getBuffer()); - // Add the base tag as the first operand. The mapping operands follow - mappingOperands.add(inst->getBase()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + } + } + else + { + SLANG_UNEXPECTED( + "Expected element-of-collection type for function specialization"); + } + } + /* + // If our inst represents a collection directly (no run-time info), + // there's nothing to do except replace the type (if necessary) + // + if (as(info)) + return replaceType(context, inst); - forEachInCollection( - specializedCollectionTag, - [&](IRInst* element) - { - // Emit the GetTagForSpecializedCollection for each element. - auto specInst = cast(element); - auto baseGeneric = cast(specInst->getBase()); + auto specializedCollectionTag = as(info); - mappingOperands.add(baseGeneric); - mappingOperands.add(specInst); - }); + // If the inst represents a singleton collection, there's nothing + // to do except replace the type (if necessary) + // + if (getCollectionCount(specializedCollectionTag) <= 1) + return replaceType(context, inst); - IRBuilder builder(inst); - setInsertBeforeOrdinaryInst(&builder, inst); - auto newInst = builder.emitIntrinsicInst( - (IRType*)info, - kIROp_GetTagForSpecializedCollection, - mappingOperands.getCount(), - mappingOperands.getBuffer()); + List mappingOperands; - inst->replaceUsesWith(newInst); - inst->removeAndDeallocate(); - return true; - } - else - return false; + // Add the base tag as the first operand. The mapping operands follow + mappingOperands.add(inst->getBase()); + + forEachInCollection( + specializedCollectionTag, + [&](IRInst* element) + { + // Emit the GetTagForSpecializedCollection for each element. + auto specInst = cast(element); + auto baseGeneric = cast(specInst->getBase()); + + mappingOperands.add(baseGeneric); + mappingOperands.add(specInst); + }); + + IRBuilder builder(inst); + setInsertBeforeOrdinaryInst(&builder, inst); + auto newInst = builder.emitIntrinsicInst( + (IRType*)info, + kIROp_GetTagForSpecializedCollection, + mappingOperands.getCount(), + mappingOperands.getBuffer()); + + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return true; + */ } // For all other specializations, we'll 'drop' the dynamic tag information. @@ -3680,7 +4113,16 @@ struct TypeFlowSpecializationContext { // If this is a tag type, replace with collection. changed = true; - args.add(collectionTagType->getCollection()); + if (as(collectionTagType->getCollection())) + { + args.add(collectionTagType->getCollection()); + } + else if ( + auto typeCollection = as(collectionTagType->getCollection())) + { + IRBuilder builder(inst); + args.add(builder.getValueOfCollectionType(typeCollection)); + } } else { @@ -3714,13 +4156,15 @@ struct TypeFlowSpecializationContext // SLANG_UNUSED(context); - auto destType = inst->getDataType(); + // auto destType = inst->getDataType(); auto operandInfo = inst->getOperand(0)->getDataType(); if (as(operandInfo)) { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - auto newInst = builder.emitGetTupleElement((IRType*)destType, inst->getOperand(0), 1); + /*auto newInst = builder.emitGetTupleElement((IRType*)destType, inst->getOperand(0), + * 1);*/ + auto newInst = builder.emitGetValueFromTaggedUnion(inst->getOperand(0)); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); return true; @@ -3728,6 +4172,13 @@ struct TypeFlowSpecializationContext return false; } + bool specializeGetElementFromTag(IRInst* context, IRGetElementFromTag* inst) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + return true; + } + bool specializeLoad(IRInst* context, IRInst* inst) { // There's two cases to handle.. @@ -3840,7 +4291,9 @@ struct TypeFlowSpecializationContext if (as(inst->getVal())) return handleDefaultStore(context, inst); - auto specializedVal = upcastCollection(context, inst->getVal(), ptrInfo); + IRBuilder builder(context); + builder.setInsertBefore(inst); + auto specializedVal = upcastCollection(&builder, inst->getVal(), ptrInfo); if (specializedVal != inst->getVal()) { @@ -3939,18 +4392,21 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(inst); // Create a tuple for the empty type.. - SLANG_ASSERT(taggedUnionType->getTableCollection()->isSingleton()); - auto noneWitnessTable = taggedUnionType->getTableCollection()->getElement(0); + SLANG_ASSERT(taggedUnionType->getWitnessTableCollection()->isSingleton()); + auto noneWitnessTable = taggedUnionType->getWitnessTableCollection()->getElement(0); auto singletonTagType = makeTagType(cBuilder.makeSingletonSet(noneWitnessTable)); - auto zeroValueOfTagType = builder.getIntValue((IRType*)singletonTagType, 0); + IRInst* zeroValueOfTagType = builder.emitGetTagOfElementInCollection( + (IRType*)singletonTagType, + noneWitnessTable, + taggedUnionType->getWitnessTableCollection()); - List tupleOperands; - tupleOperands.add(zeroValueOfTagType); - tupleOperands.add( - builder.emitDefaultConstruct((IRType*)taggedUnionType->getTypeCollection())); + auto newTuple = builder.emitMakeTaggedUnion( + (IRType*)taggedUnionType, + zeroValueOfTagType, + builder.emitDefaultConstruct( + makeValueOfCollectionType(taggedUnionType->getTypeCollection()))); - auto newTuple = builder.emitMakeTuple((IRType*)taggedUnionType, tupleOperands); inst->replaceUsesWith(newTuple); propagationMap[InstWithContext(context, newTuple)] = taggedUnionType; inst->removeAndDeallocate(); @@ -4017,8 +4473,8 @@ struct TypeFlowSpecializationContext // we just return a true. // // 2. 'none' is a possibility. In this case, we create a 0 value of - // type TagType(TableCollection(NoneWitness)) and then upcast it - // to TagType(inputTableCollection). This will convert the value + // type TagType(WitnessTableCollection(NoneWitness)) and then upcast it + // to TagType(inputWitnessTableCollection). This will convert the value // to the corresponding value of 'none' in the input's table collection // allowing us to directly compare it against the tag part of the // input tagged union. @@ -4028,7 +4484,7 @@ struct TypeFlowSpecializationContext bool containsNone = false; forEachInCollection( - taggedUnionType->getTableCollection(), + taggedUnionType->getWitnessTableCollection(), [&](IRInst* wt) { if (wt == getNoneWitness()) @@ -4052,8 +4508,10 @@ struct TypeFlowSpecializationContext // the value for 'none' (in the context of the tag's collection) // builder.setInsertBefore(inst); + + // TODO: Use proper op-codes and don't rely on tuple ops. auto dynTag = builder.emitGetTupleElement( - (IRType*)makeTagType(taggedUnionType->getTableCollection()), + (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), inst->getOptionalOperand(), 0); @@ -4066,7 +4524,7 @@ struct TypeFlowSpecializationContext // value to the corresponding value for the larger set) // auto noneWitnessTag = builder.emitIntrinsicInst( - (IRType*)makeTagType(taggedUnionType->getTableCollection()), + (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), kIROp_GetTagForSuperCollection, 1, &noneSingletonWitnessTag); diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index 888ddfdf901..470501676e5 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -4,6 +4,7 @@ #include "slang-ir-clone.h" #include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" +#include "slang-ir-typeflow-collection.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -126,14 +127,12 @@ struct GenerateWitnessTableWrapperContext // DUPLICATES... put into common file. + static bool isTaggedUnionType(IRInst* type) { - if (auto tupleType = as(type)) - return as(tupleType->getOperand(0)) != nullptr; - - return false; + return as(type) != nullptr; } - +/* static UCount getCollectionCount(IRCollectionBase* collection) { if (!collection) @@ -207,12 +206,13 @@ static IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInf } else if (!as(argInfo) && as(destInfo)) { + SLANG_UNEXPECTED("Raw collections should not appear"); return builder->emitPackAnyValue((IRType*)destInfo, arg); } return arg; // Can use as-is. } - +*/ // Represents a work item for packing `inout` or `out` arguments after a concrete call. struct ArgumentPackWorkItem @@ -231,7 +231,7 @@ struct ArgumentPackWorkItem bool isAnyValueType(IRType* type) { - if (as(type) || as(type)) + if (as(type) || as(type)) return true; return false; } @@ -406,9 +406,7 @@ IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* inte auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); builder->emitReturn(pack); } - else if ( - isTaggedUnionType(call->getDataType()) && - isTaggedUnionType(funcTypeInInterface->getResultType())) + else if (call->getDataType() != funcTypeInInterface->getResultType()) { auto reinterpret = upcastCollection(builder, call, funcTypeInInterface->getResultType()); builder->emitReturn(reinterpret); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 8c96a3fe4ee..f6bcd6b5be5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8816,6 +8816,21 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetPerVertexInputArray: case kIROp_MetalCastToDepthTexture: case kIROp_GetCurrentStage: + case kIROp_GetDispatcher: + case kIROp_GetSpecializedDispatcher: + case kIROp_GetTagForMappedCollection: + case kIROp_GetTagForSpecializedCollection: + case kIROp_GetTagForSuperCollection: + case kIROp_GetTagFromSequentialID: + case kIROp_GetSequentialIDFromTag: + case kIROp_CastInterfaceToTaggedUnionPtr: + case kIROp_CastTaggedUnionToInterfacePtr: + case kIROp_GetElementFromTag: + case kIROp_GetTagFromTaggedUnion: + case kIROp_GetTypeTagFromTaggedUnion: + case kIROp_GetValueFromTaggedUnion: + case kIROp_MakeTaggedUnion: + case kIROp_GetTagOfElementInCollection: return false; case kIROp_ForwardDifferentiate: diff --git a/tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang b/tests/language-feature/dynamic-dispatch/dependent-assoc-types-1.slang similarity index 100% rename from tests/language-feature/dynamic-dispatch/dependent-assoc-types.slang rename to tests/language-feature/dynamic-dispatch/dependent-assoc-types-1.slang diff --git a/tests/language-feature/dynamic-dispatch/generic-method.slang b/tests/language-feature/dynamic-dispatch/generic-method.slang new file mode 100644 index 00000000000..a2df3b97c5e --- /dev/null +++ b/tests/language-feature/dynamic-dispatch/generic-method.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IInterface +{ + T calc(T x); +} + +struct A : IInterface +{ + T calc(T x) { return x * x * x; } +}; + +struct B : IInterface +{ + T calc(T x) { return x * x; } +}; + +struct C : IInterface +{ + T calc(T x) { return x; } +}; + +float f(uint id, float x) +{ + IInterface obj; + + if (id == 0) + obj = A(); + else if (id == 1) + obj = B(); + else if (id == 2) + obj = C(); + + return obj.calc(x); +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(0, 2); // CHECK: 8 + outputBuffer[1] = f(1, 2); // CHECK: 4 +} \ No newline at end of file From 44b0108e607599ae649d98910006226a7e661ffe Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:48:14 -0400 Subject: [PATCH 072/105] Some more testing. Change tags to use global unique IDs instead of local IDs --- .../slang/slang-ir-lower-typeflow-insts.cpp | 221 +------- source/slang/slang-ir-lower-typeflow-insts.h | 4 +- source/slang/slang-ir-typeflow-collection.h | 2 +- source/slang/slang-ir-typeflow-specialize.cpp | 509 ++++-------------- 4 files changed, 125 insertions(+), 611 deletions(-) diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index f664157c9a8..d56648df978 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -210,104 +210,24 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi return func; } -// This context lowers `GetTagForSpecializedCollection`, +// This context lowers `GetTagOfElementInCollection`, // `GetTagForSuperCollection`, and `GetTagForMappedCollection` instructions, // struct TagOpsLoweringContext : public InstPassBase { TagOpsLoweringContext(IRModule* module) - : InstPassBase(module) + : InstPassBase(module), cBuilder(module) { } void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) { - // We use the result type and the type of the operand - // to figure out the source and destination collections. - // - // We then replace this with an array access, where the i'th - // element of the array is the corresponding index in the super - // collection. - // - // e.g. - // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; - // let b : TagType(WitnessTableCollection(A, B, C)) = GetTagForSuperCollection(a); - // becomes - // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; - // let lookupArr : ArrayType = [1, 2]; // B is at index 1, C is at index 2 - // let b : TagType(WitnessTableCollection(A, B, C)) = ElementExtract(lookupArr, a); - // - // Note that we leave the tag-types of the output intact since we may need to lower - // later tag operations. - // - - auto srcCollection = cast( - cast(inst->getOperand(0)->getDataType())->getOperand(0)); - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); - - IRBuilder builder(inst->getModule()); - builder.setInsertAfter(inst); - - List indices; - for (UInt i = 0; i < srcCollection->getCount(); i++) - { - // Find in destCollection - auto srcElement = srcCollection->getElement(i); - - bool found = false; - for (UInt j = 0; j < destCollection->getCount(); j++) - { - auto destElement = destCollection->getElement(j); - if (srcElement == destElement) - { - found = true; - indices.add(builder.getIntValue(builder.getUIntType(), j)); - break; // Found the index - } - } - - if (!found) - { - // destCollection must be a super-set - SLANG_UNEXPECTED("Element not found in destination collection"); - } - } - - // Create an array for the lookup - auto lookupArrayType = builder.getArrayType( - builder.getUIntType(), - builder.getIntValue(builder.getUIntType(), indices.getCount())); - auto lookupArray = - builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); - auto resultID = - builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); - inst->replaceUsesWith(resultID); + inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); } void lowerGetTagForMappedCollection(IRGetTagForMappedCollection* inst) { - // We use the result type and the type of the operand - // to figure out the source and destination collections. - // - // We then replace this with an array access, where the i'th - // element of the array is the corresponding index in the mapped - // collection. - // - // e.g. - // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; - // let b : TagType(FuncCollection(C_key, B_key)) = - // GetTagForMappedCollection(a, key); - // becomes - // let a : TagType(WitnessTableCollection(B, C)) = /* ... */; - // let lookupArr : ArrayType = [1, 0]; // B is at index 1, C is at index 0 - // let b : TagType(FuncCollection(C_key, B_key)) = ElementExtract(lookupArr, a); - // - // Note that we leave the tag-types of the output intact since we may need to lower - // later tag operations. - // - auto srcCollection = cast( cast(inst->getOperand(0)->getDataType())->getOperand(0)); auto destCollection = @@ -317,20 +237,25 @@ struct TagOpsLoweringContext : public InstPassBase IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); - List indices; + Dictionary mapping; for (UInt i = 0; i < srcCollection->getCount(); i++) { // Find in destCollection bool found = false; - auto srcElement = + auto srcMappedElement = findWitnessTableEntry(cast(srcCollection->getElement(i)), key); for (UInt j = 0; j < destCollection->getCount(); j++) { auto destElement = destCollection->getElement(j); - if (srcElement == destElement) + if (srcMappedElement == destElement) { found = true; - indices.add(builder.getIntValue(builder.getUIntType(), j)); + // We rely on the fact that if the element ever appeared in a collection, + // it must have been assigned a unique ID. + // + mapping.add( + cBuilder.getUniqueID(srcCollection->getElement(i)), + cBuilder.getUniqueID(destElement)); break; // Found the index } } @@ -342,88 +267,13 @@ struct TagOpsLoweringContext : public InstPassBase } } - // Create an array for the lookup - auto lookupArrayType = builder.getArrayType( - builder.getUIntType(), - builder.getIntValue(builder.getUIntType(), indices.getCount())); - auto lookupArray = - builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); - auto resultID = - builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); - inst->replaceUsesWith(resultID); - inst->removeAndDeallocate(); - } - - void lowerGetTagForSpecializedCollection(IRGetTagForSpecializedCollection* inst) - { - // We use the result type and the type of the operand - // to figure out the source and destination collections. - // - // We then replace this with an array access, where the i'th - // element of the array is the corresponding index in the mapped - // collection. - // - // The mapping between elements is provided as pairs of operands to the instruction. - // - // e.g. - // let a : TagType(GenericCollection(B, C)) = /* ... */; - // let b : TagType(FuncCollection(E, F)) = - // GetTagForSpecializedCollection(a, B, F, C, E); - // becomes - // let a : TagType(GenericCollection(B, C)) = /* ... */; - // let lookupArr : ArrayType = [1, 0]; // B->F, C->E - // let b : TagType(FuncCollection(E, F)) = ElementExtract(lookupArr, a); - // - // Note that we leave the tag-types of the output intact since we may need to lower - // later tag operations. - // - - auto srcCollection = - cast(inst->getOperand(0)->getDataType())->getCollection(); - auto destCollection = cast(inst->getDataType())->getCollection(); - Dictionary mapping; - - for (UInt i = 1; i < inst->getOperandCount(); i += 2) - { - auto srcElement = inst->getOperand(i); - auto destElement = inst->getOperand(i + 1); - mapping[srcElement] = destElement; - } - - IRBuilder builder(inst->getModule()); - builder.setInsertAfter(inst); - - List indices; - for (UInt i = 0; i < srcCollection->getCount(); i++) - { - // Find in destCollection - bool found = false; - auto mappedElement = mapping[srcCollection->getElement(i)]; - for (UInt j = 0; j < destCollection->getCount(); j++) - { - auto destElement = destCollection->getElement(j); - if (mappedElement == destElement) - { - found = true; - indices.add(builder.getIntValue(builder.getUIntType(), j)); - break; // Found the index - } - } - - if (!found) - { - SLANG_UNEXPECTED("Element not found in specialized collection"); - } - } + // Create an index mapping func and call that + auto mappingFunc = createIntegerMappingFunc(inst->getModule(), mapping, 0); - // Create an array for the lookup - auto lookupArrayType = builder.getArrayType( - builder.getUIntType(), - builder.getIntValue(builder.getUIntType(), indices.getCount())); - auto lookupArray = - builder.emitMakeArray(lookupArrayType, indices.getCount(), indices.getBuffer()); - auto resultID = - builder.emitElementExtract(inst->getDataType(), lookupArray, inst->getOperand(0)); + auto resultID = builder.emitCallInst( + inst->getDataType(), + mappingFunc, + List({inst->getOperand(0)})); inst->replaceUsesWith(resultID); inst->removeAndDeallocate(); } @@ -433,21 +283,8 @@ struct TagOpsLoweringContext : public InstPassBase IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); - // find the index of the element in the collection - auto collection = cast(inst->getOperand(1)); - auto element = inst->getOperand(0); - UInt foundIndex = UInt(-1); - for (UInt i = 0; i < collection->getCount(); i++) - { - if (collection->getElement(i) == element) - { - foundIndex = i; - break; - } - } - - SLANG_ASSERT(foundIndex != UInt(-1)); - auto resultValue = builder.getIntValue(inst->getDataType(), foundIndex); + auto uniqueId = cBuilder.getUniqueID(inst->getOperand(0)); + auto resultValue = builder.getIntValue(inst->getDataType(), uniqueId); inst->replaceUsesWith(resultValue); inst->removeAndDeallocate(); } @@ -462,9 +299,6 @@ struct TagOpsLoweringContext : public InstPassBase case kIROp_GetTagForMappedCollection: lowerGetTagForMappedCollection(as(inst)); break; - case kIROp_GetTagForSpecializedCollection: - lowerGetTagForSpecializedCollection(as(inst)); - break; case kIROp_GetTagOfElementInCollection: lowerGetTagOfElementInCollection(as(inst)); break; @@ -473,11 +307,12 @@ struct TagOpsLoweringContext : public InstPassBase } } - void processModule() { processAllInsts([&](IRInst* inst) { return processInst(inst); }); } + + CollectionBuilder cBuilder; }; struct DispatcherLoweringContext : public InstPassBase @@ -505,7 +340,6 @@ struct DispatcherLoweringContext : public InstPassBase } Dictionary elements; - UInt index = 0; IRBuilder builder(dispatcher->getModule()); forEachInCollection( witnessTableCollection, @@ -571,7 +405,6 @@ struct DispatcherLoweringContext : public InstPassBase IRBuilder builder(dispatcher->getModule()); Dictionary elements; - UInt index = 0; forEachInCollection( witnessTableCollection, [&](IRInst* table) @@ -680,7 +513,7 @@ void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) struct SequentialIDTagLoweringContext : public InstPassBase { SequentialIDTagLoweringContext(IRModule* module) - : InstPassBase(module) + : InstPassBase(module), cBuilder(module) { } void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) @@ -704,13 +537,12 @@ struct SequentialIDTagLoweringContext : public InstPassBase // Map from sequential ID to unique ID auto destCollection = cast(inst->getDataType())->getCollection(); - UIndex dstSeqID = 0; forEachInCollection( destCollection, [&](IRInst* table) { // Get unique ID for the witness table - auto outputId = dstSeqID++; + auto outputId = cBuilder.getUniqueID(table); auto seqDecoration = table->findDecoration(); if (seqDecoration) { @@ -754,14 +586,13 @@ struct SequentialIDTagLoweringContext : public InstPassBase // Map from sequential ID to unique ID auto destCollection = cast(srcTagInst->getDataType())->getCollection(); - UIndex dstSeqID = 0; forEachInCollection( destCollection, [&](IRInst* table) { // Get unique ID for the witness table SLANG_UNUSED(cast(table)); - auto outputId = dstSeqID++; + auto outputId = cBuilder.getUniqueID(table); auto seqDecoration = table->findDecoration(); if (seqDecoration) { @@ -791,6 +622,8 @@ struct SequentialIDTagLoweringContext : public InstPassBase kIROp_GetSequentialIDFromTag, [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); } + + CollectionBuilder cBuilder; }; void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h index 573587a9360..cac05626d02 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -16,8 +16,8 @@ bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); // Lower `CollectionTagType` types void lowerTagTypes(IRModule* module); -// Lower `GetTagForSuperCollection`, `GetTagForMappedCollection` and -// `GetTagForSpecializedCollection` instructions +// Lower `GetTagOfElementInCollection`, +// `GetTagForSuperCollection`, and `GetTagForMappedCollection` instructions, // void lowerTagInsts(IRModule* module, DiagnosticSink* sink); diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h index 110be0bff7d..13f51f4ecd1 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-collection.h @@ -76,12 +76,12 @@ struct CollectionBuilder // IRCollectionBase* makeSet(const HashSet& values); -private: // Return a unique ID for the inst. Assuming the module pointer // is consistent, this should always be the same for a given inst. // UInt getUniqueID(IRInst* inst); +private: // Reference to parent module IRModule* module; diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 3a5443b2724..77a06e437c2 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -127,6 +127,19 @@ bool isResourcePointer(IRInst* inst) return false; } +bool isNoneCallee(IRInst* callee) +{ + if (auto lookupWitness = as(callee)) + { + if (auto table = as(callee->getOperand(0))) + { + return table->getConcreteType()->getOp() == kIROp_VoidType; + } + } + + return false; +} + // Represents an interprocedural edge between call sites and functions struct InterproceduralEdge { @@ -551,7 +564,6 @@ struct TypeFlowSpecializationContext if (auto info = tryGetInfo(context, inst)) return info; - IRBuilder builder(module); if (auto ptrType = as(inst->getDataType())) { @@ -607,38 +619,9 @@ struct TypeFlowSpecializationContext if (!inst->getParent()) return none(); - // If this is a global instruction (parent is module), return a singleton set of - // that inst. - // - // Since it's easy to tell when an inst is representing a concrete - // entity, we do this on demand rather than trying to put it in the - // propagation map. - // - /* - if (as(inst->getParent())) - { - if (as(inst) || as(inst) || as(inst) || - as(inst)) - { - // We won't directly handle interface types, but rather treat objects of interface - // type as objects that can be specialized with collections. - // - if (as(inst)) - return none(); - - if (as(inst) && as(getGenericReturnVal(inst))) - return none(); - - return makeElementOfCollectionType(cBuilder.makeSingletonSet(inst)); - } - else - return none(); - } - */ + // Global insts always have no info. if (as(inst->getParent())) - { return none(); - } return _tryGetInfo(InstWithContext(context, inst)); } @@ -763,14 +746,6 @@ struct TypeFlowSpecializationContext cast(info2->getOperand(0)))); } - if (as(info1) && as(info2)) - { - SLANG_UNEXPECTED("Should not see 'raw' collection types anymore"); - return unionCollection( - cast(info1), - cast(info2)); - } - SLANG_UNEXPECTED("Unhandled propagation info types in unionPropagationInfo"); } @@ -1001,9 +976,6 @@ struct TypeFlowSpecializationContext case kIROp_Call: info = analyzeCall(context, as(inst), workQueue); break; - // case kIROp_Param: - // info = analyzeParam(context, as(inst)); - // break; case kIROp_Specialize: info = analyzeSpecialize(context, as(inst)); break; @@ -1135,14 +1107,12 @@ struct TypeFlowSpecializationContext const auto [paramDirection, paramType] = splitParameterDirectionAndType(param->getDataType()); - // Only update if - // 1. The paramType is a global inst and an interface type - // 2. The paramType is a local inst. - // all other cases, continue. + // Only update if the parameter is not a concrete type. // - // This is primarily just an optimization. Without this, - // we'd be storing 'singleton' sets for parameters with - // regular concrete types (i.e. 99% of cases). + // This is primarily just an optimization. + // Without this, we'd be storing 'singleton' sets for parameters with + // regular concrete types (i.e. 99% of cases), which can clog up + // the propagation dictionary when analyzing large modules. // This optimization ignores them and re-derives the info // from the data-type. // @@ -1161,16 +1131,6 @@ struct TypeFlowSpecializationContext case ParameterDirectionInfo::Kind::BorrowIn: { IRBuilder builder(module); - /* - if (!argInfo) - { - if (isGlobalInst(arg->getDataType()) && - !as( - as(arg->getDataType())->getValueType())) - argInfo = arg->getDataType(); - } - */ - if (!argInfo) break; @@ -1183,14 +1143,6 @@ struct TypeFlowSpecializationContext } case ParameterDirectionInfo::Kind::In: { - /* - if (!argInfo) - { - if (isGlobalInst(arg->getDataType()) && - !as(arg->getDataType())) - argInfo = arg->getDataType(); - } - */ updateInfo(edge.targetContext, param, argInfo, true, workQueue); break; } @@ -1290,7 +1242,13 @@ struct TypeFlowSpecializationContext return makeTaggedUnionType(as( cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); else + { + sink->diagnose( + inst, + Diagnostics::noTypeConformancesFoundForInterface, + interfaceType); return none(); + } } return none(); @@ -1927,14 +1885,6 @@ struct TypeFlowSpecializationContext auto specializedFuncType = cast(specializeGeneric(specializeInst)); typeOfSpecialization = specializedFuncType; } - else if (auto collection = as(typeInfo)) - { - SLANG_UNEXPECTED("shouldn't see this case"); - SLANG_ASSERT(collection->isSingleton()); - auto specializeInst = cast(collection->getElement(0)); - auto specializedFuncType = cast(specializeGeneric(specializeInst)); - typeOfSpecialization = specializedFuncType; - } else { return none(); @@ -2093,6 +2043,9 @@ struct TypeFlowSpecializationContext auto callee = inst->getCallee(); auto calleeInfo = tryGetInfo(context, callee); + if (isNoneCallee(callee)) + return none(); + auto propagateToCallSite = [&](IRInst* callee) { // Register the call site in the map to allow for the @@ -2136,23 +2089,6 @@ struct TypeFlowSpecializationContext return none(); } - IRInst* analyzeParam(IRInst* context, IRInst* inst, WorkQueue& workQueue) - { - /* We only need to handle one case, where we calculate the info based on the - // info for the data-type. - // - if (auto info = tryGetInfo(context, inst->getDataType())) - { - if (auto elementOfCollection = as(info)) - { - return makeValueOfCollectionType( - cast(elementOfCollection->getCollection())); - } - } - - return none();*/ - } - // Updates the information for an address. void maybeUpdatePtr(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) { @@ -2662,23 +2598,6 @@ struct TypeFlowSpecializationContext for (auto structType : structsToProcess) hasChanges |= specializeStructType(structType); - /*while (funcsToProcess.getCount() > 0) - { - for (auto func : funcsToProcess) - hasChanges |= specializeFunc(func); - - funcsToProcess.clear(); - for (auto context : contextsToLower) - { - hasChanges = true; - auto dynGenericFunc = specializeDynamicGeneric(cast(context)); - context->replaceUsesWith(dynGenericFunc); - context->removeAndDeallocate(); - funcsToProcess.add(cast(dynGenericFunc)); - } - - contextsToLower.clear(); - }*/ for (auto func : funcsToProcess) hasChanges |= specializeFunc(func); @@ -2976,7 +2895,6 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertAfter(inst); - // auto val = builder.emitGetTupleElement((IRType*)valType, existential, 1); auto val = builder.emitGetValueFromTaggedUnion(existential); inst->replaceUsesWith(val); inst->removeAndDeallocate(); @@ -2988,28 +2906,6 @@ struct TypeFlowSpecializationContext bool specializeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { - /*auto info = tryGetInfo(context, inst); - auto collectionTagType = as(info); - if (!collectionTagType) - return false; - - IRBuilder builder(inst); - builder.setInsertBefore(inst); - - if (collectionTagType->isSingleton()) - { - // Found a single possible type. Simple replacement. - auto singletonValue = collectionTagType->getCollection()->getElement(0); - inst->replaceUsesWith(singletonValue); - inst->removeAndDeallocate(); - return true; - } - - // Replace the instruction with the collection type. - inst->replaceUsesWith(collectionTagType->getCollection()); - inst->removeAndDeallocate(); - return true;*/ - auto info = tryGetInfo(context, inst); if (auto elementOfCollectionType = as(info)) { @@ -3099,8 +2995,6 @@ struct TypeFlowSpecializationContext cBuilder.createCollection(kIROp_TypeCollection, collectionElements); return makeValueOfCollectionType(cast(newCollection)); } - - SLANG_UNEXPECTED("Unhandled case in updateType"); } IRFuncType* getEffectiveFuncTypeForDispatcher( @@ -3124,77 +3018,29 @@ struct TypeFlowSpecializationContext { SLANG_UNUSED(key); - List paramTypes; - IRType* resultType = nullptr; - List extraParamTypes; extraParamTypes.add((IRType*)makeTagType(tableCollection)); + /* for (auto specArg : specArgs) if (as(specArg)) extraParamTypes.add((IRType*)makeTagType(as(specArg))); + */ - IRBuilder builder(module); - auto updateParamType = [&](Index index, IRType* paramType) -> IRType* - { - if (paramTypes.getCount() <= index) - { - // If this index hasn't been seen yet, expand the buffer and initialize - // the type. - // - paramTypes.growToCount(index + 1); - paramTypes[index] = paramType; - return paramType; - } - else - { - // Otherwise, update the existing type - auto [currentDirection, currentType] = - splitParameterDirectionAndType(paramTypes[index]); - auto [newDirection, newType] = splitParameterDirectionAndType(paramType); - auto updatedType = updateType(currentType, newType); - SLANG_ASSERT(currentDirection == newDirection); - paramTypes[index] = fromDirectionAndType(&builder, currentDirection, updatedType); - return updatedType; - } - }; - - forEachInCollection( - resultFuncCollection, - [&](IRInst* func) - { - auto paramEffectiveTypes = getEffectiveParamTypes(func); - auto paramDirections = getParamDirections(func); - - for (Index i = 0; i < paramEffectiveTypes.getCount(); i++) - updateParamType(i, getLoweredType(paramEffectiveTypes[i])); - - auto returnType = getFuncReturnInfo(func); - if (auto newResultType = getLoweredType(returnType)) - { - resultType = updateType(resultType, newResultType); - } - else if (auto funcType = as(func->getDataType())) - { - SLANG_ASSERT(isGlobalInst(funcType->getResultType())); - resultType = updateType(resultType, funcType->getResultType()); - } - else - { - SLANG_UNEXPECTED("Cannot determine result type for context"); - } - }); - + auto innerFuncType = getEffectiveFuncTypeForCollection(resultFuncCollection); List allParamTypes; allParamTypes.addRange(extraParamTypes); - allParamTypes.addRange(paramTypes); - return builder.getFuncType(allParamTypes, resultType); + for (auto paramType : innerFuncType->getParamTypes()) + allParamTypes.add(paramType); + + IRBuilder builder(module); + return builder.getFuncType(allParamTypes, innerFuncType->getResultType()); } // Get an effective func type to use for the callee. - // The callee may be a collection, in which case, this returns a union-ed functype./ + // The callee may be a collection, in which case, this returns a union-ed functype. // - IRFuncType* getEffectiveFuncType(IRInst* callee) + IRFuncType* getEffectiveFuncTypeForCollection(IRFuncCollection* calleeCollection) { // The effective func type for a callee is calculated as follows: // @@ -3242,26 +3088,10 @@ struct TypeFlowSpecializationContext } }; - List contextsToProcess; - if (auto collection = as(callee)) - { - forEachInCollection(collection, [&](IRInst* func) { contextsToProcess.add(func); }); - } - else if (auto collectionTagType = as(callee->getDataType())) - { - SLANG_UNEXPECTED( - "Should never try to directly call a tag type. Not semantically meaningful"); - forEachInCollection( - collectionTagType, - [&](IRInst* func) { contextsToProcess.add(func); }); - } - else - { - // Otherwise, just process the single function - contextsToProcess.add(callee); - } + List calleesToProcess; + forEachInCollection(calleeCollection, [&](IRInst* func) { calleesToProcess.add(func); }); - for (auto context : contextsToProcess) + for (auto context : calleesToProcess) { auto paramEffectiveTypes = getEffectiveParamTypes(context); auto paramDirections = getParamDirections(context); @@ -3286,29 +3116,18 @@ struct TypeFlowSpecializationContext } // - // Add in extra parameter types for a call to a non-concrete callee. + // Add in extra parameter types for a call to a dynamic generic callee // List extraParamTypes; - // If the callee is a collection, then we need a tag as input. - if (auto funcCollection = as(callee)) - { - // If this is a non-trivial collection, we need to add a tag type for the collection - // as the first parameter. - if (getCollectionCount(funcCollection) > 1) - extraParamTypes.add((IRType*)makeTagType(funcCollection)); - } // If the any of the elements in the callee (or the callee itself in case // of a singleton) is a dynamic specialization, each non-singleton WitnessTableCollection, // requries a corresponding tag input. // - auto calleeToCheck = as(callee) - ? getCollectionElement(as(callee), 0) - : callee; - if (isDynamicGeneric(calleeToCheck)) + if (calleeCollection->isSingleton() && isDynamicGeneric(calleeCollection->getElement(0))) { - auto specializeInst = as(calleeToCheck); + auto specializeInst = as(calleeCollection->getElement(0)); // If this is a dynamic generic, we need to add a tag type for each // WitnessTableCollection in the callee. @@ -3325,6 +3144,12 @@ struct TypeFlowSpecializationContext return builder.getFuncType(allParamTypes, resultType); } + IRFuncType* getEffectiveFuncType(IRInst* callee) + { + return getEffectiveFuncTypeForCollection( + cast(cBuilder.makeSingletonSet(callee))); + } + // Helper function for specializing calls. // // For a `Specialize` instruction that has dynamic tag arguments, @@ -3435,10 +3260,15 @@ struct TypeFlowSpecializationContext // auto callee = inst->getCallee(); + + // TODO: Can remove this workaround since we're lowering these immediately. if (as(callee) || as(callee)) // Already specialized return false; + if (isNoneCallee(callee)) + return false; + // IRInst* calleeTagInst = nullptr; List callArgs; @@ -3469,8 +3299,6 @@ struct TypeFlowSpecializationContext // function. If we ever have the ability to pass functions around more // flexibly, then this should just become a specific case. - // TODO: Also handle the case where we need to perform a static specialization. - if (auto tagMapOperand = as(callee)) { auto tableTag = tagMapOperand->getOperand(0); @@ -3493,12 +3321,12 @@ struct TypeFlowSpecializationContext else if ( auto specializedTagMapOperand = as(callee)) { - auto tagMapOperand = + auto innerTagMapOperand = cast(specializedTagMapOperand->getOperand(0)); - auto tableTag = tagMapOperand->getOperand(0); + auto tableTag = innerTagMapOperand->getOperand(0); auto tableCollection = cast( cast(tableTag->getDataType())->getCollection()); - auto lookupKey = cast(tagMapOperand->getOperand(1)); + auto lookupKey = cast(innerTagMapOperand->getOperand(1)); List specArgs; for (auto argIdx = 1; argIdx < specializedTagMapOperand->getOperandCount(); @@ -3553,21 +3381,24 @@ struct TypeFlowSpecializationContext auto funcType = getEffectiveFuncType(callee); callee->setFullType(funcType); - - // contextsToLower.add(callee); } else { - SLANG_ASSERT("Shouldn't get here.."); - // Single element which is a concrete function. - callee = collectionTag->getCollection()->getElement(0); - - auto funcType = getEffectiveFuncType(callee); - callee->setFullType(funcType); + // If we reach here, then something is wrong. If our callee is an inst of tag-type, + // we expect it to either be a `GetTagForMappedCollection`, `Specialize` or + // `GetTagForSpecializedCollection`. + // Any other case should never occur (in the current design of the compiler) + // + SLANG_UNEXPECTED( + "Unexpected operand type for type-flow specialization of Call inst"); } } - else + else if (isGlobalInst(callee) && !isIntrinsic(callee)) { + // If our callee is not a tag-type, then it is necessarily a simple concrete function. + // We will fix-up the function type so that it has the effective types as determined + // by the analysis. + // auto funcType = getEffectiveFuncType(callee); callee->setFullType(funcType); } @@ -3584,60 +3415,6 @@ struct TypeFlowSpecializationContext if (as(callee)) return false; - // auto expectedFuncType = getEffectiveFuncType(callee); - - /* - List newArgs; - IRInst* newCallee = nullptr; - - // Determine a new callee. - auto calleeCollection = as(callee); - if (!calleeCollection) - newCallee = callee; // Not a collection, no need to specialize - else if (getCollectionCount(calleeCollection) == 1) - { - auto singletonValue = getCollectionElement(calleeCollection, 0); - if (singletonValue == callee) - { - newCallee = callee; - } - else - { - if (isDynamicGeneric(singletonValue)) - newArgs.addRange( - getArgsForDynamicSpecialization(cast(inst->getCallee()))); - - newCallee = singletonValue; - } - } - else - { - // Multiple elements in the collection. - if (calleeTagInst) - newArgs.add(calleeTagInst); - auto funcCollection = cast(calleeCollection); - - // Check if the first element is a dynamic generic (this should imply that all - // elements are similar dynamic generics, but we might want to check for that..) - // - if (isDynamicGeneric(getCollectionElement(funcCollection, 0))) - { - auto dynamicSpecArgs = - getArgsForDynamicSpecialization(cast(inst->getCallee())); - for (auto& arg : dynamicSpecArgs) - newArgs.add(arg); - } - - if (!as(funcCollection->getDataType())) - { - auto typeForCollection = getEffectiveFuncType(funcCollection); - funcCollection->setFullType(typeForCollection); - } - - newCallee = funcCollection; - } - */ - // First, we'll legalize all operands by upcasting if necessary. // This needs to be done even if the callee is not a collection. // @@ -3675,36 +3452,6 @@ struct TypeFlowSpecializationContext } } - /* - // First, we'll legalize all operands by upcasting if necessary. - // This needs to be done even if the callee is not a collection. - // - UCount extraArgCount = newArgs.getCount(); - for (UInt i = 0; i < inst->getArgCount(); i++) - { - auto arg = inst->getArg(i); - const auto [paramDirection, paramType] = - splitParameterDirectionAndType(expectedFuncType->getParamType(i + extraArgCount)); - - switch (paramDirection.kind) - { - case ParameterDirectionInfo::Kind::In: - newArgs.add(upcastCollection(context, arg, paramType)); - break; - case ParameterDirectionInfo::Kind::Out: - case ParameterDirectionInfo::Kind::BorrowInOut: - case ParameterDirectionInfo::Kind::BorrowIn: - case ParameterDirectionInfo::Kind::Ref: - { - newArgs.add(arg); - break; - } - default: - SLANG_UNEXPECTED("Unhandled parameter direction in specializeCall"); - } - } - */ - IRBuilder builder(inst); builder.setInsertBefore(inst); @@ -3790,7 +3537,7 @@ struct TypeFlowSpecializationContext // After specialization, existentials (that are not unbounded) are treated as tuples // of a WitnessTableCollection tag and a value of type TypeCollection. // - // A MakeExistential is just converted into a MakeTuple, with any necessary + // A MakeExistential is just converted into a MakeTaggedUnion, with any necessary // upcasts. // @@ -3995,8 +3742,6 @@ struct TypeFlowSpecializationContext bool isFuncReturn = false; - // TODO: Would checking this inst's info be enough instead? - // This seems long-winded. if (auto concreteGeneric = as(inst->getBase())) isFuncReturn = as(getGenericReturnVal(concreteGeneric)) != nullptr; else if (auto tagType = as(inst->getBase()->getDataType())) @@ -4012,7 +3757,8 @@ struct TypeFlowSpecializationContext { if (auto elementOfCollectionType = as(info)) { - // TODO: Should we make it such that the `GetTagForSpecializedCollection` + // Note for future reworks: + // Should we make it such that the `GetTagForSpecializedCollection` // is emitted in the single func case too? // // Basically, as long as any of the specialization operands are dynamic, @@ -4021,6 +3767,7 @@ struct TypeFlowSpecializationContext // Currently, if the func is a singleton, we leave it as a Specialize inst // with dynamic args to be handled in specializeCall. // + if (elementOfCollectionType->getCollection()->isSingleton()) { // If the result is a singleton collection, we can just @@ -4056,50 +3803,6 @@ struct TypeFlowSpecializationContext "Expected element-of-collection type for function specialization"); } } - /* - // If our inst represents a collection directly (no run-time info), - // there's nothing to do except replace the type (if necessary) - // - if (as(info)) - return replaceType(context, inst); - - auto specializedCollectionTag = as(info); - - // If the inst represents a singleton collection, there's nothing - // to do except replace the type (if necessary) - // - if (getCollectionCount(specializedCollectionTag) <= 1) - return replaceType(context, inst); - - List mappingOperands; - - // Add the base tag as the first operand. The mapping operands follow - mappingOperands.add(inst->getBase()); - - forEachInCollection( - specializedCollectionTag, - [&](IRInst* element) - { - // Emit the GetTagForSpecializedCollection for each element. - auto specInst = cast(element); - auto baseGeneric = cast(specInst->getBase()); - - mappingOperands.add(baseGeneric); - mappingOperands.add(specInst); - }); - - IRBuilder builder(inst); - setInsertBeforeOrdinaryInst(&builder, inst); - auto newInst = builder.emitIntrinsicInst( - (IRType*)info, - kIROp_GetTagForSpecializedCollection, - mappingOperands.getCount(), - mappingOperands.getBuffer()); - - inst->replaceUsesWith(newInst); - inst->removeAndDeallocate(); - return true; - */ } // For all other specializations, we'll 'drop' the dynamic tag information. @@ -4151,19 +3854,18 @@ struct TypeFlowSpecializationContext bool specializeGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) { - // GetValueFromBoundInterface is essentially accessing the value component of - // an existential. We turn it into a tuple element access. + // `GetValueFromBoundInterface` is essentially accessing the value component of + // an existential. If the operand has been specialized into a tagged-union, then we can + // turn it into a `GetValueFromTaggedUnion`. // SLANG_UNUSED(context); - // auto destType = inst->getDataType(); + auto operandInfo = inst->getOperand(0)->getDataType(); if (as(operandInfo)) { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - /*auto newInst = builder.emitGetTupleElement((IRType*)destType, inst->getOperand(0), - * 1);*/ auto newInst = builder.emitGetValueFromTaggedUnion(inst->getOperand(0)); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -4174,6 +3876,7 @@ struct TypeFlowSpecializationContext bool specializeGetElementFromTag(IRInst* context, IRGetElementFromTag* inst) { + SLANG_UNUSED(context); inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); return true; @@ -4298,7 +4001,6 @@ struct TypeFlowSpecializationContext if (specializedVal != inst->getVal()) { // If the value was changed, we need to update the store instruction. - IRBuilder builder(inst); builder.replaceOperand(inst->getValUse(), specializedVal); return true; } @@ -4308,7 +4010,7 @@ struct TypeFlowSpecializationContext bool specializeGetSequentialID(IRInst* context, IRGetSequentialID* inst) { - // A SequentialID is a globally unique ID for a witness table, while the + // A sequential ID is a globally unique ID for a witness table, while the // the tags we use in the specialization are only locally consistent. // // To extract the global ID, we'll use a separate op code `GetSequentialIDFromTag` @@ -4353,23 +4055,12 @@ struct TypeFlowSpecializationContext { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - auto firstElement = tagType->getCollection()->getElement(0); - auto interfaceType = - as(as(firstElement)->getConformanceType()); - - // TODO: This is a rather suboptimal implementation that involves using - // global sequential IDs even though we could do it via local IDs. - // - - List args = {interfaceType, witnessTableArg}; - auto valueSeqID = builder.emitIntrinsicInst( - (IRType*)builder.getUIntType(), - kIROp_GetSequentialIDFromTag, - args.getCount(), - args.getBuffer()); - auto targetSeqID = builder.emitGetSequentialIDInst(inst->getTargetWitness()); - auto eqlInst = builder.emitEql(valueSeqID, targetSeqID); + auto targetTag = builder.emitGetTagOfElementInCollection( + (IRType*)tagType, + inst->getTargetWitness(), + tagType->getCollection()); + auto eqlInst = builder.emitEql(targetTag, witnessTableArg); inst->replaceUsesWith(eqlInst); inst->removeAndDeallocate(); @@ -4509,25 +4200,15 @@ struct TypeFlowSpecializationContext // builder.setInsertBefore(inst); - // TODO: Use proper op-codes and don't rely on tuple ops. - auto dynTag = builder.emitGetTupleElement( - (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), - inst->getOptionalOperand(), - 0); - - IRInst* noneWitnessTagType = - makeTagType(cBuilder.makeSingletonSet(getNoneWitness())); - IRInst* noneSingletonWitnessTag = - builder.getIntValue((IRType*)noneWitnessTagType, 0); + auto dynTag = builder.emitGetTagFromTaggedUnion(inst->getOptionalOperand()); // Cast the singleton tag to the target collection tag (will convert the // value to the corresponding value for the larger set) // - auto noneWitnessTag = builder.emitIntrinsicInst( + auto noneWitnessTag = builder.emitGetTagOfElementInCollection( (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), - kIROp_GetTagForSuperCollection, - 1, - &noneSingletonWitnessTag); + getNoneWitness(), + taggedUnionType->getWitnessTableCollection()); auto newInst = builder.emitNeq(dynTag, noneWitnessTag); inst->replaceUsesWith(newInst); @@ -4584,6 +4265,12 @@ struct TypeFlowSpecializationContext // performInformationPropagation(); + if (sink->getErrorCount() > 0) + { + // If there were errors during propagation, we bail out early. + return false; + } + // Phase 2: Dynamic Instruction Specialization // Re-write dynamic instructions into specialized versions based on the // type information in the previous phase. @@ -4629,12 +4316,6 @@ struct TypeFlowSpecializationContext // Set of already discovered contexts. HashSet availableContexts; - // Contexts requiring lowering - HashSet contextsToLower; - - // Lowered contexts. - Dictionary loweredContexts; - // Helper for building collections. CollectionBuilder cBuilder; }; From bba9b1a92f37c01e5c380bc810c6517fa83ce586 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 10:34:36 -0400 Subject: [PATCH 073/105] Fix unused vars --- external/slang-rhi | 2 +- external/spirv-headers | 2 +- external/spirv-tools | 2 +- source/slang/slang-ir-typeflow-specialize.cpp | 66 ++++--------------- 4 files changed, 14 insertions(+), 58 deletions(-) diff --git a/external/slang-rhi b/external/slang-rhi index 3ca6ff61a45..5b847c52a47 160000 --- a/external/slang-rhi +++ b/external/slang-rhi @@ -1 +1 @@ -Subproject commit 3ca6ff61a45c41012192aa46044efed3d39a7836 +Subproject commit 5b847c52a476888257a6f4beddef5801c0a5f0b7 diff --git a/external/spirv-headers b/external/spirv-headers index 9268f305735..01e0577914a 160000 --- a/external/spirv-headers +++ b/external/spirv-headers @@ -1 +1 @@ -Subproject commit 9268f3057354a2cb65991ba5f38b16d81e803692 +Subproject commit 01e0577914a75a2569c846778c2f93aa8e6feddd diff --git a/external/spirv-tools b/external/spirv-tools index 42c2fe22df0..7f2d9ee926f 160000 --- a/external/spirv-tools +++ b/external/spirv-tools @@ -1 +1 @@ -Subproject commit 42c2fe22df033e5f0ab1c2ba4d0079a2a72a16b8 +Subproject commit 7f2d9ee926f98fc77a3ed1e1e0f113b8c9c49458 diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 77a06e437c2..31f8a635911 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -131,7 +131,7 @@ bool isNoneCallee(IRInst* callee) { if (auto lookupWitness = as(callee)) { - if (auto table = as(callee->getOperand(0))) + if (auto table = as(lookupWitness->getWitnessTable())) { return table->getConcreteType()->getOp() == kIROp_VoidType; } @@ -1774,10 +1774,6 @@ struct TypeFlowSpecializationContext // Handle the 'many' or 'one' cases. if (as(operandInfo) || isGlobalInst(operand)) { - // If any of the specialization arguments need a tag (or the generic itself is a tag), - // we need the result to also be wrapped in a tag type. - bool needsElement = false; - List specializationArgs; for (UInt i = 0; i < inst->getArgCount(); ++i) { @@ -1809,7 +1805,6 @@ struct TypeFlowSpecializationContext elementOfCollectionType->getCollection()->getElement(0)); else { - needsElement = true; if (auto typeCollection = as(elementOfCollectionType->getCollection())) { @@ -2794,21 +2789,8 @@ struct TypeFlowSpecializationContext return true; } - // If the collection is a type-collection, we'll still do a direct replacement - // effectively dropping the tag information - // - if (auto typeCollection = as(elementOfCollectionType->getCollection())) - { - // If this is a type collection, we can replace it with the collection type - // We don't currently care about the tag of a type. - // - inst->replaceUsesWith(typeCollection); - inst->removeAndDeallocate(); - return true; - } - - // If we reach here, we have a truly dynamic case. Multiple elements and not a type - // collection We need to emit a run-time inst to keep track of the tag. + // If we reach here, we have a truly dynamic case. Multiple elements. + // We need to emit a run-time inst to keep track of the tag. // // We use the GetTagForMappedCollection inst to do this, and set its data type to // the appropriate tag-type. @@ -2817,7 +2799,7 @@ struct TypeFlowSpecializationContext auto witnessTableInst = inst->getWitnessTable(); auto witnessTableInfo = witnessTableInst->getDataType(); - if (auto witnessTableOperandTagType = as(witnessTableInfo)) + if (as(witnessTableInfo)) { auto thisInstInfo = cast(tryGetInfo(context, inst)); if (thisInstInfo->getCollection() != nullptr) @@ -2846,8 +2828,9 @@ struct TypeFlowSpecializationContext // If we have a non-trivial info registered, it must of // CollectionTagType(WitnessTableCollection(...)) // - // Futher, the operand must be an existential (CollectionTaggedUnionType), which is - // conceptually lowered to a TupleType(TagType(tableCollection), typeCollection) + // Further, the operand must be an existential (CollectionTaggedUnionType), which is + // conceptually a pair of TagType(tableCollection) and a + // ValueOfCollectionType(typeCollection) // // We will simply extract the first element of this tuple. // @@ -3001,32 +2984,12 @@ struct TypeFlowSpecializationContext IRWitnessTableCollection* tableCollection, IRStructKey* key, IRFuncCollection* resultFuncCollection) - { - List specArgs; - return getEffectiveFuncTypeForDispatcher( - tableCollection, - key, - resultFuncCollection, - specArgs); - } - - IRFuncType* getEffectiveFuncTypeForDispatcher( - IRWitnessTableCollection* tableCollection, - IRStructKey* key, - IRFuncCollection* resultFuncCollection, - List& specArgs) { SLANG_UNUSED(key); List extraParamTypes; extraParamTypes.add((IRType*)makeTagType(tableCollection)); - /* - for (auto specArg : specArgs) - if (as(specArg)) - extraParamTypes.add((IRType*)makeTagType(as(specArg))); - */ - auto innerFuncType = getEffectiveFuncTypeForCollection(resultFuncCollection); List allParamTypes; allParamTypes.addRange(extraParamTypes); @@ -3261,11 +3224,6 @@ struct TypeFlowSpecializationContext auto callee = inst->getCallee(); - // TODO: Can remove this workaround since we're lowering these immediately. - if (as(callee) || - as(callee)) // Already specialized - return false; - if (isNoneCallee(callee)) return false; @@ -3329,7 +3287,7 @@ struct TypeFlowSpecializationContext auto lookupKey = cast(innerTagMapOperand->getOperand(1)); List specArgs; - for (auto argIdx = 1; argIdx < specializedTagMapOperand->getOperandCount(); + for (UInt argIdx = 1; argIdx < specializedTagMapOperand->getOperandCount(); ++argIdx) { auto arg = specializedTagMapOperand->getOperand(argIdx); @@ -3359,8 +3317,7 @@ struct TypeFlowSpecializationContext getEffectiveFuncTypeForDispatcher( tableCollection, lookupKey, - cast(collectionTag->getCollection()), - specArgs), + cast(collectionTag->getCollection())), tableCollection, lookupKey, specArgs); @@ -3783,7 +3740,7 @@ struct TypeFlowSpecializationContext List specOperands; specOperands.add(inst->getBase()); - for (auto ii = 0; ii < inst->getArgCount(); ii++) + for (UInt ii = 0; ii < inst->getArgCount(); ii++) specOperands.add(inst->getArg(ii)); auto newInst = builder.emitIntrinsicInst( @@ -4133,8 +4090,7 @@ struct TypeFlowSpecializationContext bool specializeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) { SLANG_UNUSED(context); - if (auto taggedUnionType = - as(inst->getOptionalOperand()->getDataType())) + if (as(inst->getOptionalOperand()->getDataType())) { // Since `GetOptionalValue` is the reverse of `MakeOptionalValue`, and we treat // the latter as a no-op, then `GetOptionalValue` is also a no-op (we simply pass From 924e33258da4060a5355e50728770636c49d95c4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 10:50:05 -0400 Subject: [PATCH 074/105] Update slang-ir-typeflow-specialize.cpp --- source/slang/slang-ir-typeflow-specialize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 31f8a635911..dd9d7c9ea72 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -4068,7 +4068,7 @@ struct TypeFlowSpecializationContext bool specializeMakeOptionalValue(IRInst* context, IRMakeOptionalValue* inst) { SLANG_UNUSED(context); - if (auto taggedUnionType = as(inst->getValue()->getDataType())) + if (as(inst->getValue()->getDataType())) { // If we're dealing with a `MakeOptionalValue` for an existential type, // we don't actually have to change anything, since logically, the input and output From 84395e9aded6c444e3f357296e95f06d58036b70 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 11:34:55 -0400 Subject: [PATCH 075/105] Move set builder logic into IRBuilder --- source/slang/slang-ir-insts.h | 26 ++++ .../slang/slang-ir-lower-typeflow-insts.cpp | 35 ++--- source/slang/slang-ir-specialize.cpp | 7 +- source/slang/slang-ir-typeflow-collection.cpp | 139 +----------------- source/slang/slang-ir-typeflow-collection.h | 67 --------- source/slang/slang-ir-typeflow-specialize.cpp | 98 ++++++------ source/slang/slang-ir.cpp | 68 +++++++++ 7 files changed, 169 insertions(+), 271 deletions(-) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 8e0b2a99c2c..e44f4ba233d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4207,6 +4207,22 @@ struct IRBuilder emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &operand)); } + IRCollectionTaggedUnionType* getCollectionTaggedUnionType( + IRWitnessTableCollection* tables, + IRTypeCollection* types) + { + IRInst* operands[] = {tables, types}; + return as( + emitIntrinsicInst(nullptr, kIROp_CollectionTaggedUnionType, 2, operands)); + } + + IRCollectionTagType* getCollectionTagType(IRCollectionBase* collection) + { + IRInst* operands[] = {collection}; + return cast( + emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, operands)); + } + IRGetTagOfElementInCollection* emitGetTagOfElementInCollection( IRType* tagType, IRInst* element, @@ -5000,6 +5016,16 @@ struct IRBuilder } void addRayPayloadDecoration(IRType* inst) { addDecoration(inst, kIROp_RayPayloadDecoration); } + + IRCollectionBase* getCollection(IROp op, const HashSet& elements); + IRCollectionBase* getCollection(const HashSet& elements); + + IRCollectionBase* getSingletonCollection(IROp op, IRInst* element); + IRCollectionBase* getSingletonCollection(IRInst* element); + + UInt getUniqueID(IRInst* inst); + + IROp getCollectionTypeForInst(IRInst* inst); }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index d56648df978..fc5aa612331 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -216,7 +216,7 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi struct TagOpsLoweringContext : public InstPassBase { TagOpsLoweringContext(IRModule* module) - : InstPassBase(module), cBuilder(module) + : InstPassBase(module) { } @@ -254,8 +254,8 @@ struct TagOpsLoweringContext : public InstPassBase // it must have been assigned a unique ID. // mapping.add( - cBuilder.getUniqueID(srcCollection->getElement(i)), - cBuilder.getUniqueID(destElement)); + builder.getUniqueID(srcCollection->getElement(i)), + builder.getUniqueID(destElement)); break; // Found the index } } @@ -283,7 +283,7 @@ struct TagOpsLoweringContext : public InstPassBase IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); - auto uniqueId = cBuilder.getUniqueID(inst->getOperand(0)); + auto uniqueId = builder.getUniqueID(inst->getOperand(0)); auto resultValue = builder.getIntValue(inst->getDataType(), uniqueId); inst->replaceUsesWith(resultValue); inst->removeAndDeallocate(); @@ -311,8 +311,6 @@ struct TagOpsLoweringContext : public InstPassBase { processAllInsts([&](IRInst* inst) { return processInst(inst); }); } - - CollectionBuilder cBuilder; }; struct DispatcherLoweringContext : public InstPassBase @@ -513,7 +511,7 @@ void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) struct SequentialIDTagLoweringContext : public InstPassBase { SequentialIDTagLoweringContext(IRModule* module) - : InstPassBase(module), cBuilder(module) + : InstPassBase(module) { } void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) @@ -537,12 +535,15 @@ struct SequentialIDTagLoweringContext : public InstPassBase // Map from sequential ID to unique ID auto destCollection = cast(inst->getDataType())->getCollection(); + IRBuilder builder(inst); + builder.setInsertAfter(inst); + forEachInCollection( destCollection, [&](IRInst* table) { // Get unique ID for the witness table - auto outputId = cBuilder.getUniqueID(table); + auto outputId = builder.getUniqueID(table); auto seqDecoration = table->findDecoration(); if (seqDecoration) { @@ -551,9 +552,6 @@ struct SequentialIDTagLoweringContext : public InstPassBase } }); - IRBuilder builder(inst); - builder.setInsertAfter(inst); - // By default, use the tag for the largest available sequential ID. UInt defaultSeqID = 0; for (auto [inputId, outputId] : mapping) @@ -586,13 +584,16 @@ struct SequentialIDTagLoweringContext : public InstPassBase // Map from sequential ID to unique ID auto destCollection = cast(srcTagInst->getDataType())->getCollection(); + IRBuilder builder(inst); + builder.setInsertAfter(inst); + forEachInCollection( destCollection, [&](IRInst* table) { // Get unique ID for the witness table SLANG_UNUSED(cast(table)); - auto outputId = cBuilder.getUniqueID(table); + auto outputId = builder.getUniqueID(table); auto seqDecoration = table->findDecoration(); if (seqDecoration) { @@ -601,8 +602,6 @@ struct SequentialIDTagLoweringContext : public InstPassBase } }); - IRBuilder builder(inst); - builder.setInsertAfter(inst); auto translatedID = builder.emitCallInst( inst->getDataType(), createIntegerMappingFunc(builder.getModule(), mapping, 0), @@ -622,8 +621,6 @@ struct SequentialIDTagLoweringContext : public InstPassBase kIROp_GetSequentialIDFromTag, [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); } - - CollectionBuilder cBuilder; }; void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) @@ -838,11 +835,11 @@ struct TaggedUnionLoweringContext : public InstPassBase if (taggedUnion->getTypeCollection()->isSingleton()) return builder.getTupleType(List( - {(IRType*)makeTagType(tableCollection), + {(IRType*)builder.getCollectionTagType(tableCollection), (IRType*)taggedUnion->getTypeCollection()->getElement(0)})); - return builder.getTupleType( - List({(IRType*)makeTagType(tableCollection), (IRType*)typeCollection})); + return builder.getTupleType(List( + {(IRType*)builder.getCollectionTagType(tableCollection), (IRType*)typeCollection})); } bool lowerGetValueFromTaggedUnion(IRGetValueFromTaggedUnion* inst) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 26d05cdf1ae..a0f1d2704f7 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -889,12 +889,11 @@ struct SpecializationContext if (!skipSpecialization) { - CollectionBuilder cBuilder(lookupInst->getModule()); - auto newCollection = cBuilder.makeSet(satisfyingValSet); + IRBuilder builder(module); + auto newCollection = builder.getCollection(satisfyingValSet); addUsersToWorkList(lookupInst); if (as(newCollection)) { - IRBuilder builder(module); lookupInst->replaceUsesWith( builder.getValueOfCollectionType(newCollection)); lookupInst->removeAndDeallocate(); @@ -3275,7 +3274,7 @@ IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // We'll create an integer parameter for all the rest of // the insts which will may need the runtime tag. // - auto tagType = (IRType*)makeTagType(collection); + auto tagType = (IRType*)builder.getCollectionTagType(collection); // cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); extraParamMap.add(param, builder.emitParam(tagType)); extraParamTypes.add(tagType); diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index dff7eb45a9f..1ed7026dfe9 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -7,129 +7,6 @@ namespace Slang { -IRCollectionTagType* makeTagType(IRCollectionBase* collection) -{ - IRInst* collectionInst = collection; - // Create the tag type from the collection - IRBuilder builder(collection->getModule()); - return as( - builder.emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collectionInst)); -} - -UCount getCollectionCount(IRCollectionBase* collection) -{ - if (!collection) - return 0; - return collection->getOperandCount(); -} - -UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion) -{ - auto typeCollection = taggedUnion->getTypeCollection(); - return getCollectionCount(as(typeCollection)); -} - -UCount getCollectionCount(IRCollectionTagType* tagType) -{ - auto collection = tagType->getCollection(); - return getCollectionCount(as(collection)); -} - -IRInst* getCollectionElement(IRCollectionBase* collection, UInt index) -{ - if (!collection || index >= collection->getOperandCount()) - return nullptr; - return collection->getOperand(index); -} - -IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index) -{ - auto collection = collectionTagType->getCollection(); - return getCollectionElement(as(collection), index); -} - -CollectionBuilder::CollectionBuilder(IRModule* module) - : module(module) -{ - this->uniqueIds = module->getUniqueIdMap(); -} - -UInt CollectionBuilder::getUniqueID(IRInst* inst) -{ - auto existingId = uniqueIds->tryGetValue(inst); - if (existingId) - return *existingId; - - auto id = uniqueIds->getCount(); - uniqueIds->add(inst, id); - return id; -} - -// Helper method for creating canonical collections -IRCollectionBase* CollectionBuilder::createCollection(IROp op, const HashSet& elements) -{ - SLANG_ASSERT( - op == kIROp_TypeCollection || op == kIROp_FuncCollection || - op == kIROp_WitnessTableCollection || op == kIROp_GenericCollection); - - if (elements.getCount() == 0) - return nullptr; - - // Verify that all operands are global instructions - for (auto element : elements) - if (element->getParent()->getOp() != kIROp_ModuleInst) - SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); - - List sortedElements; - for (auto element : elements) - sortedElements.add(element); - - // Sort elements by their unique IDs to ensure canonical ordering - sortedElements.sort( - [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); - - // Create the collection instruction - IRBuilder builder(module); - builder.setInsertInto(module); - - return as(builder.emitIntrinsicInst( - nullptr, - op, - sortedElements.getCount(), - sortedElements.getBuffer())); -} - -IROp CollectionBuilder::getCollectionTypeForInst(IRInst* inst) -{ - if (as(inst)) - return kIROp_GenericCollection; - - if (as(inst->getDataType())) - return kIROp_TypeCollection; - else if (as(inst->getDataType())) - return kIROp_FuncCollection; - else if (as(inst) && !as(inst)) - return kIROp_TypeCollection; - else if (as(inst->getDataType())) - return kIROp_WitnessTableCollection; - else - return kIROp_Invalid; // Return invalid IROp when not supported -} - -// Factory methods for PropagationInfo -IRCollectionBase* CollectionBuilder::makeSingletonSet(IRInst* value) -{ - HashSet singleSet; - singleSet.add(value); - return createCollection(getCollectionTypeForInst(value), singleSet); -} - -IRCollectionBase* CollectionBuilder::makeSet(const HashSet& values) -{ - SLANG_ASSERT(values.getCount() > 0); - return createCollection(getCollectionTypeForInst(*values.begin()), values); -} - // Upcast the value in 'arg' to match the destInfo type. This method inserts // any necessary reinterprets or tag translation instructions. // @@ -159,19 +36,17 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) auto argTUType = as(argInfo); auto destTUType = as(destInfo); - if (getCollectionCount(argTUType) != getCollectionCount(destTUType)) + if (argTUType != destTUType) { // Technically, IRCollectionTaggedUnionType is not a TupleType, // but in practice it works the same way so we'll re-use Slang's // tuple accessors & constructors // - // IRBuilder builder(module); - // setInsertAfterOrdinaryInst(&builder, arg); auto argTableTag = builder->emitGetTagFromTaggedUnion(arg); auto reinterpretedTag = upcastCollection( builder, argTableTag, - makeTagType(destTUType->getWitnessTableCollection())); + builder->getCollectionTagType(destTUType->getWitnessTableCollection())); auto argVal = builder->emitGetValueFromTaggedUnion(arg); auto reinterpretedVal = upcastCollection( @@ -190,8 +65,7 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) // Note that, by the invariant provided by the typeflow analysis, the target // collection must necessarily be a super-set. // - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) + if (argInfo != destInfo) { return builder ->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); @@ -207,8 +81,7 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) // is necessarily a super-set, the target any-value-type is always larger (M >= N), // so we only need a simple reinterpret. // - if (getCollectionCount(as(argInfo)->getCollection()) != - getCollectionCount(as(destInfo)->getCollection())) + if (argInfo != destInfo) { auto argCollection = as(argInfo)->getCollection(); if (argCollection->isSingleton() && as(argCollection->getElement(0))) @@ -217,8 +90,6 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) // type. We'll avoid emitting a reinterpret in this case, and emit a // default-construct instead. // - // IRBuilder builder(module); - // setInsertAfterOrdinaryInst(&builder, arg); return builder->emitDefaultConstruct((IRType*)destInfo); } @@ -227,8 +98,6 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) // If the sets of witness tables are not equal, reinterpret to the // parameter type // - // IRBuilder builder(module); - // setInsertAfterOrdinaryInst(&builder, arg); return builder->emitReinterpret((IRType*)destInfo, arg); } } diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h index 13f51f4ecd1..5c3bb233e5d 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-collection.h @@ -6,19 +6,6 @@ namespace Slang { -IRCollectionTagType* makeTagType(IRCollectionBase* collection); - -// -// Count and indexing helpers -// - -UCount getCollectionCount(IRCollectionBase* collection); -UCount getCollectionCount(IRCollectionTaggedUnionType* taggedUnion); -UCount getCollectionCount(IRCollectionTagType* tagType); - -IRInst* getCollectionElement(IRCollectionBase* collection, UInt index); -IRInst* getCollectionElement(IRCollectionTagType* collectionTagType, UInt index); - // // Helpers to iterate over elements of a collection. // @@ -41,58 +28,4 @@ void forEachInCollection(IRCollectionTagType* tagType, F func) // IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo); -// Builder class that helps greatly with constructing `CollectionBase` instructions, -// which conceptually represent sets, and maintain the property that the equal sets -// should always be represented by the same instruction. -// -// Uses a unique ID assignment to keep stable ordering throughout the lifetime of the -// module. -// -struct CollectionBuilder -{ - // Get a collection builder for 'module'. - CollectionBuilder(IRModule* module); - - // Create an inst to represent the elements in the set. - // - // All insts in `elements` must be global and concrete. They must not - // be collections themselves. - // - // Op must be one of the ops in `CollectionBase` - // - // For a given set, the returned inst is always the same within a single - // module. - // - IRCollectionBase* createCollection(IROp op, const HashSet& elements); - - // Get a suitable collection op-code to use for an set containing 'inst'. - IROp getCollectionTypeForInst(IRInst* inst); - - // Create a collection with a single element - IRCollectionBase* makeSingletonSet(IRInst* value); - - // Create a collection with the given elements (the collection op will be - // automatically deduced using` getCollectionTypeForInst`) - // - IRCollectionBase* makeSet(const HashSet& values); - - // Return a unique ID for the inst. Assuming the module pointer - // is consistent, this should always be the same for a given inst. - // - UInt getUniqueID(IRInst* inst); - -private: - // Reference to parent module - IRModule* module; - - // Unique ID assignment for functions and witness tables. - // - // This is a pointer to a shared dictionary (typically - // a part of the module inst) so that all CollectionBuilder - // objects for the same module will always produce the same - // ordering. - // - Dictionary* uniqueIds; -}; - } // namespace Slang diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index dd9d7c9ea72..7f9b7eccd81 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -393,9 +393,9 @@ std::tuple splitParameterDirectionAndType(IRTyp } // Join parameter direction and a type back into a parameter type -IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo direction, IRType* type) +IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo info, IRType* type) { - switch (direction.kind) + switch (info.kind) { case ParameterDirectionInfo::Kind::In: return type; @@ -404,11 +404,11 @@ IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo directio case ParameterDirectionInfo::Kind::BorrowInOut: return builder->getBorrowInOutParamType(type); case ParameterDirectionInfo::Kind::BorrowIn: - return builder->getBorrowInParamType(type, direction.addressSpace); + return builder->getBorrowInParamType(type, info.addressSpace); case ParameterDirectionInfo::Kind::Ref: - return builder->getRefParamType(type, direction.addressSpace); + return builder->getRefParamType(type, info.addressSpace); default: - SLANG_UNEXPECTED("Unhandled parameter direction in fromDirectionAndType"); + SLANG_UNEXPECTED("Unhandled parameter info in fromDirectionAndType"); } } @@ -457,6 +457,7 @@ struct TypeFlowSpecializationContext // IRCollectionTaggedUnionType* makeTaggedUnionType(IRWitnessTableCollection* tableCollection) { + IRBuilder builder(module); HashSet typeSet; // Create a type collection out of the base types from each table. @@ -468,16 +469,11 @@ struct TypeFlowSpecializationContext typeSet.add(table->getConcreteType()); }); - auto typeCollection = cBuilder.createCollection(kIROp_TypeCollection, typeSet); + auto typeCollection = + cast(builder.getCollection(kIROp_TypeCollection, typeSet)); // Create the tagged union type out of the type and table collection. - IRBuilder builder(module); - List elements = {tableCollection, typeCollection}; - return as(builder.emitIntrinsicInst( - nullptr, - kIROp_CollectionTaggedUnionType, - elements.getCount(), - elements.getBuffer())); + return builder.getCollectionTaggedUnionType(tableCollection, typeCollection); } // Create an unbounded collection. @@ -510,17 +506,19 @@ struct TypeFlowSpecializationContext IRValueOfCollectionType* makeValueOfCollectionType(IRTypeCollection* typeCollection) { IRBuilder builder(module); - IRInst* operand = typeCollection; - return cast( - builder.emitIntrinsicInst(nullptr, kIROp_ValueOfCollectionType, 1, &operand)); + return builder.getValueOfCollectionType(typeCollection); } IRElementOfCollectionType* makeElementOfCollectionType(IRCollectionBase* collection) { IRBuilder builder(module); - IRInst* operand = collection; - return cast( - builder.emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &operand)); + return builder.getElementOfCollectionType(collection); + } + + IRCollectionTagType* makeTagType(IRCollectionBase* collection) + { + IRBuilder builder(module); + return builder.getCollectionTagType(collection); } IRInst* _tryGetInfo(InstWithContext element) @@ -569,8 +567,8 @@ struct TypeFlowSpecializationContext { if (isConcreteType(ptrType->getValueType())) return builder.getPtrTypeWithAddressSpace( - builder.getValueOfCollectionType( - cast(cBuilder.makeSingletonSet(ptrType->getValueType()))), + builder.getValueOfCollectionType(cast( + builder.getSingletonCollection(ptrType->getValueType()))), ptrType); else return none(); @@ -582,7 +580,7 @@ struct TypeFlowSpecializationContext { return builder.getArrayType( builder.getValueOfCollectionType(cast( - cBuilder.makeSingletonSet(arrayType->getElementType()))), + builder.getSingletonCollection(arrayType->getElementType()))), arrayType->getElementCount()); } else @@ -591,7 +589,7 @@ struct TypeFlowSpecializationContext if (isConcreteType(inst->getDataType())) return builder.getValueOfCollectionType( - cast(cBuilder.makeSingletonSet(inst->getDataType()))); + cast(builder.getSingletonCollection(inst->getDataType()))); else return none(); } @@ -651,7 +649,8 @@ struct TypeFlowSpecializationContext forEachInCollection(collection1, [&](IRInst* value) { allValues.add(value); }); forEachInCollection(collection2, [&](IRInst* value) { allValues.add(value); }); - return as(cBuilder.createCollection( + IRBuilder builder(module); + return as(builder.getCollection( collection1->getOp(), allValues)); // Create a new collection with the union of values } @@ -1178,7 +1177,7 @@ struct TypeFlowSpecializationContext { IRBuilder builder(module); returnInfo = builder.getValueOfCollectionType(cast( - cBuilder.makeSingletonSet(concreteReturnType))); + builder.getSingletonCollection(concreteReturnType))); } } @@ -1229,6 +1228,8 @@ struct TypeFlowSpecializationContext IRInst* analyzeCreateExistentialObject(IRInst* context, IRCreateExistentialObject* inst) { SLANG_UNUSED(context); + + IRBuilder builder(module); if (auto interfaceType = as(inst->getDataType())) { if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) @@ -1240,7 +1241,7 @@ struct TypeFlowSpecializationContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) return makeTaggedUnionType(as( - cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); + builder.getCollection(kIROp_WitnessTableCollection, tables))); else { sink->diagnose( @@ -1256,6 +1257,7 @@ struct TypeFlowSpecializationContext IRInst* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) { + IRBuilder builder(module); auto witnessTable = inst->getWitnessTable(); // If we're building an existential for a COM interface, @@ -1268,7 +1270,7 @@ struct TypeFlowSpecializationContext // Concrete case. if (as(witnessTable)) return makeTaggedUnionType( - as(cBuilder.makeSingletonSet(witnessTable))); + as(builder.getSingletonCollection(witnessTable))); // Get the witness table info auto witnessTableInfo = tryGetInfo(context, witnessTable); @@ -1323,6 +1325,7 @@ struct TypeFlowSpecializationContext IRInst* analyzeLoad(IRInst* context, IRInst* inst) { + IRBuilder builder(module); if (auto loadInst = as(inst)) { // If we have a simple load, theres one of two cases: @@ -1346,7 +1349,7 @@ struct TypeFlowSpecializationContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) return makeTaggedUnionType(as( - cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); + builder.getCollection(kIROp_WitnessTableCollection, tables))); else return none(); } @@ -1376,7 +1379,7 @@ struct TypeFlowSpecializationContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) return makeTaggedUnionType(as( - cBuilder.createCollection(kIROp_WitnessTableCollection, tables))); + builder.getCollection(kIROp_WitnessTableCollection, tables))); else return none(); } @@ -1555,11 +1558,11 @@ struct TypeFlowSpecializationContext // type. // SLANG_UNUSED(context); + IRBuilder builder(module); if (isOptionalExistentialType(inst->getDataType())) { - IRBuilder builder(module); auto noneTableSet = cast( - cBuilder.createCollection(kIROp_WitnessTableCollection, getNoneWitness())); + builder.getCollection(kIROp_WitnessTableCollection, getNoneWitness())); return makeTaggedUnionType(noneTableSet); } @@ -1636,12 +1639,13 @@ struct TypeFlowSpecializationContext if (auto elementOfCollectionType = as(witnessTableInfo)) { + IRBuilder builder(module); HashSet results; forEachInCollection( cast(elementOfCollectionType->getCollection()), [&](IRInst* table) { results.add(findWitnessTableEntry(cast(table), key)); }); - return makeElementOfCollectionType(cBuilder.makeSet(results)); + return makeElementOfCollectionType(builder.getCollection(results)); } if (!witnessTableInfo) @@ -1929,7 +1933,8 @@ struct TypeFlowSpecializationContext builder.emitSpecializeInst(typeOfSpecialization, operand, specializationArgs)); } - return makeElementOfCollectionType(cBuilder.makeSet(specializedSet)); + IRBuilder builder(module); + return makeElementOfCollectionType(builder.getCollection(specializedSet)); } if (!operandInfo) @@ -2000,10 +2005,11 @@ struct TypeFlowSpecializationContext } else if (as(arg) || as(arg)) { + IRBuilder builder(module); updateInfo( context, param, - makeElementOfCollectionType(cBuilder.makeSingletonSet(arg)), + makeElementOfCollectionType(builder.getSingletonCollection(arg)), true, workQueue); } @@ -2377,12 +2383,14 @@ struct TypeFlowSpecializationContext IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) { SLANG_UNUSED(context); + IRBuilder builder(module); + // Check if this is a global concrete type, witness table, or function. // If so, it's a concrete element. We'll create a singleton set for it. if (isGlobalInst(inst) && (!as(inst) && (as(inst) || as(inst) || as(inst)))) - return cBuilder.makeSingletonSet(inst); + return builder.getSingletonCollection(inst); auto instType = inst->getDataType(); if (isGlobalInst(inst)) @@ -2943,8 +2951,8 @@ struct TypeFlowSpecializationContext } // If this is a collection, we need to create a new collection with the new type - auto newCollection = - cBuilder.createCollection(kIROp_TypeCollection, collectionElements); + IRBuilder builder(module); + auto newCollection = builder.getCollection(kIROp_TypeCollection, collectionElements); return makeValueOfCollectionType(cast(newCollection)); } else if (currentType == newType) @@ -2974,8 +2982,8 @@ struct TypeFlowSpecializationContext collectionElements.add(newType); // If this is a collection, we need to create a new collection with the new type - auto newCollection = - cBuilder.createCollection(kIROp_TypeCollection, collectionElements); + IRBuilder builder(module); + auto newCollection = builder.getCollection(kIROp_TypeCollection, collectionElements); return makeValueOfCollectionType(cast(newCollection)); } } @@ -3109,8 +3117,9 @@ struct TypeFlowSpecializationContext IRFuncType* getEffectiveFuncType(IRInst* callee) { + IRBuilder builder(module); return getEffectiveFuncTypeForCollection( - cast(cBuilder.makeSingletonSet(callee))); + cast(builder.getSingletonCollection(callee))); } // Helper function for specializing calls. @@ -3513,7 +3522,7 @@ struct TypeFlowSpecializationContext IRInst* witnessTableTag = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { - auto singletonTagType = makeTagType(cBuilder.makeSingletonSet(witnessTable)); + auto singletonTagType = makeTagType(builder.getSingletonCollection(witnessTable)); IRInst* tagValue = builder.emitGetTagOfElementInCollection( (IRType*)singletonTagType, witnessTable, @@ -4043,7 +4052,7 @@ struct TypeFlowSpecializationContext SLANG_ASSERT(taggedUnionType->getWitnessTableCollection()->isSingleton()); auto noneWitnessTable = taggedUnionType->getWitnessTableCollection()->getElement(0); - auto singletonTagType = makeTagType(cBuilder.makeSingletonSet(noneWitnessTable)); + auto singletonTagType = makeTagType(builder.getSingletonCollection(noneWitnessTable)); IRInst* zeroValueOfTagType = builder.emitGetTagOfElementInCollection( (IRType*)singletonTagType, noneWitnessTable, @@ -4237,7 +4246,7 @@ struct TypeFlowSpecializationContext } TypeFlowSpecializationContext(IRModule* module, DiagnosticSink* sink) - : module(module), sink(sink), cBuilder(module) + : module(module), sink(sink) { } @@ -4271,9 +4280,6 @@ struct TypeFlowSpecializationContext // Set of already discovered contexts. HashSet availableContexts; - - // Helper for building collections. - CollectionBuilder cBuilder; }; // Main entry point diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 031af91482b..f1463686a44 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6554,6 +6554,74 @@ IREntryPointLayout* IRBuilder::getEntryPointLayout( operands)); } +IRCollectionBase* IRBuilder::getCollection(IROp op, const HashSet& elements) +{ + if (elements.getCount() == 0) + return nullptr; + + // Verify that all operands are global instructions + for (auto element : elements) + if (element->getParent()->getOp() != kIROp_ModuleInst) + SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); + + List sortedElements; + for (auto element : elements) + sortedElements.add(element); + + // Sort elements by their unique IDs to ensure canonical ordering + sortedElements.sort( + [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); + + return as( + emitIntrinsicInst(nullptr, op, sortedElements.getCount(), sortedElements.getBuffer())); +} + +IRCollectionBase* IRBuilder::getCollection(const HashSet& elements) +{ + SLANG_ASSERT(elements.getCount() > 0); + auto firstElement = *elements.begin(); + return getCollection(getCollectionTypeForInst(firstElement), elements); +} + +IRCollectionBase* IRBuilder::getSingletonCollection(IROp op, IRInst* element) +{ + return getCollection(op, {element}); +} + +IRCollectionBase* IRBuilder::getSingletonCollection(IRInst* element) +{ + return getCollection(getCollectionTypeForInst(element), {element}); +} + +UInt IRBuilder::getUniqueID(IRInst* inst) +{ + auto uniqueIDMap = getModule()->getUniqueIdMap(); + auto existingId = uniqueIDMap->tryGetValue(inst); + if (existingId) + return *existingId; + + auto id = uniqueIDMap->getCount(); + uniqueIDMap->add(inst, id); + return id; +} + +IROp IRBuilder::getCollectionTypeForInst(IRInst* inst) +{ + if (as(inst)) + return kIROp_GenericCollection; + + if (as(inst->getDataType())) + return kIROp_TypeCollection; + else if (as(inst->getDataType())) + return kIROp_FuncCollection; + else if (as(inst) && !as(inst)) + return kIROp_TypeCollection; + else if (as(inst->getDataType())) + return kIROp_WitnessTableCollection; + else + return kIROp_Invalid; // Return invalid IROp when not supported +} + // struct IRDumpContext From ea362d6649ecc21fe33a18c977733cdc35397c9d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 12:09:24 -0400 Subject: [PATCH 076/105] Bulk rename collection instructions --- source/slang/slang-emit.cpp | 2 +- .../slang/slang-ir-any-value-marshalling.cpp | 2 +- source/slang/slang-ir-insts-stable-names.lua | 26 +- source/slang/slang-ir-insts.h | 131 ++-- source/slang/slang-ir-insts.lua | 76 +-- source/slang/slang-ir-layout.cpp | 2 +- .../slang/slang-ir-lower-typeflow-insts.cpp | 147 +++-- source/slang/slang-ir-lower-typeflow-insts.h | 12 +- source/slang/slang-ir-specialize.cpp | 39 +- source/slang/slang-ir-typeflow-collection.cpp | 43 +- source/slang/slang-ir-typeflow-collection.h | 10 +- source/slang/slang-ir-typeflow-specialize.cpp | 584 +++++++++--------- source/slang/slang-ir-typeflow-specialize.h | 2 +- .../slang/slang-ir-witness-table-wrapper.cpp | 98 +-- source/slang/slang-ir.cpp | 38 +- 15 files changed, 540 insertions(+), 672 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index ef90309ca62..96296b2e081 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1150,7 +1150,7 @@ Result linkAndOptimizeIR( if (lowerTaggedUnionTypes(irModule, sink)) requiredLoweringPassSet.reinterpret = true; - lowerTypeCollections(irModule, sink); + lowerUntaggedUnionTypes(irModule, sink); if (requiredLoweringPassSet.reinterpret) lowerReinterpret(targetProgram, irModule, sink); diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 75a0842bcad..99f5698ddb0 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -1018,7 +1018,7 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { return alignUp(offset, 4) + kRTTIHandleSize; } - case kIROp_CollectionTagType: + case kIROp_SetTagType: { return alignUp(offset, 4) + 4; } diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index b4eb342cd0e..cf06de96fa0 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -682,18 +682,18 @@ return { ["Attr.MemoryScope"] = 678, ["Undefined.LoadFromUninitializedMemory"] = 679, ["CUDA_LDG"] = 680, - ["TypeFlowData.CollectionBase.TypeCollection"] = 681, - ["TypeFlowData.CollectionBase.FuncCollection"] = 682, - ["TypeFlowData.CollectionBase.WitnessTableCollection"] = 683, - ["TypeFlowData.CollectionBase.GenericCollection"] = 684, - ["TypeFlowData.UnboundedCollection"] = 685, - ["Type.CollectionTagType"] = 686, - ["Type.CollectionTaggedUnionType"] = 687, + ["TypeFlowData.SetBase.TypeSet"] = 681, + ["TypeFlowData.SetBase.FuncSet"] = 682, + ["TypeFlowData.SetBase.WitnessTableSet"] = 683, + ["TypeFlowData.SetBase.GenericSet"] = 684, + ["TypeFlowData.UnboundedSet"] = 685, + ["Type.SetTagType"] = 686, + ["Type.TaggedUnionType"] = 687, ["CastInterfaceToTaggedUnionPtr"] = 688, ["CastTaggedUnionToInterfacePtr"] = 689, - ["GetTagForSuperCollection"] = 690, - ["GetTagForMappedCollection"] = 691, - ["GetTagForSpecializedCollection"] = 692, + ["GetTagForSuperSet"] = 690, + ["GetTagForMappedSet"] = 691, + ["GetTagForSpecializedSet"] = 692, ["GetTagFromSequentialID"] = 693, ["GetSequentialIDFromTag"] = 694, ["GetElementFromTag"] = 695, @@ -701,9 +701,9 @@ return { ["GetSpecializedDispatcher"] = 697, ["GetTagFromTaggedUnion"] = 698, ["GetValueFromTaggedUnion"] = 699, - ["Type.ValueOfCollectionType"] = 700, - ["Type.ElementOfCollectionType"] = 701, + ["Type.UntaggedUnionType"] = 700, + ["Type.ElementOfSetType"] = 701, ["MakeTaggedUnion"] = 702, ["GetTypeTagFromTaggedUnion"] = 703, - ["GetTagOfElementInCollection"] = 704 + ["GetTagOfElementInSet"] = 704 } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e44f4ba233d..29c5b3707f9 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2843,7 +2843,7 @@ struct IRTypeFlowData : IRInst }; FIDDLE() -struct IRCollectionBase : IRTypeFlowData +struct IRSetBase : IRTypeFlowData { FIDDLE(baseInst()) UInt getCount() { return getOperandCount(); } @@ -2852,53 +2852,50 @@ struct IRCollectionBase : IRTypeFlowData }; FIDDLE() -struct IRWitnessTableCollection : IRCollectionBase +struct IRWitnessTableSet : IRSetBase { FIDDLE(leafInst()) }; FIDDLE() -struct IRTypeCollection : IRCollectionBase +struct IRTypeSet : IRSetBase { FIDDLE(leafInst()) }; FIDDLE() -struct IRCollectionTagType : IRType +struct IRSetTagType : IRType { FIDDLE(leafInst()) - IRCollectionBase* getCollection() { return as(getOperand(0)); } - bool isSingleton() { return getCollection()->isSingleton(); } + IRSetBase* getSet() { return as(getOperand(0)); } + bool isSingleton() { return getSet()->isSingleton(); } }; FIDDLE() -struct IRCollectionTaggedUnionType : IRType +struct IRTaggedUnionType : IRType { FIDDLE(leafInst()) - IRWitnessTableCollection* getWitnessTableCollection() - { - return as(getOperand(0)); - } - IRTypeCollection* getTypeCollection() { return as(getOperand(1)); } + IRWitnessTableSet* getWitnessTableSet() { return as(getOperand(0)); } + IRTypeSet* getTypeSet() { return as(getOperand(1)); } bool isSingleton() { - return getTypeCollection()->isSingleton() && getWitnessTableCollection()->isSingleton(); + return getTypeSet()->isSingleton() && getWitnessTableSet()->isSingleton(); } }; FIDDLE() -struct IRElementOfCollectionType : IRType +struct IRElementOfSetType : IRType { FIDDLE(leafInst()) - IRCollectionBase* getCollection() { return as(getOperand(0)); } + IRSetBase* getSet() { return as(getOperand(0)); } }; FIDDLE() -struct IRValueOfCollectionType : IRType +struct IRUntaggedUnionType : IRType { FIDDLE(leafInst()) - IRCollectionBase* getCollection() { return as(getOperand(0)); } + IRSetBase* getSet() { return as(getOperand(0)); } }; // Generate struct definitions for all IR instructions not explicitly defined in this file @@ -4111,29 +4108,28 @@ struct IRBuilder IRMetalSetIndices* emitMetalSetIndices(IRInst* index, IRInst* indices); // TODO: Move all the collection-based ops into the builder. - IRUnboundedCollection* emitUnboundedCollection() + IRUnboundedSet* emitUnboundedSet() { - return cast( - emitIntrinsicInst(nullptr, kIROp_UnboundedCollection, 0, nullptr)); + return cast(emitIntrinsicInst(nullptr, kIROp_UnboundedSet, 0, nullptr)); } IRGetElementFromTag* emitGetElementFromTag(IRInst* tag) { - auto tagType = cast(tag->getDataType()); - IRInst* collection = tagType->getCollection(); - auto elementType = cast( - emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &collection)); + auto tagType = cast(tag->getDataType()); + IRInst* collection = tagType->getSet(); + auto elementType = cast( + emitIntrinsicInst(nullptr, kIROp_ElementOfSetType, 1, &collection)); return cast( emitIntrinsicInst(elementType, kIROp_GetElementFromTag, 1, &tag)); } IRGetTagFromTaggedUnion* emitGetTagFromTaggedUnion(IRInst* tag) { - auto taggedUnionType = cast(tag->getDataType()); + auto taggedUnionType = cast(tag->getDataType()); - IRInst* collection = taggedUnionType->getWitnessTableCollection(); - auto tableTagType = cast( - emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collection)); + IRInst* collection = taggedUnionType->getWitnessTableSet(); + auto tableTagType = + cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &collection)); return cast( emitIntrinsicInst(tableTagType, kIROp_GetTagFromTaggedUnion, 1, &tag)); @@ -4141,11 +4137,11 @@ struct IRBuilder IRGetTypeTagFromTaggedUnion* emitGetTypeTagFromTaggedUnion(IRInst* tag) { - auto taggedUnionType = cast(tag->getDataType()); + auto taggedUnionType = cast(tag->getDataType()); - IRInst* collection = taggedUnionType->getTypeCollection(); - auto typeTagType = cast( - emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collection)); + IRInst* collection = taggedUnionType->getTypeSet(); + auto typeTagType = + cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &collection)); return cast( emitIntrinsicInst(typeTagType, kIROp_GetTypeTagFromTaggedUnion, 1, &tag)); @@ -4153,36 +4149,33 @@ struct IRBuilder IRGetValueFromTaggedUnion* emitGetValueFromTaggedUnion(IRInst* taggedUnion) { - auto taggedUnionType = cast(taggedUnion->getDataType()); + auto taggedUnionType = cast(taggedUnion->getDataType()); - IRInst* typeCollection = taggedUnionType->getTypeCollection(); - auto valueOfTypeCollectionType = cast( - emitIntrinsicInst(nullptr, kIROp_ValueOfCollectionType, 1, &typeCollection)); + IRInst* typeSet = taggedUnionType->getTypeSet(); + auto valueOfTypeSetType = cast( + emitIntrinsicInst(nullptr, kIROp_UntaggedUnionType, 1, &typeSet)); - return cast(emitIntrinsicInst( - valueOfTypeCollectionType, - kIROp_GetValueFromTaggedUnion, - 1, - &taggedUnion)); + return cast( + emitIntrinsicInst(valueOfTypeSetType, kIROp_GetValueFromTaggedUnion, 1, &taggedUnion)); } IRGetDispatcher* emitGetDispatcher( IRFuncType* funcType, - IRWitnessTableCollection* witnessTableCollection, + IRWitnessTableSet* witnessTableSet, IRStructKey* key) { - IRInst* args[] = {witnessTableCollection, key}; + IRInst* args[] = {witnessTableSet, key}; return cast(emitIntrinsicInst(funcType, kIROp_GetDispatcher, 2, args)); } IRGetSpecializedDispatcher* emitGetSpecializedDispatcher( IRFuncType* funcType, - IRWitnessTableCollection* witnessTableCollection, + IRWitnessTableSet* witnessTableSet, IRStructKey* key, List const& specArgs) { List args; - args.add(witnessTableCollection); + args.add(witnessTableSet); args.add(key); for (auto specArg : specArgs) { @@ -4195,49 +4188,45 @@ struct IRBuilder args.getBuffer())); } - IRValueOfCollectionType* getValueOfCollectionType(IRInst* operand) + IRUntaggedUnionType* getUntaggedUnionType(IRInst* operand) { - return as( - emitIntrinsicInst(nullptr, kIROp_ValueOfCollectionType, 1, &operand)); + return as( + emitIntrinsicInst(nullptr, kIROp_UntaggedUnionType, 1, &operand)); } - IRElementOfCollectionType* getElementOfCollectionType(IRInst* operand) + IRElementOfSetType* getElementOfSetType(IRInst* operand) { - return as( - emitIntrinsicInst(nullptr, kIROp_ElementOfCollectionType, 1, &operand)); + return as( + emitIntrinsicInst(nullptr, kIROp_ElementOfSetType, 1, &operand)); } - IRCollectionTaggedUnionType* getCollectionTaggedUnionType( - IRWitnessTableCollection* tables, - IRTypeCollection* types) + IRTaggedUnionType* getTaggedUnionType(IRWitnessTableSet* tables, IRTypeSet* types) { IRInst* operands[] = {tables, types}; - return as( - emitIntrinsicInst(nullptr, kIROp_CollectionTaggedUnionType, 2, operands)); + return as( + emitIntrinsicInst(nullptr, kIROp_TaggedUnionType, 2, operands)); } - IRCollectionTagType* getCollectionTagType(IRCollectionBase* collection) + IRSetTagType* getSetTagType(IRSetBase* collection) { IRInst* operands[] = {collection}; - return cast( - emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, operands)); + return cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, operands)); } - IRGetTagOfElementInCollection* emitGetTagOfElementInCollection( + IRGetTagOfElementInSet* emitGetTagOfElementInSet( IRType* tagType, IRInst* element, IRInst* collection) { - SLANG_ASSERT(tagType->getOp() == kIROp_CollectionTagType); + SLANG_ASSERT(tagType->getOp() == kIROp_SetTagType); IRInst* args[] = {element, collection}; - return cast( - emitIntrinsicInst(tagType, kIROp_GetTagOfElementInCollection, 2, args)); + return cast( + emitIntrinsicInst(tagType, kIROp_GetTagOfElementInSet, 2, args)); } - IRCollectionTagType* getCollectionTagType(IRInst* collection) + IRSetTagType* getSetTagType(IRInst* collection) { - return cast( - emitIntrinsicInst(nullptr, kIROp_CollectionTagType, 1, &collection)); + return cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &collection)); } // @@ -5017,15 +5006,15 @@ struct IRBuilder void addRayPayloadDecoration(IRType* inst) { addDecoration(inst, kIROp_RayPayloadDecoration); } - IRCollectionBase* getCollection(IROp op, const HashSet& elements); - IRCollectionBase* getCollection(const HashSet& elements); + IRSetBase* getSet(IROp op, const HashSet& elements); + IRSetBase* getSet(const HashSet& elements); - IRCollectionBase* getSingletonCollection(IROp op, IRInst* element); - IRCollectionBase* getSingletonCollection(IRInst* element); + IRSetBase* getSingletonSet(IROp op, IRInst* element); + IRSetBase* getSingletonSet(IRInst* element); UInt getUniqueID(IRInst* inst); - IROp getCollectionTypeForInst(IRInst* inst); + IROp getSetTypeForInst(IRInst* inst); }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index da501d0575b..0b03bbda3aa 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -677,34 +677,34 @@ local insts = { }, }, }, - { ValueOfCollectionType = { + { UntaggedUnionType = { hoistable = true, -- A type that represents that the value's _type_ is one of types in the collection operand. } }, - { ElementOfCollectionType = { + { ElementOfSetType = { hoistable = true, -- A type that represents that the value must be an element of the collection operand. } }, - { CollectionTagType = { + { SetTagType = { hoistable = true, -- Represents a tag-type for a collection. -- - -- An inst whose type is CollectionTagType(collection) is semantically carrying a + -- An inst whose type is SetTagType(collection) is semantically carrying a -- run-time value that "picks" one of the elements of the collection operand. -- - -- Only operand is a CollectionBase + -- Only operand is a SetBase } }, - { CollectionTaggedUnionType = { + { TaggedUnionType = { hoistable = true, -- Represents a tagged union type. -- - -- An inst whose type is a CollectionTaggedUnionType(typeCollection, witnessTableCollection) is semantically carrying a tuple of - -- two values: a value of CollectionTagType(witnessTableCollection) to represent the tag, and a payload value of type - -- ValueOfCollectionType(typeCollection), which conceptually represents a union/"anyvalue" type. + -- An inst whose type is a TaggedUnionType(typeSet, witnessTableSet) is semantically carrying a tuple of + -- two values: a value of SetTagType(witnessTableSet) to represent the tag, and a payload value of type + -- UntaggedUnionType(typeSet), which conceptually represents a union/"anyvalue" type. -- -- This is most commonly used to specialize the type of existential insts once the possibilities can be statically determined. -- - -- Operands are a TypeCollection and a WitnessTableCollection that represent the possibilities of the existential + -- Operands are a TypeSet and a WitnessTableSet that represent the possibilities of the existential } } }, }, @@ -2683,15 +2683,15 @@ local insts = { -- A collection of IR instructions used for propagation analysis. hoistable = true, { - CollectionBase = { + SetBase = { -- Base class for all collection types. -- -- Semantically, collections model sets of concrete values, and use Slang's de-duplication infrastructure -- to allow set-equality to be the same as inst identity. -- - -- - Collection ops have one or more operands that represent the elements of the collection + -- - Set ops have one or more operands that represent the elements of the collection -- - -- - Collection ops must have at least one operand. A zero-operand collection is illegal. + -- - Set ops must have at least one operand. A zero-operand collection is illegal. -- The type-flow pass will represent this case using nullptr, so that uniqueness is preserved. -- -- - All operands of a collection _must_ be concrete, individual insts @@ -2701,21 +2701,21 @@ local insts = { -- -- - Since collections are hositable, collection ops should (consequently) only appear in the global scope. -- - -- - Collection operands must be consistently sorted. i.e. a TypeCollection(A, B) and TypeCollection(B, A) + -- - Set operands must be consistently sorted. i.e. a TypeSet(A, B) and TypeSet(B, A) -- cannot exist at the same time, but either one is okay. -- - -- - To help with the implementation of collections, the CollectionBuilder class is provided + -- - To help with the implementation of collections, the SetBuilder class is provided -- in slang-ir-typeflow-collection.h. - -- All collection insts must be built using the CollectionBuilder, which uses a persistent map on the module + -- All collection insts must be built using the SetBuilder, which uses a persistent map on the module -- inst to ensure stable ordering. -- - { TypeCollection = {} }, - { FuncCollection = {} }, - { WitnessTableCollection = {} }, - { GenericCollection = {} }, + { TypeSet = {} }, + { FuncSet = {} }, + { WitnessTableSet = {} }, + { GenericSet = {} }, }, }, - { UnboundedCollection = { + { UnboundedSet = { -- -- A catch-all opcode to represent unbounded collections during -- the type-flow specialization pass. @@ -2741,15 +2741,15 @@ local insts = { { CastTaggedUnionToInterfacePtr = { -- Cast a tagged-union pointer with a known set to a corresponding interface-typed pointer. } }, - { GetTagForSuperCollection = { + { GetTagForSuperSet = { -- Translate a tag from a set to its equivalent in a super-set -- TODO: Lower using a global ID and not local IDs + mapping ops. } }, - { GetTagForMappedCollection = { + { GetTagForMappedSet = { -- Translate a tag from a set to its equivalent in a different set -- based on a mapping induced by a lookup key } }, - { GetTagForSpecializedCollection = { + { GetTagForSpecializedSet = { -- Translate a tag from a generic set to its equivalent in a specialized set -- based on a mapping that is encoded in the operands of this tag instruction } }, @@ -2762,8 +2762,8 @@ local insts = { } }, { GetElementFromTag = { -- Translate a tag to its corresponding element in the collection. - -- Input's type: CollectionTagType(collection). - -- Output's type: ElementOfCollectionType(collection) + -- Input's type: SetTagType(collection). + -- Output's type: ElementOfSetType(collection) -- operands = {{"tag"}} } }, @@ -2775,11 +2775,11 @@ local insts = { -- or else this is a malformed inst. -- -- Output: a value of 'FuncType' that can be called. - -- This func-type will take a `TagType(witnessTableCollection)` as the first parameter to + -- This func-type will take a `TagType(witnessTableSet)` as the first parameter to -- discriminate which witness table to use, and the rest of the parameters. -- hoistable = true, - operands = {{"witnessTableCollection", "IRWitnessTableCollection"}, {"lookupKey", "IRStructKey"}} + operands = {{"witnessTableSet", "IRWitnessTableSet"}, {"lookupKey", "IRStructKey"}} } }, { GetSpecializedDispatcher = { -- Get a specialized dispatcher function for a given witness table set + key, where @@ -2791,36 +2791,36 @@ local insts = { -- A set of specialization arguments (these must be concrete/global types or collections) -- -- Output: a value of `FuncType` that can be called. - -- This func-type will take a `TagType(witnessTableCollection)` as the first parameter to + -- This func-type will take a `TagType(witnessTableSet)` as the first parameter to -- discriminate which generic to use, and the rest of the parameters. -- hoistable = true } }, { GetTagFromTaggedUnion = { -- Translate a tagged-union value to its corresponding tag in the tagged-union's set. - -- Input's type: CollectionTaggedUnionType(typeCollection, tableCollection) - -- Output's type: CollectionTagType(tableCollection) + -- Input's type: TaggedUnionType(typeSet, tableSet) + -- Output's type: SetTagType(tableSet) operands = {{"taggedUnionValue"}} } }, { GetTypeTagFromTaggedUnion = { -- Translate a tagged-union value to its corresponding type tag in the tagged-union's set. - -- Input's type: CollectionTaggedUnionType(typeCollection, tableCollection) - -- Output's type: CollectionTagType(typeCollection) + -- Input's type: TaggedUnionType(typeSet, tableSet) + -- Output's type: SetTagType(typeSet) operands = {{"taggedUnionValue"}} } }, { GetValueFromTaggedUnion = { -- Translate a tagged-union value to its corresponding value in the tagged-union's set. - -- Input's type: CollectionTaggedUnionType(typeCollection, tableCollection) - -- Output's type: ValueOfCollectionType(typeCollection) + -- Input's type: TaggedUnionType(typeSet, tableSet) + -- Output's type: UntaggedUnionType(typeSet) operands = {{"taggedUnionValue"}} } }, { MakeTaggedUnion = { -- Create a tagged-union value from a tag and a value. - -- Input's type: CollectionTagType(tableCollection), ValueOfCollectionType(typeCollection) - -- Output's type: CollectionTaggedUnionType(typeCollection, tableCollection) + -- Input's type: SetTagType(tableSet), UntaggedUnionType(typeSet) + -- Output's type: TaggedUnionType(typeSet, tableSet) operands = { { "tag" }, { "value" } }, } }, - { GetTagOfElementInCollection = { + { GetTagOfElementInSet = { -- Get the tag corresponding to an element in a collection. hoistable = true } }, diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 594dbd6cb2c..f282be81a48 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -300,7 +300,7 @@ Result IRTypeLayoutRules::calcSizeAndAlignment( return SLANG_OK; } break; - case kIROp_CollectionTagType: + case kIROp_SetTagType: { outSizeAndAlignment->size = 4; outSizeAndAlignment->alignment = 4; diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index fc5aa612331..b1ab94013ee 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -210,8 +210,8 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi return func; } -// This context lowers `GetTagOfElementInCollection`, -// `GetTagForSuperCollection`, and `GetTagForMappedCollection` instructions, +// This context lowers `GetTagOfElementInSet`, +// `GetTagForSuperSet`, and `GetTagForMappedSet` instructions, // struct TagOpsLoweringContext : public InstPassBase { @@ -220,33 +220,32 @@ struct TagOpsLoweringContext : public InstPassBase { } - void lowerGetTagForSuperCollection(IRGetTagForSuperCollection* inst) + void lowerGetTagForSuperSet(IRGetTagForSuperSet* inst) { inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); } - void lowerGetTagForMappedCollection(IRGetTagForMappedCollection* inst) + void lowerGetTagForMappedSet(IRGetTagForMappedSet* inst) { - auto srcCollection = cast( - cast(inst->getOperand(0)->getDataType())->getOperand(0)); - auto destCollection = - cast(cast(inst->getDataType())->getOperand(0)); + auto srcSet = cast( + cast(inst->getOperand(0)->getDataType())->getOperand(0)); + auto destSet = cast(cast(inst->getDataType())->getOperand(0)); auto key = cast(inst->getOperand(1)); IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); Dictionary mapping; - for (UInt i = 0; i < srcCollection->getCount(); i++) + for (UInt i = 0; i < srcSet->getCount(); i++) { - // Find in destCollection + // Find in destSet bool found = false; auto srcMappedElement = - findWitnessTableEntry(cast(srcCollection->getElement(i)), key); - for (UInt j = 0; j < destCollection->getCount(); j++) + findWitnessTableEntry(cast(srcSet->getElement(i)), key); + for (UInt j = 0; j < destSet->getCount(); j++) { - auto destElement = destCollection->getElement(j); + auto destElement = destSet->getElement(j); if (srcMappedElement == destElement) { found = true; @@ -254,7 +253,7 @@ struct TagOpsLoweringContext : public InstPassBase // it must have been assigned a unique ID. // mapping.add( - builder.getUniqueID(srcCollection->getElement(i)), + builder.getUniqueID(srcSet->getElement(i)), builder.getUniqueID(destElement)); break; // Found the index } @@ -262,7 +261,7 @@ struct TagOpsLoweringContext : public InstPassBase if (!found) { - // destCollection must be a super-set + // destSet must be a super-set SLANG_UNEXPECTED("Element not found in destination collection"); } } @@ -278,7 +277,7 @@ struct TagOpsLoweringContext : public InstPassBase inst->removeAndDeallocate(); } - void lowerGetTagOfElementInCollection(IRGetTagOfElementInCollection* inst) + void lowerGetTagOfElementInSet(IRGetTagOfElementInSet* inst) { IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); @@ -293,14 +292,14 @@ struct TagOpsLoweringContext : public InstPassBase { switch (inst->getOp()) { - case kIROp_GetTagForSuperCollection: - lowerGetTagForSuperCollection(as(inst)); + case kIROp_GetTagForSuperSet: + lowerGetTagForSuperSet(as(inst)); break; - case kIROp_GetTagForMappedCollection: - lowerGetTagForMappedCollection(as(inst)); + case kIROp_GetTagForMappedSet: + lowerGetTagForMappedSet(as(inst)); break; - case kIROp_GetTagOfElementInCollection: - lowerGetTagOfElementInCollection(as(inst)); + case kIROp_GetTagOfElementInSet: + lowerGetTagOfElementInSet(as(inst)); break; default: break; @@ -328,7 +327,7 @@ struct DispatcherLoweringContext : public InstPassBase // We'll also replace the callee in all 'call' insts. // - auto witnessTableCollection = cast(dispatcher->getOperand(0)); + auto witnessTableSet = cast(dispatcher->getOperand(0)); auto key = cast(dispatcher->getOperand(1)); List specArgs; @@ -339,8 +338,8 @@ struct DispatcherLoweringContext : public InstPassBase Dictionary elements; IRBuilder builder(dispatcher->getModule()); - forEachInCollection( - witnessTableCollection, + forEachInSet( + witnessTableSet, [&](IRInst* table) { auto generic = @@ -359,10 +358,10 @@ struct DispatcherLoweringContext : public InstPassBase specArgs.getCount(), specArgs.getBuffer()); - auto singletonTag = builder.emitGetTagOfElementInCollection( - builder.getCollectionTagType(witnessTableCollection), + auto singletonTag = builder.emitGetTagOfElementInSet( + builder.getSetTagType(witnessTableSet), table, - witnessTableCollection); + witnessTableSet); elements.add(singletonTag, specializedFunc); }); @@ -397,20 +396,20 @@ struct DispatcherLoweringContext : public InstPassBase // We'll also replace the callee in all 'call' insts. // - auto witnessTableCollection = cast(dispatcher->getOperand(0)); + auto witnessTableSet = cast(dispatcher->getOperand(0)); auto key = cast(dispatcher->getOperand(1)); IRBuilder builder(dispatcher->getModule()); Dictionary elements; - forEachInCollection( - witnessTableCollection, + forEachInSet( + witnessTableSet, [&](IRInst* table) { - auto tag = builder.emitGetTagOfElementInCollection( - builder.getCollectionTagType(witnessTableCollection), + auto tag = builder.emitGetTagOfElementInSet( + builder.getSetTagType(witnessTableSet), table, - witnessTableCollection); + witnessTableSet); elements.add( tag, cast(findWitnessTableEntry(cast(table), key))); @@ -458,24 +457,24 @@ bool lowerDispatchers(IRModule* module, DiagnosticSink* sink) return true; } -// This context lowers `TypeCollection` instructions. -struct CollectionLoweringContext : public InstPassBase +// This context lowers `TypeSet` instructions. +struct SetLoweringContext : public InstPassBase { - CollectionLoweringContext(IRModule* module) + SetLoweringContext(IRModule* module) : InstPassBase(module) { } - void lowerValueOfCollectionType(IRValueOfCollectionType* valueOfCollectionType) + void lowerUntaggedUnionType(IRUntaggedUnionType* valueOfSetType) { // Type collections are replaced with `AnyValueType` large enough to hold // any of the types in the collection. // HashSet types; - for (UInt i = 0; i < valueOfCollectionType->getCollection()->getCount(); i++) + for (UInt i = 0; i < valueOfSetType->getSet()->getCount(); i++) { - if (auto type = as(valueOfCollectionType->getCollection()->getElement(i))) + if (auto type = as(valueOfSetType->getSet()->getElement(i))) { types.add(type); } @@ -483,24 +482,24 @@ struct CollectionLoweringContext : public InstPassBase IRBuilder builder(module); auto anyValueType = createAnyValueType(&builder, types); - valueOfCollectionType->replaceUsesWith(anyValueType); + valueOfSetType->replaceUsesWith(anyValueType); } void processModule() { - processInstsOfType( - kIROp_ValueOfCollectionType, - [&](IRValueOfCollectionType* inst) { return lowerValueOfCollectionType(inst); }); + processInstsOfType( + kIROp_UntaggedUnionType, + [&](IRUntaggedUnionType* inst) { return lowerUntaggedUnionType(inst); }); } }; -// Lower `ValueOfCollectionType(TypeCollection(...))` instructions by replacing them with +// Lower `UntaggedUnionType(TypeSet(...))` instructions by replacing them with // appropriate `AnyValueType` instructions. // -void lowerTypeCollections(IRModule* module, DiagnosticSink* sink) +void lowerUntaggedUnionTypes(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); - CollectionLoweringContext context(module); + SetLoweringContext context(module); context.processModule(); } @@ -533,13 +532,13 @@ struct SequentialIDTagLoweringContext : public InstPassBase Dictionary mapping; // Map from sequential ID to unique ID - auto destCollection = cast(inst->getDataType())->getCollection(); + auto destSet = cast(inst->getDataType())->getSet(); IRBuilder builder(inst); builder.setInsertAfter(inst); - forEachInCollection( - destCollection, + forEachInSet( + destSet, [&](IRInst* table) { // Get unique ID for the witness table @@ -582,13 +581,13 @@ struct SequentialIDTagLoweringContext : public InstPassBase Dictionary mapping; // Map from sequential ID to unique ID - auto destCollection = cast(srcTagInst->getDataType())->getCollection(); + auto destSet = cast(srcTagInst->getDataType())->getSet(); IRBuilder builder(inst); builder.setInsertAfter(inst); - forEachInCollection( - destCollection, + forEachInSet( + destSet, [&](IRInst* table) { // Get unique ID for the witness table @@ -637,7 +636,7 @@ void lowerTagInsts(IRModule* module, DiagnosticSink* sink) tagContext.processModule(); } -// This context lowers `IRCollectionTagType` instructions, by replacing +// This context lowers `IRSetTagType` instructions, by replacing // them with a suitable integer type. struct TagTypeLoweringContext : public InstPassBase { @@ -648,9 +647,9 @@ struct TagTypeLoweringContext : public InstPassBase void processModule() { - processInstsOfType( - kIROp_CollectionTagType, - [&](IRCollectionTagType* inst) + processInstsOfType( + kIROp_SetTagType, + [&](IRSetTagType* inst) { IRBuilder builder(inst->getModule()); inst->replaceUsesWith(builder.getUIntType()); @@ -721,9 +720,9 @@ struct TaggedUnionLoweringContext : public InstPassBase // e.g. // // let basePtr : PtrType(InterfaceType(I)) = /* ... */; - // let tuPtr : PtrType(CollectionTaggedUnionType(types, tables)) = + // let tuPtr : PtrType(TaggedUnionType(types, tables)) = // CastInterfaceToTaggedUnionPtr(basePtr); - // let loadedVal : CollectionTaggedUnionType(...) = Load(tuPtr); + // let loadedVal : TaggedUnionType(...) = Load(tuPtr); // // becomes // @@ -814,13 +813,13 @@ struct TaggedUnionLoweringContext : public InstPassBase inst->removeAndDeallocate(); } - IRType* convertToTupleType(IRCollectionTaggedUnionType* taggedUnion) + IRType* convertToTupleType(IRTaggedUnionType* taggedUnion) { - // Replace `CollectionTaggedUnionType(typeCollection, tableCollection)` with - // `TupleType(CollectionTagType(tableCollection), typeCollection)` + // Replace `TaggedUnionType(typeSet, tableSet)` with + // `TupleType(SetTagType(tableSet), typeSet)` // // Unless the collection has a single element, in which case we - // replace it with `TupleType(CollectionTagType(tableCollection), elementType)` + // replace it with `TupleType(SetTagType(tableSet), elementType)` // // We still maintain a tuple type (even though it's not really necesssary) to avoid // breaking any operations that assumed this is a tuple. @@ -830,16 +829,16 @@ struct TaggedUnionLoweringContext : public InstPassBase IRBuilder builder(module); builder.setInsertInto(module); - auto typeCollection = builder.getValueOfCollectionType(taggedUnion->getTypeCollection()); - auto tableCollection = taggedUnion->getWitnessTableCollection(); + auto typeSet = builder.getUntaggedUnionType(taggedUnion->getTypeSet()); + auto tableSet = taggedUnion->getWitnessTableSet(); - if (taggedUnion->getTypeCollection()->isSingleton()) + if (taggedUnion->getTypeSet()->isSingleton()) return builder.getTupleType(List( - {(IRType*)builder.getCollectionTagType(tableCollection), - (IRType*)taggedUnion->getTypeCollection()->getElement(0)})); + {(IRType*)builder.getSetTagType(tableSet), + (IRType*)taggedUnion->getTypeSet()->getElement(0)})); - return builder.getTupleType(List( - {(IRType*)builder.getCollectionTagType(tableCollection), (IRType*)typeCollection})); + return builder.getTupleType( + List({(IRType*)builder.getSetTagType(tableSet), (IRType*)typeSet})); } bool lowerGetValueFromTaggedUnion(IRGetValueFromTaggedUnion* inst) @@ -905,12 +904,12 @@ struct TaggedUnionLoweringContext : public InstPassBase bool processModule() { - // First, we'll lower all CollectionTaggedUnionType insts + // First, we'll lower all TaggedUnionType insts // into tuples. // - processInstsOfType( - kIROp_CollectionTaggedUnionType, - [&](IRCollectionTaggedUnionType* inst) + processInstsOfType( + kIROp_TaggedUnionType, + [&](IRTaggedUnionType* inst) { inst->replaceUsesWith(convertToTupleType(inst)); inst->removeAndDeallocate(); diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h index cac05626d02..930bf4fd830 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -5,19 +5,19 @@ namespace Slang { -// Lower `ValueOfCollectionType` types. -void lowerTypeCollections(IRModule* module, DiagnosticSink* sink); +// Lower `UntaggedUnionType` types. +void lowerUntaggedUnionTypes(IRModule* module, DiagnosticSink* sink); -// Lower `CollectionTaggedUnion` and `CastInterfaceToTaggedUnionPtr` instructions +// Lower `SetTaggedUnion` and `CastInterfaceToTaggedUnionPtr` instructions // May create new `Reinterpret` instructions. // bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); -// Lower `CollectionTagType` types +// Lower `SetTagType` types void lowerTagTypes(IRModule* module); -// Lower `GetTagOfElementInCollection`, -// `GetTagForSuperCollection`, and `GetTagForMappedCollection` instructions, +// Lower `GetTagOfElementInSet`, +// `GetTagForSuperSet`, and `GetTagForMappedSet` instructions, // void lowerTagInsts(IRModule* module, DiagnosticSink* sink); diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index a0f1d2704f7..89db89f9109 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -864,13 +864,13 @@ struct SpecializationContext IRInterfaceType* interfaceType = nullptr; if (!witnessTable) { - if (auto collection = as(lookupInst->getWitnessTable())) + if (auto collection = as(lookupInst->getWitnessTable())) { auto requirementKey = lookupInst->getRequirementKey(); HashSet satisfyingValSet; bool skipSpecialization = false; - forEachInCollection( + forEachInSet( collection, [&](IRInst* instElement) { @@ -890,17 +890,16 @@ struct SpecializationContext if (!skipSpecialization) { IRBuilder builder(module); - auto newCollection = builder.getCollection(satisfyingValSet); + auto newSet = builder.getSet(satisfyingValSet); addUsersToWorkList(lookupInst); - if (as(newCollection)) + if (as(newSet)) { - lookupInst->replaceUsesWith( - builder.getValueOfCollectionType(newCollection)); + lookupInst->replaceUsesWith(builder.getUntaggedUnionType(newSet)); lookupInst->removeAndDeallocate(); } - else if (as(newCollection)) + else if (as(newSet)) { - lookupInst->replaceUsesWith(newCollection); + lookupInst->replaceUsesWith(newSet); lookupInst->removeAndDeallocate(); } else @@ -3180,8 +3179,8 @@ void finalizeSpecialization(IRModule* module) } } -// Evaluate a `Specialize` inst where the arguments are collections rather than -// concrete singleton types and the generic returns a function. +// Evaluate a `Specialize` inst where the arguments are sets of types (or witness tables) +// rather than concrete singleton types and the generic returns a function. // // This needs to be slightly different from the usual case because the function // needs dynamic information to select a specific element from each collection @@ -3190,7 +3189,7 @@ void finalizeSpecialization(IRModule* module) // The resulting function will therefore have additional parameters at the beginning // to accept this information. // -IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) +IRInst* specializeGenericWithSetArgs(IRSpecialize* specializeInst) { // The high-level logic for specializing a generic to operate over collections // is similar to specializing a simple generic: @@ -3209,8 +3208,8 @@ IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // // - Add any dynamic parameters of the generic to the function's first block. Keep track of the // first block for later. For now, we only treat `WitnessTableType` parameters that have - // `WitnessTableCollection` arguments (with atleast 2 distinct elements) as dynamic. Each such - // parameter will get a corresponding parameter of `TagType(tableCollection)` + // `WitnessTableSet` arguments (with atleast 2 distinct elements) as dynamic. Each such + // parameter will get a corresponding parameter of `TagType(tableSet)` // // - Clone in the rest of the generic's body into the first block of the function. // The tricky part here is that we may have parameter types that depend on other parameters. @@ -3256,13 +3255,13 @@ IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) for (auto param : generic->getFirstBlock()->getParams()) { auto specArg = specializeInst->getArg(argIndex++); - if (auto collection = as(specArg)) + if (auto collection = as(specArg)) { // We're dealing with a set of types. if (as(param->getDataType())) { SLANG_ASSERT("Should not happen"); - cloneEnv.mapOldValToNew[param] = builder.getValueOfCollectionType(collection); + cloneEnv.mapOldValToNew[param] = builder.getUntaggedUnionType(collection); } else if (as(param->getDataType())) { @@ -3274,7 +3273,7 @@ IRInst* specializeDynamicGeneric(IRSpecialize* specializeInst) // We'll create an integer parameter for all the rest of // the insts which will may need the runtime tag. // - auto tagType = (IRType*)builder.getCollectionTagType(collection); + auto tagType = (IRType*)builder.getSetTagType(collection); // cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); extraParamMap.add(param, builder.emitParam(tagType)); extraParamTypes.add(tagType); @@ -3400,8 +3399,8 @@ IRInst* specializeGenericImpl( IRModule* module, SpecializationContext* context) { - if (isDynamicGeneric(specializeInst)) - return specializeDynamicGeneric(specializeInst); + if (isSetSpecializedGeneric(specializeInst)) + return specializeGenericWithSetArgs(specializeInst); // Effectively, specializing a generic amounts to "calling" the generic // on its concrete argument values and computing the @@ -3540,8 +3539,8 @@ IRInst* specializeGeneric(IRSpecialize* specializeInst) if (!module) return specializeInst; - if (isDynamicGeneric(specializeInst)) - return specializeDynamicGeneric(specializeInst); + if (isSetSpecializedGeneric(specializeInst)) + return specializeGenericWithSetArgs(specializeInst); // Standard static specialization of generic. return specializeGenericImpl(baseGeneric, specializeInst, module, nullptr); diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index 1ed7026dfe9..721110710e8 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -10,7 +10,7 @@ namespace Slang // Upcast the value in 'arg' to match the destInfo type. This method inserts // any necessary reinterprets or tag translation instructions. // -IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) +IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) { // The upcasting process inserts the appropriate instructions // to make arg's type match the type provided by destInfo. @@ -18,7 +18,7 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) // This process depends on the structure of arg and destInfo. // // We only deal with the type-flow data-types that are created in - // our pass (CollectionBase/CollectionTaggedUnionType/CollectionTagType/any other + // our pass (SetBase/TaggedUnionType/SetTagType/any other // composites of these insts) // @@ -26,37 +26,35 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) if (!argInfo || !destInfo) return arg; - if (as(argInfo) && as(destInfo)) + if (as(argInfo) && as(destInfo)) { - // A collection tagged union is essentially a tuple(TagType(tableCollection), - // typeCollection) We simply extract the two components, upcast each one, and put it + // A collection tagged union is essentially a tuple(TagType(tableSet), + // typeSet) We simply extract the two components, upcast each one, and put it // back together. // - auto argTUType = as(argInfo); - auto destTUType = as(destInfo); + auto argTUType = as(argInfo); + auto destTUType = as(destInfo); if (argTUType != destTUType) { - // Technically, IRCollectionTaggedUnionType is not a TupleType, + // Technically, IRTaggedUnionType is not a TupleType, // but in practice it works the same way so we'll re-use Slang's // tuple accessors & constructors // auto argTableTag = builder->emitGetTagFromTaggedUnion(arg); - auto reinterpretedTag = upcastCollection( + auto reinterpretedTag = upcastSet( builder, argTableTag, - builder->getCollectionTagType(destTUType->getWitnessTableCollection())); + builder->getSetTagType(destTUType->getWitnessTableSet())); auto argVal = builder->emitGetValueFromTaggedUnion(arg); - auto reinterpretedVal = upcastCollection( - builder, - argVal, - builder->getValueOfCollectionType(destTUType->getTypeCollection())); + auto reinterpretedVal = + upcastSet(builder, argVal, builder->getUntaggedUnionType(destTUType->getTypeSet())); return builder->emitMakeTaggedUnion(destTUType, reinterpretedTag, reinterpretedVal); } } - else if (as(argInfo) && as(destInfo)) + else if (as(argInfo) && as(destInfo)) { // If the arg represents a tag of a colleciton, but the dest is a _different_ // collection, then we need to emit a tag operation to reinterpret the @@ -67,24 +65,23 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) // if (argInfo != destInfo) { - return builder - ->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); + return builder->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperSet, 1, &arg); } } - else if (as(argInfo) && as(destInfo)) + else if (as(argInfo) && as(destInfo)) { // If the arg has a collection type, but the dest is a _different_ collection, // we need to perform a reinterpret. // - // e.g. TypeCollection({T1, T2}) may lower to AnyValueType(N), while - // TypeCollection({T1, T2, T3}) may lower to AnyValueType(M). Since the target + // e.g. TypeSet({T1, T2}) may lower to AnyValueType(N), while + // TypeSet({T1, T2, T3}) may lower to AnyValueType(M). Since the target // is necessarily a super-set, the target any-value-type is always larger (M >= N), // so we only need a simple reinterpret. // if (argInfo != destInfo) { - auto argCollection = as(argInfo)->getCollection(); - if (argCollection->isSingleton() && as(argCollection->getElement(0))) + auto argSet = as(argInfo)->getSet(); + if (argSet->isSingleton() && as(argSet->getElement(0))) { // There's a specific case where we're trying to reinterpret a value of 'void' // type. We'll avoid emitting a reinterpret in this case, and emit a @@ -101,7 +98,7 @@ IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) return builder->emitReinterpret((IRType*)destInfo, arg); } } - else if (!as(argInfo) && as(destInfo)) + else if (!as(argInfo) && as(destInfo)) { // If the arg is not a collection-type, but the dest is a collection, // we need to perform a pack operation. diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-collection.h index 5c3bb233e5d..bac8960d8ff 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-collection.h @@ -11,21 +11,15 @@ namespace Slang // template -void forEachInCollection(IRCollectionBase* info, F func) +void forEachInSet(IRSetBase* info, F func) { for (UInt i = 0; i < info->getOperandCount(); ++i) func(info->getOperand(i)); } -template -void forEachInCollection(IRCollectionTagType* tagType, F func) -{ - forEachInCollection(as(tagType->getCollection()), func); -} - // Upcast the value in 'arg' to match the destInfo type. This method inserts // any necessary reinterprets or tag translation instructions. // -IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo); +IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo); } // namespace Slang diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 7f9b7eccd81..8a75445f507 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -304,9 +304,8 @@ struct WorkQueue // This test is primarily used to determine if additional parameters are requried to place a call to // this callee. // -bool isDynamicGeneric(IRInst* callee) +bool isSetSpecializedGeneric(IRInst* callee) { - // // If the callee is a specialization, and at least one of its arguments // is a collection, then it needs dynamic-dispatch logic to be generated. // @@ -314,17 +313,16 @@ bool isDynamicGeneric(IRInst* callee) { for (UInt i = 0; i < specialize->getArgCount(); i++) { - - // Only functions need dynamic-aware specialization. + // Only functions need set-aware specialization. auto generic = specialize->getBase(); if (getGenericReturnVal(generic)->getOp() != kIROp_Func) return false; auto arg = specialize->getArg(i); - if (as(arg)) - return true; // Found a type-flow-collection argument + if (as(arg)) + return true; // Found a set argument } - return false; // No type-flow-collection arguments found + return false; // No set arguments found } return false; @@ -455,25 +453,24 @@ struct TypeFlowSpecializationContext // This type can be used for insts that are semantically a tuple of a tag (to select a table) // and a payload to contain the existential value. // - IRCollectionTaggedUnionType* makeTaggedUnionType(IRWitnessTableCollection* tableCollection) + IRTaggedUnionType* makeTaggedUnionType(IRWitnessTableSet* tableSet) { IRBuilder builder(module); HashSet typeSet; // Create a type collection out of the base types from each table. - forEachInCollection( - tableCollection, + forEachInSet( + tableSet, [&](IRInst* witnessTable) { if (auto table = as(witnessTable)) typeSet.add(table->getConcreteType()); }); - auto typeCollection = - cast(builder.getCollection(kIROp_TypeCollection, typeSet)); - // Create the tagged union type out of the type and table collection. - return builder.getCollectionTaggedUnionType(tableCollection, typeCollection); + return builder.getTaggedUnionType( + tableSet, + cast(builder.getSet(kIROp_TypeSet, typeSet))); } // Create an unbounded collection. @@ -484,11 +481,11 @@ struct TypeFlowSpecializationContext // // Most commonly occurs with COM interface types. // - IRUnboundedCollection* makeUnbounded() + IRUnboundedSet* makeUnbounded() { IRBuilder builder(module); - return as( - builder.emitIntrinsicInst(nullptr, kIROp_UnboundedCollection, 0, nullptr)); + return as( + builder.emitIntrinsicInst(nullptr, kIROp_UnboundedSet, 0, nullptr)); } // Creates an 'empty' inst (denoted by nullptr), that @@ -503,22 +500,22 @@ struct TypeFlowSpecializationContext // IRTypeFlowData* none() { return nullptr; } - IRValueOfCollectionType* makeValueOfCollectionType(IRTypeCollection* typeCollection) + IRUntaggedUnionType* makeUntaggedUnionType(IRTypeSet* typeSet) { IRBuilder builder(module); - return builder.getValueOfCollectionType(typeCollection); + return builder.getUntaggedUnionType(typeSet); } - IRElementOfCollectionType* makeElementOfCollectionType(IRCollectionBase* collection) + IRElementOfSetType* makeElementOfSetType(IRSetBase* collection) { IRBuilder builder(module); - return builder.getElementOfCollectionType(collection); + return builder.getElementOfSetType(collection); } - IRCollectionTagType* makeTagType(IRCollectionBase* collection) + IRSetTagType* makeTagType(IRSetBase* collection) { IRBuilder builder(module); - return builder.getCollectionTagType(collection); + return builder.getSetTagType(collection); } IRInst* _tryGetInfo(InstWithContext element) @@ -567,8 +564,8 @@ struct TypeFlowSpecializationContext { if (isConcreteType(ptrType->getValueType())) return builder.getPtrTypeWithAddressSpace( - builder.getValueOfCollectionType(cast( - builder.getSingletonCollection(ptrType->getValueType()))), + builder.getUntaggedUnionType( + cast(builder.getSingletonSet(ptrType->getValueType()))), ptrType); else return none(); @@ -579,8 +576,8 @@ struct TypeFlowSpecializationContext if (isConcreteType(arrayType)) { return builder.getArrayType( - builder.getValueOfCollectionType(cast( - builder.getSingletonCollection(arrayType->getElementType()))), + builder.getUntaggedUnionType( + cast(builder.getSingletonSet(arrayType->getElementType()))), arrayType->getElementCount()); } else @@ -588,8 +585,8 @@ struct TypeFlowSpecializationContext } if (isConcreteType(inst->getDataType())) - return builder.getValueOfCollectionType( - cast(builder.getSingletonCollection(inst->getDataType()))); + return builder.getUntaggedUnionType( + cast(builder.getSingletonSet(inst->getDataType()))); else return none(); } @@ -604,9 +601,9 @@ struct TypeFlowSpecializationContext { switch (inst->getDataType()->getOp()) { - case kIROp_CollectionTaggedUnionType: - case kIROp_ValueOfCollectionType: - case kIROp_ElementOfCollectionType: + case kIROp_TaggedUnionType: + case kIROp_UntaggedUnionType: + case kIROp_ElementOfSetType: // These insts directly represent type-flow information, // so we return them directly. return inst->getDataType(); @@ -628,13 +625,13 @@ struct TypeFlowSpecializationContext // inst to represent the collection. // template - T* unionCollection(T* collection1, T* collection2) + T* unionSet(T* collection1, T* collection2) { // It may be possible to accelerate this further, but we usually // don't have to deal with overly large sets (usually 3-20 elements) // - SLANG_ASSERT(as(collection1) && as(collection2)); + SLANG_ASSERT(as(collection1) && as(collection2)); SLANG_ASSERT(collection1->getOp() == collection2->getOp()); if (!collection1) @@ -646,11 +643,11 @@ struct TypeFlowSpecializationContext HashSet allValues; // Collect all values from both collections - forEachInCollection(collection1, [&](IRInst* value) { allValues.add(value); }); - forEachInCollection(collection2, [&](IRInst* value) { allValues.add(value); }); + forEachInSet(collection1, [&](IRInst* value) { allValues.add(value); }); + forEachInSet(collection2, [&](IRInst* value) { allValues.add(value); }); IRBuilder builder(module); - return as(builder.getCollection( + return as(builder.getSet( collection1->getOp(), allValues)); // Create a new collection with the union of values } @@ -660,7 +657,7 @@ struct TypeFlowSpecializationContext // IRInst* unionPropagationInfo(IRInst* info1, IRInst* info2) { - // This is similar to unionCollection, but must consider structures that + // This is similar to unionSet, but must consider structures that // can be built out of collections. // // We allow some level of nesting of collections into other type instructions, @@ -707,7 +704,7 @@ struct TypeFlowSpecializationContext as(info1)); } - if (as(info1) && as(info2)) + if (as(info1) && as(info2)) { // If either info is unbounded, the union is unbounded return makeUnbounded(); @@ -717,32 +714,32 @@ struct TypeFlowSpecializationContext // we simply take the collection union for all the collection operands. // - if (as(info1) && as(info2)) + if (as(info1) && as(info2)) { - return makeTaggedUnionType(unionCollection( - as(info1)->getWitnessTableCollection(), - as(info2)->getWitnessTableCollection())); + return makeTaggedUnionType(unionSet( + as(info1)->getWitnessTableSet(), + as(info2)->getWitnessTableSet())); } - if (as(info1) && as(info2)) + if (as(info1) && as(info2)) { - return makeTagType(unionCollection( - cast(info1->getOperand(0)), - cast(info2->getOperand(0)))); + return makeTagType(unionSet( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); } - if (as(info1) && as(info2)) + if (as(info1) && as(info2)) { - return makeElementOfCollectionType(unionCollection( - cast(info1->getOperand(0)), - cast(info2->getOperand(0)))); + return makeElementOfSetType(unionSet( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); } - if (as(info1) && as(info2)) + if (as(info1) && as(info2)) { - return makeValueOfCollectionType(unionCollection( - cast(info1->getOperand(0)), - cast(info2->getOperand(0)))); + return makeUntaggedUnionType(unionSet( + cast(info1->getOperand(0)), + cast(info2->getOperand(0)))); } SLANG_UNEXPECTED("Unhandled propagation info types in unionPropagationInfo"); @@ -1010,10 +1007,9 @@ struct TypeFlowSpecializationContext { if (auto dataTypeInfo = tryGetInfo(context, inst->getDataType())) { - if (auto elementOfCollectionType = as(dataTypeInfo)) + if (auto elementOfSetType = as(dataTypeInfo)) { - info = makeValueOfCollectionType( - cast(elementOfCollectionType->getCollection())); + info = makeUntaggedUnionType(cast(elementOfSetType->getSet())); } } } @@ -1176,8 +1172,8 @@ struct TypeFlowSpecializationContext if (isConcreteType(concreteReturnType)) { IRBuilder builder(module); - returnInfo = builder.getValueOfCollectionType(cast( - builder.getSingletonCollection(concreteReturnType))); + returnInfo = builder.getUntaggedUnionType( + cast(builder.getSingletonSet(concreteReturnType))); } } @@ -1240,8 +1236,8 @@ struct TypeFlowSpecializationContext auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeTaggedUnionType(as( - builder.getCollection(kIROp_WitnessTableCollection, tables))); + return makeTaggedUnionType( + as(builder.getSet(kIROp_WitnessTableSet, tables))); else { sink->diagnose( @@ -1270,7 +1266,7 @@ struct TypeFlowSpecializationContext // Concrete case. if (as(witnessTable)) return makeTaggedUnionType( - as(builder.getSingletonCollection(witnessTable))); + as(builder.getSingletonSet(witnessTable))); // Get the witness table info auto witnessTableInfo = tryGetInfo(context, witnessTable); @@ -1278,12 +1274,11 @@ struct TypeFlowSpecializationContext if (!witnessTableInfo) return none(); - if (as(witnessTableInfo)) + if (as(witnessTableInfo)) return makeUnbounded(); - if (auto elementOfCollectionType = as(witnessTableInfo)) - return makeTaggedUnionType( - cast(elementOfCollectionType->getCollection())); + if (auto elementOfSetType = as(witnessTableInfo)) + return makeTaggedUnionType(cast(elementOfSetType->getSet())); SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } @@ -1348,8 +1343,8 @@ struct TypeFlowSpecializationContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeTaggedUnionType(as( - builder.getCollection(kIROp_WitnessTableCollection, tables))); + return makeTaggedUnionType(as( + builder.getSet(kIROp_WitnessTableSet, tables))); else return none(); } @@ -1378,8 +1373,8 @@ struct TypeFlowSpecializationContext { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) - return makeTaggedUnionType(as( - builder.getCollection(kIROp_WitnessTableCollection, tables))); + return makeTaggedUnionType( + as(builder.getSet(kIROp_WitnessTableSet, tables))); else return none(); } @@ -1561,8 +1556,8 @@ struct TypeFlowSpecializationContext IRBuilder builder(module); if (isOptionalExistentialType(inst->getDataType())) { - auto noneTableSet = cast( - builder.getCollection(kIROp_WitnessTableCollection, getNoneWitness())); + auto noneTableSet = + cast(builder.getSet(kIROp_WitnessTableSet, getNoneWitness())); return makeTaggedUnionType(noneTableSet); } @@ -1591,7 +1586,7 @@ struct TypeFlowSpecializationContext { if (auto info = tryGetInfo(context, inst->getValue())) { - SLANG_ASSERT(as(info)); + SLANG_ASSERT(as(info)); return info; } } @@ -1613,7 +1608,7 @@ struct TypeFlowSpecializationContext // if (auto info = tryGetInfo(context, inst->getOperand(0))) { - SLANG_ASSERT(as(info)); + SLANG_ASSERT(as(info)); return info; } } @@ -1637,21 +1632,21 @@ struct TypeFlowSpecializationContext auto witnessTable = inst->getWitnessTable(); auto witnessTableInfo = tryGetInfo(context, witnessTable); - if (auto elementOfCollectionType = as(witnessTableInfo)) + if (auto elementOfSetType = as(witnessTableInfo)) { IRBuilder builder(module); HashSet results; - forEachInCollection( - cast(elementOfCollectionType->getCollection()), + forEachInSet( + cast(elementOfSetType->getSet()), [&](IRInst* table) { results.add(findWitnessTableEntry(cast(table), key)); }); - return makeElementOfCollectionType(builder.getCollection(results)); + return makeElementOfSetType(builder.getSet(results)); } if (!witnessTableInfo) return none(); - if (as(witnessTableInfo)) + if (as(witnessTableInfo)) return makeUnbounded(); SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); @@ -1666,7 +1661,7 @@ struct TypeFlowSpecializationContext // state that the info of the result is a tag-type of that collection. // // Note that since ExtractExistentialWitnessTable can only be used on - // an existential, the input info must be a CollectionTaggedUnionType of + // an existential, the input info must be a TaggedUnionType of // concrete table and type collections (or none/unbounded) // @@ -1676,11 +1671,11 @@ struct TypeFlowSpecializationContext if (!operandInfo) return none(); - if (as(operandInfo)) + if (as(operandInfo)) return makeUnbounded(); - if (auto taggedUnion = as(operandInfo)) - return makeElementOfCollectionType(taggedUnion->getWitnessTableCollection()); + if (auto taggedUnion = as(operandInfo)) + return makeElementOfSetType(taggedUnion->getWitnessTableSet()); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } @@ -1692,7 +1687,7 @@ struct TypeFlowSpecializationContext // state that the info of the result is a tag-type of that collection. // // Note: Since ExtractExistentialType can only be used on - // an existential, the input info must be a CollectionTaggedUnionType of + // an existential, the input info must be a TaggedUnionType of // concrete table and type collections (or none/unbounded) // @@ -1702,11 +1697,11 @@ struct TypeFlowSpecializationContext if (!operandInfo) return none(); - if (as(operandInfo)) + if (as(operandInfo)) return makeUnbounded(); - if (auto taggedUnion = as(operandInfo)) - return makeElementOfCollectionType(taggedUnion->getTypeCollection()); + if (auto taggedUnion = as(operandInfo)) + return makeElementOfSetType(taggedUnion->getTypeSet()); SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialType"); } @@ -1730,11 +1725,11 @@ struct TypeFlowSpecializationContext if (!operandInfo) return none(); - if (as(operandInfo)) + if (as(operandInfo)) return makeUnbounded(); - if (auto taggedUnion = as(operandInfo)) - return makeValueOfCollectionType(taggedUnion->getTypeCollection()); + if (auto taggedUnion = as(operandInfo)) + return makeUntaggedUnionType(taggedUnion->getTypeSet()); return none(); } @@ -1766,17 +1761,17 @@ struct TypeFlowSpecializationContext auto operand = inst->getBase(); auto operandInfo = tryGetInfo(context, operand); - if (as(operandInfo)) + if (as(operandInfo)) return makeUnbounded(); - if (as(operandInfo)) + if (as(operandInfo)) { SLANG_UNEXPECTED( "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); } // Handle the 'many' or 'one' cases. - if (as(operandInfo) || isGlobalInst(operand)) + if (as(operandInfo) || isGlobalInst(operand)) { List specializationArgs; for (UInt i = 0; i < inst->getArgCount(); ++i) @@ -1797,27 +1792,24 @@ struct TypeFlowSpecializationContext if (!argInfo) return none(); - if (as(argInfo) || as(argInfo)) + if (as(argInfo) || as(argInfo)) { SLANG_UNEXPECTED("Unexpected Existential operand in specialization argument."); } - if (auto elementOfCollectionType = as(argInfo)) + if (auto elementOfSetType = as(argInfo)) { - if (elementOfCollectionType->getCollection()->isSingleton()) - specializationArgs.add( - elementOfCollectionType->getCollection()->getElement(0)); + if (elementOfSetType->getSet()->isSingleton()) + specializationArgs.add(elementOfSetType->getSet()->getElement(0)); else { - if (auto typeCollection = - as(elementOfCollectionType->getCollection())) + if (auto typeSet = as(elementOfSetType->getSet())) { - specializationArgs.add(makeValueOfCollectionType(typeCollection)); + specializationArgs.add(makeUntaggedUnionType(typeSet)); } - else if (as( - elementOfCollectionType->getCollection())) + else if (as(elementOfSetType->getSet())) { - specializationArgs.add(elementOfCollectionType->getCollection()); + specializationArgs.add(elementOfSetType->getSet()); } else { @@ -1845,13 +1837,13 @@ struct TypeFlowSpecializationContext { if (auto info = tryGetInfo(context, type)) { - if (auto elementOfCollectionType = as(info)) + if (auto elementOfSetType = as(info)) { - if (elementOfCollectionType->getCollection()->isSingleton()) - return elementOfCollectionType->getCollection()->getElement(0); + if (elementOfSetType->getSet()->isSingleton()) + return elementOfSetType->getSet()->getElement(0); else - return makeValueOfCollectionType(cast( - elementOfCollectionType->getCollection())); + return makeUntaggedUnionType( + cast(elementOfSetType->getSet())); } else return type; @@ -1876,11 +1868,11 @@ struct TypeFlowSpecializationContext // dynamic IRSpecialize. In this situation, we'd want to use the type inst's info to // find the collection-based specialization and create a func-type from it. // - if (auto elementOfCollectionType = as(typeInfo)) + if (auto elementOfSetType = as(typeInfo)) { - SLANG_ASSERT(elementOfCollectionType->getCollection()->isSingleton()); + SLANG_ASSERT(elementOfSetType->getSet()->isSingleton()); auto specializeInst = - cast(elementOfCollectionType->getCollection()->getElement(0)); + cast(elementOfSetType->getSet()->getElement(0)); auto specializedFuncType = cast(specializeGeneric(specializeInst)); typeOfSpecialization = specializedFuncType; } @@ -1906,12 +1898,12 @@ struct TypeFlowSpecializationContext // Specialize each element in the set HashSet specializedSet; - IRCollectionBase* collection = nullptr; - if (auto elementOfCollectionType = as(operandInfo)) + IRSetBase* collection = nullptr; + if (auto elementOfSetType = as(operandInfo)) { - collection = elementOfCollectionType->getCollection(); + collection = elementOfSetType->getSet(); - forEachInCollection( + forEachInSet( collection, [&](IRInst* arg) { @@ -1934,7 +1926,7 @@ struct TypeFlowSpecializationContext } IRBuilder builder(module); - return makeElementOfCollectionType(builder.getCollection(specializedSet)); + return makeElementOfSetType(builder.getSet(specializedSet)); } if (!operandInfo) @@ -1994,12 +1986,12 @@ struct TypeFlowSpecializationContext if (as(arg)) continue; - if (auto collection = as(arg)) + if (auto collection = as(arg)) { updateInfo( context, param, - makeElementOfCollectionType(collection), + makeElementOfSetType(collection), true, workQueue); } @@ -2009,7 +2001,7 @@ struct TypeFlowSpecializationContext updateInfo( context, param, - makeElementOfCollectionType(builder.getSingletonCollection(arg)), + makeElementOfSetType(builder.getSingletonSet(arg)), true, workQueue); } @@ -2073,10 +2065,10 @@ struct TypeFlowSpecializationContext // If we have a collection of functions (with or without a dynamic tag), register // each one. // - if (auto elementOfCollectionType = as(calleeInfo)) + if (auto elementOfSetType = as(calleeInfo)) { - forEachInCollection( - elementOfCollectionType->getCollection(), + forEachInSet( + elementOfSetType->getSet(), [&](IRInst* func) { propagateToCallSite(func); }); } else if (isGlobalInst(callee)) @@ -2390,7 +2382,7 @@ struct TypeFlowSpecializationContext if (isGlobalInst(inst) && (!as(inst) && (as(inst) || as(inst) || as(inst)))) - return builder.getSingletonCollection(inst); + return builder.getSingletonSet(inst); auto instType = inst->getDataType(); if (isGlobalInst(inst)) @@ -2464,13 +2456,13 @@ struct TypeFlowSpecializationContext // \ / // C // - // After specialization, A could pass a value of type TagType(WitnessTableCollection{T1, - // T2}) while B passes a value of type TagType(WitnessTableCollection{T2, T3}), while the - // phi parameter's type in C has the union type `TagType(WitnessTableCollection{T1, T2, + // After specialization, A could pass a value of type TagType(WitnessTableSet{T1, + // T2}) while B passes a value of type TagType(WitnessTableSet{T2, T3}), while the + // phi parameter's type in C has the union type `TagType(WitnessTableSet{T1, T2, // T3})` // - // In this case, we use `upcastCollection` to insert a cast from - // TagType(WitnessTableCollection{T1, T2}) -> TagType(WitnessTableCollection{T1, T2, T3}) + // In this case, we use `upcastSet` to insert a cast from + // TagType(WitnessTableSet{T1, T2}) -> TagType(WitnessTableSet{T1, T2, T3}) // before passing the result as a phi argument. // // The same logic applies for the return values. The function's caller expects a union type @@ -2512,7 +2504,7 @@ struct TypeFlowSpecializationContext auto arg = unconditionalBranch->getArg(paramIndex); IRBuilder builder(module); builder.setInsertBefore(unconditionalBranch); - auto newArg = upcastCollection(&builder, arg, param->getDataType()); + auto newArg = upcastSet(&builder, arg, param->getDataType()); if (newArg != arg) { @@ -2544,7 +2536,7 @@ struct TypeFlowSpecializationContext IRBuilder builder(module); builder.setInsertBefore(returnInst); auto newReturnVal = - upcastCollection(&builder, returnInst->getVal(), specializedType); + upcastSet(&builder, returnInst->getVal(), specializedType); if (newReturnVal != returnInst->getVal()) { // Replace the return value with the reinterpreted value @@ -2578,7 +2570,7 @@ struct TypeFlowSpecializationContext // // i.e. `ExtractExistentialType`, `ExtractExistentialWitnessTable`, `ExtractExistentialValue`, // `MakeExistential`, `LookupWitness` (and more) are rewritten to concrete tag translation - // insts (e.g. `GetTagForMappedCollection`, `GetTagForSpecializedCollection`, etc.) + // insts (e.g. `GetTagForMappedSet`, `GetTagForSpecializedSet`, etc.) // bool performDynamicInstLowering() { @@ -2620,7 +2612,7 @@ struct TypeFlowSpecializationContext if (!info) return nullptr; - if (as(info)) + if (as(info)) return nullptr; if (auto ptrType = as(info)) @@ -2647,29 +2639,29 @@ struct TypeFlowSpecializationContext return nullptr; } - if (auto taggedUnion = as(info)) + if (auto taggedUnion = as(info)) { return (IRType*)taggedUnion; } - if (auto elementOfCollectionType = as(info)) + if (auto elementOfSetType = as(info)) { // Replace element-of-collection types with tag types. - return makeTagType(elementOfCollectionType->getCollection()); + return makeTagType(elementOfSetType->getSet()); } - if (auto valOfCollectionType = as(info)) + if (auto valOfSetType = as(info)) { - if (valOfCollectionType->getCollection()->isSingleton()) + if (valOfSetType->getSet()->isSingleton()) { // If there's only one type in the collection, return it directly - return (IRType*)valOfCollectionType->getCollection()->getElement(0); + return (IRType*)valOfSetType->getSet()->getElement(0); } - return valOfCollectionType; + return valOfSetType; } - if (as(info) || as(info)) + if (as(info) || as(info)) { // Don't specialize these collections.. they should be used through // tag types, or be processed out during specializeing. @@ -2781,17 +2773,17 @@ struct TypeFlowSpecializationContext return false; // If we didn't resolve anything for this inst, don't modify it. - auto elementOfCollectionType = as(info); - if (!elementOfCollectionType) + auto elementOfSetType = as(info); + if (!elementOfSetType) return false; IRBuilder builder(inst); builder.setInsertBefore(inst); // If there's a single element, we can do a simple replacement. - if (elementOfCollectionType->getCollection()->getCount() == 1) + if (elementOfSetType->getSet()->getCount() == 1) { - auto element = elementOfCollectionType->getCollection()->getElement(0); + auto element = elementOfSetType->getSet()->getElement(0); inst->replaceUsesWith(element); inst->removeAndDeallocate(); return true; @@ -2800,23 +2792,23 @@ struct TypeFlowSpecializationContext // If we reach here, we have a truly dynamic case. Multiple elements. // We need to emit a run-time inst to keep track of the tag. // - // We use the GetTagForMappedCollection inst to do this, and set its data type to + // We use the GetTagForMappedSet inst to do this, and set its data type to // the appropriate tag-type. // auto witnessTableInst = inst->getWitnessTable(); auto witnessTableInfo = witnessTableInst->getDataType(); - if (as(witnessTableInfo)) + if (as(witnessTableInfo)) { - auto thisInstInfo = cast(tryGetInfo(context, inst)); - if (thisInstInfo->getCollection() != nullptr) + auto thisInstInfo = cast(tryGetInfo(context, inst)); + if (thisInstInfo->getSet() != nullptr) { List operands = {witnessTableInst, inst->getRequirementKey()}; auto newInst = builder.emitIntrinsicInst( - (IRType*)makeTagType(thisInstInfo->getCollection()), - kIROp_GetTagForMappedCollection, + (IRType*)makeTagType(thisInstInfo->getSet()), + kIROp_GetTagForMappedSet, operands.getCount(), operands.getBuffer()); @@ -2834,11 +2826,11 @@ struct TypeFlowSpecializationContext IRExtractExistentialWitnessTable* inst) { // If we have a non-trivial info registered, it must of - // CollectionTagType(WitnessTableCollection(...)) + // SetTagType(WitnessTableSet(...)) // - // Further, the operand must be an existential (CollectionTaggedUnionType), which is - // conceptually a pair of TagType(tableCollection) and a - // ValueOfCollectionType(typeCollection) + // Further, the operand must be an existential (TaggedUnionType), which is + // conceptually a pair of TagType(tableSet) and a + // UntaggedUnionType(typeSet) // // We will simply extract the first element of this tuple. // @@ -2850,20 +2842,20 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertBefore(inst); - auto elementOfCollectionType = as(info); - if (!elementOfCollectionType) + auto elementOfSetType = as(info); + if (!elementOfSetType) return false; - if (elementOfCollectionType->getCollection()->getCount() == 1) + if (elementOfSetType->getSet()->getCount() == 1) { // Found a single possible type. Simple replacement. - inst->replaceUsesWith(elementOfCollectionType->getCollection()->getElement(0)); + inst->replaceUsesWith(elementOfSetType->getSet()->getElement(0)); inst->removeAndDeallocate(); return true; } else { - // Replace with GetElement(specializedInst, 0) -> TagType(tableCollection) + // Replace with GetElement(specializedInst, 0) -> TagType(tableSet) // which retreives a 'tag' (i.e. a run-time identifier for one of the elements // of the collection) // @@ -2881,7 +2873,7 @@ struct TypeFlowSpecializationContext auto existential = inst->getOperand(0); auto existentialInfo = existential->getDataType(); - if (as(existentialInfo)) + if (as(existentialInfo)) { IRBuilder builder(inst); builder.setInsertAfter(inst); @@ -2898,12 +2890,12 @@ struct TypeFlowSpecializationContext bool specializeExtractExistentialType(IRInst* context, IRExtractExistentialType* inst) { auto info = tryGetInfo(context, inst); - if (auto elementOfCollectionType = as(info)) + if (auto elementOfSetType = as(info)) { - if (elementOfCollectionType->getCollection()->isSingleton()) + if (elementOfSetType->getSet()->isSingleton()) { // Found a single possible type. Statically known concrete type. - auto singletonValue = elementOfCollectionType->getCollection()->getElement(0); + auto singletonValue = elementOfSetType->getSet()->getElement(0); inst->replaceUsesWith(singletonValue); inst->removeAndDeallocate(); return true; @@ -2923,25 +2915,22 @@ struct TypeFlowSpecializationContext return false; } - bool isTaggedUnionType(IRInst* type) - { - return as(type) != nullptr; - } + bool isTaggedUnionType(IRInst* type) { return as(type) != nullptr; } IRType* updateType(IRType* currentType, IRType* newType) { - if (auto valOfCollectionType = as(currentType)) + if (auto valOfSetType = as(currentType)) { HashSet collectionElements; - forEachInCollection( - valOfCollectionType->getCollection(), + forEachInSet( + valOfSetType->getSet(), [&](IRInst* element) { collectionElements.add(element); }); - if (auto newValOfCollectionType = as(newType)) + if (auto newValOfSetType = as(newType)) { // If the new type is also a collection, merge the two collections - forEachInCollection( - newValOfCollectionType->getCollection(), + forEachInSet( + newValOfSetType->getSet(), [&](IRInst* element) { collectionElements.add(element); }); } else @@ -2952,8 +2941,8 @@ struct TypeFlowSpecializationContext // If this is a collection, we need to create a new collection with the new type IRBuilder builder(module); - auto newCollection = builder.getCollection(kIROp_TypeCollection, collectionElements); - return makeValueOfCollectionType(cast(newCollection)); + auto newSet = builder.getSet(kIROp_TypeSet, collectionElements); + return makeUntaggedUnionType(cast(newSet)); } else if (currentType == newType) { @@ -2963,42 +2952,40 @@ struct TypeFlowSpecializationContext { return newType; } - else if ( - as(currentType) && - as(newType)) + else if (as(currentType) && as(newType)) { // Merge the elements of both tagged unions into a new tuple type - return (IRType*)makeTaggedUnionType((unionCollection( - as(currentType)->getWitnessTableCollection(), - as(newType)->getWitnessTableCollection()))); + return (IRType*)makeTaggedUnionType((unionSet( + as(currentType)->getWitnessTableSet(), + as(newType)->getWitnessTableSet()))); } else // Need to create a new collection. { HashSet collectionElements; - SLANG_ASSERT(!as(currentType) && !as(newType)); + SLANG_ASSERT(!as(currentType) && !as(newType)); collectionElements.add(currentType); collectionElements.add(newType); // If this is a collection, we need to create a new collection with the new type IRBuilder builder(module); - auto newCollection = builder.getCollection(kIROp_TypeCollection, collectionElements); - return makeValueOfCollectionType(cast(newCollection)); + auto newSet = builder.getSet(kIROp_TypeSet, collectionElements); + return makeUntaggedUnionType(cast(newSet)); } } IRFuncType* getEffectiveFuncTypeForDispatcher( - IRWitnessTableCollection* tableCollection, + IRWitnessTableSet* tableSet, IRStructKey* key, - IRFuncCollection* resultFuncCollection) + IRFuncSet* resultFuncSet) { SLANG_UNUSED(key); List extraParamTypes; - extraParamTypes.add((IRType*)makeTagType(tableCollection)); + extraParamTypes.add((IRType*)makeTagType(tableSet)); - auto innerFuncType = getEffectiveFuncTypeForCollection(resultFuncCollection); + auto innerFuncType = getEffectiveFuncTypeForSet(resultFuncSet); List allParamTypes; allParamTypes.addRange(extraParamTypes); for (auto paramType : innerFuncType->getParamTypes()) @@ -3011,7 +2998,7 @@ struct TypeFlowSpecializationContext // Get an effective func type to use for the callee. // The callee may be a collection, in which case, this returns a union-ed functype. // - IRFuncType* getEffectiveFuncTypeForCollection(IRFuncCollection* calleeCollection) + IRFuncType* getEffectiveFuncTypeForSet(IRFuncSet* calleeSet) { // The effective func type for a callee is calculated as follows: // @@ -3060,7 +3047,7 @@ struct TypeFlowSpecializationContext }; List calleesToProcess; - forEachInCollection(calleeCollection, [&](IRInst* func) { calleesToProcess.add(func); }); + forEachInSet(calleeSet, [&](IRInst* func) { calleesToProcess.add(func); }); for (auto context : calleesToProcess) { @@ -3093,19 +3080,19 @@ struct TypeFlowSpecializationContext List extraParamTypes; // If the any of the elements in the callee (or the callee itself in case - // of a singleton) is a dynamic specialization, each non-singleton WitnessTableCollection, + // of a singleton) is a dynamic specialization, each non-singleton WitnessTableSet, // requries a corresponding tag input. // - if (calleeCollection->isSingleton() && isDynamicGeneric(calleeCollection->getElement(0))) + if (calleeSet->isSingleton() && isSetSpecializedGeneric(calleeSet->getElement(0))) { - auto specializeInst = as(calleeCollection->getElement(0)); + auto specializeInst = as(calleeSet->getElement(0)); // If this is a dynamic generic, we need to add a tag type for each - // WitnessTableCollection in the callee. + // WitnessTableSet in the callee. // for (UIndex i = 0; i < specializeInst->getArgCount(); i++) - if (auto tableCollection = as(specializeInst->getArg(i))) - extraParamTypes.add((IRType*)makeTagType(tableCollection)); + if (auto tableSet = as(specializeInst->getArg(i))) + extraParamTypes.add((IRType*)makeTagType(tableSet)); } List allParamTypes; @@ -3118,8 +3105,7 @@ struct TypeFlowSpecializationContext IRFuncType* getEffectiveFuncType(IRInst* callee) { IRBuilder builder(module); - return getEffectiveFuncTypeForCollection( - cast(builder.getSingletonCollection(callee))); + return getEffectiveFuncTypeForSet(cast(builder.getSingletonSet(callee))); } // Helper function for specializing calls. @@ -3127,7 +3113,7 @@ struct TypeFlowSpecializationContext // For a `Specialize` instruction that has dynamic tag arguments, // extract all the tags and return them as a list. // - List getArgsForDynamicSpecialization(IRSpecialize* specializedCallee) + List getArgsForSetSpecializedGeneric(IRSpecialize* specializedCallee) { List callArgs; for (UInt ii = 0; ii < specializedCallee->getArgCount(); ii++) @@ -3138,8 +3124,8 @@ struct TypeFlowSpecializationContext // Pull all tag-type arguments from the specialization arguments // and add them to the call arguments. // - if (auto tagType = as(argInfo)) - if (as(tagType->getCollection())) + if (auto tagType = as(argInfo)) + if (as(tagType->getSet())) callArgs.add(specArg); } @@ -3174,11 +3160,11 @@ struct TypeFlowSpecializationContext // and using the tag to specify which function to call. // // e.g. - // let tag : TagType(funcCollection) = /* ... */; + // let tag : TagType(funcSet) = /* ... */; // let val = Call(tag, arg1, arg2, ...); // becomes - // let tag : TagType(funcCollection) = /* ... */; - // let val = Call(funcCollection, tag, arg1, arg2, ...); + // let tag : TagType(funcSet) = /* ... */; + // let val = Call(funcSet, tag, arg1, arg2, ...); // // - If any the callee is a dynamic specialization of a generic, we need to add any dynamic // witness @@ -3199,16 +3185,16 @@ struct TypeFlowSpecializationContext // replaced with their static collections. // // // --- before specialization --- - // let s1 : TagType(WitnessTableCollection(tA, tB, tC)) = /* ... */; - // let s2 : TagType(TypeCollection(A, B, C)) = /* ... */; + // let s1 : TagType(WitnessTableSet(tA, tB, tC)) = /* ... */; + // let s2 : TagType(TypeSet(A, B, C)) = /* ... */; // let specCallee = Specialize(generic, s1, s2); // let val = Call(specCallee, /* call args */); // // // --- after specialization --- - // let s1 : TagType(WitnessTableCollection(tA, tB, tC)) = /* ... */; - // let s2 : TagType(TypeCollection(A, B, C)) = /* ... */; + // let s1 : TagType(WitnessTableSet(tA, tB, tC)) = /* ... */; + // let s2 : TagType(TypeSet(A, B, C)) = /* ... */; // let newSpecCallee = Specialize(generic, - // WitnessTableCollection(tA, tB, tC), TypeCollection(A, B, C)); + // WitnessTableSet(tA, tB, tC), TypeSet(A, B, C)); // let newVal = Call(newSpecCallee, s1, /* call args */); // // @@ -3227,7 +3213,7 @@ struct TypeFlowSpecializationContext // the parameters in the callee have been specialized to accept // a wider collection compared to the arguments from this call site // - // In this case, we just upcast them using `upcastCollection` before + // In this case, we just upcast them using `upcastSet` before // creating a new call inst // @@ -3250,7 +3236,7 @@ struct TypeFlowSpecializationContext // with the tag as the first argument. So the callee is // the collection itself. // - if (auto collectionTag = as(callee->getDataType())) + if (auto collectionTag = as(callee->getDataType())) { if (!collectionTag->isSingleton()) { @@ -3266,33 +3252,32 @@ struct TypeFlowSpecializationContext // function. If we ever have the ability to pass functions around more // flexibly, then this should just become a specific case. - if (auto tagMapOperand = as(callee)) + if (auto tagMapOperand = as(callee)) { auto tableTag = tagMapOperand->getOperand(0); auto lookupKey = cast(tagMapOperand->getOperand(1)); - auto tableCollection = cast( - cast(tableTag->getDataType())->getCollection()); + auto tableSet = cast( + cast(tableTag->getDataType())->getSet()); IRBuilder builder(module); callee = builder.emitGetDispatcher( getEffectiveFuncTypeForDispatcher( - tableCollection, + tableSet, lookupKey, - cast(collectionTag->getCollection())), - tableCollection, + cast(collectionTag->getSet())), + tableSet, lookupKey); callArgs.add(tableTag); } - else if ( - auto specializedTagMapOperand = as(callee)) + else if (auto specializedTagMapOperand = as(callee)) { auto innerTagMapOperand = - cast(specializedTagMapOperand->getOperand(0)); + cast(specializedTagMapOperand->getOperand(0)); auto tableTag = innerTagMapOperand->getOperand(0); - auto tableCollection = cast( - cast(tableTag->getDataType())->getCollection()); + auto tableSet = cast( + cast(tableTag->getDataType())->getSet()); auto lookupKey = cast(innerTagMapOperand->getOperand(1)); List specArgs; @@ -3300,17 +3285,17 @@ struct TypeFlowSpecializationContext ++argIdx) { auto arg = specializedTagMapOperand->getOperand(argIdx); - if (auto tagType = as(arg->getDataType())) + if (auto tagType = as(arg->getDataType())) { - SLANG_ASSERT(!tagType->getCollection()->isSingleton()); - if (as(tagType->getCollection())) + SLANG_ASSERT(!tagType->getSet()->isSingleton()); + if (as(tagType->getSet())) { callArgs.add(arg); - specArgs.add(tagType->getCollection()); + specArgs.add(tagType->getSet()); } else { - specArgs.add(tagType->getCollection()); + specArgs.add(tagType->getSet()); } } else @@ -3324,10 +3309,10 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(callee); callee = builder.emitGetSpecializedDispatcher( getEffectiveFuncTypeForDispatcher( - tableCollection, + tableSet, lookupKey, - cast(collectionTag->getCollection())), - tableCollection, + cast(collectionTag->getSet())), + tableSet, lookupKey, specArgs); @@ -3339,11 +3324,11 @@ struct TypeFlowSpecializationContext "Cannot specialize call with non-singleton collection tag callee"); } } - else if (isDynamicGeneric(collectionTag->getCollection()->getElement(0))) + else if (isSetSpecializedGeneric(collectionTag->getSet()->getElement(0))) { - // Single element which is a dynamic generic specialization. - callArgs.addRange(getArgsForDynamicSpecialization(cast(callee))); - callee = collectionTag->getCollection()->getElement(0); + // Single element which is a set specialized generic. + callArgs.addRange(getArgsForSetSpecializedGeneric(cast(callee))); + callee = collectionTag->getSet()->getElement(0); auto funcType = getEffectiveFuncType(callee); callee->setFullType(funcType); @@ -3351,8 +3336,8 @@ struct TypeFlowSpecializationContext else { // If we reach here, then something is wrong. If our callee is an inst of tag-type, - // we expect it to either be a `GetTagForMappedCollection`, `Specialize` or - // `GetTagForSpecializedCollection`. + // we expect it to either be a `GetTagForMappedSet`, `Specialize` or + // `GetTagForSpecializedSet`. // Any other case should never occur (in the current design of the compiler) // SLANG_UNEXPECTED( @@ -3398,7 +3383,7 @@ struct TypeFlowSpecializationContext { IRBuilder builder(context); builder.setInsertBefore(inst); - callArgs.add(upcastCollection(&builder, arg, paramType)); + callArgs.add(upcastSet(&builder, arg, paramType)); break; } @@ -3484,7 +3469,7 @@ struct TypeFlowSpecializationContext auto arg = inst->getOperand(operandIndex); IRBuilder builder(context); builder.setInsertBefore(inst); - auto newArg = upcastCollection(&builder, arg, field->getFieldType()); + auto newArg = upcastSet(&builder, arg, field->getFieldType()); if (arg != newArg) { @@ -3501,14 +3486,14 @@ struct TypeFlowSpecializationContext bool specializeMakeExistential(IRInst* context, IRMakeExistential* inst) { // After specialization, existentials (that are not unbounded) are treated as tuples - // of a WitnessTableCollection tag and a value of type TypeCollection. + // of a WitnessTableSet tag and a value of type TypeSet. // // A MakeExistential is just converted into a MakeTaggedUnion, with any necessary // upcasts. // auto info = tryGetInfo(context, inst); - auto taggedUnion = as(info); + auto taggedUnion = as(info); if (!taggedUnion) return false; @@ -3516,36 +3501,34 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(inst); // Collect types from the witness tables to determine the any-value type - auto tableCollection = taggedUnion->getWitnessTableCollection(); - auto typeCollection = taggedUnion->getTypeCollection(); + auto tableSet = taggedUnion->getWitnessTableSet(); + auto typeSet = taggedUnion->getTypeSet(); IRInst* witnessTableTag = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { - auto singletonTagType = makeTagType(builder.getSingletonCollection(witnessTable)); - IRInst* tagValue = builder.emitGetTagOfElementInCollection( - (IRType*)singletonTagType, - witnessTable, - tableCollection); + auto singletonTagType = makeTagType(builder.getSingletonSet(witnessTable)); + IRInst* tagValue = + builder.emitGetTagOfElementInSet((IRType*)singletonTagType, witnessTable, tableSet); witnessTableTag = builder.emitIntrinsicInst( - (IRType*)makeTagType(tableCollection), - kIROp_GetTagForSuperCollection, + (IRType*)makeTagType(tableSet), + kIROp_GetTagForSuperSet, 1, &tagValue); } - else if (as(inst->getWitnessTable()->getDataType())) + else if (as(inst->getWitnessTable()->getDataType())) { // Dynamic. Use the witness table inst as a tag witnessTableTag = inst->getWitnessTable(); } // Create the appropriate any-value type - auto collectionType = typeCollection->isSingleton() - ? (IRType*)typeCollection->getElement(0) - : builder.getValueOfCollectionType((IRType*)typeCollection); + auto collectionType = typeSet->isSingleton() + ? (IRType*)typeSet->getElement(0) + : builder.getUntaggedUnionType((IRType*)typeSet); // Pack the value - auto packedValue = as(collectionType) + auto packedValue = as(collectionType) ? builder.emitPackAnyValue(collectionType, inst->getWrappedValue()) : inst->getWrappedValue(); @@ -3571,34 +3554,34 @@ struct TypeFlowSpecializationContext // auto info = tryGetInfo(context, inst); - auto taggedUnion = as(info); + auto taggedUnion = as(info); if (!taggedUnion) return false; - auto taggedUnionType = as(getLoweredType(taggedUnion)); + auto taggedUnionType = as(getLoweredType(taggedUnion)); IRBuilder builder(inst); builder.setInsertBefore(inst); IRInst* args[] = {inst->getDataType(), inst->getTypeID()}; auto translatedTag = builder.emitIntrinsicInst( - (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), + (IRType*)makeTagType(taggedUnionType->getWitnessTableSet()), kIROp_GetTagFromSequentialID, 2, args); IRInst* packedValue = nullptr; - auto collection = taggedUnionType->getTypeCollection(); + auto collection = taggedUnionType->getTypeSet(); if (!collection->isSingleton()) { packedValue = builder.emitPackAnyValue( - (IRType*)builder.getValueOfCollectionType(collection), + (IRType*)builder.getUntaggedUnionType(collection), inst->getValue()); } else { packedValue = builder.emitReinterpret( - (IRType*)builder.getValueOfCollectionType(collection), + (IRType*)builder.getUntaggedUnionType(collection), inst->getValue()); } @@ -3697,7 +3680,7 @@ struct TypeFlowSpecializationContext // // - A collection of functions with concrete specialization arguments. // In this case, we will emit an instruction to map from the input generic collection - // to the output specialized collection via `GetTagForSpecializedCollection`. + // to the output specialized collection via `GetTagForSpecializedSet`. // This inst encodes the key-value mapping in its operands: // e.g.(input_tag, key0, value0, key1, value1, ...) // @@ -3710,9 +3693,9 @@ struct TypeFlowSpecializationContext if (auto concreteGeneric = as(inst->getBase())) isFuncReturn = as(getGenericReturnVal(concreteGeneric)) != nullptr; - else if (auto tagType = as(inst->getBase()->getDataType())) + else if (auto tagType = as(inst->getBase()->getDataType())) { - auto firstConcreteGeneric = as(tagType->getCollection()->getElement(0)); + auto firstConcreteGeneric = as(tagType->getSet()->getElement(0)); isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; } @@ -3721,10 +3704,10 @@ struct TypeFlowSpecializationContext { if (auto info = tryGetInfo(context, inst)) { - if (auto elementOfCollectionType = as(info)) + if (auto elementOfSetType = as(info)) { // Note for future reworks: - // Should we make it such that the `GetTagForSpecializedCollection` + // Should we make it such that the `GetTagForSpecializedSet` // is emitted in the single func case too? // // Basically, as long as any of the specialization operands are dynamic, @@ -3734,7 +3717,7 @@ struct TypeFlowSpecializationContext // with dynamic args to be handled in specializeCall. // - if (elementOfCollectionType->getCollection()->isSingleton()) + if (elementOfSetType->getSet()->isSingleton()) { // If the result is a singleton collection, we can just // replace the type (if necessary) and be done with it. @@ -3753,8 +3736,8 @@ struct TypeFlowSpecializationContext specOperands.add(inst->getArg(ii)); auto newInst = builder.emitIntrinsicInst( - (IRType*)makeTagType(elementOfCollectionType->getCollection()), - kIROp_GetTagForSpecializedCollection, + (IRType*)makeTagType(elementOfSetType->getSet()), + kIROp_GetTagForSpecializedSet, specOperands.getCount(), specOperands.getBuffer()); @@ -3778,19 +3761,18 @@ struct TypeFlowSpecializationContext { auto arg = inst->getArg(i); auto argDataType = arg->getDataType(); - if (auto collectionTagType = as(argDataType)) + if (auto collectionTagType = as(argDataType)) { // If this is a tag type, replace with collection. changed = true; - if (as(collectionTagType->getCollection())) + if (as(collectionTagType->getSet())) { - args.add(collectionTagType->getCollection()); + args.add(collectionTagType->getSet()); } - else if ( - auto typeCollection = as(collectionTagType->getCollection())) + else if (auto typeSet = as(collectionTagType->getSet())) { IRBuilder builder(inst); - args.add(builder.getValueOfCollectionType(typeCollection)); + args.add(builder.getUntaggedUnionType(typeSet)); } } else @@ -3828,7 +3810,7 @@ struct TypeFlowSpecializationContext SLANG_UNUSED(context); auto operandInfo = inst->getOperand(0)->getDataType(); - if (as(operandInfo)) + if (as(operandInfo)) { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); @@ -3962,7 +3944,7 @@ struct TypeFlowSpecializationContext IRBuilder builder(context); builder.setInsertBefore(inst); - auto specializedVal = upcastCollection(&builder, inst->getVal(), ptrInfo); + auto specializedVal = upcastSet(&builder, inst->getVal(), ptrInfo); if (specializedVal != inst->getVal()) { @@ -3984,11 +3966,11 @@ struct TypeFlowSpecializationContext // SLANG_UNUSED(context); auto arg = inst->getOperand(0); - if (auto tagType = as(arg->getDataType())) + if (auto tagType = as(arg->getDataType())) { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - auto firstElement = tagType->getCollection()->getElement(0); + auto firstElement = tagType->getSet()->getElement(0); auto interfaceType = as(as(firstElement)->getConformanceType()); List args = {interfaceType, arg}; @@ -4017,15 +3999,15 @@ struct TypeFlowSpecializationContext // SLANG_UNUSED(context); auto witnessTableArg = inst->getValueWitness(); - if (auto tagType = as(witnessTableArg->getDataType())) + if (auto tagType = as(witnessTableArg->getDataType())) { IRBuilder builder(inst); setInsertAfterOrdinaryInst(&builder, inst); - auto targetTag = builder.emitGetTagOfElementInCollection( + auto targetTag = builder.emitGetTagOfElementInSet( (IRType*)tagType, inst->getTargetWitness(), - tagType->getCollection()); + tagType->getSet()); auto eqlInst = builder.emitEql(targetTag, witnessTableArg); inst->replaceUsesWith(eqlInst); @@ -4038,7 +4020,7 @@ struct TypeFlowSpecializationContext bool specializeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) { - if (auto taggedUnionType = as(tryGetInfo(context, inst))) + if (auto taggedUnionType = as(tryGetInfo(context, inst))) { // If we're dealing with a `MakeOptionalNone` for an existential type, then // this just becomes a tagged union tuple where the set of tables is {none} @@ -4049,20 +4031,19 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(inst); // Create a tuple for the empty type.. - SLANG_ASSERT(taggedUnionType->getWitnessTableCollection()->isSingleton()); - auto noneWitnessTable = taggedUnionType->getWitnessTableCollection()->getElement(0); + SLANG_ASSERT(taggedUnionType->getWitnessTableSet()->isSingleton()); + auto noneWitnessTable = taggedUnionType->getWitnessTableSet()->getElement(0); - auto singletonTagType = makeTagType(builder.getSingletonCollection(noneWitnessTable)); - IRInst* zeroValueOfTagType = builder.emitGetTagOfElementInCollection( + auto singletonTagType = makeTagType(builder.getSingletonSet(noneWitnessTable)); + IRInst* zeroValueOfTagType = builder.emitGetTagOfElementInSet( (IRType*)singletonTagType, noneWitnessTable, - taggedUnionType->getWitnessTableCollection()); + taggedUnionType->getWitnessTableSet()); auto newTuple = builder.emitMakeTaggedUnion( (IRType*)taggedUnionType, zeroValueOfTagType, - builder.emitDefaultConstruct( - makeValueOfCollectionType(taggedUnionType->getTypeCollection()))); + builder.emitDefaultConstruct(makeUntaggedUnionType(taggedUnionType->getTypeSet()))); inst->replaceUsesWith(newTuple); propagationMap[InstWithContext(context, newTuple)] = taggedUnionType; @@ -4077,7 +4058,7 @@ struct TypeFlowSpecializationContext bool specializeMakeOptionalValue(IRInst* context, IRMakeOptionalValue* inst) { SLANG_UNUSED(context); - if (as(inst->getValue()->getDataType())) + if (as(inst->getValue()->getDataType())) { // If we're dealing with a `MakeOptionalValue` for an existential type, // we don't actually have to change anything, since logically, the input and output @@ -4099,7 +4080,7 @@ struct TypeFlowSpecializationContext bool specializeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) { SLANG_UNUSED(context); - if (as(inst->getOptionalOperand()->getDataType())) + if (as(inst->getOptionalOperand()->getDataType())) { // Since `GetOptionalValue` is the reverse of `MakeOptionalValue`, and we treat // the latter as a no-op, then `GetOptionalValue` is also a no-op (we simply pass @@ -4117,8 +4098,7 @@ struct TypeFlowSpecializationContext bool specializeOptionalHasValue(IRInst* context, IROptionalHasValue* inst) { SLANG_UNUSED(context); - if (auto taggedUnionType = - as(inst->getOptionalOperand()->getDataType())) + if (auto taggedUnionType = as(inst->getOptionalOperand()->getDataType())) { // The logic here is similar to specializing IsType, but we'll directly compare // tags instead of trying to use sequential ID. @@ -4129,8 +4109,8 @@ struct TypeFlowSpecializationContext // we just return a true. // // 2. 'none' is a possibility. In this case, we create a 0 value of - // type TagType(WitnessTableCollection(NoneWitness)) and then upcast it - // to TagType(inputWitnessTableCollection). This will convert the value + // type TagType(WitnessTableSet(NoneWitness)) and then upcast it + // to TagType(inputWitnessTableSet). This will convert the value // to the corresponding value of 'none' in the input's table collection // allowing us to directly compare it against the tag part of the // input tagged union. @@ -4139,8 +4119,8 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); bool containsNone = false; - forEachInCollection( - taggedUnionType->getWitnessTableCollection(), + forEachInSet( + taggedUnionType->getWitnessTableSet(), [&](IRInst* wt) { if (wt == getNoneWitness()) @@ -4170,10 +4150,10 @@ struct TypeFlowSpecializationContext // Cast the singleton tag to the target collection tag (will convert the // value to the corresponding value for the larger set) // - auto noneWitnessTag = builder.emitGetTagOfElementInCollection( - (IRType*)makeTagType(taggedUnionType->getWitnessTableCollection()), + auto noneWitnessTag = builder.emitGetTagOfElementInSet( + (IRType*)makeTagType(taggedUnionType->getWitnessTableSet()), getNoneWitness(), - taggedUnionType->getWitnessTableCollection()); + taggedUnionType->getWitnessTableSet()); auto newInst = builder.emitNeq(dynTag, noneWitnessTag); inst->replaceUsesWith(newInst); diff --git a/source/slang/slang-ir-typeflow-specialize.h b/source/slang/slang-ir-typeflow-specialize.h index 8863ac915b8..6f3e562675a 100644 --- a/source/slang/slang-ir-typeflow-specialize.h +++ b/source/slang/slang-ir-typeflow-specialize.h @@ -18,5 +18,5 @@ namespace Slang // bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink); -bool isDynamicGeneric(IRInst* callee); +bool isSetSpecializedGeneric(IRInst* callee); } // namespace Slang diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp index 470501676e5..61d5ef5ccb4 100644 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ b/source/slang/slang-ir-witness-table-wrapper.cpp @@ -124,96 +124,6 @@ struct GenerateWitnessTableWrapperContext } }; - -// DUPLICATES... put into common file. - - -static bool isTaggedUnionType(IRInst* type) -{ - return as(type) != nullptr; -} -/* -static UCount getCollectionCount(IRCollectionBase* collection) -{ - if (!collection) - return 0; - return collection->getOperandCount(); -} - -static UCount getCollectionCount(IRCollectionTagType* tagType) -{ - auto collection = tagType->getOperand(0); - return getCollectionCount(as(collection)); -} - -static IRInst* upcastCollection(IRBuilder* builder, IRInst* arg, IRType* destInfo) -{ - auto argInfo = arg->getDataType(); - if (!argInfo || !destInfo) - return arg; - - if (isTaggedUnionType(argInfo) && isTaggedUnionType(destInfo)) - { - auto argTupleType = as(argInfo); - auto destTupleType = as(destInfo); - - List upcastedElements; - bool hasUpcastedElements = false; - - // Upcast each element of the tuple - for (UInt i = 0; i < argTupleType->getOperandCount(); ++i) - { - auto argElementType = argTupleType->getOperand(i); - auto destElementType = destTupleType->getOperand(i); - - // If the element types are different, we need to reinterpret - if (argElementType != destElementType) - { - hasUpcastedElements = true; - upcastedElements.add(upcastCollection( - builder, - builder->emitGetTupleElement((IRType*)argElementType, arg, i), - (IRType*)destElementType)); - } - else - { - upcastedElements.add(builder->emitGetTupleElement((IRType*)argElementType, arg, i)); - } - } - - if (hasUpcastedElements) - { - return builder->emitMakeTuple(upcastedElements); - } - } - else if (as(argInfo) && as(destInfo)) - { - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - return builder - ->emitIntrinsicInst((IRType*)destInfo, kIROp_GetTagForSuperCollection, 1, &arg); - } - } - else if (as(argInfo) && as(destInfo)) - { - if (getCollectionCount(as(argInfo)) != - getCollectionCount(as(destInfo))) - { - // If the sets of witness tables are not equal, reinterpret to the parameter type - return builder->emitReinterpret((IRType*)destInfo, arg); - } - } - else if (!as(argInfo) && as(destInfo)) - { - SLANG_UNEXPECTED("Raw collections should not appear"); - return builder->emitPackAnyValue((IRType*)destInfo, arg); - } - - return arg; // Can use as-is. -} -*/ - // Represents a work item for packing `inout` or `out` arguments after a concrete call. struct ArgumentPackWorkItem { @@ -231,7 +141,7 @@ struct ArgumentPackWorkItem bool isAnyValueType(IRType* type) { - if (as(type) || as(type)) + if (as(type) || as(type)) return true; return false; } @@ -297,7 +207,7 @@ IRInst* maybeUnpackArg( // by checking if the types are different, but this should be // encoded in the types. // - if (isTaggedUnionType(paramValType) && isTaggedUnionType(argValType) && + if (as(paramValType) && as(argValType) && paramValType != argValType) { // if parameter expects an `out` pointer, store the unpacked val into a @@ -395,7 +305,7 @@ IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* inte auto concreteVal = builder->emitLoad(item.concreteArg); auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) ? builder->emitPackAnyValue(anyValType, concreteVal) - : upcastCollection(builder, concreteVal, anyValType); + : upcastSet(builder, concreteVal, anyValType); builder->emitStore(item.dstArg, packedVal); } @@ -408,7 +318,7 @@ IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* inte } else if (call->getDataType() != funcTypeInInterface->getResultType()) { - auto reinterpret = upcastCollection(builder, call, funcTypeInInterface->getResultType()); + auto reinterpret = upcastSet(builder, call, funcTypeInInterface->getResultType()); builder->emitReturn(reinterpret); } else diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f1463686a44..141bdc30591 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6554,7 +6554,7 @@ IREntryPointLayout* IRBuilder::getEntryPointLayout( operands)); } -IRCollectionBase* IRBuilder::getCollection(IROp op, const HashSet& elements) +IRSetBase* IRBuilder::getSet(IROp op, const HashSet& elements) { if (elements.getCount() == 0) return nullptr; @@ -6562,7 +6562,7 @@ IRCollectionBase* IRBuilder::getCollection(IROp op, const HashSet& elem // Verify that all operands are global instructions for (auto element : elements) if (element->getParent()->getOp() != kIROp_ModuleInst) - SLANG_ASSERT_FAILURE("createCollection called with non-global operands"); + SLANG_ASSERT_FAILURE("createSet called with non-global operands"); List sortedElements; for (auto element : elements) @@ -6572,25 +6572,25 @@ IRCollectionBase* IRBuilder::getCollection(IROp op, const HashSet& elem sortedElements.sort( [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); - return as( + return as( emitIntrinsicInst(nullptr, op, sortedElements.getCount(), sortedElements.getBuffer())); } -IRCollectionBase* IRBuilder::getCollection(const HashSet& elements) +IRSetBase* IRBuilder::getSet(const HashSet& elements) { SLANG_ASSERT(elements.getCount() > 0); auto firstElement = *elements.begin(); - return getCollection(getCollectionTypeForInst(firstElement), elements); + return getSet(getSetTypeForInst(firstElement), elements); } -IRCollectionBase* IRBuilder::getSingletonCollection(IROp op, IRInst* element) +IRSetBase* IRBuilder::getSingletonSet(IROp op, IRInst* element) { - return getCollection(op, {element}); + return getSet(op, {element}); } -IRCollectionBase* IRBuilder::getSingletonCollection(IRInst* element) +IRSetBase* IRBuilder::getSingletonSet(IRInst* element) { - return getCollection(getCollectionTypeForInst(element), {element}); + return getSet(getSetTypeForInst(element), {element}); } UInt IRBuilder::getUniqueID(IRInst* inst) @@ -6605,19 +6605,19 @@ UInt IRBuilder::getUniqueID(IRInst* inst) return id; } -IROp IRBuilder::getCollectionTypeForInst(IRInst* inst) +IROp IRBuilder::getSetTypeForInst(IRInst* inst) { if (as(inst)) - return kIROp_GenericCollection; + return kIROp_GenericSet; if (as(inst->getDataType())) - return kIROp_TypeCollection; + return kIROp_TypeSet; else if (as(inst->getDataType())) - return kIROp_FuncCollection; + return kIROp_FuncSet; else if (as(inst) && !as(inst)) - return kIROp_TypeCollection; + return kIROp_TypeSet; else if (as(inst->getDataType())) - return kIROp_WitnessTableCollection; + return kIROp_WitnessTableSet; else return kIROp_Invalid; // Return invalid IROp when not supported } @@ -8545,9 +8545,9 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetCurrentStage: case kIROp_GetDispatcher: case kIROp_GetSpecializedDispatcher: - case kIROp_GetTagForMappedCollection: - case kIROp_GetTagForSpecializedCollection: - case kIROp_GetTagForSuperCollection: + case kIROp_GetTagForMappedSet: + case kIROp_GetTagForSpecializedSet: + case kIROp_GetTagForSuperSet: case kIROp_GetTagFromSequentialID: case kIROp_GetSequentialIDFromTag: case kIROp_CastInterfaceToTaggedUnionPtr: @@ -8557,7 +8557,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetTypeTagFromTaggedUnion: case kIROp_GetValueFromTaggedUnion: case kIROp_MakeTaggedUnion: - case kIROp_GetTagOfElementInCollection: + case kIROp_GetTagOfElementInSet: return false; case kIROp_ForwardDifferentiate: From 71678499ae3214e6aec99c592deeedd358d4dc04 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 12:24:30 -0400 Subject: [PATCH 077/105] Remove `TypeFlowData` base inst. --- source/slang/slang-ir-insts-stable-names.lua | 10 +- source/slang/slang-ir-insts.h | 8 +- source/slang/slang-ir-insts.lua | 103 +++++++++--------- source/slang/slang-ir-typeflow-specialize.cpp | 46 ++------ source/slang/slang-ir.cpp | 3 +- 5 files changed, 67 insertions(+), 103 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index cf06de96fa0..0d35424ee59 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -682,11 +682,11 @@ return { ["Attr.MemoryScope"] = 678, ["Undefined.LoadFromUninitializedMemory"] = 679, ["CUDA_LDG"] = 680, - ["TypeFlowData.SetBase.TypeSet"] = 681, - ["TypeFlowData.SetBase.FuncSet"] = 682, - ["TypeFlowData.SetBase.WitnessTableSet"] = 683, - ["TypeFlowData.SetBase.GenericSet"] = 684, - ["TypeFlowData.UnboundedSet"] = 685, + ["SetBase.TypeSet"] = 681, + ["SetBase.FuncSet"] = 682, + ["SetBase.WitnessTableSet"] = 683, + ["SetBase.GenericSet"] = 684, + ["UnboundedSet"] = 685, ["Type.SetTagType"] = 686, ["Type.TaggedUnionType"] = 687, ["CastInterfaceToTaggedUnionPtr"] = 688, diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 29c5b3707f9..6b6ef5c120d 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2837,13 +2837,7 @@ struct IREmbeddedDownstreamIR : IRInst }; FIDDLE() -struct IRTypeFlowData : IRInst -{ - FIDDLE(baseInst()) -}; - -FIDDLE() -struct IRSetBase : IRTypeFlowData +struct IRSetBase : IRInst { FIDDLE(baseInst()) UInt getCount() { return getOperandCount(); } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 0b03bbda3aa..631a933cfa0 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2679,62 +2679,59 @@ local insts = { }, }, { - TypeFlowData = { - -- A collection of IR instructions used for propagation analysis. + SetBase = { + -- Base class for all set representation.s + -- + -- Semantically, `SetBase` types model sets of concrete values, and use Slang's de-duplication infrastructure + -- to allow set-equality to be the same as inst identity. + -- + -- - Set ops have one or more operands that represent the elements of the set + -- + -- - Set ops must have at least one operand. A zero-operand set is illegal. + -- The type-flow pass will represent this case using nullptr, so that uniqueness is preserved. + -- + -- - All operands of a set _must_ be concrete, individual insts + -- - Operands should NOT be an interface or abstract type. + -- - Operands should NOT be type parameters or existentail types (i.e. insts that appear in blocks) + -- - Operands should NOT be sets (i.e. sets should be flat and never heirarchical) + -- + -- - Since sets are hositable, set ops should (consequently) only appear in the global scope. + -- + -- - Set operands must be consistently sorted. i.e. a TypeSet(A, B) and TypeSet(B, A) + -- cannot exist at the same time, but either one is okay. + -- + -- - To help with the implementation of sets, the IRBuilder class provides operations such as `getSet` + -- that will ensure the above invariants are maintained, and uses a persistent unique ID map to + -- ensure stable ordering of set elements. + -- + -- Set representations should never be manually constructed to avoid breaking these invariants. + -- hoistable = true, - { - SetBase = { - -- Base class for all collection types. - -- - -- Semantically, collections model sets of concrete values, and use Slang's de-duplication infrastructure - -- to allow set-equality to be the same as inst identity. - -- - -- - Set ops have one or more operands that represent the elements of the collection - -- - -- - Set ops must have at least one operand. A zero-operand collection is illegal. - -- The type-flow pass will represent this case using nullptr, so that uniqueness is preserved. - -- - -- - All operands of a collection _must_ be concrete, individual insts - -- - Operands should NOT be an interface or abstract type. - -- - Operands should NOT be type parameters or existentail types (i.e. insts that appear in blocks) - -- - Operands should NOT be collections (i.e. collections should be flat and never heirarchical) - -- - -- - Since collections are hositable, collection ops should (consequently) only appear in the global scope. - -- - -- - Set operands must be consistently sorted. i.e. a TypeSet(A, B) and TypeSet(B, A) - -- cannot exist at the same time, but either one is okay. - -- - -- - To help with the implementation of collections, the SetBuilder class is provided - -- in slang-ir-typeflow-collection.h. - -- All collection insts must be built using the SetBuilder, which uses a persistent map on the module - -- inst to ensure stable ordering. - -- - { TypeSet = {} }, - { FuncSet = {} }, - { WitnessTableSet = {} }, - { GenericSet = {} }, - }, - }, - { UnboundedSet = { - -- - -- A catch-all opcode to represent unbounded collections during - -- the type-flow specialization pass. - -- - -- This op is usually used to mark insts that can contain a dynamic type - -- whose information cannot be gleaned from the type-flow analysis. - -- - -- E.g. COM interface objects, whose implementations can be fully external to - -- the linkage - -- - -- This op is only used to denote that an inst is unbounded so the specialization - -- pass does not attempt to specialize it. It should not appear in the code after - -- the specialization pass. - -- - -- TODO: Consider the scenario where we can combine the unbounded case with known cases. - -- unbounded collection should probably be an element and not a separate op. - } }, + { TypeSet = {} }, + { FuncSet = {} }, + { WitnessTableSet = {} }, + { GenericSet = {} }, }, }, + { UnboundedSet = { + hoistable = true, + -- + -- A catch-all opcode to represent unbounded collections during + -- the type-flow specialization pass. + -- + -- This op is usually used to mark insts that can contain a dynamic type + -- whose information cannot be gleaned from the type-flow analysis. + -- + -- E.g. COM interface objects, whose implementations can be fully external to + -- the linkage + -- + -- This op is only used to denote that an inst is unbounded so the specialization + -- pass does not attempt to specialize it. It should not appear in the code after + -- the specialization pass. + -- + -- TODO: Consider the scenario where we can combine the unbounded case with known cases. + -- unbounded collection should probably be an element and not a separate op. + } }, { CastInterfaceToTaggedUnionPtr = { -- Cast an interface-typed pointer to a tagged-union pointer with a known set. } }, diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 8a75445f507..5822e092489 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -498,7 +498,7 @@ struct TypeFlowSpecializationContext // // From an order-theoretic perspective, 'none' is the bottom of the lattice. // - IRTypeFlowData* none() { return nullptr; } + IRInst* none() { return nullptr; } IRUntaggedUnionType* makeUntaggedUnionType(IRTypeSet* typeSet) { @@ -1394,7 +1394,7 @@ struct TypeFlowSpecializationContext // the information from the stored value. // // Since the pointer can be an access chain, we have to recursively transfer - // the information down to the base. This logic is handled by `maybeUpdatePtr` + // the information down to the base. This logic is handled by `maybeUpdateInfoForAddress` // // If the value has "info", we construct an appropriate PtrType(info) and // update the ptr with it. @@ -1409,7 +1409,7 @@ struct TypeFlowSpecializationContext as(address->getDataType())); // Propagate the information up the access chain to the base location. - maybeUpdatePtr(context, address, ptrInfo, workQueue); + maybeUpdateInfoForAddress(context, address, ptrInfo, workQueue); } // The store inst itself doesn't produce anything, so it has no info @@ -2083,7 +2083,11 @@ struct TypeFlowSpecializationContext } // Updates the information for an address. - void maybeUpdatePtr(IRInst* context, IRInst* inst, IRInst* info, WorkQueue& workQueue) + void maybeUpdateInfoForAddress( + IRInst* context, + IRInst* inst, + IRInst* info, + WorkQueue& workQueue) { // This method recursively walks up the access chain until it hits a location. // @@ -2123,7 +2127,7 @@ struct TypeFlowSpecializationContext as(getElementPtr->getBase()->getDataType())); // Recursively try to update the base pointer. - maybeUpdatePtr(context, getElementPtr->getBase(), baseInfo, workQueue); + maybeUpdateInfoForAddress(context, getElementPtr->getBase(), baseInfo, workQueue); } } else if (auto fieldAddress = as(inst)) @@ -2371,38 +2375,6 @@ struct TypeFlowSpecializationContext } } - // Default catch-all analysis method for any unhandled case. - IRTypeFlowData* analyzeDefault(IRInst* context, IRInst* inst) - { - SLANG_UNUSED(context); - IRBuilder builder(module); - - // Check if this is a global concrete type, witness table, or function. - // If so, it's a concrete element. We'll create a singleton set for it. - if (isGlobalInst(inst) && - (!as(inst) && - (as(inst) || as(inst) || as(inst)))) - return builder.getSingletonSet(inst); - - auto instType = inst->getDataType(); - if (isGlobalInst(inst)) - { - if (as(instType) && !(as(instType))) - return none(); // We'll avoid storing propagation info for concrete insts. (can just - // use the inst directly) - - if (as(instType)) - { - // As a general rule, if none of the non-default cases handled this inst that is - // producing an existential type, then we assume that we can't constrain it - // - return makeUnbounded(); - } - } - - return none(); // Default case, no propagation info - } - // Specialize the fields of a struct type based on the recorded field info (if we // have a non-trivial specilialization) // diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 141bdc30591..e5671bbe3b3 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8352,7 +8352,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) if (as(this)) return false; - if (as(this)) + if (as(this)) return false; switch (getOp()) @@ -8558,6 +8558,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetValueFromTaggedUnion: case kIROp_MakeTaggedUnion: case kIROp_GetTagOfElementInSet: + case kIROp_UnboundedSet: return false; case kIROp_ForwardDifferentiate: From f7a872230592dfee58ba7fec03f69bd62f474011 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 12:30:33 -0400 Subject: [PATCH 078/105] Update slang-ir-typeflow-specialize.cpp --- source/slang/slang-ir-typeflow-specialize.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 5822e092489..8b0c3beafc7 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -1612,6 +1612,8 @@ struct TypeFlowSpecializationContext return info; } } + + return none(); } IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) From a4f538aa46f811f29d12704993205ce9df085804 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 13:03:51 -0400 Subject: [PATCH 079/105] Update slang-ir-typeflow-specialize.cpp --- source/slang/slang-ir-typeflow-specialize.cpp | 118 +++++++++--------- 1 file changed, 62 insertions(+), 56 deletions(-) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 8b0c3beafc7..1e7fa0f26b4 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -410,6 +410,61 @@ IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo info, IR } } +bool isConcreteType(IRInst* inst) +{ + bool isInstGlobal = isGlobalInst(inst); + if (!isInstGlobal) + return false; + + switch (inst->getOp()) + { + case kIROp_InterfaceType: + return false; + case kIROp_WitnessTableType: + case kIROp_FuncType: + return isInstGlobal; + case kIROp_ArrayType: + return isConcreteType(cast(inst)->getElementType()) && + isGlobalInst(cast(inst)->getElementCount()); + case kIROp_OptionalType: + return isConcreteType(cast(inst)->getValueType()); + default: + break; + } + + if (as(inst)) + { + auto ptrType = as(inst); + return isConcreteType(ptrType->getValueType()); + } + + return true; +} + +IRInst* makeInfoForConcreteType(IRModule* module, IRInst* inst) +{ + SLANG_ASSERT(isConcreteType(inst)); + IRBuilder builder(module); + if (auto ptrType = as(inst->getDataType())) + { + return builder.getPtrTypeWithAddressSpace( + builder.getUntaggedUnionType( + cast(builder.getSingletonSet(ptrType->getValueType()))), + ptrType); + } + + if (auto arrayType = as(inst->getDataType())) + { + return builder.getArrayType( + builder.getUntaggedUnionType( + cast(builder.getSingletonSet(arrayType->getElementType()))), + arrayType->getElementCount()); + } + + return builder.getUntaggedUnionType( + cast(builder.getSingletonSet(inst->getDataType()))); +} + // Helper to check if an IRParam is a function parameter (vs. a phi param or generic param) bool isFuncParam(IRParam* param) { @@ -526,69 +581,20 @@ struct TypeFlowSpecializationContext return none(); // Default info for any inst that we haven't registered. } - bool isConcreteType(IRInst* inst) - { - if (!isGlobalInst(inst) || as(inst) || - as(inst) && as(inst)) - return false; - - if (as(inst)) - { - auto ptrType = as(inst); - return isConcreteType(ptrType->getValueType()); - } - - if (as(inst)) - { - auto arrayType = as(inst); - return isConcreteType(arrayType->getElementType()) && - isGlobalInst(arrayType->getElementCount()); - } - - if (as(inst)) - { - auto optionalType = as(inst); - return isConcreteType(optionalType->getValueType()); - } - - return true; - } - IRInst* tryGetArgInfo(IRInst* context, IRInst* inst) { if (auto info = tryGetInfo(context, inst)) return info; - IRBuilder builder(module); - if (auto ptrType = as(inst->getDataType())) - { - if (isConcreteType(ptrType->getValueType())) - return builder.getPtrTypeWithAddressSpace( - builder.getUntaggedUnionType( - cast(builder.getSingletonSet(ptrType->getValueType()))), - ptrType); - else - return none(); - } - - if (auto arrayType = as(inst->getDataType())) + if (isConcreteType(inst->getDataType())) { - if (isConcreteType(arrayType)) - { - return builder.getArrayType( - builder.getUntaggedUnionType( - cast(builder.getSingletonSet(arrayType->getElementType()))), - arrayType->getElementCount()); - } - else - return none(); + // If the inst has a concrete type, we can make a default info for it, + // considering it as a singleton set. Note that this needs to be + // nested into any relevant type structure that we want to propagate through + // `makeInfoForConcreteType` handles this logic. + // + return makeInfoForConcreteType(module, inst->getDataType()); } - - if (isConcreteType(inst->getDataType())) - return builder.getUntaggedUnionType( - cast(builder.getSingletonSet(inst->getDataType()))); - else - return none(); } // From 17e38e2e66eabbb3be4e1520fa448ee3b136bea0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:21:29 -0400 Subject: [PATCH 080/105] More CI fixes --- source/slang/slang-ir-insts.h | 1 + source/slang/slang-ir-typeflow-specialize.cpp | 112 ++++++++++++------ source/slang/slang-ir.cpp | 12 ++ 3 files changed, 87 insertions(+), 38 deletions(-) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 6b6ef5c120d..20cc7e334e3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3079,6 +3079,7 @@ struct IRBuilder IRPtrLit* getNullPtrValue(IRType* type); IRPtrLit* getNullVoidPtrValue() { return getNullPtrValue(getPtrType(getVoidType())); } IRVoidLit* getVoidValue(); + IRVoidLit* getVoidValue(IRType* type); IRInst* getCapabilityValue(CapabilitySet const& caps); IRBasicType* getBasicType(BaseType baseType); diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 1e7fa0f26b4..704d2f646e2 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -410,6 +410,27 @@ IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo info, IR } } +IRInst* lookupWitnessTableEntry(IRWitnessTable* table, IRInst* key) +{ + if (auto entry = findWitnessTableEntry(table, key)) + { + return entry; + } + else if (as(table->getConcreteType())) + { + IRBuilder builder(table->getModule()); + return builder.getVoidValue(); + } + + return nullptr; +} + +// Helper to test if an inst is in the global scope. +bool isGlobalInst(IRInst* inst) +{ + return inst->getParent()->getOp() == kIROp_ModuleInst; +} + bool isConcreteType(IRInst* inst) { bool isInstGlobal = isGlobalInst(inst); @@ -441,28 +462,25 @@ bool isConcreteType(IRInst* inst) return true; } -IRInst* makeInfoForConcreteType(IRModule* module, IRInst* inst) +IRInst* makeInfoForConcreteType(IRModule* module, IRInst* type) { - SLANG_ASSERT(isConcreteType(inst)); + SLANG_ASSERT(isConcreteType(type)); IRBuilder builder(module); - if (auto ptrType = as(inst->getDataType())) + if (auto ptrType = as(type)) { return builder.getPtrTypeWithAddressSpace( - builder.getUntaggedUnionType( - cast(builder.getSingletonSet(ptrType->getValueType()))), + (IRType*)makeInfoForConcreteType(module, ptrType->getValueType()), ptrType); } - if (auto arrayType = as(inst->getDataType())) + if (auto arrayType = as(type)) { return builder.getArrayType( - builder.getUntaggedUnionType( - cast(builder.getSingletonSet(arrayType->getElementType()))), + (IRType*)makeInfoForConcreteType(module, arrayType->getElementType()), arrayType->getElementCount()); } - return builder.getUntaggedUnionType( - cast(builder.getSingletonSet(inst->getDataType()))); + return builder.getUntaggedUnionType(cast(builder.getSingletonSet(type))); } // Helper to check if an IRParam is a function parameter (vs. a phi param or generic param) @@ -473,12 +491,6 @@ bool isFuncParam(IRParam* param) return (paramFunc && paramFunc->getFirstBlock() == paramBlock); } -// Helper to test if an inst is in the global scope. -bool isGlobalInst(IRInst* inst) -{ - return inst->getParent()->getOp() == kIROp_ModuleInst; -} - // Helper to test if a function or generic contains a body (i.e. is intrinsic/external) // For the purposes of type-flow, if a function body is not available, we can't analyze it. // @@ -595,6 +607,8 @@ struct TypeFlowSpecializationContext // return makeInfoForConcreteType(module, inst->getDataType()); } + + return none(); } // @@ -1037,7 +1051,8 @@ struct TypeFlowSpecializationContext if (auto returnInfo = as(block->getTerminator())) { auto val = returnInfo->getVal(); - updateFuncReturnInfo(context, tryGetArgInfo(context, val), workQueue); + if (!as(val->getDataType())) + updateFuncReturnInfo(context, tryGetArgInfo(context, val), workQueue); } }; @@ -1071,7 +1086,8 @@ struct TypeFlowSpecializationContext if (auto argInfo = tryGetArgInfo(context, arg)) { // Use centralized update method - updateInfo(context, param, argInfo, true, workQueue); + if (!isConcreteType(param->getDataType())) + updateInfo(context, param, argInfo, true, workQueue); } } paramIndex++; @@ -1139,12 +1155,14 @@ struct TypeFlowSpecializationContext &builder, paramDirection, as(argInfo)->getValueType()); - updateInfo(edge.targetContext, param, newInfo, true, workQueue); + if (!isConcreteType(param->getDataType())) + updateInfo(edge.targetContext, param, newInfo, true, workQueue); break; } case ParameterDirectionInfo::Kind::In: { - updateInfo(edge.targetContext, param, argInfo, true, workQueue); + if (!isConcreteType(param->getDataType())) + updateInfo(edge.targetContext, param, argInfo, true, workQueue); break; } default: @@ -1206,14 +1224,15 @@ struct TypeFlowSpecializationContext auto argPtrType = as(arg->getDataType()); IRBuilder builder(module); - updateInfo( - edge.callerContext, - arg, - builder.getPtrTypeWithAddressSpace( - (IRType*)as(paramInfo)->getValueType(), - argPtrType), - true, - workQueue); + if (!isConcreteType(arg->getDataType())) + updateInfo( + edge.callerContext, + arg, + builder.getPtrTypeWithAddressSpace( + (IRType*)as(paramInfo)->getValueType(), + argPtrType), + true, + workQueue); } } argIndex++; @@ -1647,7 +1666,7 @@ struct TypeFlowSpecializationContext forEachInSet( cast(elementOfSetType->getSet()), [&](IRInst* table) - { results.add(findWitnessTableEntry(cast(table), key)); }); + { results.add(lookupWitnessTableEntry(cast(table), key)); }); return makeElementOfSetType(builder.getSet(results)); } @@ -2049,6 +2068,9 @@ struct TypeFlowSpecializationContext auto propagateToCallSite = [&](IRInst* callee) { + if (as(callee)) + return; + // Register the call site in the map to allow for the // return-edge to be created. // @@ -2198,7 +2220,8 @@ struct TypeFlowSpecializationContext // // This is one of the base cases for the recursion. // - updateInfo(context, var, info, true, workQueue); + if (!isConcreteType(var->getDataType())) + updateInfo(context, var, info, true, workQueue); } else if (auto param = as(inst)) { @@ -2215,7 +2238,9 @@ struct TypeFlowSpecializationContext auto newInfo = builder.getPtrTypeWithAddressSpace( (IRType*)as(info)->getValueType(), as(param->getDataType())); - updateInfo(context, param, newInfo, true, workQueue); + + if (!isConcreteType(param->getDataType())) + updateInfo(context, param, newInfo, true, workQueue); } else { @@ -2742,7 +2767,7 @@ struct TypeFlowSpecializationContext // Handle trivial case where inst's operand is a concrete table. if (auto witnessTable = as(inst->getWitnessTable())) { - inst->replaceUsesWith(findWitnessTableEntry(witnessTable, inst->getRequirementKey())); + inst->replaceUsesWith(lookupWitnessTableEntry(witnessTable, inst->getRequirementKey())); inst->removeAndDeallocate(); return true; } @@ -3324,6 +3349,23 @@ struct TypeFlowSpecializationContext "Unexpected operand type for type-flow specialization of Call inst"); } } + else if (as(callee)) + { + // Occasionally, we will determine that there are absolutely no possible callees + // for a call site. This typically happens to impossible branches. + // + // The correct way to handle such cases is to improve the analysis to avoid + // branches that are impossible. For now, we will just remove the callee and + // replace with a default value. The exact value doesn't matter since we've determined + // that this code is unreachable. + // + IRBuilder builder(context); + builder.setInsertBefore(inst); + auto defaultVal = builder.emitDefaultConstruct(inst->getDataType()); + inst->replaceUsesWith(defaultVal); + inst->removeAndDeallocate(); + return true; + } else if (isGlobalInst(callee) && !isIntrinsic(callee)) { // If our callee is not a tag-type, then it is necessarily a simple concrete function. @@ -3340,12 +3382,6 @@ struct TypeFlowSpecializationContext if (!isGlobalInst(callee) || isIntrinsic(callee)) return false; - // One other case to avoid is if the function is a global LookupWitnessMethod - // which can be created when optional witnesses are specialized. - // - if (as(callee)) - return false; - // First, we'll legalize all operands by upcasting if necessary. // This needs to be done even if the callee is not a collection. // diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e5671bbe3b3..be2e808ed6d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2483,6 +2483,16 @@ IRVoidLit* IRBuilder::getVoidValue() return (IRVoidLit*)_findOrEmitConstant(keyInst); } +IRVoidLit* IRBuilder::getVoidValue(IRType* type) +{ + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.m_op = kIROp_VoidLit; + keyInst.typeUse.usedValue = type; + keyInst.value.intVal = 0; + return (IRVoidLit*)_findOrEmitConstant(keyInst); +} + IRInst* IRBuilder::getCapabilityValue(CapabilitySet const& caps) { IRType* capabilityAtomType = getIntType(); @@ -6618,6 +6628,8 @@ IROp IRBuilder::getSetTypeForInst(IRInst* inst) return kIROp_TypeSet; else if (as(inst->getDataType())) return kIROp_WitnessTableSet; + else if (as(inst)) + return kIROp_TypeSet; // TODO: this feels wrong... else return kIROp_Invalid; // Return invalid IROp when not supported } From 5b0263a1f148ecd85a5bfd6de8a8b2b0c3819e45 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 22 Oct 2025 17:23:07 -0400 Subject: [PATCH 081/105] Fix up documentation --- .../slang/slang-ir-lower-typeflow-insts.cpp | 111 ++-- source/slang/slang-ir-typeflow-specialize.cpp | 513 ++++++++++-------- 2 files changed, 365 insertions(+), 259 deletions(-) diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index b1ab94013ee..bcd969f7551 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -215,6 +215,16 @@ IRFunc* createIntegerMappingFunc(IRModule* module, Dictionary& mappi // struct TagOpsLoweringContext : public InstPassBase { + // Our strategy for lowering tag operations is to + // assign each element to a unique integer ID that is stable + // across the same module. This is acheived via `getUniqueID`, + // on the IRBuilder, which uses a dicionary on the module inst + // to keep track of assignments. + // + // Then, tag operations can be lowered to mapping functions that + // take an integer in and return an integer out, based on the + // input and output sets (and any other operands) + // TagOpsLoweringContext(IRModule* module) : InstPassBase(module) { @@ -222,12 +232,22 @@ struct TagOpsLoweringContext : public InstPassBase void lowerGetTagForSuperSet(IRGetTagForSuperSet* inst) { + // `GetTagForSuperSet` is a no-op since we want to translate the tag + // for an element in the sub-set to a tag for the same element in the super-set. + // + // Since all elements have a unique ID across the module, this is the identity operation. + // + inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); } void lowerGetTagForMappedSet(IRGetTagForMappedSet* inst) { + // `GetTagForMappedSet` turns into a integer mapping from + // the unique ID of each input set element to the unique ID of the + // corresponding element (as determined by witness table lookup) in the destination set. + // auto srcSet = cast( cast(inst->getOperand(0)->getDataType())->getOperand(0)); auto destSet = cast(cast(inst->getDataType())->getOperand(0)); @@ -266,7 +286,7 @@ struct TagOpsLoweringContext : public InstPassBase } } - // Create an index mapping func and call that + // Create an index mapping func and call that. auto mappingFunc = createIntegerMappingFunc(inst->getModule(), mapping, 0); auto resultID = builder.emitCallInst( @@ -279,6 +299,12 @@ struct TagOpsLoweringContext : public InstPassBase void lowerGetTagOfElementInSet(IRGetTagOfElementInSet* inst) { + // `GetTagOfElementInSet` simply gets replaced by the element's + // unique ID (as an integer literal value) + // + // Note: the element must be a concrete global inst (cannot by a + // dynamic value) + // IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); @@ -319,51 +345,40 @@ struct DispatcherLoweringContext : public InstPassBase { } - void lowerGetSpecializedDispatcher(IRGetSpecializedDispatcher* dispatcher) + void lowerGetDispatcher(IRGetDispatcher* dispatcher) { - // Replace the `IRGetSpecializedDispatcher` with a dispatch function, + // Replace the `IRGetDispatcher` with a dispatch function, // which takes an extra first parameter for the tag (i.e. ID) // // We'll also replace the callee in all 'call' insts. // + // The generated dispatch function uses a switch-case to call the + // appropriate function based on the integer tag. Since tags + // may not yet be lowered into actual integers, we use `GetTagOfElementInSet` + // as a placeholder literal. + // + // Note that before each function is called, it needs to be wrapped in a + // method (a 'witness table wrapper') that handles marshalling between the input types + // to the dispatcher and the actual function types (which may be different) + // auto witnessTableSet = cast(dispatcher->getOperand(0)); auto key = cast(dispatcher->getOperand(1)); - List specArgs; - for (UIndex i = 2; i < dispatcher->getOperandCount(); i++) - { - specArgs.add(dispatcher->getOperand(i)); - } + IRBuilder builder(dispatcher->getModule()); Dictionary elements; - IRBuilder builder(dispatcher->getModule()); forEachInSet( witnessTableSet, [&](IRInst* table) { - auto generic = - cast(findWitnessTableEntry(cast(table), key)); - - auto specializedFuncType = - (IRType*)specializeGeneric(cast(builder.emitSpecializeInst( - builder.getTypeKind(), - generic->getDataType(), - specArgs.getCount(), - specArgs.getBuffer()))); - - auto specializedFunc = builder.emitSpecializeInst( - specializedFuncType, - generic, - specArgs.getCount(), - specArgs.getBuffer()); - - auto singletonTag = builder.emitGetTagOfElementInSet( + auto tag = builder.emitGetTagOfElementInSet( builder.getSetTagType(witnessTableSet), table, witnessTableSet); - - elements.add(singletonTag, specializedFunc); + elements.add( + tag, + cast(findWitnessTableEntry(cast(table), key))); }); if (dispatcher->hasUses() && dispatcher->getDataType() != nullptr) @@ -388,31 +403,57 @@ struct DispatcherLoweringContext : public InstPassBase } } - void lowerGetDispatcher(IRGetDispatcher* dispatcher) + + void lowerGetSpecializedDispatcher(IRGetSpecializedDispatcher* dispatcher) { - // Replace the `IRGetDispatcher` with a dispatch function, + // Replace the `IRGetSpecializedDispatcher` with a dispatch function, // which takes an extra first parameter for the tag (i.e. ID) // // We'll also replace the callee in all 'call' insts. // + // The logic here is very similar to `lowerGetDispatcher`, except that we need to + // account for the specialization arguments when creating the dispatch function. + // We construct an `IRSpecialize` inst around each generic function before dispatching + // to it. + // auto witnessTableSet = cast(dispatcher->getOperand(0)); auto key = cast(dispatcher->getOperand(1)); - IRBuilder builder(dispatcher->getModule()); + List specArgs; + for (UIndex i = 2; i < dispatcher->getOperandCount(); i++) + { + specArgs.add(dispatcher->getOperand(i)); + } Dictionary elements; + IRBuilder builder(dispatcher->getModule()); forEachInSet( witnessTableSet, [&](IRInst* table) { - auto tag = builder.emitGetTagOfElementInSet( + auto generic = + cast(findWitnessTableEntry(cast(table), key)); + + auto specializedFuncType = + (IRType*)specializeGeneric(cast(builder.emitSpecializeInst( + builder.getTypeKind(), + generic->getDataType(), + specArgs.getCount(), + specArgs.getBuffer()))); + + auto specializedFunc = builder.emitSpecializeInst( + specializedFuncType, + generic, + specArgs.getCount(), + specArgs.getBuffer()); + + auto singletonTag = builder.emitGetTagOfElementInSet( builder.getSetTagType(witnessTableSet), table, witnessTableSet); - elements.add( - tag, - cast(findWitnessTableEntry(cast(table), key))); + + elements.add(singletonTag, specializedFunc); }); if (dispatcher->hasUses() && dispatcher->getDataType() != nullptr) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 704d2f646e2..9aadfd0ae25 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -127,6 +127,9 @@ bool isResourcePointer(IRInst* inst) return false; } +// Test if the callee represents an invalid call. This can arise from looking something up +// on a 'none' witness (in the presence of optional witness tables) +// bool isNoneCallee(IRInst* callee) { if (auto lookupWitness = as(callee)) @@ -512,10 +515,19 @@ bool isIntrinsic(IRInst* inst) return false; } +// Returns true if the inst is of the form OptionalType +bool isOptionalExistentialType(IRInst* inst) +{ + if (auto optionalType = as(inst)) + if (auto interfaceType = as(optionalType->getValueType())) + return !isComInterfaceType(interfaceType) && !isBuiltin(interfaceType); + return false; +} + // Parent context for the full type-flow pass. struct TypeFlowSpecializationContext { - // Create a tagged-union-type out of a given collection of tables. + // Make a tagged-union-type out of a given collection of tables. // // This type can be used for insts that are semantically a tuple of a tag (to select a table) // and a payload to contain the existential value. @@ -540,13 +552,13 @@ struct TypeFlowSpecializationContext cast(builder.getSet(kIROp_TypeSet, typeSet))); } - // Create an unbounded collection. + // Create an unbounded set. // - // This collection is a catch-all for - // all cases where we can't enumerate the possibilites. We use this as - // a sentinel value to figure out when NOT to specialize a given inst. + // This is a catch-all for cases where we can't enumerate the possibilites. + // We use this as a sentinel value to figure out when NOT to specialize a + // given inst. // - // Most commonly occurs with COM interface types. + // Most commonly occurs with COM objects & some builtin-in interface types. // IRUnboundedSet* makeUnbounded() { @@ -567,22 +579,49 @@ struct TypeFlowSpecializationContext // IRInst* none() { return nullptr; } + // Make an untagged-union type out of a given collection of types. + // + // This is used to denote insts whose value can be of multiple possible types, + // Note that unlike tagged-unions, untagged-unions do not have any information + // on which type is currently held. + // + // Typically used as the type of the value part of existential objects. + // IRUntaggedUnionType* makeUntaggedUnionType(IRTypeSet* typeSet) { IRBuilder builder(module); return builder.getUntaggedUnionType(typeSet); } - IRElementOfSetType* makeElementOfSetType(IRSetBase* collection) + // Make an element-of-set type out of a given set. + // + // ElementOfSetType can be used as the type of an inst whose + // _value_ is known to be one of the elements of the set. + // + // e.g. + // if we have IR of the form: + // %1 : ElementOfSetType> = ExtractExistentialType(%existentialObj) + // + // then %1 is Int or Float. + // (Note that this is different from saying %1's value is an int or a float) + // + IRElementOfSetType* makeElementOfSetType(IRSetBase* set) { IRBuilder builder(module); - return builder.getElementOfSetType(collection); + return builder.getElementOfSetType(set); } - IRSetTagType* makeTagType(IRSetBase* collection) + // Make a tag-type for a given set. + // + // `TagType(set)` is used to denote insts that are carrying an identifier for one of the + // elements of the set. + // These insts cannot be used directly as one of the values (must be used with `GetDispatcher` + // or `GetElementFromTag`) before they can be used as values. + // + IRSetTagType* makeTagType(IRSetBase* set) { IRBuilder builder(module); - return builder.getSetTagType(collection); + return builder.getSetTagType(set); } IRInst* _tryGetInfo(InstWithContext element) @@ -593,6 +632,21 @@ struct TypeFlowSpecializationContext return none(); // Default info for any inst that we haven't registered. } + // Find information for an inst that is being used as an argument + // under a given context. + // + // This is a bit different from `tryGetInfo`, in that we want to + // obtain an info that can be passed to a parameter (which can have multiple + // possibilities stemming from different argument types) + // + // This key difference is that even if the argument has no propagated info because it + // is not a dynamic or abstract type, it's possible that the parameter's type is + // abstract. + // Thus, we will construct an `UntaggedUnionType` of the concrete type so that it + // can be union-ed with other argument infos. + // + // Such 'argument' cases arise for phi-args and function args. + // IRInst* tryGetArgInfo(IRInst* context, IRInst* inst) { if (auto info = tryGetInfo(context, inst)) @@ -611,7 +665,6 @@ struct TypeFlowSpecializationContext return none(); } - // // Bottleneck method to fetch the current propagation info // for a given instruction under context. // @@ -619,13 +672,17 @@ struct TypeFlowSpecializationContext { if (inst->getDataType()) { + // If the data-type is already a collection type, then the refinement + // occured during a previous phase. For now, we simply re-use that info directly. + // + // In the future, it makes sense to ignore the pre-existing type and treat + // them as an upper-bound on the new info. + // switch (inst->getDataType()->getOp()) { case kIROp_TaggedUnionType: case kIROp_UntaggedUnionType: case kIROp_ElementOfSetType: - // These insts directly represent type-flow information, - // so we return them directly. return inst->getDataType(); } } @@ -641,35 +698,35 @@ struct TypeFlowSpecializationContext return _tryGetInfo(InstWithContext(context, inst)); } - // Performs set-union over the two collections, and returns a new - // inst to represent the collection. + // Performs set-union over the two sets, and returns a new + // inst to represent the set. // template - T* unionSet(T* collection1, T* collection2) + T* unionSet(T* set1, T* set2) { - // It may be possible to accelerate this further, but we usually + // It's possible to accelerate this further, but we usually // don't have to deal with overly large sets (usually 3-20 elements) // - SLANG_ASSERT(as(collection1) && as(collection2)); - SLANG_ASSERT(collection1->getOp() == collection2->getOp()); + SLANG_ASSERT(as(set1) && as(set2)); + SLANG_ASSERT(set1->getOp() == set2->getOp()); - if (!collection1) - return collection2; - if (!collection2) - return collection1; - if (collection1 == collection2) - return collection1; + if (!set1) + return set2; + if (!set2) + return set1; + if (set1 == set2) + return set1; HashSet allValues; - // Collect all values from both collections - forEachInSet(collection1, [&](IRInst* value) { allValues.add(value); }); - forEachInSet(collection2, [&](IRInst* value) { allValues.add(value); }); + // Collect all values from both sets + forEachInSet(set1, [&](IRInst* value) { allValues.add(value); }); + forEachInSet(set2, [&](IRInst* value) { allValues.add(value); }); IRBuilder builder(module); return as(builder.getSet( - collection1->getOp(), - allValues)); // Create a new collection with the union of values + set1->getOp(), + allValues)); // Create a new set with the union of values } // Find the union of two propagation info insts, and return and @@ -684,13 +741,14 @@ struct TypeFlowSpecializationContext // to let us propagate information elegantly for pointers, parameters, arrays // and existential tuples. // - // A few interesting cases are missing, but could be added in easily in the future: + // A few cases are missing, but could be added in easily in the future: // - TupleType (will allow us to propagate information for each tuple element) - // - OptionalType - // + // - Vector/Matrix types + // - TypePack - // Basic cases: if either is null, it is considered "empty" + // Basic cases: if either info is null, it is considered "empty" // if they're equal, union must be the same inst. + // if (!info1) return info2; @@ -730,8 +788,8 @@ struct TypeFlowSpecializationContext return makeUnbounded(); } - // For all other cases which are structured composites of collections, - // we simply take the collection union for all the collection operands. + // For all other cases which are structured composites of sets, + // we simply take the set union for all the set operands. // if (as(info1) && as(info2)) @@ -775,6 +833,14 @@ struct TypeFlowSpecializationContext bool takeUnion, WorkQueue& workQueue) { + if (isConcreteType(inst->getDataType())) + { + // No need to update info for insts with already defined concrete types, + // since these can't be refined any further. + // + return; + } + auto existingInfo = tryGetInfo(context, inst); auto unionedInfo = (takeUnion) ? unionPropagationInfo(existingInfo, newInfo) : newInfo; @@ -789,6 +855,7 @@ struct TypeFlowSpecializationContext addUsersToWorkQueue(context, inst, unionedInfo, workQueue); } + // Helper method to add work items for all call sites of a function/generic. void addContextUsersToWorkQueue(IRInst* context, WorkQueue& workQueue) { if (this->funcCallSites.containsKey(context)) @@ -863,8 +930,10 @@ struct TypeFlowSpecializationContext if (isFuncParam(param)) addContextUsersToWorkQueue(context, workQueue); - // TODO: Stopgap workaround. - // Add an analyzeFuncType for this.. + // TODO: This is a tiny bit of a hack.. but we don't currently need to + // analyze FuncType insts, but we do want to make sure that any changes + // are propagated through to their users. + // if (auto funcType = as(user)) if (as(funcType->getParent())) addUsersToWorkQueue(context, funcType, none(), workQueue); @@ -962,6 +1031,21 @@ struct TypeFlowSpecializationContext } } + IRInst* analyzeByType(IRInst* context, IRInst* inst) + { + if (!inst->getDataType()) + return none(); + + if (auto dataTypeInfo = tryGetInfo(context, inst->getDataType())) + { + if (auto elementOfSetType = as(dataTypeInfo)) + { + return makeUntaggedUnionType(cast(elementOfSetType->getSet())); + } + } + + return none(); + } void processInstForPropagation(IRInst* context, IRInst* inst, WorkQueue& workQueue) { @@ -1023,16 +1107,11 @@ struct TypeFlowSpecializationContext break; } - if (!info && inst->getDataType()) - { - if (auto dataTypeInfo = tryGetInfo(context, inst->getDataType())) - { - if (auto elementOfSetType = as(dataTypeInfo)) - { - info = makeUntaggedUnionType(cast(elementOfSetType->getSet())); - } - } - } + // If we didn't get any info from inst-specific analysis, we'll try to get + // info from the data-type's info. + // + if (!info) + info = analyzeByType(context, inst); if (info) updateInfo(context, inst, info, false, workQueue); @@ -1086,8 +1165,7 @@ struct TypeFlowSpecializationContext if (auto argInfo = tryGetArgInfo(context, arg)) { // Use centralized update method - if (!isConcreteType(param->getDataType())) - updateInfo(context, param, argInfo, true, workQueue); + updateInfo(context, param, argInfo, true, workQueue); } } paramIndex++; @@ -1155,14 +1233,12 @@ struct TypeFlowSpecializationContext &builder, paramDirection, as(argInfo)->getValueType()); - if (!isConcreteType(param->getDataType())) - updateInfo(edge.targetContext, param, newInfo, true, workQueue); + updateInfo(edge.targetContext, param, newInfo, true, workQueue); break; } case ParameterDirectionInfo::Kind::In: { - if (!isConcreteType(param->getDataType())) - updateInfo(edge.targetContext, param, argInfo, true, workQueue); + updateInfo(edge.targetContext, param, argInfo, true, workQueue); break; } default: @@ -1224,15 +1300,14 @@ struct TypeFlowSpecializationContext auto argPtrType = as(arg->getDataType()); IRBuilder builder(module); - if (!isConcreteType(arg->getDataType())) - updateInfo( - edge.callerContext, - arg, - builder.getPtrTypeWithAddressSpace( - (IRType*)as(paramInfo)->getValueType(), - argPtrType), - true, - workQueue); + updateInfo( + edge.callerContext, + arg, + builder.getPtrTypeWithAddressSpace( + (IRType*)as(paramInfo)->getValueType(), + argPtrType), + true, + workQueue); } } argIndex++; @@ -1555,15 +1630,6 @@ struct TypeFlowSpecializationContext return builder.createWitnessTable(nullptr, voidType); } - // Returns true if the inst is of the form OptionalType - bool isOptionalExistentialType(IRInst* inst) - { - if (auto optionalType = as(inst)) - if (auto interfaceType = as(optionalType->getValueType())) - return !isComInterfaceType(interfaceType) && !isBuiltin(interfaceType); - return false; - } - IRInst* analyzeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) { // If the optional type we're dealing with is an optional concrete type, we won't @@ -1603,7 +1669,7 @@ struct TypeFlowSpecializationContext // // Thus, we simply pass the input existential info as-is. // - // Note: we don't actually have to add a new 'none' table to the collection, since that will + // Note: we don't actually have to add a new 'none' table to the set, since that will // automatically occur if this value ever merges with a value created using // `MakeOptionalNone` // @@ -1624,7 +1690,7 @@ struct TypeFlowSpecializationContext if (isOptionalExistentialType(inst->getDataType())) { // This is an interesting case.. technically, at this point we could go - // from a larger collection to a smaller one (without the none-type). + // from a larger set to a smaller one (without the none-type). // // However, for simplicitly reasons, we currently only allow up-casting, // so for now we'll just passthrough all types (so the result will @@ -1644,10 +1710,10 @@ struct TypeFlowSpecializationContext IRInst* analyzeLookupWitnessMethod(IRInst* context, IRLookupWitnessMethod* inst) { // A LookupWitnessMethod is assumed to by dynamic, so we - // (i) construct a collection of the results by looking up the given + // (i) construct a set of the results by looking up the given // key in each of the input witness tables // (ii) wrap the result in a tag type, since the lookup inst is logically holding - // on to run-time information about which element of the collection is active. + // on to run-time information about which element of the set is active. // // Note that the input must be a set of concrete witness tables (or none/unbounded). // If this is not the case and we see anything abstract, then something has gone @@ -1685,11 +1751,11 @@ struct TypeFlowSpecializationContext { // An ExtractExistentialWitnessTable inst is assumed to by dynamic, so we // extract the set of witness tables from the input existential and - // state that the info of the result is a tag-type of that collection. + // state that the info of the result is a tag-type of that set. // // Note that since ExtractExistentialWitnessTable can only be used on // an existential, the input info must be a TaggedUnionType of - // concrete table and type collections (or none/unbounded) + // concrete table and type sets (or none/unbounded) // auto operand = inst->getOperand(0); @@ -1711,11 +1777,11 @@ struct TypeFlowSpecializationContext { // An ExtractExistentialType inst is assumed to be dynamic, so we // extract the set of witness tables from the input existential and - // state that the info of the result is a tag-type of that collection. + // state that the info of the result is a tag-type of that set. // // Note: Since ExtractExistentialType can only be used on // an existential, the input info must be a TaggedUnionType of - // concrete table and type collections (or none/unbounded) + // concrete table and type sets (or none/unbounded) // auto operand = inst->getOperand(0); @@ -1738,12 +1804,12 @@ struct TypeFlowSpecializationContext // Logically, an ExtractExistentialValue inst is carrying a payload // of a union type. // - // We represent this by setting its info to be equal to the type-collection, + // We represent this by setting its info to be equal to the type-set, // which will later lower into an any-value-type. // // Note that there is no 'tag' here since ExtractExistentialValue is not representing - // tag information about which type in the collection is active, but is representing - // a value of the collection's union type. + // tag information about which type in the set is active, but is representing + // a value of the set's union type. // auto operand = inst->getOperand(0); @@ -1770,19 +1836,19 @@ struct TypeFlowSpecializationContext // dynamic types or witness tables. // // We'll first look at the specialization base, which may be a single generic - // or a collection of generics. + // or a set of generics. // // Then, for each generic, we'll create a specialized version by using the - // collection info for each argument in place of the argument. - // e.g. Specialize(G, A0, A1) becomes Specialize(G, info(A1).collection, - // info(A2).collection) - // (i.e. if the args are tag-types, we only use the collection part) + // set info for each argument in place of the argument. + // e.g. Specialize(G, A0, A1) becomes Specialize(G, info(A1).set, + // info(A2).set) + // (i.e. if the args are tag-types, we only use the set part) // // This transformation is important to lift the 'dynamic' specialize instruction into a // global specialize instruction while still retaining the information about what types and // tables the resulting generic should support. // - // Finally, we put all the specialized vesions back into a collection and return that info. + // Finally, we put all the specialized vesions back into a set and return that info. // auto operand = inst->getBase(); @@ -1840,8 +1906,7 @@ struct TypeFlowSpecializationContext } else { - SLANG_UNEXPECTED( - "Unexpected collection type in specialization argument."); + SLANG_UNEXPECTED("Unexpected set type in specialization argument."); } } } @@ -1852,7 +1917,7 @@ struct TypeFlowSpecializationContext } // This part creates a correct type for the specialization, by following the same - // process: replace all operands in the composite type with their propagated collection. + // process: replace all operands in the composite type with their propagated set. // IRType* typeOfSpecialization = nullptr; @@ -1893,7 +1958,7 @@ struct TypeFlowSpecializationContext { // There's one other case we'd like to handle, where the func-type itself is a // dynamic IRSpecialize. In this situation, we'd want to use the type inst's info to - // find the collection-based specialization and create a func-type from it. + // find the set-based specialization and create a func-type from it. // if (auto elementOfSetType = as(typeInfo)) { @@ -1925,13 +1990,13 @@ struct TypeFlowSpecializationContext // Specialize each element in the set HashSet specializedSet; - IRSetBase* collection = nullptr; + IRSetBase* set = nullptr; if (auto elementOfSetType = as(operandInfo)) { - collection = elementOfSetType->getSet(); + set = elementOfSetType->getSet(); forEachInSet( - collection, + set, [&](IRInst* arg) { // Create a new specialized instruction for each argument @@ -2013,14 +2078,9 @@ struct TypeFlowSpecializationContext if (as(arg)) continue; - if (auto collection = as(arg)) + if (auto set = as(arg)) { - updateInfo( - context, - param, - makeElementOfSetType(collection), - true, - workQueue); + updateInfo(context, param, makeElementOfSetType(set), true, workQueue); } else if (as(arg) || as(arg)) { @@ -2092,7 +2152,7 @@ struct TypeFlowSpecializationContext WorkItem(InterproceduralEdge::Direction::CallToFunc, context, inst, callee)); }; - // If we have a collection of functions (with or without a dynamic tag), register + // If we have a set of functions (with or without a dynamic tag), register // each one. // if (auto elementOfSetType = as(calleeInfo)) @@ -2220,8 +2280,7 @@ struct TypeFlowSpecializationContext // // This is one of the base cases for the recursion. // - if (!isConcreteType(var->getDataType())) - updateInfo(context, var, info, true, workQueue); + updateInfo(context, var, info, true, workQueue); } else if (auto param = as(inst)) { @@ -2239,8 +2298,7 @@ struct TypeFlowSpecializationContext (IRType*)as(info)->getValueType(), as(param->getDataType())); - if (!isConcreteType(param->getDataType())) - updateInfo(context, param, newInfo, true, workQueue); + updateInfo(context, param, newInfo, true, workQueue); } else { @@ -2449,7 +2507,7 @@ struct TypeFlowSpecializationContext { // When specializing a func, we // (i) rewrite the types and insts by calling `specializeInstsInBlock` and - // (ii) handle 'merge' points where the collections need to be upcasted. + // (ii) handle 'merge' points where the sets need to be upcasted. // // The merge points are places where a specialized inst might be passed as // argument to a parameter that has a 'wider' type. @@ -2608,7 +2666,7 @@ struct TypeFlowSpecializationContext // info. // // This basically recursively walks the info and applies the array/ptr-type - // wrappers, while replacing unbounded collection with a nullptr. + // wrappers, while replacing unbounded set with a nullptr. // // If the result of this is null, then the inst should keep its current type. // @@ -2651,7 +2709,7 @@ struct TypeFlowSpecializationContext if (auto elementOfSetType = as(info)) { - // Replace element-of-collection types with tag types. + // Replace element-of-set types with tag types. return makeTagType(elementOfSetType->getSet()); } @@ -2659,7 +2717,7 @@ struct TypeFlowSpecializationContext { if (valOfSetType->getSet()->isSingleton()) { - // If there's only one type in the collection, return it directly + // If there's only one type in the set, return it directly return (IRType*)valOfSetType->getSet()->getElement(0); } @@ -2668,7 +2726,7 @@ struct TypeFlowSpecializationContext if (as(info) || as(info)) { - // Don't specialize these collections.. they should be used through + // Don't specialize these sets.. they should be used through // tag types, or be processed out during specializeing. // return nullptr; @@ -2862,7 +2920,7 @@ struct TypeFlowSpecializationContext { // Replace with GetElement(specializedInst, 0) -> TagType(tableSet) // which retreives a 'tag' (i.e. a run-time identifier for one of the elements - // of the collection) + // of the set) // auto operand = inst->getOperand(0); auto element = builder.emitGetTagFromTaggedUnion(operand); @@ -2926,27 +2984,27 @@ struct TypeFlowSpecializationContext { if (auto valOfSetType = as(currentType)) { - HashSet collectionElements; + HashSet setElements; forEachInSet( valOfSetType->getSet(), - [&](IRInst* element) { collectionElements.add(element); }); + [&](IRInst* element) { setElements.add(element); }); if (auto newValOfSetType = as(newType)) { - // If the new type is also a collection, merge the two collections + // If the new type is also a set, merge the two sets forEachInSet( newValOfSetType->getSet(), - [&](IRInst* element) { collectionElements.add(element); }); + [&](IRInst* element) { setElements.add(element); }); } else { - // Otherwise, just add the new type to the collection - collectionElements.add(newType); + // Otherwise, just add the new type to the set + setElements.add(newType); } - // If this is a collection, we need to create a new collection with the new type + // If this is a set, we need to create a new set with the new type IRBuilder builder(module); - auto newSet = builder.getSet(kIROp_TypeSet, collectionElements); + auto newSet = builder.getSet(kIROp_TypeSet, setElements); return makeUntaggedUnionType(cast(newSet)); } else if (currentType == newType) @@ -2964,18 +3022,18 @@ struct TypeFlowSpecializationContext as(currentType)->getWitnessTableSet(), as(newType)->getWitnessTableSet()))); } - else // Need to create a new collection. + else // Need to create a new set. { - HashSet collectionElements; + HashSet setElements; SLANG_ASSERT(!as(currentType) && !as(newType)); - collectionElements.add(currentType); - collectionElements.add(newType); + setElements.add(currentType); + setElements.add(newType); - // If this is a collection, we need to create a new collection with the new type + // If this is a set, we need to create a new set with the new type IRBuilder builder(module); - auto newSet = builder.getSet(kIROp_TypeSet, collectionElements); + auto newSet = builder.getSet(kIROp_TypeSet, setElements); return makeUntaggedUnionType(cast(newSet)); } } @@ -3001,7 +3059,7 @@ struct TypeFlowSpecializationContext } // Get an effective func type to use for the callee. - // The callee may be a collection, in which case, this returns a union-ed functype. + // The callee may be a set, in which case, this returns a union-ed functype. // IRFuncType* getEffectiveFuncTypeForSet(IRFuncSet* calleeSet) { @@ -3009,7 +3067,7 @@ struct TypeFlowSpecializationContext // // (i) we build up the effective parameter types for the callee // by taking the union of each parameter type - // for each callee in the collection. + // for each callee in the set. // // (ii) build up the effective result type in a similar manner. // @@ -3018,8 +3076,8 @@ struct TypeFlowSpecializationContext // - if we have multiple callees, then a parameter of TagType(callee) is appended // to the beginning to select the callee. // - // - if our callee is Specialize inst with collection args, then for each - // table-collection argument, a tag is required as input. + // - if our callee is Specialize inst with set args, then for each + // table-set argument, a tag is required as input. // IRBuilder builder(module); @@ -3157,37 +3215,54 @@ struct TypeFlowSpecializationContext // // First, we handle the callee: // - // - If the callee is already a concrete function, there's nothing to do + // The callee can only be in a few specific patterns: + // + // 1. If the callee is already a concrete function, there's nothing to do + // + // 2. If the callee is a dynamic inst of tag type, then we need to look at the + // tag inst's structure: + // + // i. If inst is a GetTagForMappedSet (resulting from a lookup), // - // - If the callee is a dynamic inst of tag type, we replace - // the callee with the collection itself, and pass the tag inst as - // the first operand. Effectively, we are placing a call to a set of functions - // and using the tag to specify which function to call. + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let tag : TagType(funcSet) = GetTagForMappedSet(tableTag, key); + // let val = Call(tag, arg1, arg2, ...); + // becomes + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let dispatcher : FuncType(...) = GetDispatcher(witnessTableSet, key); + // let val = Call(dispatcher, tableTag, arg1, arg2, ...); + // where the dispatcher represents a dispatch function that selects the function + // based on the witness table tag. // - // e.g. - // let tag : TagType(funcSet) = /* ... */; - // let val = Call(tag, arg1, arg2, ...); - // becomes - // let tag : TagType(funcSet) = /* ... */; - // let val = Call(funcSet, tag, arg1, arg2, ...); + // ii. If the inst is a GetTagForSpecializedCollection (resulting from a specialization + // resulting from a lookup), // - // - If any the callee is a dynamic specialization of a generic, we need to add any dynamic - // witness - // table insts as arguments to the call. + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let tag : TagType(genericSet) = GetTagForMappedSet(tableTag, key); + // let specializedTag : + // TagType(funcSet) = GetTagForSpecializedCollection(tag, specArgs...); + // let val = Call(specializedTag, arg1, arg2, ...); + // becomes + // let tableTag : TagType(witnessTableSet) = /* ... */; + // let dispatcher : FuncType(...) = + // GetSpecializedDispatcher(witnessTableSet, key, specArgs...); + // let val = Call(dispatcher, tableTag, arg1, arg2, ...); // - // e.g.: - // Call( - // Specialize(g, specArgs...), callArgs...); - // where atleast one of specialization args is a dynamic tag inst. + // iii. If the inst is a Specialize of a concrete generic, then + // it means that one or more specialization arguments are dynamic. // - // Our convention for dynamic generics is that the dynamic witness table - // operands are added to the front of the regular call arguments. + // let specCallee = Specialize(generic, specArgs...); + // let val = Call(specCallee, callArgs...); + // becomes + // let specCallee = Specialize(generic, staticFormOfSpecArgs...); + // let val = Call(specCallee, dynamicSpecArgs..., callArgs...); + // where the new dynamicSpecArgs are the tag insts of WitnessTableSets + // and the static form is the corresponding WitnessTableSet itself. // - // So, we'll turn this into: - // Call( - // Specialize(g, staticFormOfSpecArgs...), dynamicSpecArgs..., callArgs...); - // where the new callee is a specialization where any dynamic insts are - // replaced with their static collections. + // This creates a specialization that includes set arguments (and is handled + // by `specializeGenericWithSetArgs`) + // + // More concrete example for (iii): // // // --- before specialization --- // let s1 : TagType(WitnessTableSet(tA, tB, tC)) = /* ... */; @@ -3203,20 +3278,9 @@ struct TypeFlowSpecializationContext // let newVal = Call(newSpecCallee, s1, /* call args */); // // - // - In case the callee is a collection of dynamically specialized - // generics, _both_ of the above transformations are applied, with - // the callee's tag going first, followed by any witness table tags - // and finally the regular call arguments. - // However, this case is NOT currently well supported because the func-collection - // tag does not encode the additional tags that need to be passed, so this - // is likely to fail currently. - // This is a rare scenario that only occurs on trying to specialize an existential - // method with existential arguments, which we don't officially support. - // - // Secondly, we handle the argument types: - // It is possible that - // the parameters in the callee have been specialized to accept - // a wider collection compared to the arguments from this call site + // After the callee has been selected, we handle the argument types: + // It is possible that the parameters in the callee have been specialized + // to accept a super-set of types compared to the arguments from this call site // // In this case, we just upcast them using `upcastSet` before // creating a new call inst @@ -3237,20 +3301,20 @@ struct TypeFlowSpecializationContext // maybeSpecializeCalleeType(callee); - // If we're calling using a tag, place a call to the collection, + // If we're calling using a tag, place a call to the set, // with the tag as the first argument. So the callee is - // the collection itself. + // the set itself. // - if (auto collectionTag = as(callee->getDataType())) + if (auto setTag = as(callee->getDataType())) { - if (!collectionTag->isSingleton()) + if (!setTag->isSingleton()) { // Multiple callees case: // // If we need to use a tag, we'll do a bit of an optimization here.. // - // Instead of building a dispatcher on then func-collection, we'll - // build it on the table collection that it is looked up from. This + // Instead of building a dispatcher on then func-set, we'll + // build it on the table set that it is looked up from. This // avoids the extra map. // // This works primarily because this is the only way to call a dynamic @@ -3270,7 +3334,7 @@ struct TypeFlowSpecializationContext getEffectiveFuncTypeForDispatcher( tableSet, lookupKey, - cast(collectionTag->getSet())), + cast(setTag->getSet())), tableSet, lookupKey); @@ -3316,7 +3380,7 @@ struct TypeFlowSpecializationContext getEffectiveFuncTypeForDispatcher( tableSet, lookupKey, - cast(collectionTag->getSet())), + cast(setTag->getSet())), tableSet, lookupKey, specArgs); @@ -3325,15 +3389,14 @@ struct TypeFlowSpecializationContext } else { - SLANG_UNEXPECTED( - "Cannot specialize call with non-singleton collection tag callee"); + SLANG_UNEXPECTED("Cannot specialize call with non-singleton set tag callee"); } } - else if (isSetSpecializedGeneric(collectionTag->getSet()->getElement(0))) + else if (isSetSpecializedGeneric(setTag->getSet()->getElement(0))) { // Single element which is a set specialized generic. callArgs.addRange(getArgsForSetSpecializedGeneric(cast(callee))); - callee = collectionTag->getSet()->getElement(0); + callee = setTag->getSet()->getElement(0); auto funcType = getEffectiveFuncType(callee); callee->setFullType(funcType); @@ -3377,13 +3440,13 @@ struct TypeFlowSpecializationContext } // If by this point, we haven't resolved our callee into a global inst ( - // either a collection or a single function), then we can't specialize it (likely unbounded) + // either a set or a single function), then we can't specialize it (likely unbounded) // if (!isGlobalInst(callee) || isIntrinsic(callee)) return false; // First, we'll legalize all operands by upcasting if necessary. - // This needs to be done even if the callee is not a collection. + // This needs to be done even if the callee is not a set. // UCount extraArgCount = callArgs.getCount(); for (UInt i = 0; i < inst->getArgCount(); i++) @@ -3406,7 +3469,7 @@ struct TypeFlowSpecializationContext // Out parameters are handled at the callee's end case ParameterDirectionInfo::Kind::Out: - // For all other modes, collections must match ('subtyping' is not allowed) + // For all other modes, sets must match ('subtyping' is not allowed) case ParameterDirectionInfo::Kind::BorrowInOut: case ParameterDirectionInfo::Kind::BorrowIn: case ParameterDirectionInfo::Kind::Ref: @@ -3539,13 +3602,13 @@ struct TypeFlowSpecializationContext } // Create the appropriate any-value type - auto collectionType = typeSet->isSingleton() - ? (IRType*)typeSet->getElement(0) - : builder.getUntaggedUnionType((IRType*)typeSet); + auto effectiveType = typeSet->isSingleton() + ? (IRType*)typeSet->getElement(0) + : builder.getUntaggedUnionType((IRType*)typeSet); // Pack the value - auto packedValue = as(collectionType) - ? builder.emitPackAnyValue(collectionType, inst->getWrappedValue()) + auto packedValue = as(effectiveType) + ? builder.emitPackAnyValue(effectiveType, inst->getWrappedValue()) : inst->getWrappedValue(); auto taggedUnionType = getLoweredType(taggedUnion); @@ -3587,17 +3650,17 @@ struct TypeFlowSpecializationContext args); IRInst* packedValue = nullptr; - auto collection = taggedUnionType->getTypeSet(); - if (!collection->isSingleton()) + auto set = taggedUnionType->getTypeSet(); + if (!set->isSingleton()) { packedValue = builder.emitPackAnyValue( - (IRType*)builder.getUntaggedUnionType(collection), + (IRType*)builder.getUntaggedUnionType(set), inst->getValue()); } else { packedValue = builder.emitReinterpret( - (IRType*)builder.getUntaggedUnionType(collection), + (IRType*)builder.getUntaggedUnionType(set), inst->getValue()); } @@ -3615,7 +3678,7 @@ struct TypeFlowSpecializationContext // interface-typed pointer. // // Our type-flow analysis will convert the - // result into a collection of all available implementations of this + // result into a set of all available implementations of this // interface, so we need to cast the result. // @@ -3684,7 +3747,7 @@ struct TypeFlowSpecializationContext // // If we're dealing with specializing a type, witness table, or any other // generic, we simply drop all dynamic tag information, and replace all - // operands with their collection variants. + // operands with their set variants. // // If we're dealing with a function, there are two cases: // - A single function when dynamic specialization arguments. @@ -3694,13 +3757,13 @@ struct TypeFlowSpecializationContext // Instead, we'll just replace the type, and retain the `Specialize` inst with // the dynamic args. It will be specialized out in `specializeCall` instead. // - // - A collection of functions with concrete specialization arguments. - // In this case, we will emit an instruction to map from the input generic collection - // to the output specialized collection via `GetTagForSpecializedSet`. + // - A set of functions with concrete specialization arguments. + // In this case, we will emit an instruction to map from the input generic set + // to the output specialized set via `GetTagForSpecializedSet`. // This inst encodes the key-value mapping in its operands: // e.g.(input_tag, key0, value0, key1, value1, ...) // - // - The case where there is a collection of functions with dynamic specialization arguments + // - The case where there is a set of functions with dynamic specialization arguments // is not currently properly handled. This case should not arise naturally since we // don't advertise support for it. // @@ -3715,7 +3778,7 @@ struct TypeFlowSpecializationContext isFuncReturn = as(getGenericReturnVal(firstConcreteGeneric)) != nullptr; } - // We'll emit a dynamic tag inst if the result is a func collection with multiple elements + // We'll emit a dynamic tag inst if the result is a func set with multiple elements if (isFuncReturn) { if (auto info = tryGetInfo(context, inst)) @@ -3735,7 +3798,7 @@ struct TypeFlowSpecializationContext if (elementOfSetType->getSet()->isSingleton()) { - // If the result is a singleton collection, we can just + // If the result is a singleton set, we can just // replace the type (if necessary) and be done with it. return replaceType(context, inst); } @@ -3764,8 +3827,7 @@ struct TypeFlowSpecializationContext } else { - SLANG_UNEXPECTED( - "Expected element-of-collection type for function specialization"); + SLANG_UNEXPECTED("Expected element-of-set type for function specialization"); } } } @@ -3777,15 +3839,15 @@ struct TypeFlowSpecializationContext { auto arg = inst->getArg(i); auto argDataType = arg->getDataType(); - if (auto collectionTagType = as(argDataType)) + if (auto setTagType = as(argDataType)) { - // If this is a tag type, replace with collection. + // If this is a tag type, replace with set. changed = true; - if (as(collectionTagType->getSet())) + if (as(setTagType->getSet())) { - args.add(collectionTagType->getSet()); + args.add(setTagType->getSet()); } - else if (auto typeSet = as(collectionTagType->getSet())) + else if (auto typeSet = as(setTagType->getSet())) { IRBuilder builder(inst); args.add(builder.getUntaggedUnionType(typeSet)); @@ -3840,6 +3902,13 @@ struct TypeFlowSpecializationContext bool specializeGetElementFromTag(IRInst* context, IRGetElementFromTag* inst) { + // During specialization, we convert all "element-of" types into + // run-time tag types so that we can obtain and use this information. + // + // Thus, `GetElementFromTag` simply becomes a no-op. Any instructions using + // the result of `GetElementFromTag` should be specialized accordingly to + // use the tag operand. + // SLANG_UNUSED(context); inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); @@ -4007,11 +4076,10 @@ struct TypeFlowSpecializationContext bool specializeIsType(IRInst* context, IRIsType* inst) { // The is-type checks equality between two witness tables - // via their sequential IDs. // - // If the dynamic part has been specialized into a tag, we emit - // a `GetSequentialIDFromTag` inst to extract the ID and emit - // an equality test. + // We'll turn this into a tag comparison by extracting the tag + // for a specific element in the set, and comparing that to the + // dynamic witness table tag. // SLANG_UNUSED(context); auto witnessTableArg = inst->getValueWitness(); @@ -4121,15 +4189,12 @@ struct TypeFlowSpecializationContext // // There's two cases to handle here: // 1. We statically know that it cannot be a 'none' because the - // input's collection type doesn't have a 'none'. In this case + // input's set type doesn't have a 'none'. In this case // we just return a true. // - // 2. 'none' is a possibility. In this case, we create a 0 value of - // type TagType(WitnessTableSet(NoneWitness)) and then upcast it - // to TagType(inputWitnessTableSet). This will convert the value - // to the corresponding value of 'none' in the input's table collection - // allowing us to directly compare it against the tag part of the - // input tagged union. + // 2. 'none' is a possibility. In this case, we'll get the + // tag of 'none' in the set by emitting a `GetTagOfElementInSet` + // compare those to determine if we have a value. // IRBuilder builder(inst); @@ -4145,7 +4210,7 @@ struct TypeFlowSpecializationContext if (!containsNone) { - // If 'none' isn't a part of the collection, statically set + // If 'none' isn't a part of the set, statically set // to true. // @@ -4157,13 +4222,13 @@ struct TypeFlowSpecializationContext else { // Otherwise, we'll extract the tag and compare against - // the value for 'none' (in the context of the tag's collection) + // the value for 'none' (in the context of the tag's set) // builder.setInsertBefore(inst); auto dynTag = builder.emitGetTagFromTaggedUnion(inst->getOptionalOperand()); - // Cast the singleton tag to the target collection tag (will convert the + // Cast the singleton tag to the target set tag (will convert the // value to the corresponding value for the larger set) // auto noneWitnessTag = builder.emitGetTagOfElementInSet( From 8d386346c48fd83731e01a471baf4a75f92dbd4d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 23 Oct 2025 12:39:03 -0400 Subject: [PATCH 082/105] Update comments for op-codes --- source/slang/slang-ir-insts.lua | 62 +++++++++++++++++++++++---------- 1 file changed, 43 insertions(+), 19 deletions(-) diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 631a933cfa0..4b6c13b1ef4 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -679,18 +679,18 @@ local insts = { }, { UntaggedUnionType = { hoistable = true, - -- A type that represents that the value's _type_ is one of types in the collection operand. + -- A type that represents that the value's _type_ is one of types in the set operand. } }, { ElementOfSetType = { hoistable = true, - -- A type that represents that the value must be an element of the collection operand. + -- A type that represents that the value must be an element of the set operand. } }, { SetTagType = { hoistable = true, - -- Represents a tag-type for a collection. + -- Represents a tag-type for a set. -- - -- An inst whose type is SetTagType(collection) is semantically carrying a - -- run-time value that "picks" one of the elements of the collection operand. + -- An inst whose type is SetTagType(set) is semantically carrying a + -- run-time value that "picks" one of the elements of the set operand. -- -- Only operand is a SetBase } }, @@ -2730,7 +2730,7 @@ local insts = { -- the specialization pass. -- -- TODO: Consider the scenario where we can combine the unbounded case with known cases. - -- unbounded collection should probably be an element and not a separate op. + -- unbounded set should probably be an element and not a separate op. } }, { CastInterfaceToTaggedUnionPtr = { -- Cast an interface-typed pointer to a tagged-union pointer with a known set. @@ -2740,27 +2740,34 @@ local insts = { } }, { GetTagForSuperSet = { -- Translate a tag from a set to its equivalent in a super-set - -- TODO: Lower using a global ID and not local IDs + mapping ops. + -- + -- Operands: (the tag for the source set) + -- The source and destination sets are implied by the type of the operand and the type of the result } }, { GetTagForMappedSet = { -- Translate a tag from a set to its equivalent in a different set -- based on a mapping induced by a lookup key + -- + -- Operands: (the tag for the witness table set, the lookup key) } }, { GetTagForSpecializedSet = { - -- Translate a tag from a generic set to its equivalent in a specialized set - -- based on a mapping that is encoded in the operands of this tag instruction + -- Translate a tag from a set of generics to its equivalent in a specialized set + -- according to the set of specialization arguments that are encoded in the + -- operands of this instruction. + -- + -- Operands: (the tag for the generic set, any number of specialization arguments....) } }, { GetTagFromSequentialID = { -- Translate an existing sequential ID (a 'global' ID) & and interface type into a tag - -- the provided collection (a 'local' ID) + -- the provided set (a 'local' ID) } }, { GetSequentialIDFromTag = { - -- Translate a tag from the given collection (a 'local' ID) to a sequential ID (a 'global' ID) + -- Translate a tag from the given set (a 'local' ID) to a sequential ID (a 'global' ID) } }, { GetElementFromTag = { - -- Translate a tag to its corresponding element in the collection. - -- Input's type: SetTagType(collection). - -- Output's type: ElementOfSetType(collection) + -- Translate a tag to its corresponding element in the set. + -- Input's type: SetTagType(set). + -- Output's type: ElementOfSetType(set) -- operands = {{"tag"}} } }, @@ -2782,10 +2789,8 @@ local insts = { -- Get a specialized dispatcher function for a given witness table set + key, where -- the key points to a generic function. -- - -- Inputs: set of witness tables to create a dispatched for and the key to use to identify the - -- entry that needs to be dispatched to. All witness tables must have an entry for the given key. - -- or else this is a malformed inst. - -- A set of specialization arguments (these must be concrete/global types or collections) + -- Operands: (set of witness tables, lookup key, specialization args...) + -- -- -- Output: a value of `FuncType` that can be called. -- This func-type will take a `TagType(witnessTableSet)` as the first parameter to @@ -2795,30 +2800,49 @@ local insts = { } }, { GetTagFromTaggedUnion = { -- Translate a tagged-union value to its corresponding tag in the tagged-union's set. + -- -- Input's type: TaggedUnionType(typeSet, tableSet) + -- -- Output's type: SetTagType(tableSet) + -- operands = {{"taggedUnionValue"}} } }, { GetTypeTagFromTaggedUnion = { -- Translate a tagged-union value to its corresponding type tag in the tagged-union's set. + -- -- Input's type: TaggedUnionType(typeSet, tableSet) + -- -- Output's type: SetTagType(typeSet) + -- operands = {{"taggedUnionValue"}} } }, { GetValueFromTaggedUnion = { -- Translate a tagged-union value to its corresponding value in the tagged-union's set. + -- -- Input's type: TaggedUnionType(typeSet, tableSet) + -- -- Output's type: UntaggedUnionType(typeSet) + -- operands = {{"taggedUnionValue"}} } }, { MakeTaggedUnion = { -- Create a tagged-union value from a tag and a value. + -- -- Input's type: SetTagType(tableSet), UntaggedUnionType(typeSet) + -- -- Output's type: TaggedUnionType(typeSet, tableSet) + -- operands = { { "tag" }, { "value" } }, } }, { GetTagOfElementInSet = { - -- Get the tag corresponding to an element in a collection. + -- Get the tag corresponding to an element in a set. + -- + -- Operands: (element, set) + -- "element" must resolve into a concrete inst before lowering, + -- otherwise, this is an error. + -- + -- Output's type: SetTagType(set) + -- hoistable = true } }, } From 640d6a02b88129b3e69d324c8529b217ea5ac9a0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 23 Oct 2025 12:56:32 -0400 Subject: [PATCH 083/105] Update slang-ir-typeflow-specialize.cpp --- source/slang/slang-ir-typeflow-specialize.cpp | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 9aadfd0ae25..c08e9400032 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -434,6 +434,25 @@ bool isGlobalInst(IRInst* inst) return inst->getParent()->getOp() == kIROp_ModuleInst; } +// This is fairly fundamental check: +// This method checks whether a inst's type cannot accept any further refinement. +// +// e.g. an inst of `UInt` type cannot be further refined (under the current scope +// of the type-flow pass), since it has a concrete type, and we do not +// track values. +// +// an inst of `InterfaceType` represents a tagged union of any type that implements +// the interface, so it can be further refined by determining a smaller set of +// possibilities (i.e. via `TaggedUnionType(tableSet, typeSet)`). +// +// Similarly, an inst of `WitnessTableType` represents any witness table, +// so it can accept a further refinement into `ElementOfSetType(tableSet)`. +// +// In the future, we may want to extend this check to something more nuanced, +// which takes in both the inst and the refined type to determine if we want to +// accept the refinement (this is useful in cases like `UnsizedArrayType`, where +// we only want to refine it if we can determine a concrete size). +// bool isConcreteType(IRInst* inst) { bool isInstGlobal = isGlobalInst(inst); @@ -442,11 +461,16 @@ bool isConcreteType(IRInst* inst) switch (inst->getOp()) { - case kIROp_InterfaceType: + case kIROp_InterfaceType: // Can be refined to tagged unions + return false; + case kIROp_WitnessTableType: // Can be refined into set of concrete tables + return false; + case kIROp_FuncType: // Can be refined into set of concrete functions + return false; + case kIROp_GenericKind: // Can be refined into set of concrete generics + return false; + case kIROp_TypeKind: // Can be refined into set of concrete types return false; - case kIROp_WitnessTableType: - case kIROp_FuncType: - return isInstGlobal; case kIROp_ArrayType: return isConcreteType(cast(inst)->getElementType()) && isGlobalInst(cast(inst)->getElementCount()); From 98de01ec764d966846bddbc520fbabd2c2021cd0 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:36:59 -0400 Subject: [PATCH 084/105] Disable old generics pass completely. Get tests passing. --- source/slang/slang-emit.cpp | 42 +- source/slang/slang-ir-any-value-inference.cpp | 31 +- source/slang/slang-ir-any-value-inference.h | 5 +- source/slang/slang-ir-dce.cpp | 4 +- source/slang/slang-ir-insts-stable-names.lua | 8 +- source/slang/slang-ir-insts.h | 64 +- source/slang/slang-ir-insts.lua | 22 + source/slang/slang-ir-lower-enum-type.cpp | 17 + source/slang/slang-ir-lower-existential.cpp | 4 +- source/slang/slang-ir-lower-generics.cpp | 44 +- source/slang/slang-ir-lower-reinterpret.cpp | 2 +- .../slang/slang-ir-lower-typeflow-insts.cpp | 628 ++++++++++++++++-- source/slang/slang-ir-lower-typeflow-insts.h | 11 +- .../slang-ir-propagate-func-properties.cpp | 22 +- source/slang/slang-ir-specialize-dispatch.cpp | 4 +- .../slang-ir-strip-legalization-insts.cpp | 8 +- source/slang/slang-ir-typeflow-collection.cpp | 16 +- source/slang/slang-ir-typeflow-specialize.cpp | 451 ++++++++++--- source/slang/slang-ir-util.cpp | 1 + source/slang/slang-ir.cpp | 11 + tests/compute/dynamic-dispatch-1.slang | 4 +- tests/compute/dynamic-dispatch-10.slang | 4 +- tests/compute/dynamic-dispatch-11.slang | 4 +- tests/compute/dynamic-dispatch-2.slang | 4 +- tests/compute/dynamic-dispatch-3.slang | 4 +- tests/compute/dynamic-dispatch-4.slang | 4 +- tests/compute/dynamic-dispatch-5.slang | 4 +- tests/compute/dynamic-dispatch-6.slang | 4 +- tests/compute/dynamic-dispatch-7.slang | 4 +- tests/compute/dynamic-dispatch-8.slang | 4 +- tests/compute/dynamic-dispatch-9.slang | 4 +- tests/compute/dynamic-generics-simple.slang | 4 +- .../interfaces/anyvalue-size-validation.slang | 31 +- .../anyvalue-size-validation.slang.expected | 12 +- .../interfaces/interface-extension.slang | 2 +- tests/diagnostics/no-type-conformance.slang | 4 +- 36 files changed, 1240 insertions(+), 252 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index a6f270eff54..764d44c29a0 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -896,6 +896,12 @@ Result linkAndOptimizeIR( // Lower all the LValue implict casts (used for out/inout/ref scenarios) lowerLValueCast(targetProgram, irModule); + // Lower enum types early since enums and enum casts may appear in + // specialization & not resolving them here would block specialization. + // + if (requiredLoweringPassSet.enumType) + lowerEnumType(irModule, sink); + IRSimplificationOptions defaultIRSimplificationOptions = IRSimplificationOptions::getDefault(targetProgram); IRSimplificationOptions fastIRSimplificationOptions = @@ -1144,23 +1150,12 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(performTypeInlining(irModule, targetProgram, sink)); } - lowerTagInsts(irModule, sink); - - // Tagged union type lowering typically generates more reinterpret instructions. - if (lowerTaggedUnionTypes(irModule, sink)) - requiredLoweringPassSet.reinterpret = true; - - lowerUntaggedUnionTypes(irModule, sink); - - if (requiredLoweringPassSet.reinterpret) - lowerReinterpret(targetProgram, irModule, sink); - if (sink->getErrorCount() != 0) return SLANG_FAIL; validateIRModuleIfEnabled(codeGenContext, irModule); - inferAnyValueSizeWhereNecessary(targetProgram, irModule); + inferAnyValueSizeWhereNecessary(targetProgram, irModule, sink); // If we have any witness tables that are marked as `KeepAlive`, // but are not used for dynamic dispatch, unpin them so we don't @@ -1176,6 +1171,26 @@ Result linkAndOptimizeIR( eliminateDeadCode(irModule, fastIRSimplificationOptions.deadCodeElimOptions); } + // Tagged union type lowering typically generates more reinterpret instructions. + if (lowerTaggedUnionTypes(irModule, sink)) + requiredLoweringPassSet.reinterpret = true; + + lowerUntaggedUnionTypes(irModule, targetProgram, sink); + + if (requiredLoweringPassSet.reinterpret) + lowerReinterpret(targetProgram, irModule, sink); + + lowerSequentialIDTagCasts(irModule, codeGenContext->getLinkage(), sink); + lowerTagInsts(irModule, sink); + lowerTagTypes(irModule); + + eliminateDeadCode(irModule, fastIRSimplificationOptions.deadCodeElimOptions); + + lowerExistentials(irModule, sink); + + if (sink->getErrorCount() != 0) + return SLANG_FAIL; + if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc) && targetProgram->getOptionSet().shouldRunNonEssentialValidation()) { @@ -1194,9 +1209,6 @@ Result linkAndOptimizeIR( cleanupGenerics(targetProgram, irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); - if (requiredLoweringPassSet.enumType) - lowerEnumType(irModule, sink); - // Don't need to run any further target-dependent passes if we are generating code // for host vm. if (target == CodeGenTarget::HostVM) diff --git a/source/slang/slang-ir-any-value-inference.cpp b/source/slang/slang-ir-any-value-inference.cpp index d26dff20305..831d20e49a5 100644 --- a/source/slang/slang-ir-any-value-inference.cpp +++ b/source/slang/slang-ir-any-value-inference.cpp @@ -90,7 +90,10 @@ List sortTopologically( return sortedInterfaceTypes; } -void inferAnyValueSizeWhereNecessary(TargetProgram* targetProgram, IRModule* module) +void inferAnyValueSizeWhereNecessary( + TargetProgram* targetProgram, + IRModule* module, + DiagnosticSink* sink) { // Go through the global insts and collect all interface types. // For each interface type, infer its any-value-size, by looking up @@ -130,9 +133,10 @@ void inferAnyValueSizeWhereNecessary(TargetProgram* targetProgram, IRModule* mod if (interfaceType->findDecoration()) continue; - // If the interface already has an explicit any-value-size, don't infer anything. + /* If the interface already has an explicit any-value-size, don't infer anything. if (interfaceType->findDecoration()) continue; + */ // Skip interfaces that are not implemented by any type. if (!implementedInterfaces.contains(interfaceType)) @@ -213,6 +217,12 @@ void inferAnyValueSizeWhereNecessary(TargetProgram* targetProgram, IRModule* mod for (auto interfaceType : sortedInterfaceTypes) { + IRIntegerValue existingMaxSize = (IRIntegerValue)kMaxInt; // Default to max int. + if (auto existingAnyValueDecor = interfaceType->findDecoration()) + { + existingMaxSize = existingAnyValueDecor->getSize(); + } + IRIntegerValue maxAnyValueSize = -1; for (auto implType : mapInterfaceToImplementations[interfaceType]) { @@ -223,6 +233,18 @@ void inferAnyValueSizeWhereNecessary(TargetProgram* targetProgram, IRModule* mod &sizeAndAlignment); maxAnyValueSize = Math::Max(maxAnyValueSize, sizeAndAlignment.size); + + // Diagnose if the existing any-value-size is smaller than the inferred size. + if (existingMaxSize < sizeAndAlignment.size) + { + sink->diagnose(implType, Diagnostics::typeDoesNotFitAnyValueSize, implType); + sink->diagnoseWithoutSourceView( + implType, + Diagnostics::typeAndLimit, + implType, + sizeAndAlignment.size, + existingMaxSize); + } } // Should not encounter interface types without any conforming implementations. @@ -232,7 +254,10 @@ void inferAnyValueSizeWhereNecessary(TargetProgram* targetProgram, IRModule* mod if (maxAnyValueSize >= 0) { IRBuilder builder(module); - builder.addAnyValueSizeDecoration(interfaceType, maxAnyValueSize); + if (!interfaceType->findDecoration()) + { + builder.addAnyValueSizeDecoration(interfaceType, maxAnyValueSize); + } } } } diff --git a/source/slang/slang-ir-any-value-inference.h b/source/slang/slang-ir-any-value-inference.h index 23e579b6778..7c520cc33bc 100644 --- a/source/slang/slang-ir-any-value-inference.h +++ b/source/slang/slang-ir-any-value-inference.h @@ -8,5 +8,8 @@ namespace Slang { -void inferAnyValueSizeWhereNecessary(TargetProgram* targetProgram, IRModule* module); +void inferAnyValueSizeWhereNecessary( + TargetProgram* targetProgram, + IRModule* module, + DiagnosticSink* sink); } diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 9578bf5e53f..3c20750aec6 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -671,10 +671,10 @@ bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex) // to be replaced with `undef`. switch (inst->getOp()) { - case kIROp_BoundInterfaceType: + /*case kIROp_BoundInterfaceType: if (inst->getOperand(operandIndex)->getOp() == kIROp_WitnessTable) return true; - break; + break;*/ case kIROp_SpecializationDictionaryItem: // Ignore all operands of SpecializationDictionaryItem. // This inst is used as a cache and shouldn't hold anything alive. diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 0d35424ee59..a237af33c40 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -705,5 +705,11 @@ return { ["Type.ElementOfSetType"] = 701, ["MakeTaggedUnion"] = 702, ["GetTypeTagFromTaggedUnion"] = 703, - ["GetTagOfElementInSet"] = 704 + ["GetTagOfElementInSet"] = 704, + ["UnboundedTypeElement"] = 705, + ["UnboundedFuncElement"] = 706, + ["UnboundedWitnessTableElement"] = 707, + ["UnboundedGenericElement"] = 708, + ["UninitializedTypeElement"] = 709, + ["UninitializedWitnessTableElement"] = 710, } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 20cc7e334e3..4620c49e53e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2842,7 +2842,23 @@ struct IRSetBase : IRInst FIDDLE(baseInst()) UInt getCount() { return getOperandCount(); } IRInst* getElement(UInt idx) { return getOperand(idx); } - bool isSingleton() { return getOperandCount() == 1; } + bool isSingleton() { return (getOperandCount() == 1) && !isUnbounded(); } + bool isUnbounded() + { + // This is an unbounded set if any of its elements are unbounded. + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UnboundedTypeElement: + case kIROp_UnboundedWitnessTableElement: + case kIROp_UnboundedFuncElement: + case kIROp_UnboundedGenericElement: + return true; + } + } + return false; + } }; FIDDLE() @@ -3587,10 +3603,14 @@ struct IRBuilder return emitMakeTuple(SLANG_COUNT_OF(args), args); } - IRMakeTaggedUnion* emitMakeTaggedUnion(IRType* type, IRInst* tag, IRInst* value) + IRMakeTaggedUnion* emitMakeTaggedUnion( + IRType* type, + IRInst* typeTag, + IRInst* witnessTableTag, + IRInst* value) { - IRInst* args[] = {tag, value}; - return cast(emitIntrinsicInst(type, kIROp_MakeTaggedUnion, 2, args)); + IRInst* args[] = {typeTag, witnessTableTag, value}; + return cast(emitIntrinsicInst(type, kIROp_MakeTaggedUnion, 3, args)); } IRInst* emitMakeValuePack(IRType* type, UInt count, IRInst* const* args); @@ -4208,6 +4228,42 @@ struct IRBuilder return cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, operands)); } + IRUnboundedTypeElement* getUnboundedTypeElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedTypeElement, 1, &interfaceType)); + } + + IRUnboundedWitnessTableElement* getUnboundedWitnessTableElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedWitnessTableElement, 1, &interfaceType)); + } + + IRUnboundedFuncElement* getUnboundedFuncElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedFuncElement, 0, nullptr)); + } + + IRUnboundedGenericElement* getUnboundedGenericElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UnboundedGenericElement, 0, nullptr)); + } + + IRUninitializedTypeElement* getUninitializedTypeElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UninitializedTypeElement, 1, &interfaceType)); + } + + IRUninitializedWitnessTableElement* getUninitializedWitnessTableElement(IRInst* interfaceType) + { + return cast( + emitIntrinsicInst(nullptr, kIROp_UninitializedWitnessTableElement, 1, &interfaceType)); + } + IRGetTagOfElementInSet* emitGetTagOfElementInSet( IRType* tagType, IRInst* element, diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 4b6c13b1ef4..cc640973f76 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2845,6 +2845,28 @@ local insts = { -- hoistable = true } }, + { UnboundedTypeElement = { + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { UnboundedFuncElement = { + hoistable = true, + } }, + { UnboundedWitnessTableElement = { + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { UnboundedGenericElement = { + hoistable = true, + } }, + { UninitializedTypeElement = { + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, + { UninitializedWitnessTableElement = { + hoistable = true, + operands = { {"baseInterfaceType"} } + } }, } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-enum-type.cpp b/source/slang/slang-ir-lower-enum-type.cpp index 548c29a5182..b22539192df 100644 --- a/source/slang/slang-ir-lower-enum-type.cpp +++ b/source/slang/slang-ir-lower-enum-type.cpp @@ -49,6 +49,23 @@ struct EnumTypeLoweringContext if (!type) return nullptr; + if (auto attributedType = as(type)) + { + IRBuilder builder(module); + + List attrs; + for (auto attr : attributedType->getAllAttrs()) + attrs.add(attr); + + RefPtr info = new LoweredEnumTypeInfo(); + info->enumType = (IRType*)type; + info->loweredType = builder.getAttributedType( + getLoweredEnumType(type->getOperand(0))->loweredType, + attrs); + loweredEnumTypes[type] = info; + return info.Ptr(); + } + if (type->getOp() != kIROp_EnumType) return nullptr; diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp index ea7f906a9a6..a3ed369837b 100644 --- a/source/slang/slang-ir-lower-existential.cpp +++ b/source/slang/slang-ir-lower-existential.cpp @@ -6,7 +6,7 @@ #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" - +/* namespace Slang { bool isCPUTarget(TargetRequest* targetReq); @@ -311,3 +311,5 @@ void lowerExistentials(SharedGenericsLoweringContext* sharedContext) context.processModule(); } } // namespace Slang + +*/ \ No newline at end of file diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index b3223108684..0a87d634d28 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -9,7 +9,6 @@ #include "slang-ir-generics-lowering-context.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-layout.h" -#include "slang-ir-lower-existential.h" #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-type.h" @@ -105,27 +104,6 @@ void cleanUpInterfaceTypes(SharedGenericsLoweringContext* sharedContext) } } -void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext) -{ - InstPassBase pass(sharedContext->module); - pass.processInstsOfType( - kIROp_IsType, - [&](IRIsType* inst) - { - auto witnessTableType = - as(inst->getValueWitness()->getDataType()); - if (witnessTableType && - isComInterfaceType((IRType*)witnessTableType->getConformanceType())) - return; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto eqlInst = builder.emitEql( - builder.emitGetSequentialIDInst(inst->getValueWitness()), - builder.emitGetSequentialIDInst(inst->getTargetWitness())); - inst->replaceUsesWith(eqlInst); - inst->removeAndDeallocate(); - }); -} // Turn all references of witness table or RTTI objects into integer IDs, generate // specialized `switch` based dispatch functions based on witness table IDs, and remove @@ -134,16 +112,13 @@ void lowerIsTypeInsts(SharedGenericsLoweringContext* sharedContext) // no pointers are involved in RTTI / dynamic dispatch logic. void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, DiagnosticSink* sink) { - specializeDispatchFunctions(sharedContext); + /*specializeDispatchFunctions(sharedContext); if (sink->getErrorCount() != 0) - return; - - lowerSequentialIDTagCasts(sharedContext->module, sharedContext->sink); - lowerTagTypes(sharedContext->module); + return;*/ - lowerIsTypeInsts(sharedContext); + // lowerIsTypeInsts(sharedContext); - specializeDynamicAssociatedTypeLookup(sharedContext); + /*specializeDynamicAssociatedTypeLookup(sharedContext); if (sink->getErrorCount() != 0) return; @@ -153,7 +128,7 @@ void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, Diagnos cleanUpRTTIHandleTypes(sharedContext); - cleanUpInterfaceTypes(sharedContext); + cleanUpInterfaceTypes(sharedContext);*/ } void checkTypeConformanceExists(SharedGenericsLoweringContext* context) @@ -230,7 +205,7 @@ void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSin sharedContext.targetProgram = targetProgram; sharedContext.sink = sink; - checkTypeConformanceExists(&sharedContext); + /*checkTypeConformanceExists(&sharedContext); // Replace all `makeExistential` insts with `makeExistentialWithRTTI` // before making any other changes. This is necessary because a parameter of @@ -250,17 +225,18 @@ void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSin if (sink->getErrorCount() != 0) return; - lowerExistentials(&sharedContext); + lowerGenericCalls(&sharedContext); if (sink->getErrorCount() != 0) return; - lowerGenericCalls(&sharedContext); + generateWitnessTableWrapperFunctions(&sharedContext); if (sink->getErrorCount() != 0) return; - generateWitnessTableWrapperFunctions(&sharedContext); + lowerExistentials(&sharedContext); if (sink->getErrorCount() != 0) return; + */ // This optional step replaces all uses of witness tables and RTTI objects with // sequential IDs. Without this step, we will emit code that uses function pointers and diff --git a/source/slang/slang-ir-lower-reinterpret.cpp b/source/slang/slang-ir-lower-reinterpret.cpp index 6624b5f1a1a..28be3d1bbe8 100644 --- a/source/slang/slang-ir-lower-reinterpret.cpp +++ b/source/slang/slang-ir-lower-reinterpret.cpp @@ -94,7 +94,7 @@ void lowerReinterpret(TargetProgram* target, IRModule* module, DiagnosticSink* s // Before processing reinterpret insts, ensure that existential types without // user-defined sizes have inferred sizes where possible. // - inferAnyValueSizeWhereNecessary(target, module); + /*inferAnyValueSizeWhereNecessary(target, module, sink);*/ ReinterpretLoweringContext context; context.module = module; diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index bcd969f7551..3f22d17295d 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -3,6 +3,7 @@ #include "slang-ir-any-value-marshalling.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" +#include "slang-ir-layout.h" #include "slang-ir-specialize.h" #include "slang-ir-typeflow-collection.h" #include "slang-ir-util.h" @@ -11,22 +12,11 @@ namespace Slang { -SlangInt calculateAnyValueSize(const HashSet& types) -{ - SlangInt maxSize = 0; - for (auto type : types) - { - auto size = getAnyValueSize(type); - if (size > maxSize) - maxSize = size; - } - return maxSize; -} -IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& types) +UInt getUniqueID(IRBuilder* builder, IRInst* inst) { - auto size = calculateAnyValueSize(types); - return builder->getAnyValueType(size); + // Fallback. + return builder->getUniqueID(inst); } // Generate a single function that dispatches to each function in the collection. @@ -238,7 +228,9 @@ struct TagOpsLoweringContext : public InstPassBase // Since all elements have a unique ID across the module, this is the identity operation. // - inst->replaceUsesWith(inst->getOperand(0)); + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + inst->replaceUsesWith(builder.emitCast(inst->getDataType(), inst->getOperand(0), true)); inst->removeAndDeallocate(); } @@ -273,8 +265,8 @@ struct TagOpsLoweringContext : public InstPassBase // it must have been assigned a unique ID. // mapping.add( - builder.getUniqueID(srcSet->getElement(i)), - builder.getUniqueID(destElement)); + getUniqueID(&builder, srcSet->getElement(i)), + getUniqueID(&builder, destElement)); break; // Found the index } } @@ -308,7 +300,7 @@ struct TagOpsLoweringContext : public InstPassBase IRBuilder builder(inst->getModule()); builder.setInsertAfter(inst); - auto uniqueId = builder.getUniqueID(inst->getOperand(0)); + auto uniqueId = getUniqueID(&builder, inst->getOperand(0)); auto resultValue = builder.getIntValue(inst->getDataType(), uniqueId); inst->replaceUsesWith(resultValue); inst->removeAndDeallocate(); @@ -501,9 +493,63 @@ bool lowerDispatchers(IRModule* module, DiagnosticSink* sink) // This context lowers `TypeSet` instructions. struct SetLoweringContext : public InstPassBase { - SetLoweringContext(IRModule* module) - : InstPassBase(module) + SetLoweringContext( + IRModule* module, + TargetProgram* targetProgram, + DiagnosticSink* sink = nullptr) + : InstPassBase(module), targetProgram(targetProgram), sink(sink) + { + } + + SlangInt tryCalculateAnyValueSize(const HashSet& types) + { + SlangInt maxSize = 0; + for (auto type : types) + { + auto size = getAnyValueSize(type); + if (size > maxSize) + maxSize = size; + + if (sink && !canTypeBeStored(type)) + { + sink->diagnose( + type->sourceLoc, + Slang::Diagnostics::typeCannotBePackedIntoAnyValue, + type); + } + } + + // Defaults to 0 if any type could not be sized. + return maxSize; + } + + IRAnyValueType* createAnyValueType(IRBuilder* builder, const HashSet& types) { + auto size = tryCalculateAnyValueSize(types); + return builder->getAnyValueType(size); + } + + bool canTypeBeStored(IRType* concreteType) + { + if (!areResourceTypesBindlessOnTarget(targetProgram->getTargetReq())) + { + IRType* opaqueType = nullptr; + if (isOpaqueType(concreteType, &opaqueType)) + { + return false; + } + } + + IRSizeAndAlignment sizeAndAlignment; + Result result = getNaturalSizeAndAlignment( + targetProgram->getOptionSet(), + concreteType, + &sizeAndAlignment); + + if (SLANG_FAILED(result)) + return false; + + return true; } void lowerUntaggedUnionType(IRUntaggedUnionType* valueOfSetType) @@ -532,15 +578,19 @@ struct SetLoweringContext : public InstPassBase kIROp_UntaggedUnionType, [&](IRUntaggedUnionType* inst) { return lowerUntaggedUnionType(inst); }); } + +private: + DiagnosticSink* sink; + TargetProgram* targetProgram; }; // Lower `UntaggedUnionType(TypeSet(...))` instructions by replacing them with // appropriate `AnyValueType` instructions. // -void lowerUntaggedUnionTypes(IRModule* module, DiagnosticSink* sink) +void lowerUntaggedUnionTypes(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink) { SLANG_UNUSED(sink); - SetLoweringContext context(module); + SetLoweringContext context(module, targetProgram, sink); context.processModule(); } @@ -550,8 +600,8 @@ void lowerUntaggedUnionTypes(IRModule* module, DiagnosticSink* sink) // struct SequentialIDTagLoweringContext : public InstPassBase { - SequentialIDTagLoweringContext(IRModule* module) - : InstPassBase(module) + SequentialIDTagLoweringContext(Linkage* linkage, IRModule* module) + : InstPassBase(module), m_linkage(linkage) { } void lowerGetTagFromSequentialID(IRGetTagFromSequentialID* inst) @@ -651,8 +701,112 @@ struct SequentialIDTagLoweringContext : public InstPassBase inst->removeAndDeallocate(); } + + // Ensures every witness table object has been assigned a sequential ID. + // All witness tables will have a SequentialID decoration after this function is run. + // The sequantial ID in the decoration will be the same as the one specified in the Linkage. + // Otherwise, a new ID will be generated and assigned to the witness table object, and + // the sequantial ID map in the Linkage will be updated to include the new ID, so they + // can be looked up by the user via future Slang API calls. + void ensureWitnessTableSequentialIDs() + { + StringBuilder generatedMangledName; + + auto linkage = getLinkage(); + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_WitnessTable) + { + UnownedStringSlice witnessTableMangledName; + if (auto instLinkage = inst->findDecoration()) + { + witnessTableMangledName = instLinkage->getMangledName(); + } + else + { + auto witnessTableType = as(inst->getDataType()); + + if (witnessTableType && witnessTableType->getConformanceType() == nullptr) + { + // Ignore witness tables that represent 'none' for optional witness table + // types. + continue; + } + + if (witnessTableType && witnessTableType->getConformanceType() + ->findDecoration()) + { + // The interface is for specialization only, it would be an error if dynamic + // dispatch is used through the interface. Skip assigning ID for the witness + // table. + continue; + } + + // generate a unique linkage for it. + static int32_t uniqueId = 0; + uniqueId++; + if (auto nameHint = inst->findDecoration()) + { + generatedMangledName << nameHint->getName(); + } + generatedMangledName << "_generated_witness_uuid_" << uniqueId; + witnessTableMangledName = generatedMangledName.getUnownedSlice(); + } + + // If the inst already has a SequentialIDDecoration, stop now. + if (inst->findDecoration()) + continue; + + // Get a sequential ID for the witness table using the map from the Linkage. + uint32_t seqID = 0; + if (!linkage->mapMangledNameToRTTIObjectIndex.tryGetValue( + witnessTableMangledName, + seqID)) + { + auto interfaceType = + cast(inst->getDataType())->getConformanceType(); + if (as(interfaceType)) + { + auto interfaceLinkage = + interfaceType->findDecoration(); + SLANG_ASSERT( + interfaceLinkage && "An interface type does not have a linkage," + "but a witness table associated with it has one."); + auto interfaceName = interfaceLinkage->getMangledName(); + auto idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( + interfaceName); + if (!idAllocator) + { + linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = + 0; + idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( + interfaceName); + } + seqID = *idAllocator; + ++(*idAllocator); + } + else + { + // NoneWitness, has special ID of -1. + seqID = uint32_t(-1); + } + linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID; + } + + // Add a decoration to the inst. + IRBuilder builder(module); + builder.setInsertBefore(inst); + builder.addSequentialIDDecoration(inst, seqID); + } + } + } + void processModule() { + ensureWitnessTableSequentialIDs(); + processInstsOfType( kIROp_GetTagFromSequentialID, [&](IRGetTagFromSequentialID* inst) { return lowerGetTagFromSequentialID(inst); }); @@ -661,12 +815,17 @@ struct SequentialIDTagLoweringContext : public InstPassBase kIROp_GetSequentialIDFromTag, [&](IRGetSequentialIDFromTag* inst) { return lowerGetSequentialIDFromTag(inst); }); } + + Linkage* getLinkage() { return m_linkage; } + +private: + Linkage* m_linkage; }; -void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink) +void lowerSequentialIDTagCasts(IRModule* module, Linkage* linkage, DiagnosticSink* sink) { SLANG_UNUSED(sink); - SequentialIDTagLoweringContext context(module); + SequentialIDTagLoweringContext context(linkage, module); context.processModule(); } @@ -854,7 +1013,7 @@ struct TaggedUnionLoweringContext : public InstPassBase inst->removeAndDeallocate(); } - IRType* convertToTupleType(IRTaggedUnionType* taggedUnion) + IRType* lowerTaggedUnionType(IRTaggedUnionType* taggedUnion) { // Replace `TaggedUnionType(typeSet, tableSet)` with // `TupleType(SetTagType(tableSet), typeSet)` @@ -920,25 +1079,35 @@ struct TaggedUnionLoweringContext : public InstPassBase bool lowerGetTypeTagFromTaggedUnion(IRGetTypeTagFromTaggedUnion* inst) { - // We don't use type tags anywhere, so this instruction should have no - // uses. - // - SLANG_ASSERT(inst->hasUses() == false); - inst->removeAndDeallocate(); + IRBuilder builder(module); + builder.setInsertAfter(inst); + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); return true; } + bool lowerMakeTaggedUnion(IRMakeTaggedUnion* inst) { - // We replace `MakeTaggedUnion(tag, val)` with `MakeTuple(tag, val)` + // We replace `MakeTaggedUnion(typeTag, witnessTableTag, val)` with `MakeTuple(tag, val)` // IRBuilder builder(module); builder.setInsertAfter(inst); - auto tag = inst->getOperand(0); - auto val = inst->getOperand(1); - inst->replaceUsesWith(builder.emitMakeTuple((IRType*)inst->getDataType(), {tag, val})); + auto tuTupleType = cast(inst->getDataType()); + + // The current lowering logic is only for bounded tagged unions (finite sets) + SLANG_ASSERT(!as(tuTupleType->getOperand(0))->getSet()->isUnbounded()); + + auto typeTag = inst->getOperand(0); + // We'll ignore the type tag, since the table is the only thing we need. + // for the bounded case. + SLANG_UNUSED(typeTag); + + auto witnessTableTag = inst->getOperand(1); + auto val = inst->getOperand(2); + inst->replaceUsesWith( + builder.emitMakeTuple((IRType*)inst->getDataType(), {witnessTableTag, val})); inst->removeAndDeallocate(); return true; } @@ -952,7 +1121,7 @@ struct TaggedUnionLoweringContext : public InstPassBase kIROp_TaggedUnionType, [&](IRTaggedUnionType* inst) { - inst->replaceUsesWith(convertToTupleType(inst)); + inst->replaceUsesWith(lowerTaggedUnionType(inst)); inst->removeAndDeallocate(); }); @@ -998,4 +1167,389 @@ bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) TaggedUnionLoweringContext context(module); return context.processModule(); } + +void lowerIsTypeInsts(IRModule* module) +{ + InstPassBase pass(module); + pass.processInstsOfType( + kIROp_IsType, + [&](IRIsType* inst) + { + auto witnessTableType = + as(inst->getValueWitness()->getDataType()); + if (witnessTableType && + isComInterfaceType((IRType*)witnessTableType->getConformanceType())) + return; + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto eqlInst = builder.emitEql( + builder.emitGetSequentialIDInst(inst->getValueWitness()), + builder.emitGetSequentialIDInst(inst->getTargetWitness())); + inst->replaceUsesWith(eqlInst); + inst->removeAndDeallocate(); + }); +} + +struct ExistentialLoweringContext : public InstPassBase +{ + ExistentialLoweringContext(IRModule* module) + : InstPassBase(module) + { + } + + + bool _canReplace(IRUse* use) + { + switch (use->getUser()->getOp()) + { + case kIROp_WitnessTableIDType: + case kIROp_WitnessTableType: + case kIROp_RTTIPointerType: + case kIROp_RTTIHandleType: + case kIROp_ComPtrType: + case kIROp_NativePtrType: + { + // Don't replace + return false; + } + case kIROp_ThisType: + { + // Appears replacable. + break; + } + case kIROp_PtrType: + { + // We can have ** and ComPtr*. + // If it's a pointer type it could be because it is a global. + break; + } + default: + break; + } + return true; + } + + // Replace all WitnessTableID type or RTTIHandleType with `uint2`. + void lowerHandleTypes() + { + List instsToRemove; + for (auto inst : module->getGlobalInsts()) + { + switch (inst->getOp()) + { + case kIROp_WitnessTableIDType: + if (isComInterfaceType((IRType*)inst->getOperand(0))) + continue; + // fall through + case kIROp_RTTIHandleType: + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + auto uint2Type = builder.getVectorType( + builder.getUIntType(), + builder.getIntValue(builder.getIntType(), 2)); + inst->replaceUsesWith(uint2Type); + instsToRemove.add(inst); + } + break; + } + } + for (auto inst : instsToRemove) + inst->removeAndDeallocate(); + } + + IRInst* lowerInterfaceType(IRInst* interfaceType) + { + if (isComInterfaceType((IRType*)interfaceType)) + return (IRType*)interfaceType; + + IRBuilder builder(module); + if (isBuiltin(interfaceType)) + return (IRType*)builder.getIntValue(builder.getIntType(), 0); + + IRIntegerValue anyValueSize = 0; + if (auto decor = interfaceType->findDecoration()) + { + anyValueSize = decor->getSize(); + } + + auto anyValueType = builder.getAnyValueType(anyValueSize); + auto witnessTableType = builder.getWitnessTableIDType((IRType*)interfaceType); + auto rttiType = builder.getRTTIHandleType(); + + return builder.getTupleType(rttiType, witnessTableType, anyValueType); + } + + IRInst* lowerBoundInterfaceType(IRBoundInterfaceType* boundInterfaceType) + { + IRBuilder builder(module); + + auto payloadType = boundInterfaceType->getConcreteType(); + auto witnessTableType = builder.getWitnessTableIDType( + (IRType*)as(boundInterfaceType->getWitnessTable()) + ->getConformanceType()); + auto rttiType = builder.getRTTIHandleType(); + auto interfaceType = boundInterfaceType->getInterfaceType(); + + IRIntegerValue anyValueSize = 16; + if (auto decor = interfaceType->findDecoration()) + { + anyValueSize = decor->getSize(); + } + + auto anyValueType = builder.getAnyValueType(anyValueSize); + + return builder.getTupleType( + rttiType, + witnessTableType, + builder.getPseudoPtrType(payloadType), + anyValueType); + } + + bool lowerExtractExistentialType(IRExtractExistentialType* inst) + { + // Replace with extraction of the value type from the tagged-union tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + if (auto tupleType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(0), + inst->getOperand(0), + 0)); + inst->removeAndDeallocate(); + } + else if (auto comPtrType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + } + return true; + } + + bool lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) + { + // Replace with extraction of the value from the tagged-union tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + if (auto tupleType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(1), + inst->getOperand(0), + 1)); + inst->removeAndDeallocate(); + return true; + } + else if (auto comPtrType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + return true; + } + else + { + SLANG_UNEXPECTED("Unexpected type for ExtractExistentialWitnessTable operand"); + } + return false; + } + + bool lowerGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst) + { + // Replace with extraction of the value from the tagged-union tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + auto tupleType = as(inst->getOperand(0)->getDataType()); + + if (as(tupleType->getOperand(2))) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(2), + inst->getOperand(0), + 2)); + inst->removeAndDeallocate(); + return true; + } + else + { + inst->replaceUsesWith(builder.emitUnpackAnyValue( + inst->getDataType(), + builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(2), + inst->getOperand(0), + 2))); + inst->removeAndDeallocate(); + return true; + } + return true; + } + + bool lowerExtractExistentialValue(IRExtractExistentialValue* inst) + { + // Replace with extraction of the value from the tagged-union tuple. + // + + IRBuilder builder(module); + builder.setInsertAfter(inst); + + if (auto tupleType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(builder.emitGetTupleElement( + (IRType*)tupleType->getOperand(2), + inst->getOperand(0), + 2)); + inst->removeAndDeallocate(); + return true; + } + else if (auto comPtrType = as(inst->getOperand(0)->getDataType())) + { + inst->replaceUsesWith(inst->getOperand(0)); + inst->removeAndDeallocate(); + return true; + } + SLANG_UNEXPECTED("Unexpected type for ExtractExistentialValue operand"); + return false; + } + + bool processGetSequentialIDInst(IRGetSequentialID* inst) + { + // If the operand is a witness table, it is already replaced with a uint2 + // at this point, where the first element in the uint2 is the id of the + // witness table. + IRBuilder builder(module); + builder.setInsertBefore(inst); + + if (auto table = as(inst->getRTTIOperand())) + { + auto seqDecoration = table->findDecoration(); + SLANG_ASSERT(seqDecoration && "Witness table missing SequentialID decoration"); + auto id = builder.getIntValue(builder.getUIntType(), seqDecoration->getSequentialID()); + inst->replaceUsesWith(id); + inst->removeAndDeallocate(); + return true; + } + + + UInt index = 0; + auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index); + inst->replaceUsesWith(id); + inst->removeAndDeallocate(); + return true; + } + + void processModule() + { + // Then, start lowering the remaining non-COM/non-Builtin interface types + // At this point, we should only bea dealing with public facing uses of + // interface types (which must lower into a 3-tuple of RTTI, witness table ID, AnyValue) + // + processInstsOfType( + kIROp_InterfaceType, + [&](IRInterfaceType* inst) + { + IRBuilder builder(module); + builder.setInsertInto(module); + if (auto loweredInterfaceType = lowerInterfaceType(inst)) + { + if (loweredInterfaceType != inst) + { + + traverseUses( + inst, + [&](IRUse* use) + { + if (_canReplace(use)) + builder.replaceOperand(use, loweredInterfaceType); + }); + } + } + }); + + processInstsOfType( + kIROp_BoundInterfaceType, + [&](IRBoundInterfaceType* inst) + { + IRBuilder builder(module); + builder.setInsertInto(module); + if (auto loweredBoundInterfaceType = lowerBoundInterfaceType(inst)) + { + if (loweredBoundInterfaceType != inst) + { + traverseUses( + inst, + [&](IRUse* use) + { + if (_canReplace(use)) + builder.replaceOperand(use, loweredBoundInterfaceType); + }); + } + } + }); + + // Replace any other uses with dummy value 0. + // TODO: Ideally, we should replace it with IRPoison.. + { + IRBuilder builder(module); + builder.setInsertInto(module); + auto dummyInterfaceObj = builder.getIntValue(builder.getIntType(), 0); + processInstsOfType( + kIROp_InterfaceType, + [&](IRInterfaceType* inst) + { + if (!isComInterfaceType((IRType*)inst)) + { + inst->replaceUsesWith(dummyInterfaceObj); + inst->removeAndDeallocate(); + } + }); + } + + processAllInsts( + [&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ExtractExistentialType: + lowerExtractExistentialType(cast(inst)); + break; + case kIROp_ExtractExistentialValue: + lowerExtractExistentialValue(cast(inst)); + break; + case kIROp_ExtractExistentialWitnessTable: + lowerExtractExistentialWitnessTable( + cast(inst)); + break; + case kIROp_GetValueFromBoundInterface: + lowerGetValueFromBoundInterface(cast(inst)); + break; + } + }); + + lowerIsTypeInsts(module); + + processInstsOfType( + kIROp_GetSequentialID, + [&](IRGetSequentialID* inst) { return processGetSequentialIDInst(inst); }); + + lowerHandleTypes(); + } +}; + +bool lowerExistentials(IRModule* module, DiagnosticSink* sink) +{ + SLANG_UNUSED(sink); + ExistentialLoweringContext context(module); + context.processModule(); + return true; +}; + }; // namespace Slang \ No newline at end of file diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h index 930bf4fd830..72f595e5929 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -4,9 +4,11 @@ namespace Slang { +class Linkage; +class TargetProgram; // Lower `UntaggedUnionType` types. -void lowerUntaggedUnionTypes(IRModule* module, DiagnosticSink* sink); +void lowerUntaggedUnionTypes(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink); // Lower `SetTaggedUnion` and `CastInterfaceToTaggedUnionPtr` instructions // May create new `Reinterpret` instructions. @@ -22,9 +24,14 @@ void lowerTagTypes(IRModule* module); void lowerTagInsts(IRModule* module, DiagnosticSink* sink); // Lower `GetTagFromSequentialID` and `GetSequentialIDFromTag` instructions -void lowerSequentialIDTagCasts(IRModule* module, DiagnosticSink* sink); +void lowerSequentialIDTagCasts(IRModule* module, Linkage* linkage, DiagnosticSink* sink); // Lower `GetDispatcher` and `GetSpecializedDispatcher` instructions bool lowerDispatchers(IRModule* module, DiagnosticSink* sink); +// Lower `ExtractExistentialValue`, `ExtractExistentialType`, `ExtractExistentialWitnessTable`, +// `InterfaceType`, `GetSequentialID`, `WitnessTableIDType` and `RTTIHandleType` instructions. +// +bool lowerExistentials(IRModule* module, DiagnosticSink* sink); + } // namespace Slang diff --git a/source/slang/slang-ir-propagate-func-properties.cpp b/source/slang/slang-ir-propagate-func-properties.cpp index 0e5ee731a38..ece196079ae 100644 --- a/source/slang/slang-ir-propagate-func-properties.cpp +++ b/source/slang/slang-ir-propagate-func-properties.cpp @@ -82,6 +82,20 @@ class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationCon return true; } + bool isDebugInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DebugLine: + case kIROp_DebugScope: + case kIROp_DebugVar: + case kIROp_DebugValue: + return true; + default: + return false; + } + } + virtual bool propagate(IRBuilder& builder, IRFunc* f) override { bool hasReadNoneCall = false; @@ -90,7 +104,7 @@ class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationCon for (auto inst : block->getChildren()) { // Is this inst known to not have global side effect/analyzable? - if (!isKnownOpCodeWithSideEffect(inst->getOp())) + if (!isKnownOpCodeWithSideEffect(inst->getOp()) && !isDebugInst(inst)) { if (inst->mightHaveSideEffects() || isResourceLoad(inst->getOp())) { @@ -122,6 +136,12 @@ class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationCon } } + // If the inst is a debug instruction, skip it. + // these are only annotations + // + if (isDebugInst(inst)) // TODO: May not need this + continue; + // Do any operands defined have pointer type of global or // unknown source? Passing them into a readNone callee may cause // side effects that breaks the readNone property. diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index 8d2a08a618d..2a8374a42f0 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -183,6 +183,7 @@ IRFunc* specializeDispatchFunction( // Otherwise, a new ID will be generated and assigned to the witness table object, and // the sequantial ID map in the Linkage will be updated to include the new ID, so they // can be looked up by the user via future Slang API calls. +/* void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContext) { StringBuilder generatedMangledName; @@ -274,6 +275,7 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex } } } +*/ // Fixes up call sites of a dispatch function, so that the witness table argument is replaced with // its sequential ID. @@ -309,7 +311,7 @@ void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) { // First we ensure that all witness table objects has a sequential ID assigned. - ensureWitnessTableSequentialIDs(sharedContext); + // ensureWitnessTableSequentialIDs(sharedContext); // Generate specialized dispatch functions and fixup call sites. for (const auto& [_, dispatchFunc] : sharedContext->mapInterfaceRequirementKeyToDispatchMethods) diff --git a/source/slang/slang-ir-strip-legalization-insts.cpp b/source/slang/slang-ir-strip-legalization-insts.cpp index f55d3f11974..316d091c6ca 100644 --- a/source/slang/slang-ir-strip-legalization-insts.cpp +++ b/source/slang/slang-ir-strip-legalization-insts.cpp @@ -68,12 +68,8 @@ void unpinWitnessTables(IRModule* module) if (!witnessTable) continue; - // If a witness table is not used for dynamic dispatch, unpin it. - if (!witnessTable->findDecoration()) - { - while (auto decor = witnessTable->findDecoration()) - decor->removeAndDeallocate(); - } + while (auto decor = witnessTable->findDecoration()) + decor->removeAndDeallocate(); } } diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index 721110710e8..f1c0a1f76fa 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -38,20 +38,24 @@ IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) if (argTUType != destTUType) { - // Technically, IRTaggedUnionType is not a TupleType, - // but in practice it works the same way so we'll re-use Slang's - // tuple accessors & constructors - // auto argTableTag = builder->emitGetTagFromTaggedUnion(arg); - auto reinterpretedTag = upcastSet( + auto reinterpretedTableTag = upcastSet( builder, argTableTag, builder->getSetTagType(destTUType->getWitnessTableSet())); + auto argTypeTag = builder->emitGetTypeTagFromTaggedUnion(arg); + auto reinterpretedTypeTag = + upcastSet(builder, argTypeTag, builder->getSetTagType(destTUType->getTypeSet())); + auto argVal = builder->emitGetValueFromTaggedUnion(arg); auto reinterpretedVal = upcastSet(builder, argVal, builder->getUntaggedUnionType(destTUType->getTypeSet())); - return builder->emitMakeTaggedUnion(destTUType, reinterpretedTag, reinterpretedVal); + return builder->emitMakeTaggedUnion( + destTUType, + reinterpretedTypeTag, + reinterpretedTableTag, + reinterpretedVal); } } else if (as(argInfo) && as(destInfo)) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index c08e9400032..64b790a3c2f 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -331,6 +331,13 @@ bool isSetSpecializedGeneric(IRInst* callee) return false; } +IRInst* getArrayStride(IRArrayType* arrayType) +{ + if (arrayType->getOperandCount() >= 3) + return arrayType->getStride(); + return nullptr; +} + // // Helper struct to represent a parameter's direction and type component. // This is used by the type flow system to figure out which direction to propagate @@ -462,7 +469,7 @@ bool isConcreteType(IRInst* inst) switch (inst->getOp()) { case kIROp_InterfaceType: // Can be refined to tagged unions - return false; + return isComInterfaceType(cast(inst)); case kIROp_WitnessTableType: // Can be refined into set of concrete tables return false; case kIROp_FuncType: // Can be refined into set of concrete functions @@ -471,6 +478,8 @@ bool isConcreteType(IRInst* inst) return false; case kIROp_TypeKind: // Can be refined into set of concrete types return false; + case kIROp_TypeType: // Can be refined into set of concrete types + return false; case kIROp_ArrayType: return isConcreteType(cast(inst)->getElementType()) && isGlobalInst(cast(inst)->getElementCount()); @@ -504,7 +513,8 @@ IRInst* makeInfoForConcreteType(IRModule* module, IRInst* type) { return builder.getArrayType( (IRType*)makeInfoForConcreteType(module, arrayType->getElementType()), - arrayType->getElementCount()); + arrayType->getElementCount(), + getArrayStride(arrayType)); } return builder.getUntaggedUnionType(cast(builder.getSingletonSet(type))); @@ -548,6 +558,28 @@ bool isOptionalExistentialType(IRInst* inst) return false; } +IRInst* maybeGetUninitializedElement(IRSetBase* set) +{ + IRInst* foundInst = nullptr; + forEachInSet( + set, + [&](IRInst* element) + { + if (auto uninitializedTypeElement = as(element)) + { + foundInst = uninitializedTypeElement; + } + else if ( + auto uninitializedWitnessTableElement = + as(element)) + { + foundInst = uninitializedWitnessTableElement; + } + }); + + return foundInst; +} + // Parent context for the full type-flow pass. struct TypeFlowSpecializationContext { @@ -566,8 +598,21 @@ struct TypeFlowSpecializationContext tableSet, [&](IRInst* witnessTable) { - if (auto table = as(witnessTable)) - typeSet.add(table->getConcreteType()); + switch (witnessTable->getOp()) + { + case kIROp_UnboundedWitnessTableElement: + typeSet.add(builder.getUnboundedTypeElement( + as(witnessTable)->getBaseInterfaceType())); + break; + case kIROp_WitnessTable: + typeSet.add(as(witnessTable)->getConcreteType()); + break; + case kIROp_UninitializedWitnessTableElement: + typeSet.add(builder.getUninitializedTypeElement( + as(witnessTable) + ->getBaseInterfaceType())); + break; + } }); // Create the tagged union type out of the type and table collection. @@ -783,6 +828,7 @@ struct TypeFlowSpecializationContext if (areInfosEqual(info1, info2)) return info1; + // TODO: Move into utility function to avoid dropping information. if (as(info1) && as(info2)) { SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); @@ -791,7 +837,8 @@ struct TypeFlowSpecializationContext builder.setInsertInto(module); return builder.getArrayType( (IRType*)unionPropagationInfo(info1->getOperand(0), info2->getOperand(0)), - info1->getOperand(1)); // Keep the same size + as(info1)->getElementCount(), + getArrayStride(as(info1))); // Keep the same size } if (as(info1) && as(info2)) @@ -806,11 +853,13 @@ struct TypeFlowSpecializationContext as(info1)); } + /* if (as(info1) && as(info2)) { // If either info is unbounded, the union is unbounded return makeUnbounded(); } + */ // For all other cases which are structured composites of sets, // we simply take the set union for all the set operands. @@ -1083,6 +1132,9 @@ struct TypeFlowSpecializationContext case kIROp_MakeExistential: info = analyzeMakeExistential(context, as(inst)); break; + // case kIROp_WrapExistential: + // info = analyzeWrapExistential(context, as(inst)); + // break; case kIROp_LookupWitnessMethod: info = analyzeLookupWitnessMethod(context, as(inst)); break; @@ -1108,6 +1160,9 @@ struct TypeFlowSpecializationContext case kIROp_StructuredBufferLoad: info = analyzeLoad(context, inst); break; + case kIROp_LoadFromUninitializedMemory: + info = analyzeLoadFromUninitializedMemory(context, inst); + break; case kIROp_MakeStruct: info = analyzeMakeStruct(context, as(inst), workQueue); break; @@ -1352,10 +1407,10 @@ struct TypeFlowSpecializationContext IRBuilder builder(module); if (auto interfaceType = as(inst->getDataType())) { - if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) + if (isComInterfaceType(interfaceType)) { - // If this is a COM interface, we treat it as unbounded - return makeUnbounded(); + // If this is a COM interface, we ignore it. + return none(); } auto tables = collectExistentialTables(interfaceType); @@ -1378,15 +1433,19 @@ struct TypeFlowSpecializationContext IRInst* analyzeMakeExistential(IRInst* context, IRMakeExistential* inst) { IRBuilder builder(module); - auto witnessTable = inst->getWitnessTable(); // If we're building an existential for a COM interface, // we always assume it is unbounded, since we can receive // types that we know nothing about in the current linkage. // if (isComInterfaceType(inst->getDataType())) - return makeUnbounded(); + { + return none(); + // return builder.getComPtrType(inst->getDataType()); + // return makeUnbounded(); + } + auto witnessTable = inst->getWitnessTable(); // Concrete case. if (as(witnessTable)) return makeTaggedUnionType( @@ -1398,8 +1457,8 @@ struct TypeFlowSpecializationContext if (!witnessTableInfo) return none(); - if (as(witnessTableInfo)) - return makeUnbounded(); + // if (as(witnessTableInfo)) + // return makeUnbounded(); if (auto elementOfSetType = as(witnessTableInfo)) return makeTaggedUnionType(cast(elementOfSetType->getSet())); @@ -1407,6 +1466,28 @@ struct TypeFlowSpecializationContext SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } + /* + IRInst* analyzeWrapExistential(IRInst* context, IRWrapExistential* wrapExistential) + { + if (auto valInfo = tryGetInfo(context, wrapExistential->getWrappedValue())) + { + // We need a single possible value for the wrapped value. + auto taggedUnionType = cast(valInfo); + SLANG_ASSERT( + taggedUnionType->getTypeSet()->isSingleton() && + taggedUnionType->getWitnessTableSet()->isSingleton()); + // Since the inst's result is expected to be a concrete type, + // we'll return a 'none' here. The info won't be recorded anyway. + // + return none(); + } + else + { + return none(); + } + } + */ + IRInst* analyzeMakeStruct(IRInst* context, IRMakeStruct* makeStruct, WorkQueue& workQueue) { // We'll process this in the same way as a field-address, but for @@ -1442,6 +1523,20 @@ struct TypeFlowSpecializationContext return none(); // the make struct itself doesn't have any info. } + IRInst* analyzeLoadFromUninitializedMemory(IRInst* context, IRInst* inst) + { + IRBuilder builder(module); + if (as(inst->getDataType()) && !isConcreteType(inst->getDataType())) + { + auto uninitializedSet = builder.getSingletonSet( + kIROp_WitnessTableSet, + builder.getUninitializedWitnessTableElement(inst->getDataType())); + return makeTaggedUnionType(as(uninitializedSet)); + } + + return none(); + } + IRInst* analyzeLoad(IRInst* context, IRInst* inst) { IRBuilder builder(module); @@ -1463,7 +1558,7 @@ struct TypeFlowSpecializationContext { if (auto interfaceType = as(loadInst->getDataType())) { - if (!isComInterfaceType(interfaceType) && !isBuiltin(interfaceType)) + if (!isComInterfaceType(interfaceType)) { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) @@ -1474,9 +1569,22 @@ struct TypeFlowSpecializationContext } else { - return makeUnbounded(); + return none(); } } + else if ( + auto boundInterfaceType = as(loadInst->getDataType())) + { + IRBuilder builder(module); + return makeTaggedUnionType(cast( + builder.getSingletonSet(boundInterfaceType->getWitnessTable()))); + } + else + { + // Loading from a resource pointer that isn't an interface? + // Just return no info. + return none(); + } } // If the load is from a pointer, we can transfer the info directly @@ -1493,7 +1601,7 @@ struct TypeFlowSpecializationContext // if (auto interfaceType = as(inst->getDataType())) { - if (!isComInterfaceType(interfaceType) && !isBuiltin(interfaceType)) + if (!isComInterfaceType(interfaceType)) { auto tables = collectExistentialTables(interfaceType); if (tables.getCount() > 0) @@ -1504,9 +1612,15 @@ struct TypeFlowSpecializationContext } else { - return makeUnbounded(); + return none(); } } + else if (auto boundInterfaceType = as(inst->getDataType())) + { + IRBuilder builder(module); + return makeTaggedUnionType(cast( + builder.getSingletonSet(boundInterfaceType->getWitnessTable()))); + } } return none(); // No info for other load types @@ -1756,16 +1870,31 @@ struct TypeFlowSpecializationContext forEachInSet( cast(elementOfSetType->getSet()), [&](IRInst* table) - { results.add(lookupWitnessTableEntry(cast(table), key)); }); + { + if (as(table)) + { + if (inst->getDataType()->getOp() == kIROp_FuncType) + results.add(builder.getUnboundedFuncElement()); + else if (inst->getDataType()->getOp() == kIROp_WitnessTableType) + results.add(builder.getUnboundedWitnessTableElement( + as(inst)->getDataType())); + else if (inst->getDataType()->getOp() == kIROp_TypeKind) + { + SLANG_UNEXPECTED( + "TypeKind result from LookupWitnessMethod not supported"); + } + return; + } + + results.add(lookupWitnessTableEntry(cast(table), key)); + }); + return makeElementOfSetType(builder.getSet(results)); } if (!witnessTableInfo) return none(); - if (as(witnessTableInfo)) - return makeUnbounded(); - SLANG_UNEXPECTED("Unexpected witness table info type in analyzeLookupWitnessMethod"); } @@ -1788,11 +1917,24 @@ struct TypeFlowSpecializationContext if (!operandInfo) return none(); - if (as(operandInfo)) - return makeUnbounded(); - if (auto taggedUnion = as(operandInfo)) - return makeElementOfSetType(taggedUnion->getWitnessTableSet()); + { + auto tableSet = taggedUnion->getWitnessTableSet(); + if (auto uninitElement = maybeGetUninitializedElement(tableSet)) + { + // Uninitialized element should contain + sink->diagnose( + inst->sourceLoc, + Diagnostics::noTypeConformancesFoundForInterface, + uninitElement->getOperand(0)); + + return none(); // We'll return none so that the analysis doesn't + // crash early, before we can detect the error count + // and exit gracefully. + } + + return makeElementOfSetType(tableSet); + } SLANG_UNEXPECTED("Unhandled info type in analyzeExtractExistentialWitnessTable"); } @@ -1814,9 +1956,6 @@ struct TypeFlowSpecializationContext if (!operandInfo) return none(); - if (as(operandInfo)) - return makeUnbounded(); - if (auto taggedUnion = as(operandInfo)) return makeElementOfSetType(taggedUnion->getTypeSet()); @@ -1842,9 +1981,6 @@ struct TypeFlowSpecializationContext if (!operandInfo) return none(); - if (as(operandInfo)) - return makeUnbounded(); - if (auto taggedUnion = as(operandInfo)) return makeUntaggedUnionType(taggedUnion->getTypeSet()); @@ -1878,13 +2014,9 @@ struct TypeFlowSpecializationContext auto operand = inst->getBase(); auto operandInfo = tryGetInfo(context, operand); - if (as(operandInfo)) - return makeUnbounded(); - if (as(operandInfo)) { - SLANG_UNEXPECTED( - "Unexpected ExtractExistentialWitnessTable on Set (should be Existential)"); + SLANG_UNEXPECTED("Unexpected operand for IRSpecialize"); } // Handle the 'many' or 'one' cases. @@ -1909,7 +2041,7 @@ struct TypeFlowSpecializationContext if (!argInfo) return none(); - if (as(argInfo) || as(argInfo)) + if (as(argInfo)) { SLANG_UNEXPECTED("Unexpected Existential operand in specialization argument."); } @@ -1918,8 +2050,34 @@ struct TypeFlowSpecializationContext { if (elementOfSetType->getSet()->isSingleton()) specializationArgs.add(elementOfSetType->getSet()->getElement(0)); + else if (elementOfSetType->getSet()->isUnbounded()) + { + // Infinite set. + // + // While our sets allow for encoding an unbounded case along with known + // cases, when it comes to specializing a function or placing a call to a + // function, we will default to the single unbounded element case. + // + IRInst* unboundedElement; + forEachInSet( + elementOfSetType->getSet(), + [&](IRInst* element) + { + if (as(element) || + as(element)) + unboundedElement = element; + }); + IRBuilder builder(module); + SLANG_ASSERT(unboundedElement); + auto pureUnboundedSet = builder.getSingletonSet(unboundedElement); + if (auto typeSet = as(pureUnboundedSet)) + specializationArgs.add(makeUntaggedUnionType(typeSet)); + else + specializationArgs.add(pureUnboundedSet); + } else { + // Dealing with a non-singleton, but finite set. if (auto typeSet = as(elementOfSetType->getSet())) { specializationArgs.add(makeUntaggedUnionType(typeSet)); @@ -1957,6 +2115,21 @@ struct TypeFlowSpecializationContext { if (elementOfSetType->getSet()->isSingleton()) return elementOfSetType->getSet()->getElement(0); + else if (elementOfSetType->getSet()->isUnbounded()) + { + IRInst* unboundedElement; + forEachInSet( + elementOfSetType->getSet(), + [&](IRInst* element) + { + if (as(element)) + unboundedElement = element; + }); + SLANG_ASSERT(unboundedElement); + IRBuilder builder(module); + return makeUntaggedUnionType( + cast(builder.getSingletonSet(unboundedElement))); + } else return makeUntaggedUnionType( cast(elementOfSetType->getSet())); @@ -2026,6 +2199,18 @@ struct TypeFlowSpecializationContext // Create a new specialized instruction for each argument IRBuilder builder(module); builder.setInsertInto(module); + + if (as(arg)) + { + // Infinite set. + // + // We currently only support specializing generic functions + // in this way, so we'll assume its an unbounded-func-element + // + specializedSet.add(builder.getUnboundedFuncElement()); + return; + } + specializedSet.add(builder.emitSpecializeInst( typeOfSpecialization, arg, @@ -2057,7 +2242,7 @@ struct TypeFlowSpecializationContext // time we're trying to propagate information into this context. A context // is a global-scope IRFunc or IRSpecialize. // - // If it is the first, we enqueue some work to perform initialization of all + // If it is the first, we enqueue some work to perform initialization of all // the insts in the body of the func. // // Since discover context is only called 'on-demand' as the type-flow propagation @@ -2082,6 +2267,7 @@ struct TypeFlowSpecializationContext // Add all blocks to the work queue for (auto block = func->getFirstBlock(); block; block = block->getNextBlock()) workQueue.enqueue(WorkItem(context, block)); + break; } case kIROp_Specialize: @@ -2106,6 +2292,16 @@ struct TypeFlowSpecializationContext { updateInfo(context, param, makeElementOfSetType(set), true, workQueue); } + else if (auto untaggedUnion = as(arg)) + { + IRBuilder builder(module); + updateInfo( + context, + param, + makeElementOfSetType(untaggedUnion->getSet()), + true, + workQueue); + } else if (as(arg) || as(arg)) { IRBuilder builder(module); @@ -2155,6 +2351,14 @@ struct TypeFlowSpecializationContext if (as(callee)) return; + if (as(callee)) + { + // An unbounded element represents an unknown function, + // so we can't propagate anything in this case. + // + return; + } + // Register the call site in the map to allow for the // return-edge to be created. // @@ -2237,7 +2441,8 @@ struct TypeFlowSpecializationContext auto baseInfo = builder.getPtrTypeWithAddressSpace( builder.getArrayType( (IRType*)thisValueInfo, - as(baseValueType)->getElementCount()), + as(baseValueType)->getElementCount(), + getArrayStride(as(baseValueType))), as(getElementPtr->getBase()->getDataType())); // Recursively try to update the base pointer. @@ -2475,18 +2680,6 @@ struct TypeFlowSpecializationContext auto paramInfo = tryGetInfo(context, param); if (paramInfo) continue; // Already has some information - - if (auto interfaceType = as(paramType)) - { - if (isComInterfaceType(interfaceType) || isBuiltin(interfaceType)) - propagationMap[InstWithContext(context, param)] = makeUnbounded(); - else - propagationMap[InstWithContext(context, param)] = none(); - } - else - { - propagationMap[InstWithContext(context, param)] = none(); - } } } @@ -2699,9 +2892,6 @@ struct TypeFlowSpecializationContext if (!info) return nullptr; - if (as(info)) - return nullptr; - if (auto ptrType = as(info)) { IRBuilder builder(module); @@ -2720,7 +2910,8 @@ struct TypeFlowSpecializationContext { return builder.getArrayType( (IRType*)specializedElementType, - arrayType->getElementCount()); + arrayType->getElementCount(), + getArrayStride(arrayType)); } else return nullptr; @@ -2803,6 +2994,8 @@ struct TypeFlowSpecializationContext return specializeCall(context, as(inst)); case kIROp_MakeExistential: return specializeMakeExistential(context, as(inst)); + case kIROp_WrapExistential: + return specializeWrapExistential(context, as(inst)); case kIROp_MakeStruct: return specializeMakeStruct(context, as(inst)); case kIROp_CreateExistentialObject: @@ -2868,7 +3061,7 @@ struct TypeFlowSpecializationContext builder.setInsertBefore(inst); // If there's a single element, we can do a simple replacement. - if (elementOfSetType->getSet()->getCount() == 1) + if (elementOfSetType->getSet()->isSingleton()) { auto element = elementOfSetType->getSet()->getElement(0); inst->replaceUsesWith(element); @@ -2929,28 +3122,31 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); builder.setInsertBefore(inst); - auto elementOfSetType = as(info); - if (!elementOfSetType) - return false; - - if (elementOfSetType->getSet()->getCount() == 1) + if (auto elementOfSetType = as(info)) { - // Found a single possible type. Simple replacement. - inst->replaceUsesWith(elementOfSetType->getSet()->getElement(0)); - inst->removeAndDeallocate(); - return true; + if (elementOfSetType->getSet()->isSingleton()) + { + // Found a single possible type. Simple replacement. + inst->replaceUsesWith(elementOfSetType->getSet()->getElement(0)); + inst->removeAndDeallocate(); + return true; + } + else + { + // Replace with GetElement(specializedInst, 0) -> TagType(tableSet) + // which retreives a 'tag' (i.e. a run-time identifier for one of the elements + // of the set) + // + auto operand = inst->getOperand(0); + auto element = builder.emitGetTagFromTaggedUnion(operand); + inst->replaceUsesWith(element); + inst->removeAndDeallocate(); + return true; + } } else { - // Replace with GetElement(specializedInst, 0) -> TagType(tableSet) - // which retreives a 'tag' (i.e. a run-time identifier for one of the elements - // of the set) - // - auto operand = inst->getOperand(0); - auto element = builder.emitGetTagFromTaggedUnion(operand); - inst->replaceUsesWith(element); - inst->removeAndDeallocate(); - return true; + SLANG_UNEXPECTED("Unexpected info type for ExtractExistentialWitnessTable"); } } @@ -3104,6 +3300,20 @@ struct TypeFlowSpecializationContext // table-set argument, a tag is required as input. // + if (calleeSet->isUnbounded()) + { + IRUnboundedFuncElement* unboundedFuncElement = nullptr; + forEachInSet( + calleeSet, + [&](IRInst* func) + { + if (as(func)) + unboundedFuncElement = as(func); + }); + SLANG_ASSERT(unboundedFuncElement); + return cast(unboundedFuncElement->getOperand(0)); + } + IRBuilder builder(module); List paramTypes; @@ -3219,13 +3429,19 @@ struct TypeFlowSpecializationContext return callArgs; } - void maybeSpecializeCalleeType(IRInst* callee) + IRInst* maybeSpecializeCalleeType(IRInst* callee) { if (auto specializeInst = as(callee->getDataType())) { if (isGlobalInst(specializeInst)) - callee->setFullType((IRType*)specializeGeneric(specializeInst)); + { + // callee->setFullType((IRType*)specializeGeneric(specializeInst)); + IRBuilder builder(module); + return builder.replaceOperand(&callee->typeUse, specializeGeneric(specializeInst)); + } } + + return callee; } bool specializeCall(IRInst* context, IRCall* inst) @@ -3323,7 +3539,7 @@ struct TypeFlowSpecializationContext // occur for concrete IRSpecialize insts that are created // during the specializeing process). // - maybeSpecializeCalleeType(callee); + callee = maybeSpecializeCalleeType(callee); // If we're calling using a tag, place a call to the set, // with the tag as the first argument. So the callee is @@ -3423,7 +3639,10 @@ struct TypeFlowSpecializationContext callee = setTag->getSet()->getElement(0); auto funcType = getEffectiveFuncType(callee); - callee->setFullType(funcType); + // callee->setFullType(funcType); + IRBuilder builder(module); + builder.setInsertInto(module); + callee = builder.replaceOperand(&callee->typeUse, funcType); } else { @@ -3460,13 +3679,32 @@ struct TypeFlowSpecializationContext // by the analysis. // auto funcType = getEffectiveFuncType(callee); - callee->setFullType(funcType); + // callee->setFullType(funcType); + IRBuilder builder(module); + builder.setInsertInto(module); + callee = builder.replaceOperand(&callee->typeUse, funcType); + } + else if (isGlobalInst(callee)) + { + auto resultType = getLoweredType(getFuncReturnInfo(callee)); + if (resultType && resultType != inst->getFullType()) + { + auto oldFuncType = as(callee->getDataType()); + IRBuilder builder(module); + List paramTypes; + for (auto paramType : oldFuncType->getParamTypes()) + paramTypes.add(paramType); + auto newFuncType = builder.getFuncType(paramTypes, resultType); + // callee->setFullType(newFuncType); + builder.setInsertInto(module); + callee = builder.replaceOperand(&callee->typeUse, newFuncType); + } } // If by this point, we haven't resolved our callee into a global inst ( // either a set or a single function), then we can't specialize it (likely unbounded) // - if (!isGlobalInst(callee) || isIntrinsic(callee)) + if (!isGlobalInst(callee)) return false; // First, we'll legalize all operands by upcasting if necessary. @@ -3608,21 +3846,22 @@ struct TypeFlowSpecializationContext auto typeSet = taggedUnion->getTypeSet(); IRInst* witnessTableTag = nullptr; + IRInst* typeTag = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { - auto singletonTagType = makeTagType(builder.getSingletonSet(witnessTable)); - IRInst* tagValue = - builder.emitGetTagOfElementInSet((IRType*)singletonTagType, witnessTable, tableSet); - witnessTableTag = builder.emitIntrinsicInst( + witnessTableTag = builder.emitGetTagOfElementInSet( (IRType*)makeTagType(tableSet), - kIROp_GetTagForSuperSet, - 1, - &tagValue); + witnessTable, + tableSet); + typeTag = builder.emitGetTagOfElementInSet( + (IRType*)makeTagType(typeSet), + inst->getDataType(), + typeSet); } else if (as(inst->getWitnessTable()->getDataType())) { - // Dynamic. Use the witness table inst as a tag - witnessTableTag = inst->getWitnessTable(); + witnessTableTag = upcastSet(&builder, inst->getWitnessTable(), makeTagType(tableSet)); + typeTag = nullptr; } // Create the appropriate any-value type @@ -3637,9 +3876,18 @@ struct TypeFlowSpecializationContext auto taggedUnionType = getLoweredType(taggedUnion); - auto tuple = builder.emitMakeTaggedUnion(taggedUnionType, witnessTableTag, packedValue); + inst->replaceUsesWith(builder.emitMakeTaggedUnion( + taggedUnionType, + builder.emitPoison(makeTagType(typeSet)), + witnessTableTag, + packedValue)); + inst->removeAndDeallocate(); + return true; + } - inst->replaceUsesWith(tuple); + bool specializeWrapExistential(IRInst* context, IRWrapExistential* inst) + { + inst->replaceUsesWith(inst->getWrappedValue()); inst->removeAndDeallocate(); return true; } @@ -3688,8 +3936,11 @@ struct TypeFlowSpecializationContext inst->getValue()); } - auto newInst = - builder.emitMakeTaggedUnion((IRType*)taggedUnionType, translatedTag, packedValue); + auto newInst = builder.emitMakeTaggedUnion( + (IRType*)taggedUnionType, + builder.emitPoison(makeTagType(taggedUnionType->getTypeSet())), + translatedTag, + packedValue); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -4142,15 +4393,23 @@ struct TypeFlowSpecializationContext SLANG_ASSERT(taggedUnionType->getWitnessTableSet()->isSingleton()); auto noneWitnessTable = taggedUnionType->getWitnessTableSet()->getElement(0); - auto singletonTagType = makeTagType(builder.getSingletonSet(noneWitnessTable)); - IRInst* zeroValueOfTagType = builder.emitGetTagOfElementInSet( - (IRType*)singletonTagType, + auto singletonWitnessTableTagType = + makeTagType(builder.getSingletonSet(noneWitnessTable)); + IRInst* tableTag = builder.emitGetTagOfElementInSet( + (IRType*)singletonWitnessTableTagType, noneWitnessTable, taggedUnionType->getWitnessTableSet()); + auto singletonTypeTagType = makeTagType(builder.getSingletonSet(builder.getVoidType())); + IRInst* typeTag = builder.emitGetTagOfElementInSet( + (IRType*)singletonTypeTagType, + builder.getVoidType(), + taggedUnionType->getTypeSet()); + auto newTuple = builder.emitMakeTaggedUnion( (IRType*)taggedUnionType, - zeroValueOfTagType, + typeTag, + tableTag, builder.emitDefaultConstruct(makeUntaggedUnionType(taggedUnionType->getTypeSet()))); inst->replaceUsesWith(newTuple); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index db46a6d04b0..1803fa108bb 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1731,6 +1731,7 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* case kIROp_GlobalConstant: case kIROp_Var: case kIROp_Param: + case kIROp_DebugVar: break; case kIROp_Call: return true; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 520c87d2705..f7dc4a7dc88 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6617,6 +6617,15 @@ UInt IRBuilder::getUniqueID(IRInst* inst) IROp IRBuilder::getSetTypeForInst(IRInst* inst) { + if (as(inst) || as(inst)) + return kIROp_TypeSet; + if (as(inst)) + return kIROp_FuncSet; + if (as(inst) || as(inst)) + return kIROp_WitnessTableSet; + if (as(inst)) + return kIROp_GenericSet; + if (as(inst)) return kIROp_GenericSet; @@ -8573,6 +8582,8 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_MakeTaggedUnion: case kIROp_GetTagOfElementInSet: case kIROp_UnboundedSet: + case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: return false; case kIROp_ForwardDifferentiate: diff --git a/tests/compute/dynamic-dispatch-1.slang b/tests/compute/dynamic-dispatch-1.slang index 10d3552ab88..f28704d47b3 100644 --- a/tests/compute/dynamic-dispatch-1.slang +++ b/tests/compute/dynamic-dispatch-1.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj // Test dynamic dispatch code gen for non-static member functions. diff --git a/tests/compute/dynamic-dispatch-10.slang b/tests/compute/dynamic-dispatch-10.slang index ddd51c4b920..9df6a6aaf37 100644 --- a/tests/compute/dynamic-dispatch-10.slang +++ b/tests/compute/dynamic-dispatch-10.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for specializing a generic with // an existential value. diff --git a/tests/compute/dynamic-dispatch-11.slang b/tests/compute/dynamic-dispatch-11.slang index 59e7ce58170..4b7a67324ad 100644 --- a/tests/compute/dynamic-dispatch-11.slang +++ b/tests/compute/dynamic-dispatch-11.slang @@ -6,8 +6,8 @@ //DISABLE_TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj [anyValueSize(8)] interface IInterface diff --git a/tests/compute/dynamic-dispatch-2.slang b/tests/compute/dynamic-dispatch-2.slang index 06aed6e71c9..3c83ee2a0f8 100644 --- a/tests/compute/dynamic-dispatch-2.slang +++ b/tests/compute/dynamic-dispatch-2.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for static member functions // of associated type. diff --git a/tests/compute/dynamic-dispatch-3.slang b/tests/compute/dynamic-dispatch-3.slang index 94430e638ab..2f8ca6504d3 100644 --- a/tests/compute/dynamic-dispatch-3.slang +++ b/tests/compute/dynamic-dispatch-3.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for static member functions // of associated type. diff --git a/tests/compute/dynamic-dispatch-4.slang b/tests/compute/dynamic-dispatch-4.slang index 996e706655b..6cb1c50d083 100644 --- a/tests/compute/dynamic-dispatch-4.slang +++ b/tests/compute/dynamic-dispatch-4.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for generic-typed local variables. diff --git a/tests/compute/dynamic-dispatch-5.slang b/tests/compute/dynamic-dispatch-5.slang index 61d4024af5a..4a7e12f9e5b 100644 --- a/tests/compute/dynamic-dispatch-5.slang +++ b/tests/compute/dynamic-dispatch-5.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for general `This` type. [anyValueSize(8)] diff --git a/tests/compute/dynamic-dispatch-6.slang b/tests/compute/dynamic-dispatch-6.slang index 69edb51bc56..6364274d57c 100644 --- a/tests/compute/dynamic-dispatch-6.slang +++ b/tests/compute/dynamic-dispatch-6.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for generic-typed return values. [anyValueSize(8)] diff --git a/tests/compute/dynamic-dispatch-7.slang b/tests/compute/dynamic-dispatch-7.slang index 7f516b0d556..1e5d7a639e0 100644 --- a/tests/compute/dynamic-dispatch-7.slang +++ b/tests/compute/dynamic-dispatch-7.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for associated-typed return values // and local variables. diff --git a/tests/compute/dynamic-dispatch-8.slang b/tests/compute/dynamic-dispatch-8.slang index 62d3f0a9d41..77faf771928 100644 --- a/tests/compute/dynamic-dispatch-8.slang +++ b/tests/compute/dynamic-dispatch-8.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for extential type parameters. diff --git a/tests/compute/dynamic-dispatch-9.slang b/tests/compute/dynamic-dispatch-9.slang index 77112a318b7..de213dc4224 100644 --- a/tests/compute/dynamic-dispatch-9.slang +++ b/tests/compute/dynamic-dispatch-9.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE:-dx11 -shaderobj //TEST(compute):COMPARE_COMPUTE:-vk -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -shaderobj -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cuda -shaderobj // Test dynamic dispatch code gen for initializing an extential value // from a generic value. diff --git a/tests/compute/dynamic-generics-simple.slang b/tests/compute/dynamic-generics-simple.slang index c3db4542041..ecf18486927 100644 --- a/tests/compute/dynamic-generics-simple.slang +++ b/tests/compute/dynamic-generics-simple.slang @@ -1,5 +1,5 @@ -//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -disable-specialization -//TEST(compute):COMPARE_COMPUTE:-cuda -xslang -disable-specialization +//TEST(compute):COMPARE_COMPUTE:-cpu +//TEST(compute):COMPARE_COMPUTE:-cuda // Test basic dynamic dispatch code gen diff --git a/tests/diagnostics/interfaces/anyvalue-size-validation.slang b/tests/diagnostics/interfaces/anyvalue-size-validation.slang index ffe968c678c..213c2718993 100644 --- a/tests/diagnostics/interfaces/anyvalue-size-validation.slang +++ b/tests/diagnostics/interfaces/anyvalue-size-validation.slang @@ -1,6 +1,6 @@ // anyvalue-size-validation.slang -//DIAGNOSTIC_TEST:SIMPLE:-target cpp -stage compute -entry main -disable-specialization +//TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute -conformance "A:IInterface=0" -conformance "B:IInterface=1" [anyValueSize(8)] interface IInterface @@ -8,7 +8,9 @@ interface IInterface int doSomething(); }; -struct S : IInterface +//CHECK: ([[# @LINE+2]]): error 41011: type 'A' does not fit in the size required by its conforming interface. +//CHECK-NEXT: struct A : IInterface +struct A : IInterface { uint a; uint b; @@ -16,16 +18,27 @@ struct S : IInterface int doSomething() { return 5; } }; -T test(T s) +//CHECK: note: sizeof(A) is 12, limit is 8 + +struct B : IInterface { - return s; -} + int doSomething() { return 6; } +}; + +//TEST_INPUT: +//TEST_INPUT: type_conformance B:IInterface = 1 -RWStructuredBuffer output; +RWStructuredBuffer data; +RWStructuredBuffer output; + +struct Data +{ + float x; + float y; +}; [numthreads(4, 1, 1)] -void main() +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - S s = S(1, 2, 3); - output[0] = test(s).a; + output[1] = data[0].doSomething(); } diff --git a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected index e9465167114..a95188a0bb6 100644 --- a/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected +++ b/tests/diagnostics/interfaces/anyvalue-size-validation.slang.expected @@ -1,9 +1,11 @@ -result code = -1 +result code = 1 standard error = { -tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): error 41011: type 'S' does not fit in the size required by its conforming interface. -struct S : IInterface - ^ -tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note: sizeof(S) is 12, limit is 8 } standard output = { } +debug layer = { +tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): error 41011: type 'A' does not fit in the size required by its conforming interface. +struct A : IInterface + ^ +tests/diagnostics/interfaces/anyvalue-size-validation.slang(11): note: sizeof(A) is 12, limit is 8 +} diff --git a/tests/diagnostics/interfaces/interface-extension.slang b/tests/diagnostics/interfaces/interface-extension.slang index b63b454abdc..7c42b3874e6 100644 --- a/tests/diagnostics/interfaces/interface-extension.slang +++ b/tests/diagnostics/interfaces/interface-extension.slang @@ -1,4 +1,4 @@ -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main -disable-specialization +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target cpp -stage compute -entry main interface IFoo{} diff --git a/tests/diagnostics/no-type-conformance.slang b/tests/diagnostics/no-type-conformance.slang index ada4f663a75..05ddc61db4d 100644 --- a/tests/diagnostics/no-type-conformance.slang +++ b/tests/diagnostics/no-type-conformance.slang @@ -1,6 +1,6 @@ -//DIAGNOSTIC_TEST:COMMAND_LINE_SIMPLE:-target hlsl -entry computeMain -stage compute -o no-type-conformance.hlsl -// no type conformance linked +//TEST:SIMPLE(filecheck=CHECK):-target hlsl -entry computeMain -stage compute -o no-type-conformance.hlsl +//CHECK: error 50100: No type conformances are found for interface 'IFoo' interface IFoo { float get(); From c92ed7ecf0783f116ee90d5163c65c9586c8372e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 30 Oct 2025 16:47:47 -0400 Subject: [PATCH 085/105] Remove the old generics pass completely --- source/slang/slang-emit.cpp | 18 +- source/slang/slang-ir-any-value-inference.cpp | 1 - .../slang/slang-ir-any-value-marshalling.cpp | 79 +-- source/slang/slang-ir-any-value-marshalling.h | 4 +- .../slang-ir-generics-lowering-context.cpp | 443 ----------------- .../slang-ir-generics-lowering-context.h | 171 ------- source/slang/slang-ir-layout.cpp | 6 +- source/slang/slang-ir-loop-inversion.cpp | 1 - source/slang/slang-ir-lower-existential.cpp | 315 ------------ source/slang/slang-ir-lower-existential.h | 11 - source/slang/slang-ir-lower-generic-call.cpp | 404 ---------------- source/slang/slang-ir-lower-generic-call.h | 12 - .../slang/slang-ir-lower-generic-function.cpp | 449 ------------------ .../slang/slang-ir-lower-generic-function.h | 19 - source/slang/slang-ir-lower-generic-type.cpp | 93 ---- source/slang/slang-ir-lower-generic-type.h | 12 - source/slang/slang-ir-lower-generics.cpp | 289 ----------- source/slang/slang-ir-lower-generics.h | 19 - .../slang/slang-ir-lower-typeflow-insts.cpp | 208 +++++++- .../slang/slang-ir-lower-witness-lookup.cpp | 446 ----------------- source/slang/slang-ir-lower-witness-lookup.h | 15 - source/slang/slang-ir-specialize-dispatch.cpp | 328 ------------- source/slang/slang-ir-specialize-dispatch.h | 13 - ...ecialize-dynamic-associatedtype-lookup.cpp | 272 ----------- ...specialize-dynamic-associatedtype-lookup.h | 15 - source/slang/slang-ir-specialize.cpp | 1 - source/slang/slang-ir-typeflow-specialize.cpp | 31 +- source/slang/slang-ir-util.cpp | 19 + source/slang/slang-ir-util.h | 7 + .../slang/slang-ir-witness-table-wrapper.cpp | 341 ------------- source/slang/slang-ir-witness-table-wrapper.h | 30 -- .../dynamic-dispatch/generic-method.slang | 2 +- 32 files changed, 297 insertions(+), 3777 deletions(-) delete mode 100644 source/slang/slang-ir-generics-lowering-context.cpp delete mode 100644 source/slang/slang-ir-generics-lowering-context.h delete mode 100644 source/slang/slang-ir-lower-existential.cpp delete mode 100644 source/slang/slang-ir-lower-existential.h delete mode 100644 source/slang/slang-ir-lower-generic-call.cpp delete mode 100644 source/slang/slang-ir-lower-generic-call.h delete mode 100644 source/slang/slang-ir-lower-generic-function.cpp delete mode 100644 source/slang/slang-ir-lower-generic-function.h delete mode 100644 source/slang/slang-ir-lower-generic-type.cpp delete mode 100644 source/slang/slang-ir-lower-generic-type.h delete mode 100644 source/slang/slang-ir-lower-generics.cpp delete mode 100644 source/slang/slang-ir-lower-generics.h delete mode 100644 source/slang/slang-ir-lower-witness-lookup.cpp delete mode 100644 source/slang/slang-ir-lower-witness-lookup.h delete mode 100644 source/slang/slang-ir-specialize-dispatch.cpp delete mode 100644 source/slang/slang-ir-specialize-dispatch.h delete mode 100644 source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp delete mode 100644 source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h delete mode 100644 source/slang/slang-ir-witness-table-wrapper.cpp delete mode 100644 source/slang/slang-ir-witness-table-wrapper.h diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 764d44c29a0..028cc6af627 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -22,6 +22,7 @@ #include "slang-emit-vm.h" #include "slang-emit-wgsl.h" #include "slang-ir-any-value-inference.h" +#include "slang-ir-any-value-marshalling.h" #include "slang-ir-autodiff.h" #include "slang-ir-bind-existentials.h" #include "slang-ir-byte-address-legalize.h" @@ -79,7 +80,6 @@ #include "slang-ir-lower-coopvec.h" #include "slang-ir-lower-dynamic-resource-heap.h" #include "slang-ir-lower-enum-type.h" -#include "slang-ir-lower-generics.h" #include "slang-ir-lower-glsl-ssbo-types.h" #include "slang-ir-lower-l-value-cast.h" #include "slang-ir-lower-optional-type.h" @@ -1199,15 +1199,13 @@ Result linkAndOptimizeIR( SLANG_RETURN_ON_FAIL(checkGetStringHashInsts(irModule, sink)); } - // For targets that supports dynamic dispatch, we need to lower the - // generics / interface types to ordinary functions and types using - // function pointers. - dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-LOWER-GENERICS"); - if (requiredLoweringPassSet.generics) - lowerGenerics(targetProgram, irModule, sink); - else - cleanupGenerics(targetProgram, irModule, sink); - dumpIRIfEnabled(codeGenContext, irModule, "AFTER-LOWER-GENERICS"); + lowerTuples(irModule, sink); + if (sink->getErrorCount() != 0) + return SLANG_FAIL; + + generateAnyValueMarshallingFunctions(irModule); + if (sink->getErrorCount() != 0) + return SLANG_FAIL; // Don't need to run any further target-dependent passes if we are generating code // for host vm. diff --git a/source/slang/slang-ir-any-value-inference.cpp b/source/slang/slang-ir-any-value-inference.cpp index 831d20e49a5..00190c2b530 100644 --- a/source/slang/slang-ir-any-value-inference.cpp +++ b/source/slang/slang-ir-any-value-inference.cpp @@ -1,7 +1,6 @@ #include "slang-ir-any-value-inference.h" #include "../core/slang-func-ptr.h" -#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-util.h" diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 99f5698ddb0..dde2a4d3048 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -1,8 +1,8 @@ #include "slang-ir-any-value-marshalling.h" #include "../core/slang-math.h" -#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" #include "slang-ir.h" #include "slang-legalize-types.h" @@ -14,7 +14,36 @@ namespace Slang // functions. struct AnyValueMarshallingContext { - SharedGenericsLoweringContext* sharedContext; + IRModule* module; + + // We will use a single work list of instructions that need + // to be considered for lowering. + // + InstWorkList workList; + InstHashSet workListSet; + + AnyValueMarshallingContext(IRModule* module) + : module(module), workList(module), workListSet(module) + { + } + + void addToWorkList(IRInst* inst) + { + if (!inst) + return; + + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as(ii)) + return; + } + + if (workListSet.contains(inst)) + return; + + workList.add(inst); + workListSet.add(inst); + } // Stores information about generated `AnyValue` struct types. struct AnyValueTypeInfo : RefObject @@ -54,7 +83,7 @@ struct AnyValueMarshallingContext if (auto typeInfo = generatedAnyValueTypes.tryGetValue(size)) return typeInfo->Ptr(); RefPtr info = new AnyValueTypeInfo(); - IRBuilder builder(sharedContext->module); + IRBuilder builder(module); builder.setInsertBefore(type); auto structType = builder.createStructType(); info->type = structType; @@ -498,7 +527,7 @@ struct AnyValueMarshallingContext IRFunc* generatePackingFunc(IRType* type, IRAnyValueType* anyValueType) { - IRBuilder builder(sharedContext->module); + IRBuilder builder(module); builder.setInsertBefore(type); auto anyValInfo = ensureAnyValueType(anyValueType); @@ -767,7 +796,7 @@ struct AnyValueMarshallingContext IRFunc* generateUnpackingFunc(IRType* type, IRAnyValueType* anyValueType) { - IRBuilder builder(sharedContext->module); + IRBuilder builder(module); builder.setInsertBefore(type); auto anyValInfo = ensureAnyValueType(anyValueType); @@ -822,7 +851,7 @@ struct AnyValueMarshallingContext auto func = ensureMarshallingFunc( operand->getDataType(), cast(packInst->getDataType())); - IRBuilder builderStorage(sharedContext->module); + IRBuilder builderStorage(module); auto builder = &builderStorage; builder->setInsertBefore(packInst); auto callInst = builder->emitCallInst(packInst->getDataType(), func.packFunc, 1, &operand); @@ -836,7 +865,7 @@ struct AnyValueMarshallingContext auto func = ensureMarshallingFunc( unpackInst->getDataType(), cast(operand->getDataType())); - IRBuilder builderStorage(sharedContext->module); + IRBuilder builderStorage(module); auto builder = &builderStorage; builder->setInsertBefore(unpackInst); auto callInst = @@ -869,37 +898,35 @@ struct AnyValueMarshallingContext // since we will re-use that state for any code we // generate along the way. // - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + addToWorkList(module->getModuleInst()); - while (sharedContext->workList.getCount() != 0) + while (workList.getCount() != 0) { - IRInst* inst = sharedContext->workList.getLast(); + IRInst* inst = workList.getLast(); - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); + workList.removeLast(); + workListSet.remove(inst); processInst(inst); for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { - sharedContext->addToWorkList(child); + addToWorkList(child); } } // Finally, replace all `AnyValueType` with the actual struct type that implements it. - for (auto inst : sharedContext->module->getModuleInst()->getChildren()) + for (auto inst : module->getModuleInst()->getChildren()) { if (auto anyValueType = as(inst)) processAnyValueType(anyValueType); } - sharedContext->mapInterfaceRequirementKeyValue.clear(); } }; -void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext) +void generateAnyValueMarshallingFunctions(IRModule* module) { - AnyValueMarshallingContext context; - context.sharedContext = sharedContext; + AnyValueMarshallingContext context(module); context.processModule(); } @@ -1025,9 +1052,7 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) case kIROp_InterfaceType: { auto interfaceType = cast(type); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); size += kRTTIHeaderSize; return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } @@ -1045,18 +1070,14 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { auto thisType = cast(type); auto interfaceType = thisType->getConstraintType(); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } case kIROp_ExtractExistentialType: { auto existentialValue = type->getOperand(0); auto interfaceType = cast(existentialValue->getDataType()); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } case kIROp_LookupWitnessMethod: @@ -1091,9 +1112,7 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset) { anyValueSize = Math::Min( anyValueSize, - SharedGenericsLoweringContext::getInterfaceAnyValueSize( - assocType->getOperand(i), - type->sourceLoc)); + getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc)); } if (anyValueSize == kInvalidAnyValueSize) diff --git a/source/slang/slang-ir-any-value-marshalling.h b/source/slang/slang-ir-any-value-marshalling.h index 941dae5e8c2..bfebf1b0123 100644 --- a/source/slang/slang-ir-any-value-marshalling.h +++ b/source/slang/slang-ir-any-value-marshalling.h @@ -6,13 +6,13 @@ namespace Slang { struct IRType; -struct SharedGenericsLoweringContext; +struct IRModule; /// Generates functions that pack and unpack `AnyValue`s, and replaces /// all `IRPackAnyValue` and `IRUnpackAnyValue` instructions with calls /// to these packing/unpacking functions. /// This is a sub-pass of lower-generics. -void generateAnyValueMarshallingFunctions(SharedGenericsLoweringContext* sharedContext); +void generateAnyValueMarshallingFunctions(IRModule* module); /// Get the AnyValue size required to hold a value of `type`. diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp deleted file mode 100644 index 0dee5631a93..00000000000 --- a/source/slang/slang-ir-generics-lowering-context.cpp +++ /dev/null @@ -1,443 +0,0 @@ -// slang-ir-generics-lowering-context.cpp - -#include "slang-ir-generics-lowering-context.h" - -#include "slang-ir-layout.h" -#include "slang-ir-util.h" - -namespace Slang -{ -bool isPolymorphicType(IRInst* typeInst) -{ - if (as(typeInst) && as(typeInst->getDataType())) - return true; - switch (typeInst->getOp()) - { - case kIROp_ThisType: - case kIROp_AssociatedType: - case kIROp_InterfaceType: - case kIROp_LookupWitnessMethod: - return true; - case kIROp_Specialize: - { - for (UInt i = 0; i < typeInst->getOperandCount(); i++) - { - if (isPolymorphicType(typeInst->getOperand(i))) - return true; - } - return false; - } - default: - break; - } - if (auto ptrType = as(typeInst)) - { - return isPolymorphicType(ptrType->getValueType()); - } - return false; -} - -bool isTypeValue(IRInst* typeInst) -{ - if (typeInst) - { - switch (typeInst->getOp()) - { - case kIROp_TypeType: - case kIROp_TypeKind: - return true; - default: - return false; - } - } - return false; -} - -IRInst* SharedGenericsLoweringContext::maybeEmitRTTIObject(IRInst* typeInst) -{ - IRInst* result = nullptr; - if (mapTypeToRTTIObject.tryGetValue(typeInst, result)) - return result; - IRBuilder builderStorage(module); - auto builder = &builderStorage; - builder->setInsertAfter(typeInst); - - result = builder->emitMakeRTTIObject(typeInst); - - // For now the only type info we encapsualte is type size. - IRSizeAndAlignment sizeAndAlignment; - getNaturalSizeAndAlignment(targetProgram->getOptionSet(), (IRType*)typeInst, &sizeAndAlignment); - builder->addRTTITypeSizeDecoration(result, sizeAndAlignment.size); - - // Give a name to the rtti object. - if (auto exportDecoration = typeInst->findDecoration()) - { - String rttiObjName = exportDecoration->getMangledName(); - builder->addExportDecoration(result, rttiObjName.getUnownedSlice()); - } - - // Make sure the RTTI object for an exported struct type is marked as export if the type is. - if (typeInst->findDecoration()) - { - builder->addHLSLExportDecoration(result); - builder->addKeepAliveDecoration(result); - } - mapTypeToRTTIObject[typeInst] = result; - return result; -} - -IRInst* SharedGenericsLoweringContext::findInterfaceRequirementVal( - IRInterfaceType* interfaceType, - IRInst* requirementKey) -{ - if (auto dict = mapInterfaceRequirementKeyValue.tryGetValue(interfaceType)) - return dict->getValue(requirementKey); - _builldInterfaceRequirementMap(interfaceType); - return findInterfaceRequirementVal(interfaceType, requirementKey); -} - -void SharedGenericsLoweringContext::_builldInterfaceRequirementMap(IRInterfaceType* interfaceType) -{ - mapInterfaceRequirementKeyValue.add(interfaceType, Dictionary()); - auto dict = mapInterfaceRequirementKeyValue.tryGetValue(interfaceType); - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - auto entry = cast(interfaceType->getOperand(i)); - (*dict)[entry->getRequirementKey()] = entry->getRequirementVal(); - } -} - -IRType* SharedGenericsLoweringContext::lowerAssociatedType(IRBuilder* builder, IRInst* type) -{ - if (type->getOp() != kIROp_AssociatedType) - return (IRType*)type; - IRIntegerValue anyValueSize = kInvalidAnyValueSize; - for (UInt i = 0; i < type->getOperandCount(); i++) - { - anyValueSize = - Math::Min(anyValueSize, getInterfaceAnyValueSize(type->getOperand(i), type->sourceLoc)); - } - if (anyValueSize == kInvalidAnyValueSize) - { - // We could conceivably make it an error to have an associated type - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - anyValueSize = kDefaultAnyValueSize; - } - return builder->getAnyValueType(anyValueSize); -} - -IRType* SharedGenericsLoweringContext::lowerType( - IRBuilder* builder, - IRInst* paramType, - const Dictionary& typeMapping, - IRType* concreteType) -{ - if (!paramType) - return nullptr; - - IRInst* resultType; - if (typeMapping.tryGetValue(paramType, resultType)) - return (IRType*)resultType; - - if (isTypeValue(paramType)) - { - return builder->getRTTIHandleType(); - } - - switch (paramType->getOp()) - { - case kIROp_WitnessTableType: - case kIROp_WitnessTableIDType: - case kIROp_ExtractExistentialType: - // Do not translate these types. - return (IRType*)paramType; - case kIROp_Param: - { - if (auto anyValueSizeDecor = paramType->findDecoration()) - { - if (isBuiltin(anyValueSizeDecor->getConstraintType())) - return (IRType*)paramType; - auto anyValueSize = getInterfaceAnyValueSize( - anyValueSizeDecor->getConstraintType(), - paramType->sourceLoc); - return builder->getAnyValueType(anyValueSize); - } - // We could conceivably make it an error to have a generic parameter - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - return builder->getAnyValueType(kDefaultAnyValueSize); - } - case kIROp_ThisType: - { - auto interfaceType = cast(paramType)->getConstraintType(); - - if (isBuiltin(interfaceType)) - return (IRType*)paramType; - - if (isComInterfaceType((IRType*)interfaceType)) - return (IRType*)interfaceType; - - auto anyValueSize = getInterfaceAnyValueSize( - cast(paramType)->getConstraintType(), - paramType->sourceLoc); - return builder->getAnyValueType(anyValueSize); - } - case kIROp_AssociatedType: - { - return lowerAssociatedType(builder, paramType); - } - case kIROp_InterfaceType: - { - if (isBuiltin(paramType)) - return (IRType*)paramType; - - if (isComInterfaceType((IRType*)paramType)) - return (IRType*)paramType; - - // In the dynamic-dispatch case, a value of interface type - // is going to be packed into the "any value" part of a tuple. - // The size of the "any value" part depends on the interface - // type (e.g., it might have an `[anyValueSize(8)]` attribute - // indicating that 8 bytes needs to be reserved). - // - auto anyValueSize = getInterfaceAnyValueSize(paramType, paramType->sourceLoc); - - // If there is a non-null `concreteType` parameter, then this - // interface type is one that has been statically bound (via - // specialization parameters) to hold a value of that concrete - // type. - // - IRType* pendingType = nullptr; - if (concreteType) - { - // Because static specialization is being used (at least in part), - // we do *not* have a guarantee that the `concreteType` is one - // that can fit into the `anyValueSize` of the interface. - // - // We will use the IR layout logic to see if we can compute - // a size for the type, which can lead to a few different outcomes: - // - // * If a size is computed successfully, and it is smaller than or - // equal to `anyValueSize`, then the concrete value will fit into - // the reserved area, and the layout will match the dynamic case. - // - // * If a size is computed successfully, and it is larger than - // `anyValueSize`, then the concrete value cannot fit into the - // reserved area, and it needs to be stored out-of-line. - // - // * If size cannot be computed, then that implies that the type - // includes non-ordinary data (e.g., a `Texture2D` on a D3D11 - // target), and cannot possible fit into the reserved area - // (which consists of only uniform bytes). In this case, the - // value must be stored out-of-line. - // - IRSizeAndAlignment sizeAndAlignment; - Result result = getNaturalSizeAndAlignment( - targetProgram->getOptionSet(), - concreteType, - &sizeAndAlignment); - if (SLANG_FAILED(result) || (sizeAndAlignment.size > anyValueSize)) - { - // If the value must be stored out-of-line, we construct - // a "pseudo pointer" to the concrete type, and the - // constructed tuple will contain such a pseudo pointer. - // - // Semantically, the pseudo pointer behaves a bit like - // a pointer to the concrete type, in that it can be - // (pseudo-)dereferenced to produce a value of the chosen - // type. - // - // In terms of layout, the pseudo pointer occupies no - // space in the parent tuple/type, and will be automatically - // moved out-of-line by a later type legalization pass. - // - pendingType = builder->getPseudoPtrType(concreteType); - } - } - - auto anyValueType = builder->getAnyValueType(anyValueSize); - auto witnessTableType = builder->getWitnessTableIDType((IRType*)paramType); - auto rttiType = builder->getRTTIHandleType(); - - IRType* tupleType = nullptr; - if (!pendingType) - { - // In the oridnary (dynamic) case, an existential type decomposes - // into a tuple of: - // - // (RTTI, witness table, any-value). - // - tupleType = builder->getTupleType(rttiType, witnessTableType, anyValueType); - } - else - { - // In the case where static specialization mandateds out-of-line storage, - // an existential type decomposes into a tuple of: - // - // (RTTI, witness table, pseudo pointer, any-value) - // - tupleType = - builder->getTupleType(rttiType, witnessTableType, pendingType, anyValueType); - // - // Note that in each of the cases, the third element of the tuple - // is a representation of the value being stored in the existential. - // - // Also note that each of these representations has the same - // size and alignment when only "ordinary" data is considered - // (the pseudo-pointer will eventually be legalized away, leaving - // behind a tuple with equivalent layout). - } - - return tupleType; - } - case kIROp_LookupWitnessMethod: - { - auto lookupInterface = static_cast(paramType); - auto witnessTableType = - as(lookupInterface->getWitnessTable()->getDataType()); - if (!witnessTableType) - return (IRType*)paramType; - auto interfaceType = as(witnessTableType->getConformanceType()); - if (!interfaceType || isBuiltin(interfaceType)) - return (IRType*)paramType; - // Make sure we are looking up inside the original interface type (prior to lowering). - // Only in the original interface type will an associated type entry have an - // IRAssociatedType value. We need to extract AnyValueSize from this IRAssociatedType. - // In lowered interface type, that entry is lowered into an Ptr(RTTIType) and this info - // is lost. - mapLoweredInterfaceToOriginal.tryGetValue(interfaceType, interfaceType); - auto reqVal = - findInterfaceRequirementVal(interfaceType, lookupInterface->getRequirementKey()); - SLANG_ASSERT(reqVal && reqVal->getOp() == kIROp_AssociatedType); - return lowerType(builder, reqVal, typeMapping, nullptr); - } - case kIROp_BoundInterfaceType: - { - // A bound interface type represents an existential together with - // static knowledge that the value stored in the extistential has - // a particular concrete type. - // - // We handle this case by lowering the underlying interface type, - // but pass along the concrete type so that it can impact the - // layout of the interface type. - // - auto boundInterfaceType = static_cast(paramType); - return lowerType( - builder, - boundInterfaceType->getInterfaceType(), - typeMapping, - boundInterfaceType->getConcreteType()); - } - default: - { - bool translated = false; - List loweredOperands; - for (UInt i = 0; i < paramType->getOperandCount(); i++) - { - loweredOperands.add( - lowerType(builder, paramType->getOperand(i), typeMapping, nullptr)); - if (loweredOperands.getLast() != paramType->getOperand(i)) - translated = true; - } - if (translated) - return builder->getType( - paramType->getOp(), - loweredOperands.getCount(), - loweredOperands.getBuffer()); - return (IRType*)paramType; - } - } -} - -List getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType) -{ - List witnessTables; - for (auto globalInst : module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_WitnessTable && - cast(globalInst->getDataType())->getConformanceType() == - interfaceType) - { - witnessTables.add(cast(globalInst)); - } - } - return witnessTables; -} - -List SharedGenericsLoweringContext::getWitnessTablesFromInterfaceType( - IRInst* interfaceType) -{ - return Slang::getWitnessTablesFromInterfaceType(module, interfaceType); -} - -IRIntegerValue SharedGenericsLoweringContext::getInterfaceAnyValueSize( - IRInst* type, - SourceLoc usageLoc) -{ - SLANG_UNUSED(usageLoc); - - if (auto decor = type->findDecoration()) - { - return decor->getSize(); - } - - // We could conceivably make it an error to have an interface - // without an `[anyValueSize(...)]` attribute, but then we risk - // producing error messages even when doing 100% static specialization. - // - // It is simpler to use a reasonable default size and treat any - // type without an explicit attribute as using that size. - // - return kDefaultAnyValueSize; -} - - -bool SharedGenericsLoweringContext::doesTypeFitInAnyValue( - IRType* concreteType, - IRInterfaceType* interfaceType, - IRIntegerValue* outTypeSize, - IRIntegerValue* outLimit, - bool* outIsTypeOpaque) -{ - auto anyValueSize = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); - if (outLimit) - *outLimit = anyValueSize; - - if (!areResourceTypesBindlessOnTarget(targetProgram->getTargetReq())) - { - IRType* opaqueType = nullptr; - if (isOpaqueType(concreteType, &opaqueType)) - { - if (outIsTypeOpaque) - *outIsTypeOpaque = true; - return false; - } - } - IRSizeAndAlignment sizeAndAlignment; - Result result = - getNaturalSizeAndAlignment(targetProgram->getOptionSet(), concreteType, &sizeAndAlignment); - if (outTypeSize) - *outTypeSize = sizeAndAlignment.size; - - if (SLANG_FAILED(result) || (sizeAndAlignment.size > anyValueSize)) - { - // The value does not fit, either because it is too large, - // or because it includes types that cannot be stored - // in uniform/ordinary memory for this target. - // - return false; - } - - return true; -} - -} // namespace Slang diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h deleted file mode 100644 index 62848c4b7dd..00000000000 --- a/source/slang/slang-ir-generics-lowering-context.h +++ /dev/null @@ -1,171 +0,0 @@ -// slang-ir-generics-lowering-context.h -#pragma once - -#include "slang-ir-dce.h" -#include "slang-ir-insts.h" -#include "slang-ir-lower-generics.h" -#include "slang-ir.h" - -namespace Slang -{ -struct IRModule; - -constexpr IRIntegerValue kInvalidAnyValueSize = 0xFFFFFFFF; -constexpr IRIntegerValue kDefaultAnyValueSize = 16; -constexpr SlangInt kRTTIHeaderSize = 16; -constexpr SlangInt kRTTIHandleSize = 8; - -struct SharedGenericsLoweringContext -{ - // For convenience, we will keep a pointer to the module - // we are processing. - IRModule* module; - - TargetProgram* targetProgram; - - DiagnosticSink* sink; - - // RTTI objects for each type used to call a generic function. - OrderedDictionary mapTypeToRTTIObject; - - Dictionary loweredGenericFunctions; - Dictionary loweredInterfaceTypes; - Dictionary mapLoweredInterfaceToOriginal; - - // Dictionaries for interface type requirement key-value lookups. - // Used by `findInterfaceRequirementVal`. - Dictionary> mapInterfaceRequirementKeyValue; - - // Map from interface requirement keys to its corresponding dispatch method. - OrderedDictionary mapInterfaceRequirementKeyToDispatchMethods; - - // We will use a single work list of instructions that need - // to be considered for lowering. - // - InstWorkList workList; - InstHashSet workListSet; - - SharedGenericsLoweringContext(IRModule* inModule) - : module(inModule), workList(inModule), workListSet(inModule) - { - } - - void addToWorkList(IRInst* inst) - { - if (!inst) - return; - - for (auto ii = inst->getParent(); ii; ii = ii->getParent()) - { - if (as(ii)) - return; - } - - if (workListSet.contains(inst)) - return; - - workList.add(inst); - workListSet.add(inst); - } - - - void _builldInterfaceRequirementMap(IRInterfaceType* interfaceType); - - IRInst* findInterfaceRequirementVal(IRInterfaceType* interfaceType, IRInst* requirementKey); - - // Emits an IRRTTIObject containing type information for a given type. - IRInst* maybeEmitRTTIObject(IRInst* typeInst); - - static IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc); - static IRType* lowerAssociatedType(IRBuilder* builder, IRInst* type); - - IRType* lowerType( - IRBuilder* builder, - IRInst* paramType, - const Dictionary& typeMapping, - IRType* concreteType); - - IRType* lowerType(IRBuilder* builder, IRInst* paramType) - { - return lowerType(builder, paramType, Dictionary(), nullptr); - } - - // Get a list of all witness tables whose conformance type is `interfaceType`. - List getWitnessTablesFromInterfaceType(IRInst* interfaceType); - - /// Does the given `concreteType` fit within the any-value size deterined by `interfaceType`? - bool doesTypeFitInAnyValue( - IRType* concreteType, - IRInterfaceType* interfaceType, - IRIntegerValue* outTypeSize = nullptr, - IRIntegerValue* outLimit = nullptr, - bool* outIsTypeOpaque = nullptr); -}; - -List getWitnessTablesFromInterfaceType(IRModule* module, IRInst* interfaceType); - -bool isPolymorphicType(IRInst* typeInst); - -// Returns true if typeInst represents a type and should be lowered into -// Ptr(RTTIType). -bool isTypeValue(IRInst* typeInst); - -template -void workOnModule(SharedGenericsLoweringContext* sharedContext, const TFunc& func) -{ - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - func(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } -} - -template -void workOnCallGraph(SharedGenericsLoweringContext* sharedContext, const TFunc& func) -{ - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - IRDeadCodeEliminationOptions dceOptions; - dceOptions.keepExportsAlive = true; - dceOptions.keepLayoutsAlive = true; - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - - sharedContext->addToWorkList(inst->parent); - sharedContext->addToWorkList(inst->getFullType()); - - UInt operandCount = inst->getOperandCount(); - for (UInt ii = 0; ii < operandCount; ++ii) - { - if (!isWeakReferenceOperand(inst, ii)) - sharedContext->addToWorkList(inst->getOperand(ii)); - } - - if (auto call = as(inst)) - { - if (func(call)) - return; - } - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - if (shouldInstBeLiveIfParentIsLive(child, dceOptions)) - sharedContext->addToWorkList(child); - } - } -} -} // namespace Slang diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index f282be81a48..614ca2baa63 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -1,8 +1,8 @@ // slang-ir-layout.cpp #include "slang-ir-layout.h" -#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" // This file implements facilities for computing and caching layout // information on IR types. @@ -310,9 +310,7 @@ Result IRTypeLayoutRules::calcSizeAndAlignment( case kIROp_InterfaceType: { auto interfaceType = cast(type); - auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize( - interfaceType, - interfaceType->sourceLoc); + auto size = getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); size += kRTTIHeaderSize; size = align(size, 4); IRSizeAndAlignment resultLayout; diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp index e7c24c511e1..2f255d10432 100644 --- a/source/slang/slang-ir-loop-inversion.cpp +++ b/source/slang/slang-ir-loop-inversion.cpp @@ -3,7 +3,6 @@ #include "slang-ir-clone.h" #include "slang-ir-dominators.h" #include "slang-ir-insts.h" -#include "slang-ir-lower-witness-lookup.h" #include "slang-ir-reachability.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" diff --git a/source/slang/slang-ir-lower-existential.cpp b/source/slang/slang-ir-lower-existential.cpp deleted file mode 100644 index a3ed369837b..00000000000 --- a/source/slang/slang-ir-lower-existential.cpp +++ /dev/null @@ -1,315 +0,0 @@ -// slang-ir-lower-generic-existential.cpp - -#include "slang-ir-lower-existential.h" - -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" -/* -namespace Slang -{ -bool isCPUTarget(TargetRequest* targetReq); -bool isCUDATarget(TargetRequest* targetReq); - -struct ExistentialLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - void processMakeExistential(IRMakeExistentialWithRTTI* inst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - auto value = inst->getWrappedValue(); - auto valueType = sharedContext->lowerType(builder, value->getDataType()); - if (valueType->getOp() == kIROp_ComPtrType) - return; - auto witnessTableType = - cast(inst->getWitnessTable()->getDataType()); - auto interfaceType = witnessTableType->getConformanceType(); - if (interfaceType->findDecoration()) - return; - auto witnessTableIdType = builder->getWitnessTableIDType((IRType*)interfaceType); - auto anyValueSize = sharedContext->getInterfaceAnyValueSize(interfaceType, inst->sourceLoc); - auto anyValueType = builder->getAnyValueType(anyValueSize); - auto rttiType = builder->getRTTIHandleType(); - auto tupleType = builder->getTupleType(rttiType, witnessTableIdType, anyValueType); - - IRInst* rttiObject = inst->getRTTI(); - if (auto type = as(rttiObject)) - { - rttiObject = sharedContext->maybeEmitRTTIObject(type); - rttiObject = builder->emitGetAddress(rttiType, rttiObject); - } - IRInst* packedValue = value; - if (valueType->getOp() != kIROp_AnyValueType) - packedValue = builder->emitPackAnyValue(anyValueType, value); - IRInst* tupleArgs[] = {rttiObject, inst->getWitnessTable(), packedValue}; - auto tuple = builder->emitMakeTuple(tupleType, 3, tupleArgs); - inst->replaceUsesWith(tuple); - inst->removeAndDeallocate(); - } - - // Translates `createExistentialObject` insts, which takes a user defined - // type id and user defined value and turns into an existential value, - // into a `makeTuple` inst that makes the tuple representing the lowered - // existential value. - void processCreateExistentialObject(IRCreateExistentialObject* inst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - - // The result type of this `createExistentialObject` inst should already - // be lowered into a `TupleType(rttiType, WitnessTableIDType, AnyValueType)` - // in the previous `lowerGenericType` pass. - auto tupleType = inst->getDataType(); - auto witnessTableIdType = cast(tupleType->getOperand(1)); - auto anyValueType = cast(tupleType->getOperand(2)); - - // Create a standin value for `rttiObject` for now since it will not be used - // other than test for null in the case of `Optional`. - auto uint2Type = builder->getVectorType( - builder->getUIntType(), - builder->getIntValue(builder->getIntType(), 2)); - IRInst* standinVal = builder->getIntValue(builder->getUIntType(), 0xFFFFFFFF); - IRInst* zero = builder->getIntValue(builder->getUIntType(), 0); - IRInst* standinRTTIVectorArgs[] = {standinVal, zero}; - IRInst* rttiObject = builder->emitMakeVector(uint2Type, 2, standinRTTIVectorArgs); - - // Pack the user provided value into `AnyValue`. - IRInst* packedValue = inst->getValue(); - if (packedValue->getDataType()->getOp() != kIROp_AnyValueType) - packedValue = builder->emitPackAnyValue(anyValueType, packedValue); - - // Use the user provided `typeID` value as the witness table ID field in the - // newly constructed tuple. - // All `WitnessTableID` types are lowered into `uint2`s, so we need to create - // a `uint2` value from `typeID` to stay consistent with the convention. - IRInst* vectorArgs[2] = { - inst->getTypeID(), - builder->getIntValue(builder->getUIntType(), 0)}; - - IRInst* typeIdValue = builder->emitMakeVector(uint2Type, 2, vectorArgs); - typeIdValue = builder->emitBitCast(witnessTableIdType, typeIdValue); - IRInst* tupleArgs[] = {rttiObject, typeIdValue, packedValue}; - auto tuple = builder->emitMakeTuple(tupleType, 3, tupleArgs); - inst->replaceUsesWith(tuple); - inst->removeAndDeallocate(); - } - - IRInst* extractTupleElement(IRBuilder* builder, IRInst* value, UInt index) - { - auto tupleType = cast(sharedContext->lowerType(builder, value->getDataType())); - auto getElement = - builder->emitGetTupleElement((IRType*)tupleType->getOperand(index), value, index); - return getElement; - } - - void processExtractExistentialElement(IRInst* extractInst, UInt elementId) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(extractInst); - - IRInst* element = nullptr; - if (isComInterfaceType(extractInst->getOperand(0)->getDataType())) - { - // If this is an COM interface, the elements (witness table/rtti) are just the interface - // value itself. - element = extractInst->getOperand(0); - } - else - { - element = extractTupleElement(builder, extractInst->getOperand(0), elementId); - } - extractInst->replaceUsesWith(element); - extractInst->removeAndDeallocate(); - } - - void processExtractExistentialValue(IRExtractExistentialValue* inst) - { - processExtractExistentialElement(inst, 2); - } - - void processIsNullExistential(IRIsNullExistential* inst) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - - auto rttiElement = extractTupleElement(&builder, inst->getOperand(0), 0); - auto isNull = builder.emitNeq( - builder.emitGetElement(builder.getUIntType(), rttiElement, 0), - builder.getIntValue(builder.getUIntType(), 0)); - inst->replaceUsesWith(isNull); - inst->removeAndDeallocate(); - } - - void processExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) - { - processExtractExistentialElement(inst, 1); - } - - void processExtractExistentialType(IRExtractExistentialType* extractInst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(extractInst); - - IRInst* element = nullptr; - IRInst* anyValueType = nullptr; - if (isComInterfaceType(extractInst->getOperand(0)->getDataType())) - { - // If this is an COM interface, the elements (witness table/rtti) are just the interface - // value itself. - element = extractInst->getOperand(0); - } - else - { - element = extractTupleElement(builder, extractInst->getOperand(0), 0); - if (auto tupleType = as(extractInst->getOperand(0)->getDataType())) - { - anyValueType = tupleType->getOperand(2); - } - } - - // If this instruction is used as a type, we need to replace it with the lowered type, - // which should be an AnyValueType. - // If it is used as a value, then we can replace it with the extracted element. - auto isTypeUse = [](IRUse* use) -> bool - { - auto user = use->getUser(); - if (as(user)) - return true; - if (use == &use->getUser()->typeUse) - return true; - return false; - }; - traverseUses( - extractInst, - [&](IRUse* use) - { - if (anyValueType && isTypeUse(use)) - { - builder->replaceOperand(use, anyValueType); - return; - } - builder->replaceOperand(use, element); - }); - } - - void processGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst) - { - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - if (inst->getDataType()->getOp() == kIROp_ClassType) - { - return; - } - // A value of interface will lower as a tuple, and - // the third element of that tuple represents the - // concrete value that was put into the existential. - // - auto element = extractTupleElement(builder, inst->getOperand(0), 2); - auto elementType = element->getDataType(); - - // There are two cases we expect to see for that - // tuple element. - // - IRInst* replacement = nullptr; - if (as(elementType)) - { - // The first case is when legacy static specialization - // is applied, and the element is a "pseudo-pointer." - // - // Semantically, we should emit a (pseudo-)load from the pseudo-pointer - // to go from `PseudoPtr` to `T`. - // - // TODO: Actually introduce and emit a "psedudo-load" instruction - // here. For right now we are just using the value directly and - // downstream passes seem okay with it, but it isn't really - // type-correct to be doing this. - // - replacement = element; - } - else - { - // The second case is when the dynamic-dispatch layout is - // being used, and the element is an "any-value." - // - // In this case we need to emit an unpacking operation - // to get from `AnyValue` to `T`. - // - SLANG_ASSERT(as(elementType)); - replacement = builder->emitUnpackAnyValue(inst->getFullType(), element); - } - - inst->replaceUsesWith(replacement); - inst->removeAndDeallocate(); - } - - void processInst(IRInst* inst) - { - if (auto makeExistential = as(inst)) - { - processMakeExistential(makeExistential); - } - else if (auto createExistentialObject = as(inst)) - { - processCreateExistentialObject(createExistentialObject); - } - else if (auto getValueFromBoundInterface = as(inst)) - { - processGetValueFromBoundInterface(getValueFromBoundInterface); - } - else if (auto extractExistentialVal = as(inst)) - { - processExtractExistentialValue(extractExistentialVal); - } - else if (auto extractExistentialType = as(inst)) - { - processExtractExistentialType(extractExistentialType); - } - else if (auto extractExistentialWitnessTable = as(inst)) - { - processExtractExistentialWitnessTable(extractExistentialWitnessTable); - } - else if (auto isNullExistential = as(inst)) - { - processIsNullExistential(isNullExistential); - } - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - } -}; - - -void lowerExistentials(SharedGenericsLoweringContext* sharedContext) -{ - ExistentialLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} -} // namespace Slang - -*/ \ No newline at end of file diff --git a/source/slang/slang-ir-lower-existential.h b/source/slang/slang-ir-lower-existential.h deleted file mode 100644 index 73ae61f75a7..00000000000 --- a/source/slang/slang-ir-lower-existential.h +++ /dev/null @@ -1,11 +0,0 @@ -// slang-ir-lower-existential.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower existential types and related instructions to Tuple types. -void lowerExistentials(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp deleted file mode 100644 index e4fb5ac9814..00000000000 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ /dev/null @@ -1,404 +0,0 @@ -// slang-ir-lower-generic-call.cpp -#include "slang-ir-lower-generic-call.h" - -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-util.h" - -namespace Slang -{ -struct GenericCallLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - // Represents a work item for unpacking `inout` or `out` arguments after a generic call. - struct ArgumentUnpackWorkItem - { - // Concrete typed destination. - IRInst* dstArg = nullptr; - // Packed argument. - IRInst* packedArg = nullptr; - }; - - // Packs `arg` into a `IRAnyValue` if necessary, to make it feedable into the parameter. - // If `arg` represents a concrete typed variable passed in to a generic `out` parameter, - // this function indicates that it needs to be unpacked after the call by setting - // `unpackAfterCall`. - IRInst* maybePackArgument( - IRBuilder* builder, - IRType* paramType, - IRInst* arg, - ArgumentUnpackWorkItem& unpackAfterCall) - { - unpackAfterCall.dstArg = nullptr; - unpackAfterCall.packedArg = nullptr; - - // If either paramType or argType is a pointer type - // (because of `inout` or `out` modifiers), we extract - // the underlying value type first. - IRType* paramValType = paramType; - IRType* argValType = arg->getDataType(); - IRInst* argVal = arg; - if (auto ptrType = as(paramType)) - { - paramValType = ptrType->getValueType(); - } - auto argType = arg->getDataType(); - if (auto argPtrType = as(argType)) - { - argValType = argPtrType->getValueType(); - argVal = builder->emitLoad(arg); - } - - // Pack `arg` if the parameter expects AnyValue but - // `arg` is not an AnyValue. - if (as(paramValType) && !as(argValType)) - { - auto packedArgVal = builder->emitPackAnyValue(paramValType, argVal); - // if parameter expects an `out` pointer, store the packed val into a - // variable and pass in a pointer to that variable. - if (as(paramType)) - { - auto tempVar = builder->emitVar(paramValType); - builder->emitStore(tempVar, packedArgVal); - // tempVar needs to be unpacked into original var after the call. - unpackAfterCall.dstArg = arg; - unpackAfterCall.packedArg = tempVar; - return tempVar; - } - else - { - return packedArgVal; - } - } - return arg; - } - - IRInst* maybeUnpackValue( - IRBuilder* builder, - IRType* expectedType, - IRType* actualType, - IRInst* value) - { - if (as(actualType) && !as(expectedType)) - { - auto unpack = builder->emitUnpackAnyValue(expectedType, value); - return unpack; - } - return value; - } - - // Create a dispatch function for a interface method. - // On CPU, the dispatch function is implemented as a witness table lookup followed by - // a function-pointer call. - // On GPU targets, we can modify the body of the dispatch function in a follow-up - // pass to implement it with a `switch` statement based on the type ID. - IRFunc* _createInterfaceDispatchMethod( - IRBuilder* builder, - IRInterfaceType* interfaceType, - IRInst* requirementKey, - IRInst* requirementVal) - { - auto func = builder->createFunc(); - if (auto linkage = requirementKey->findDecoration()) - { - builder->addNameHintDecoration(func, linkage->getMangledName()); - } - - auto reqFuncType = cast(requirementVal); - List paramTypes; - paramTypes.add(builder->getWitnessTableType(interfaceType)); - for (UInt i = 0; i < reqFuncType->getParamCount(); i++) - { - paramTypes.add(reqFuncType->getParamType(i)); - } - auto dispatchFuncType = builder->getFuncType(paramTypes, reqFuncType->getResultType()); - func->setFullType(dispatchFuncType); - builder->setInsertInto(func); - builder->emitBlock(); - List params; - IRParam* witnessTableParam = builder->emitParam(paramTypes[0]); - for (Index i = 1; i < paramTypes.getCount(); i++) - { - params.add(builder->emitParam(paramTypes[i])); - } - auto callee = - builder->emitLookupInterfaceMethodInst(reqFuncType, witnessTableParam, requirementKey); - auto call = (IRCall*)builder->emitCallInst(reqFuncType->getResultType(), callee, params); - if (call->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(call); - return func; - } - - // If an interface dispatch method is already created, return it. - // Otherwise, create the method. - IRFunc* getOrCreateInterfaceDispatchMethod( - IRBuilder* builder, - IRInterfaceType* interfaceType, - IRInst* requirementKey, - IRInst* requirementVal) - { - if (auto func = sharedContext->mapInterfaceRequirementKeyToDispatchMethods.tryGetValue( - requirementKey)) - return *func; - auto dispatchFunc = - _createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal); - sharedContext->mapInterfaceRequirementKeyToDispatchMethods.addIfNotExists( - requirementKey, - dispatchFunc); - return dispatchFunc; - } - - // Translate `callInst` into a call of `newCallee`, and respect the new `funcType`. - // If `newCallee` is a lowered generic function, `specializeInst` contains the type - // arguments used to specialize the callee. - void translateCallInst( - IRCall* callInst, - IRFuncType* funcType, - IRInst* newCallee, - IRSpecialize* specializeInst) - { - List paramTypes; - for (UInt i = 0; i < funcType->getParamCount(); i++) - paramTypes.add(funcType->getParamType(i)); - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(callInst); - - // Process the argument list of the call. - // For each argument, we test if it needs to be packed into an `AnyValue` for the - // call. For `out` and `inout` parameters, they may also need to be unpacked after - // the call, in which case we add such the argument to `argsToUnpack` so it can be - // processed after the new call inst is emitted. - List args; - List argsToUnpack; - for (UInt i = 0; i < callInst->getArgCount(); i++) - { - auto arg = callInst->getArg(i); - ArgumentUnpackWorkItem unpackWorkItem; - auto newArg = maybePackArgument(builder, paramTypes[i], arg, unpackWorkItem); - args.add(newArg); - if (unpackWorkItem.packedArg) - argsToUnpack.add(unpackWorkItem); - } - if (specializeInst) - { - for (UInt i = 0; i < specializeInst->getArgCount(); i++) - { - auto arg = specializeInst->getArg(i); - // Translate Type arguments into RTTI object. - if (as(arg)) - { - // We are using a simple type to specialize a callee. - // Generate RTTI for this type. - auto rttiObject = sharedContext->maybeEmitRTTIObject(arg); - arg = builder->emitGetAddress(builder->getRTTIHandleType(), rttiObject); - } - else if (arg->getOp() == kIROp_Specialize) - { - // The type argument used to specialize a callee is itself a - // specialization of some generic type. - // TODO: generate RTTI object for specializations of generic types. - SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); - } - else if (arg->getOp() == kIROp_RTTIObject) - { - // We are inside a generic function and using a generic parameter - // to specialize another callee. The generic parameter of the caller - // has already been translated into an RTTI object, so we just need - // to pass this object down. - } - args.add(arg); - } - } - - // If callee returns `AnyValue` but we are expecting a concrete value, unpack it. - auto calleeRetType = funcType->getResultType(); - auto newCall = builder->emitCallInst(calleeRetType, newCallee, args); - auto callInstType = callInst->getDataType(); - auto unpackInst = maybeUnpackValue(builder, callInstType, calleeRetType, newCall); - // Unpack other `out` arguments. - for (auto& item : argsToUnpack) - { - auto packedVal = builder->emitLoad(item.packedArg); - auto originalValType = cast(item.dstArg->getDataType())->getValueType(); - auto unpackedVal = builder->emitUnpackAnyValue(originalValType, packedVal); - builder->emitStore(item.dstArg, unpackedVal); - } - callInst->replaceUsesWith(unpackInst); - callInst->removeAndDeallocate(); - } - - IRInst* findInnerMostSpecializingBase(IRSpecialize* inst) - { - auto result = inst->getBase(); - while (auto specialize = as(result)) - result = specialize->getBase(); - return result; - } - - void lowerCallToSpecializedFunc(IRCall* callInst, IRSpecialize* specializeInst) - { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - auto loweredFunc = specializeInst->getBase(); - - // Don't process intrinsic functions. - UnownedStringSlice intrinsicDef; - IRInst* intrinsicInst; - if (findTargetIntrinsicDefinition( - getResolvedInstForDecorations(loweredFunc), - sharedContext->targetProgram->getTargetReq()->getTargetCaps(), - intrinsicDef, - intrinsicInst)) - return; - - // All callees should have already been lowered in lower-generic-functions pass. - // For intrinsic generic functions, they are left as is, and we also need to ignore - // them here. - if (loweredFunc->getOp() == kIROp_Generic) - { - return; - } - else if (loweredFunc->getOp() == kIROp_Specialize) - { - // All nested generic functions are supposed to be flattend before this pass. - // If they are not, they represent an intrinsic function that should not be - // modified in this pass. - SLANG_UNEXPECTED("Nested generics specialization."); - } - else if (loweredFunc->getOp() == kIROp_LookupWitnessMethod) - { - lowerCallToInterfaceMethod( - callInst, - cast(loweredFunc), - specializeInst); - return; - } - IRFuncType* funcType = cast(loweredFunc->getDataType()); - translateCallInst(callInst, funcType, loweredFunc, specializeInst); - } - - void lowerCallToInterfaceMethod( - IRCall* callInst, - IRLookupWitnessMethod* lookupInst, - IRSpecialize* specializeInst) - { - // If we see a call(lookup_interface_method(...), ...), we need to translate - // all occurences of associatedtypes. - - // If `w` in `lookup_interface_method(w, ...)` is a COM interface, bail. - if (isComInterfaceType(lookupInst->getWitnessTable()->getDataType())) - { - return; - } - - auto interfaceType = as( - cast(lookupInst->getWitnessTable()->getDataType()) - ->getConformanceType()); - - if (!interfaceType) - { - // NoneWitness -> remove call. - callInst->removeAndDeallocate(); - return; - } - - if (isBuiltin(interfaceType)) - return; - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(callInst); - - // Create interface dispatch method that bottlenecks the dispatch logic. - auto requirementKey = lookupInst->getRequirementKey(); - auto requirementVal = - sharedContext->findInterfaceRequirementVal(interfaceType, requirementKey); - - if (interfaceType->findDecoration()) - { - sharedContext->sink->diagnose( - callInst->sourceLoc, - Diagnostics::dynamicDispatchOnSpecializeOnlyInterface, - interfaceType); - } - auto dispatchFunc = getOrCreateInterfaceDispatchMethod( - builder, - interfaceType, - requirementKey, - requirementVal); - - auto parentFunc = getParentFunc(callInst); - // Don't process the call inst that is the one in the dispatch function itself. - if (parentFunc == dispatchFunc) - return; - - // Replace `callInst` with a new call inst that calls `dispatchFunc` instead, and - // with the witness table as first argument, - builder->setInsertBefore(callInst); - List newArgs; - newArgs.add(lookupInst->getWitnessTable()); - for (UInt i = 0; i < callInst->getArgCount(); i++) - newArgs.add(callInst->getArg(i)); - auto newCall = - (IRCall*)builder->emitCallInst(callInst->getFullType(), dispatchFunc, newArgs); - callInst->replaceUsesWith(newCall); - callInst->removeAndDeallocate(); - - // Translate the new call inst as normal, taking care of packing/unpacking inputs - // and outputs. - translateCallInst( - newCall, - cast(dispatchFunc->getFullType()), - dispatchFunc, - specializeInst); - } - - void lowerCall(IRCall* callInst) - { - if (auto specializeInst = as(callInst->getCallee())) - lowerCallToSpecializedFunc(callInst, specializeInst); - else if (auto lookupInst = as(callInst->getCallee())) - lowerCallToInterfaceMethod(callInst, lookupInst, nullptr); - } - - void processInst(IRInst* inst) - { - if (auto callInst = as(inst)) - { - lowerCall(callInst); - } - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - } -}; - -void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext) -{ - GenericCallLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-call.h b/source/slang/slang-ir-lower-generic-call.h deleted file mode 100644 index a876634e6ed..00000000000 --- a/source/slang/slang-ir-lower-generic-call.h +++ /dev/null @@ -1,12 +0,0 @@ -// slang-ir-lower-generic-call.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower generic and interface-based code to ordinary types and functions using -/// dynamic dispatch mechanisms. -void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp deleted file mode 100644 index 6c2eb61939d..00000000000 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ /dev/null @@ -1,449 +0,0 @@ -// slang-ir-lower-generic-function.cpp -#include "slang-ir-lower-generic-function.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -// This is a subpass of generics lowering IR transformation. -// This pass lowers all generic function types and function definitions, including -// the function types used in interface types, to ordinary functions that takes -// raw pointers in place of generic types. -struct GenericFunctionLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - IRInst* lowerGenericFunction(IRInst* genericValue) - { - IRInst* result = nullptr; - if (sharedContext->loweredGenericFunctions.tryGetValue(genericValue, result)) - return result; - // Do not lower intrinsic functions. - if (genericValue->findDecoration()) - return genericValue; - auto genericParent = as(genericValue); - SLANG_ASSERT(genericParent); - SLANG_ASSERT(genericParent->getDataType()); - auto genericRetVal = findGenericReturnVal(genericParent); - auto func = as(genericRetVal); - if (!func) - { - // Nested generic functions are supposed to be flattened before entering - // this pass. The reason we are still seeing them must be that they are - // intrinsic functions. In this case we ignore the function. - if (as(genericRetVal)) - { - SLANG_ASSERT( - findInnerMostGenericReturnVal(genericParent) - ->findDecoration() != nullptr); - } - return genericValue; - } - SLANG_ASSERT(func); - // Do not lower intrinsic functions. - UnownedStringSlice intrinsicDef; - IRInst* intrinsicInst; - if (!func->isDefinition() || - findTargetIntrinsicDefinition( - func, - sharedContext->targetProgram->getTargetReq()->getTargetCaps(), - intrinsicDef, - intrinsicInst)) - { - sharedContext->loweredGenericFunctions[genericValue] = genericValue; - return genericValue; - } - IRCloneEnv cloneEnv; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(genericParent); - // Do not clone func type (which would break IR def-use rules if we do it here) - // This is OK since we will lower the type immediately after the clone. - cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind(); - auto loweredFunc = cast(cloneInstAndOperands(&cloneEnv, &builder, func)); - auto loweredGenericType = - lowerGenericFuncType(&builder, genericParent, cast(func->getFullType())); - SLANG_ASSERT(loweredGenericType); - loweredFunc->setFullType(loweredGenericType); - - OrderedHashSet childrenToDemote; - List clonedParams; - auto moduleInst = genericParent->getModule()->getModuleInst(); - for (auto genericChild : genericParent->getFirstBlock()->getChildren()) - { - switch (genericChild->getOp()) - { - case kIROp_Func: - continue; - case kIROp_Return: - continue; - } - // Process all generic parameters and local type definitions. - auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); - switch (clonedChild->getOp()) - { - case kIROp_Param: - { - auto paramType = clonedChild->getFullType(); - auto loweredParamType = sharedContext->lowerType(&builder, paramType); - if (loweredParamType != paramType) - { - clonedChild->setFullType((IRType*)loweredParamType); - } - clonedParams.add(clonedChild); - } - break; - case kIROp_Specialize: - case kIROp_LookupWitnessMethod: - childrenToDemote.add(clonedChild); - break; - default: - { - bool shouldDemote = false; - if (childrenToDemote.contains(clonedChild->getFullType())) - shouldDemote = true; - for (UInt i = 0; i < clonedChild->getOperandCount(); i++) - { - if (childrenToDemote.contains(clonedChild->getOperand(i))) - { - shouldDemote = true; - break; - } - } - if (shouldDemote && clonedChild->getParent() == moduleInst) - { - childrenToDemote.add(clonedChild); - } - continue; - } - } - } - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, func, loweredFunc); - - auto block = as(loweredFunc->getFirstChild()); - for (auto param : clonedParams) - { - param->removeFromParent(); - block->addParam(as(param)); - } - - // Demote specialize and lookupWitness insts and their dependents down to function body. - auto insertPoint = block->getFirstOrdinaryInst(); - List childrenToDemoteList; - for (auto child : childrenToDemote) - childrenToDemoteList.add(child); - for (Index i = childrenToDemoteList.getCount() - 1; i >= 0; i--) - { - auto child = childrenToDemoteList[i]; - child->insertBefore(insertPoint); - } - - // Lower generic typed parameters into AnyValueType. - auto firstInst = loweredFunc->getFirstOrdinaryInst(); - builder.setInsertBefore(firstInst); - sharedContext->loweredGenericFunctions[genericValue] = loweredFunc; - sharedContext->addToWorkList(loweredFunc); - return loweredFunc; - } - - IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType) - { - ShortList genericParamTypes; - Dictionary typeMapping; - for (auto genericParam : genericVal->getParams()) - { - genericParamTypes.add(sharedContext->lowerType(builder, genericParam->getFullType())); - if (auto anyValueSizeDecor = genericParam->findDecoration()) - { - auto anyValueSize = sharedContext->getInterfaceAnyValueSize( - anyValueSizeDecor->getConstraintType(), - genericParam->sourceLoc); - auto anyValueType = builder->getAnyValueType(anyValueSize); - typeMapping[genericParam] = anyValueType; - } - } - - auto innerType = (IRFuncType*)lowerFuncType( - builder, - funcType, - typeMapping, - genericParamTypes.getArrayView().arrayView); - - return innerType; - } - - IRType* lowerFuncType( - IRBuilder* builder, - IRFuncType* funcType, - const Dictionary& typeMapping, - ArrayView additionalParams) - { - List newOperands; - bool translated = false; - for (UInt i = 0; i < funcType->getOperandCount(); i++) - { - auto paramType = funcType->getOperand(i); - auto loweredParamType = - sharedContext->lowerType(builder, paramType, typeMapping, nullptr); - SLANG_ASSERT(loweredParamType); - translated = translated || (loweredParamType != paramType); - newOperands.add(loweredParamType); - } - if (!translated && additionalParams.getCount() == 0) - return funcType; - for (Index i = 0; i < additionalParams.getCount(); i++) - { - newOperands.add(additionalParams[i]); - } - auto newFuncType = builder->getFuncType( - newOperands.getCount() - 1, - (IRType**)(newOperands.begin() + 1), - (IRType*)newOperands[0]); - - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, funcType, newFuncType); - return newFuncType; - } - - IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) - { - IRInterfaceType* loweredType = nullptr; - if (sharedContext->loweredInterfaceTypes.tryGetValue(interfaceType, loweredType)) - return loweredType; - if (sharedContext->mapLoweredInterfaceToOriginal.containsKey(interfaceType)) - return interfaceType; - // Do not lower intrinsic interfaces. - if (isBuiltin(interfaceType)) - return interfaceType; - // Do not lower COM interfaces. - if (isComInterfaceType(interfaceType)) - return interfaceType; - - List newEntries; - - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(interfaceType); - - // Translate IRFuncType in interface requirements. - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - if (auto entry = as(interfaceType->getOperand(i))) - { - // Note: The logic that creates the `IRInterfaceRequirementEntry`s does - // not currently guarantee that the *value* part of each key-value pair - // gets filled in. We thus need to defend against a null `requirementVal` - // here, at least until the underlying issue gets resolved. - // - IRInst* requirementVal = entry->getRequirementVal(); - IRInst* loweredVal = nullptr; - if (!requirementVal) - { - } - else if (auto funcType = as(requirementVal)) - { - loweredVal = lowerFuncType( - &builder, - funcType, - Dictionary(), - ArrayView()); - } - else if (auto genericFuncType = as(requirementVal)) - { - loweredVal = lowerGenericFuncType( - &builder, - genericFuncType, - cast(findGenericReturnVal(genericFuncType))); - } - else if (requirementVal->getOp() == kIROp_AssociatedType) - { - loweredVal = builder.getRTTIHandleType(); - } - else - { - loweredVal = requirementVal; - } - auto newEntry = - builder.createInterfaceRequirementEntry(entry->getRequirementKey(), loweredVal); - newEntries.add(newEntry); - } - } - loweredType = - builder.createInterfaceType(newEntries.getCount(), (IRInst**)newEntries.getBuffer()); - loweredType->sourceLoc = interfaceType->sourceLoc; - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren( - &cloneEnv, - sharedContext->module, - interfaceType, - loweredType); - sharedContext->loweredInterfaceTypes.add(interfaceType, loweredType); - sharedContext->mapLoweredInterfaceToOriginal[loweredType] = interfaceType; - return loweredType; - } - - bool isTypeKindVal(IRInst* inst) - { - auto type = inst->getDataType(); - if (!type) - return false; - return type->getOp() == kIROp_TypeKind; - } - - // Lower items in a witness table. This triggers lowering of generic functions, - // and emission of wrapper functions. - void lowerWitnessTable(IRWitnessTable* witnessTable) - { - IRInterfaceType* conformanceType = as(witnessTable->getConformanceType()); - if (!conformanceType) - return; - - auto interfaceType = maybeLowerInterfaceType(conformanceType); - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(witnessTable); - if (interfaceType != witnessTable->getConformanceType()) - { - auto newWitnessTableType = builder->getWitnessTableType(interfaceType); - witnessTable->setFullType(newWitnessTableType); - } - if (isBuiltin(interfaceType)) - return; - for (auto child : witnessTable->getChildren()) - { - auto entry = as(child); - if (!entry) - continue; - if (auto genericVal = as(entry->getSatisfyingVal())) - { - // Lower generic functions in witness table. - if (findGenericReturnVal(genericVal)->getOp() == kIROp_Func) - { - auto loweredFunc = lowerGenericFunction(genericVal); - entry->satisfyingVal.set(loweredFunc); - } - } - else if (isTypeKindVal(entry->getSatisfyingVal())) - { - // Translate a Type value to an RTTI object pointer. - auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); - auto rttiObjectPtr = - builder->emitGetAddress(builder->getRTTIHandleType(), rttiObject); - entry->satisfyingVal.set(rttiObjectPtr); - } - else if (as(entry->getSatisfyingVal())) - { - // No processing needed here. - // The witness table will be processed from the work list. - } - } - } - - void lowerLookupInterfaceMethodInst(IRLookupWitnessMethod* lookupInst) - { - // Update the type of lookupInst to the lowered type of the corresponding interface - // requirement val. - - // If the requirement is a function, interfaceRequirementVal will be the lowered function - // type. If the requirement is an associatedtype, interfaceRequirementVal will be - // Ptr. - IRInst* interfaceRequirementVal = nullptr; - auto witnessTableType = - as(lookupInst->getWitnessTable()->getDataType()); - if (!witnessTableType) - return; - if (witnessTableType->getConformanceType()->findDecoration()) - return; - - IRInterfaceType* conformanceType = - as(witnessTableType->getConformanceType()); - - // NoneWitness generates conformance types which aren't interfaces. In - // that case, the method can just be skipped entirely, since there's no - // real witness for it and it should be in unreachable code at this - // point. - if (!conformanceType) - return; - auto interfaceType = maybeLowerInterfaceType(conformanceType); - interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( - interfaceType, - lookupInst->getRequirementKey()); - IRBuilder builder(lookupInst); - builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal); - } - - void lowerSpecialize(IRSpecialize* specializeInst) - { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - IRInst* loweredFunc = nullptr; - auto funcToSpecialize = specializeInst->getBase(); - if (funcToSpecialize->getOp() == kIROp_Generic) - { - loweredFunc = lowerGenericFunction(funcToSpecialize); - if (loweredFunc != funcToSpecialize) - { - IRBuilder builder; - builder.replaceOperand(specializeInst->getOperands(), loweredFunc); - } - } - } - - void processInst(IRInst* inst) - { - if (auto specializeInst = as(inst)) - { - lowerSpecialize(specializeInst); - } - else if (auto lookupInterfaceMethod = as(inst)) - { - lowerLookupInterfaceMethodInst(lookupInterfaceMethod); - } - else if (auto witnessTable = as(inst)) - { - lowerWitnessTable(witnessTable); - } - else if (auto interfaceType = as(inst)) - { - maybeLowerInterfaceType(interfaceType); - } - } - - void replaceLoweredInterfaceTypes() - { - for (const auto& [loweredKey, loweredValue] : sharedContext->loweredInterfaceTypes) - loweredKey->replaceUsesWith(loweredValue); - sharedContext->mapInterfaceRequirementKeyValue.clear(); - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - - replaceLoweredInterfaceTypes(); - } -}; -void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext) -{ - GenericFunctionLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-function.h b/source/slang/slang-ir-lower-generic-function.h deleted file mode 100644 index c70ad0ae6c6..00000000000 --- a/source/slang/slang-ir-lower-generic-function.h +++ /dev/null @@ -1,19 +0,0 @@ -// slang-ir-lower-generic-function.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower generic and interface-based code to ordinary types and functions using -/// dynamic dispatch mechanisms. -/// After this pass, generic type parameters will be lowered into `AnyValue` types, -/// and an existential type I in function signatures will be lowered into -/// `Tuple`. -/// Note that this pass mostly deals with function signatures and interface definitions, -/// and does not modify function bodies. -/// All variable declarations and type uses are handled in `lower-generic-type`, -/// and all call sites are handled in `lower-generic-call`. -void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-type.cpp b/source/slang/slang-ir-lower-generic-type.cpp deleted file mode 100644 index c4dcd92a9a1..00000000000 --- a/source/slang/slang-ir-lower-generic-type.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// slang-ir-lower-generic-type.cpp -#include "slang-ir-lower-generic-type.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir.h" - -namespace Slang -{ -// This is a subpass of generics lowering IR transformation. -// This pass lowers all generic/polymorphic types into IRAnyValueType. -struct GenericTypeLoweringContext -{ - SharedGenericsLoweringContext* sharedContext; - - IRInst* processInst(IRInst* inst) - { - // Ensure exported struct types has RTTI object defined. - if (as(inst)) - { - if (inst->findDecoration()) - { - sharedContext->maybeEmitRTTIObject(inst); - } - } - - // Don't modify type insts themselves. - if (as(inst)) - return inst; - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(inst); - - auto newType = sharedContext->lowerType(builder, inst->getFullType()); - if (newType != inst->getFullType()) - inst = builder->replaceOperand(&inst->typeUse, newType); - - switch (inst->getOp()) - { - default: - break; - case kIROp_StructField: - { - // Translate the struct field type. - auto structField = static_cast(inst); - auto loweredFieldType = - sharedContext->lowerType(builder, structField->getFieldType()); - structField->setOperand(1, loweredFieldType); - } - break; - case kIROp_DebugFunction: - { - auto oldFuncType = as(inst)->getDebugType(); - auto newFuncType = sharedContext->lowerType(builder, oldFuncType); - if (newFuncType != oldFuncType) - inst = builder->replaceOperand(inst->getOperandUse(4), newFuncType); - } - break; - } - return inst; - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - inst = processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - sharedContext->mapInterfaceRequirementKeyValue.clear(); - } -}; - -void lowerGenericType(SharedGenericsLoweringContext* sharedContext) -{ - GenericTypeLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generic-type.h b/source/slang/slang-ir-lower-generic-type.h deleted file mode 100644 index 20d4fa7b33c..00000000000 --- a/source/slang/slang-ir-lower-generic-type.h +++ /dev/null @@ -1,12 +0,0 @@ -// slang-ir-lower-generic-type.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Lower all references to generic types (ThisType, AssociatedType, etc.) into IRAnyValueType, -/// and existential types into Tuple. -void lowerGenericType(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp deleted file mode 100644 index 0a87d634d28..00000000000 --- a/source/slang/slang-ir-lower-generics.cpp +++ /dev/null @@ -1,289 +0,0 @@ -// slang-ir-lower-generics.cpp -#include "slang-ir-lower-generics.h" - -#include "../core/slang-func-ptr.h" -#include "../core/slang-performance-profiler.h" -#include "slang-ir-any-value-inference.h" -#include "slang-ir-any-value-marshalling.h" -#include "slang-ir-augment-make-existential.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-inst-pass-base.h" -#include "slang-ir-layout.h" -#include "slang-ir-lower-generic-call.h" -#include "slang-ir-lower-generic-function.h" -#include "slang-ir-lower-generic-type.h" -#include "slang-ir-lower-tuple-types.h" -#include "slang-ir-lower-typeflow-insts.h" -#include "slang-ir-specialize-dispatch.h" -#include "slang-ir-specialize-dynamic-associatedtype-lookup.h" -#include "slang-ir-ssa-simplification.h" -#include "slang-ir-util.h" -#include "slang-ir-witness-table-wrapper.h" - -namespace Slang -{ -// Replace all uses of RTTI objects with its sequential ID. -// Currently we don't use RTTI objects other than check for null/invalid value, -// so all of them are 0xFFFFFFFF. -void specializeRTTIObjectReferences(SharedGenericsLoweringContext* sharedContext) -{ - uint32_t id = 0xFFFFFFFF; - for (auto rtti : sharedContext->mapTypeToRTTIObject) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(rtti.value); - IRUse* nextUse = nullptr; - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - builder.getIntValue(builder.getUIntType(), id), - builder.getIntValue(builder.getUIntType(), 0)}; - auto idOperand = builder.emitMakeVector(uint2Type, 2, uint2Args); - for (auto use = rtti.value->firstUse; use; use = nextUse) - { - nextUse = use->nextUse; - if (use->getUser()->getOp() == kIROp_GetAddress) - { - use->getUser()->replaceUsesWith(idOperand); - } - } - } -} - -// Replace all WitnessTableID type or RTTIHandleType with `uint2`. -void cleanUpRTTIHandleTypes(SharedGenericsLoweringContext* sharedContext) -{ - List instsToRemove; - for (auto inst : sharedContext->module->getGlobalInsts()) - { - switch (inst->getOp()) - { - case kIROp_WitnessTableIDType: - if (isComInterfaceType((IRType*)inst->getOperand(0))) - continue; - // fall through - case kIROp_RTTIHandleType: - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - inst->replaceUsesWith(uint2Type); - instsToRemove.add(inst); - } - break; - } - } - for (auto inst : instsToRemove) - inst->removeAndDeallocate(); -} - -// Remove all interface types from module. -void cleanUpInterfaceTypes(SharedGenericsLoweringContext* sharedContext) -{ - IRBuilder builder(sharedContext->module); - builder.setInsertInto(sharedContext->module->getModuleInst()); - auto dummyInterfaceObj = builder.getIntValue(builder.getIntType(), 0); - List interfaceInsts; - for (auto inst : sharedContext->module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_InterfaceType) - { - if (inst->findDecoration()) - continue; - - interfaceInsts.add(inst); - } - } - for (auto inst : interfaceInsts) - { - inst->replaceUsesWith(dummyInterfaceObj); - inst->removeAndDeallocate(); - } -} - - -// Turn all references of witness table or RTTI objects into integer IDs, generate -// specialized `switch` based dispatch functions based on witness table IDs, and remove -// all original witness table, RTTI object and interface definitions from IR module. -// With these transformations, the resulting code is compatible with D3D/Vulkan where -// no pointers are involved in RTTI / dynamic dispatch logic. -void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, DiagnosticSink* sink) -{ - /*specializeDispatchFunctions(sharedContext); - if (sink->getErrorCount() != 0) - return;*/ - - // lowerIsTypeInsts(sharedContext); - - /*specializeDynamicAssociatedTypeLookup(sharedContext); - if (sink->getErrorCount() != 0) - return; - - sharedContext->mapInterfaceRequirementKeyValue.clear(); - - specializeRTTIObjectReferences(sharedContext); - - cleanUpRTTIHandleTypes(sharedContext); - - cleanUpInterfaceTypes(sharedContext);*/ -} - -void checkTypeConformanceExists(SharedGenericsLoweringContext* context) -{ - HashSet implementedInterfaces; - - // Add all interface type that are implemented by at least one type to a set. - for (auto inst : context->module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_WitnessTable) - { - auto interfaceType = - cast(inst->getDataType())->getConformanceType(); - implementedInterfaces.add(interfaceType); - } - } - // Check if an interface type has any implementations. - workOnModule( - context, - [&](IRInst* inst) - { - if (auto lookupWitnessMethod = as(inst)) - { - auto witnessTableType = lookupWitnessMethod->getWitnessTable()->getDataType(); - if (!witnessTableType) - return; - auto interfaceType = - cast(witnessTableType)->getConformanceType(); - if (isComInterfaceType((IRType*)interfaceType)) - return; - if (!implementedInterfaces.contains(interfaceType)) - { - context->sink->diagnose( - interfaceType->sourceLoc, - Diagnostics::noTypeConformancesFoundForInterface, - interfaceType); - // Add to set to prevent duplicate diagnostic messages. - implementedInterfaces.add(interfaceType); - } - } - }); -} - -void stripWrapExistential(IRModule* module) -{ - InstWorkList workList(module); - - workList.add(module->getModuleInst()); - for (Index i = 0; i < workList.getCount(); i++) - { - auto inst = workList[i]; - switch (inst->getOp()) - { - case kIROp_WrapExistential: - { - auto operand = inst->getOperand(0); - inst->replaceUsesWith(operand); - inst->removeAndDeallocate(); - } - break; - default: - for (auto child : inst->getChildren()) - workList.add(child); - break; - } - } -} - -void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink) -{ - SLANG_PROFILE; - - SharedGenericsLoweringContext sharedContext(module); - sharedContext.targetProgram = targetProgram; - sharedContext.sink = sink; - - /*checkTypeConformanceExists(&sharedContext); - - // Replace all `makeExistential` insts with `makeExistentialWithRTTI` - // before making any other changes. This is necessary because a parameter of - // generic type will be lowered into `AnyValueType`, and after that we can no longer - // access the original generic type parameter from the lowered parameter value. - // This steps ensures that the generic type parameter is available via an - // explicit operand in `makeExistentialWithRTTI`, so that type parameter - // can be translated into an RTTI object during `lower-generic-type`, - // and used to create a tuple representing the existential value. - augmentMakeExistentialInsts(module); - - lowerGenericFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerGenericType(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerGenericCalls(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - generateWitnessTableWrapperFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - lowerExistentials(&sharedContext); - if (sink->getErrorCount() != 0) - return; - */ - - // This optional step replaces all uses of witness tables and RTTI objects with - // sequential IDs. Without this step, we will emit code that uses function pointers and - // real RTTI objects and witness tables. - specializeRTTIObjects(&sharedContext, sink); - - simplifyIR( - sharedContext.targetProgram, - module, - IRSimplificationOptions::getFast(sharedContext.targetProgram)); - - lowerTuples(module, sink); - if (sink->getErrorCount() != 0) - return; - - generateAnyValueMarshallingFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // At this point, we should no longer need to care any `WrapExistential` insts, - // although they could still exist in the IR in order to call generic core module functions, - // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). - // We should remove them now. - stripWrapExistential(module); -} - -void cleanupGenerics(TargetProgram* program, IRModule* module, DiagnosticSink* sink) -{ - SharedGenericsLoweringContext sharedContext(module); - sharedContext.targetProgram = program; - sharedContext.sink = sink; - - specializeRTTIObjects(&sharedContext, sink); - - lowerTuples(module, sink); - if (sink->getErrorCount() != 0) - return; - - generateAnyValueMarshallingFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - - // At this point, we should no longer need to care any `WrapExistential` insts, - // although they could still exist in the IR in order to call generic core module functions, - // e.g. RWStucturedBuffer.Load(WrapExistential(sbuffer, type), index). - // We should remove them now. - stripWrapExistential(module); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.h b/source/slang/slang-ir-lower-generics.h deleted file mode 100644 index 8385597b698..00000000000 --- a/source/slang/slang-ir-lower-generics.h +++ /dev/null @@ -1,19 +0,0 @@ -// slang-ir-lower-generics.h -#pragma once - -#include "slang-ir.h" - -namespace Slang -{ -struct IRModule; -class DiagnosticSink; -class TargetProgram; - -/// Lower generic and interface-based code to ordinary types and functions using -/// dynamic dispatch mechanisms. -void lowerGenerics(TargetProgram* targetReq, IRModule* module, DiagnosticSink* sink); - -// Clean up any generic-related IR insts that are no longer needed. Called when -// it has been determined that no more dynamic dispatch code will be generated. -void cleanupGenerics(TargetProgram* targetReq, IRModule* module, DiagnosticSink* sink); -} // namespace Slang diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 3f22d17295d..347f3245563 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -7,12 +7,218 @@ #include "slang-ir-specialize.h" #include "slang-ir-typeflow-collection.h" #include "slang-ir-util.h" -#include "slang-ir-witness-table-wrapper.h" #include "slang-ir.h" namespace Slang { +// Represents a work item for packing `inout` or `out` arguments after a concrete call. +struct ArgumentPackWorkItem +{ + enum Kind + { + Pack, + UpCast, + } kind = Pack; + + // A `AnyValue` typed destination. + IRInst* dstArg = nullptr; + // A concrete value to be packed. + IRInst* concreteArg = nullptr; +}; + +bool isAnyValueType(IRType* type) +{ + if (as(type) || as(type)) + return true; + return false; +} + +// Unpack an `arg` of `IRAnyValue` into concrete type if necessary, to make it feedable into the +// parameter. If `arg` represents a AnyValue typed variable passed in to a concrete `out` +// parameter, this function indicates that it needs to be packed after the call by setting +// `packAfterCall`. +IRInst* maybeUnpackArg( + IRBuilder* builder, + IRType* paramType, + IRInst* arg, + ArgumentPackWorkItem& packAfterCall) +{ + packAfterCall.dstArg = nullptr; + packAfterCall.concreteArg = nullptr; + + // If either paramType or argType is a pointer type + // (because of `inout` or `out` modifiers), we extract + // the underlying value type first. + IRType* paramValType = paramType; + IRType* argValType = arg->getDataType(); + IRInst* argVal = arg; + if (auto ptrType = as(paramType)) + { + paramValType = ptrType->getValueType(); + } + auto argType = arg->getDataType(); + if (auto argPtrType = as(argType)) + { + argValType = argPtrType->getValueType(); + } + + + // Unpack `arg` if the parameter expects concrete type but + // `arg` is an AnyValue. + if (!isAnyValueType(paramValType) && isAnyValueType(argValType)) + { + // if parameter expects an `out` pointer, store the unpacked val into a + // variable and pass in a pointer to that variable. + if (as(paramType)) + { + auto tempVar = builder->emitVar(paramValType); + if (as(paramType)) + builder->emitStore( + tempVar, + builder->emitUnpackAnyValue(paramValType, builder->emitLoad(arg))); + + // tempVar needs to be unpacked into original var after the call. + packAfterCall.kind = ArgumentPackWorkItem::Kind::Pack; + packAfterCall.dstArg = arg; + packAfterCall.concreteArg = tempVar; + return tempVar; + } + else + { + return builder->emitUnpackAnyValue(paramValType, argVal); + } + } + + // Reinterpret 'arg' if it is being passed to a parameter with + // a different type collection. For now, we'll approximate this + // by checking if the types are different, but this should be + // encoded in the types. + // + if (as(paramValType) && as(argValType) && + paramValType != argValType) + { + // if parameter expects an `out` pointer, store the unpacked val into a + // variable and pass in a pointer to that variable. + if (as(paramType)) + { + auto tempVar = builder->emitVar(paramValType); + + // tempVar needs to be unpacked into original var after the call. + packAfterCall.kind = ArgumentPackWorkItem::Kind::UpCast; + packAfterCall.dstArg = arg; + packAfterCall.concreteArg = tempVar; + return tempVar; + } + else + { + SLANG_UNEXPECTED("Unexpected upcast for non-out parameter"); + } + } + return arg; +} + +IRStringLit* _getWitnessTableWrapperFuncName(IRModule* module, IRFunc* func) +{ + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(func); + if (auto linkageDecoration = func->findDecoration()) + { + return builder->getStringValue( + (String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); + } + if (auto namehintDecoration = func->findDecoration()) + { + return builder->getStringValue( + (String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); + } + return nullptr; +} + + +IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* interfaceRequirementVal) +{ + auto funcTypeInInterface = cast(interfaceRequirementVal); + auto targetFuncType = as(funcInst->getDataType()); + + IRBuilder builderStorage(module); + auto builder = &builderStorage; + builder->setInsertBefore(funcInst); + + auto wrapperFunc = builder->createFunc(); + wrapperFunc->setFullType((IRType*)interfaceRequirementVal); + if (auto func = as(funcInst)) + if (auto name = _getWitnessTableWrapperFuncName(module, func)) + builder->addNameHintDecoration(wrapperFunc, name); + + builder->setInsertInto(wrapperFunc); + auto block = builder->emitBlock(); + builder->setInsertInto(block); + + ShortList params; + for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) + { + params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); + } + + List args; + List argsToPack; + + SLANG_ASSERT(params.getCount() == (Index)targetFuncType->getParamCount()); + for (UInt i = 0; i < targetFuncType->getParamCount(); i++) + { + auto wrapperParam = params[i]; + // Type of the parameter in the callee. + auto funcParamType = targetFuncType->getParamType(i); + + // If the implementation expects a concrete type + // (either in the form of a pointer for `out`/`inout` parameters, + // or in the form a value for `in` parameters, while + // the interface exposes an AnyValue type, + // we need to unpack the AnyValue argument to the appropriate + // concerete type. + ArgumentPackWorkItem packWorkItem; + auto newArg = maybeUnpackArg(builder, funcParamType, wrapperParam, packWorkItem); + args.add(newArg); + if (packWorkItem.concreteArg) + argsToPack.add(packWorkItem); + } + auto call = builder->emitCallInst(targetFuncType->getResultType(), funcInst, args); + + // Pack all `out` arguments. + for (auto item : argsToPack) + { + auto anyValType = cast(item.dstArg->getDataType())->getValueType(); + auto concreteVal = builder->emitLoad(item.concreteArg); + auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) + ? builder->emitPackAnyValue(anyValType, concreteVal) + : upcastSet(builder, concreteVal, anyValType); + builder->emitStore(item.dstArg, packedVal); + } + + // Pack return value if necessary. + if (!isAnyValueType(call->getDataType()) && + isAnyValueType(funcTypeInInterface->getResultType())) + { + auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); + builder->emitReturn(pack); + } + else if (call->getDataType() != funcTypeInInterface->getResultType()) + { + auto reinterpret = upcastSet(builder, call, funcTypeInInterface->getResultType()); + builder->emitReturn(reinterpret); + } + else + { + if (call->getDataType()->getOp() == kIROp_VoidType) + builder->emitReturn(); + else + builder->emitReturn(call); + } + return wrapperFunc; +} + UInt getUniqueID(IRBuilder* builder, IRInst* inst) { // Fallback. diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp deleted file mode 100644 index 8d6979913bc..00000000000 --- a/source/slang/slang-ir-lower-witness-lookup.cpp +++ /dev/null @@ -1,446 +0,0 @@ -// slang-ir-lower-generic-existential.cpp - -#include "slang-ir-lower-witness-lookup.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ - -struct WitnessLookupLoweringContext -{ - IRModule* module; - DiagnosticSink* sink; - - Dictionary witnessDispatchFunctions; - - void init() - { - // Reconstruct the witness dispatch functions map. - for (auto inst : module->getGlobalInsts()) - { - if (auto key = as(inst)) - { - for (auto decor : key->getDecorations()) - { - if (auto witnessDispatchFunc = as(decor)) - { - witnessDispatchFunctions.add(key, witnessDispatchFunc->getFunc()); - } - } - } - } - } - - bool hasAssocType(IRInst* type) - { - if (!type) - return false; - - InstHashSet processedSet(type->getModule()); - InstWorkList workList(type->getModule()); - workList.add(type); - processedSet.add(type); - for (Index i = 0; i < workList.getCount(); i++) - { - auto inst = workList[i]; - if (inst->getOp() == kIROp_AssociatedType) - return true; - - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - if (!inst->getOperand(j)) - continue; - if (processedSet.add(inst->getOperand(j))) - workList.add(inst->getOperand(j)); - } - } - return false; - } - - IRType* translateType(IRBuilder builder, IRInst* type) - { - if (!type) - return nullptr; - if (auto genType = as(type)) - { - IRCloneEnv cloneEnv; - builder.setInsertBefore(genType); - auto newGeneric = as(cloneInst(&cloneEnv, &builder, genType)); - newGeneric->setFullType(builder.getGenericKind()); - auto retVal = findGenericReturnVal(newGeneric); - builder.setInsertBefore(retVal); - auto translated = translateType(builder, retVal); - retVal->replaceUsesWith(translated); - return (IRType*)newGeneric; - } - else if (auto thisType = as(type)) - { - return (IRType*)thisType->getConstraintType(); - } - else if (auto assocType = as(type)) - { - return assocType; - } - - if (as(type)) - return (IRType*)type; - - switch (type->getOp()) - { - case kIROp_Param: - case kIROp_VectorType: - case kIROp_MatrixType: - case kIROp_StructType: - case kIROp_ClassType: - case kIROp_InterfaceType: - case kIROp_LookupWitnessMethod: - return (IRType*)type; - default: - { - List translatedOperands; - for (UInt i = 0; i < type->getOperandCount(); i++) - { - translatedOperands.add(translateType(builder, type->getOperand(i))); - } - auto translated = builder.emitIntrinsicInst( - type->getFullType(), - type->getOp(), - (UInt)translatedOperands.getCount(), - translatedOperands.getBuffer()); - return (IRType*)translated; - } - } - } - - IRInst* findOrCreateDispatchFunc(IRLookupWitnessMethod* lookupInst) - { - IRInst* func = nullptr; - auto requirementKey = cast(lookupInst->getRequirementKey()); - if (witnessDispatchFunctions.tryGetValue(requirementKey, func)) - { - return func; - } - - auto witnessTableOperand = lookupInst->getWitnessTable(); - auto witnessTableType = as(witnessTableOperand->getDataType()); - SLANG_RELEASE_ASSERT(witnessTableType); - auto interfaceType = - as(unwrapAttributedType(witnessTableType->getConformanceType())); - SLANG_RELEASE_ASSERT(interfaceType); - if (interfaceType->findDecoration()) - return nullptr; - auto requirementType = findInterfaceRequirement(interfaceType, requirementKey); - SLANG_RELEASE_ASSERT(requirementType); - - // We only lower non-static function requirement lookups for now. - // Our front end will stick a StaticRequirementDecoration on the IRStructKey for static - // member requirements. - if (lookupInst->getRequirementKey()->findDecoration()) - return nullptr; - auto interfaceMethodFuncType = - as(getResolvedInstForDecorations(requirementType)); - if (interfaceMethodFuncType) - { - // Detect cases that we currently does not support and exit. - - // If this is a non static function requirement, we should - // make sure the first parameter is the interface type. If not, something has gone - // wrong. - if (interfaceMethodFuncType->getParamCount() == 0) - return nullptr; - if (!as(unwrapAttributedType(interfaceMethodFuncType->getParamType(0)))) - return nullptr; - - // The function has any associated type parameter, we currently can't lower it early in - // this pass. We will lower it in the catch all generic lowering pass. - for (UInt i = 1; i < interfaceMethodFuncType->getParamCount(); i++) - { - if (hasAssocType(interfaceMethodFuncType->getParamType(i))) - return nullptr; - } - - // If return type is a composite type containing an assoc type, we won't lower it now. - // Supporting general use of assoc type is possible, but would require more complex - // logic in this pass to marshal things to and from existential types. - if (interfaceMethodFuncType->getResultType()->getOp() != kIROp_AssociatedType && - hasAssocType(interfaceMethodFuncType->getResultType())) - return nullptr; - } - else - { - return nullptr; - } - - - IRBuilder builder(module); - builder.setInsertBefore(getParentFunc(lookupInst)); - - // Create a dispatch func. - IRFunc* dispatchFunc = nullptr; - IRFuncType* dispatchFuncType = nullptr; - IRGeneric* parentGeneric = nullptr; - - // If requirementType is a generic, we need to create a new generic that has the same - // parameters. - if (auto genericRequirement = as(requirementType)) - { - IRCloneEnv cloneEnv; - parentGeneric = as(cloneInst(&cloneEnv, &builder, genericRequirement)); - - auto returnInst = as(parentGeneric->getFirstBlock()->getLastInst()); - SLANG_RELEASE_ASSERT(returnInst); - builder.setInsertBefore(returnInst); - auto oldDispatchFuncType = as(returnInst->getVal()); - if (!oldDispatchFuncType) - return nullptr; - - dispatchFuncType = as(translateType(builder, oldDispatchFuncType)); - - SLANG_RELEASE_ASSERT(dispatchFuncType); - - dispatchFunc = builder.createFunc(); - dispatchFunc->setFullType(dispatchFuncType); - builder.emitReturn(dispatchFunc); - returnInst->removeAndDeallocate(); - - parentGeneric->setFullType(translateType(builder, requirementType)); - } - else - { - dispatchFuncType = as(translateType(builder, requirementType)); - dispatchFunc = builder.createFunc(); - dispatchFunc->setFullType(dispatchFuncType); - } - - // We need to inline this function if the requirement is differentiable, - // so that the autodiff pass doesn't need to handle the dispatch function. - if (requirementKey->findDecoration() || - requirementKey->findDecoration()) - { - builder.addForceInlineDecoration(dispatchFunc); - } - - // Collect generic params. - List genericParams; - if (parentGeneric) - { - for (auto param : parentGeneric->getParams()) - genericParams.add(param); - } - - // Emit the body of the dispatch func. - builder.setInsertInto(dispatchFunc); - auto firstBlock = builder.emitBlock(); - auto firstBlockBuilder = builder; - // Emit parameters. - List params; - - for (UInt i = 0; i < dispatchFuncType->getParamCount(); i++) - { - params.add(builder.emitParam(dispatchFuncType->getParamType(i))); - } - auto witness = builder.emitExtractExistentialWitnessTable(params[0]); - - auto witnessTables = getWitnessTablesFromInterfaceType(module, interfaceType); - if (witnessTables.getCount() == 0) - { - // If there is no witness table, we should emit an error. - sink->diagnose( - lookupInst, - Diagnostics::noTypeConformancesFoundForInterface, - interfaceType); - return nullptr; - } - else - { - List cases; - for (auto witnessTable : witnessTables) - { - IRBlock* block = builder.emitBlock(); - auto caseValue = firstBlockBuilder.emitGetSequentialIDInst(witnessTable); - cases.add(caseValue); - cases.add(block); - auto entry = findWitnessTableEntry(witnessTable, requirementKey); - SLANG_RELEASE_ASSERT(entry); - // If the entry is a generic, we need to specialize it. - if (const auto genericEntry = as(entry)) - { - auto specializedFuncType = builder.emitSpecializeInst( - builder.getTypeKind(), - entry->getFullType(), - (UInt)genericParams.getCount(), - genericParams.getBuffer()); - entry = builder.emitSpecializeInst( - (IRType*)specializedFuncType, - entry, - (UInt)genericParams.getCount(), - genericParams.getBuffer()); - } - auto args = params; - // Reinterpret the first arg into the concrete type. - args[0] = builder.emitReinterpret( - witnessTable->getConcreteType(), - builder.emitExtractExistentialValue( - builder.emitExtractExistentialType(args[0]), - args[0])); - - auto calleeFuncType = - as(getResolvedInstForDecorations(entry)->getFullType()); - auto callReturnType = calleeFuncType->getResultType(); - if (callReturnType->getParent() != module->getModuleInst()) - { - // the return type is dependent on generic parameter, use the type from - // dispatchFuncType instead. - callReturnType = dispatchFuncType->getResultType(); - } - - IRInst* ret = builder.emitCallInst( - callReturnType, - entry, - (UInt)args.getCount(), - args.getBuffer()); - // If result type is an associated type, we need to pack it into an anyValue. - if (as(dispatchFuncType->getResultType())) - { - ret = builder.emitPackAnyValue(dispatchFuncType->getResultType(), ret); - } - builder.emitReturn(ret); - } - builder.setInsertInto(firstBlock); - if (witnessTables.getCount() == 1) - { - builder.emitBranch((IRBlock*)cases[1]); - } - else - { - auto witnessId = firstBlockBuilder.emitGetSequentialIDInst(witness); - auto breakLabel = builder.emitBlock(); - builder.emitUnreachable(); - firstBlockBuilder.emitSwitch( - witnessId, - breakLabel, - (IRBlock*)cases.getLast(), - (UInt)(cases.getCount() - 2), - cases.getBuffer()); - } - } - - // Stick a decoration on the requirement key so we can find the dispatch func later. - IRInst* resultValue = parentGeneric ? (IRInst*)parentGeneric : dispatchFunc; - builder.addDispatchFuncDecoration(requirementKey, resultValue); - - // Register the dispatch func to witnessDispatchFunctions dictionary. - witnessDispatchFunctions[requirementKey] = resultValue; - - return resultValue; - } - - void rewriteCallSite(IRCall* call, IRInst* dispatchFunc, IRInst* initialExistentialObject) - { - SLANG_RELEASE_ASSERT(call->getArgCount() != 0); - call->setOperand(0, dispatchFunc); - IRBuilder builder(call); - builder.setInsertBefore(call); - auto witnessTable = builder.emitExtractExistentialWitnessTable(initialExistentialObject); - auto newExistentialObject = builder.emitMakeExistential( - initialExistentialObject->getDataType(), - call->getOperand(1), - witnessTable); - call->setOperand(1, newExistentialObject); - } - - bool processWitnessLookup(IRLookupWitnessMethod* lookupInst) - { - auto witnessTableOperand = lookupInst->getWitnessTable(); - auto extractInst = as(witnessTableOperand); - if (!extractInst) - return false; - auto dispatchFunc = findOrCreateDispatchFunc(lookupInst); - if (!dispatchFunc) - return false; - bool changed = false; - auto existentialObject = extractInst->getOperand(0); - - IRBuilder builder(lookupInst); - builder.setInsertBefore(lookupInst); - traverseUses( - lookupInst, - [&](IRUse* use) - { - if (auto specialize = as(use->getUser())) - { - builder.setInsertBefore(use->getUser()); - List args; - for (UInt i = 0; i < specialize->getArgCount(); i++) - args.add(specialize->getArg(i)); - auto specializedType = builder.emitSpecializeInst( - builder.getTypeKind(), - dispatchFunc->getFullType(), - (UInt)args.getCount(), - args.getBuffer()); - auto newSpecialize = builder.emitSpecializeInst( - (IRType*)specializedType, - dispatchFunc, - (UInt)args.getCount(), - args.getBuffer()); - traverseUses( - specialize, - [&](IRUse* specializeUse) - { - if (auto call = as(specializeUse->getUser())) - { - changed = true; - rewriteCallSite(call, newSpecialize, existentialObject); - } - }); - } - else if (auto call = as(use->getUser())) - { - changed = true; - rewriteCallSite(call, dispatchFunc, existentialObject); - } - }); - return changed; - } - - bool processFunc(IRFunc* func) - { - bool changed = false; - for (auto bb : func->getBlocks()) - { - for (auto inst : bb->getModifiableChildren()) - { - if (auto witnessLookupInst = as(inst)) - { - changed |= processWitnessLookup(witnessLookupInst); - } - } - } - return changed; - } -}; - -bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink) -{ - bool changed = false; - WitnessLookupLoweringContext context; - context.module = module; - context.sink = sink; - context.init(); - - for (auto inst : module->getGlobalInsts()) - { - // Process all fully specialized functions and look for - // witness lookup instructions. If we see a lookup for a non-static function, - // create a dispatch function and replace the lookup with a call to the dispatch function. - if (auto func = as(inst)) - changed |= context.processFunc(func); - } - return changed; -} -} // namespace Slang diff --git a/source/slang/slang-ir-lower-witness-lookup.h b/source/slang/slang-ir-lower-witness-lookup.h deleted file mode 100644 index 434e97f3e86..00000000000 --- a/source/slang/slang-ir-lower-witness-lookup.h +++ /dev/null @@ -1,15 +0,0 @@ -// slang-ir-lower-witness-lookup.h -#pragma once - -namespace Slang -{ -struct IRModule; -class DiagnosticSink; - -/// Lower calls to a witness lookup into a call to a dispatch function. -/// For example, if we see call(witnessLookup(wt, key)), we will create a -/// dispatch function that calls into different implementations based on witness table -/// ID. The dispatch function will be called instead of witnessLookup. -bool lowerWitnessLookup(IRModule* module, DiagnosticSink* sink); - -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp deleted file mode 100644 index 2a8374a42f0..00000000000 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ /dev/null @@ -1,328 +0,0 @@ -#include "slang-ir-specialize-dispatch.h" - -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -IRFunc* specializeDispatchFunction( - SharedGenericsLoweringContext* sharedContext, - IRFunc* dispatchFunc) -{ - auto witnessTableType = cast(dispatchFunc->getDataType())->getParamType(0); - auto conformanceType = cast(witnessTableType)->getConformanceType(); - // Collect all witness tables of `witnessTableType` in current module. - List witnessTables = - sharedContext->getWitnessTablesFromInterfaceType(conformanceType); - - SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock()); - auto block = dispatchFunc->getFirstBlock(); - - // The dispatch function before modification must be in the form of - // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args) - // We now find the relavent instructions. - IRCall* callInst = nullptr; - IRLookupWitnessMethod* lookupInst = nullptr; - // Only used in debug builds as a sanity check - [[maybe_unused]] IRReturn* returnInst = nullptr; - for (auto inst : block->getOrdinaryInsts()) - { - switch (inst->getOp()) - { - case kIROp_Call: - callInst = cast(inst); - break; - case kIROp_LookupWitnessMethod: - lookupInst = cast(inst); - break; - case kIROp_Return: - returnInst = cast(inst); - break; - default: - break; - } - } - SLANG_ASSERT(callInst && lookupInst && returnInst); - - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(dispatchFunc); - - // Create a new dispatch func to replace the existing one. - auto newDispatchFunc = builder->createFunc(); - - List paramTypes; - for (auto paramInst : dispatchFunc->getParams()) - { - paramTypes.add(paramInst->getFullType()); - } - - // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential - // ID. - paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType); - - auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType()); - newDispatchFunc->setFullType(newDipsatchFuncType); - dispatchFunc->transferDecorationsTo(newDispatchFunc); - - builder->setInsertInto(newDispatchFunc); - auto newBlock = builder->emitBlock(); - - IRBlock* defaultBlock = nullptr; - - auto requirementKey = lookupInst->getRequirementKey(); - List params; - for (Index i = 0; i < paramTypes.getCount(); i++) - { - auto param = builder->emitParam(paramTypes[i]); - if (i > 0) - params.add(param); - } - auto witnessTableParam = newBlock->getFirstParam(); - - // `witnessTableParam` is expected to have `IRWitnessTableID` type, which - // will later lower into a `uint2`. We only use the first element of the uint2 - // to store the sequential ID and reserve the second 32-bit value for future - // pointer-compatibility. We insert a member extract inst right now - // to obtain the first element and use it in our switch statement. - UInt elemIdx = 0; - auto witnessTableSequentialID = - builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx); - - // Generate case blocks for each possible witness table. - List caseBlocks; - for (Index i = 0; i < witnessTables.getCount(); i++) - { - auto witnessTable = witnessTables[i]; - auto seqIdDecoration = witnessTable->findDecoration(); - if (!seqIdDecoration) - { - sharedContext->sink->diagnose( - witnessTable->getConcreteType(), - Diagnostics::typeCannotBeUsedInDynamicDispatch, - witnessTable->getConcreteType()); - } - - if (i != witnessTables.getCount() - 1) - { - // Create a case block if we are not the last case. - caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); - builder->setInsertInto(newDispatchFunc); - auto caseBlock = builder->emitBlock(); - caseBlocks.add(caseBlock); - } - else - { - // Generate code for the last possible value in the `default` block. - builder->setInsertInto(newDispatchFunc); - defaultBlock = builder->emitBlock(); - builder->setInsertInto(defaultBlock); - } - - auto callee = findWitnessTableEntry(witnessTable, requirementKey); - SLANG_ASSERT(callee); - auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); - if (callInst->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(specializedCallInst); - } - - // Emit a switch statement to call the correct concrete function based on - // the witness table sequential ID passed in. - builder->setInsertInto(newDispatchFunc); - - - if (witnessTables.getCount() == 1) - { - // If there is only 1 case, no switch statement is necessary. - builder->setInsertInto(newBlock); - builder->emitBranch(defaultBlock); - } - else if (witnessTables.getCount() > 1) - { - auto breakBlock = builder->emitBlock(); - builder->setInsertInto(breakBlock); - builder->emitUnreachable(); - - builder->setInsertInto(newBlock); - builder->emitSwitch( - witnessTableSequentialID, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); - } - else - { - // We have no witness tables that implements this interface. - // Just return a default value. - builder->setInsertInto(newBlock); - if (callInst->getDataType()->getOp() == kIROp_VoidType) - { - builder->emitReturn(); - } - else - { - auto defaultValue = builder->emitDefaultConstruct(callInst->getDataType()); - builder->emitReturn(defaultValue); - } - } - // Remove old implementation. - dispatchFunc->replaceUsesWith(newDispatchFunc); - dispatchFunc->removeAndDeallocate(); - - return newDispatchFunc; -} - -// Ensures every witness table object has been assigned a sequential ID. -// All witness tables will have a SequentialID decoration after this function is run. -// The sequantial ID in the decoration will be the same as the one specified in the Linkage. -// Otherwise, a new ID will be generated and assigned to the witness table object, and -// the sequantial ID map in the Linkage will be updated to include the new ID, so they -// can be looked up by the user via future Slang API calls. -/* -void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContext) -{ - StringBuilder generatedMangledName; - - auto linkage = sharedContext->targetProgram->getTargetReq()->getLinkage(); - for (auto inst : sharedContext->module->getGlobalInsts()) - { - if (inst->getOp() == kIROp_WitnessTable) - { - UnownedStringSlice witnessTableMangledName; - if (auto instLinkage = inst->findDecoration()) - { - witnessTableMangledName = instLinkage->getMangledName(); - } - else - { - auto witnessTableType = as(inst->getDataType()); - - if (witnessTableType && witnessTableType->getConformanceType() == nullptr) - { - // Ignore witness tables that represent 'none' for optional witness table types. - continue; - } - - if (witnessTableType && witnessTableType->getConformanceType() - ->findDecoration()) - { - // The interface is for specialization only, it would be an error if dynamic - // dispatch is used through the interface. Skip assigning ID for the witness - // table. - continue; - } - - // generate a unique linkage for it. - static int32_t uniqueId = 0; - uniqueId++; - if (auto nameHint = inst->findDecoration()) - { - generatedMangledName << nameHint->getName(); - } - generatedMangledName << "_generated_witness_uuid_" << uniqueId; - witnessTableMangledName = generatedMangledName.getUnownedSlice(); - } - - // If the inst already has a SequentialIDDecoration, stop now. - if (inst->findDecoration()) - continue; - - // Get a sequential ID for the witness table using the map from the Linkage. - uint32_t seqID = 0; - if (!linkage->mapMangledNameToRTTIObjectIndex.tryGetValue( - witnessTableMangledName, - seqID)) - { - auto interfaceType = - cast(inst->getDataType())->getConformanceType(); - if (as(interfaceType)) - { - auto interfaceLinkage = interfaceType->findDecoration(); - SLANG_ASSERT( - interfaceLinkage && "An interface type does not have a linkage," - "but a witness table associated with it has one."); - auto interfaceName = interfaceLinkage->getMangledName(); - auto idAllocator = - linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( - interfaceName); - if (!idAllocator) - { - linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; - idAllocator = - linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( - interfaceName); - } - seqID = *idAllocator; - ++(*idAllocator); - } - else - { - // NoneWitness, has special ID of -1. - seqID = uint32_t(-1); - } - linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID; - } - - // Add a decoration to the inst. - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - builder.addSequentialIDDecoration(inst, seqID); - } - } -} -*/ - -// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with -// its sequential ID. -void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc) -{ - List users; - for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse) - { - users.add(use->getUser()); - } - for (auto user : users) - { - if (auto call = as(user)) - { - if (call->getCallee() != newDispatchFunc) - continue; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(call); - List args; - for (UInt i = 0; i < call->getArgCount(); i++) - { - args.add(call->getArg(i)); - } - if (as(args[0]->getDataType())) - continue; - auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args); - call->replaceUsesWith(newCall); - call->removeAndDeallocate(); - } - } -} - -void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) -{ - // First we ensure that all witness table objects has a sequential ID assigned. - // ensureWitnessTableSequentialIDs(sharedContext); - - // Generate specialized dispatch functions and fixup call sites. - for (const auto& [_, dispatchFunc] : sharedContext->mapInterfaceRequirementKeyToDispatchMethods) - { - // Generate a specialized `switch` statement based dispatch func, - // from the witness tables present in the module. - auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc); - - // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of - // witness table objects. - fixupDispatchFuncCall(sharedContext, newDispatchFunc); - } -} -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dispatch.h b/source/slang/slang-ir-specialize-dispatch.h deleted file mode 100644 index f176229cfc4..00000000000 --- a/source/slang/slang-ir-specialize-dispatch.h +++ /dev/null @@ -1,13 +0,0 @@ -// slang-ir-specialize-dispatch.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Modifies the body of interface dispatch functions to use branching instead -/// of function pointer calls to implement the dynamic dispatch logic. -/// This is only used on GPU targets where function pointers are not supported -/// or are not efficient. -void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext); -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp deleted file mode 100644 index 8d9c58b1b6c..00000000000 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ /dev/null @@ -1,272 +0,0 @@ -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-specialize-dispatch.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ - -struct AssociatedTypeLookupSpecializationContext -{ - SharedGenericsLoweringContext* sharedContext; - - IRFunc* createWitnessTableLookupFunc(IRInterfaceType* interfaceType, IRInst* key) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(interfaceType); - - auto inputWitnessTableIDType = builder.getWitnessTableIDType(interfaceType); - auto requirementEntry = sharedContext->findInterfaceRequirementVal(interfaceType, key); - - auto resultWitnessTableType = cast(requirementEntry); - auto resultWitnessTableIDType = - builder.getWitnessTableIDType((IRType*)resultWitnessTableType->getConformanceType()); - - auto funcType = - builder.getFuncType(1, (IRType**)&inputWitnessTableIDType, resultWitnessTableIDType); - auto func = builder.createFunc(); - func->setFullType(funcType); - - if (auto linkage = key->findDecoration()) - builder.addNameHintDecoration(func, linkage->getMangledName()); - - builder.setInsertInto(func); - - auto block = builder.emitBlock(); - auto witnessTableParam = builder.emitParam(inputWitnessTableIDType); - - // `witnessTableParam` is expected to have `IRWitnessTableID` type, which - // will later lower into a `uint2`. We only use the first element of the uint2 - // to store the sequential ID and reserve the second 32-bit value for future - // pointer-compatibility. We insert a member extract inst right now - // to obtain the first element and use it in our switch statement. - UInt elemIdx = 0; - auto witnessTableSequentialID = - builder.emitSwizzle(builder.getUIntType(), witnessTableParam, 1, &elemIdx); - - // Collect all witness tables of `witnessTableType` in current module. - List witnessTables = - sharedContext->getWitnessTablesFromInterfaceType(interfaceType); - - // Generate case blocks for each possible witness table. - IRBlock* defaultBlock = nullptr; - List caseBlocks; - for (Index i = 0; i < witnessTables.getCount(); i++) - { - auto witnessTable = witnessTables[i]; - auto seqIdDecoration = witnessTable->findDecoration(); - SLANG_ASSERT(seqIdDecoration); - - if (i != witnessTables.getCount() - 1) - { - // Create a case block if we are not the last case. - caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); - builder.setInsertInto(func); - auto caseBlock = builder.emitBlock(); - caseBlocks.add(caseBlock); - } - else - { - // Generate code for the last possible value in the `default` block. - builder.setInsertInto(func); - defaultBlock = builder.emitBlock(); - builder.setInsertInto(defaultBlock); - } - - auto resultWitnessTable = findWitnessTableEntry(witnessTable, key); - auto resultWitnessTableIDDecoration = - resultWitnessTable->findDecoration(); - SLANG_ASSERT(resultWitnessTableIDDecoration); - // Pack the resulting witness table ID into a `uint2`. - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - resultWitnessTableIDDecoration->getSequentialIDOperand(), - builder.getIntValue(builder.getUIntType(), 0)}; - auto resultID = builder.emitMakeVector(uint2Type, 2, uint2Args); - builder.emitReturn(resultID); - } - - builder.setInsertInto(func); - - if (witnessTables.getCount() == 1) - { - // If there is only 1 case, no switch statement is necessary. - builder.setInsertInto(block); - builder.emitBranch(defaultBlock); - } - else - { - // If there are more than 1 cases, - // emit a switch statement to return the correct witness table ID based on - // the witness table ID passed in. - auto breakBlock = builder.emitBlock(); - builder.setInsertInto(breakBlock); - builder.emitUnreachable(); - - builder.setInsertInto(block); - builder.emitSwitch( - witnessTableSequentialID, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); - } - - return func; - } - - void processLookupInterfaceMethodInst(IRLookupWitnessMethod* inst) - { - if (isComInterfaceType(inst->getWitnessTable()->getDataType())) - { - return; - } - - // Ignore lookups for RTTI objects for now, since they are not used anywhere. - if (!as(inst->getDataType())) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - auto zero = builder.getIntValue(builder.getUIntType(), 0); - IRInst* args[] = {zero, zero}; - auto zeroUint2 = builder.emitMakeVector(uint2Type, 2, args); - inst->replaceUsesWith(zeroUint2); - return; - } - - // Replace all witness table lookups with calls to specialized functions that directly - // returns the sequential ID of the resulting witness table, effectively getting rid - // of actual witness table objects in the target code (they all become IDs). - auto witnessTableType = inst->getWitnessTable()->getDataType(); - IRInterfaceType* interfaceType = cast( - cast(witnessTableType)->getConformanceType()); - if (!interfaceType) - return; - - List tables = - sharedContext->getWitnessTablesFromInterfaceType(interfaceType); - if (tables.getCount() == 0) - return; - - auto key = inst->getRequirementKey(); - IRFunc* func = nullptr; - if (!sharedContext->mapInterfaceRequirementKeyToDispatchMethods.tryGetValue(key, func)) - { - func = createWitnessTableLookupFunc(interfaceType, key); - sharedContext->mapInterfaceRequirementKeyToDispatchMethods[key] = func; - } - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - auto witnessTableArg = inst->getWitnessTable(); - auto callInst = builder.emitCallInst(func->getResultType(), func, witnessTableArg); - inst->replaceUsesWith(callInst); - inst->removeAndDeallocate(); - } - - void processGetSequentialIDInst(IRGetSequentialID* inst) - { - // If the operand is a witness table, it is already replaced with a uint2 - // at this point, where the first element in the uint2 is the id of the - // witness table. - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(inst); - UInt index = 0; - auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index); - inst->replaceUsesWith(id); - inst->removeAndDeallocate(); - } - - void processModule() - { - // Replace all `lookup_interface_method():IRWitnessTable` with call to specialized - // functions. - workOnModule( - sharedContext, - [this](IRInst* inst) - { - if (inst->getOp() == kIROp_LookupWitnessMethod) - { - processLookupInterfaceMethodInst(cast(inst)); - } - }); - - // Replace all direct uses of IRWitnessTables with its sequential ID. - workOnModule( - sharedContext, - [this](IRInst* inst) - { - if (inst->getOp() == kIROp_WitnessTable) - { - auto seqId = inst->findDecoration(); - if (!seqId) - return; - // Insert code to pack sequential ID into an uint2 at all use sites. - traverseUses( - inst, - [&](IRUse* use) - { - if (as(use->getUser())) - { - return; - } - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(use->getUser()); - auto uint2Type = builder.getVectorType( - builder.getUIntType(), - builder.getIntValue(builder.getIntType(), 2)); - IRInst* uint2Args[] = { - seqId->getSequentialIDOperand(), - builder.getIntValue(builder.getUIntType(), 0)}; - auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args); - builder.replaceOperand(use, uint2seqID); - }); - } - }); - - // Replace all `IRWitnessTableType`s with `IRWitnessTableIDType`. - for (auto globalInst : sharedContext->module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_WitnessTableType) - { - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(globalInst); - auto witnessTableIDType = builder.getWitnessTableIDType( - (IRType*)cast(globalInst)->getConformanceType()); - traverseUses( - globalInst, - [&](IRUse* use) - { - if (use->getUser()->getOp() == kIROp_WitnessTable) - return; - builder.replaceOperand(use, witnessTableIDType); - }); - } - } - - // `GetSequentialID(WitnessTableIDOperand)` becomes just `WitnessTableIDOperand`. - workOnModule( - sharedContext, - [this](IRInst* inst) - { - if (inst->getOp() == kIROp_GetSequentialID) - { - processGetSequentialIDInst(cast(inst)); - } - }); - } -}; - -void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext) -{ - AssociatedTypeLookupSpecializationContext context; - context.sharedContext = sharedContext; - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h deleted file mode 100644 index 83039eca54c..00000000000 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.h +++ /dev/null @@ -1,15 +0,0 @@ -// slang-ir-specialize-dynamic-associatedtype-lookup.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; - -/// Modifies the lookup of associatedtype entries from witness tables into -/// calls to a specialized "lookup" function that takes a witness table id -/// and returns a witness table id. -/// This is used on GPU targets where all witness tables are replaced as -/// integral IDs instead of a real pointer table. -void specializeDynamicAssociatedTypeLookup(SharedGenericsLoweringContext* sharedContext); - -} // namespace Slang diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 89db89f9109..ecb05bd1aab 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -6,7 +6,6 @@ #include "slang-ir-dce.h" #include "slang-ir-insts.h" #include "slang-ir-lower-typeflow-insts.h" -#include "slang-ir-lower-witness-lookup.h" #include "slang-ir-peephole.h" #include "slang-ir-sccp.h" #include "slang-ir-ssa-simplification.h" diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 64b790a3c2f..9a1d311ae7f 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -7,7 +7,6 @@ #include "slang-ir-specialize.h" #include "slang-ir-typeflow-collection.h" #include "slang-ir-util.h" -#include "slang-ir-witness-table-wrapper.h" #include "slang-ir.h" @@ -1132,9 +1131,6 @@ struct TypeFlowSpecializationContext case kIROp_MakeExistential: info = analyzeMakeExistential(context, as(inst)); break; - // case kIROp_WrapExistential: - // info = analyzeWrapExistential(context, as(inst)); - // break; case kIROp_LookupWitnessMethod: info = analyzeLookupWitnessMethod(context, as(inst)); break; @@ -1466,28 +1462,6 @@ struct TypeFlowSpecializationContext SLANG_UNEXPECTED("Unexpected witness table info type in analyzeMakeExistential"); } - /* - IRInst* analyzeWrapExistential(IRInst* context, IRWrapExistential* wrapExistential) - { - if (auto valInfo = tryGetInfo(context, wrapExistential->getWrappedValue())) - { - // We need a single possible value for the wrapped value. - auto taggedUnionType = cast(valInfo); - SLANG_ASSERT( - taggedUnionType->getTypeSet()->isSingleton() && - taggedUnionType->getWitnessTableSet()->isSingleton()); - // Since the inst's result is expected to be a concrete type, - // we'll return a 'none' here. The info won't be recorded anyway. - // - return none(); - } - else - { - return none(); - } - } - */ - IRInst* analyzeMakeStruct(IRInst* context, IRMakeStruct* makeStruct, WorkQueue& workQueue) { // We'll process this in the same way as a field-address, but for @@ -1525,6 +1499,7 @@ struct TypeFlowSpecializationContext IRInst* analyzeLoadFromUninitializedMemory(IRInst* context, IRInst* inst) { + SLANG_UNUSED(context); IRBuilder builder(module); if (as(inst->getDataType()) && !isConcreteType(inst->getDataType())) { @@ -1575,7 +1550,6 @@ struct TypeFlowSpecializationContext else if ( auto boundInterfaceType = as(loadInst->getDataType())) { - IRBuilder builder(module); return makeTaggedUnionType(cast( builder.getSingletonSet(boundInterfaceType->getWitnessTable()))); } @@ -1617,7 +1591,6 @@ struct TypeFlowSpecializationContext } else if (auto boundInterfaceType = as(inst->getDataType())) { - IRBuilder builder(module); return makeTaggedUnionType(cast( builder.getSingletonSet(boundInterfaceType->getWitnessTable()))); } @@ -2676,7 +2649,6 @@ struct TypeFlowSpecializationContext // for (auto param : firstBlock->getParams()) { - auto paramType = param->getDataType(); auto paramInfo = tryGetInfo(context, param); if (paramInfo) continue; // Already has some information @@ -3887,6 +3859,7 @@ struct TypeFlowSpecializationContext bool specializeWrapExistential(IRInst* context, IRWrapExistential* inst) { + SLANG_UNUSED(context); inst->replaceUsesWith(inst->getWrappedValue()); inst->removeAndDeallocate(); return true; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 1803fa108bb..9889464458c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -2847,4 +2847,23 @@ bool canRelaxInstOrderRule(IRInst* inst, IRInst* useOfInst) return isSameBlock && isGenericParameter(useOfInst) && (useOfInst->getDataType() == inst); } +IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc) +{ + SLANG_UNUSED(usageLoc); + + if (auto decor = type->findDecoration()) + { + return decor->getSize(); + } + + // We could conceivably make it an error to have an interface + // without an `[anyValueSize(...)]` attribute, but then we risk + // producing error messages even when doing 100% static specialization. + // + // It is simpler to use a reasonable default size and treat any + // type without an explicit attribute as using that size. + // + return kDefaultAnyValueSize; +} + } // namespace Slang diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 4423a7e2bd8..bf19b0687cc 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -11,6 +11,11 @@ namespace Slang struct GenericChildrenMigrationContextImpl; struct IRCloneEnv; +constexpr IRIntegerValue kInvalidAnyValueSize = 0xFFFFFFFF; +constexpr IRIntegerValue kDefaultAnyValueSize = 16; +constexpr SlangInt kRTTIHeaderSize = 16; +constexpr SlangInt kRTTIHandleSize = 8; + // A helper class to clone children insts to a different generic parent that has equivalent set of // generic parameters. The clone will take care of substitution of equivalent generic parameters and // intermediate values between the two generic parents. @@ -460,6 +465,8 @@ bool isGenericParameter(IRInst* inst); // of the generic parameter. bool canRelaxInstOrderRule(IRInst* instToCheck, IRInst* otherInst); +IRIntegerValue getInterfaceAnyValueSize(IRInst* type, SourceLoc usageLoc); + } // namespace Slang #endif diff --git a/source/slang/slang-ir-witness-table-wrapper.cpp b/source/slang/slang-ir-witness-table-wrapper.cpp deleted file mode 100644 index 61d5ef5ccb4..00000000000 --- a/source/slang/slang-ir-witness-table-wrapper.cpp +++ /dev/null @@ -1,341 +0,0 @@ -// slang-ir-witness-table-wrapper.cpp -#include "slang-ir-witness-table-wrapper.h" - -#include "slang-ir-clone.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir-insts.h" -#include "slang-ir-typeflow-collection.h" -#include "slang-ir-util.h" -#include "slang-ir.h" - -namespace Slang -{ -struct GenericsLoweringContext; -IRFunc* emitWitnessTableWrapper( - IRModule* module, - IRInst* funcInst, - IRInst* interfaceRequirementVal); - -struct GenerateWitnessTableWrapperContext -{ - SharedGenericsLoweringContext* sharedContext; - - void lowerWitnessTable(IRWitnessTable* witnessTable) - { - auto interfaceType = as(witnessTable->getConformanceType()); - if (!interfaceType) - return; - if (isBuiltin(interfaceType)) - return; - if (isComInterfaceType(interfaceType)) - return; - - // We need to consider whether the concrete type that is conforming - // in this witness table actually fits within the declared any-value - // size for the interface. - // - // If the type doesn't fit then it would be invalid to use for dynamic - // dispatch, and the packing/unpacking operations we emit would fail - // to generate valid code. - // - // Such a type might still be useful for static specialization, so - // we can't consider this case a hard error. - // - auto concreteType = witnessTable->getConcreteType(); - IRIntegerValue typeSize, sizeLimit; - bool isTypeOpaque = false; - if (!sharedContext->doesTypeFitInAnyValue( - concreteType, - interfaceType, - &typeSize, - &sizeLimit, - &isTypeOpaque)) - { - HashSet visited; - if (isTypeOpaque) - { - sharedContext->sink->diagnose( - concreteType, - Diagnostics::typeCannotBePackedIntoAnyValue, - concreteType); - } - else - { - sharedContext->sink->diagnose( - concreteType, - Diagnostics::typeDoesNotFitAnyValueSize, - concreteType); - sharedContext->sink->diagnoseWithoutSourceView( - concreteType, - Diagnostics::typeAndLimit, - concreteType, - typeSize, - sizeLimit); - } - return; - } - - for (auto child : witnessTable->getChildren()) - { - auto entry = as(child); - if (!entry) - continue; - auto interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( - interfaceType, - entry->getRequirementKey()); - if (auto ordinaryFunc = as(entry->getSatisfyingVal())) - { - auto wrapper = emitWitnessTableWrapper( - sharedContext->module, - ordinaryFunc, - interfaceRequirementVal); - entry->satisfyingVal.set(wrapper); - sharedContext->addToWorkList(wrapper); - } - } - } - - void processInst(IRInst* inst) - { - if (auto witnessTable = as(inst)) - { - lowerWitnessTable(witnessTable); - } - } - - void processModule() - { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); - - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } - } -}; - -// Represents a work item for packing `inout` or `out` arguments after a concrete call. -struct ArgumentPackWorkItem -{ - enum Kind - { - Pack, - UpCast, - } kind = Pack; - - // A `AnyValue` typed destination. - IRInst* dstArg = nullptr; - // A concrete value to be packed. - IRInst* concreteArg = nullptr; -}; - -bool isAnyValueType(IRType* type) -{ - if (as(type) || as(type)) - return true; - return false; -} - -// Unpack an `arg` of `IRAnyValue` into concrete type if necessary, to make it feedable into the -// parameter. If `arg` represents a AnyValue typed variable passed in to a concrete `out` -// parameter, this function indicates that it needs to be packed after the call by setting -// `packAfterCall`. -IRInst* maybeUnpackArg( - IRBuilder* builder, - IRType* paramType, - IRInst* arg, - ArgumentPackWorkItem& packAfterCall) -{ - packAfterCall.dstArg = nullptr; - packAfterCall.concreteArg = nullptr; - - // If either paramType or argType is a pointer type - // (because of `inout` or `out` modifiers), we extract - // the underlying value type first. - IRType* paramValType = paramType; - IRType* argValType = arg->getDataType(); - IRInst* argVal = arg; - if (auto ptrType = as(paramType)) - { - paramValType = ptrType->getValueType(); - } - auto argType = arg->getDataType(); - if (auto argPtrType = as(argType)) - { - argValType = argPtrType->getValueType(); - } - - - // Unpack `arg` if the parameter expects concrete type but - // `arg` is an AnyValue. - if (!isAnyValueType(paramValType) && isAnyValueType(argValType)) - { - // if parameter expects an `out` pointer, store the unpacked val into a - // variable and pass in a pointer to that variable. - if (as(paramType)) - { - auto tempVar = builder->emitVar(paramValType); - if (as(paramType)) - builder->emitStore( - tempVar, - builder->emitUnpackAnyValue(paramValType, builder->emitLoad(arg))); - - // tempVar needs to be unpacked into original var after the call. - packAfterCall.kind = ArgumentPackWorkItem::Kind::Pack; - packAfterCall.dstArg = arg; - packAfterCall.concreteArg = tempVar; - return tempVar; - } - else - { - return builder->emitUnpackAnyValue(paramValType, argVal); - } - } - - // Reinterpret 'arg' if it is being passed to a parameter with - // a different type collection. For now, we'll approximate this - // by checking if the types are different, but this should be - // encoded in the types. - // - if (as(paramValType) && as(argValType) && - paramValType != argValType) - { - // if parameter expects an `out` pointer, store the unpacked val into a - // variable and pass in a pointer to that variable. - if (as(paramType)) - { - auto tempVar = builder->emitVar(paramValType); - - // tempVar needs to be unpacked into original var after the call. - packAfterCall.kind = ArgumentPackWorkItem::Kind::UpCast; - packAfterCall.dstArg = arg; - packAfterCall.concreteArg = tempVar; - return tempVar; - } - else - { - SLANG_UNEXPECTED("Unexpected upcast for non-out parameter"); - } - } - return arg; -} - -IRStringLit* _getWitnessTableWrapperFuncName(IRModule* module, IRFunc* func) -{ - IRBuilder builderStorage(module); - auto builder = &builderStorage; - builder->setInsertBefore(func); - if (auto linkageDecoration = func->findDecoration()) - { - return builder->getStringValue( - (String(linkageDecoration->getMangledName()) + "_wtwrapper").getUnownedSlice()); - } - if (auto namehintDecoration = func->findDecoration()) - { - return builder->getStringValue( - (String(namehintDecoration->getName()) + "_wtwrapper").getUnownedSlice()); - } - return nullptr; -} - - -IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* interfaceRequirementVal) -{ - auto funcTypeInInterface = cast(interfaceRequirementVal); - auto targetFuncType = as(funcInst->getDataType()); - - IRBuilder builderStorage(module); - auto builder = &builderStorage; - builder->setInsertBefore(funcInst); - - auto wrapperFunc = builder->createFunc(); - wrapperFunc->setFullType((IRType*)interfaceRequirementVal); - if (auto func = as(funcInst)) - if (auto name = _getWitnessTableWrapperFuncName(module, func)) - builder->addNameHintDecoration(wrapperFunc, name); - - builder->setInsertInto(wrapperFunc); - auto block = builder->emitBlock(); - builder->setInsertInto(block); - - ShortList params; - for (UInt i = 0; i < funcTypeInInterface->getParamCount(); i++) - { - params.add(builder->emitParam(funcTypeInInterface->getParamType(i))); - } - - List args; - List argsToPack; - - SLANG_ASSERT(params.getCount() == (Index)targetFuncType->getParamCount()); - for (UInt i = 0; i < targetFuncType->getParamCount(); i++) - { - auto wrapperParam = params[i]; - // Type of the parameter in the callee. - auto funcParamType = targetFuncType->getParamType(i); - - // If the implementation expects a concrete type - // (either in the form of a pointer for `out`/`inout` parameters, - // or in the form a value for `in` parameters, while - // the interface exposes an AnyValue type, - // we need to unpack the AnyValue argument to the appropriate - // concerete type. - ArgumentPackWorkItem packWorkItem; - auto newArg = maybeUnpackArg(builder, funcParamType, wrapperParam, packWorkItem); - args.add(newArg); - if (packWorkItem.concreteArg) - argsToPack.add(packWorkItem); - } - auto call = builder->emitCallInst(targetFuncType->getResultType(), funcInst, args); - - // Pack all `out` arguments. - for (auto item : argsToPack) - { - auto anyValType = cast(item.dstArg->getDataType())->getValueType(); - auto concreteVal = builder->emitLoad(item.concreteArg); - auto packedVal = (item.kind == ArgumentPackWorkItem::Kind::Pack) - ? builder->emitPackAnyValue(anyValType, concreteVal) - : upcastSet(builder, concreteVal, anyValType); - builder->emitStore(item.dstArg, packedVal); - } - - // Pack return value if necessary. - if (!isAnyValueType(call->getDataType()) && - isAnyValueType(funcTypeInInterface->getResultType())) - { - auto pack = builder->emitPackAnyValue(funcTypeInInterface->getResultType(), call); - builder->emitReturn(pack); - } - else if (call->getDataType() != funcTypeInInterface->getResultType()) - { - auto reinterpret = upcastSet(builder, call, funcTypeInInterface->getResultType()); - builder->emitReturn(reinterpret); - } - else - { - if (call->getDataType()->getOp() == kIROp_VoidType) - builder->emitReturn(); - else - builder->emitReturn(call); - } - return wrapperFunc; -} - -void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext) -{ - GenerateWitnessTableWrapperContext context; - context.sharedContext = sharedContext; - context.processModule(); -} - -} // namespace Slang diff --git a/source/slang/slang-ir-witness-table-wrapper.h b/source/slang/slang-ir-witness-table-wrapper.h deleted file mode 100644 index acc69aad14d..00000000000 --- a/source/slang/slang-ir-witness-table-wrapper.h +++ /dev/null @@ -1,30 +0,0 @@ -// slang-ir-witness-table-wrapper.h -#pragma once - -namespace Slang -{ -struct SharedGenericsLoweringContext; -struct IRFunc; -struct IRInst; -struct IRModule; - -/// This pass generates wrapper functions for witness table function entries. -/// -/// Enabled for generation of dynamic dispatch code only. -/// -/// Functions that are used to satisfy interface requirement have concrete -/// type signatures for `this` and `associatedtype` parameters/return values. -/// However, when they are called from a witness table, the callee only have a -/// raw pointer for this arguments, since the conrete type is not known to the -/// callee. Therefore, we need to generate wrappers for each member function -/// callable through a witness table, so that the wrapper functions take general void* -/// pointer for arguments whose type is unknown at call sites, and convert them -/// to concrete types and calls the actual implementation. -void generateWitnessTableWrapperFunctions(SharedGenericsLoweringContext* sharedContext); - -IRFunc* emitWitnessTableWrapper( - IRModule* module, - IRInst* funcInst, - IRInst* interfaceRequirementVal); - -} // namespace Slang diff --git a/tests/language-feature/dynamic-dispatch/generic-method.slang b/tests/language-feature/dynamic-dispatch/generic-method.slang index a2df3b97c5e..ff839d0df36 100644 --- a/tests/language-feature/dynamic-dispatch/generic-method.slang +++ b/tests/language-feature/dynamic-dispatch/generic-method.slang @@ -31,7 +31,7 @@ float f(uint id, float x) obj = A(); else if (id == 1) obj = B(); - else if (id == 2) + else obj = C(); return obj.calc(x); From 150e0972fdb73d661b024dc5b8990894fa17d79e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 30 Oct 2025 17:09:16 -0400 Subject: [PATCH 086/105] Fix-up --- source/slang/slang-ir-dce.cpp | 1 + source/slang/slang-ir-insts-stable-names.lua | 63 +++++++++---------- .../slang/slang-ir-lower-typeflow-insts.cpp | 8 +-- 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 3c20750aec6..903b7c1820a 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -665,6 +665,7 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex) { + SLANG_UNUSED(operandIndex); // There are some type of operands that needs to be treated as // "weak" references -- they can never hold things alive, and // whenever we delete the referenced value, these operands needs diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index c7941bca572..98f1be3eebc 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -682,40 +682,37 @@ return { ["Attr.MemoryScope"] = 678, ["Undefined.LoadFromUninitializedMemory"] = 679, ["CUDA_LDG"] = 680, -<<<<<<< HEAD - ["SetBase.TypeSet"] = 681, - ["SetBase.FuncSet"] = 682, - ["SetBase.WitnessTableSet"] = 683, - ["SetBase.GenericSet"] = 684, - ["UnboundedSet"] = 685, - ["Type.SetTagType"] = 686, - ["Type.TaggedUnionType"] = 687, - ["CastInterfaceToTaggedUnionPtr"] = 688, - ["CastTaggedUnionToInterfacePtr"] = 689, - ["GetTagForSuperSet"] = 690, - ["GetTagForMappedSet"] = 691, - ["GetTagForSpecializedSet"] = 692, - ["GetTagFromSequentialID"] = 693, - ["GetSequentialIDFromTag"] = 694, - ["GetElementFromTag"] = 695, - ["GetDispatcher"] = 696, - ["GetSpecializedDispatcher"] = 697, - ["GetTagFromTaggedUnion"] = 698, - ["GetValueFromTaggedUnion"] = 699, - ["Type.UntaggedUnionType"] = 700, - ["Type.ElementOfSetType"] = 701, - ["MakeTaggedUnion"] = 702, - ["GetTypeTagFromTaggedUnion"] = 703, - ["GetTagOfElementInSet"] = 704, - ["UnboundedTypeElement"] = 705, - ["UnboundedFuncElement"] = 706, - ["UnboundedWitnessTableElement"] = 707, - ["UnboundedGenericElement"] = 708, - ["UninitializedTypeElement"] = 709, - ["UninitializedWitnessTableElement"] = 710, -======= ["StoreBase.copyLogical"] = 681, ["MakeStorageTypeLoweringConfig"] = 682, ["Decoration.experimentalModule"] = 683, ->>>>>>> upstream + ["SetBase.TypeSet"] = 684, + ["SetBase.FuncSet"] = 685, + ["SetBase.WitnessTableSet"] = 686, + ["SetBase.GenericSet"] = 687, + ["UnboundedSet"] = 688, + ["Type.SetTagType"] = 689, + ["Type.TaggedUnionType"] = 690, + ["CastInterfaceToTaggedUnionPtr"] = 691, + ["CastTaggedUnionToInterfacePtr"] = 692, + ["GetTagForSuperSet"] = 693, + ["GetTagForMappedSet"] = 694, + ["GetTagForSpecializedSet"] = 695, + ["GetTagFromSequentialID"] = 696, + ["GetSequentialIDFromTag"] = 697, + ["GetElementFromTag"] = 698, + ["GetDispatcher"] = 699, + ["GetSpecializedDispatcher"] = 700, + ["GetTagFromTaggedUnion"] = 701, + ["GetValueFromTaggedUnion"] = 702, + ["Type.UntaggedUnionType"] = 703, + ["Type.ElementOfSetType"] = 704, + ["MakeTaggedUnion"] = 705, + ["GetTypeTagFromTaggedUnion"] = 706, + ["GetTagOfElementInSet"] = 707, + ["UnboundedTypeElement"] = 708, + ["UnboundedFuncElement"] = 709, + ["UnboundedWitnessTableElement"] = 710, + ["UnboundedGenericElement"] = 711, + ["UninitializedTypeElement"] = 712, + ["UninitializedWitnessTableElement"] = 713, } diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 347f3245563..6fc6596c2f9 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -1563,7 +1563,6 @@ struct ExistentialLoweringContext : public InstPassBase { SLANG_UNEXPECTED("Unexpected type for ExtractExistentialWitnessTable operand"); } - return false; } bool lowerGetValueFromBoundInterface(IRGetValueFromBoundInterface* inst) @@ -1596,7 +1595,6 @@ struct ExistentialLoweringContext : public InstPassBase inst->removeAndDeallocate(); return true; } - return true; } bool lowerExtractExistentialValue(IRExtractExistentialValue* inst) @@ -1622,8 +1620,10 @@ struct ExistentialLoweringContext : public InstPassBase inst->removeAndDeallocate(); return true; } - SLANG_UNEXPECTED("Unexpected type for ExtractExistentialValue operand"); - return false; + else + { + SLANG_UNEXPECTED("Unexpected type for ExtractExistentialValue operand"); + } } bool processGetSequentialIDInst(IRGetSequentialID* inst) From d386b064f1f439461db3f093f0a8640218a7b595 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:17:48 -0400 Subject: [PATCH 087/105] Fix lowering of bound interface types --- source/slang/slang-emit.cpp | 2 +- .../slang/slang-ir-lower-typeflow-insts.cpp | 34 +++++++++++++------ source/slang/slang-ir-lower-typeflow-insts.h | 2 +- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b1c5c465b1a..2044ae19253 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1186,7 +1186,7 @@ Result linkAndOptimizeIR( eliminateDeadCode(irModule, fastIRSimplificationOptions.deadCodeElimOptions); - lowerExistentials(irModule, sink); + lowerExistentials(irModule, targetProgram, sink); if (sink->getErrorCount() != 0) return SLANG_FAIL; diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 6fc6596c2f9..bf946215245 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -1398,12 +1398,13 @@ void lowerIsTypeInsts(IRModule* module) struct ExistentialLoweringContext : public InstPassBase { - ExistentialLoweringContext(IRModule* module) - : InstPassBase(module) + TargetProgram* targetProgram; + + ExistentialLoweringContext(IRModule* module, TargetProgram* targetProgram) + : InstPassBase(module), targetProgram(targetProgram) { } - bool _canReplace(IRUse* use) { switch (use->getUser()->getOp()) @@ -1505,11 +1506,24 @@ struct ExistentialLoweringContext : public InstPassBase auto anyValueType = builder.getAnyValueType(anyValueSize); - return builder.getTupleType( - rttiType, - witnessTableType, - builder.getPseudoPtrType(payloadType), - anyValueType); + IRSizeAndAlignment sizeAndAlignment; + Result result = getNaturalSizeAndAlignment( + targetProgram->getOptionSet(), + payloadType, + &sizeAndAlignment); + if (SLANG_FAILED(result) || sizeAndAlignment.size > anyValueSize) + { + return builder.getTupleType( + rttiType, + witnessTableType, + builder.getPseudoPtrType(payloadType), + anyValueType); + } + else + { + // Regular case (lower in the same way as unbound interface types) + return builder.getTupleType(rttiType, witnessTableType, anyValueType); + } } bool lowerExtractExistentialType(IRExtractExistentialType* inst) @@ -1750,10 +1764,10 @@ struct ExistentialLoweringContext : public InstPassBase } }; -bool lowerExistentials(IRModule* module, DiagnosticSink* sink) +bool lowerExistentials(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink) { SLANG_UNUSED(sink); - ExistentialLoweringContext context(module); + ExistentialLoweringContext context(module, targetProgram); context.processModule(); return true; }; diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-typeflow-insts.h index 72f595e5929..57a8e00420f 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-typeflow-insts.h @@ -32,6 +32,6 @@ bool lowerDispatchers(IRModule* module, DiagnosticSink* sink); // Lower `ExtractExistentialValue`, `ExtractExistentialType`, `ExtractExistentialWitnessTable`, // `InterfaceType`, `GetSequentialID`, `WitnessTableIDType` and `RTTIHandleType` instructions. // -bool lowerExistentials(IRModule* module, DiagnosticSink* sink); +bool lowerExistentials(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink); } // namespace Slang From 74a55d42acedeb0d253a552465f60bbd7e2bc3b1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 31 Oct 2025 11:22:45 -0400 Subject: [PATCH 088/105] Fix detection of COM objects during lowering --- .../slang/slang-ir-lower-typeflow-insts.cpp | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index bf946215245..06dbfab415c 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -1069,6 +1069,27 @@ void lowerTagTypes(IRModule* module) context.processModule(); } +bool isEffectivelyComPtrType(IRType* type) +{ + if (!type) + return false; + if (type->findDecoration() || type->getOp() == kIROp_ComPtrType) + { + return true; + } + if (auto witnessTableType = as(type)) + { + return isComInterfaceType((IRType*)witnessTableType->getConformanceType()); + } + if (auto ptrType = as(type)) + { + auto valueType = ptrType->getValueType(); + return valueType->findDecoration() != nullptr; + } + + return false; +} + // This context lowers `CastInterfaceToTaggedUnionPtr` and // `CastTaggedUnionToInterfacePtr` by finding all `IRLoad` and // `IRStore` uses of these insts, and upcasting the tagged-union @@ -1542,7 +1563,7 @@ struct ExistentialLoweringContext : public InstPassBase 0)); inst->removeAndDeallocate(); } - else if (auto comPtrType = as(inst->getOperand(0)->getDataType())) + else if (isEffectivelyComPtrType((IRType*)inst->getOperand(0)->getDataType())) { inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); @@ -1567,7 +1588,7 @@ struct ExistentialLoweringContext : public InstPassBase inst->removeAndDeallocate(); return true; } - else if (auto comPtrType = as(inst->getOperand(0)->getDataType())) + else if (isEffectivelyComPtrType((IRType*)inst->getOperand(0)->getDataType())) { inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); @@ -1628,7 +1649,7 @@ struct ExistentialLoweringContext : public InstPassBase inst->removeAndDeallocate(); return true; } - else if (auto comPtrType = as(inst->getOperand(0)->getDataType())) + else if (isEffectivelyComPtrType((IRType*)inst->getOperand(0)->getDataType())) { inst->replaceUsesWith(inst->getOperand(0)); inst->removeAndDeallocate(); From 74711f653b771340d5b128a4e52620f6111fd6f4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 31 Oct 2025 11:49:13 -0400 Subject: [PATCH 089/105] Fix-up uninitialized vars --- source/slang/slang-ir-typeflow-specialize.cpp | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 9a1d311ae7f..376e94a1741 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -2031,7 +2031,7 @@ struct TypeFlowSpecializationContext // cases, when it comes to specializing a function or placing a call to a // function, we will default to the single unbounded element case. // - IRInst* unboundedElement; + IRInst* unboundedElement = nullptr; forEachInSet( elementOfSetType->getSet(), [&](IRInst* element) @@ -2090,7 +2090,7 @@ struct TypeFlowSpecializationContext return elementOfSetType->getSet()->getElement(0); else if (elementOfSetType->getSet()->isUnbounded()) { - IRInst* unboundedElement; + IRInst* unboundedElement = nullptr; forEachInSet( elementOfSetType->getSet(), [&](IRInst* element) @@ -3818,22 +3818,16 @@ struct TypeFlowSpecializationContext auto typeSet = taggedUnion->getTypeSet(); IRInst* witnessTableTag = nullptr; - IRInst* typeTag = nullptr; if (auto witnessTable = as(inst->getWitnessTable())) { witnessTableTag = builder.emitGetTagOfElementInSet( (IRType*)makeTagType(tableSet), witnessTable, tableSet); - typeTag = builder.emitGetTagOfElementInSet( - (IRType*)makeTagType(typeSet), - inst->getDataType(), - typeSet); } else if (as(inst->getWitnessTable()->getDataType())) { witnessTableTag = upcastSet(&builder, inst->getWitnessTable(), makeTagType(tableSet)); - typeTag = nullptr; } // Create the appropriate any-value type From 31f1e21e4031a651c18e36edc711f70cd70d768d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:46:30 -0400 Subject: [PATCH 090/105] Move some set utilities into `IRSetBase`, use memory pools for lists and hashsets --- source/slang/slang-ir-insts.h | 45 ++++ source/slang/slang-ir-typeflow-specialize.cpp | 249 +++++++++--------- source/slang/slang-ir.cpp | 14 +- 3 files changed, 176 insertions(+), 132 deletions(-) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 19190ad8f14..b661ca4371e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2863,6 +2863,51 @@ struct IRSetBase : IRInst } return false; } + + IRInst* tryGetUnboundedElement() + { + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UnboundedTypeElement: + case kIROp_UnboundedWitnessTableElement: + case kIROp_UnboundedFuncElement: + case kIROp_UnboundedGenericElement: + return getElement(ii); + } + } + return nullptr; + } + + bool containsUninitializedElement() + { + // This is a "potentially uninitialized" set if any of its elements are unbounded. + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UninitializedTypeElement: + case kIROp_UninitializedWitnessTableElement: + return true; + } + } + return false; + } + + IRInst* tryGetUninitializedElement() + { + for (UInt ii = 0; ii < getOperandCount(); ++ii) + { + switch (getElement(ii)->getOp()) + { + case kIROp_UninitializedTypeElement: + case kIROp_UninitializedWitnessTableElement: + return getElement(ii); + } + } + return nullptr; + } }; FIDDLE() diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 376e94a1741..eb350911bbb 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -494,6 +494,12 @@ bool isConcreteType(IRInst* inst) return isConcreteType(ptrType->getValueType()); } + if (auto generic = as(inst)) + { + if (as(getGenericReturnVal(generic))) + return false; // Can be refined into set of concrete generics. + } + return true; } @@ -557,28 +563,6 @@ bool isOptionalExistentialType(IRInst* inst) return false; } -IRInst* maybeGetUninitializedElement(IRSetBase* set) -{ - IRInst* foundInst = nullptr; - forEachInSet( - set, - [&](IRInst* element) - { - if (auto uninitializedTypeElement = as(element)) - { - foundInst = uninitializedTypeElement; - } - else if ( - auto uninitializedWitnessTableElement = - as(element)) - { - foundInst = uninitializedWitnessTableElement; - } - }); - - return foundInst; -} - // Parent context for the full type-flow pass. struct TypeFlowSpecializationContext { @@ -852,14 +836,6 @@ struct TypeFlowSpecializationContext as(info1)); } - /* - if (as(info1) && as(info2)) - { - // If either info is unbounded, the union is unbounded - return makeUnbounded(); - } - */ - // For all other cases which are structured composites of sets, // we simply take the set union for all the set operands. // @@ -1409,16 +1385,22 @@ struct TypeFlowSpecializationContext return none(); } - auto tables = collectExistentialTables(interfaceType); + HashSet& tables = *module->getContainerPool().getHashSet(); + collectExistentialTables(interfaceType, tables); if (tables.getCount() > 0) - return makeTaggedUnionType( + { + auto resultTaggedUnionType = makeTaggedUnionType( as(builder.getSet(kIROp_WitnessTableSet, tables))); + module->getContainerPool().free(&tables); + return resultTaggedUnionType; + } else { sink->diagnose( inst, Diagnostics::noTypeConformancesFoundForInterface, interfaceType); + module->getContainerPool().free(&tables); return none(); } } @@ -1437,8 +1419,6 @@ struct TypeFlowSpecializationContext if (isComInterfaceType(inst->getDataType())) { return none(); - // return builder.getComPtrType(inst->getDataType()); - // return makeUnbounded(); } auto witnessTable = inst->getWitnessTable(); @@ -1453,9 +1433,6 @@ struct TypeFlowSpecializationContext if (!witnessTableInfo) return none(); - // if (as(witnessTableInfo)) - // return makeUnbounded(); - if (auto elementOfSetType = as(witnessTableInfo)) return makeTaggedUnionType(cast(elementOfSetType->getSet())); @@ -1535,12 +1512,24 @@ struct TypeFlowSpecializationContext { if (!isComInterfaceType(interfaceType)) { - auto tables = collectExistentialTables(interfaceType); + HashSet& tables = *module->getContainerPool().getHashSet(); + collectExistentialTables(interfaceType, tables); if (tables.getCount() > 0) - return makeTaggedUnionType(as( + { + auto resultTaggedUnionType = makeTaggedUnionType(as( builder.getSet(kIROp_WitnessTableSet, tables))); + module->getContainerPool().free(&tables); + return resultTaggedUnionType; + } else + { + sink->diagnose( + loadInst, + Diagnostics::noTypeConformancesFoundForInterface, + interfaceType); + module->getContainerPool().free(&tables); return none(); + } } else { @@ -1577,12 +1566,20 @@ struct TypeFlowSpecializationContext { if (!isComInterfaceType(interfaceType)) { - auto tables = collectExistentialTables(interfaceType); + HashSet& tables = *module->getContainerPool().getHashSet(); + collectExistentialTables(interfaceType, tables); if (tables.getCount() > 0) - return makeTaggedUnionType( + { + auto resultTaggedUnionType = makeTaggedUnionType( as(builder.getSet(kIROp_WitnessTableSet, tables))); + module->getContainerPool().free(&tables); + return resultTaggedUnionType; + } else + { + module->getContainerPool().free(&tables); return none(); + } } else { @@ -1839,7 +1836,7 @@ struct TypeFlowSpecializationContext if (auto elementOfSetType = as(witnessTableInfo)) { IRBuilder builder(module); - HashSet results; + HashSet& results = *module->getContainerPool().getHashSet(); forEachInSet( cast(elementOfSetType->getSet()), [&](IRInst* table) @@ -1862,7 +1859,10 @@ struct TypeFlowSpecializationContext results.add(lookupWitnessTableEntry(cast(table), key)); }); - return makeElementOfSetType(builder.getSet(results)); + auto resultSetType = makeElementOfSetType(builder.getSet(results)); + module->getContainerPool().free(&results); + + return resultSetType; } if (!witnessTableInfo) @@ -1893,9 +1893,9 @@ struct TypeFlowSpecializationContext if (auto taggedUnion = as(operandInfo)) { auto tableSet = taggedUnion->getWitnessTableSet(); - if (auto uninitElement = maybeGetUninitializedElement(tableSet)) + if (auto uninitElement = tableSet->tryGetUninitializedElement()) { - // Uninitialized element should contain + // TODO: Need a better diagnostic here. sink->diagnose( inst->sourceLoc, Diagnostics::noTypeConformancesFoundForInterface, @@ -1995,7 +1995,7 @@ struct TypeFlowSpecializationContext // Handle the 'many' or 'one' cases. if (as(operandInfo) || isGlobalInst(operand)) { - List specializationArgs; + List& specializationArgs = *module->getContainerPool().getList(); for (UInt i = 0; i < inst->getArgCount(); ++i) { // For concrete args, add as-is. @@ -2012,7 +2012,10 @@ struct TypeFlowSpecializationContext // If any of the args are 'empty' sets, we can't generate a specialization just yet. if (!argInfo) + { + module->getContainerPool().free(&specializationArgs); return none(); + } if (as(argInfo)) { @@ -2023,7 +2026,9 @@ struct TypeFlowSpecializationContext { if (elementOfSetType->getSet()->isSingleton()) specializationArgs.add(elementOfSetType->getSet()->getElement(0)); - else if (elementOfSetType->getSet()->isUnbounded()) + else if ( + auto unboundedElement = + elementOfSetType->getSet()->tryGetUnboundedElement()) { // Infinite set. // @@ -2031,15 +2036,6 @@ struct TypeFlowSpecializationContext // cases, when it comes to specializing a function or placing a call to a // function, we will default to the single unbounded element case. // - IRInst* unboundedElement = nullptr; - forEachInSet( - elementOfSetType->getSet(), - [&](IRInst* element) - { - if (as(element) || - as(element)) - unboundedElement = element; - }); IRBuilder builder(module); SLANG_ASSERT(unboundedElement); auto pureUnboundedSet = builder.getSingletonSet(unboundedElement); @@ -2088,17 +2084,10 @@ struct TypeFlowSpecializationContext { if (elementOfSetType->getSet()->isSingleton()) return elementOfSetType->getSet()->getElement(0); - else if (elementOfSetType->getSet()->isUnbounded()) + else if ( + auto unboundedElement = + elementOfSetType->getSet()->tryGetUnboundedElement()) { - IRInst* unboundedElement = nullptr; - forEachInSet( - elementOfSetType->getSet(), - [&](IRInst* element) - { - if (as(element)) - unboundedElement = element; - }); - SLANG_ASSERT(unboundedElement); IRBuilder builder(module); return makeUntaggedUnionType( cast(builder.getSingletonSet(unboundedElement))); @@ -2114,7 +2103,7 @@ struct TypeFlowSpecializationContext return type; }; - List newParamTypes; + List& newParamTypes = *module->getContainerPool().getList(); for (auto paramType : funcType->getParamTypes()) newParamTypes.add((IRType*)substituteSets(paramType)); IRBuilder builder(module); @@ -2123,6 +2112,7 @@ struct TypeFlowSpecializationContext newParamTypes.getCount(), newParamTypes.getBuffer(), (IRType*)substituteSets(funcType->getResultType())); + module->getContainerPool().free(&newParamTypes); } else if (auto typeInfo = tryGetInfo(context, inst->getDataType())) { @@ -2140,12 +2130,14 @@ struct TypeFlowSpecializationContext } else { + module->getContainerPool().free(&specializationArgs); return none(); } } else { // We don't have a type we can work with just yet. + module->getContainerPool().free(&specializationArgs); return none(); // No info for the type } @@ -2154,11 +2146,12 @@ struct TypeFlowSpecializationContext // Our func-type operand is not yet been lifted. // For now, we can't say anything. // + module->getContainerPool().free(&specializationArgs); return none(); } // Specialize each element in the set - HashSet specializedSet; + HashSet& specializedSet = *module->getContainerPool().getHashSet(); IRSetBase* set = nullptr; if (auto elementOfSetType = as(operandInfo)) @@ -2200,7 +2193,10 @@ struct TypeFlowSpecializationContext } IRBuilder builder(module); - return makeElementOfSetType(builder.getSet(specializedSet)); + auto resultSetType = makeElementOfSetType(builder.getSet(specializedSet)); + module->getContainerPool().free(&specializedSet); + module->getContainerPool().free(&specializationArgs); + return resultSetType; } if (!operandInfo) @@ -2681,7 +2677,8 @@ struct TypeFlowSpecializationContext bool specializeInstsInBlock(IRInst* context, IRBlock* block) { - List instsToLower; + List& instsToLower = *module->getContainerPool().getList(); + bool hasChanges = false; for (auto inst : block->getChildren()) instsToLower.add(inst); @@ -2689,6 +2686,7 @@ struct TypeFlowSpecializationContext for (auto inst : instsToLower) hasChanges |= specializeInst(context, inst); + module->getContainerPool().free(&instsToLower); return hasChanges; } @@ -3056,13 +3054,13 @@ struct TypeFlowSpecializationContext auto thisInstInfo = cast(tryGetInfo(context, inst)); if (thisInstInfo->getSet() != nullptr) { - List operands = {witnessTableInst, inst->getRequirementKey()}; + IRInst* operands[] = {witnessTableInst, inst->getRequirementKey()}; auto newInst = builder.emitIntrinsicInst( (IRType*)makeTagType(thisInstInfo->getSet()), kIROp_GetTagForMappedSet, - operands.getCount(), - operands.getBuffer()); + 2, + operands); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -3176,7 +3174,7 @@ struct TypeFlowSpecializationContext { if (auto valOfSetType = as(currentType)) { - HashSet setElements; + HashSet& setElements = *module->getContainerPool().getHashSet(); forEachInSet( valOfSetType->getSet(), [&](IRInst* element) { setElements.add(element); }); @@ -3197,6 +3195,7 @@ struct TypeFlowSpecializationContext // If this is a set, we need to create a new set with the new type IRBuilder builder(module); auto newSet = builder.getSet(kIROp_TypeSet, setElements); + module->getContainerPool().free(&setElements); return makeUntaggedUnionType(cast(newSet)); } else if (currentType == newType) @@ -3216,7 +3215,7 @@ struct TypeFlowSpecializationContext } else // Need to create a new set. { - HashSet setElements; + HashSet& setElements = *module->getContainerPool().getHashSet(); SLANG_ASSERT(!as(currentType) && !as(newType)); @@ -3226,6 +3225,7 @@ struct TypeFlowSpecializationContext // If this is a set, we need to create a new set with the new type IRBuilder builder(module); auto newSet = builder.getSet(kIROp_TypeSet, setElements); + module->getContainerPool().free(&setElements); return makeUntaggedUnionType(cast(newSet)); } } @@ -3237,17 +3237,22 @@ struct TypeFlowSpecializationContext { SLANG_UNUSED(key); - List extraParamTypes; + List& extraParamTypes = *module->getContainerPool().getList(); extraParamTypes.add((IRType*)makeTagType(tableSet)); auto innerFuncType = getEffectiveFuncTypeForSet(resultFuncSet); - List allParamTypes; + List& allParamTypes = *module->getContainerPool().getList(); allParamTypes.addRange(extraParamTypes); for (auto paramType : innerFuncType->getParamTypes()) allParamTypes.add(paramType); IRBuilder builder(module); - return builder.getFuncType(allParamTypes, innerFuncType->getResultType()); + auto resultFuncType = builder.getFuncType(allParamTypes, innerFuncType->getResultType()); + + module->getContainerPool().free(&extraParamTypes); + module->getContainerPool().free(&allParamTypes); + + return resultFuncType; } // Get an effective func type to use for the callee. @@ -3274,21 +3279,14 @@ struct TypeFlowSpecializationContext if (calleeSet->isUnbounded()) { - IRUnboundedFuncElement* unboundedFuncElement = nullptr; - forEachInSet( - calleeSet, - [&](IRInst* func) - { - if (as(func)) - unboundedFuncElement = as(func); - }); - SLANG_ASSERT(unboundedFuncElement); + IRUnboundedFuncElement* unboundedFuncElement = + cast(calleeSet->tryGetUnboundedElement()); return cast(unboundedFuncElement->getOperand(0)); } IRBuilder builder(module); - List paramTypes; + List& paramTypes = *module->getContainerPool().getList(); IRType* resultType = nullptr; auto updateParamType = [&](Index index, IRType* paramType) -> IRType* @@ -3315,7 +3313,7 @@ struct TypeFlowSpecializationContext } }; - List calleesToProcess; + List& calleesToProcess = *module->getContainerPool().getList(); forEachInSet(calleeSet, [&](IRInst* func) { calleesToProcess.add(func); }); for (auto context : calleesToProcess) @@ -3342,11 +3340,13 @@ struct TypeFlowSpecializationContext } } + module->getContainerPool().free(&calleesToProcess); + // // Add in extra parameter types for a call to a dynamic generic callee // - List extraParamTypes; + List& extraParamTypes = *module->getContainerPool().getList(); // If the any of the elements in the callee (or the callee itself in case // of a singleton) is a dynamic specialization, each non-singleton WitnessTableSet, @@ -3364,11 +3364,17 @@ struct TypeFlowSpecializationContext extraParamTypes.add((IRType*)makeTagType(tableSet)); } - List allParamTypes; + List& allParamTypes = *module->getContainerPool().getList(); allParamTypes.addRange(extraParamTypes); allParamTypes.addRange(paramTypes); - return builder.getFuncType(allParamTypes, resultType); + auto resultFuncType = builder.getFuncType(allParamTypes, resultType); + + module->getContainerPool().free(¶mTypes); + module->getContainerPool().free(&extraParamTypes); + module->getContainerPool().free(&allParamTypes); + + return resultFuncType; } IRFuncType* getEffectiveFuncType(IRInst* callee) @@ -3382,9 +3388,10 @@ struct TypeFlowSpecializationContext // For a `Specialize` instruction that has dynamic tag arguments, // extract all the tags and return them as a list. // - List getArgsForSetSpecializedGeneric(IRSpecialize* specializedCallee) + void addArgsForSetSpecializedGeneric( + IRSpecialize* specializedCallee, + List& outCallArgs) { - List callArgs; for (UInt ii = 0; ii < specializedCallee->getArgCount(); ii++) { auto specArg = specializedCallee->getArg(ii); @@ -3395,10 +3402,8 @@ struct TypeFlowSpecializationContext // if (auto tagType = as(argInfo)) if (as(tagType->getSet())) - callArgs.add(specArg); + outCallArgs.add(specArg); } - - return callArgs; } IRInst* maybeSpecializeCalleeType(IRInst* callee) @@ -3503,8 +3508,7 @@ struct TypeFlowSpecializationContext if (isNoneCallee(callee)) return false; - // IRInst* calleeTagInst = nullptr; - List callArgs; + List& callArgs = *module->getContainerPool().getList(); // This is a bit of a workaround for specialized callee's // whose function types haven't been specialized yet (can @@ -3607,11 +3611,10 @@ struct TypeFlowSpecializationContext else if (isSetSpecializedGeneric(setTag->getSet()->getElement(0))) { // Single element which is a set specialized generic. - callArgs.addRange(getArgsForSetSpecializedGeneric(cast(callee))); + addArgsForSetSpecializedGeneric(cast(callee), callArgs); callee = setTag->getSet()->getElement(0); auto funcType = getEffectiveFuncType(callee); - // callee->setFullType(funcType); IRBuilder builder(module); builder.setInsertInto(module); callee = builder.replaceOperand(&callee->typeUse, funcType); @@ -3651,7 +3654,6 @@ struct TypeFlowSpecializationContext // by the analysis. // auto funcType = getEffectiveFuncType(callee); - // callee->setFullType(funcType); IRBuilder builder(module); builder.setInsertInto(module); callee = builder.replaceOperand(&callee->typeUse, funcType); @@ -3667,7 +3669,6 @@ struct TypeFlowSpecializationContext for (auto paramType : oldFuncType->getParamTypes()) paramTypes.add(paramType); auto newFuncType = builder.getFuncType(paramTypes, resultType); - // callee->setFullType(newFuncType); builder.setInsertInto(module); callee = builder.replaceOperand(&callee->typeUse, newFuncType); } @@ -3746,7 +3747,6 @@ struct TypeFlowSpecializationContext auto newCall = builder.emitCallInst(calleeFuncType->getResultType(), callee, callArgs); inst->replaceUsesWith(newCall); inst->removeAndDeallocate(); - return true; } else if (calleeFuncType->getResultType() != inst->getFullType()) { @@ -3754,13 +3754,12 @@ struct TypeFlowSpecializationContext // need to update the result type. // inst->setFullType(calleeFuncType->getResultType()); - return true; - } - else - { - // Nothing changed. - return false; + changed = true; } + + module->getContainerPool().free(&callArgs); + + return changed; } bool specializeMakeStruct(IRInst* context, IRMakeStruct* inst) @@ -3960,12 +3959,12 @@ struct TypeFlowSpecializationContext kIROp_CastInterfaceToTaggedUnionPtr, 1, &bufferHandle); - List newLoadOperands = {newHandle, inst->getOperand(1)}; + IRInst* newLoadOperands[] = {newHandle, inst->getOperand(1)}; auto newLoad = builder.emitIntrinsicInst( specializedValType, inst->getOp(), - newLoadOperands.getCount(), - newLoadOperands.getBuffer()); + 2, + newLoadOperands); inst->replaceUsesWith(newLoad); inst->removeAndDeallocate(); @@ -4050,7 +4049,7 @@ struct TypeFlowSpecializationContext IRBuilder builder(inst); setInsertBeforeOrdinaryInst(&builder, inst); - List specOperands; + List& specOperands = *module->getContainerPool().getList(); specOperands.add(inst->getBase()); for (UInt ii = 0; ii < inst->getArgCount(); ii++) @@ -4061,6 +4060,7 @@ struct TypeFlowSpecializationContext kIROp_GetTagForSpecializedSet, specOperands.getCount(), specOperands.getBuffer()); + module->getContainerPool().free(&specOperands); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -4076,7 +4076,7 @@ struct TypeFlowSpecializationContext // For all other specializations, we'll 'drop' the dynamic tag information. bool changed = false; - List args; + List& args = *module->getContainerPool().getList(); for (UIndex i = 0; i < inst->getArgCount(); i++) { auto arg = inst->getArg(i); @@ -4114,10 +4114,9 @@ struct TypeFlowSpecializationContext inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); - return true; } - - return false; + module->getContainerPool().free(&args); + return changed; } bool specializeGetValueFromBoundInterface(IRInst* context, IRGetValueFromBoundInterface* inst) @@ -4300,12 +4299,12 @@ struct TypeFlowSpecializationContext auto firstElement = tagType->getSet()->getElement(0); auto interfaceType = as(as(firstElement)->getConformanceType()); - List args = {interfaceType, arg}; + IRInst* args[] = {interfaceType, arg}; auto newInst = builder.emitIntrinsicInst( (IRType*)builder.getUIntType(), kIROp_GetSequentialIDFromTag, - args.getCount(), - args.getBuffer()); + 2, + args); inst->replaceUsesWith(newInst); inst->removeAndDeallocate(); @@ -4495,10 +4494,8 @@ struct TypeFlowSpecializationContext return false; } - HashSet collectExistentialTables(IRInterfaceType* interfaceType) + void collectExistentialTables(IRInterfaceType* interfaceType, HashSet& outTables) { - HashSet tables; - IRWitnessTableType* targetTableType = nullptr; // First, find the IRWitnessTableType that wraps the given interfaceType for (auto use = interfaceType->firstUse; use; use = use->nextUse) @@ -4522,13 +4519,11 @@ struct TypeFlowSpecializationContext { if (witnessTable->getDataType() == targetTableType) { - tables.add(witnessTable); + outTables.add(witnessTable); } } } } - - return tables; } bool processModule() diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 340c1386f52..bf4c93062b7 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6613,16 +6613,20 @@ IRSetBase* IRBuilder::getSet(IROp op, const HashSet& elements) if (element->getParent()->getOp() != kIROp_ModuleInst) SLANG_ASSERT_FAILURE("createSet called with non-global operands"); - List sortedElements; + List* sortedElements = getModule()->getContainerPool().getList(); for (auto element : elements) - sortedElements.add(element); + sortedElements->add(element); // Sort elements by their unique IDs to ensure canonical ordering - sortedElements.sort( + sortedElements->sort( [&](IRInst* a, IRInst* b) -> bool { return getUniqueID(a) < getUniqueID(b); }); - return as( - emitIntrinsicInst(nullptr, op, sortedElements.getCount(), sortedElements.getBuffer())); + auto setBaseInst = as( + emitIntrinsicInst(nullptr, op, sortedElements->getCount(), sortedElements->getBuffer())); + + getModule()->getContainerPool().free(sortedElements); + + return setBaseInst; } IRSetBase* IRBuilder::getSet(const HashSet& elements) From ba249f3133504fa02da4fc4addb0e64e7f4b126b Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 31 Oct 2025 15:05:37 -0400 Subject: [PATCH 091/105] Add uninitialized object diagnostic + test --- source/slang/slang-diagnostic-defs.h | 5 +++++ source/slang/slang-ir-typeflow-specialize.cpp | 4 +--- tests/diagnostics/no-type-conformance.slang | 2 +- .../no-type-conformance.slang.expected | 8 -------- .../uninitialized-existential.slang | 19 +++++++++++++++++++ 5 files changed, 26 insertions(+), 12 deletions(-) delete mode 100644 tests/diagnostics/no-type-conformance.slang.expected create mode 100644 tests/diagnostics/uninitialized-existential.slang diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 4cfbeb2ada6..0cba75d11a7 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2883,6 +2883,11 @@ DIAGNOSTIC( noTypeConformancesFoundForInterface, "No type conformances are found for interface '$0'. Code generation for current target " "requires at least one implementation type present in the linkage.") +DIAGNOSTIC( + 50101, + Error, + dynamicDispatchOnPotentiallyUninitializedExistential, + "Cannot dynamically dispatch on potentially uninitialized interface object '$0'.") DIAGNOSTIC( 52000, diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index eb350911bbb..b2442f49de0 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -811,7 +811,6 @@ struct TypeFlowSpecializationContext if (areInfosEqual(info1, info2)) return info1; - // TODO: Move into utility function to avoid dropping information. if (as(info1) && as(info2)) { SLANG_ASSERT(info1->getOperand(1) == info2->getOperand(1)); @@ -1895,10 +1894,9 @@ struct TypeFlowSpecializationContext auto tableSet = taggedUnion->getWitnessTableSet(); if (auto uninitElement = tableSet->tryGetUninitializedElement()) { - // TODO: Need a better diagnostic here. sink->diagnose( inst->sourceLoc, - Diagnostics::noTypeConformancesFoundForInterface, + Diagnostics::dynamicDispatchOnPotentiallyUninitializedExistential, uninitElement->getOperand(0)); return none(); // We'll return none so that the analysis doesn't diff --git a/tests/diagnostics/no-type-conformance.slang b/tests/diagnostics/no-type-conformance.slang index 05ddc61db4d..08147268871 100644 --- a/tests/diagnostics/no-type-conformance.slang +++ b/tests/diagnostics/no-type-conformance.slang @@ -8,7 +8,7 @@ interface IFoo void foo() { - IFoo obj; + IFoo obj = createDynamicObject(0, 0); obj.get(); } diff --git a/tests/diagnostics/no-type-conformance.slang.expected b/tests/diagnostics/no-type-conformance.slang.expected deleted file mode 100644 index ebd9482ec0b..00000000000 --- a/tests/diagnostics/no-type-conformance.slang.expected +++ /dev/null @@ -1,8 +0,0 @@ -result code = -1 -standard error = { -tests/diagnostics/no-type-conformance.slang(12): error 50100: No type conformances are found for interface 'IFoo'. Code generation for current target requires at least one implementation type present in the linkage. -interface IFoo - ^~~~ -} -standard output = { -} diff --git a/tests/diagnostics/uninitialized-existential.slang b/tests/diagnostics/uninitialized-existential.slang new file mode 100644 index 00000000000..26a5ed3a4ad --- /dev/null +++ b/tests/diagnostics/uninitialized-existential.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK):-target hlsl -entry computeMain -stage compute -o no-type-conformance.hlsl + +//CHECK: error 50101: Cannot dynamically dispatch on potentially uninitialized interface object 'IFoo' +interface IFoo +{ + float get(); +} + +void foo() +{ + IFoo obj; + obj.get(); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + foo(); +} \ No newline at end of file From 86947502699879084e291d452f8255bdaa3ffbdb Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 3 Nov 2025 12:06:05 -0500 Subject: [PATCH 092/105] Fix for Falcor test. --- source/slang/slang-ir-insts-stable-names.lua | 3 + source/slang/slang-ir-insts.h | 13 + source/slang/slang-ir-insts.lua | 14 +- .../slang/slang-ir-lower-typeflow-insts.cpp | 35 +++ source/slang/slang-ir-typeflow-collection.cpp | 4 +- source/slang/slang-ir-typeflow-specialize.cpp | 255 +++++++++++++----- source/slang/slang-ir.cpp | 15 +- .../types/optional-ifoo-3.slang | 35 +++ 8 files changed, 299 insertions(+), 75 deletions(-) create mode 100644 tests/language-feature/types/optional-ifoo-3.slang diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 98f1be3eebc..0d899c6275e 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -715,4 +715,7 @@ return { ["UnboundedGenericElement"] = 711, ["UninitializedTypeElement"] = 712, ["UninitializedWitnessTableElement"] = 713, + ["NoneTypeElement"] = 714, + ["NoneWitnessTableElement"] = 716, + ["GetTagForSubSet"] = 718 } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index b661ca4371e..59c4c840340 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2846,6 +2846,7 @@ struct IRSetBase : IRInst FIDDLE(baseInst()) UInt getCount() { return getOperandCount(); } IRInst* getElement(UInt idx) { return getOperand(idx); } + bool isEmpty() { return getOperandCount() == 0; } bool isSingleton() { return (getOperandCount() == 1) && !isUnbounded(); } bool isUnbounded() { @@ -4325,6 +4326,18 @@ struct IRBuilder emitIntrinsicInst(nullptr, kIROp_UninitializedWitnessTableElement, 1, &interfaceType)); } + IRNoneTypeElement* getNoneTypeElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_NoneTypeElement, 0, nullptr)); + } + + IRNoneWitnessTableElement* getNoneWitnessTableElement() + { + return cast( + emitIntrinsicInst(nullptr, kIROp_NoneWitnessTableElement, 0, nullptr)); + } + IRGetTagOfElementInSet* emitGetTagOfElementInSet( IRType* tagType, IRInst* element, diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 3f05d4954ce..65647ab2d0c 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2729,7 +2729,7 @@ local insts = { { TypeSet = {} }, { FuncSet = {} }, { WitnessTableSet = {} }, - { GenericSet = {} }, + { GenericSet = {} } }, }, { UnboundedSet = { @@ -2763,6 +2763,12 @@ local insts = { -- Operands: (the tag for the source set) -- The source and destination sets are implied by the type of the operand and the type of the result } }, + { GetTagForSubSet = { + -- Translate a tag from a set to its equivalent in a sub-set + -- + -- Operands: (the tag for the source set) + -- The source and destination sets are implied by the type of the operand and the type of the result + } }, { GetTagForMappedSet = { -- Translate a tag from a set to its equivalent in a different set -- based on a mapping induced by a lookup key @@ -2886,6 +2892,12 @@ local insts = { hoistable = true, operands = { {"baseInterfaceType"} } } }, + { NoneTypeElement = { + hoistable = true + } }, + { NoneWitnessTableElement = { + hoistable = true + } }, } -- A function to calculate some useful properties and put it in the table, diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 06dbfab415c..8efa64df572 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -440,6 +440,22 @@ struct TagOpsLoweringContext : public InstPassBase inst->removeAndDeallocate(); } + void lowerGetTagForSubSet(IRGetTagForSubSet* inst) + { + // `GetTagForSubSet` is a no-op since we want to translate the tag + // for an element in the sub-set to a tag for the same element in a sub-set. + // It is assumed that this operation has already been confirmed as safe. + // + // Since all elements have a unique ID across the module, this is the identity operation. + // + + IRBuilder builder(inst->getModule()); + builder.setInsertAfter(inst); + inst->replaceUsesWith(builder.emitCast(inst->getDataType(), inst->getOperand(0), true)); + inst->removeAndDeallocate(); + } + + void lowerGetTagForMappedSet(IRGetTagForMappedSet* inst) { // `GetTagForMappedSet` turns into a integer mapping from @@ -519,6 +535,9 @@ struct TagOpsLoweringContext : public InstPassBase case kIROp_GetTagForSuperSet: lowerGetTagForSuperSet(as(inst)); break; + case kIROp_GetTagForSubSet: + lowerGetTagForSubSet(as(inst)); + break; case kIROp_GetTagForMappedSet: lowerGetTagForMappedSet(as(inst)); break; @@ -771,6 +790,17 @@ struct SetLoweringContext : public InstPassBase { types.add(type); } + else if (auto noneType = as(valueOfSetType->getSet()->getElement(i))) + { + // Can safely skip. (effectively 0 size) + } + else + { + // Should either be a type or NoneTypeElement (if other cases are okay, need to + // handle them here) + // + SLANG_UNEXPECTED("Expected type element in UntaggedUnionType set"); + } } IRBuilder builder(module); @@ -783,6 +813,11 @@ struct SetLoweringContext : public InstPassBase processInstsOfType( kIROp_UntaggedUnionType, [&](IRUntaggedUnionType* inst) { return lowerUntaggedUnionType(inst); }); + + IRBuilder builder(module); + auto noneTypeElement = builder.getNoneTypeElement(); + noneTypeElement->replaceUsesWith(builder.getVoidType()); + noneTypeElement->removeAndDeallocate(); } private: diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-collection.cpp index f1c0a1f76fa..8c49ecc824b 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-collection.cpp @@ -85,9 +85,9 @@ IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) if (argInfo != destInfo) { auto argSet = as(argInfo)->getSet(); - if (argSet->isSingleton() && as(argSet->getElement(0))) + if (argSet->isSingleton() && as(argSet->getElement(0))) { - // There's a specific case where we're trying to reinterpret a value of 'void' + // There's a specific case where we're trying to reinterpret a value of 'none' // type. We'll avoid emitting a reinterpret in this case, and emit a // default-construct instead. // diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index b2442f49de0..e11348d8810 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -427,8 +427,7 @@ IRInst* lookupWitnessTableEntry(IRWitnessTable* table, IRInst* key) } else if (as(table->getConcreteType())) { - IRBuilder builder(table->getModule()); - return builder.getVoidValue(); + SLANG_UNEXPECTED("Looking up entry on 'none' witness table"); } return nullptr; @@ -522,7 +521,36 @@ IRInst* makeInfoForConcreteType(IRModule* module, IRInst* type) getArrayStride(arrayType)); } - return builder.getUntaggedUnionType(cast(builder.getSingletonSet(type))); + return builder.getUntaggedUnionType( + cast(builder.getSingletonSet(kIROp_TypeSet, type))); +} + +IROp getSetOpFromDataType(IRType* type) +{ + switch (type->getOp()) + { + case kIROp_WitnessTableType: // Can be refined into set of concrete tables + return kIROp_WitnessTableSet; + case kIROp_FuncType: // Can be refined into set of concrete functions + return kIROp_FuncSet; + case kIROp_GenericKind: // Can be refined into set of concrete generics + return kIROp_GenericSet; + case kIROp_TypeKind: // Can be refined into set of concrete types + return kIROp_TypeSet; + case kIROp_TypeType: // Can be refined into set of concrete types + return kIROp_TypeSet; + default: + break; + } + + if (auto generic = as(type)) + { + auto innerValType = getGenericReturnVal(generic); + if (as(innerValType) || as(innerValType)) + return kIROp_GenericSet; // Can be refined into set of concrete generics. + } + + return kIROp_Invalid; } // Helper to check if an IRParam is a function parameter (vs. a phi param or generic param) @@ -595,6 +623,9 @@ struct TypeFlowSpecializationContext as(witnessTable) ->getBaseInterfaceType())); break; + case kIROp_NoneWitnessTableElement: + typeSet.add(builder.getNoneTypeElement()); + break; } }); @@ -1155,6 +1186,9 @@ struct TypeFlowSpecializationContext case kIROp_MakeOptionalValue: info = analyzeMakeOptionalValue(context, as(inst)); break; + case kIROp_GetOptionalValue: + info = analyzeGetOptionalValue(context, as(inst)); + break; } // If we didn't get any info from inst-specific analysis, we'll try to get @@ -1322,8 +1356,8 @@ struct TypeFlowSpecializationContext if (isConcreteType(concreteReturnType)) { IRBuilder builder(module); - returnInfo = builder.getUntaggedUnionType( - cast(builder.getSingletonSet(concreteReturnType))); + returnInfo = builder.getUntaggedUnionType(cast( + builder.getSingletonSet(kIROp_TypeSet, concreteReturnType))); } } @@ -1423,8 +1457,8 @@ struct TypeFlowSpecializationContext auto witnessTable = inst->getWitnessTable(); // Concrete case. if (as(witnessTable)) - return makeTaggedUnionType( - as(builder.getSingletonSet(witnessTable))); + return makeTaggedUnionType(as( + builder.getSingletonSet(kIROp_WitnessTableSet, witnessTable))); // Get the witness table info auto witnessTableInfo = tryGetInfo(context, witnessTable); @@ -1538,8 +1572,9 @@ struct TypeFlowSpecializationContext else if ( auto boundInterfaceType = as(loadInst->getDataType())) { - return makeTaggedUnionType(cast( - builder.getSingletonSet(boundInterfaceType->getWitnessTable()))); + return makeTaggedUnionType(cast(builder.getSingletonSet( + kIROp_WitnessTableSet, + boundInterfaceType->getWitnessTable()))); } else { @@ -1587,8 +1622,9 @@ struct TypeFlowSpecializationContext } else if (auto boundInterfaceType = as(inst->getDataType())) { - return makeTaggedUnionType(cast( - builder.getSingletonSet(boundInterfaceType->getWitnessTable()))); + return makeTaggedUnionType(cast(builder.getSingletonSet( + kIROp_WitnessTableSet, + boundInterfaceType->getWitnessTable()))); } } @@ -1702,39 +1738,13 @@ struct TypeFlowSpecializationContext return none(); } - // Locate the 'none' witness table in the global scope - // of the module in context. This will be the table - // that conforms to 'nullptr' and has 'void' as the concrete type - // - IRWitnessTable* findNoneWitness() - { - IRBuilder builder(module); - auto voidType = builder.getVoidType(); - for (auto inst : module->getGlobalInsts()) - { - if (auto witnessTable = as(inst)) - { - if (witnessTable->getConcreteType() == voidType && - witnessTable->getConformanceType() == nullptr) - return witnessTable; - } - } - - return nullptr; - } - // Get the witness table inst to be used for the 'none' case of // an optional witness table. // - IRWitnessTable* getNoneWitness() + IRInst* getNoneWitness() { - if (auto table = findNoneWitness()) - return table; - IRBuilder builder(module); - auto voidType = builder.getVoidType(); - - return builder.createWitnessTable(nullptr, voidType); + return builder.getNoneWitnessTableElement(); } IRInst* analyzeMakeOptionalNone(IRInst* context, IRMakeOptionalNone* inst) @@ -1794,20 +1804,18 @@ struct TypeFlowSpecializationContext IRInst* analyzeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) { - if (isOptionalExistentialType(inst->getDataType())) + if (isOptionalExistentialType(inst->getOperand(0)->getDataType())) { - // This is an interesting case.. technically, at this point we could go - // from a larger set to a smaller one (without the none-type). - // - // However, for simplicitly reasons, we currently only allow up-casting, - // so for now we'll just passthrough all types (so the result will - // assume that 'none-type' is a possiblity even though we statically know - // that it isn't). - // + // TODO: Document. if (auto info = tryGetInfo(context, inst->getOperand(0))) { SLANG_ASSERT(as(info)); - return info; + + IRBuilder builder(module); + auto taggedUnion = as(info); + return builder.getTaggedUnionType( + cast(filterNoneElements(taggedUnion->getWitnessTableSet())), + cast(filterNoneElements(taggedUnion->getTypeSet()))); } } @@ -1858,7 +1866,8 @@ struct TypeFlowSpecializationContext results.add(lookupWitnessTableEntry(cast(table), key)); }); - auto resultSetType = makeElementOfSetType(builder.getSet(results)); + auto setOp = getSetOpFromDataType(inst->getDataType()); + auto resultSetType = makeElementOfSetType(builder.getSet(setOp, results)); module->getContainerPool().free(&results); return resultSetType; @@ -2036,7 +2045,8 @@ struct TypeFlowSpecializationContext // IRBuilder builder(module); SLANG_ASSERT(unboundedElement); - auto pureUnboundedSet = builder.getSingletonSet(unboundedElement); + auto setOp = getSetOpFromDataType(inst->getArg(i)->getDataType()); + auto pureUnboundedSet = builder.getSingletonSet(setOp, unboundedElement); if (auto typeSet = as(pureUnboundedSet)) specializationArgs.add(makeUntaggedUnionType(typeSet)); else @@ -2087,8 +2097,8 @@ struct TypeFlowSpecializationContext elementOfSetType->getSet()->tryGetUnboundedElement()) { IRBuilder builder(module); - return makeUntaggedUnionType( - cast(builder.getSingletonSet(unboundedElement))); + return makeUntaggedUnionType(cast( + builder.getSingletonSet(kIROp_TypeSet, unboundedElement))); } else return makeUntaggedUnionType( @@ -2191,7 +2201,8 @@ struct TypeFlowSpecializationContext } IRBuilder builder(module); - auto resultSetType = makeElementOfSetType(builder.getSet(specializedSet)); + auto setOp = getSetOpFromDataType(inst->getDataType()); + auto resultSetType = makeElementOfSetType(builder.getSet(setOp, specializedSet)); module->getContainerPool().free(&specializedSet); module->getContainerPool().free(&specializationArgs); return resultSetType; @@ -2269,13 +2280,24 @@ struct TypeFlowSpecializationContext true, workQueue); } - else if (as(arg) || as(arg)) + else if (as(arg)) { IRBuilder builder(module); updateInfo( context, param, - makeElementOfSetType(builder.getSingletonSet(arg)), + makeElementOfSetType(builder.getSingletonSet(kIROp_TypeSet, arg)), + true, + workQueue); + } + else if (as(arg)) + { + IRBuilder builder(module); + updateInfo( + context, + param, + makeElementOfSetType( + builder.getSingletonSet(kIROp_WitnessTableSet, arg)), true, workQueue); } @@ -2315,9 +2337,6 @@ struct TypeFlowSpecializationContext auto propagateToCallSite = [&](IRInst* callee) { - if (as(callee)) - return; - if (as(callee)) { // An unbounded element represents an unknown function, @@ -3036,6 +3055,13 @@ struct TypeFlowSpecializationContext inst->removeAndDeallocate(); return true; } + else if (elementOfSetType->getSet()->isEmpty()) + { + auto poison = builder.emitPoison(inst->getDataType()); + inst->replaceUsesWith(poison); + inst->removeAndDeallocate(); + return true; + } // If we reach here, we have a truly dynamic case. Multiple elements. // We need to emit a run-time inst to keep track of the tag. @@ -3099,6 +3125,12 @@ struct TypeFlowSpecializationContext inst->removeAndDeallocate(); return true; } + else if (elementOfSetType->getSet()->isEmpty()) + { + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + inst->removeAndDeallocate(); + return true; + } else { // Replace with GetElement(specializedInst, 0) -> TagType(tableSet) @@ -3134,6 +3166,15 @@ struct TypeFlowSpecializationContext inst->removeAndDeallocate(); return true; } + else if (as(existential)) + { + IRBuilder builder(inst); + builder.setInsertAfter(inst); + + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + inst->removeAndDeallocate(); + return true; + } return false; } @@ -3151,6 +3192,14 @@ struct TypeFlowSpecializationContext inst->removeAndDeallocate(); return true; } + else if (elementOfSetType->getSet()->isEmpty()) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); + inst->removeAndDeallocate(); + return true; + } else { // Multiple elements, emit a tag inst. @@ -3378,7 +3427,8 @@ struct TypeFlowSpecializationContext IRFuncType* getEffectiveFuncType(IRInst* callee) { IRBuilder builder(module); - return getEffectiveFuncTypeForSet(cast(builder.getSingletonSet(callee))); + return getEffectiveFuncTypeForSet( + cast(builder.getSingletonSet(kIROp_FuncSet, callee))); } // Helper function for specializing calls. @@ -3419,6 +3469,40 @@ struct TypeFlowSpecializationContext return callee; } + IRSetBase* filterNoneElements(IRSetBase* set) + { + auto setOp = set->getOp(); + + IRBuilder builder(module); + HashSet& filteredElements = *module->getContainerPool().getHashSet(); + bool containsNone = false; + forEachInSet( + set, + [&](IRInst* element) + { + switch (element->getOp()) + { + case kIROp_NoneTypeElement: + case kIROp_NoneWitnessTableElement: + containsNone = true; + break; + default: + filteredElements.add(element); + break; + } + }); + + if (!containsNone) + { + module->getContainerPool().free(&filteredElements); + return set; + } + + auto newFuncSet = cast(builder.getSet(setOp, filteredElements)); + module->getContainerPool().free(&filteredElements); + return newFuncSet; + } + bool specializeCall(IRInst* context, IRCall* inst) { // The overall goal is to remove any dynamic-ness in the call inst @@ -3521,7 +3605,7 @@ struct TypeFlowSpecializationContext // if (auto setTag = as(callee->getDataType())) { - if (!setTag->isSingleton()) + if (!setTag->isSingleton() && !setTag->getSet()->isEmpty()) { // Multiple callees case: // @@ -3628,7 +3712,7 @@ struct TypeFlowSpecializationContext "Unexpected operand type for type-flow specialization of Call inst"); } } - else if (as(callee)) + else if (as(callee)) { // Occasionally, we will determine that there are absolutely no possible callees // for a call site. This typically happens to impossible branches. @@ -4358,16 +4442,17 @@ struct TypeFlowSpecializationContext auto noneWitnessTable = taggedUnionType->getWitnessTableSet()->getElement(0); auto singletonWitnessTableTagType = - makeTagType(builder.getSingletonSet(noneWitnessTable)); + makeTagType(builder.getSingletonSet(kIROp_WitnessTableSet, noneWitnessTable)); IRInst* tableTag = builder.emitGetTagOfElementInSet( (IRType*)singletonWitnessTableTagType, noneWitnessTable, taggedUnionType->getWitnessTableSet()); - auto singletonTypeTagType = makeTagType(builder.getSingletonSet(builder.getVoidType())); + auto singletonTypeTagType = + makeTagType(builder.getSingletonSet(kIROp_TypeSet, builder.getNoneTypeElement())); IRInst* typeTag = builder.emitGetTagOfElementInSet( (IRType*)singletonTypeTagType, - builder.getVoidType(), + builder.getNoneTypeElement(), taggedUnionType->getTypeSet()); auto newTuple = builder.emitMakeTaggedUnion( @@ -4411,16 +4496,48 @@ struct TypeFlowSpecializationContext bool specializeGetOptionalValue(IRInst* context, IRGetOptionalValue* inst) { SLANG_UNUSED(context); - if (as(inst->getOptionalOperand()->getDataType())) + if (auto srcTaggedUnionType = + as(inst->getOptionalOperand()->getDataType())) { // Since `GetOptionalValue` is the reverse of `MakeOptionalValue`, and we treat // the latter as a no-op, then `GetOptionalValue` is also a no-op (we simply pass // the inner existential value as-is) // - auto newInst = inst->getOptionalOperand(); - inst->replaceUsesWith(newInst); - inst->removeAndDeallocate(); + IRBuilder builder(inst); + auto destTaggedUnionType = cast(tryGetInfo(context, inst)); + if (destTaggedUnionType != srcTaggedUnionType) + { + // If the source and destination tagged-union types are different, + // we need to emit a cast. + builder.setInsertBefore(inst); + IRInst* tag = builder.emitGetTagFromTaggedUnion(inst->getOptionalOperand()); + auto downcastedTag = builder.emitIntrinsicInst( + (IRType*)makeTagType(destTaggedUnionType->getWitnessTableSet()), + kIROp_GetTagForSubSet, + 1, + &tag); + + auto unpackedValue = builder.emitUnpackAnyValue( + getLoweredType(makeUntaggedUnionType(destTaggedUnionType->getTypeSet())), + builder.emitGetValueFromTaggedUnion(inst->getOptionalOperand())); + + auto newTaggedUnion = builder.emitMakeTaggedUnion( + getLoweredType(destTaggedUnionType), + builder.emitPoison(makeTagType(destTaggedUnionType->getTypeSet())), + downcastedTag, + unpackedValue); + + + inst->replaceUsesWith(newTaggedUnion); + inst->removeAndDeallocate(); + } + else + { + inst->replaceUsesWith(inst->getOptionalOperand()); + inst->removeAndDeallocate(); + } + return true; } return false; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index bf4c93062b7..7058748ae86 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6605,9 +6605,6 @@ IREntryPointLayout* IRBuilder::getEntryPointLayout( IRSetBase* IRBuilder::getSet(IROp op, const HashSet& elements) { - if (elements.getCount() == 0) - return nullptr; - // Verify that all operands are global instructions for (auto element : elements) if (element->getParent()->getOp() != kIROp_ModuleInst) @@ -6631,6 +6628,7 @@ IRSetBase* IRBuilder::getSet(IROp op, const HashSet& elements) IRSetBase* IRBuilder::getSet(const HashSet& elements) { + // Cannot call getSet with an empty set of elements and no specific op-code. SLANG_ASSERT(elements.getCount() > 0); auto firstElement = *elements.begin(); return getSet(getSetTypeForInst(firstElement), elements); @@ -8614,6 +8612,7 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetTagForMappedSet: case kIROp_GetTagForSpecializedSet: case kIROp_GetTagForSuperSet: + case kIROp_GetTagForSubSet: case kIROp_GetTagFromSequentialID: case kIROp_GetSequentialIDFromTag: case kIROp_CastInterfaceToTaggedUnionPtr: @@ -8630,6 +8629,16 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_MakeStorageTypeLoweringConfig: return false; + case kIROp_UnboundedFuncElement: + case kIROp_UnboundedTypeElement: + case kIROp_UnboundedWitnessTableElement: + case kIROp_UnboundedGenericElement: + case kIROp_UninitializedTypeElement: + case kIROp_UninitializedWitnessTableElement: + case kIROp_NoneTypeElement: + case kIROp_NoneWitnessTableElement: + return false; + case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: case kIROp_BackwardDifferentiatePrimal: diff --git a/tests/language-feature/types/optional-ifoo-3.slang b/tests/language-feature/types/optional-ifoo-3.slang new file mode 100644 index 00000000000..0299eb3e96f --- /dev/null +++ b/tests/language-feature/types/optional-ifoo-3.slang @@ -0,0 +1,35 @@ +//TEST:INTERPRET(filecheck=CHECK): + +interface IFoo +{ + T get_result(); +} + +struct FooImpl : IFoo +{ + T get_result() { return (T)data; } + int data; +} + +Optional generate_foo(int i) +{ + if (i % 5 == 0) + { + FooImpl result = {}; + result.data = i; + return { result }; + } + else + { + return {}; + } +} + +void main() +{ + // CHECK: hasValue: 1 + // CHECK-NEXT: result: 100.0 + let result_foo = generate_foo(100); + printf("hasValue: %d\n", (int)result_foo.hasValue); + printf("result: %f\n", (float)result_foo.value.get_result()); +} From 9640dc9db1b04a4384fa3e787ed96adc3328ada4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:06:55 -0500 Subject: [PATCH 093/105] Fix CI --- source/slang/slang-ir-insts-stable-names.lua | 4 +- source/slang/slang-ir-insts.lua | 55 +++++++++++++++++++ .../slang/slang-ir-lower-typeflow-insts.cpp | 10 ++-- source/slang/slang-ir-typeflow-specialize.cpp | 12 ++++ 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 0d899c6275e..ea0ac735562 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -716,6 +716,6 @@ return { ["UninitializedTypeElement"] = 712, ["UninitializedWitnessTableElement"] = 713, ["NoneTypeElement"] = 714, - ["NoneWitnessTableElement"] = 716, - ["GetTagForSubSet"] = 718 + ["NoneWitnessTableElement"] = 715, + ["GetTagForSubSet"] = 716 } diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 65647ab2d0c..95ffae6b7cc 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2871,31 +2871,86 @@ local insts = { hoistable = true } }, { UnboundedTypeElement = { + -- An element of TypeSet that represents an unbounded set of types conforming to + -- the given interface type. + -- + -- Used in cases where a finite set of types cannot be determined during type-flow analysis. + -- + -- Note that this is a set element, not a set in itself, so a TypeSet(A, B, UnboundedTypeElement(I)) + -- represents a set where we know two concrete types A and B, and any number of other types that conform to interface I. + -- hoistable = true, operands = { {"baseInterfaceType"} } } }, { UnboundedFuncElement = { + -- An element of FuncSet that represents an unbounded set of functions of a certain + -- func-type + -- + -- Used in cases where a finite set of functions cannot be determined during type-flow analysis. + -- + -- Similar to UnboundedTypeElement, this is a set element, not a set in itself. + -- hoistable = true, + operands = { {"funcType"} } } }, { UnboundedWitnessTableElement = { + -- An element of WitnessTableSet that represents an unbounded set of witness tables of a certain + -- interface type + -- + -- Used in cases where a finite set of witness tables cannot be determined during type-flow analysis. + -- + -- Similar to UnboundedTypeElement, this is a set element, not a set in itself. + -- hoistable = true, operands = { {"baseInterfaceType"} } } }, { UnboundedGenericElement = { + -- An element of GenericSet that represents an unbounded set of generics of a certain + -- interface type + -- + -- Used in cases where a finite set of generics cannot be determined during type-flow analysis. + -- + -- Similar to UnboundedTypeElement, this is a set element, not a set in itself. + -- hoistable = true, } }, { UninitializedTypeElement = { + -- An element that represents an uninitialized type of a certain interface. + -- + -- Used to denote cases where the type represented may be garbage (e.g. from a `LoadFromUninitializedMemory`) + -- + -- Similar to UnboundedXYZElement IR ops described above, this is a set element, not a set in itself. + -- e.g. a `TypeSet(A, B, UninitializedTypeElement(I))` represents a set where we know two concrete types A and B, + -- and an uninitialized type that conforms to interface I. + -- + -- Note: In practice, having any uninitialized type in a TypeSet will likely force the entire set to be treated as + -- uninitialized, and this element is mainly so that we can provide useful errors during the type-flow specialization pass. + -- hoistable = true, operands = { {"baseInterfaceType"} } } }, { UninitializedWitnessTableElement = { + -- An element that represents an uninitialized witness table of a certain interface. + -- + -- Used to denote cases where the witness table information may be garbage (e.g. from a `LoadFromUninitializedMemory`) + -- + -- Similar to UninitializedTypeElement, this is a set element, not a set in itself. + -- hoistable = true, operands = { {"baseInterfaceType"} } } }, { NoneTypeElement = { + -- An element that represents a default 'none' case (only relevant in the context of OptionalType) + -- + -- Similar to UnboundedXYZElement IR ops described above, this is a set element, not a set in itself. + -- hoistable = true } }, { NoneWitnessTableElement = { + -- An element that represents a default 'none' case (only relevant in the context of OptionalType) + -- + -- Similar to UnboundedXYZElement IR ops described above, this is a set element, not a set in itself. + -- hoistable = true } }, } diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-typeflow-insts.cpp index 8efa64df572..565f5149f01 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-typeflow-insts.cpp @@ -777,20 +777,20 @@ struct SetLoweringContext : public InstPassBase return true; } - void lowerUntaggedUnionType(IRUntaggedUnionType* valueOfSetType) + void lowerUntaggedUnionType(IRUntaggedUnionType* untaggedUnionType) { // Type collections are replaced with `AnyValueType` large enough to hold // any of the types in the collection. // HashSet types; - for (UInt i = 0; i < valueOfSetType->getSet()->getCount(); i++) + for (UInt i = 0; i < untaggedUnionType->getSet()->getCount(); i++) { - if (auto type = as(valueOfSetType->getSet()->getElement(i))) + if (auto type = as(untaggedUnionType->getSet()->getElement(i))) { types.add(type); } - else if (auto noneType = as(valueOfSetType->getSet()->getElement(i))) + else if (as(untaggedUnionType->getSet()->getElement(i))) { // Can safely skip. (effectively 0 size) } @@ -805,7 +805,7 @@ struct SetLoweringContext : public InstPassBase IRBuilder builder(module); auto anyValueType = createAnyValueType(&builder, types); - valueOfSetType->replaceUsesWith(anyValueType); + untaggedUnionType->replaceUsesWith(anyValueType); } void processModule() diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index e11348d8810..84df3342e1c 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -550,6 +550,18 @@ IROp getSetOpFromDataType(IRType* type) return kIROp_GenericSet; // Can be refined into set of concrete generics. } + // Slight workaround for the fact that we can have cases where the type has not been specialized + // yet (particularly from auto-diff) + // + if (auto specialize = as(type)) + { + auto innerValType = getGenericReturnVal(specialize->getBase()); + if (as(innerValType)) + return kIROp_FuncSet; + if (as(innerValType)) + return kIROp_WitnessTableSet; + } + return kIROp_Invalid; } From 1ad37453418d48eb8970314ac5cb1ee18b324d22 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:56:30 -0500 Subject: [PATCH 094/105] Add comments, rename some files to better reflect their contents, update submodules --- external/spirv-headers | 2 +- external/spirv-tools | 2 +- source/slang/slang-emit.cpp | 2 +- ...slang-ir-lower-dynamic-dispatch-insts.cpp} | 185 +++++++++++++----- ...> slang-ir-lower-dynamic-dispatch-insts.h} | 4 +- source/slang/slang-ir-lower-dynamic-insts.cpp | 0 source/slang/slang-ir-specialize.cpp | 4 +- ...llection.cpp => slang-ir-typeflow-set.cpp} | 2 +- ...w-collection.h => slang-ir-typeflow-set.h} | 2 +- source/slang/slang-ir-typeflow-specialize.cpp | 86 ++++---- 10 files changed, 194 insertions(+), 95 deletions(-) rename source/slang/{slang-ir-lower-typeflow-insts.cpp => slang-ir-lower-dynamic-dispatch-insts.cpp} (90%) rename source/slang/{slang-ir-lower-typeflow-insts.h => slang-ir-lower-dynamic-dispatch-insts.h} (91%) delete mode 100644 source/slang/slang-ir-lower-dynamic-insts.cpp rename source/slang/{slang-ir-typeflow-collection.cpp => slang-ir-typeflow-set.cpp} (99%) rename source/slang/{slang-ir-typeflow-collection.h => slang-ir-typeflow-set.h} (94%) diff --git a/external/spirv-headers b/external/spirv-headers index 01e0577914a..6bb105b6c4b 160000 --- a/external/spirv-headers +++ b/external/spirv-headers @@ -1 +1 @@ -Subproject commit 01e0577914a75a2569c846778c2f93aa8e6feddd +Subproject commit 6bb105b6c4b3a246e1e6bb96366fe14c6dbfde83 diff --git a/external/spirv-tools b/external/spirv-tools index 7f2d9ee926f..2fa0e29ba95 160000 --- a/external/spirv-tools +++ b/external/spirv-tools @@ -1 +1 @@ -Subproject commit 7f2d9ee926f98fc77a3ed1e1e0f113b8c9c49458 +Subproject commit 2fa0e29ba9541a44296897c3258ebbc9d121ac96 diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 4d52a03b458..3b22c774b3f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -78,6 +78,7 @@ #include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-lower-combined-texture-sampler.h" #include "slang-ir-lower-coopvec.h" +#include "slang-ir-lower-dynamic-dispatch-insts.h" #include "slang-ir-lower-dynamic-resource-heap.h" #include "slang-ir-lower-enum-type.h" #include "slang-ir-lower-glsl-ssbo-types.h" @@ -86,7 +87,6 @@ #include "slang-ir-lower-reinterpret.h" #include "slang-ir-lower-result-type.h" #include "slang-ir-lower-tuple-types.h" -#include "slang-ir-lower-typeflow-insts.h" #include "slang-ir-metadata.h" #include "slang-ir-metal-legalize.h" #include "slang-ir-missing-return.h" diff --git a/source/slang/slang-ir-lower-typeflow-insts.cpp b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp similarity index 90% rename from source/slang/slang-ir-lower-typeflow-insts.cpp rename to source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp index 565f5149f01..ac5fa009f5b 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp @@ -1,11 +1,11 @@ -#include "slang-ir-lower-typeflow-insts.h" +#include "slang-ir-lower-dynamic-dispatch-insts.h" #include "slang-ir-any-value-marshalling.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-specialize.h" -#include "slang-ir-typeflow-collection.h" +#include "slang-ir-typeflow-set.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -136,7 +136,15 @@ IRStringLit* _getWitnessTableWrapperFuncName(IRModule* module, IRFunc* func) return nullptr; } - +// Create a wrapper function that makes a specific function's signature match it's type in the +// interface requirement. +// +// e.g. the signature of the function from the caller's side might look like: ((ThisType, float) -> +// ThisType) while the actual implementation function might be ((FooImpl, float) -> FooImpl). +// +// The witness table wrapper will marshal from the union-types (ThisType) to the concrete types +// (FooImpl) expected by the implementation. +// IRFunc* emitWitnessTableWrapper(IRModule* module, IRInst* funcInst, IRInst* interfaceRequirementVal) { auto funcTypeInInterface = cast(interfaceRequirementVal); @@ -285,6 +293,10 @@ IRFunc* createDispatchFunc(IRFuncType* dispatchFuncType, DictionarygetModule(), funcInst, innerFuncType); // Create case block @@ -716,9 +728,9 @@ bool lowerDispatchers(IRModule* module, DiagnosticSink* sink) } // This context lowers `TypeSet` instructions. -struct SetLoweringContext : public InstPassBase +struct UntaggedUnionLoweringContext : public InstPassBase { - SetLoweringContext( + UntaggedUnionLoweringContext( IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink = nullptr) @@ -735,6 +747,11 @@ struct SetLoweringContext : public InstPassBase if (size > maxSize) maxSize = size; + // We need to consider whether the concrete type can be stored in an `AnyValue`. + // + // For example, resource types that are not bindless on the target cannot be marshalled + // and we need to diagnose that. + // if (sink && !canTypeBeStored(type)) { sink->diagnose( @@ -808,16 +825,22 @@ struct SetLoweringContext : public InstPassBase untaggedUnionType->replaceUsesWith(anyValueType); } + // Replace any uses of `NoneTypeElement` with `VoidType`. + void replaceNoneTypeElementWithVoidType() + { + IRBuilder builder(module); + auto noneTypeElement = builder.getNoneTypeElement(); + noneTypeElement->replaceUsesWith(builder.getVoidType()); + noneTypeElement->removeAndDeallocate(); + } + void processModule() { processInstsOfType( kIROp_UntaggedUnionType, [&](IRUntaggedUnionType* inst) { return lowerUntaggedUnionType(inst); }); - IRBuilder builder(module); - auto noneTypeElement = builder.getNoneTypeElement(); - noneTypeElement->replaceUsesWith(builder.getVoidType()); - noneTypeElement->removeAndDeallocate(); + replaceNoneTypeElementWithVoidType(); } private: @@ -831,7 +854,7 @@ struct SetLoweringContext : public InstPassBase void lowerUntaggedUnionTypes(IRModule* module, TargetProgram* targetProgram, DiagnosticSink* sink) { SLANG_UNUSED(sink); - SetLoweringContext context(module, targetProgram, sink); + UntaggedUnionLoweringContext context(module, targetProgram, sink); context.processModule(); } @@ -924,12 +947,12 @@ struct SequentialIDTagLoweringContext : public InstPassBase { // Get unique ID for the witness table SLANG_UNUSED(cast(table)); - auto outputId = builder.getUniqueID(table); + auto inputId = builder.getUniqueID(table); auto seqDecoration = table->findDecoration(); if (seqDecoration) { - auto inputId = seqDecoration->getSequentialID(); - mapping.add({outputId, inputId}); + auto outputId = seqDecoration->getSequentialID(); + mapping.add({inputId, outputId}); } }); @@ -944,11 +967,15 @@ struct SequentialIDTagLoweringContext : public InstPassBase // Ensures every witness table object has been assigned a sequential ID. + // // All witness tables will have a SequentialID decoration after this function is run. - // The sequantial ID in the decoration will be the same as the one specified in the Linkage. + // + // The sequential ID in the decoration will be the same as the one specified in the Linkage. + // // Otherwise, a new ID will be generated and assigned to the witness table object, and - // the sequantial ID map in the Linkage will be updated to include the new ID, so they + // the sequential ID map in the Linkage will be updated to include the new ID, so they // can be looked up by the user via future Slang API calls. + // void ensureWitnessTableSequentialIDs() { StringBuilder generatedMangledName; @@ -1138,6 +1165,9 @@ struct TaggedUnionLoweringContext : public InstPassBase { } + // Extract the required components from an interface-typed value + // and create a tagged union tuple from the result. + // IRInst* convertToTaggedUnion( IRBuilder* builder, IRInst* val, @@ -1341,6 +1371,13 @@ struct TaggedUnionLoweringContext : public InstPassBase bool lowerGetTypeTagFromTaggedUnion(IRGetTypeTagFromTaggedUnion* inst) { + // `GetTypeTagFromTaggedUnion(taggedUnionVal)` is not expected to + // appear after lowering, since we currently don't need the type tag + // for anything. + // + // We'll replace it with a poison value so that any accidental uses will result in + // an error later on. + // IRBuilder builder(module); builder.setInsertAfter(inst); inst->replaceUsesWith(builder.emitPoison(inst->getDataType())); @@ -1387,37 +1424,35 @@ struct TaggedUnionLoweringContext : public InstPassBase inst->removeAndDeallocate(); }); - // TODO: Is this repeated scanning of the module inefficient? - // It feels like this form could be very efficient if it's automatically - // 'fused' together. - // - processInstsOfType( - kIROp_GetTagFromTaggedUnion, - [&](IRGetTagFromTaggedUnion* inst) { return lowerGetTagFromTaggedUnion(inst); }); - - processInstsOfType( - kIROp_GetTypeTagFromTaggedUnion, - [&](IRGetTypeTagFromTaggedUnion* inst) - { return lowerGetTypeTagFromTaggedUnion(inst); }); - - processInstsOfType( - kIROp_GetValueFromTaggedUnion, - [&](IRGetValueFromTaggedUnion* inst) { return lowerGetValueFromTaggedUnion(inst); }); - - processInstsOfType( - kIROp_MakeTaggedUnion, - [&](IRMakeTaggedUnion* inst) { return lowerMakeTaggedUnion(inst); }); - - // Then, convert any loads/stores from reinterpreted pointers. bool hasCastInsts = false; - processInstsOfType( - kIROp_CastInterfaceToTaggedUnionPtr, - [&](IRCastInterfaceToTaggedUnionPtr* inst) + processAllInsts( + [&](IRInst* inst) { - hasCastInsts = true; - return lowerCastInterfaceToTaggedUnionPtr(inst); + switch (inst->getOp()) + { + case kIROp_GetTagFromTaggedUnion: + lowerGetTagFromTaggedUnion(as(inst)); + break; + case kIROp_GetTypeTagFromTaggedUnion: + lowerGetTypeTagFromTaggedUnion(as(inst)); + break; + case kIROp_GetValueFromTaggedUnion: + lowerGetValueFromTaggedUnion(as(inst)); + break; + case kIROp_MakeTaggedUnion: + lowerMakeTaggedUnion(as(inst)); + break; + case kIROp_CastInterfaceToTaggedUnionPtr: + { + hasCastInsts = true; + lowerCastInterfaceToTaggedUnionPtr( + as(inst)); + } + break; + default: + break; + } }); - return hasCastInsts; } }; @@ -1430,6 +1465,9 @@ bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink) return context.processModule(); } +// Convert `IsType` insts into an boolean equality check on the sequential IDs of the +// witness table operands. +// void lowerIsTypeInsts(IRModule* module) { InstPassBase pass(module); @@ -1530,6 +1568,13 @@ struct ExistentialLoweringContext : public InstPassBase if (isBuiltin(interfaceType)) return (IRType*)builder.getIntValue(builder.getIntType(), 0); + // In the dynamic-dispatch case, a value of interface type + // is going to be packed into the "any value" part of a tuple. + // The size of the "any value" part depends on the interface + // type (e.g., it might have an `[anyValueSize(8)]` attribute + // indicating that 8 bytes needs to be reserved). + // + IRIntegerValue anyValueSize = 0; if (auto decor = interfaceType->findDecoration()) { @@ -1540,11 +1585,21 @@ struct ExistentialLoweringContext : public InstPassBase auto witnessTableType = builder.getWitnessTableIDType((IRType*)interfaceType); auto rttiType = builder.getRTTIHandleType(); + // In the ordinary (dynamic) case, an existential type decomposes + // into a tuple of: + // + // (RTTI, witness table, any-value). + // return builder.getTupleType(rttiType, witnessTableType, anyValueType); } IRInst* lowerBoundInterfaceType(IRBoundInterfaceType* boundInterfaceType) { + // A bound interface type represents an existential together with + // static knowledge that the value stored in the extistential has + // a particular concrete type. + // + IRBuilder builder(module); auto payloadType = boundInterfaceType->getConcreteType(); @@ -1562,6 +1617,27 @@ struct ExistentialLoweringContext : public InstPassBase auto anyValueType = builder.getAnyValueType(anyValueSize); + // Because static specialization is being used (at least in part), + // we do *not* have a guarantee that the `concreteType` is one + // that can fit into the `anyValueSize` of the interface. + // + // We will use the IR layout logic to see if we can compute + // a size for the type, which can lead to a few different outcomes: + // + // * If a size is computed successfully, and it is smaller than or + // equal to `anyValueSize`, then the concrete value will fit into + // the reserved area, and the layout will match the dynamic case. + // + // * If a size is computed successfully, and it is larger than + // `anyValueSize`, then the concrete value cannot fit into the + // reserved area, and it needs to be stored out-of-line. + // + // * If size cannot be computed, then that implies that the type + // includes non-ordinary data (e.g., a `Texture2D` on a D3D11 + // target), and cannot possible fit into the reserved area + // (which consists of only uniform bytes). In this case, the + // value must be stored out-of-line. + // IRSizeAndAlignment sizeAndAlignment; Result result = getNaturalSizeAndAlignment( targetProgram->getOptionSet(), @@ -1569,6 +1645,11 @@ struct ExistentialLoweringContext : public InstPassBase &sizeAndAlignment); if (SLANG_FAILED(result) || sizeAndAlignment.size > anyValueSize) { + // In the case where static specialization mandateds out-of-line storage, + // an existential type decomposes into a tuple of: + // + // (RTTI, witness table, pseudo pointer, any-value) + // return builder.getTupleType( rttiType, witnessTableType, @@ -1584,7 +1665,7 @@ struct ExistentialLoweringContext : public InstPassBase bool lowerExtractExistentialType(IRExtractExistentialType* inst) { - // Replace with extraction of the value type from the tagged-union tuple. + // Replace with extraction of the type as a value from the existential tuple. // IRBuilder builder(module); @@ -1608,7 +1689,7 @@ struct ExistentialLoweringContext : public InstPassBase bool lowerExtractExistentialWitnessTable(IRExtractExistentialWitnessTable* inst) { - // Replace with extraction of the value from the tagged-union tuple. + // Replace with extraction of the witness table identifier from the existential tuple. // IRBuilder builder(module); @@ -1669,7 +1750,7 @@ struct ExistentialLoweringContext : public InstPassBase bool lowerExtractExistentialValue(IRExtractExistentialValue* inst) { - // Replace with extraction of the value from the tagged-union tuple. + // Replace with extraction of the value payload from the existential tuple. // IRBuilder builder(module); @@ -1714,7 +1795,6 @@ struct ExistentialLoweringContext : public InstPassBase return true; } - UInt index = 0; auto id = builder.emitSwizzle(builder.getUIntType(), inst->getRTTIOperand(), 1, &index); inst->replaceUsesWith(id); @@ -1725,8 +1805,13 @@ struct ExistentialLoweringContext : public InstPassBase void processModule() { // Then, start lowering the remaining non-COM/non-Builtin interface types - // At this point, we should only bea dealing with public facing uses of - // interface types (which must lower into a 3-tuple of RTTI, witness table ID, AnyValue) + // At this point, we should only be dealing with public facing uses of + // interface types, as all internal uses would have been rewritten + // into known tagged-union types. + // + // These must lower into either a + // TupleType(RTTI, witness table ID, AnyValue) for regular interface types or a + // TupleType(RTTI, witness table ID, PseudoPtr, AnyValue) for bound interface types. // processInstsOfType( kIROp_InterfaceType, @@ -1772,7 +1857,7 @@ struct ExistentialLoweringContext : public InstPassBase }); // Replace any other uses with dummy value 0. - // TODO: Ideally, we should replace it with IRPoison.. + // TODO: Ideally, we should replace it with OpPoison? { IRBuilder builder(module); builder.setInsertInto(module); diff --git a/source/slang/slang-ir-lower-typeflow-insts.h b/source/slang/slang-ir-lower-dynamic-dispatch-insts.h similarity index 91% rename from source/slang/slang-ir-lower-typeflow-insts.h rename to source/slang/slang-ir-lower-dynamic-dispatch-insts.h index 57a8e00420f..761173a3b65 100644 --- a/source/slang/slang-ir-lower-typeflow-insts.h +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.h @@ -18,8 +18,8 @@ bool lowerTaggedUnionTypes(IRModule* module, DiagnosticSink* sink); // Lower `SetTagType` types void lowerTagTypes(IRModule* module); -// Lower `GetTagOfElementInSet`, -// `GetTagForSuperSet`, and `GetTagForMappedSet` instructions, +// Lower `GetTagOfElementInSet`, `GetTagForSuperSet`, `GetTagForSubSet` and `GetTagForMappedSet` +// instructions, // void lowerTagInsts(IRModule* module, DiagnosticSink* sink); diff --git a/source/slang/slang-ir-lower-dynamic-insts.cpp b/source/slang/slang-ir-lower-dynamic-insts.cpp deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index ecb05bd1aab..6c78d10445d 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -5,11 +5,11 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-insts.h" -#include "slang-ir-lower-typeflow-insts.h" +#include "slang-ir-lower-dynamic-dispatch-insts.h" #include "slang-ir-peephole.h" #include "slang-ir-sccp.h" #include "slang-ir-ssa-simplification.h" -#include "slang-ir-typeflow-collection.h" +#include "slang-ir-typeflow-set.h" #include "slang-ir-typeflow-specialize.h" #include "slang-ir-util.h" #include "slang-ir.h" diff --git a/source/slang/slang-ir-typeflow-collection.cpp b/source/slang/slang-ir-typeflow-set.cpp similarity index 99% rename from source/slang/slang-ir-typeflow-collection.cpp rename to source/slang/slang-ir-typeflow-set.cpp index 8c49ecc824b..bdf8a8d0712 100644 --- a/source/slang/slang-ir-typeflow-collection.cpp +++ b/source/slang/slang-ir-typeflow-set.cpp @@ -1,4 +1,4 @@ -#include "slang-ir-typeflow-collection.h" +#include "slang-ir-typeflow-set.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" diff --git a/source/slang/slang-ir-typeflow-collection.h b/source/slang/slang-ir-typeflow-set.h similarity index 94% rename from source/slang/slang-ir-typeflow-collection.h rename to source/slang/slang-ir-typeflow-set.h index bac8960d8ff..6431af6ac0b 100644 --- a/source/slang/slang-ir-typeflow-collection.h +++ b/source/slang/slang-ir-typeflow-set.h @@ -1,4 +1,4 @@ -// slang-ir-typeflow-collection.h +// slang-ir-typeflow-set.h #pragma once #include "slang-ir-insts.h" #include "slang-ir.h" diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index 84df3342e1c..b7f1b1f7cdf 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -5,7 +5,7 @@ #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" #include "slang-ir-specialize.h" -#include "slang-ir-typeflow-collection.h" +#include "slang-ir-typeflow-set.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -419,20 +419,6 @@ IRType* fromDirectionAndType(IRBuilder* builder, ParameterDirectionInfo info, IR } } -IRInst* lookupWitnessTableEntry(IRWitnessTable* table, IRInst* key) -{ - if (auto entry = findWitnessTableEntry(table, key)) - { - return entry; - } - else if (as(table->getConcreteType())) - { - SLANG_UNEXPECTED("Looking up entry on 'none' witness table"); - } - - return nullptr; -} - // Helper to test if an inst is in the global scope. bool isGlobalInst(IRInst* inst) { @@ -442,15 +428,16 @@ bool isGlobalInst(IRInst* inst) // This is fairly fundamental check: // This method checks whether a inst's type cannot accept any further refinement. // -// e.g. an inst of `UInt` type cannot be further refined (under the current scope -// of the type-flow pass), since it has a concrete type, and we do not -// track values. +// Examples: +// 1. an inst of `UInt` type cannot be further refined (under the current scope +// of the type-flow pass), since it has a concrete type, and we cannot replace it +// with a "narrower" type. // -// an inst of `InterfaceType` represents a tagged union of any type that implements +// 2. an inst of `InterfaceType` represents a tagged union of any type that implements // the interface, so it can be further refined by determining a smaller set of // possibilities (i.e. via `TaggedUnionType(tableSet, typeSet)`). // -// Similarly, an inst of `WitnessTableType` represents any witness table, +// 3. Similarly, an inst of `WitnessTableType` represents any witness table, // so it can accept a further refinement into `ElementOfSetType(tableSet)`. // // In the future, we may want to extend this check to something more nuanced, @@ -525,7 +512,15 @@ IRInst* makeInfoForConcreteType(IRModule* module, IRInst* type) cast(builder.getSingletonSet(kIROp_TypeSet, type))); } -IROp getSetOpFromDataType(IRType* type) +// Determines a suitable set opcode to use to represent a set of elements of the given type. +// +// e.g. an inst of WitnessTableType can be be refined using a WitnessTableSet (but it doesn't make +// sense to use a TypeSet or FuncSet). +// +// This is primarily used for `LookupWitnessMethod`, where the resulting element could be +// a generic, witness table, type or function, depending on the type of the lookup inst. +// +IROp getSetOpFromType(IRType* type) { switch (type->getOp()) { @@ -1875,10 +1870,10 @@ struct TypeFlowSpecializationContext return; } - results.add(lookupWitnessTableEntry(cast(table), key)); + results.add(findWitnessTableEntry(cast(table), key)); }); - auto setOp = getSetOpFromDataType(inst->getDataType()); + auto setOp = getSetOpFromType(inst->getDataType()); auto resultSetType = makeElementOfSetType(builder.getSet(setOp, results)); module->getContainerPool().free(&results); @@ -2057,7 +2052,7 @@ struct TypeFlowSpecializationContext // IRBuilder builder(module); SLANG_ASSERT(unboundedElement); - auto setOp = getSetOpFromDataType(inst->getArg(i)->getDataType()); + auto setOp = getSetOpFromType(inst->getArg(i)->getDataType()); auto pureUnboundedSet = builder.getSingletonSet(setOp, unboundedElement); if (auto typeSet = as(pureUnboundedSet)) specializationArgs.add(makeUntaggedUnionType(typeSet)); @@ -2213,7 +2208,7 @@ struct TypeFlowSpecializationContext } IRBuilder builder(module); - auto setOp = getSetOpFromDataType(inst->getDataType()); + auto setOp = getSetOpFromType(inst->getDataType()); auto resultSetType = makeElementOfSetType(builder.getSet(setOp, specializedSet)); module->getContainerPool().free(&specializedSet); module->getContainerPool().free(&specializationArgs); @@ -3041,7 +3036,7 @@ struct TypeFlowSpecializationContext // Handle trivial case where inst's operand is a concrete table. if (auto witnessTable = as(inst->getWitnessTable())) { - inst->replaceUsesWith(lookupWitnessTableEntry(witnessTable, inst->getRequirementKey())); + inst->replaceUsesWith(findWitnessTableEntry(witnessTable, inst->getRequirementKey())); inst->removeAndDeallocate(); return true; } @@ -3481,6 +3476,9 @@ struct TypeFlowSpecializationContext return callee; } + // Filter out `NoneTypeElement` and `NoneWitnessTableElement` from a set (if any exist) + // and construct a new set of the same kind. + // IRSetBase* filterNoneElements(IRSetBase* set) { auto setOp = set->getOp(); @@ -3506,10 +3504,12 @@ struct TypeFlowSpecializationContext if (!containsNone) { + // Return the same set if there were no None elements module->getContainerPool().free(&filteredElements); return set; } + // Create a new set without the filtered elements auto newFuncSet = cast(builder.getSet(setOp, filteredElements)); module->getContainerPool().free(&filteredElements); return newFuncSet; @@ -3729,10 +3729,12 @@ struct TypeFlowSpecializationContext // Occasionally, we will determine that there are absolutely no possible callees // for a call site. This typically happens to impossible branches. // - // The correct way to handle such cases is to improve the analysis to avoid - // branches that are impossible. For now, we will just remove the callee and - // replace with a default value. The exact value doesn't matter since we've determined - // that this code is unreachable. + // If this happens, the inst representing the callee would have been replaced + // with a poison value. In this case, we're simply going to replace the entire call + // with a default-constructed value of the appropriate type. + // + // Note that it doesn't matter what we replace it with since this code should be + // effectively unreachable. // IRBuilder builder(context); builder.setInsertBefore(inst); @@ -3759,9 +3761,11 @@ struct TypeFlowSpecializationContext { auto oldFuncType = as(callee->getDataType()); IRBuilder builder(module); + List paramTypes; for (auto paramType : oldFuncType->getParamTypes()) paramTypes.add(paramType); + auto newFuncType = builder.getFuncType(paramTypes, resultType); builder.setInsertInto(module); callee = builder.replaceOperand(&callee->typeUse, newFuncType); @@ -3795,7 +3799,7 @@ struct TypeFlowSpecializationContext break; } - // Out parameters are handled at the callee's end + // Upcasting of out-parameters is the responsibility of the callee. case ParameterDirectionInfo::Kind::Out: // For all other modes, sets must match ('subtyping' is not allowed) @@ -4511,10 +4515,20 @@ struct TypeFlowSpecializationContext if (auto srcTaggedUnionType = as(inst->getOptionalOperand()->getDataType())) { - // Since `GetOptionalValue` is the reverse of `MakeOptionalValue`, and we treat - // the latter as a no-op, then `GetOptionalValue` is also a no-op (we simply pass - // the inner existential value as-is) + // How we handle `GetOptionalValue` depends on whether our analysis + // shows that there can be a 'none' type in the input. + // + // If not, we can do a simple replace with the operand since the + // input and output are equivalent (this is a no-op) + // + // If so, then we will cast the tagged-union's tag to a sub-set (without the none), + // and unpack the untagged union-type'd value into the smaller union type. // + // Note that the union is "smaller" only in the sense that it doesn't have the 'none' + // type. Since a 'none' value takes up 0 space, the two union types will end up having + // the same size in the end. + // + IRBuilder builder(inst); auto destTaggedUnionType = cast(tryGetInfo(context, inst)); @@ -4657,7 +4671,7 @@ struct TypeFlowSpecializationContext { bool hasChanges = false; - // Phase 1: Information Propagation + // Part 1: Information Propagation // This phase propagates type information through the module // and records them into different maps in the current context. // @@ -4669,7 +4683,7 @@ struct TypeFlowSpecializationContext return false; } - // Phase 2: Dynamic Instruction Specialization + // Part 2: Dynamic Instruction Specialization // Re-write dynamic instructions into specialized versions based on the // type information in the previous phase. // From 9b9927b41f7d0fa39253ebc7ddaa008844f8b8fd Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:19:14 -0500 Subject: [PATCH 095/105] More documentation tweaks. Remove unused ops. Re-order sets during inst replacement --- source/slang/slang-ir-insts-stable-names.lua | 2 - source/slang/slang-ir-insts.h | 20 ++++------ source/slang/slang-ir-insts.lua | 22 ----------- .../slang-ir-lower-dynamic-dispatch-insts.cpp | 39 ++++++++++++------- .../slang-ir-lower-dynamic-dispatch-insts.h | 2 +- source/slang/slang-ir-specialize.cpp | 16 ++++---- source/slang/slang-ir-typeflow-set.cpp | 10 ++--- source/slang/slang-ir-typeflow-specialize.cpp | 33 +++++----------- source/slang/slang-ir.cpp | 34 +++++++++++++++- 9 files changed, 87 insertions(+), 91 deletions(-) diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index ea0ac735562..922f3976ef4 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -689,11 +689,9 @@ return { ["SetBase.FuncSet"] = 685, ["SetBase.WitnessTableSet"] = 686, ["SetBase.GenericSet"] = 687, - ["UnboundedSet"] = 688, ["Type.SetTagType"] = 689, ["Type.TaggedUnionType"] = 690, ["CastInterfaceToTaggedUnionPtr"] = 691, - ["CastTaggedUnionToInterfacePtr"] = 692, ["GetTagForSuperSet"] = 693, ["GetTagForMappedSet"] = 694, ["GetTagForSpecializedSet"] = 695, diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 59c4c840340..1c685859832 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4184,18 +4184,12 @@ struct IRBuilder IRMetalSetPrimitive* emitMetalSetPrimitive(IRInst* index, IRInst* primitive); IRMetalSetIndices* emitMetalSetIndices(IRInst* index, IRInst* indices); - // TODO: Move all the collection-based ops into the builder. - IRUnboundedSet* emitUnboundedSet() - { - return cast(emitIntrinsicInst(nullptr, kIROp_UnboundedSet, 0, nullptr)); - } - IRGetElementFromTag* emitGetElementFromTag(IRInst* tag) { auto tagType = cast(tag->getDataType()); - IRInst* collection = tagType->getSet(); - auto elementType = cast( - emitIntrinsicInst(nullptr, kIROp_ElementOfSetType, 1, &collection)); + IRInst* set = tagType->getSet(); + auto elementType = + cast(emitIntrinsicInst(nullptr, kIROp_ElementOfSetType, 1, &set)); return cast( emitIntrinsicInst(elementType, kIROp_GetElementFromTag, 1, &tag)); } @@ -4204,9 +4198,9 @@ struct IRBuilder { auto taggedUnionType = cast(tag->getDataType()); - IRInst* collection = taggedUnionType->getWitnessTableSet(); + IRInst* set = taggedUnionType->getWitnessTableSet(); auto tableTagType = - cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &collection)); + cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &set)); return cast( emitIntrinsicInst(tableTagType, kIROp_GetTagFromTaggedUnion, 1, &tag)); @@ -4216,9 +4210,9 @@ struct IRBuilder { auto taggedUnionType = cast(tag->getDataType()); - IRInst* collection = taggedUnionType->getTypeSet(); + IRInst* typeSet = taggedUnionType->getTypeSet(); auto typeTagType = - cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &collection)); + cast(emitIntrinsicInst(nullptr, kIROp_SetTagType, 1, &typeSet)); return cast( emitIntrinsicInst(typeTagType, kIROp_GetTypeTagFromTaggedUnion, 1, &tag)); diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index 95ffae6b7cc..d4c553799d5 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -2732,31 +2732,9 @@ local insts = { { GenericSet = {} } }, }, - { UnboundedSet = { - hoistable = true, - -- - -- A catch-all opcode to represent unbounded collections during - -- the type-flow specialization pass. - -- - -- This op is usually used to mark insts that can contain a dynamic type - -- whose information cannot be gleaned from the type-flow analysis. - -- - -- E.g. COM interface objects, whose implementations can be fully external to - -- the linkage - -- - -- This op is only used to denote that an inst is unbounded so the specialization - -- pass does not attempt to specialize it. It should not appear in the code after - -- the specialization pass. - -- - -- TODO: Consider the scenario where we can combine the unbounded case with known cases. - -- unbounded set should probably be an element and not a separate op. - } }, { CastInterfaceToTaggedUnionPtr = { -- Cast an interface-typed pointer to a tagged-union pointer with a known set. } }, - { CastTaggedUnionToInterfacePtr = { - -- Cast a tagged-union pointer with a known set to a corresponding interface-typed pointer. - } }, { GetTagForSuperSet = { -- Translate a tag from a set to its equivalent in a super-set -- diff --git a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp index ac5fa009f5b..c9d9b3f0b17 100644 --- a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp @@ -467,7 +467,6 @@ struct TagOpsLoweringContext : public InstPassBase inst->removeAndDeallocate(); } - void lowerGetTagForMappedSet(IRGetTagForMappedSet* inst) { // `GetTagForMappedSet` turns into a integer mapping from @@ -632,7 +631,6 @@ struct DispatcherLoweringContext : public InstPassBase } } - void lowerGetSpecializedDispatcher(IRGetSpecializedDispatcher* dispatcher) { // Replace the `IRGetSpecializedDispatcher` with a dispatch function, @@ -1152,8 +1150,7 @@ bool isEffectivelyComPtrType(IRType* type) return false; } -// This context lowers `CastInterfaceToTaggedUnionPtr` and -// `CastTaggedUnionToInterfacePtr` by finding all `IRLoad` and +// This context lowers `CastInterfaceToTaggedUnionPtr` by finding all `IRLoad` and // `IRStore` uses of these insts, and upcasting the tagged-union // tuple to the the interface-based tuple (of the loaded inst or before // storing the val, as necessary) @@ -1725,24 +1722,34 @@ struct ExistentialLoweringContext : public InstPassBase builder.setInsertAfter(inst); auto tupleType = as(inst->getOperand(0)->getDataType()); - + auto element = + builder.emitGetTupleElement((IRType*)tupleType->getOperand(2), inst->getOperand(0), 2); if (as(tupleType->getOperand(2))) { - inst->replaceUsesWith(builder.emitGetTupleElement( - (IRType*)tupleType->getOperand(2), - inst->getOperand(0), - 2)); + // The first case is when legacy static specialization + // is applied, and the element is a "pseudo-pointer." + // + // Semantically, we should emit a (pseudo-)load from the pseudo-pointer + // to go from `PseudoPtr` to `T`. + // + // TODO: Actually introduce and emit a "psedudo-load" instruction + // here. For right now we are just using the value directly and + // downstream passes seem okay with it, but it isn't really + // type-correct to be doing this. + // + inst->replaceUsesWith(element); inst->removeAndDeallocate(); return true; } else { - inst->replaceUsesWith(builder.emitUnpackAnyValue( - inst->getDataType(), - builder.emitGetTupleElement( - (IRType*)tupleType->getOperand(2), - inst->getOperand(0), - 2))); + // The second case is when the dynamic-dispatch layout is + // being used, and the element is an "any-value." + // + // In this case we need to emit an unpacking operation + // to get from `AnyValue` to `T`. + // + inst->replaceUsesWith(builder.emitUnpackAnyValue(inst->getDataType(), element)); inst->removeAndDeallocate(); return true; } @@ -1782,6 +1789,8 @@ struct ExistentialLoweringContext : public InstPassBase // If the operand is a witness table, it is already replaced with a uint2 // at this point, where the first element in the uint2 is the id of the // witness table. + // + IRBuilder builder(module); builder.setInsertBefore(inst); diff --git a/source/slang/slang-ir-lower-dynamic-dispatch-insts.h b/source/slang/slang-ir-lower-dynamic-dispatch-insts.h index 761173a3b65..d4c4ff064ca 100644 --- a/source/slang/slang-ir-lower-dynamic-dispatch-insts.h +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.h @@ -1,4 +1,4 @@ -// slang-ir-typeflow-specialize.h +// slang-ir-lower-dynamic-dispatch-insts.h #pragma once #include "slang-ir.h" diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 6c78d10445d..596df5474f0 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -863,14 +863,14 @@ struct SpecializationContext IRInterfaceType* interfaceType = nullptr; if (!witnessTable) { - if (auto collection = as(lookupInst->getWitnessTable())) + if (auto witnessTableSet = as(lookupInst->getWitnessTable())) { auto requirementKey = lookupInst->getRequirementKey(); HashSet satisfyingValSet; bool skipSpecialization = false; forEachInSet( - collection, + witnessTableSet, [&](IRInst* instElement) { if (auto table = as(instElement)) @@ -904,7 +904,7 @@ struct SpecializationContext else { // Should not see any other case. - SLANG_UNREACHABLE("unexpected collection type"); + SLANG_UNREACHABLE("unexpected set kind"); } return true; @@ -3254,25 +3254,25 @@ IRInst* specializeGenericWithSetArgs(IRSpecialize* specializeInst) for (auto param : generic->getFirstBlock()->getParams()) { auto specArg = specializeInst->getArg(argIndex++); - if (auto collection = as(specArg)) + if (auto set = as(specArg)) { // We're dealing with a set of types. if (as(param->getDataType())) { SLANG_ASSERT("Should not happen"); - cloneEnv.mapOldValToNew[param] = builder.getUntaggedUnionType(collection); + cloneEnv.mapOldValToNew[param] = builder.getUntaggedUnionType(set); } else if (as(param->getDataType())) { // For cloning parameter types, we want to just use the - // collection. + // set. // - staticCloningEnv.mapOldValToNew[param] = collection; + staticCloningEnv.mapOldValToNew[param] = set; // We'll create an integer parameter for all the rest of // the insts which will may need the runtime tag. // - auto tagType = (IRType*)builder.getSetTagType(collection); + auto tagType = (IRType*)builder.getSetTagType(set); // cloneEnv.mapOldValToNew[param] = builder.emitParam(tagType); extraParamMap.add(param, builder.emitParam(tagType)); extraParamTypes.add(tagType); diff --git a/source/slang/slang-ir-typeflow-set.cpp b/source/slang/slang-ir-typeflow-set.cpp index bdf8a8d0712..f70900d4f56 100644 --- a/source/slang/slang-ir-typeflow-set.cpp +++ b/source/slang/slang-ir-typeflow-set.cpp @@ -28,7 +28,7 @@ IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) if (as(argInfo) && as(destInfo)) { - // A collection tagged union is essentially a tuple(TagType(tableSet), + // A tagged union is essentially a tuple(TagType(tableSet), // typeSet) We simply extract the two components, upcast each one, and put it // back together. // @@ -60,12 +60,12 @@ IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) } else if (as(argInfo) && as(destInfo)) { - // If the arg represents a tag of a colleciton, but the dest is a _different_ - // collection, then we need to emit a tag operation to reinterpret the + // If the arg represents a tag of a set, but the dest is a _different_ + // set, then we need to emit a tag operation to reinterpret the // tag. // // Note that, by the invariant provided by the typeflow analysis, the target - // collection must necessarily be a super-set. + // set must necessarily be a super-set. // if (argInfo != destInfo) { @@ -74,7 +74,7 @@ IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) } else if (as(argInfo) && as(destInfo)) { - // If the arg has a collection type, but the dest is a _different_ collection, + // If the arg has a untagged union type, but the dest is a _different_ untagged union, // we need to perform a reinterpret. // // e.g. TypeSet({T1, T2}) may lower to AnyValueType(N), while diff --git a/source/slang/slang-ir-typeflow-specialize.cpp b/source/slang/slang-ir-typeflow-specialize.cpp index b7f1b1f7cdf..3fc863d039e 100644 --- a/source/slang/slang-ir-typeflow-specialize.cpp +++ b/source/slang/slang-ir-typeflow-specialize.cpp @@ -601,7 +601,7 @@ bool isOptionalExistentialType(IRInst* inst) // Parent context for the full type-flow pass. struct TypeFlowSpecializationContext { - // Make a tagged-union-type out of a given collection of tables. + // Make a tagged-union-type out of a given set of tables. // // This type can be used for insts that are semantically a tuple of a tag (to select a table) // and a payload to contain the existential value. @@ -611,7 +611,7 @@ struct TypeFlowSpecializationContext IRBuilder builder(module); HashSet typeSet; - // Create a type collection out of the base types from each table. + // Create a type set out of the base types from each table. forEachInSet( tableSet, [&](IRInst* witnessTable) @@ -642,21 +642,6 @@ struct TypeFlowSpecializationContext cast(builder.getSet(kIROp_TypeSet, typeSet))); } - // Create an unbounded set. - // - // This is a catch-all for cases where we can't enumerate the possibilites. - // We use this as a sentinel value to figure out when NOT to specialize a - // given inst. - // - // Most commonly occurs with COM objects & some builtin-in interface types. - // - IRUnboundedSet* makeUnbounded() - { - IRBuilder builder(module); - return as( - builder.emitIntrinsicInst(nullptr, kIROp_UnboundedSet, 0, nullptr)); - } - // Creates an 'empty' inst (denoted by nullptr), that // can be used to denote one of two things: // @@ -669,7 +654,7 @@ struct TypeFlowSpecializationContext // IRInst* none() { return nullptr; } - // Make an untagged-union type out of a given collection of types. + // Make an untagged-union type out of a given set of types. // // This is used to denote insts whose value can be of multiple possible types, // Note that unlike tagged-unions, untagged-unions do not have any information @@ -762,11 +747,13 @@ struct TypeFlowSpecializationContext { if (inst->getDataType()) { - // If the data-type is already a collection type, then the refinement - // occured during a previous phase. For now, we simply re-use that info directly. + // If the data-type is already a tagged union or untagged union or + // element-of-set type, then the refinement occured during a previous phase. + // + // For now, we simply re-use that info directly. // - // In the future, it makes sense to ignore the pre-existing type and treat - // them as an upper-bound on the new info. + // In the future, it makes sense to treat it as non-concrete and use + // them as an upper-bound for further refinement. // switch (inst->getDataType()->getOp()) { @@ -1071,7 +1058,7 @@ struct TypeFlowSpecializationContext // 3. Continue (2) until no more information has changed. // // This process is guaranteed to terminate because our propagation 'states' (i.e. - // collection insts and their wrapped versions) form a lattice. This is an order-theoretic + // sets and their composites) form a lattice. This is an order-theoretic // structure that implies that // (i) each update moves us strictly 'upward', and // (ii) that there are a finite number of possible states. diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 7058748ae86..f06dd47b1a7 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8117,6 +8117,38 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) // Swap this use over to use the other value. uu->usedValue = other; + bool userIsSetInst = as(uu->getUser()) != nullptr; + if (userIsSetInst) + { + // Set insts need their operands sorted + auto module = user->getModule(); + SLANG_ASSERT(module); + + List& operands = *module->getContainerPool().getList(); + for (UInt i = 0; i < user->getOperandCount(); i++) + operands.add(user->getOperand(i)); + + auto getUniqueId = [&](IRInst* inst) + { + auto uniqueIDMap = module->getUniqueIdMap(); + auto existingId = uniqueIDMap->tryGetValue(inst); + if (existingId) + return *existingId; + + auto id = uniqueIDMap->getCount(); + uniqueIDMap->add(inst, id); + return id; + }; + + operands.sort([&](IRInst* a, IRInst* b) + { return getUniqueId(a) < getUniqueId(b); }); + + for (UInt i = 0; i < user->getOperandCount(); i++) + user->getOperandUse(i)->usedValue = operands[i]; + + module->getContainerPool().free(&operands); + } + // If `other` is hoistable, then we need to make sure `other` is hoisted // to a point before `user`, if it is not already so. _maybeHoistOperand(uu); @@ -8616,14 +8648,12 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetTagFromSequentialID: case kIROp_GetSequentialIDFromTag: case kIROp_CastInterfaceToTaggedUnionPtr: - case kIROp_CastTaggedUnionToInterfacePtr: case kIROp_GetElementFromTag: case kIROp_GetTagFromTaggedUnion: case kIROp_GetTypeTagFromTaggedUnion: case kIROp_GetValueFromTaggedUnion: case kIROp_MakeTaggedUnion: case kIROp_GetTagOfElementInSet: - case kIROp_UnboundedSet: case kIROp_MakeDifferentialPairUserCode: case kIROp_MakeDifferentialPtrPair: case kIROp_MakeStorageTypeLoweringConfig: From 1e833df36f61fa9dbc7218f76bda6ae94adceb40 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:48:29 -0500 Subject: [PATCH 096/105] Remove op-less set building logic. Fix warnings. More doc tweaks --- source/slang/slang-ir-insts.h | 4 -- .../slang-ir-lower-dynamic-dispatch-insts.cpp | 2 +- source/slang/slang-ir-specialize.cpp | 3 +- source/slang/slang-ir-typeflow-set.cpp | 2 +- source/slang/slang-ir-typeflow-specialize.h | 2 + source/slang/slang-ir.cpp | 49 ++----------------- 6 files changed, 10 insertions(+), 52 deletions(-) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 1c685859832..dcb2eed92d9 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -5126,14 +5126,10 @@ struct IRBuilder void addRayPayloadDecoration(IRType* inst) { addDecoration(inst, kIROp_RayPayloadDecoration); } IRSetBase* getSet(IROp op, const HashSet& elements); - IRSetBase* getSet(const HashSet& elements); IRSetBase* getSingletonSet(IROp op, IRInst* element); - IRSetBase* getSingletonSet(IRInst* element); UInt getUniqueID(IRInst* inst); - - IROp getSetTypeForInst(IRInst* inst); }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp index c9d9b3f0b17..5bee559b497 100644 --- a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp @@ -1307,7 +1307,7 @@ struct TaggedUnionLoweringContext : public InstPassBase // Replace `TaggedUnionType(typeSet, tableSet)` with // `TupleType(SetTagType(tableSet), typeSet)` // - // Unless the collection has a single element, in which case we + // Unless the set has a single element, in which case we // replace it with `TupleType(SetTagType(tableSet), elementType)` // // We still maintain a tuple type (even though it's not really necesssary) to avoid diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 596df5474f0..bd0e4e2cb49 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -889,7 +889,8 @@ struct SpecializationContext if (!skipSpecialization) { IRBuilder builder(module); - auto newSet = builder.getSet(satisfyingValSet); + auto setOp = getSetOpFromType(lookupInst->getDataType()); + auto newSet = builder.getSet(setOp, satisfyingValSet); addUsersToWorkList(lookupInst); if (as(newSet)) { diff --git a/source/slang/slang-ir-typeflow-set.cpp b/source/slang/slang-ir-typeflow-set.cpp index f70900d4f56..b712ffbb5eb 100644 --- a/source/slang/slang-ir-typeflow-set.cpp +++ b/source/slang/slang-ir-typeflow-set.cpp @@ -108,7 +108,7 @@ IRInst* upcastSet(IRBuilder* builder, IRInst* arg, IRType* destInfo) // we need to perform a pack operation. // // This case only arises when passing a value of type T to a parameter - // of a type-collection that contains T. + // of a type-set that contains T. // return builder->emitPackAnyValue((IRType*)destInfo, arg); } diff --git a/source/slang/slang-ir-typeflow-specialize.h b/source/slang/slang-ir-typeflow-specialize.h index 6f3e562675a..8a4dfd8efbf 100644 --- a/source/slang/slang-ir-typeflow-specialize.h +++ b/source/slang/slang-ir-typeflow-specialize.h @@ -19,4 +19,6 @@ namespace Slang bool specializeDynamicInsts(IRModule* module, DiagnosticSink* sink); bool isSetSpecializedGeneric(IRInst* callee); + +IROp getSetOpFromType(IRType* type); } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index f06dd47b1a7..0283517c4e1 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6626,24 +6626,11 @@ IRSetBase* IRBuilder::getSet(IROp op, const HashSet& elements) return setBaseInst; } -IRSetBase* IRBuilder::getSet(const HashSet& elements) -{ - // Cannot call getSet with an empty set of elements and no specific op-code. - SLANG_ASSERT(elements.getCount() > 0); - auto firstElement = *elements.begin(); - return getSet(getSetTypeForInst(firstElement), elements); -} - IRSetBase* IRBuilder::getSingletonSet(IROp op, IRInst* element) { return getSet(op, {element}); } -IRSetBase* IRBuilder::getSingletonSet(IRInst* element) -{ - return getSet(getSetTypeForInst(element), {element}); -} - UInt IRBuilder::getUniqueID(IRInst* inst) { auto uniqueIDMap = getModule()->getUniqueIdMap(); @@ -6656,34 +6643,6 @@ UInt IRBuilder::getUniqueID(IRInst* inst) return id; } -IROp IRBuilder::getSetTypeForInst(IRInst* inst) -{ - if (as(inst) || as(inst)) - return kIROp_TypeSet; - if (as(inst)) - return kIROp_FuncSet; - if (as(inst) || as(inst)) - return kIROp_WitnessTableSet; - if (as(inst)) - return kIROp_GenericSet; - - if (as(inst)) - return kIROp_GenericSet; - - if (as(inst->getDataType())) - return kIROp_TypeSet; - else if (as(inst->getDataType())) - return kIROp_FuncSet; - else if (as(inst) && !as(inst)) - return kIROp_TypeSet; - else if (as(inst->getDataType())) - return kIROp_WitnessTableSet; - else if (as(inst)) - return kIROp_TypeSet; // TODO: this feels wrong... - else - return kIROp_Invalid; // Return invalid IROp when not supported -} - // struct IRDumpContext @@ -8125,8 +8084,8 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) SLANG_ASSERT(module); List& operands = *module->getContainerPool().getList(); - for (UInt i = 0; i < user->getOperandCount(); i++) - operands.add(user->getOperand(i)); + for (UInt ii = 0; ii < user->getOperandCount(); ii++) + operands.add(user->getOperand(ii)); auto getUniqueId = [&](IRInst* inst) { @@ -8143,8 +8102,8 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) operands.sort([&](IRInst* a, IRInst* b) { return getUniqueId(a) < getUniqueId(b); }); - for (UInt i = 0; i < user->getOperandCount(); i++) - user->getOperandUse(i)->usedValue = operands[i]; + for (UInt ii = 0; ii < user->getOperandCount(); ii++) + user->getOperandUse(ii)->usedValue = operands[ii]; module->getContainerPool().free(&operands); } From 562c751a6c7da0c6d54d67686ac05aa0b468aea1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 11:43:40 -0500 Subject: [PATCH 097/105] Fix CI warning --- source/slang/slang-ir.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3da1f929538..6017e67c93e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8054,7 +8054,7 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) for (UInt ii = 0; ii < user->getOperandCount(); ii++) operands.add(user->getOperand(ii)); - auto getUniqueId = [&](IRInst* inst) + auto getUniqueId = [&](IRInst* inst) -> UInt { auto uniqueIDMap = module->getUniqueIdMap(); auto existingId = uniqueIDMap->tryGetValue(inst); @@ -8063,11 +8063,11 @@ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) auto id = uniqueIDMap->getCount(); uniqueIDMap->add(inst, id); - return id; + return (UInt)id; }; - operands.sort([&](IRInst* a, IRInst* b) - { return getUniqueId(a) < getUniqueId(b); }); + operands.sort( + [&](IRInst* a, IRInst* b) -> bool { return getUniqueId(a) < getUniqueId(b); }); for (UInt ii = 0; ii < user->getOperandCount(); ii++) user->getOperandUse(ii)->usedValue = operands[ii]; From dbb29726190d4755ef84688c86e2fb94c5318669 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:44:38 -0500 Subject: [PATCH 098/105] Copy decorations when specializing; Add name hints for dispatchers --- .../slang-ir-lower-dynamic-dispatch-insts.cpp | 23 +++++++++++++++++++ source/slang/slang-ir-specialize.cpp | 8 ++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp index 5bee559b497..3fde6d1b96d 100644 --- a/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp +++ b/source/slang/slang-ir-lower-dynamic-dispatch-insts.cpp @@ -613,6 +613,15 @@ struct DispatcherLoweringContext : public InstPassBase { auto dispatchFunc = createDispatchFunc(cast(dispatcher->getDataType()), elements); + + if (auto nameHint = dispatcher->getLookupKey()->findDecoration()) + { + builder.setInsertBefore(dispatchFunc); + StringBuilder sb; + sb << "s_dispatch_" << nameHint->getName() << ""; + builder.addNameHintDecoration(dispatchFunc, sb.getUnownedSlice()); + } + traverseUses( dispatcher, [&](IRUse* use) @@ -687,6 +696,20 @@ struct DispatcherLoweringContext : public InstPassBase { auto dispatchFunc = createDispatchFunc(cast(dispatcher->getDataType()), elements); + + if (auto keyNameHint = key->findDecoration()) + { + builder.setInsertBefore(dispatchFunc); + StringBuilder sb; + sb << "s_dispatch_" << keyNameHint->getName() << ""; + for (auto specArg : specArgs) + { + sb << "_"; + getTypeNameHint(sb, specArg); + } + builder.addNameHintDecoration(dispatchFunc, sb.getUnownedSlice()); + } + traverseUses( dispatcher, [&](IRUse* use) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index bd0e4e2cb49..bf747b6c17a 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3260,7 +3260,7 @@ IRInst* specializeGenericWithSetArgs(IRSpecialize* specializeInst) // We're dealing with a set of types. if (as(param->getDataType())) { - SLANG_ASSERT("Should not happen"); + // TODO: This case should not happen anymore. cloneEnv.mapOldValToNew[param] = builder.getUntaggedUnionType(set); } else if (as(param->getDataType())) @@ -3311,6 +3311,12 @@ IRInst* specializeGenericWithSetArgs(IRSpecialize* specializeInst) auto returnedFunc = cast(inst); auto funcFirstBlock = returnedFunc->getFirstBlock(); + builder.setInsertBefore(loweredFunc->getFirstBlock()); + for (auto decoration : returnedFunc->getDecorations()) + { + cloneInst(&staticCloningEnv, &builder, decoration); + } + builder.setInsertInto(loweredFunc); for (auto block : returnedFunc->getBlocks()) { From e12dcba96c4eacb3008628abe91c6173d9959d28 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:36:59 -0500 Subject: [PATCH 099/105] Update func-type in the debug-func-decoration when legalizing --- source/slang/slang-ir-legalize-types.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index bd42efecbf4..e746c295815 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2465,6 +2465,12 @@ struct LegalFuncBuilder auto newFuncType = irBuilder->getFuncType(m_paramTypes.getCount(), m_paramTypes.getBuffer(), m_resultType); irBuilder->setDataType(oldFunc, newFuncType); + if (auto debugFuncDecoration = oldFunc->findDecoration()) + { + auto debugFunc = as(debugFuncDecoration->getOperand(0)); + SLANG_ASSERT(as(debugFunc->getOperand(4))); + debugFunc->setOperand(4, newFuncType); + } // If the function required any new parameters to be created // to represent the result/return type, then we need to From 4011762619030125de11bd9efa919154b380736a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 13:45:31 -0500 Subject: [PATCH 100/105] Update switch-case.slang --- tests/wgsl/switch-case.slang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/wgsl/switch-case.slang b/tests/wgsl/switch-case.slang index 34d253a8b02..163178efc89 100644 --- a/tests/wgsl/switch-case.slang +++ b/tests/wgsl/switch-case.slang @@ -70,7 +70,7 @@ func fs_main(VertexOutput input)->FragmentOutput return output; } -//WGSL: fn _S{{[0-9]+}}( _S{{[0-9]+}} : u32, _S{{[0-9]+}} : AnyValue8) -> f32 +//WGSL: fn s_dispatch_IShape_getArea_{{[0-9]+}}( _S{{[0-9]+}} : u32, _S{{[0-9]+}} : AnyValue8) -> f32 //WGSL-NEXT:{ //WGSL-DAG: return U_SR14switch_2Dxcase6Circle7getAreap0pf_wtwrapper_0(_S{{[0-9]+}}); //WGSL-DAG: return U_SR14switch_2Dxcase9Rectangle7getAreap0pf_wtwrapper_0(_S{{[0-9]+}}); From 385728578b98c5defdf69fa83ab1abdca642f65c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 4 Nov 2025 17:35:36 -0500 Subject: [PATCH 101/105] Update slang-ir-specialize.cpp --- source/slang/slang-ir-specialize.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index bf747b6c17a..4ae6acd5deb 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -3378,6 +3378,13 @@ IRInst* specializeGenericWithSetArgs(IRSpecialize* specializeInst) loweredFunc->setFullType( builder.getFuncType(funcTypeParams, loweredFuncType->getResultType())); } + else if (as(inst)) + { + // Emit out into the global scope. + IRBuilder globalBuilder(builder.getModule()); + globalBuilder.setInsertInto(builder.getModule()); + cloneInst(&staticCloningEnv, &globalBuilder, inst); + } else if (!as(inst)) { // Clone insts in the generic under two different environments: From 67bb72b17d44fc1a4bd94d482168d4dc6795d0a2 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:33:40 -0500 Subject: [PATCH 102/105] Fix debug info test to reflect emitted function signatures --- tests/spirv/debug-return-types.slang | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/spirv/debug-return-types.slang b/tests/spirv/debug-return-types.slang index cecca5872ee..f9d6e12cbef 100644 --- a/tests/spirv/debug-return-types.slang +++ b/tests/spirv/debug-return-types.slang @@ -17,8 +17,5 @@ float4 main(VSOutput input) : SV_TARGET { return float4(1.0, 1.0, 1.0, 1.0); } -// CHECK: %[[VECTOR:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeVector -// CHECK: %[[VSOUTPUT:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypePointer -// CHECK: {{.*}} = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeFunction %{{[a-zA-Z0-9_]+}} %[[VECTOR]] %[[VSOUTPUT]] // CHECK: %[[MATRIX:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeMatrix // CHECK: {{.*}} = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeFunction %{{[a-zA-Z0-9_]+}} %[[MATRIX]] %[[MATRIX]] From 355ca76e09834569a28e0f68fc64687c07727ce8 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 25 Nov 2025 14:52:52 -0500 Subject: [PATCH 103/105] Revert debug function change for now.. --- source/slang/slang-ir-legalize-types.cpp | 6 ------ tests/spirv/debug-return-types.slang | 3 +++ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index e746c295815..bd42efecbf4 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -2465,12 +2465,6 @@ struct LegalFuncBuilder auto newFuncType = irBuilder->getFuncType(m_paramTypes.getCount(), m_paramTypes.getBuffer(), m_resultType); irBuilder->setDataType(oldFunc, newFuncType); - if (auto debugFuncDecoration = oldFunc->findDecoration()) - { - auto debugFunc = as(debugFuncDecoration->getOperand(0)); - SLANG_ASSERT(as(debugFunc->getOperand(4))); - debugFunc->setOperand(4, newFuncType); - } // If the function required any new parameters to be created // to represent the result/return type, then we need to diff --git a/tests/spirv/debug-return-types.slang b/tests/spirv/debug-return-types.slang index f9d6e12cbef..cecca5872ee 100644 --- a/tests/spirv/debug-return-types.slang +++ b/tests/spirv/debug-return-types.slang @@ -17,5 +17,8 @@ float4 main(VSOutput input) : SV_TARGET { return float4(1.0, 1.0, 1.0, 1.0); } +// CHECK: %[[VECTOR:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeVector +// CHECK: %[[VSOUTPUT:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypePointer +// CHECK: {{.*}} = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeFunction %{{[a-zA-Z0-9_]+}} %[[VECTOR]] %[[VSOUTPUT]] // CHECK: %[[MATRIX:[0-9]+]] = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeMatrix // CHECK: {{.*}} = OpExtInst %void %{{[a-zA-Z0-9_]+}} DebugTypeFunction %{{[a-zA-Z0-9_]+}} %[[MATRIX]] %[[MATRIX]] From c9ddbaa511478becd3d6cb78a5c3d54ebb34b092 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:49:42 -0500 Subject: [PATCH 104/105] Wrap passes with the new tracking tool --- source/slang/slang-emit.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 9623fed4d17..c1da2e5be8f 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1129,21 +1129,21 @@ Result linkAndOptimizeIR( } // Tagged union type lowering typically generates more reinterpret instructions. - if (lowerTaggedUnionTypes(irModule, sink)) + if (SLANG_PASS(lowerTaggedUnionTypes, sink)) requiredLoweringPassSet.reinterpret = true; - lowerUntaggedUnionTypes(irModule, targetProgram, sink); + SLANG_PASS(lowerUntaggedUnionTypes, targetProgram, sink); if (requiredLoweringPassSet.reinterpret) SLANG_PASS(lowerReinterpret, targetProgram, sink); - lowerSequentialIDTagCasts(irModule, codeGenContext->getLinkage(), sink); - lowerTagInsts(irModule, sink); - lowerTagTypes(irModule); + SLANG_PASS(lowerSequentialIDTagCasts, codeGenContext->getLinkage(), sink); + SLANG_PASS(lowerTagInsts, sink); + SLANG_PASS(lowerTagTypes); - eliminateDeadCode(irModule, fastIRSimplificationOptions.deadCodeElimOptions); + SLANG_PASS(eliminateDeadCode, fastIRSimplificationOptions.deadCodeElimOptions); - lowerExistentials(irModule, targetProgram, sink); + SLANG_PASS(lowerExistentials, targetProgram, sink); if (sink->getErrorCount() != 0) return SLANG_FAIL; From 7d02c4be37ab777af9c86e15fcff926b17911bcf Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:49:53 -0500 Subject: [PATCH 105/105] Update slang-ir-specialize.cpp --- source/slang/slang-ir-specialize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 3ddc78cfbdd..0197826bde6 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -875,7 +875,7 @@ struct SpecializationContext { if (auto table = as(instElement)) { - if (auto satisfyingVal = findWitnessVal(table, requirementKey)) + if (auto satisfyingVal = findWitnessTableEntry(table, requirementKey)) { satisfyingValSet.add(satisfyingVal); return;