diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 20b721a2002..9433b560be3 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -2107,6 +2107,105 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } + void legalizeStructBlocks() + { + // SPIRV does not allow using a struct with a block declaration as a field + // of another struct. Only top-level usage (e.g., global parameter blocks) should + // have the block decoration. If a struct is used both as a field and as a block, + // we must move the top-level usage to a wrapper struct, and move the block + // decoration to the wrapper struct. + + HashSet embeddedBlockStructs; + List structGlobalParams; + for (auto globalInst : m_module->getGlobalInsts()) + { + if (auto outerStruct = as(globalInst)) + { + for (auto field : outerStruct->getFields()) + { + if (auto innerStruct = as(field->getFieldType())) + { + if (innerStruct->findDecorationImpl(kIROp_SPIRVBlockDecoration) || + innerStruct->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration)) + { + embeddedBlockStructs.add(innerStruct); + } + } + } + } + else if (auto globalParam = as(globalInst)) + { + if (auto ptrType = as(globalParam->getDataType())) + { + if (as(ptrType->getValueType())) + { + structGlobalParams.add(globalParam); + } + } + } + } + + for (auto globalParam : structGlobalParams) + { + auto ptrType = as(globalParam->getDataType()); + auto structType = as(ptrType->getValueType()); + + if (!embeddedBlockStructs.contains(structType)) + continue; + + // Create a wrapper struct type + IRBuilder builder(globalParam); + builder.setInsertBefore(globalParam); + + auto wrapperStruct = builder.createStructType(); + auto key = builder.createStructKey(); + builder.createStructField(wrapperStruct, key, structType); + + // Copy the block decoration from the inner struct to the wrapper + if (structType->findDecorationImpl(kIROp_SPIRVBlockDecoration)) + { + builder.addDecorationIfNotExist(wrapperStruct, kIROp_SPIRVBlockDecoration); + } + if (structType->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration)) + { + builder.addDecorationIfNotExist(wrapperStruct, kIROp_SPIRVBufferBlockDecoration); + } + + // Update the global param's type to use the wrapper struct + auto newPtrType = + builder.getPtrType(ptrType->getOp(), wrapperStruct, ptrType->getAddressSpace()); + globalParam->setFullType(newPtrType); + + // Traverse all uses of the global param and insert a FieldAddress to access the + // inner struct + traverseUses( + globalParam, + [&](IRUse* use) + { + builder.setInsertBefore(use->getUser()); + auto addr = builder.emitFieldAddress( + builder.getPtrType(kIROp_PtrType, structType, ptrType->getAddressSpace()), + globalParam, + key); + use->set(addr); + }); + } + + // Remove block/buffer block decorations from all embedded block structs + for (auto structType : embeddedBlockStructs) + { + if (auto blockDecor = structType->findDecorationImpl(kIROp_SPIRVBlockDecoration)) + { + blockDecor->removeAndDeallocate(); + } + if (auto bufferBlockDecor = + structType->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration)) + { + bufferBlockDecor->removeAndDeallocate(); + } + } + } + void processModule() { determineSpirvVersion(); @@ -2196,6 +2295,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase m_module, bufferElementTypeLoweringOptions); + // Look for structs that are both used as fields and marked with Block + // decorations, and move the Block decoration to a wrapper struct. + legalizeStructBlocks(); + // Inline all pack/unpack storage type functions generated during buffer element // lowering pass. performForceInlining(m_module); diff --git a/tests/bugs/gh-7431.slang b/tests/bugs/gh-7431.slang new file mode 100644 index 00000000000..d6c8ff34231 --- /dev/null +++ b/tests/bugs/gh-7431.slang @@ -0,0 +1,37 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECKOUT):-vk -compute -output-using-type + +// CHECKOUT: 10 +// CHECKOUT-NEXT: 20 +// CHECKOUT-NEXT: 30 +// CHECKOUT-NEXT: 40 + +struct Test { + int f_int; +}; + +ParameterBlock pb_struct; +uniform Test u_struct; +uniform Test u_struct_array[2]; + +RWStructuredBuffer results; + +//TEST_INPUT: set pb_struct = new Test{10}; +//TEST_INPUT: uniform(data=[20 0 0 0]):name=u_struct.f_int +//TEST_INPUT: uniform(data=[30 0 0 0]):name=u_struct_array[0].f_int +//TEST_INPUT: uniform(data=[40 0 0 0]):name=u_struct_array[1].f_int + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=results + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 tid: SV_DispatchThreadID) +{ + if (any(tid != uint3(0))) + return; + + results[0] = pb_struct.f_int; + results[1] = u_struct.f_int; + results[2] = u_struct_array[0].f_int; + results[3] = u_struct_array[1].f_int; +} +