Skip to content

Commit a4ee0e3

Browse files
AArch64: support bf16 to sf extensions [PR121853]
It looks like during the upstreaming of BF16 we didn't implement the extend optab for it. As a result we go through soft-float emulation which results in massive performance drop in projects using BF16. As an example, for float convert(__bf16 value) { return (float)value; } we generate: convert(__bf16): stp x29, x30, [sp, -16]! mov x29, sp bl __extendbfsf2 ldp x29, x30, [sp], 16 ret and after this patch convert: movi v31.4s, 0 ext v0.16b, v31.16b, v0.16b, #14 ret We generate an ext with movi because this has same latency as a shift however it has twice the throughput. The zero vector is zero latency as such in real workloads this codegen is much better than using shifts. As a reminder, BF16 -> FP32 is just shifting left 16 bits. The expand pattern has to rely on generating multiple subregs due to a restriction that subregs can't chang floating point size and type at the same time. I've tried alternative approaches like using the EXT as SF mode, but the paradoxical subreg of BF -> SF isn't allowed and using an extend doesn't work because extend is what we're defining. gcc/ChangeLog: PR target/121853 * config/aarch64/aarch64-simd.md (extendbfsf2): New. gcc/testsuite/ChangeLog: PR target/121853 * gcc.target/aarch64/pr121853_1.c: New test. * gcc.target/aarch64/pr121853_2.c: New test. (cherry picked from commit 58ee207)
1 parent 21866f2 commit a4ee0e3

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed

gcc/config/aarch64/aarch64-simd.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3181,6 +3181,7 @@
31813181
DONE;
31823182
}
31833183
)
3184+
31843185
(define_insn "extend<mode><Vwide>2"
31853186
[(set (match_operand:<VWIDE> 0 "register_operand" "=w")
31863187
(float_extend:<VWIDE>
@@ -3190,6 +3191,29 @@
31903191
[(set_attr "type" "neon_fp_cvt_widen_s")]
31913192
)
31923193

3194+
/* A BF->SF is a shift left of 16, however shifts are expensive and the generic
3195+
middle-end expansion would force through DI move. Instead use EXT to do the
3196+
shift to get better throughput and don't go through GPRs. */
3197+
3198+
(define_expand "extendbfsf2"
3199+
[(set (match_operand:SF 0 "register_operand" "=w")
3200+
(float_extend:SF
3201+
(match_operand:BF 1 "register_operand" "w")))]
3202+
"TARGET_SIMD"
3203+
{
3204+
rtx tmp0 = aarch64_gen_shareable_zero (V8BFmode);
3205+
rtx op0 = force_lowpart_subreg (V8BFmode, operands[1], BFmode);
3206+
rtx res = gen_reg_rtx (V8BFmode);
3207+
emit_insn (gen_aarch64_extv8bf (res, tmp0, op0, gen_int_mode (7, SImode)));
3208+
/* Subregs between floating point modes aren't allowed to change size, so go
3209+
through V4SFmode. */
3210+
res = force_lowpart_subreg (V4SFmode, res, V8BFmode);
3211+
res = force_lowpart_subreg (SFmode, res, V4SFmode);
3212+
emit_move_insn (operands[0], res);
3213+
DONE;
3214+
})
3215+
3216+
31933217
;; Float narrowing operations.
31943218

31953219
(define_insn "aarch64_float_trunc_rodd_df"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* { dg-do run } */
2+
/* { dg-additional-options "-O2 -std=c99" } */
3+
4+
#include <stdint.h>
5+
#include <stdio.h>
6+
#include <string.h>
7+
8+
__attribute__ ((noipa))
9+
float convert(__bf16 value) {
10+
return (float)value;
11+
}
12+
13+
static inline uint32_t f32_bits(float f) {
14+
uint32_t u; memcpy(&u, &f, sizeof u); return u;
15+
}
16+
static inline __bf16 bf16_from_bits(uint16_t u) {
17+
__bf16 b; memcpy(&b, &u, sizeof b); return b;
18+
}
19+
20+
/* Fixed bf16 inputs (as raw 16-bit payloads) covering edge cases. */
21+
static const uint16_t inputs[] = {
22+
0x0000, // +0
23+
0x8000, // -0
24+
0x7F80, // +inf
25+
0xFF80, // -inf
26+
0x7FC0, // qNaN (+) (quiet bit set in bf16)
27+
0xFFC0, // qNaN (-)
28+
0x7F01, // sNaN (+) (will be quieted by conversion)
29+
0xFF01, // sNaN (-)
30+
0x0001, // smallest +subnormal
31+
0x007F, // largest +subnormal
32+
0x8001, // smallest -subnormal
33+
0x807F, // largest -subnormal
34+
0x0080, // smallest +normal
35+
0x3F80, // +1.0
36+
0xBF80, // -1.0
37+
0x3F00, // +0.5
38+
0xBF00, // -0.5
39+
0x3FC0, // +1.5
40+
0x7F7F, // max finite +
41+
0xFF7F, // max finite -
42+
};
43+
44+
int main(void) {
45+
const size_t N = sizeof(inputs)/sizeof(inputs[0]);
46+
size_t fails = 0;
47+
48+
for (size_t i = 0; i < N; ++i) {
49+
__bf16 in = bf16_from_bits(inputs[i]);
50+
float out = convert(in);
51+
uint32_t got = f32_bits(out);
52+
uint32_t exp = inputs[i] << 16;
53+
54+
if (got != exp) {
55+
printf("FAIL[%zu]: in_bf16=0x%04X exp_f32=0x%08X got_f32=0x%08X\n",
56+
i, inputs[i], exp, got);
57+
++fails;
58+
}
59+
}
60+
61+
if (fails != 0)
62+
__builtin_abort ();
63+
}
64+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/* { dg-do compile } */
2+
/* { dg-additional-options "-O1" } */
3+
/* { dg-final { check-function-bodies "**" "" } } */
4+
5+
float convert(__bf16 value) {
6+
return (float)value;
7+
}
8+
9+
/*
10+
** convert:
11+
** movi v[0-9]+.4s, 0
12+
** ext v[0-9]+.16b, v[0-9]+.16b, v[0-9]+.16b, #14
13+
** ret
14+
*/

0 commit comments

Comments
 (0)