diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 7029d48d88..810c5aa5c4 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -13165,18 +13165,26 @@ SpirvInstruction *SpirvEmitter::processIntrinsicDP2a(const CallExpr *callExpr) { SpirvInstruction *arg1Instr = doExpr(arg1); SpirvInstruction *arg2Instr = doExpr(arg2); - // Create the dot product of the half2 vectors. - SpirvInstruction *dotInstr = spvBuilder.createBinaryOp( - spv::Op::OpDot, componentType, arg0Instr, arg1Instr, loc, range); + // Multiply the two half2 vectors and convert the result to float2. + SpirvInstruction *mulInstr = spvBuilder.createBinaryOp( + spv::Op::OpFMul, vecType, arg0Instr, arg1Instr, loc, range); + SpirvInstruction *convertInstr = spvBuilder.createUnaryOp( + spv::Op::OpFConvert, + astContext.getExtVectorType(astContext.FloatTy, vecSize), mulInstr, loc, + range); - // Convert dot product (half type) to result type (float). - QualType resultType = callExpr->getType(); - SpirvInstruction *floatDotInstr = spvBuilder.createUnaryOp( - spv::Op::OpFConvert, resultType, dotInstr, loc, range); + // Extract each float element and and sum them up. + SpirvInstruction *extractedElem0 = spvBuilder.createCompositeExtract( + astContext.FloatTy, convertInstr, {0}, loc, range); + SpirvInstruction *extractedElem1 = spvBuilder.createCompositeExtract( + astContext.FloatTy, convertInstr, {1}, loc, range); + SpirvInstruction *dotInstr = + spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, + extractedElem0, extractedElem1, loc, range); // Sum the dot product result and accumulator and return. - return spvBuilder.createBinaryOp(spv::Op::OpFAdd, resultType, floatDotInstr, - arg2Instr, loc, range); + return spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, + dotInstr, arg2Instr, loc, range); } SpirvInstruction * diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl index 97c0189397..8798b8ba4d 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl @@ -12,9 +12,12 @@ float main() : SV_Target { // CHECK: [[input1A:%[0-9]+]] = OpLoad %v2half %input1A // CHECK: [[input1B:%[0-9]+]] = OpLoad %v2half %input1B // CHECK: [[acc1:%[0-9]+]] = OpLoad %float %acc1 -// CHECK: [[dot1_0:%[0-9]+]] = OpDot %half [[input1A]] [[input1B]] -// CHECK: [[dot1:%[0-9]+]] = OpFConvert %float [[dot1_0]] -// CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[dot1]] [[acc1]] +// CHECK: [[mult1A:%[0-9]+]] = OpFMul %v2half [[input1A]] [[input1B]] +// CHECK: [[convert1A:%[0-9]+]] = OpFConvert %v2float [[mult1A]] +// CHECK: [[extract1A_0:%[0-9]+]] = OpCompositeExtract %float [[convert1A]] 0 +// CHECK: [[extract1A_1:%[0-9]+]] = OpCompositeExtract %float [[convert1A]] 1 +// CHECK: [[add1A:%[0-9]+]] = OpFAdd %float [[extract1A_0]] [[extract1A_1]] +// CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[add1A]] [[acc1]] res += dot2add(input1A, input1B, acc1); half4 input2; @@ -25,9 +28,12 @@ float main() : SV_Target { // CHECK: [[input2B:%[0-9]+]] = OpVectorShuffle %v2half [[input2_1]] [[input2_1]] 2 3 // CHECK: [[acc2_0:%[0-9]+]] = OpLoad %int %acc2 // CHECK: [[acc2:%[0-9]+]] = OpConvertSToF %float [[acc2_0]] -// CHECK: [[dot2_0:%[0-9]+]] = OpDot %half [[input2A]] [[input2B]] -// CHECK: [[dot2:%[0-9]+]] = OpFConvert %float [[dot2_0]] -// CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[dot2]] [[acc2]] +// CHECK: [[mult2A:%[0-9]+]] = OpFMul %v2half [[input2A]] [[input2B]] +// CHECK: [[convert2A:%[0-9]+]] = OpFConvert %v2float [[mult2A]] +// CHECK: [[extract2A_0:%[0-9]+]] = OpCompositeExtract %float [[convert2A]] 0 +// CHECK: [[extract2A_1:%[0-9]+]] = OpCompositeExtract %float [[convert2A]] 1 +// CHECK: [[add2A:%[0-9]+]] = OpFAdd %float [[extract2A_0]] [[extract2A_1]] +// CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[add2A]] [[acc2]] res += dot2add(input2.xy, input2.zw, acc2); float input3A; @@ -44,9 +50,12 @@ float main() : SV_Target { // CHECK: [[input3B:%[0-9]+]] = OpConvertSToF %v2half [[input3B_5]] // CHECK: [[acc3_1:%[0-9]+]] = OpLoad %half %acc3 // CHECK: [[acc3:%[0-9]+]] = OpFConvert %float [[acc3_1]] -// CHECK: [[dot3_1:%[0-9]+]] = OpDot %half [[input3A]] [[input3B]] -// CHECK: [[dot3:%[0-9]+]] = OpFConvert %float [[dot3_1]] -// CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[dot3]] [[acc3]] +// CHECK: [[mult3A:%[0-9]+]] = OpFMul %v2half [[input3A]] [[input3B]] +// CHECK: [[convert3A:%[0-9]+]] = OpFConvert %v2float [[mult3A]] +// CHECK: [[extract3A_0:%[0-9]+]] = OpCompositeExtract %float [[convert3A]] 0 +// CHECK: [[extract3A_1:%[0-9]+]] = OpCompositeExtract %float [[convert3A]] 1 +// CHECK: [[add3A:%[0-9]+]] = OpFAdd %float [[extract3A_0]] [[extract3A_1]] +// CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[add3A]] [[acc3]] res += dot2add(input3A, input3B, acc3); return res;