Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions source/val/validate_dot_product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ spv_result_t ValidateSameSignedDot(ValidationState_t& _,
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "'Vector 1' and 'Vector 2' must be the same type.";
} else if (is_vec_1_scalar && is_vec_2_scalar) {
if (!_.HasCapability(spv::Capability::DotProductInput4x8BitPacked)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "DotProductInput4x8BitPacked capability is required to use "
"scalar integers.";
}

// If both are scalar, spec doesn't say Signedness needs to match
const uint32_t vec_1_width = _.GetBitWidth(vec_1_id);
const uint32_t vec_2_width = _.GetBitWidth(vec_2_id);
Expand Down Expand Up @@ -139,6 +145,32 @@ spv_result_t ValidateSameSignedDot(ValidationState_t& _,
<< vec_1_width << ").";
}

if (!_.HasCapability(spv::Capability::DotProductInputAll)) {
// 4-wide 8-bit ints are special exception that has its own capability
if (vec_1_length == 4 && vec_1_width == 8) {
if (!_.HasCapability(spv::Capability::DotProductInput4x8Bit)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "DotProductInput4x8Bit or DotProductInputAll capability is "
"required to use 4-component vectors of 8-bit integers.";
}
} else {
// provide a more helpful message what is going on if we are here
// reporting this error
if (_.HasCapability(spv::Capability::DotProductInput4x8BitPacked)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "DotProductInputAll capability is required use vectors. "
"(DotProductInput4x8BitPacked capability declared allows "
"for only 32-bit int scalars)";
} else {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "DotProductInputAll capability is additionally required to "
"the DotProduct capability to use vectors. (It is possible "
"to set DotProductInput4x8BitPacked to only use 32-bit "
"scalars packed as a 4-wide 8-byte vector)";
}
}
}

if (opcode == spv::Op::OpUDot || opcode == spv::Op::OpUDotAccSat) {
const bool vec_1_unsigned =
_.IsIntScalarTypeWithSignedness(vec_1_type, 0);
Expand Down
116 changes: 116 additions & 0 deletions test/val/val_extension_spv_khr_integer_dot_product_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,8 @@ std::string GenerateShaderCode(const std::string& body) {
ss << R"(
OpCapability Shader
OpCapability DotProductKHR
OpCapability DotProductInputAll
OpCapability DotProductInput4x8BitPacked
OpCapability Int16
OpCapability Int64
OpCapability LongVectorEXT
Expand Down Expand Up @@ -1489,6 +1491,120 @@ TEST_F(ValidateIntegerDotProductSimple, UDotAccSat) {
HasSubstr("Result must be the same as the Accumulator type"));
}

TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInput4x8BitPacked) {
const std::string ss = R"(
OpCapability Shader
OpCapability DotProductKHR
OpExtension "SPV_KHR_integer_dot_product"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%fn = OpTypeFunction %void
%int = OpTypeInt 32 1
%v2int = OpTypeVector %int 2
%int_1 = OpConstant %int 1
%main = OpFunction %void None %fn
%label = OpLabel
%x = OpSDot %int %int_1 %int_1
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(ss);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("DotProductInput4x8BitPacked capability is required to "
"use scalar integers"));
}

TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInput4x8Bit) {
const std::string ss = R"(
OpCapability Shader
OpCapability DotProductKHR
OpCapability Int8
OpExtension "SPV_KHR_integer_dot_product"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%fn = OpTypeFunction %void
%char = OpTypeInt 8 1
%v4char = OpTypeVector %char 4
%char_1 = OpConstant %char 1
%v4char_1 = OpConstantComposite %v4char %char_1 %char_1 %char_1 %char_1
%main = OpFunction %void None %fn
%label = OpLabel
%x = OpSDot %char %v4char_1 %v4char_1
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(ss);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("DotProductInput4x8Bit or DotProductInputAll capability is "
"required to use 4-component vectors of 8-bit integers"));
}

TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInputAll) {
const std::string ss = R"(
OpCapability Shader
OpCapability DotProductKHR
OpExtension "SPV_KHR_integer_dot_product"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%fn = OpTypeFunction %void
%int = OpTypeInt 32 1
%int_1 = OpConstant %int 1
%v4int = OpTypeVector %int 4
%v4int_1 = OpConstantComposite %v4int %int_1 %int_1 %int_1 %int_1
%main = OpFunction %void None %fn
%label = OpLabel
%x = OpSDot %int %v4int_1 %v4int_1
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(ss);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(
getDiagnosticString(),
HasSubstr("DotProductInputAll capability is additionally required to the "
"DotProduct capability to use vectors. (It is possible to set "
"DotProductInput4x8BitPacked to only use 32-bit scalars packed "
"as a 4-wide 8-byte vector)"));
}

TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInputAll2) {
const std::string ss = R"(
OpCapability Shader
OpCapability DotProductKHR
OpCapability DotProductInput4x8BitPacked
OpExtension "SPV_KHR_integer_dot_product"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
%void = OpTypeVoid
%fn = OpTypeFunction %void
%int = OpTypeInt 32 1
%int_1 = OpConstant %int 1
%v4int = OpTypeVector %int 4
%v4int_1 = OpConstantComposite %v4int %int_1 %int_1 %int_1 %int_1
%main = OpFunction %void None %fn
%label = OpLabel
%x = OpSDot %int %v4int_1 %v4int_1
OpReturn
OpFunctionEnd
)";
CompileSuccessfully(ss);
EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
EXPECT_THAT(getDiagnosticString(),
HasSubstr("DotProductInputAll capability is required use "
"vectors. (DotProductInput4x8BitPacked capability "
"declared allows for only 32-bit int scalars)"));
}

} // namespace
} // namespace val
} // namespace spvtools