diff --git a/source/val/validate_dot_product.cpp b/source/val/validate_dot_product.cpp index 21cc4be247..972c6445c6 100644 --- a/source/val/validate_dot_product.cpp +++ b/source/val/validate_dot_product.cpp @@ -59,6 +59,12 @@ spv_result_t ValidateSameSignedDot(ValidationState_t& _, return _.diag(SPV_ERROR_INVALID_DATA, inst) << "'Vector 1' and 'Vector 2' must be the same type."; } else if (is_vec_1_scalar && is_vec_2_scalar) { + if (!_.HasCapability(spv::Capability::DotProductInput4x8BitPacked)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "DotProductInput4x8BitPacked capability is required to use " + "scalar integers."; + } + // If both are scalar, spec doesn't say Signedness needs to match const uint32_t vec_1_width = _.GetBitWidth(vec_1_id); const uint32_t vec_2_width = _.GetBitWidth(vec_2_id); @@ -139,6 +145,32 @@ spv_result_t ValidateSameSignedDot(ValidationState_t& _, << vec_1_width << ")."; } + if (!_.HasCapability(spv::Capability::DotProductInputAll)) { + // 4-wide 8-bit ints are special exception that has its own capability + if (vec_1_length == 4 && vec_1_width == 8) { + if (!_.HasCapability(spv::Capability::DotProductInput4x8Bit)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "DotProductInput4x8Bit or DotProductInputAll capability is " + "required to use 4-component vectors of 8-bit integers."; + } + } else { + // provide a more helpful message what is going on if we are here + // reporting this error + if (_.HasCapability(spv::Capability::DotProductInput4x8BitPacked)) { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "DotProductInputAll capability is required use vectors. " + "(DotProductInput4x8BitPacked capability declared allows " + "for only 32-bit int scalars)"; + } else { + return _.diag(SPV_ERROR_INVALID_DATA, inst) + << "DotProductInputAll capability is additionally required to " + "the DotProduct capability to use vectors. (It is possible " + "to set DotProductInput4x8BitPacked to only use 32-bit " + "scalars packed as a 4-wide 8-byte vector)"; + } + } + } + if (opcode == spv::Op::OpUDot || opcode == spv::Op::OpUDotAccSat) { const bool vec_1_unsigned = _.IsIntScalarTypeWithSignedness(vec_1_type, 0); diff --git a/test/val/val_extension_spv_khr_integer_dot_product_test.cpp b/test/val/val_extension_spv_khr_integer_dot_product_test.cpp index df5df125ce..9e3d53aa87 100644 --- a/test/val/val_extension_spv_khr_integer_dot_product_test.cpp +++ b/test/val/val_extension_spv_khr_integer_dot_product_test.cpp @@ -1278,6 +1278,8 @@ std::string GenerateShaderCode(const std::string& body) { ss << R"( OpCapability Shader OpCapability DotProductKHR +OpCapability DotProductInputAll +OpCapability DotProductInput4x8BitPacked OpCapability Int16 OpCapability Int64 OpCapability LongVectorEXT @@ -1489,6 +1491,120 @@ TEST_F(ValidateIntegerDotProductSimple, UDotAccSat) { HasSubstr("Result must be the same as the Accumulator type")); } +TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInput4x8BitPacked) { + const std::string ss = R"( + OpCapability Shader + OpCapability DotProductKHR + OpExtension "SPV_KHR_integer_dot_product" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + %void = OpTypeVoid + %fn = OpTypeFunction %void + %int = OpTypeInt 32 1 + %v2int = OpTypeVector %int 2 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %fn + %label = OpLabel + %x = OpSDot %int %int_1 %int_1 + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(ss); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("DotProductInput4x8BitPacked capability is required to " + "use scalar integers")); +} + +TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInput4x8Bit) { + const std::string ss = R"( + OpCapability Shader + OpCapability DotProductKHR + OpCapability Int8 + OpExtension "SPV_KHR_integer_dot_product" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + %void = OpTypeVoid + %fn = OpTypeFunction %void + %char = OpTypeInt 8 1 + %v4char = OpTypeVector %char 4 + %char_1 = OpConstant %char 1 + %v4char_1 = OpConstantComposite %v4char %char_1 %char_1 %char_1 %char_1 + %main = OpFunction %void None %fn + %label = OpLabel + %x = OpSDot %char %v4char_1 %v4char_1 + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(ss); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("DotProductInput4x8Bit or DotProductInputAll capability is " + "required to use 4-component vectors of 8-bit integers")); +} + +TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInputAll) { + const std::string ss = R"( + OpCapability Shader + OpCapability DotProductKHR + OpExtension "SPV_KHR_integer_dot_product" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + %void = OpTypeVoid + %fn = OpTypeFunction %void + %int = OpTypeInt 32 1 + %int_1 = OpConstant %int 1 + %v4int = OpTypeVector %int 4 + %v4int_1 = OpConstantComposite %v4int %int_1 %int_1 %int_1 %int_1 + %main = OpFunction %void None %fn + %label = OpLabel + %x = OpSDot %int %v4int_1 %v4int_1 + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(ss); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT( + getDiagnosticString(), + HasSubstr("DotProductInputAll capability is additionally required to the " + "DotProduct capability to use vectors. (It is possible to set " + "DotProductInput4x8BitPacked to only use 32-bit scalars packed " + "as a 4-wide 8-byte vector)")); +} + +TEST_F(ValidateIntegerDotProductSimple, CapabilityDotProductInputAll2) { + const std::string ss = R"( + OpCapability Shader + OpCapability DotProductKHR + OpCapability DotProductInput4x8BitPacked + OpExtension "SPV_KHR_integer_dot_product" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + %void = OpTypeVoid + %fn = OpTypeFunction %void + %int = OpTypeInt 32 1 + %int_1 = OpConstant %int 1 + %v4int = OpTypeVector %int 4 + %v4int_1 = OpConstantComposite %v4int %int_1 %int_1 %int_1 %int_1 + %main = OpFunction %void None %fn + %label = OpLabel + %x = OpSDot %int %v4int_1 %v4int_1 + OpReturn + OpFunctionEnd + )"; + CompileSuccessfully(ss); + EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions()); + EXPECT_THAT(getDiagnosticString(), + HasSubstr("DotProductInputAll capability is required use " + "vectors. (DotProductInput4x8BitPacked capability " + "declared allows for only 32-bit int scalars)")); +} + } // namespace } // namespace val } // namespace spvtools