Skip to content

Commit

Permalink
Merge pull request #2266 from KhronosGroup/pr-2257
Browse files Browse the repository at this point in the history
Land MSL integer dot products
  • Loading branch information
HansKristian-Work authored Jan 16, 2024
2 parents 0a5e7b0 + e9851cc commit 64f64c8
Show file tree
Hide file tree
Showing 4 changed files with 331 additions and 1 deletion.
65 changes: 65 additions & 0 deletions reference/shaders-msl-no-opt/comp/integer-dot-product.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"

#include <metal_stdlib>
#include <simd/simd.h>

using namespace metal;

template <typename T>
T reduce_add(vec<T, 2> v) { return v.x + v.y; }
template <typename T>
T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }
template <typename T>
T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }

struct InOut3
{
ushort4 x;
ushort4 y;
int acc;
int result;
};

struct InOut2
{
uint x;
uint y;
uint result;
};

constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);

struct InOut
{
uint4 x;
uint4 y;
int result;
};

kernel void main0(device void* spvBufferAliasSet0Binding1 [[buffer(0)]])
{
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1;
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1;
int sdot_int = reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y)));
uint sdot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(short4(comp3.y)));
uint udot_uint = reduce_add(uint4(comp3.x) * uint4(comp3.y));
int sudot_int = reduce_add(int4(short4(comp3.x)) * int4(comp3.y));
uint sudot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(comp3.y));
uchar spdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<char4>(comp2.y)));
ushort spdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<char4>(comp2.y)));
uint spdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<char4>(comp2.y)));
int spdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<char4>(comp2.y)));
uchar updot8 = reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y)));
ushort updot16 = reduce_add(ushort4(as_type<uchar4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y)));
uint updot32 = reduce_add(uint4(as_type<uchar4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y)));
uchar supdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y)));
ushort supdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y)));
uint supdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y)));
int supdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<uchar4>(comp2.y)));
int sdotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc));
uint sdotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc));
uint udotaddsat_uint = uint(addsat(reduce_add(uint4(comp3.x) * uint4(comp3.y)), uint(comp3.acc)));
int sudotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc));
uint sudotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc));
}

114 changes: 114 additions & 0 deletions shaders-msl-no-opt/comp/integer-dot-product.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#version 450
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_spirv_intrinsics : require

layout(local_size_x = 1) in;

layout(std430, binding = 0) buffer InOut {
uvec4 x;
uvec4 y;
int result;
} comp;

layout(std430, binding = 1) buffer InOut2 {
uint x;
uint y;
uint result;
} comp2;

layout(std430, binding = 1) buffer InOut3 {
u16vec4 x;
u16vec4 y;
int acc;
int result;
} comp3;

// Signed integer dot with unsigned integer
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
int sdot_int_result(u16vec4 x, u16vec4 y);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint sdot_uint_result(u16vec4 x, u16vec4 y);

// Unsigned integer dot with signed integer. Only unsigned result is allowed in SPIR-V.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint udot_uint_result(u16vec4 x, u16vec4 y);

// Mixed integer dot with unsigned integer
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
int sudot_int_result(u16vec4 x, u16vec4 y);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint sudot_uint_result(u16vec4 x, u16vec4 y);

// Signed packed dot product with different output widths.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint8_t spdot_to_8(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint16_t spdot_to_16(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
uint spdot_to_32(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450)
int spdot_to_i32(uint x, uint y, spirv_literal uint packedFormat);

// Unsigned packed dot product with different output widths.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint8_t updot_to_8(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint16_t updot_to_16(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451)
uint updot_to_32(uint x, uint y, spirv_literal uint packedFormat);

// Mixed packed dot product with different output widths.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint8_t supdot_to_8(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint16_t supdot_to_16(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
uint supdot_to_32(uint x, uint y, spirv_literal uint packedFormat);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452)
int supdot_to_i32(uint x, uint y, spirv_literal uint packedFormat);

// SDotAccSat with unsigned input and result type
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
int sdotaddsat_int_result(u16vec4 x, u16vec4 y, int acc);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453)
uint sdotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc);

// UDotAccSat. Result type must be unsigned in SPIR-V.
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4454)
uint udotaddsat(u16vec4 x, u16vec4 y, int acc);

// SUDotAccSat
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455)
int sudotaddsat_int_result(u16vec4 x, u16vec4 y, int acc);
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455)
uint sudotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc);

void main() {
int sdot_int = sdot_int_result(comp3.x, comp3.y);
uint sdot_uint = sdot_uint_result(comp3.x, comp3.y);
uint udot_uint = udot_uint_result(comp3.x, comp3.y);
int sudot_int = sudot_int_result(comp3.x, comp3.y);
uint sudot_uint = sudot_uint_result(comp3.x, comp3.y);

uint8_t spdot8 = spdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint16_t spdot16 = spdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint spdot32 = spdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
int spdoti32 = spdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit

uint8_t updot8 = updot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint16_t updot16 = updot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint updot32 = updot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit

uint8_t supdot8 = supdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint16_t supdot16 = supdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
uint supdot32 = supdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit
int supdoti32 = supdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit

int sdotaddsat_int = sdotaddsat_int_result(comp3.x, comp3.y, comp3.acc);
uint sdotaddsat_uint = sdotaddsat_uint_result(comp3.x, comp3.y, comp3.acc);
uint udotaddsat_uint = udotaddsat(comp3.x, comp3.y, comp3.acc);
int sudotaddsat_int = sudotaddsat_int_result(comp3.x, comp3.y, comp3.acc);
uint sudotaddsat_uint = sudotaddsat_uint_result(comp3.x, comp3.y, comp3.acc);
}
150 changes: 150 additions & 0 deletions spirv_msl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7458,6 +7458,22 @@ void CompilerMSL::emit_custom_functions()
statement("");
break;

case SPVFuncImplReduceAdd:
// Metal doesn't support __builtin_reduce_add or simd_reduce_add, so we need this.
// Metal also doesn't support the other vector builtins, which would have been useful to make this a single template.

statement("template <typename T>");
statement("T reduce_add(vec<T, 2> v) { return v.x + v.y; }");

statement("template <typename T>");
statement("T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; }");

statement("template <typename T>");
statement("T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; }");

statement("");
break;

default:
break;
}
Expand Down Expand Up @@ -9641,6 +9657,132 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
break;
}

case OpSDot:
case OpUDot:
case OpSUDot:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];

auto &input_type1 = expression_type(vec1);
auto &input_type2 = expression_type(vec2);

string vec1input, vec2input;
auto input_size = input_type1.vecsize;
if (instruction.length == 5)
{
if (ops[4] == PackedVectorFormatPackedVectorFormat4x8Bit)
{
string type = opcode == OpSDot || opcode == OpSUDot ? "char4" : "uchar4";
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
type = opcode == OpSDot ? "char4" : "uchar4";
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
input_size = 4;
}
else
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
}
else
{
// Inputs are sign or zero-extended to their target width.
SPIRType::BaseType vec1_expected_type =
opcode != OpUDot ?
to_signed_basetype(input_type1.width) :
to_unsigned_basetype(input_type1.width);

SPIRType::BaseType vec2_expected_type =
opcode != OpSDot ?
to_unsigned_basetype(input_type2.width) :
to_signed_basetype(input_type2.width);

vec1input = bitcast_expression(vec1_expected_type, vec1);
vec2input = bitcast_expression(vec2_expected_type, vec2);
}

auto &type = get<SPIRType>(result_type);

// We'll get the appropriate sign-extend or zero-extend, no matter which type we cast to here.
// The addition in reduce_add is sign-invariant.
auto result_type_cast = join(type_to_glsl(type), input_size);

string exp = join("reduce_add(",
result_type_cast, "(", vec1input, ") * ",
result_type_cast, "(", vec2input, "))");

emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);
break;
}

case OpSDotAccSat:
case OpUDotAccSat:
case OpSUDotAccSat:
{
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t vec1 = ops[2];
uint32_t vec2 = ops[3];
uint32_t acc = ops[4];

auto input_type1 = expression_type(vec1);
auto input_type2 = expression_type(vec2);

string vec1input, vec2input;
if (instruction.length == 6)
{
if (ops[5] == PackedVectorFormatPackedVectorFormat4x8Bit)
{
string type = opcode == OpSDotAccSat || opcode == OpSUDotAccSat ? "char4" : "uchar4";
vec1input = join("as_type<", type, ">(", to_expression(vec1), ")");
type = opcode == OpSDotAccSat ? "char4" : "uchar4";
vec2input = join("as_type<", type, ">(", to_expression(vec2), ")");
input_type1.vecsize = 4;
input_type2.vecsize = 4;
}
else
SPIRV_CROSS_THROW("Packed vector formats other than 4x8Bit for integer dot product is not supported.");
}
else
{
// Inputs are sign or zero-extended to their target width.
SPIRType::BaseType vec1_expected_type =
opcode != OpUDotAccSat ?
to_signed_basetype(input_type1.width) :
to_unsigned_basetype(input_type1.width);

SPIRType::BaseType vec2_expected_type =
opcode != OpSDotAccSat ?
to_unsigned_basetype(input_type2.width) :
to_signed_basetype(input_type2.width);

vec1input = bitcast_expression(vec1_expected_type, vec1);
vec2input = bitcast_expression(vec2_expected_type, vec2);
}

auto &type = get<SPIRType>(result_type);

SPIRType::BaseType pre_saturate_type =
opcode != OpUDotAccSat ?
to_signed_basetype(type.width) :
to_unsigned_basetype(type.width);

input_type1.basetype = pre_saturate_type;
input_type2.basetype = pre_saturate_type;

string exp = join(type_to_glsl(type), "(addsat(reduce_add(",
type_to_glsl(input_type1), "(", vec1input, ") * ",
type_to_glsl(input_type2), "(", vec2input, ")), ",
bitcast_expression(pre_saturate_type, acc), "))");

emit_op(result_type, id, exp, should_forward(vec1) && should_forward(vec2));
inherit_expression_dependencies(id, vec1);
inherit_expression_dependencies(id, vec2);
break;
}

default:
CompilerGLSL::emit_instruction(instruction);
break;
Expand Down Expand Up @@ -17266,6 +17408,14 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
case OpGroupNonUniformQuadSwap:
return SPVFuncImplQuadSwap;

case OpSDot:
case OpUDot:
case OpSUDot:
case OpSDotAccSat:
case OpUDotAccSat:
case OpSUDotAccSat:
return SPVFuncImplReduceAdd;

default:
break;
}
Expand Down
3 changes: 2 additions & 1 deletion spirv_msl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,8 @@ class CompilerMSL : public CompilerGLSL
SPVFuncImplVariableDescriptor,
SPVFuncImplVariableSizedDescriptor,
SPVFuncImplVariableDescriptorArray,
SPVFuncImplPaddedStd140
SPVFuncImplPaddedStd140,
SPVFuncImplReduceAdd
};

// If the underlying resource has been used for comparison then duplicate loads of that resource must be too
Expand Down

0 comments on commit 64f64c8

Please sign in to comment.