diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index c1711505c6e..87ec9705921 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -1196,6 +1196,11 @@ TypeCoercionWitness* ASTBuilder::getTypeCoercionWitness( return getOrCreate(subType, superType, declRef.declRefBase); } +NoneWitness* ASTBuilder::getNoneWitness(Type* subType, Type* superType) +{ + return getOrCreate(subType, superType); +} + DeclRef _getMemberDeclRef(ASTBuilder* builder, DeclRef parent, Decl* decl) { return builder->getMemberDeclRef(parent, decl); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 15a947ef0c7..d386463accb 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -696,6 +696,8 @@ class ASTBuilder : public RefObject Type* toType, DeclRef declRef); + NoneWitness* getNoneWitness(Type* subType, Type* superType); + /// Helpers to get type info from the SharedASTBuilder SyntaxClass findSyntaxClass(const UnownedStringSlice& slice) { diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 3813d104f61..14d4ee343d8 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -901,12 +901,28 @@ Val* TypeCoercionWitness::_resolveImplOverride() void NoneWitness::_toTextOverride(StringBuilder& out) { - out.append("none"); + out << "NoneWitness("; + if (getSub()) + out << getSub(); + out << ","; + if (getSup()) + out << getSup(); + out << ")"; } Val* NoneWitness::_resolveImplOverride() { - return this; + int diff = 0; + auto newSub = as(getSub()->resolve()); + if (newSub != getSub()) + diff++; + auto newSup = as(getSup()->resolve()); + if (newSup != getSup()) + diff++; + + if (!diff) + return this; + return getCurrentASTBuilder()->getNoneWitness(newSub, newSup); } // UNormModifierVal diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index c5e95b33609..d267d23185c 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -843,7 +843,7 @@ class ExtractFromConjunctionSubtypeWitness : public SubtypeWitness /// A witness for the "none" value of optional constraints. FIDDLE() -class NoneWitness : public Witness +class NoneWitness : public SubtypeWitness { FIDDLE(...) diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 2cfe7901122..ee426d80312 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -717,7 +717,7 @@ DeclRef SemanticsVisitor::trySolveConstraintSystem( else if (!subTypeWitness && constraintIsOptional) { // Optional witness failed to resolve; not an error. - auto noneWitness = m_astBuilder->getOrCreate(); + auto noneWitness = m_astBuilder->getNoneWitness(sub, sup); args.add(noneWitness); outBaseCost += kConversionCost_FailedOptionalConstraint; } @@ -870,20 +870,21 @@ bool SemanticsVisitor::TryUnifyVals( // Two subtype witnesses can be unified if they exist (non-null) and // prove that some pair of types are subtypes of types that can be unified. // - const auto fstSubtypeWitness = as(fst); - const auto sndSubtypeWitness = as(snd); - const auto fstNoneWitness = as(fst); - const auto sndNoneWitness = as(snd); - if (fstSubtypeWitness && sndSubtypeWitness) + auto fstSubtypeWitness = as(fst); + auto sndSubtypeWitness = as(snd); + auto fstNoneWitness = as(fst); + auto sndNoneWitness = as(snd); + if ((fstNoneWitness && !sndNoneWitness) || (!fstNoneWitness && sndNoneWitness)) + { + // Don't confuse a NoneWitness with a real SubtypeWitness. + return false; + } + else if (fstSubtypeWitness && sndSubtypeWitness) return TryUnifyTypes( constraints, unifyCtx, fstSubtypeWitness->getSup(), sndSubtypeWitness->getSup()); - else if (fstNoneWitness && sndNoneWitness) - return true; - else if ((fstNoneWitness && sndSubtypeWitness) || (fstSubtypeWitness && sndNoneWitness)) - return false; SLANG_UNIMPLEMENTED_X("value unification case"); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 24b48091155..39d16a411a1 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1040,7 +1040,7 @@ bool SemanticsVisitor::TryCheckOverloadCandidateConstraints( } else if (!subTypeWitness && constraintIsOptional) { - newArgs.add(m_astBuilder->getOrCreate()); + newArgs.add(m_astBuilder->getNoneWitness(sub, sup)); } else { diff --git a/source/slang/slang-ir-insts-stable-names.lua b/source/slang/slang-ir-insts-stable-names.lua index 638e55902b5..88064f3aa73 100644 --- a/source/slang/slang-ir-insts-stable-names.lua +++ b/source/slang/slang-ir-insts-stable-names.lua @@ -685,4 +685,7 @@ return { ["StoreBase.copyLogical"] = 681, ["MakeStorageTypeLoweringConfig"] = 682, ["Decoration.experimentalModule"] = 683, + ["Type.WitnessTableTypeBase.witness_table_none_t"] = 684, + ["none_witness_table"] = 685, + ["CheckOptionalWitness"] = 686 } diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index e838f0cf18f..c66d6c9c424 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2209,6 +2209,21 @@ struct IRWitnessTable : IRInst IRType* getConcreteType() { return (IRType*)getOperand(0); } }; +// The NoneWitnessTable is used in place of a witness table when the constraint +// is optional and not being conformed to. +FIDDLE() +struct IRNoneWitnessTable : IRInst +{ + FIDDLE(leafInst()) + + IRInst* getConformanceType() + { + return cast(getDataType())->getConformanceType(); + } + + IRType* getConcreteType() { return (IRType*)getOperand(0); } +}; + /// Represents an RTTI object. /// An IRRTTIObject has 1 operand, specifying the type /// this RTTI object provides info for. @@ -3415,6 +3430,8 @@ struct IRBuilder IRInst* emitGetSequentialIDInst(IRInst* rttiObj); + IRInst* emitCheckOptionalWitness(IRInst* witness); + IRInst* emitAlloca(IRInst* type, IRInst* rttiObjPtr); IRInst* emitGlobalValueRef(IRInst* globalInst); @@ -3666,6 +3683,10 @@ struct IRBuilder IRInst* requirementKey, IRInst* satisfyingVal); + // Create an empty witness table, this is used for optional constraints + // when baseType does not conform to subType. + IRNoneWitnessTable* createNoneWitnessTable(IRType* baseType, IRType* subType); + IRInst* createThisTypeWitness(IRType* interfaceType); IRInst* getTypeEqualityWitness(IRType* witnessType, IRType* type1, IRType* type2); diff --git a/source/slang/slang-ir-insts.lua b/source/slang/slang-ir-insts.lua index ff686ca1e9a..ce0515ccc2e 100644 --- a/source/slang/slang-ir-insts.lua +++ b/source/slang/slang-ir-insts.lua @@ -665,6 +665,15 @@ local insts = { operands = { { "baseType", "IRType" } }, }, }, + { + witness_table_none_t = { + -- A type for NoneWitness, which is used to satisfy + -- optional constraints when the base type does not + -- conform. + struct_name = "WitnessTableNoneType", + operands = { { "baseType", "IRType" } }, + }, + }, { witness_table_id_t = { -- An integer type representing a witness table for targets where @@ -713,6 +722,7 @@ local insts = { { key = { struct_name = "StructKey", global = true } }, { global_generic_param = { global = true } }, { witness_table = { hoistable = true } }, + { none_witness_table = { hoistable = true } }, { indexedFieldKey = { operands = { { "baseType" }, { "index" } }, hoistable = true } }, -- A placeholder witness that ThisType implements the enclosing interface. -- Used only in interface definitions. @@ -852,6 +862,7 @@ local insts = { operands = { { "param", "IRGlobalGenericParam" }, { "val", "IRInst" } }, }, }, + { CheckOptionalWitness = { operands = { { "witness" } }, hoistable = true } }, { allocObj = {} }, { globalValueRef = { operands = { { "value" } } } }, { makeUInt64 = { operands = { { "low" }, { "high" } } } }, diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 8ba1f354d3c..7648b5047ed 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -794,6 +794,20 @@ IRWitnessTable* cloneWitnessTableWithoutRegistering( false); } +IRNoneWitnessTable* cloneNoneWitnessTableImpl( + IRSpecContextBase* context, + IRBuilder* builder, + IRNoneWitnessTable* originalTable, + IROriginalValuesForClone const& originalValues) +{ + auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); + auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); + auto clonedTable = builder->createNoneWitnessTable(clonedBaseType, clonedSubType); + registerClonedValue(context, clonedTable, originalValues); + + return clonedTable; +} + IRStructType* cloneStructTypeImpl( IRSpecContextBase* context, IRBuilder* builder, @@ -1381,6 +1395,13 @@ IRInst* cloneInst( cast(originalInst), originalValues); + case kIROp_NoneWitnessTable: + return cloneNoneWitnessTableImpl( + context, + builder, + cast(originalInst), + originalValues); + case kIROp_StructType: return cloneStructTypeImpl( context, diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index e4fb5ac9814..2f011c4459d 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -295,11 +295,10 @@ struct GenericCallLoweringContext return; } - auto interfaceType = as( - cast(lookupInst->getWitnessTable()->getDataType()) - ->getConformanceType()); + auto witnessTableType = cast(lookupInst->getWitnessTable()->getDataType()); + auto interfaceType = as(witnessTableType->getConformanceType()); - if (!interfaceType) + if (as(witnessTableType)) { // NoneWitness -> remove call. callInst->removeAndDeallocate(); diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index 87243ff5cee..311b3f118d9 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -297,11 +297,8 @@ struct GenericFunctionLoweringContext // and emission of wrapper functions. void lowerWitnessTable(IRWitnessTable* witnessTable) { - IRInterfaceType* conformanceType = as(witnessTable->getConformanceType()); - if (!conformanceType) - return; - - auto interfaceType = maybeLowerInterfaceType(conformanceType); + auto interfaceType = + maybeLowerInterfaceType(cast(witnessTable->getConformanceType())); IRBuilder builderStorage(sharedContext->module); auto builder = &builderStorage; builder->setInsertBefore(witnessTable); @@ -357,17 +354,8 @@ struct GenericFunctionLoweringContext return; if (witnessTableType->getConformanceType()->findDecoration()) return; - - IRInterfaceType* conformanceType = - as(witnessTableType->getConformanceType()); - - // NoneWitness generates conformance types which aren't interfaces. In - // that case, the method can just be skipped entirely, since there's no - // real witness for it and it should be in unreachable code at this - // point. - if (!conformanceType) - return; - auto interfaceType = maybeLowerInterfaceType(conformanceType); + auto interfaceType = + maybeLowerInterfaceType(cast(witnessTableType->getConformanceType())); interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( interfaceType, lookupInst->getRequirementKey()); diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 8677973f51d..c4bcad2062c 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -152,6 +152,34 @@ void specializeRTTIObjects(SharedGenericsLoweringContext* sharedContext, Diagnos cleanUpInterfaceTypes(sharedContext); } +void lowerOptionalWitnesses(SharedGenericsLoweringContext* sharedContext) +{ + InstPassBase pass(sharedContext->module); + IRBuilder builder(sharedContext->module); + + pass.processInstsOfType( + kIROp_CheckOptionalWitness, + [&](IRCheckOptionalWitness* inst) + { + builder.setInsertBefore(inst); + auto checkInst = builder.getBoolValue(inst->getWitness()->getOp() != kIROp_NoneWitnessTable); + inst->replaceUsesWith(checkInst); + inst->removeAndDeallocate(); + }); + + // Remove all NoneWitnessTables, they're no longer referenced. + builder.setInsertInto(sharedContext->module->getModuleInst()); + + List noneWitnesses; + for (auto inst : sharedContext->module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_NoneWitnessTable) + noneWitnesses.add(inst); + } + for (auto inst : noneWitnesses) + inst->removeAndDeallocate(); +} + void checkTypeConformanceExists(SharedGenericsLoweringContext* context) { HashSet implementedInterfaces; @@ -177,7 +205,7 @@ void checkTypeConformanceExists(SharedGenericsLoweringContext* context) if (!witnessTableType) return; auto interfaceType = - cast(witnessTableType)->getConformanceType(); + cast(witnessTableType)->getConformanceType(); if (isComInterfaceType((IRType*)interfaceType)) return; if (!implementedInterfaces.contains(interfaceType)) @@ -258,6 +286,8 @@ void lowerGenerics(TargetProgram* targetProgram, IRModule* module, DiagnosticSin if (sink->getErrorCount() != 0) return; + lowerOptionalWitnesses(&sharedContext); + // This optional step replaces all uses of witness tables and RTTI objects with // sequential IDs. Without this step, we will emit code that uses function pointers and // real RTTI objects and witness tables. diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index b862b3dd087..69b85819ebd 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -232,31 +232,23 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex { auto interfaceType = cast(inst->getDataType())->getConformanceType(); - if (as(interfaceType)) + auto interfaceLinkage = interfaceType->findDecoration(); + SLANG_ASSERT( + interfaceLinkage && "An interface type does not have a linkage," + "but a witness table associated with it has one."); + auto interfaceName = interfaceLinkage->getMangledName(); + auto idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( + interfaceName); + if (!idAllocator) { - auto interfaceLinkage = interfaceType->findDecoration(); - SLANG_ASSERT( - interfaceLinkage && "An interface type does not have a linkage," - "but a witness table associated with it has one."); - auto interfaceName = interfaceLinkage->getMangledName(); - auto idAllocator = + linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; + idAllocator = linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( interfaceName); - if (!idAllocator) - { - linkage->mapInterfaceMangledNameToSequentialIDCounters[interfaceName] = 0; - idAllocator = - linkage->mapInterfaceMangledNameToSequentialIDCounters.tryGetValue( - interfaceName); - } - seqID = *idAllocator; - ++(*idAllocator); - } - else - { - // NoneWitness, has special ID of -1. - seqID = uint32_t(-1); } + seqID = *idAllocator; + ++(*idAllocator); linkage->mapMangledNameToRTTIObjectIndex[witnessTableMangledName] = seqID; } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 0cf8cc97f03..43af940387e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3530,6 +3530,14 @@ IRInst* IRBuilder::emitGetSequentialIDInst(IRInst* rttiObj) return inst; } +IRInst* IRBuilder::emitCheckOptionalWitness(IRInst* witness) +{ + auto inst = createInst( + this, kIROp_CheckOptionalWitness, getBoolType(), witness); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitBitfieldExtract(IRType* type, IRInst* value, IRInst* offset, IRInst* bits) { auto inst = createInst(this, kIROp_BitfieldExtract, type, value, offset, bits); @@ -4569,6 +4577,15 @@ IRWitnessTableEntry* IRBuilder::createWitnessTableEntry( return entry; } +IRNoneWitnessTable* IRBuilder::createNoneWitnessTable(IRType* baseType, IRType* subType) +{ + return createInst( + this, + kIROp_NoneWitnessTable, + getWitnessTableNoneType(baseType), + subType); +} + IRInterfaceRequirementEntry* IRBuilder::createInterfaceRequirementEntry( IRInst* requirementKey, IRInst* requirementVal) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 053b91aa1e6..f193b308a41 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1758,6 +1758,12 @@ struct IRWitnessTableType : IRWitnessTableTypeBase FIDDLE(leafInst()) }; +FIDDLE() +struct IRWitnessTableNoneType : IRWitnessTableTypeBase +{ + FIDDLE(leafInst()) +}; + FIDDLE() struct IRBindExistentialsTypeBase : IRType { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 78be183c304..43bd334a9b7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2104,11 +2104,13 @@ struct ValLoweringVisitor : ValVisitoremitGetTupleElement(elementType, conjunctionWitness, indexInConjunction)); } - LoweredValInfo visitNoneWitness(NoneWitness*) + LoweredValInfo visitNoneWitness(NoneWitness* val) { - auto builder = getBuilder(); - auto voidType = builder->getVoidType(); - return LoweredValInfo::simple(builder->createWitnessTable(voidType, voidType)); + auto subType = lowerType(context, val->getSub()); + auto supType = lowerType(context, val->getSup()); + return LoweredValInfo::simple(getBuilder()->createNoneWitnessTable( + supType, subType + )); } LoweredValInfo visitConstantIntVal(ConstantIntVal* val) @@ -6074,12 +6076,9 @@ struct ExprLoweringVisitorBase : public ExprVisitor if (declWitness && declWitness->isOptional()) { - // Optional constraint check. NoneWitness lowers to a specific - // ID, so that we can check for that here. - auto witnessID = builder->emitGetSequentialIDInst(witness); - auto noneWitnessID = builder->getIntValue(builder->getUIntType(), -1); - auto irVal = builder->emitNeq(witnessID, noneWitnessID); - return LoweredValInfo::simple(irVal); + // Optional constraint check. + auto val = builder->emitCheckOptionalWitness(witness); + return LoweredValInfo::simple(val); } else { // This is a run-time type check from for an existential type. diff --git a/tests/language-feature/generics/where-optional-6.slang b/tests/language-feature/generics/where-optional-6.slang new file mode 100644 index 00000000000..9761179ff6c --- /dev/null +++ b/tests/language-feature/generics/where-optional-6.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface I +{ +} + +struct A where optional T: I +{ + + int f() + { + if (T is I) + return 0; + else + return 1; + } +} + +struct B +{ + // A NoneWitness is created in the IR for specializing A without satisfying + // I. This NoneWitness should be hoisted to the global scope; otherwise it + // gets generated _after_ the terminator of the outer generic block, + // breaking a whole lot of stuff. + A a; + A b; +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + B b; + // CHECK: 1 + // CHECK-NEXT: 2 + outputBuffer[0] = b.a.f(); + outputBuffer[1] = b.b.f()*2; +}