diff --git a/llvm/test/tools/sycl-post-link/emit_program_metadata.ll b/llvm/test/tools/sycl-post-link/emit_program_metadata.ll index ba62b415974cc..7b2ad12057af7 100644 --- a/llvm/test/tools/sycl-post-link/emit_program_metadata.ll +++ b/llvm/test/tools/sycl-post-link/emit_program_metadata.ll @@ -7,8 +7,12 @@ target triple = "spir64-unknown-unknown" attributes #0 = { "sycl-work-group-size"="4,2,1" } +attributes #1 = { "sycl-work-group-size"="4,2" } +attributes #2 = { "sycl-work-group-size"="4" } !0 = !{i32 1, i32 2, i32 4} +!1 = !{i32 2, i32 4} +!2 = !{i32 4} define weak_odr spir_kernel void @SpirKernel1(float %arg1) !reqd_work_group_size !0 { call void @foo(float %arg1) @@ -20,12 +24,36 @@ define weak_odr spir_kernel void @SpirKernel2(float %arg1) #0 { ret void } +define weak_odr spir_kernel void @SpirKernel3(float %arg1) !reqd_work_group_size !1 { + call void @foo(float %arg1) + ret void +} + +define weak_odr spir_kernel void @SpirKernel4(float %arg1) #1 { + call void @foo(float %arg1) + ret void +} + +define weak_odr spir_kernel void @SpirKernel5(float %arg1) !reqd_work_group_size !2 { + call void @foo(float %arg1) + ret void +} + +define weak_odr spir_kernel void @SpirKernel6(float %arg1) #2 { + call void @foo(float %arg1) + ret void +} + declare void @foo(float) ; CHECK-PROP: [SYCL/program metadata] ; // Base64 encoding in the prop file (including 8 bytes length): ; CHECK-PROP-NEXT: SpirKernel1@reqd_work_group_size=2|gBAAAAAAAAQAAAAACAAAAQAAAAA ; CHECK-PROP-NEXT: SpirKernel2@reqd_work_group_size=2|gBAAAAAAAAQAAAAACAAAAQAAAAA +; CHECK-PROP-NEXT: SpirKernel3@reqd_work_group_size=2|ABAAAAAAAAgAAAAAEAAAAA +; CHECK-PROP-NEXT: SpirKernel4@reqd_work_group_size=2|ABAAAAAAAAgAAAAAEAAAAA +; CHECK-PROP-NEXT: SpirKernel5@reqd_work_group_size=2|gAAAAAAAAAABAAAA +; CHECK-PROP-NEXT: SpirKernel6@reqd_work_group_size=2|gAAAAAAAAAABAAAA ; CHECK-TABLE: [Code|Properties] ; CHECK-TABLE-NEXT: {{.*}}files_0.prop diff --git a/llvm/test/tools/sycl-post-link/kernel-properties.ll b/llvm/test/tools/sycl-post-link/kernel-properties.ll index 8ef9ae8bc0d80..02f3357607845 100644 --- a/llvm/test/tools/sycl-post-link/kernel-properties.ll +++ b/llvm/test/tools/sycl-post-link/kernel-properties.ll @@ -36,11 +36,10 @@ attributes #2 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-w !3 = !{i32 1, !"wchar_size", i32 4} !4 = !{i32 7, !"frame-pointer", i32 2} -; Note that work-group sizes are padded with 1's after being reversed. ; CHECK-IR-DAG: ![[SGSizeMD0]] = !{i32 3} -; CHECK-IR-DAG: ![[WGSizeMD0]] = !{i{{[0-9]+}} 1, i{{[0-9]+}} 1, i{{[0-9]+}} 1} -; CHECK-IR-DAG: ![[WGSizeHintMD0]] = !{i{{[0-9]+}} 2, i{{[0-9]+}} 1, i{{[0-9]+}} 1} -; CHECK-IR-DAG: ![[WGSizeMD1]] = !{i{{[0-9]+}} 5, i{{[0-9]+}} 4, i{{[0-9]+}} 1} -; CHECK-IR-DAG: ![[WGSizeHintMD1]] = !{i{{[0-9]+}} 7, i{{[0-9]+}} 6, i{{[0-9]+}} 1} +; CHECK-IR-DAG: ![[WGSizeMD0]] = !{i{{[0-9]+}} 1} +; CHECK-IR-DAG: ![[WGSizeHintMD0]] = !{i{{[0-9]+}} 2} +; CHECK-IR-DAG: ![[WGSizeMD1]] = !{i{{[0-9]+}} 5, i{{[0-9]+}} 4} +; CHECK-IR-DAG: ![[WGSizeHintMD1]] = !{i{{[0-9]+}} 7, i{{[0-9]+}} 6} ; CHECK-IR-DAG: ![[WGSizeMD2]] = !{i{{[0-9]+}} 10, i{{[0-9]+}} 9, i{{[0-9]+}} 8} ; CHECK-IR-DAG: ![[WGSizeHintMD2]] = !{i{{[0-9]+}} 13, i{{[0-9]+}} 12, i{{[0-9]+}} 11} diff --git a/llvm/tools/sycl-post-link/CompileTimePropertiesPass.cpp b/llvm/tools/sycl-post-link/CompileTimePropertiesPass.cpp index e6774c42876e1..8afdfd899f320 100644 --- a/llvm/tools/sycl-post-link/CompileTimePropertiesPass.cpp +++ b/llvm/tools/sycl-post-link/CompileTimePropertiesPass.cpp @@ -186,12 +186,6 @@ attributeToExecModeMetadata(Module &M, const Attribute &Attr) { MDVals.push_back(ConstantAsMetadata::get(Constant::getIntegerValue( SizeTTy, APInt(SizeTBitSize, ValStr, 10)))); - // The SPIR-V translator expects 3 values, so we pad the remaining - // dimensions with 1. - for (size_t I = MDVals.size(); I < 3; ++I) - MDVals.push_back(ConstantAsMetadata::get( - Constant::getIntegerValue(SizeTTy, APInt(SizeTBitSize, 1)))); - const char *MDName = (AttrKindStr == "sycl-work-group-size") ? "reqd_work_group_size" : "work_group_size_hint"; diff --git a/llvm/tools/sycl-post-link/sycl-post-link.cpp b/llvm/tools/sycl-post-link/sycl-post-link.cpp index ad00d3e03dbc6..bafc67e3369b1 100644 --- a/llvm/tools/sycl-post-link/sycl-post-link.cpp +++ b/llvm/tools/sycl-post-link/sycl-post-link.cpp @@ -302,18 +302,17 @@ std::vector getKernelNamesUsingAssert(const Module &M) { // Gets reqd_work_group_size information for function Func. std::vector getKernelReqdWorkGroupSizeMetadata(const Function &Func) { - auto *ReqdWorkGroupSizeMD = Func.getMetadata("reqd_work_group_size"); + MDNode *ReqdWorkGroupSizeMD = Func.getMetadata("reqd_work_group_size"); if (!ReqdWorkGroupSizeMD) return {}; - // TODO: Remove 3-operand assumption when it is relaxed. - assert(ReqdWorkGroupSizeMD->getNumOperands() == 3); - uint32_t X = mdconst::extract(ReqdWorkGroupSizeMD->getOperand(0)) - ->getZExtValue(); - uint32_t Y = mdconst::extract(ReqdWorkGroupSizeMD->getOperand(1)) - ->getZExtValue(); - uint32_t Z = mdconst::extract(ReqdWorkGroupSizeMD->getOperand(2)) - ->getZExtValue(); - return {X, Y, Z}; + size_t NumOperands = ReqdWorkGroupSizeMD->getNumOperands(); + assert(NumOperands >= 1 && NumOperands <= 3 && + "reqd_work_group_size does not have between 1 and 3 operands."); + std::vector OutVals; + OutVals.reserve(NumOperands); + for (const MDOperand &MDOp : ReqdWorkGroupSizeMD->operands()) + OutVals.push_back(mdconst::extract(MDOp)->getZExtValue()); + return OutVals; } // Creates a filename based on current output filename, given extension,