diff --git a/cranelift/codegen/src/isa/riscv64/lower.isle b/cranelift/codegen/src/isa/riscv64/lower.isle index 72890d91a136..5a8f499e9fd0 100644 --- a/cranelift/codegen/src/isa/riscv64/lower.isle +++ b/cranelift/codegen/src/isa/riscv64/lower.isle @@ -1509,56 +1509,55 @@ ;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;; fmadd: rs1 * rs2 + rs3 -(rule 0 (lower (has_type (ty_scalar_float ty) (fma x y z))) - (rv_fmadd ty (FRM.RNE) x y z)) - -;; fmsub: rs1 * rs2 - rs3 -(rule 1 (lower (has_type (ty_scalar_float ty) (fma x y (fneg z)))) - (rv_fmsub ty (FRM.RNE) x y z)) - -;; fnmsub: -rs1 * rs2 + rs3 -(rule 2 (lower (has_type (ty_scalar_float ty) (fma (fneg x) y z))) - (rv_fnmsub ty (FRM.RNE) x y z)) - -;; fnmadd: -rs1 * rs2 - rs3 -(rule 3 (lower (has_type (ty_scalar_float ty) (fma (fneg x) y (fneg z)))) - (rv_fnmadd ty (FRM.RNE) x y z)) - -;; (fma x y z) computes x * y + z -;; vfmacc computes vd[i] = +(vs1[i] * vs2[i]) + vd[i] -;; We need to reverse the order of the arguments - -(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (fma x y z))) - (rv_vfmacc_vv z y x (unmasked) ty)) - -(rule 5 (lower (has_type (ty_vec_fits_in_register ty) (fma (splat x) y z))) - (rv_vfmacc_vf z y x (unmasked) ty)) - -;; vfmsac computes vd[i] = +(vs1[i] * vs2[i]) - vd[i] - -(rule 6 (lower (has_type (ty_vec_fits_in_register ty) (fma x y (fneg z)))) - (rv_vfmsac_vv z y x (unmasked) ty)) - -(rule 9 (lower (has_type (ty_vec_fits_in_register ty) (fma (splat x) y (fneg z)))) - (rv_vfmsac_vf z y x (unmasked) ty)) - -;; vfnmacc computes vd[i] = -(vs1[i] * vs2[i]) - vd[i] - -(rule 7 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg x) y (fneg z)))) - (rv_vfnmacc_vv z y x (unmasked) ty)) - -(rule 9 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg (splat x)) y (fneg z)))) - (rv_vfnmacc_vf z y x (unmasked) ty)) - -;; vfnmsac computes vd[i] = -(vs1[i] * vs2[i]) + vd[i] - -(rule 5 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg x) y z))) - (rv_vfnmsac_vv z y x (unmasked) ty)) - -(rule 8 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg (splat x)) y z))) - (rv_vfnmsac_vf z y x (unmasked) ty)) - +;; RISC-V has 4 FMA instructions that do a slightly different computation. +;; +;; fmadd: (rs1 * rs2) + rs3 +;; fmsub: (rs1 * rs2) - rs3 +;; fnmadd: -(rs1 * rs2) - rs3 +;; fnmsub: -(rs1 * rs2) + rs3 +;; +;; Additionally there are vector versions of these instructions with slightly different names. +;; The vector instructions also have two variants each. `.vv` and `.vf`, where `.vv` variants +;; take two vector operands and the `.vf` variants take a vector operand and a scalar operand. +;; +;; Due to this, variation they recieve the arguments in a different order. So we need to swap +;; the arguments below. +;; +;; vfmacc: vd[i] = +(vs1[i] * vs2[i]) + vd[i] +;; vfmsac: vd[i] = +(vs1[i] * vs2[i]) - vd[i] +;; vfnmacc: vd[i] = -(vs1[i] * vs2[i]) - vd[i] +;; vfnmsac: vd[i] = -(vs1[i] * vs2[i]) + vd[i] + +(type IsFneg (enum (Result (negate u64) (value Value)))) + +(decl pure is_fneg (Value) IsFneg) +(rule 1 (is_fneg (fneg x)) (IsFneg.Result 1 x)) +(rule 0 (is_fneg x) (IsFneg.Result 0 x)) + +(rule (lower (has_type ty (fma x_src y_src z_src))) + (if-let (IsFneg.Result neg_x x) (is_fneg x_src)) + (if-let (IsFneg.Result neg_y y) (is_fneg y_src)) + (if-let (IsFneg.Result neg_z z) (is_fneg z_src)) + (rv_fma ty (u64_xor neg_x neg_y) neg_z x y z)) + +; parity arguments indicate whether to negate the x*y term or the z term, respectively +(decl rv_fma (Type u64 u64 Value Value Value) InstOutput) +(rule 0 (rv_fma (ty_scalar_float ty) 0 0 x y z) (rv_fmadd ty (FRM.RNE) x y z)) +(rule 0 (rv_fma (ty_scalar_float ty) 0 1 x y z) (rv_fmsub ty (FRM.RNE) x y z)) +(rule 0 (rv_fma (ty_scalar_float ty) 1 0 x y z) (rv_fnmsub ty (FRM.RNE) x y z)) +(rule 0 (rv_fma (ty_scalar_float ty) 1 1 x y z) (rv_fnmadd ty (FRM.RNE) x y z)) +(rule 1 (rv_fma (ty_vec_fits_in_register ty) 0 0 x y z) (rv_vfmacc_vv z y x (unmasked) ty)) +(rule 1 (rv_fma (ty_vec_fits_in_register ty) 0 1 x y z) (rv_vfmsac_vv z y x (unmasked) ty)) +(rule 1 (rv_fma (ty_vec_fits_in_register ty) 1 0 x y z) (rv_vfnmsac_vv z y x (unmasked) ty)) +(rule 1 (rv_fma (ty_vec_fits_in_register ty) 1 1 x y z) (rv_vfnmacc_vv z y x (unmasked) ty)) +(rule 2 (rv_fma (ty_vec_fits_in_register ty) 0 0 (splat x) y z) (rv_vfmacc_vf z y x (unmasked) ty)) +(rule 2 (rv_fma (ty_vec_fits_in_register ty) 0 1 (splat x) y z) (rv_vfmsac_vf z y x (unmasked) ty)) +(rule 2 (rv_fma (ty_vec_fits_in_register ty) 1 0 (splat x) y z) (rv_vfnmsac_vf z y x (unmasked) ty)) +(rule 2 (rv_fma (ty_vec_fits_in_register ty) 1 1 (splat x) y z) (rv_vfnmacc_vf z y x (unmasked) ty)) +(rule 3 (rv_fma (ty_vec_fits_in_register ty) 0 0 x (splat y) z) (rv_vfmacc_vf z x y (unmasked) ty)) +(rule 3 (rv_fma (ty_vec_fits_in_register ty) 0 1 x (splat y) z) (rv_vfmsac_vf z x y (unmasked) ty)) +(rule 3 (rv_fma (ty_vec_fits_in_register ty) 1 0 x (splat y) z) (rv_vfnmsac_vf z x y (unmasked) ty)) +(rule 3 (rv_fma (ty_vec_fits_in_register ty) 1 1 x (splat y) z) (rv_vfnmacc_vf z x y (unmasked) ty)) ;;;; Rules for `sqrt` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; (rule 0 (lower (has_type (ty_scalar_float ty) (sqrt x))) diff --git a/cranelift/filetests/filetests/isa/riscv64/simd-fma.clif b/cranelift/filetests/filetests/isa/riscv64/simd-fma.clif index 689763fd9f06..313afba1a15d 100644 --- a/cranelift/filetests/filetests/isa/riscv64/simd-fma.clif +++ b/cranelift/filetests/filetests/isa/riscv64/simd-fma.clif @@ -362,3 +362,90 @@ block0(v0: f64, v1: f64x2, v2: f64x2): ; addi sp, sp, 0x10 ; ret + +function %fma_splat_y_f32x4(f32x4, f32, f32x4) -> f32x4 { +block0(v0: f32x4, v1: f32, v2: f32x4): + v3 = splat.f32x4 v1 + v4 = fma v0, v3, v2 + return v4 +} + +; VCode: +; addi sp,sp,-16 +; sd ra,8(sp) +; sd fp,0(sp) +; mv fp,sp +; block0: +; vle8.v v9,-32(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma) +; vle8.v v14,-16(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma) +; vfmacc.vf v14,v9,fa0 #avl=4, #vtype=(e32, m1, ta, ma) +; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma) +; ld ra,8(sp) +; ld fp,0(sp) +; addi sp,sp,16 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; addi sp, sp, -0x10 +; sd ra, 8(sp) +; sd s0, 0(sp) +; mv s0, sp +; block1: ; offset 0x10 +; .byte 0x57, 0x70, 0x08, 0xcc +; addi t6, sp, 0x10 +; .byte 0x87, 0x84, 0x0f, 0x02 +; addi t6, sp, 0x20 +; .byte 0x07, 0x87, 0x0f, 0x02 +; .byte 0x57, 0x70, 0x02, 0xcd +; .byte 0x57, 0x57, 0x95, 0xb2 +; .byte 0x57, 0x70, 0x08, 0xcc +; .byte 0x27, 0x07, 0x05, 0x02 +; ld ra, 8(sp) +; ld s0, 0(sp) +; addi sp, sp, 0x10 +; ret + +function %fma_splat_y_f64x2(f64x2, f64, f64x2) -> f64x2 { +block0(v0: f64x2, v1: f64, v2: f64x2): + v3 = splat.f64x2 v1 + v4 = fma v0, v3, v2 + return v4 +} + +; VCode: +; addi sp,sp,-16 +; sd ra,8(sp) +; sd fp,0(sp) +; mv fp,sp +; block0: +; vle8.v v9,-32(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma) +; vle8.v v14,-16(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma) +; vfmacc.vf v14,v9,fa0 #avl=2, #vtype=(e64, m1, ta, ma) +; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma) +; ld ra,8(sp) +; ld fp,0(sp) +; addi sp,sp,16 +; ret +; +; Disassembled: +; block0: ; offset 0x0 +; addi sp, sp, -0x10 +; sd ra, 8(sp) +; sd s0, 0(sp) +; mv s0, sp +; block1: ; offset 0x10 +; .byte 0x57, 0x70, 0x08, 0xcc +; addi t6, sp, 0x10 +; .byte 0x87, 0x84, 0x0f, 0x02 +; addi t6, sp, 0x20 +; .byte 0x07, 0x87, 0x0f, 0x02 +; .byte 0x57, 0x70, 0x81, 0xcd +; .byte 0x57, 0x57, 0x95, 0xb2 +; .byte 0x57, 0x70, 0x08, 0xcc +; .byte 0x27, 0x07, 0x05, 0x02 +; ld ra, 8(sp) +; ld s0, 0(sp) +; addi sp, sp, 0x10 +; ret +