Skip to content

Commit

Permalink
Merge pull request #2344 from billhollings/gather-constoffsts-arg-buffs
Browse files Browse the repository at this point in the history
MSL: Image gather ConstOffsets supports multiple address spaces.
  • Loading branch information
HansKristian-Work authored Jun 18, 2024
2 parents 98d9e42 + b5ccb0c commit 5d127b9
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 49 deletions.
26 changes: 25 additions & 1 deletion reference/opt/shaders-msl/frag/gather-compare-const-offsets.frag
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,31 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}

// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
Expand Down
54 changes: 53 additions & 1 deletion reference/opt/shaders-msl/frag/gather-const-offsets.frag
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,59 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}

// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
Expand Down
26 changes: 25 additions & 1 deletion reference/shaders-msl/frag/gather-compare-const-offsets.frag
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,31 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}

// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, Tp... params)
{
Expand Down
54 changes: 53 additions & 1 deletion reference/shaders-msl/frag/gather-const-offsets.frag
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,59 @@ template<typename T> inline constexpr thread T&& spvForward(thread typename spvR
return static_cast<thread T&&>(x);
}

// Wrapper function that processes a texture gather with a constant offset array.
// Wrapper function that processes a device texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const device Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a constant texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const constant Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
vec<T, 4> rslts[4];
for (uint i = 0; i < 4; i++)
{
switch (c)
{
case component::x:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);
break;
case component::y:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);
break;
case component::z:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);
break;
case component::w:
rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);
break;
}
}
return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);
}

// Wrapper function that processes a thread texture gather with a constant offset array.
template<typename T, template<typename, access = access::sample, typename = void> class Tex, typename Toff, typename... Tp>
inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)
{
Expand Down
102 changes: 57 additions & 45 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5615,6 +5615,10 @@ void CompilerMSL::emit_custom_templates()
// otherwise they will cause problems when linked together in a single Metallib.
void CompilerMSL::emit_custom_functions()
{
// Use when outputting overloaded functions to cover different address spaces.
static const char *texture_addr_spaces[] = { "device", "constant", "thread" };
static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*);

if (spv_function_implementations.count(SPVFuncImplArrayCopyMultidim))
spv_function_implementations.insert(SPVFuncImplArrayCopy);

Expand Down Expand Up @@ -6264,54 +6268,62 @@ void CompilerMSL::emit_custom_functions()
break;

case SPVFuncImplGatherConstOffsets:
statement("// Wrapper function that processes a texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, "
"Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement("switch (c)");
begin_scope();
// Work around texture::gather() requiring its component parameter to be a constant expression
statement("case component::x:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
statement(" break;");
statement("case component::y:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
statement(" break;");
statement("case component::z:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
statement(" break;");
statement("case component::w:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
statement(" break;");
end_scope();
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
for (uint32_t i = 0; i < texture_addr_space_count; i++)
{
statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
"Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement("switch (c)");
begin_scope();
// Work around texture::gather() requiring its component parameter to be a constant expression
statement("case component::x:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
statement(" break;");
statement("case component::y:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
statement(" break;");
statement("case component::z:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
statement(" break;");
statement("case component::w:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
statement(" break;");
end_scope();
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
}
break;

case SPVFuncImplGatherCompareConstOffsets:
statement("// Wrapper function that processes a texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, "
"Toff coffsets, Tp... params)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement(" rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
for (uint32_t i = 0; i < texture_addr_space_count; i++)
{
statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherCompareConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
"Toff coffsets, Tp... params)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement(" rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
}
break;

case SPVFuncImplSubgroupBroadcast:
Expand Down

0 comments on commit 5d127b9

Please sign in to comment.