diff --git a/lib/SPIRV/OCLUtil.cpp b/lib/SPIRV/OCLUtil.cpp index 5245752a17..9ca60144f0 100644 --- a/lib/SPIRV/OCLUtil.cpp +++ b/lib/SPIRV/OCLUtil.cpp @@ -799,12 +799,15 @@ unsigned getOCLVersion(Module *M, bool AllowMulti) { return encodeOCLVer(Ver.first, Ver.second, 0); } -void decodeMDNode(MDNode *N, unsigned &X, unsigned &Y, unsigned &Z) { +SmallVector decodeMDNode(MDNode *N) { if (N == NULL) - return; - X = getMDOperandAsInt(N, 0); - Y = getMDOperandAsInt(N, 1); - Z = getMDOperandAsInt(N, 2); + return {}; + size_t NumOperands = N->getNumOperands(); + SmallVector ReadVals; + ReadVals.reserve(NumOperands); + for (unsigned I = 0; I < NumOperands; ++I) + ReadVals.push_back(getMDOperandAsInt(N, I)); + return ReadVals; } /// Encode LLVM type by SPIR-V execution mode VecTypeHint diff --git a/lib/SPIRV/OCLUtil.h b/lib/SPIRV/OCLUtil.h index 9eb3166571..fbdb812468 100644 --- a/lib/SPIRV/OCLUtil.h +++ b/lib/SPIRV/OCLUtil.h @@ -444,7 +444,7 @@ std::tuple decodeOCLVer(unsigned Ver); /// Decode a MDNode assuming it contains three integer constants. -void decodeMDNode(MDNode *N, unsigned &X, unsigned &Y, unsigned &Z); +SmallVector decodeMDNode(MDNode *N); /// Get full path from debug info metadata /// Return empty string if the path is not available. diff --git a/lib/SPIRV/PreprocessMetadata.cpp b/lib/SPIRV/PreprocessMetadata.cpp index 10e0d78675..da93596bbb 100644 --- a/lib/SPIRV/PreprocessMetadata.cpp +++ b/lib/SPIRV/PreprocessMetadata.cpp @@ -129,27 +129,30 @@ void PreprocessMetadataBase::visit(Module *M) { // !{void (i32 addrspace(1)*)* @kernel, i32 17, i32 X, i32 Y, i32 Z} if (MDNode *WGSize = Kernel.getMetadata(kSPIR2MD::WGSize)) { - unsigned X, Y, Z; - decodeMDNode(WGSize, X, Y, Z); + assert(WGSize->getNumOperands() >= 1 && WGSize->getNumOperands() <= 3 && + "reqd_work_group_size does not have between 1 and 3 operands."); + SmallVector DecodedVals = decodeMDNode(WGSize); EM.addOp() .add(&Kernel) .add(spv::ExecutionModeLocalSize) - .add(X) - .add(Y) - .add(Z) + .add(DecodedVals[0]) + .add(DecodedVals.size() >= 2 ? DecodedVals[1] : 1) + .add(DecodedVals.size() == 3 ? DecodedVals[2] : 1) .done(); } // !{void (i32 addrspace(1)*)* @kernel, i32 18, i32 X, i32 Y, i32 Z} if (MDNode *WGSizeHint = Kernel.getMetadata(kSPIR2MD::WGSizeHint)) { - unsigned X, Y, Z; - decodeMDNode(WGSizeHint, X, Y, Z); + assert(WGSizeHint->getNumOperands() >= 1 && + WGSizeHint->getNumOperands() <= 3 && + "work_group_size_hint does not have between 1 and 3 operands."); + SmallVector DecodedVals = decodeMDNode(WGSizeHint); EM.addOp() .add(&Kernel) .add(spv::ExecutionModeLocalSizeHint) - .add(X) - .add(Y) - .add(Z) + .add(DecodedVals[0]) + .add(DecodedVals.size() >= 2 ? DecodedVals[1] : 1) + .add(DecodedVals.size() == 3 ? DecodedVals[2] : 1) .done(); } @@ -175,14 +178,16 @@ void PreprocessMetadataBase::visit(Module *M) { // i32 Y, i32 Z} if (MDNode *MaxWorkgroupSizeINTEL = Kernel.getMetadata(kSPIR2MD::MaxWGSize)) { - unsigned X, Y, Z; - decodeMDNode(MaxWorkgroupSizeINTEL, X, Y, Z); + assert(MaxWorkgroupSizeINTEL->getNumOperands() == 3 && + "max_work_group_size does not have 3 operands."); + SmallVector DecodedVals = + decodeMDNode(MaxWorkgroupSizeINTEL); EM.addOp() .add(&Kernel) .add(spv::ExecutionModeMaxWorkgroupSizeINTEL) - .add(X) - .add(Y) - .add(Z) + .add(DecodedVals[0]) + .add(DecodedVals[1]) + .add(DecodedVals[2]) .done(); } diff --git a/test/reqd_work_group_size_md.ll b/test/reqd_work_group_size_md.ll new file mode 100644 index 0000000000..08d488b4e8 --- /dev/null +++ b/test/reqd_work_group_size_md.ll @@ -0,0 +1,37 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv +; RUN: llvm-spirv -to-text %t.spv -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; +; The purpose of this test is to check that the reqd_work_group_size metadata +; is correctly converted to the LocalSize execution mode for the kernels it is +; applied to. +; +; CHECK-SPIRV: EntryPoint 6 [[TEST1:[0-9]+]] "test1" +; CHECK-SPIRV: EntryPoint 6 [[TEST2:[0-9]+]] "test2" +; CHECK-SPIRV: EntryPoint 6 [[TEST3:[0-9]+]] "test3" +; CHECK-SPIRV: ExecutionMode [[TEST1]] 17 1 2 3 +; CHECK-SPIRV: ExecutionMode [[TEST2]] 17 2 3 1 +; CHECK-SPIRV: ExecutionMode [[TEST3]] 17 3 1 1 + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir64-unknown-unknown" + +define spir_kernel void @test1() !reqd_work_group_size !1 { +entry: + ret void +} + +define spir_kernel void @test2() !reqd_work_group_size !2 { +entry: + ret void +} + +define spir_kernel void @test3() !reqd_work_group_size !3 { +entry: + ret void +} + +!1 = !{i32 1, i32 2, i32 3} +!2 = !{i32 2, i32 3} +!3 = !{i32 3} diff --git a/test/work_group_size_hint_md.ll b/test/work_group_size_hint_md.ll new file mode 100644 index 0000000000..245c30657f --- /dev/null +++ b/test/work_group_size_hint_md.ll @@ -0,0 +1,37 @@ +; RUN: llvm-as %s -o %t.bc +; RUN: llvm-spirv %t.bc -o %t.spv +; RUN: llvm-spirv -to-text %t.spv -o %t.spt +; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV +; +; The purpose of this test is to check that the work_group_size_hint metadata +; is correctly converted to the LocalSizeHint execution mode for the kernels it +; is applied to. +; +; CHECK-SPIRV: EntryPoint 6 [[TEST1:[0-9]+]] "test1" +; CHECK-SPIRV: EntryPoint 6 [[TEST2:[0-9]+]] "test2" +; CHECK-SPIRV: EntryPoint 6 [[TEST3:[0-9]+]] "test3" +; CHECK-SPIRV: ExecutionMode [[TEST1]] 18 1 2 3 +; CHECK-SPIRV: ExecutionMode [[TEST2]] 18 2 3 1 +; CHECK-SPIRV: ExecutionMode [[TEST3]] 18 3 1 1 + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spir64-unknown-unknown" + +define spir_kernel void @test1() !work_group_size_hint !1 { +entry: + ret void +} + +define spir_kernel void @test2() !work_group_size_hint !2 { +entry: + ret void +} + +define spir_kernel void @test3() !work_group_size_hint !3 { +entry: + ret void +} + +!1 = !{i32 1, i32 2, i32 3} +!2 = !{i32 2, i32 3} +!3 = !{i32 3}