Skip to content

Commit

Permalink
cpu: x64: gemm: fix small N kernels A mask loads
Browse files Browse the repository at this point in the history
We need to zero out when loading in case NaNs are present in zmm
registers for A matrix.
  • Loading branch information
aaraujom committed Sep 26, 2020
1 parent 6cd0c35 commit 5ce95ef
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ struct xbyak_gemm_smalln_tn : public jit_generator {
int endval = (MROW < 5) ? MROW : 5;
for (int ii = 8; ii < 8 + endval; ii++) {
// Storing A values in zmm_reg[8-12]
vmovups(zmm_reg[ii] | (krem ? k_rem : k0), ptr[AO2]);
vmovups(zmm_reg[ii] | (krem ? k_rem : k0) | T_z, ptr[AO2]);
add(AO2, LDA);
}
for (int ii = 0; ii < endval; ii++) {
Expand All @@ -614,7 +614,7 @@ struct xbyak_gemm_smalln_tn : public jit_generator {
? 8
: MROW; // Do not process more than 8 rows here.
for (int ii = 0; ii < MROW2; ii++) {
vmovups(zmm_reg[ii] | (krem ? k_rem : k0), ptr[AO2]);
vmovups(zmm_reg[ii] | (krem ? k_rem : k0) | T_z, ptr[AO2]);
add(AO2, LDA);
}
for (int ii = 0; ii < MROW2; ii++) {
Expand All @@ -627,7 +627,7 @@ struct xbyak_gemm_smalln_tn : public jit_generator {
if (MROW > 8) {
vmovaps(zmm_reg[0], zmm_reg[15]);
for (int ii = 8; ii < MROW; ii++) {
vmovups(zmm_reg[ii] | (krem ? k_rem : k0), ptr[AO2]);
vmovups(zmm_reg[ii] | (krem ? k_rem : k0) | T_z, ptr[AO2]);
add(AO2, LDA);
}
for (int ii = 8; ii < MROW; ii++)
Expand Down

0 comments on commit 5ce95ef

Please sign in to comment.