-
Notifications
You must be signed in to change notification settings - Fork 575
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2266 from KhronosGroup/pr-2257
Land MSL integer dot products
- Loading branch information
Showing
4 changed files
with
331 additions
and
1 deletion.
There are no files selected for viewing
65 changes: 65 additions & 0 deletions
65
reference/shaders-msl-no-opt/comp/integer-dot-product.comp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
#pragma clang diagnostic ignored "-Wmissing-prototypes" | ||
|
||
#include <metal_stdlib> | ||
#include <simd/simd.h> | ||
|
||
using namespace metal; | ||
|
||
template <typename T> | ||
T reduce_add(vec<T, 2> v) { return v.x + v.y; } | ||
template <typename T> | ||
T reduce_add(vec<T, 3> v) { return v.x + v.y + v.z; } | ||
template <typename T> | ||
T reduce_add(vec<T, 4> v) { return v.x + v.y + v.z + v.w; } | ||
|
||
struct InOut3 | ||
{ | ||
ushort4 x; | ||
ushort4 y; | ||
int acc; | ||
int result; | ||
}; | ||
|
||
struct InOut2 | ||
{ | ||
uint x; | ||
uint y; | ||
uint result; | ||
}; | ||
|
||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); | ||
|
||
struct InOut | ||
{ | ||
uint4 x; | ||
uint4 y; | ||
int result; | ||
}; | ||
|
||
kernel void main0(device void* spvBufferAliasSet0Binding1 [[buffer(0)]]) | ||
{ | ||
device auto& comp3 = *(device InOut3*)spvBufferAliasSet0Binding1; | ||
device auto& comp2 = *(device InOut2*)spvBufferAliasSet0Binding1; | ||
int sdot_int = reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))); | ||
uint sdot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(short4(comp3.y))); | ||
uint udot_uint = reduce_add(uint4(comp3.x) * uint4(comp3.y)); | ||
int sudot_int = reduce_add(int4(short4(comp3.x)) * int4(comp3.y)); | ||
uint sudot_uint = reduce_add(uint4(short4(comp3.x)) * uint4(comp3.y)); | ||
uchar spdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<char4>(comp2.y))); | ||
ushort spdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<char4>(comp2.y))); | ||
uint spdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<char4>(comp2.y))); | ||
int spdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<char4>(comp2.y))); | ||
uchar updot8 = reduce_add(uchar4(as_type<uchar4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y))); | ||
ushort updot16 = reduce_add(ushort4(as_type<uchar4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y))); | ||
uint updot32 = reduce_add(uint4(as_type<uchar4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y))); | ||
uchar supdot8 = reduce_add(uchar4(as_type<char4>(comp2.x)) * uchar4(as_type<uchar4>(comp2.y))); | ||
ushort supdot16 = reduce_add(ushort4(as_type<char4>(comp2.x)) * ushort4(as_type<uchar4>(comp2.y))); | ||
uint supdot32 = reduce_add(uint4(as_type<char4>(comp2.x)) * uint4(as_type<uchar4>(comp2.y))); | ||
int supdoti32 = reduce_add(int4(as_type<char4>(comp2.x)) * int4(as_type<uchar4>(comp2.y))); | ||
int sdotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc)); | ||
uint sdotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(short4(comp3.y))), comp3.acc)); | ||
uint udotaddsat_uint = uint(addsat(reduce_add(uint4(comp3.x) * uint4(comp3.y)), uint(comp3.acc))); | ||
int sudotaddsat_int = int(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc)); | ||
uint sudotaddsat_uint = uint(addsat(reduce_add(int4(short4(comp3.x)) * int4(comp3.y)), comp3.acc)); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#version 450 | ||
#extension GL_EXT_shader_8bit_storage : require | ||
#extension GL_EXT_shader_16bit_storage : require | ||
#extension GL_EXT_shader_explicit_arithmetic_types : require | ||
#extension GL_EXT_spirv_intrinsics : require | ||
|
||
layout(local_size_x = 1) in; | ||
|
||
layout(std430, binding = 0) buffer InOut { | ||
uvec4 x; | ||
uvec4 y; | ||
int result; | ||
} comp; | ||
|
||
layout(std430, binding = 1) buffer InOut2 { | ||
uint x; | ||
uint y; | ||
uint result; | ||
} comp2; | ||
|
||
layout(std430, binding = 1) buffer InOut3 { | ||
u16vec4 x; | ||
u16vec4 y; | ||
int acc; | ||
int result; | ||
} comp3; | ||
|
||
// Signed integer dot with unsigned integer | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450) | ||
int sdot_int_result(u16vec4 x, u16vec4 y); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450) | ||
uint sdot_uint_result(u16vec4 x, u16vec4 y); | ||
|
||
// Unsigned integer dot with signed integer. Only unsigned result is allowed in SPIR-V. | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451) | ||
uint udot_uint_result(u16vec4 x, u16vec4 y); | ||
|
||
// Mixed integer dot with unsigned integer | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452) | ||
int sudot_int_result(u16vec4 x, u16vec4 y); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452) | ||
uint sudot_uint_result(u16vec4 x, u16vec4 y); | ||
|
||
// Signed packed dot product with different output widths. | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450) | ||
uint8_t spdot_to_8(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450) | ||
uint16_t spdot_to_16(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450) | ||
uint spdot_to_32(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4450) | ||
int spdot_to_i32(uint x, uint y, spirv_literal uint packedFormat); | ||
|
||
// Unsigned packed dot product with different output widths. | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451) | ||
uint8_t updot_to_8(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451) | ||
uint16_t updot_to_16(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4451) | ||
uint updot_to_32(uint x, uint y, spirv_literal uint packedFormat); | ||
|
||
// Mixed packed dot product with different output widths. | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452) | ||
uint8_t supdot_to_8(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452) | ||
uint16_t supdot_to_16(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452) | ||
uint supdot_to_32(uint x, uint y, spirv_literal uint packedFormat); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4452) | ||
int supdot_to_i32(uint x, uint y, spirv_literal uint packedFormat); | ||
|
||
// SDotAccSat with unsigned input and result type | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453) | ||
int sdotaddsat_int_result(u16vec4 x, u16vec4 y, int acc); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4453) | ||
uint sdotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc); | ||
|
||
// UDotAccSat. Result type must be unsigned in SPIR-V. | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4454) | ||
uint udotaddsat(u16vec4 x, u16vec4 y, int acc); | ||
|
||
// SUDotAccSat | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455) | ||
int sudotaddsat_int_result(u16vec4 x, u16vec4 y, int acc); | ||
spirv_instruction (extensions = ["SPV_KHR_integer_dot_product"], capabilities = [6019], id = 4455) | ||
uint sudotaddsat_uint_result(u16vec4 x, u16vec4 y, int acc); | ||
|
||
void main() { | ||
int sdot_int = sdot_int_result(comp3.x, comp3.y); | ||
uint sdot_uint = sdot_uint_result(comp3.x, comp3.y); | ||
uint udot_uint = udot_uint_result(comp3.x, comp3.y); | ||
int sudot_int = sudot_int_result(comp3.x, comp3.y); | ||
uint sudot_uint = sudot_uint_result(comp3.x, comp3.y); | ||
|
||
uint8_t spdot8 = spdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
uint16_t spdot16 = spdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
uint spdot32 = spdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
int spdoti32 = spdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
|
||
uint8_t updot8 = updot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
uint16_t updot16 = updot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
uint updot32 = updot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
|
||
uint8_t supdot8 = supdot_to_8(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
uint16_t supdot16 = supdot_to_16(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
uint supdot32 = supdot_to_32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
int supdoti32 = supdot_to_i32(comp2.x, comp2.y, 0x0); // PackedVectorFormat4x8Bit | ||
|
||
int sdotaddsat_int = sdotaddsat_int_result(comp3.x, comp3.y, comp3.acc); | ||
uint sdotaddsat_uint = sdotaddsat_uint_result(comp3.x, comp3.y, comp3.acc); | ||
uint udotaddsat_uint = udotaddsat(comp3.x, comp3.y, comp3.acc); | ||
int sudotaddsat_int = sudotaddsat_int_result(comp3.x, comp3.y, comp3.acc); | ||
uint sudotaddsat_uint = sudotaddsat_uint_result(comp3.x, comp3.y, comp3.acc); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters