From e5ec1941196140a43cf625ce3430e1506ce58a08 Mon Sep 17 00:00:00 2001 From: luciechoi Date: Fri, 31 Oct 2025 00:45:45 +0000 Subject: [PATCH 1/2] Fix precision for . --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 21 ++++++++++++------- .../test/CodeGenSPIRV/intrinsics.dot2add.hlsl | 15 +++++++------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 7029d48d88..7ab49b876d 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -13165,17 +13165,24 @@ SpirvInstruction *SpirvEmitter::processIntrinsicDP2a(const CallExpr *callExpr) { SpirvInstruction *arg1Instr = doExpr(arg1); SpirvInstruction *arg2Instr = doExpr(arg2); - // Create the dot product of the half2 vectors. - SpirvInstruction *dotInstr = spvBuilder.createBinaryOp( - spv::Op::OpDot, componentType, arg0Instr, arg1Instr, loc, range); + // Convert the half vectors into float vectors to preserve precision + // as specified in the HLSL documentation. + arg0Instr = spvBuilder.createUnaryOp( + spv::Op::OpFConvert, + astContext.getExtVectorType(astContext.FloatTy, vecSize), arg0Instr, loc, + range); + arg1Instr = spvBuilder.createUnaryOp( + spv::Op::OpFConvert, + astContext.getExtVectorType(astContext.FloatTy, vecSize), arg1Instr, loc, + range); - // Convert dot product (half type) to result type (float). + // Create the dot product of the half2 vectors. QualType resultType = callExpr->getType(); - SpirvInstruction *floatDotInstr = spvBuilder.createUnaryOp( - spv::Op::OpFConvert, resultType, dotInstr, loc, range); + SpirvInstruction *dotInstr = spvBuilder.createBinaryOp( + spv::Op::OpDot, resultType, arg0Instr, arg1Instr, loc, range); // Sum the dot product result and accumulator and return. - return spvBuilder.createBinaryOp(spv::Op::OpFAdd, resultType, floatDotInstr, + return spvBuilder.createBinaryOp(spv::Op::OpFAdd, resultType, dotInstr, arg2Instr, loc, range); } diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl index 97c0189397..98ed064e67 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl @@ -12,8 +12,9 @@ float main() : SV_Target { // CHECK: [[input1A:%[0-9]+]] = OpLoad %v2half %input1A // CHECK: [[input1B:%[0-9]+]] = OpLoad %v2half %input1B // CHECK: [[acc1:%[0-9]+]] = OpLoad %float %acc1 -// CHECK: [[dot1_0:%[0-9]+]] = OpDot %half [[input1A]] [[input1B]] -// CHECK: [[dot1:%[0-9]+]] = OpFConvert %float [[dot1_0]] +// CHECK: [[input1A_0:%[0-9]+]] = OpFConvert %v2float [[input1A]] +// CHECK: [[input1B_0:%[0-9]+]] = OpFConvert %v2float [[input1B]] +// CHECK: [[dot1:%[0-9]+]] = OpDot %float [[input1A_0]] [[input1B_0]] // CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[dot1]] [[acc1]] res += dot2add(input1A, input1B, acc1); @@ -25,8 +26,9 @@ float main() : SV_Target { // CHECK: [[input2B:%[0-9]+]] = OpVectorShuffle %v2half [[input2_1]] [[input2_1]] 2 3 // CHECK: [[acc2_0:%[0-9]+]] = OpLoad %int %acc2 // CHECK: [[acc2:%[0-9]+]] = OpConvertSToF %float [[acc2_0]] -// CHECK: [[dot2_0:%[0-9]+]] = OpDot %half [[input2A]] [[input2B]] -// CHECK: [[dot2:%[0-9]+]] = OpFConvert %float [[dot2_0]] +// CHECK: [[input2A_0:%[0-9]+]] = OpFConvert %v2float [[input2A]] +// CHECK: [[input2B_0:%[0-9]+]] = OpFConvert %v2float [[input2B]] +// CHECK: [[dot2:%[0-9]+]] = OpDot %float [[input2A_0]] [[input2B_0]] // CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[dot2]] [[acc2]] res += dot2add(input2.xy, input2.zw, acc2); @@ -44,8 +46,9 @@ float main() : SV_Target { // CHECK: [[input3B:%[0-9]+]] = OpConvertSToF %v2half [[input3B_5]] // CHECK: [[acc3_1:%[0-9]+]] = OpLoad %half %acc3 // CHECK: [[acc3:%[0-9]+]] = OpFConvert %float [[acc3_1]] -// CHECK: [[dot3_1:%[0-9]+]] = OpDot %half [[input3A]] [[input3B]] -// CHECK: [[dot3:%[0-9]+]] = OpFConvert %float [[dot3_1]] +// CHECK: [[input3A_0:%[0-9]+]] = OpFConvert %v2float [[input3A]] +// CHECK: [[input3B_0:%[0-9]+]] = OpFConvert %v2float [[input3B]] +// CHECK: [[dot3:%[0-9]+]] = OpDot %float [[input3A_0]] [[input3B_0]] // CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[dot3]] [[acc3]] res += dot2add(input3A, input3B, acc3); From 6a495e20b79aa548b2dfe13daaf4e4a402957c37 Mon Sep 17 00:00:00 2001 From: luciechoi Date: Sat, 1 Nov 2025 00:54:05 +0000 Subject: [PATCH 2/2] Multiplication happens with half precision --- tools/clang/lib/SPIRV/SpirvEmitter.cpp | 29 +++++++++--------- .../test/CodeGenSPIRV/intrinsics.dot2add.hlsl | 30 +++++++++++-------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 7ab49b876d..810c5aa5c4 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -13165,25 +13165,26 @@ SpirvInstruction *SpirvEmitter::processIntrinsicDP2a(const CallExpr *callExpr) { SpirvInstruction *arg1Instr = doExpr(arg1); SpirvInstruction *arg2Instr = doExpr(arg2); - // Convert the half vectors into float vectors to preserve precision - // as specified in the HLSL documentation. - arg0Instr = spvBuilder.createUnaryOp( + // Multiply the two half2 vectors and convert the result to float2. + SpirvInstruction *mulInstr = spvBuilder.createBinaryOp( + spv::Op::OpFMul, vecType, arg0Instr, arg1Instr, loc, range); + SpirvInstruction *convertInstr = spvBuilder.createUnaryOp( spv::Op::OpFConvert, - astContext.getExtVectorType(astContext.FloatTy, vecSize), arg0Instr, loc, - range); - arg1Instr = spvBuilder.createUnaryOp( - spv::Op::OpFConvert, - astContext.getExtVectorType(astContext.FloatTy, vecSize), arg1Instr, loc, + astContext.getExtVectorType(astContext.FloatTy, vecSize), mulInstr, loc, range); - // Create the dot product of the half2 vectors. - QualType resultType = callExpr->getType(); - SpirvInstruction *dotInstr = spvBuilder.createBinaryOp( - spv::Op::OpDot, resultType, arg0Instr, arg1Instr, loc, range); + // Extract each float element and and sum them up. + SpirvInstruction *extractedElem0 = spvBuilder.createCompositeExtract( + astContext.FloatTy, convertInstr, {0}, loc, range); + SpirvInstruction *extractedElem1 = spvBuilder.createCompositeExtract( + astContext.FloatTy, convertInstr, {1}, loc, range); + SpirvInstruction *dotInstr = + spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, + extractedElem0, extractedElem1, loc, range); // Sum the dot product result and accumulator and return. - return spvBuilder.createBinaryOp(spv::Op::OpFAdd, resultType, dotInstr, - arg2Instr, loc, range); + return spvBuilder.createBinaryOp(spv::Op::OpFAdd, astContext.FloatTy, + dotInstr, arg2Instr, loc, range); } SpirvInstruction * diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl index 98ed064e67..8798b8ba4d 100644 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.dot2add.hlsl @@ -12,10 +12,12 @@ float main() : SV_Target { // CHECK: [[input1A:%[0-9]+]] = OpLoad %v2half %input1A // CHECK: [[input1B:%[0-9]+]] = OpLoad %v2half %input1B // CHECK: [[acc1:%[0-9]+]] = OpLoad %float %acc1 -// CHECK: [[input1A_0:%[0-9]+]] = OpFConvert %v2float [[input1A]] -// CHECK: [[input1B_0:%[0-9]+]] = OpFConvert %v2float [[input1B]] -// CHECK: [[dot1:%[0-9]+]] = OpDot %float [[input1A_0]] [[input1B_0]] -// CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[dot1]] [[acc1]] +// CHECK: [[mult1A:%[0-9]+]] = OpFMul %v2half [[input1A]] [[input1B]] +// CHECK: [[convert1A:%[0-9]+]] = OpFConvert %v2float [[mult1A]] +// CHECK: [[extract1A_0:%[0-9]+]] = OpCompositeExtract %float [[convert1A]] 0 +// CHECK: [[extract1A_1:%[0-9]+]] = OpCompositeExtract %float [[convert1A]] 1 +// CHECK: [[add1A:%[0-9]+]] = OpFAdd %float [[extract1A_0]] [[extract1A_1]] +// CHECK: [[res1:%[0-9]+]] = OpFAdd %float [[add1A]] [[acc1]] res += dot2add(input1A, input1B, acc1); half4 input2; @@ -26,10 +28,12 @@ float main() : SV_Target { // CHECK: [[input2B:%[0-9]+]] = OpVectorShuffle %v2half [[input2_1]] [[input2_1]] 2 3 // CHECK: [[acc2_0:%[0-9]+]] = OpLoad %int %acc2 // CHECK: [[acc2:%[0-9]+]] = OpConvertSToF %float [[acc2_0]] -// CHECK: [[input2A_0:%[0-9]+]] = OpFConvert %v2float [[input2A]] -// CHECK: [[input2B_0:%[0-9]+]] = OpFConvert %v2float [[input2B]] -// CHECK: [[dot2:%[0-9]+]] = OpDot %float [[input2A_0]] [[input2B_0]] -// CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[dot2]] [[acc2]] +// CHECK: [[mult2A:%[0-9]+]] = OpFMul %v2half [[input2A]] [[input2B]] +// CHECK: [[convert2A:%[0-9]+]] = OpFConvert %v2float [[mult2A]] +// CHECK: [[extract2A_0:%[0-9]+]] = OpCompositeExtract %float [[convert2A]] 0 +// CHECK: [[extract2A_1:%[0-9]+]] = OpCompositeExtract %float [[convert2A]] 1 +// CHECK: [[add2A:%[0-9]+]] = OpFAdd %float [[extract2A_0]] [[extract2A_1]] +// CHECK: [[res2:%[0-9]+]] = OpFAdd %float [[add2A]] [[acc2]] res += dot2add(input2.xy, input2.zw, acc2); float input3A; @@ -46,10 +50,12 @@ float main() : SV_Target { // CHECK: [[input3B:%[0-9]+]] = OpConvertSToF %v2half [[input3B_5]] // CHECK: [[acc3_1:%[0-9]+]] = OpLoad %half %acc3 // CHECK: [[acc3:%[0-9]+]] = OpFConvert %float [[acc3_1]] -// CHECK: [[input3A_0:%[0-9]+]] = OpFConvert %v2float [[input3A]] -// CHECK: [[input3B_0:%[0-9]+]] = OpFConvert %v2float [[input3B]] -// CHECK: [[dot3:%[0-9]+]] = OpDot %float [[input3A_0]] [[input3B_0]] -// CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[dot3]] [[acc3]] +// CHECK: [[mult3A:%[0-9]+]] = OpFMul %v2half [[input3A]] [[input3B]] +// CHECK: [[convert3A:%[0-9]+]] = OpFConvert %v2float [[mult3A]] +// CHECK: [[extract3A_0:%[0-9]+]] = OpCompositeExtract %float [[convert3A]] 0 +// CHECK: [[extract3A_1:%[0-9]+]] = OpCompositeExtract %float [[convert3A]] 1 +// CHECK: [[add3A:%[0-9]+]] = OpFAdd %float [[extract3A_0]] [[extract3A_1]] +// CHECK: [[res3:%[0-9]+]] = OpFAdd %float [[add3A]] [[acc3]] res += dot2add(input3A, input3B, acc3); return res;