Skip to content

Commit

Permalink
riscv64: Improve pattern matching rules for FMA (bytecodealliance#8596)
Browse files Browse the repository at this point in the history
* riscv64: Improve pattern matching rules for FMA.

This commit reworks our FMA pattern matching to be slightly less verbose. It additionally adds the `(fma x (splat y) z)` pattern for vectors, which can be proven to be equivalent to `(splat x)`.

Co-Authored-By:  Jamey Sharp <[email protected]>

* riscv64: Remove unused rule priority

---------

Co-authored-by: Jamey Sharp <[email protected]>
  • Loading branch information
afonso360 and jameysharp authored May 12, 2024
1 parent 7327304 commit 895a5ac
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 50 deletions.
99 changes: 49 additions & 50 deletions cranelift/codegen/src/isa/riscv64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -1509,56 +1509,55 @@

;;;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

;; fmadd: rs1 * rs2 + rs3
(rule 0 (lower (has_type (ty_scalar_float ty) (fma x y z)))
(rv_fmadd ty (FRM.RNE) x y z))

;; fmsub: rs1 * rs2 - rs3
(rule 1 (lower (has_type (ty_scalar_float ty) (fma x y (fneg z))))
(rv_fmsub ty (FRM.RNE) x y z))

;; fnmsub: -rs1 * rs2 + rs3
(rule 2 (lower (has_type (ty_scalar_float ty) (fma (fneg x) y z)))
(rv_fnmsub ty (FRM.RNE) x y z))

;; fnmadd: -rs1 * rs2 - rs3
(rule 3 (lower (has_type (ty_scalar_float ty) (fma (fneg x) y (fneg z))))
(rv_fnmadd ty (FRM.RNE) x y z))

;; (fma x y z) computes x * y + z
;; vfmacc computes vd[i] = +(vs1[i] * vs2[i]) + vd[i]
;; We need to reverse the order of the arguments

(rule 4 (lower (has_type (ty_vec_fits_in_register ty) (fma x y z)))
(rv_vfmacc_vv z y x (unmasked) ty))

(rule 5 (lower (has_type (ty_vec_fits_in_register ty) (fma (splat x) y z)))
(rv_vfmacc_vf z y x (unmasked) ty))

;; vfmsac computes vd[i] = +(vs1[i] * vs2[i]) - vd[i]

(rule 6 (lower (has_type (ty_vec_fits_in_register ty) (fma x y (fneg z))))
(rv_vfmsac_vv z y x (unmasked) ty))

(rule 9 (lower (has_type (ty_vec_fits_in_register ty) (fma (splat x) y (fneg z))))
(rv_vfmsac_vf z y x (unmasked) ty))

;; vfnmacc computes vd[i] = -(vs1[i] * vs2[i]) - vd[i]

(rule 7 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg x) y (fneg z))))
(rv_vfnmacc_vv z y x (unmasked) ty))

(rule 9 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg (splat x)) y (fneg z))))
(rv_vfnmacc_vf z y x (unmasked) ty))

;; vfnmsac computes vd[i] = -(vs1[i] * vs2[i]) + vd[i]

(rule 5 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg x) y z)))
(rv_vfnmsac_vv z y x (unmasked) ty))

(rule 8 (lower (has_type (ty_vec_fits_in_register ty) (fma (fneg (splat x)) y z)))
(rv_vfnmsac_vf z y x (unmasked) ty))

;; RISC-V has 4 FMA instructions that do a slightly different computation.
;;
;; fmadd: (rs1 * rs2) + rs3
;; fmsub: (rs1 * rs2) - rs3
;; fnmadd: -(rs1 * rs2) - rs3
;; fnmsub: -(rs1 * rs2) + rs3
;;
;; Additionally there are vector versions of these instructions with slightly different names.
;; The vector instructions also have two variants each. `.vv` and `.vf`, where `.vv` variants
;; take two vector operands and the `.vf` variants take a vector operand and a scalar operand.
;;
;; Due to this, variation they recieve the arguments in a different order. So we need to swap
;; the arguments below.
;;
;; vfmacc: vd[i] = +(vs1[i] * vs2[i]) + vd[i]
;; vfmsac: vd[i] = +(vs1[i] * vs2[i]) - vd[i]
;; vfnmacc: vd[i] = -(vs1[i] * vs2[i]) - vd[i]
;; vfnmsac: vd[i] = -(vs1[i] * vs2[i]) + vd[i]

(type IsFneg (enum (Result (negate u64) (value Value))))

(decl pure is_fneg (Value) IsFneg)
(rule 1 (is_fneg (fneg x)) (IsFneg.Result 1 x))
(rule 0 (is_fneg x) (IsFneg.Result 0 x))

(rule (lower (has_type ty (fma x_src y_src z_src)))
(if-let (IsFneg.Result neg_x x) (is_fneg x_src))
(if-let (IsFneg.Result neg_y y) (is_fneg y_src))
(if-let (IsFneg.Result neg_z z) (is_fneg z_src))
(rv_fma ty (u64_xor neg_x neg_y) neg_z x y z))

; parity arguments indicate whether to negate the x*y term or the z term, respectively
(decl rv_fma (Type u64 u64 Value Value Value) InstOutput)
(rule 0 (rv_fma (ty_scalar_float ty) 0 0 x y z) (rv_fmadd ty (FRM.RNE) x y z))
(rule 0 (rv_fma (ty_scalar_float ty) 0 1 x y z) (rv_fmsub ty (FRM.RNE) x y z))
(rule 0 (rv_fma (ty_scalar_float ty) 1 0 x y z) (rv_fnmsub ty (FRM.RNE) x y z))
(rule 0 (rv_fma (ty_scalar_float ty) 1 1 x y z) (rv_fnmadd ty (FRM.RNE) x y z))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 0 0 x y z) (rv_vfmacc_vv z y x (unmasked) ty))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 0 1 x y z) (rv_vfmsac_vv z y x (unmasked) ty))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 1 0 x y z) (rv_vfnmsac_vv z y x (unmasked) ty))
(rule 1 (rv_fma (ty_vec_fits_in_register ty) 1 1 x y z) (rv_vfnmacc_vv z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 0 0 (splat x) y z) (rv_vfmacc_vf z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 0 1 (splat x) y z) (rv_vfmsac_vf z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 1 0 (splat x) y z) (rv_vfnmsac_vf z y x (unmasked) ty))
(rule 2 (rv_fma (ty_vec_fits_in_register ty) 1 1 (splat x) y z) (rv_vfnmacc_vf z y x (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 0 0 x (splat y) z) (rv_vfmacc_vf z x y (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 0 1 x (splat y) z) (rv_vfmsac_vf z x y (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 1 0 x (splat y) z) (rv_vfnmsac_vf z x y (unmasked) ty))
(rule 3 (rv_fma (ty_vec_fits_in_register ty) 1 1 x (splat y) z) (rv_vfnmacc_vf z x y (unmasked) ty))

;;;; Rules for `sqrt` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule 0 (lower (has_type (ty_scalar_float ty) (sqrt x)))
Expand Down
87 changes: 87 additions & 0 deletions cranelift/filetests/filetests/isa/riscv64/simd-fma.clif
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,90 @@ block0(v0: f64, v1: f64x2, v2: f64x2):
; addi sp, sp, 0x10
; ret


function %fma_splat_y_f32x4(f32x4, f32, f32x4) -> f32x4 {
block0(v0: f32x4, v1: f32, v2: f32x4):
v3 = splat.f32x4 v1
v4 = fma v0, v3, v2
return v4
}

; VCode:
; addi sp,sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v9,-32(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v14,-16(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vfmacc.vf v14,v9,fa0 #avl=4, #vtype=(e32, m1, ta, ma)
; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; addi sp,sp,16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; mv s0, sp
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, sp, 0x10
; .byte 0x87, 0x84, 0x0f, 0x02
; addi t6, sp, 0x20
; .byte 0x07, 0x87, 0x0f, 0x02
; .byte 0x57, 0x70, 0x02, 0xcd
; .byte 0x57, 0x57, 0x95, 0xb2
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

function %fma_splat_y_f64x2(f64x2, f64, f64x2) -> f64x2 {
block0(v0: f64x2, v1: f64, v2: f64x2):
v3 = splat.f64x2 v1
v4 = fma v0, v3, v2
return v4
}

; VCode:
; addi sp,sp,-16
; sd ra,8(sp)
; sd fp,0(sp)
; mv fp,sp
; block0:
; vle8.v v9,-32(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vle8.v v14,-16(incoming_arg) #avl=16, #vtype=(e8, m1, ta, ma)
; vfmacc.vf v14,v9,fa0 #avl=2, #vtype=(e64, m1, ta, ma)
; vse8.v v14,0(a0) #avl=16, #vtype=(e8, m1, ta, ma)
; ld ra,8(sp)
; ld fp,0(sp)
; addi sp,sp,16
; ret
;
; Disassembled:
; block0: ; offset 0x0
; addi sp, sp, -0x10
; sd ra, 8(sp)
; sd s0, 0(sp)
; mv s0, sp
; block1: ; offset 0x10
; .byte 0x57, 0x70, 0x08, 0xcc
; addi t6, sp, 0x10
; .byte 0x87, 0x84, 0x0f, 0x02
; addi t6, sp, 0x20
; .byte 0x07, 0x87, 0x0f, 0x02
; .byte 0x57, 0x70, 0x81, 0xcd
; .byte 0x57, 0x57, 0x95, 0xb2
; .byte 0x57, 0x70, 0x08, 0xcc
; .byte 0x27, 0x07, 0x05, 0x02
; ld ra, 8(sp)
; ld s0, 0(sp)
; addi sp, sp, 0x10
; ret

0 comments on commit 895a5ac

Please sign in to comment.