Skip to content

Commit

Permalink
Merge pull request #2343 from billhollings/recurs-desc-set-arg-buff
Browse files Browse the repository at this point in the history
MSL: Support descriptor sets with recursive content when using argument buffers.
  • Loading branch information
HansKristian-Work authored Jun 17, 2024
2 parents e818cfd + d7ad3d7 commit ab608ac
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct recurs;

struct recurs
{
int m1;
device recurs* m2;
};

struct recurs_1
{
int m1;
device recurs_1* m2;
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);

struct spvDescriptorSetBuffer0
{
device recurs* nums [[id(0)]];
texture2d<uint, access::write> tex [[id(1)]];
};

kernel void main0(constant void* spvDescriptorSet0_vp [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
constant auto& spvDescriptorSet0 = *(constant spvDescriptorSetBuffer0*)spvDescriptorSet0_vp;
spvDescriptorSet0.tex.write(uint4(uint(((*spvDescriptorSet0.nums).m1 + (*spvDescriptorSet0.nums).m2->m1) + (*spvDescriptorSet0.nums).m2->m2->m1), 0u, 0u, 1u), uint2(int2(gl_GlobalInvocationID.xy)));
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

struct recurs;

struct recurs
{
int m1;
device recurs* m2;
};

struct recurs_1
{
int m1;
device recurs_1* m2;
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);

struct spvDescriptorSetBuffer0
{
device recurs* nums [[id(0)]];
texture2d<uint, access::write> tex [[id(1)]];
};

kernel void main0(constant void* spvDescriptorSet0_vp [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
constant auto& spvDescriptorSet0 = *(constant spvDescriptorSetBuffer0*)spvDescriptorSet0_vp;
int rslt = 0;
rslt += (*spvDescriptorSet0.nums).m1;
rslt += (*spvDescriptorSet0.nums).m2->m1;
rslt += (*spvDescriptorSet0.nums).m2->m2->m1;
spvDescriptorSet0.tex.write(uint4(uint(rslt), 0u, 0u, 1u), uint2(int2(gl_GlobalInvocationID.xy)));
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#version 450
#extension GL_EXT_buffer_reference2 : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

layout(buffer_reference) buffer recurs;
layout(buffer_reference, buffer_reference_align = 16, set = 0, binding = 1, std140) buffer recurs
{
int m1;
recurs m2;
} nums;

layout(set = 0, binding = 0, r32ui) uniform writeonly uimage2D tex;

void main()
{
int rslt = 0;
rslt += nums.m1;
rslt += nums.m2.m1;
rslt += nums.m2.m2.m1;
imageStore(tex, ivec2(gl_GlobalInvocationID.xy), uvec4(rslt, 0u, 0u, 1u));
}
29 changes: 24 additions & 5 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13607,7 +13607,13 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)

claimed_bindings.set(buffer_binding);

ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
ep_args += get_argument_address_space(var) + " ";

if (recursive_inputs.count(type.self))
ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp";
else
ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);

ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";

next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
Expand Down Expand Up @@ -14058,7 +14064,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
}
}

if (msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
if (!msl_options.argument_buffers &&
msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
(var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer))
{
Expand Down Expand Up @@ -18346,7 +18353,8 @@ void CompilerMSL::analyze_argument_buffers()
else
buffer_type.storage = StorageClassUniform;

set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set);
set_name(type_id, buffer_type_name);

auto &ptr_type = set<SPIRType>(ptr_type_id, OpTypePointer);
ptr_type = buffer_type;
Expand All @@ -18356,8 +18364,9 @@ void CompilerMSL::analyze_argument_buffers()
ptr_type.parent_type = type_id;

uint32_t buffer_variable_id = next_id;
set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
auto &buffer_var = set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
auto buffer_name = join("spvDescriptorSet", desc_set);
set_name(buffer_variable_id, buffer_name);

// Ids must be emitted in ID order.
stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
Expand Down Expand Up @@ -18565,6 +18574,16 @@ void CompilerMSL::analyze_argument_buffers()
set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding);
member_index++;
}

if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type))
{
recursive_inputs.insert(type_id);
auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
auto addr_space = get_argument_address_space(buffer_var);
entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() {
statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;");
});
}
}
}

Expand Down

0 comments on commit ab608ac

Please sign in to comment.