Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions source/slang/slang-ir-spirv-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,105 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
}
}

void legalizeStructBlocks()
{
// SPIRV does not allow using a struct with a block declaration as a field
// of another struct. Only top-level usage (e.g., global parameter blocks) should
// have the block decoration. If a struct is used both as a field and as a block,
// we must move the top-level usage to a wrapper struct, and move the block
// decoration to the wrapper struct.

HashSet<IRStructType*> embeddedBlockStructs;
List<IRGlobalParam*> structGlobalParams;
for (auto globalInst : m_module->getGlobalInsts())
{
if (auto outerStruct = as<IRStructType>(globalInst))
{
for (auto field : outerStruct->getFields())
{
if (auto innerStruct = as<IRStructType>(field->getFieldType()))
{
if (innerStruct->findDecorationImpl(kIROp_SPIRVBlockDecoration) ||
innerStruct->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration))
{
embeddedBlockStructs.add(innerStruct);
}
}
}
}
else if (auto globalParam = as<IRGlobalParam>(globalInst))
{
if (auto ptrType = as<IRPtrTypeBase>(globalParam->getDataType()))
{
if (as<IRStructType>(ptrType->getValueType()))
{
structGlobalParams.add(globalParam);
}
}
}
}

for (auto globalParam : structGlobalParams)
{
auto ptrType = as<IRPtrTypeBase>(globalParam->getDataType());
auto structType = as<IRStructType>(ptrType->getValueType());

if (!embeddedBlockStructs.contains(structType))
continue;

// Create a wrapper struct type
IRBuilder builder(globalParam);
builder.setInsertBefore(globalParam);

auto wrapperStruct = builder.createStructType();
auto key = builder.createStructKey();
builder.createStructField(wrapperStruct, key, structType);

// Copy the block decoration from the inner struct to the wrapper
if (structType->findDecorationImpl(kIROp_SPIRVBlockDecoration))
{
builder.addDecorationIfNotExist(wrapperStruct, kIROp_SPIRVBlockDecoration);
}
if (structType->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration))
{
builder.addDecorationIfNotExist(wrapperStruct, kIROp_SPIRVBufferBlockDecoration);
}

// Update the global param's type to use the wrapper struct
auto newPtrType =
builder.getPtrType(ptrType->getOp(), wrapperStruct, ptrType->getAddressSpace());
globalParam->setFullType(newPtrType);

// Traverse all uses of the global param and insert a FieldAddress to access the
// inner struct
traverseUses(
globalParam,
[&](IRUse* use)
{
builder.setInsertBefore(use->getUser());
auto addr = builder.emitFieldAddress(
builder.getPtrType(kIROp_PtrType, structType, ptrType->getAddressSpace()),
globalParam,
key);
use->set(addr);
});
}

// Remove block/buffer block decorations from all embedded block structs
for (auto structType : embeddedBlockStructs)
{
if (auto blockDecor = structType->findDecorationImpl(kIROp_SPIRVBlockDecoration))
{
blockDecor->removeAndDeallocate();
}
if (auto bufferBlockDecor =
structType->findDecorationImpl(kIROp_SPIRVBufferBlockDecoration))
{
bufferBlockDecor->removeAndDeallocate();
}
}
}

void processModule()
{
determineSpirvVersion();
Expand Down Expand Up @@ -2196,6 +2295,10 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
m_module,
bufferElementTypeLoweringOptions);

// Look for structs that are both used as fields and marked with Block
// decorations, and move the Block decoration to a wrapper struct.
legalizeStructBlocks();

// Inline all pack/unpack storage type functions generated during buffer element
// lowering pass.
performForceInlining(m_module);
Expand Down
37 changes: 37 additions & 0 deletions tests/bugs/gh-7431.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECKOUT):-vk -compute -output-using-type

// CHECKOUT: 10
// CHECKOUT-NEXT: 20
// CHECKOUT-NEXT: 30
// CHECKOUT-NEXT: 40

struct Test {
int f_int;
};

ParameterBlock<Test> pb_struct;
uniform Test u_struct;
uniform Test u_struct_array[2];

RWStructuredBuffer<int> results;

//TEST_INPUT: set pb_struct = new Test{10};
//TEST_INPUT: uniform(data=[20 0 0 0]):name=u_struct.f_int
//TEST_INPUT: uniform(data=[30 0 0 0]):name=u_struct_array[0].f_int
//TEST_INPUT: uniform(data=[40 0 0 0]):name=u_struct_array[1].f_int

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=results

[shader("compute")]
[numthreads(1, 1, 1)]
void computeMain(uint3 tid: SV_DispatchThreadID)
{
if (any(tid != uint3(0)))
return;

results[0] = pb_struct.f_int;
results[1] = u_struct.f_int;
results[2] = u_struct_array[0].f_int;
results[3] = u_struct_array[1].f_int;
}