Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,47 @@ static mlir::Value emitX86Muldq(CIRGenBuilderTy &builder, mlir::Location loc,
return builder.createMul(loc, lhs, rhs);
}

// Convert f16 half values to floats.
static mlir::Value emitX86CvtF16ToFloatExpr(CIRGenBuilderTy &builder,
mlir::Location loc,
llvm::ArrayRef<mlir::Value> ops,
mlir::Type dstTy) {
assert((ops.size() == 1 || ops.size() == 3 || ops.size() == 4) &&
"Unknown cvtph2ps intrinsic");

// If the SAE intrinsic doesn't use default rounding then we can't upgrade.
if (ops.size() == 4) {
auto constOp = ops[3].getDefiningOp<cir::ConstantOp>();
assert(constOp && "Expected constant operand");
if (constOp.getIntValue().getZExtValue() != 4) {
return emitIntrinsicCallOp(builder, loc, "x86.avx512.mask.vcvtph2ps.512",
dstTy, ops);
}
}

unsigned numElts = cast<cir::VectorType>(dstTy).getSize();
mlir::Value src = ops[0];

// Extract the subvector
if (numElts != cast<cir::VectorType>(src.getType()).getSize()) {
assert(numElts == 4 && "Unexpected vector size");
src = builder.createVecShuffle(loc, src, {0, 1, 2, 3});
}

// Bitcast from vXi16 to vXf16.
cir::VectorType halfTy =
cir::VectorType::get(cir::FP16Type::get(builder.getContext()), numElts);

src = builder.createCast(cir::CastKind::bitcast, src, halfTy);

// Perform the fp-extension
mlir::Value res = builder.createCast(cir::CastKind::floating, src, dstTy);

if (ops.size() >= 3)
res = emitX86Select(builder, loc, ops[2], res, ops[1]);
return res;
}

static mlir::Value emitX86vpcom(CIRGenBuilderTy &builder, mlir::Location loc,
llvm::SmallVector<mlir::Value> ops,
bool isSigned) {
Expand Down Expand Up @@ -1828,9 +1869,17 @@ CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, const CallExpr *expr) {
case X86::BI__builtin_ia32_cmpnltsd:
case X86::BI__builtin_ia32_cmpnlesd:
case X86::BI__builtin_ia32_cmpordsd:
cgm.errorNYI(expr->getSourceRange(),
std::string("unimplemented X86 builtin call: ") +
getContext().BuiltinInfo.getName(builtinID));
return {};
case X86::BI__builtin_ia32_vcvtph2ps_mask:
case X86::BI__builtin_ia32_vcvtph2ps256_mask:
case X86::BI__builtin_ia32_vcvtph2ps512_mask:
case X86::BI__builtin_ia32_vcvtph2ps512_mask: {
mlir::Location loc = getLoc(expr->getExprLoc());
return emitX86CvtF16ToFloatExpr(builder, loc, ops,
convertType(expr->getType()));
}
case X86::BI__builtin_ia32_cvtneps2bf16_128_mask:
case X86::BI__builtin_ia32_cvtneps2bf16_256_mask:
case X86::BI__builtin_ia32_cvtneps2bf16_512_mask:
Expand Down
147 changes: 147 additions & 0 deletions clang/test/CIR/CodeGenBuiltins/X86/avx512f16c-builtins.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512fp16 -target-feature +avx512f -target-feature +avx512vl -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion
// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512fp16 -target-feature +avx512f -target-feature +avx512vl -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s
// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512fp16 -target-feature +avx512f -target-feature +avx512vl -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
// RUN: FileCheck --check-prefixes=OGCG --input-file=%t.ll %s

#include <immintrin.h>

__m128 test_vcvtph2ps_mask(__m128i a, __m128 src, __mmask8 k) {
// CIR-LABEL: cir.func no_inline dso_local @test_vcvtph2ps_mask
// CIR: %{{.*}} = cir.vec.shuffle(%{{.*}}, %{{.*}} : !cir.vector<8 x !s16i>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i] : !cir.vector<4 x !s16i>
// CIR: %{{.*}} = cir.cast floating %{{.*}} : !cir.vector<4 x !cir.f16> -> !cir.vector<4 x !cir.float>
// CIR: %{{.*}} = cir.vec.shuffle(%{{.*}}, %{{.*}} : !cir.vector<8 x !cir.int<s, 1>>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i] : !cir.vector<4 x !cir.int<s, 1>>
// CIR: %{{.*}} = cir.vec.ternary(%{{.*}}, %{{.*}}, %{{.*}}) : !cir.vector<4 x !cir.int<s, 1>>, !cir.vector<4 x !cir.float>

// LLVM-LABEL: @test_vcvtph2ps_mask
// LLVM: %{{.*}} = fpext <4 x half> %{{.*}} to <4 x float>
// LLVM: %{{.*}} = select <4 x i1> {{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}

// OGCG-LABEL: @test_vcvtph2ps_mask
// OGCG: %{{.*}} = fpext <4 x half> %{{.*}} to <4 x float>
// OGCG: %{{.*}} = select <4 x i1> {{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
typedef short __v8hi __attribute__((__vector_size__(16)));
return __builtin_ia32_vcvtph2ps_mask((__v8hi)a, src, k);
}

__m256 test_vcvtph2ps256_mask(__m128i a, __m256 src, __mmask8 k) {
// CIR-LABEL: cir.func no_inline dso_local @test_vcvtph2ps256_mask
// CIR: %{{.*}} = cir.cast floating %{{.*}} : !cir.vector<8 x !cir.f16> -> !cir.vector<8 x !cir.float>
// CIR: %{{.*}} = cir.vec.ternary(%{{.*}}, %{{.*}}, %{{.*}}) : !cir.vector<8 x !cir.int<s, 1>>, !cir.vector<8 x !cir.float>

// LLVM-LABEL: @test_vcvtph2ps256_mask
// LLVM: %{{.*}} = fpext <8 x half> %{{.*}} to <8 x float>
// LLVM: %{{.*}} = select <8 x i1> {{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}

// OGCG-LABEL: @test_vcvtph2ps256_mask
// OGCG: %{{.*}} = fpext <8 x half> %{{.*}} to <8 x float>
// OGCG: %{{.*}} = select <8 x i1> {{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
typedef short __v8hi __attribute__((__vector_size__(16)));
return __builtin_ia32_vcvtph2ps256_mask((__v8hi)a, src, k);
}

__m512 test_vcvtph2ps512_mask(__m256i a, __m512 src, __mmask16 k) {
// CIR-LABEL: cir.func no_inline dso_local @test_vcvtph2ps512_mask
// CIR: %{{.*}} = cir.cast floating %{{.*}} : !cir.vector<16 x !cir.f16> -> !cir.vector<16 x !cir.float>
// CIR: %{{.*}} = cir.vec.ternary(%{{.*}}, %{{.*}}, %{{.*}}) : !cir.vector<16 x !cir.int<s, 1>>, !cir.vector<16 x !cir.float>

// LLVM-LABEL: @test_vcvtph2ps512_mask
// LLVM: %{{.*}} = fpext <16 x half> %{{.*}} to <16 x float>
// LLVM: %{{.*}} = select <16 x i1> {{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}

// OGCG-LABEL: @test_vcvtph2ps512_mask
// OGCG: %{{.*}} = fpext <16 x half> %{{.*}} to <16 x float>
// OGCG: %{{.*}} = select <16 x i1> {{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
typedef short __v16hi __attribute__((__vector_size__(32)));
return __builtin_ia32_vcvtph2ps512_mask((__v16hi)a, src, k, 4);
}

__m128 test_vcvtph2ps_maskz(__m128i a, __mmask8 k) {
// CIR-LABEL: cir.func no_inline dso_local @test_vcvtph2ps_maskz
// CIR: %{{.*}} = cir.vec.shuffle(%{{.*}}, %{{.*}} : !cir.vector<8 x !s16i>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i] : !cir.vector<4 x !s16i>
// CIR: %{{.*}} = cir.cast floating %{{.*}} : !cir.vector<4 x !cir.f16> -> !cir.vector<4 x !cir.float>
// CIR: %{{.*}} = cir.vec.shuffle(%{{.*}}, %{{.*}} : !cir.vector<8 x !cir.int<s, 1>>) [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i] : !cir.vector<4 x !cir.int<s, 1>>
// CIR: %{{.*}} = cir.vec.ternary(%{{.*}}, %{{.*}}, %{{.*}}) : !cir.vector<4 x !cir.int<s, 1>>, !cir.vector<4 x !cir.float>

// LLVM-LABEL: @test_vcvtph2ps_maskz
// LLVM: %{{.*}} = fpext <4 x half> %{{.*}} to <4 x float>
// LLVM: %{{.*}} = select <4 x i1> {{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}

// OGCG-LABEL: @test_vcvtph2ps_maskz
// OGCG: %{{.*}} = fpext <4 x half> %{{.*}} to <4 x float>
// OGCG: %{{.*}} = select <4 x i1> {{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
typedef short __v8hi __attribute__((__vector_size__(16)));
return __builtin_ia32_vcvtph2ps_mask((__v8hi)a, _mm_setzero_ps(), k);
}

__m256 test_vcvtph2ps256_maskz(__m128i a, __mmask8 k) {
// CIR-LABEL: cir.func no_inline dso_local @test_vcvtph2ps256_maskz
// CIR: %{{.*}} = cir.cast floating %{{.*}} : !cir.vector<8 x !cir.f16> -> !cir.vector<8 x !cir.float>
// CIR: %{{.*}} = cir.vec.ternary(%{{.*}}, %{{.*}}, %{{.*}}) : !cir.vector<8 x !cir.int<s, 1>>, !cir.vector<8 x !cir.float>

// LLVM-LABEL: @test_vcvtph2ps256_maskz
// LLVM: %{{.*}} = fpext <8 x half> %{{.*}} to <8 x float>
// LLVM: %{{.*}} = select <8 x i1> {{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}

// OGCG-LABEL: @test_vcvtph2ps256_maskz
// OGCG: %{{.*}} = fpext <8 x half> %{{.*}} to <8 x float>
// OGCG: %{{.*}} = select <8 x i1> {{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
typedef short __v8hi __attribute__((__vector_size__(16)));
return __builtin_ia32_vcvtph2ps256_mask((__v8hi)a, _mm256_setzero_ps(), k);
}

__m512 test_vcvtph2ps512_maskz(__m256i a, __mmask16 k) {
// CIR-LABEL: cir.func no_inline dso_local @test_vcvtph2ps512_maskz
// CIR: %{{.*}} = cir.cast floating %{{.*}} : !cir.vector<16 x !cir.f16> -> !cir.vector<16 x !cir.float>
// CIR: %{{.*}} = cir.vec.ternary(%{{.*}}, %{{.*}}, %{{.*}}) : !cir.vector<16 x !cir.int<s, 1>>, !cir.vector<16 x !cir.float>

// LLVM-LABEL: @test_vcvtph2ps512_maskz
// LLVM: %{{.*}} = fpext <16 x half> %{{.*}} to <16 x float>
// LLVM: %{{.*}} = select <16 x i1> {{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}

// OGCG-LABEL: @test_vcvtph2ps512_maskz
// OGCG: %{{.*}} = fpext <16 x half> %{{.*}} to <16 x float>
// OGCG: %{{.*}} = select <16 x i1> {{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
typedef short __v16hi __attribute__((__vector_size__(32)));
return __builtin_ia32_vcvtph2ps512_mask((__v16hi)a, _mm512_setzero_ps(), k, 4);
}

__m512 test_mm512_cvt_roundph_ps(__m256i a) {
// CIR-LABEL: cir.func no_inline dso_local @test_mm512_cvt_roundph_ps
// CIR: %{{.*}} = cir.call_llvm_intrinsic "x86.avx512.mask.vcvtph2ps.512" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (!cir.vector<16 x !s16i>, !cir.vector<16 x !cir.float>, !u16i, !s32i) -> !cir.vector<16 x !cir.float>

// LLVM-LABEL: @test_mm512_cvt_roundph_ps
// LLVM: call <16 x float> @llvm.x86.avx512.mask.vcvtph2ps.512(<16 x i16> %{{.*}}, <16 x float> %{{.*}}, i16 -1, i32 8)

// OGCG-LABEL: @test_mm512_cvt_roundph_ps
// OGCG: call <16 x float> @llvm.x86.avx512.mask.vcvtph2ps.512(<16 x i16> %{{.*}}, <16 x float> zeroinitializer, i16 -1, i32 8)
typedef short __v16hi __attribute__((__vector_size__(32)));
return _mm512_cvt_roundph_ps((__v16hi)a, _MM_FROUND_NO_EXC);
}

__m512 test_mm512_mask_cvt_roundph_ps(__m512 w, __mmask16 u, __m256i a) {
// CIR-LABEL: cir.func no_inline dso_local @test_mm512_mask_cvt_roundph_ps
// CIR: %{{.*}} = cir.call_llvm_intrinsic "x86.avx512.mask.vcvtph2ps.512" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (!cir.vector<16 x !s16i>, !cir.vector<16 x !cir.float>, !u16i, !s32i) -> !cir.vector<16 x !cir.float>

// LLVM-LABEL: @test_mm512_mask_cvt_roundph_ps
// LLVM: call <16 x float> @llvm.x86.avx512.mask.vcvtph2ps.512(<16 x i16> %{{.*}}, <16 x float> %{{.*}}, i16 %{{.*}}, i32 8)

// OGCG-LABEL: @test_mm512_mask_cvt_roundph_ps
// OGCG: call <16 x float> @llvm.x86.avx512.mask.vcvtph2ps.512(<16 x i16> %{{.*}}, <16 x float> %{{.*}}, i16 %{{.*}}, i32 8)
typedef short __v16hi __attribute__((__vector_size__(32)));
return _mm512_mask_cvt_roundph_ps(w, u, (__v16hi)a, _MM_FROUND_NO_EXC);
}

__m512 test_mm512_maskz_cvt_roundph_ps(__mmask16 u, __m256i a) {
// CIR-LABEL: cir.func no_inline dso_local @test_mm512_maskz_cvt_roundph_ps
// CIR: %{{.*}} = cir.call_llvm_intrinsic "x86.avx512.mask.vcvtph2ps.512" %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (!cir.vector<16 x !s16i>, !cir.vector<16 x !cir.float>, !u16i, !s32i) -> !cir.vector<16 x !cir.float>

// LLVM-LABEL: @test_mm512_maskz_cvt_roundph_ps
// LLVM: call <16 x float> @llvm.x86.avx512.mask.vcvtph2ps.512(<16 x i16> %{{.*}}, <16 x float> %{{.*}}, i16 %{{.*}}, i32 8)

// OGCG-LABEL: @test_mm512_maskz_cvt_roundph_ps
// OGCG: call <16 x float> @llvm.x86.avx512.mask.vcvtph2ps.512(<16 x i16> %{{.*}}, <16 x float> %{{.*}}, i16 %{{.*}}, i32 8)
typedef short __v16hi __attribute__((__vector_size__(32)));
return _mm512_maskz_cvt_roundph_ps(u, (__v16hi)a, _MM_FROUND_NO_EXC);
}