diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 29a89e46bafba..5b4155e7b8f2e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -1880,9 +1880,40 @@ CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, const CallExpr *expr) { return emitX86CvtF16ToFloatExpr(builder, loc, ops, convertType(expr->getType())); } - case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: + case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: { + mlir::Location loc = getLoc(expr->getExprLoc()); + cir::VectorType resTy = cast(convertType(expr->getType())); + + cir::VectorType inputTy = cast(ops[0].getType()); + unsigned numElts = inputTy.getSize(); + + mlir::Value mask = getMaskVecValue(builder, loc, ops[2], numElts); + + SmallVector args; + args.push_back(ops[0]); + args.push_back(ops[1]); + args.push_back(mask); + + return emitIntrinsicCallOp( + builder, loc, "x86.avx512bf16.mask.cvtneps2bf16.128", resTy, args); + } case X86::BI__builtin_ia32_cvtneps2bf16_256_mask: - case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: + case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: { + mlir::Location loc = getLoc(expr->getExprLoc()); + cir::VectorType resTy = cast(convertType(expr->getType())); + StringRef intrinsicName; + if (builtinID == X86::BI__builtin_ia32_cvtneps2bf16_256_mask) { + intrinsicName = "x86.avx512bf16.cvtneps2bf16.256"; + } else { + assert(builtinID == X86::BI__builtin_ia32_cvtneps2bf16_512_mask); + intrinsicName = "x86.avx512bf16.cvtneps2bf16.512"; + } + + mlir::Value res = emitIntrinsicCallOp(builder, loc, intrinsicName, resTy, + mlir::ValueRange{ops[0]}); + + return emitX86Select(builder, loc, ops[2], res, ops[1]); + } case X86::BI__cpuid: case X86::BI__cpuidex: case X86::BI__emul: diff --git a/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c b/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c new file mode 100644 index 0000000000000..d1e9a030e637c --- /dev/null +++ b/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c @@ -0,0 +1,81 @@ +// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -target-feature +avx512bf16 -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion +// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s +// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -target-feature +avx512bf16 -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion +// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s +// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -target-feature +avx512bf16 -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion +// RUN: FileCheck --check-prefixes=OGCG --input-file=%t.ll %s + +#include + +__m256bh test_mm512_mask_cvtneps_pbh(__m256bh src, __mmask16 k, __m512 a) { + // CIR-LABEL: test_mm512_mask_cvtneps_pbh + // CIR: cir.call @_mm512_mask_cvtneps_pbh({{.+}}, {{.+}}, {{.+}}) : (!cir.vector<16 x !cir.bf16>, !u16i, !cir.vector<16 x !cir.float>) -> !cir.vector<16 x !cir.bf16> + + // LLVM-LABEL: @test_mm512_mask_cvtneps_pbh + // LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 + + // OGCG-LABEL: @test_mm512_mask_cvtneps_pbh + // OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 + return _mm512_mask_cvtneps_pbh(src, k, a); +} + +__m256bh test_mm512_maskz_cvtneps_pbh(__mmask16 k, __m512 a) { + // CIR-LABEL: test_mm512_maskz_cvtneps_pbh + // CIR: cir.call @_mm512_maskz_cvtneps_pbh({{.+}}, {{.+}}) : (!u16i, !cir.vector<16 x !cir.float>) -> !cir.vector<16 x !cir.bf16> + + // LLVM-LABEL: @test_mm512_maskz_cvtneps_pbh + // LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> {{.+}}) + + // OGCG-LABEL: @test_mm512_maskz_cvtneps_pbh + // OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> {{.+}}) + return _mm512_maskz_cvtneps_pbh(k, a); +} + + +__m128bh test_mm256_mask_cvtneps_pbh(__m128bh src, __mmask8 k, __m256 a) { + // CIR-LABEL: test_mm256_mask_cvtneps_pbh + // CIR: cir.call @_mm256_mask_cvtneps_pbh({{.+}}, {{.+}}, {{.+}}) : (!cir.vector<8 x !cir.bf16>, !u8i, !cir.vector<8 x !cir.float>) -> !cir.vector<8 x !cir.bf16> + + // LLVM-LABEL: @test_mm256_mask_cvtneps_pbh + // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}}) + + // OGCG-LABEL: @test_mm256_mask_cvtneps_pbh + // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}}) + return _mm256_mask_cvtneps_pbh(src, k, a); +} + +__m128bh test_mm256_maskz_cvtneps_pbh(__mmask8 k, __m256 a) { + // CIR-LABEL: test_mm256_maskz_cvtneps_pbh + // CIR: cir.call @_mm256_maskz_cvtneps_pbh({{.+}}, {{.+}}) : (!u8i, !cir.vector<8 x !cir.float>) -> !cir.vector<8 x !cir.bf16> + + // LLVM-LABEL: @test_mm256_maskz_cvtneps_pbh + // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}}) + + // OGCG-LABEL: @test_mm256_maskz_cvtneps_pbh + // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}}) + return _mm256_maskz_cvtneps_pbh(k, a); +} + +__m128bh test_mm_mask_cvtneps_pbh(__m128bh src, __mmask8 k, __m128 a) { + // CIR-LABEL: test_mm_mask_cvtneps_pbh + // CIR: cir.call @_mm_mask_cvtneps_pbh({{.+}}, {{.+}}, {{.+}}) : (!cir.vector<8 x !cir.bf16>, !u8i, !cir.vector<4 x !cir.float>) -> !cir.vector<8 x !cir.bf16>{{.+}} + + // LLVM-LABEL: @test_mm_mask_cvtneps_pbh + // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> {{.+}}) + + // OGCG-LABEL: @test_mm_mask_cvtneps_pbh + // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> {{.+}}) + return _mm_mask_cvtneps_pbh(src, k, a); +} + +__m128bh test_mm_maskz_cvtneps_pbh(__mmask8 k, __m128 a) { + // CIR-LABEL: test_mm_maskz_cvtneps_pbh + // CIR: cir.call @_mm_maskz_cvtneps_pbh({{.+}}, {{.+}}) : (!u8i, !cir.vector<4 x !cir.float>) -> !cir.vector<8 x !cir.bf16> + + // LLVM-LABEL: @test_mm_maskz_cvtneps_pbh + // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> {{.+}}) + + // OGCG-LABEL: @test_mm_maskz_cvtneps_pbh + // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> {{.+}}) + return _mm_maskz_cvtneps_pbh(k, a); +}