Skip to content

cpu: rv64: matmul: fix column-major weights transpose bug#4445

Merged
zhangjian29 merged 1 commit intouxlfoundation:mainfrom
zhangjian29:fix-matmul-col-major-bug
Dec 11, 2025
Merged

cpu: rv64: matmul: fix column-major weights transpose bug#4445
zhangjian29 merged 1 commit intouxlfoundation:mainfrom
zhangjian29:fix-matmul-col-major-bug

Conversation

@zhangjian29
Copy link
Contributor

Description

This PR fixes the bug reported here. Currently, there are failures in rvv_matmul specifically in cases where the weights are in column-major layout. It's a mistake from PR #4363 that current rvv_matmul.hpp allows column-major layout but rvv_matmul.cpp take it as row-major.

Bug Reproducing

  • Currently, any matmul test with --wtag=ba would fail with wrong result.
[RiscvTest@paas-controller-0-0 build_rvv]$ LD_LIBRARY_PATH=$(pwd)/src:$LD_LIBRARY_PATH taskset -c 32 ./tests/benchdnn/benchdnn --matmul --wtag=ba 1x30:30x20
[   0][DST][0:0] exp_f32:        1906 exp:        1906 got:       -9785 diff:   11691 rdiff: 6.13379
[   1][DST][0:1] exp_f32:       -9992 exp:       -9992 got:       11236 diff:   21228 rdiff:  2.1245
[   2][DST][0:2] exp_f32:       17130 exp:       17130 got:       -6999 diff:   24129 rdiff: 1.40858
[   3][DST][0:3] exp_f32:       15640 exp:       15640 got:      -18716 diff:   34356 rdiff: 2.19668
[   4][DST][0:4] exp_f32:      -29999 exp:      -29999 got:        4629 diff:   34628 rdiff: 1.15431
[   5][DST][0:5] exp_f32:       15136 exp:       15136 got:       -5222 diff:   20358 rdiff: 1.34501
[   6][DST][0:6] exp_f32:      -23368 exp:      -23368 got:        2943 diff:   26311 rdiff: 1.12594
[   7][DST][0:7] exp_f32:        5931 exp:        5931 got:        3617 diff:    2314 rdiff:0.390153
[   8][DST][0:8] exp_f32:       11858 exp:       11858 got:        8105 diff:    3753 rdiff:0.316495
[   9][DST][0:9] exp_f32:         538 exp:         538 got:       18551 diff:   18013 rdiff: 33.4814
[COMPARE_STATS][DST]: trh=0 err_max_diff:   44271 err_max_rdiff: 33.4814 all_max_diff:   44271 all_max_rdiff: 33.4814
0:FAILED (errors:20 total:20) (2 ms) __REPRO: --matmul --wtag=ba 1x30:30x20
============================================================
= Implementation statistics (--summary=no-impl to disable) =
============================================================
| RISCV64GCV : 1 (100%)                                    |
============================================================
===========================================================
= Failed cases summary (--summary=no-failures to disable) =
===========================================================
0:FAILED (errors:20 total:20) (2 ms) __REPRO: --matmul --wtag=ba 1x30:30x20
============================
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:1 listed:0
total: 0.00s; create_pd: 0.00s (6%); create_prim: 0.00s (4%); fill: 0.00s (24%); execute: 0.00s (1%); compute_ref: 0.00s (1%); compare: 0.00s (20%);

Bug Fixing

We fix this bug by checking the weight layout and transposing it propoerly, making sure that rvv_gemm_f32 kernel handles it correctly.

  • Now it passes the test with --matmul --wtag=ab,ba --batch=shapes_2d_ci
[RiscvTest@paas-controller-0-0 build_rvv]$ LD_LIBRARY_PATH=$(pwd)/src:$LD_LIBRARY_PATH taskset -c 32 ./tests/benchdnn/benchdnn --matmul --wtag=ba,ab --batch=tests/benchdnn/inputs/matmul/shapes_2d_ci
0:PASSED (222 ms) __REPRO: --matmul --wtag=ba 2048x13:13x512_n"DLRM:0*1"
1:PASSED (223 ms) __REPRO: --matmul --wtag=ab 2048x13:13x512_n"DLRM:0*1"
2:PASSED (655 ms) __REPRO: --matmul --wtag=ba 2048x512:512x256_n"DLRM:1*2"
3:PASSED (659 ms) __REPRO: --matmul --wtag=ab 2048x512:512x256_n"DLRM:1*2"
4:PASSED (201 ms) __REPRO: --matmul --wtag=ba 2048x256:256x128_n"DLRM:2*1"
5:PASSED (193 ms) __REPRO: --matmul --wtag=ab 2048x256:256x128_n"DLRM:2*1"
6:PASSED (2349 ms) __REPRO: --matmul --wtag=ba 2048x479:479x1024_n"DLRM:3*1"
7:PASSED (2436 ms) __REPRO: --matmul --wtag=ab 2048x479:479x1024_n"DLRM:3*1"
8:PASSED (4927 ms) __REPRO: --matmul --wtag=ba 2048x1024:1024x1024_n"DLRM:4*1"
9:PASSED (5275 ms) __REPRO: --matmul --wtag=ab 2048x1024:1024x1024_n"DLRM:4*1"
10:PASSED (2488 ms) __REPRO: --matmul --wtag=ba 2048x1024:1024x512_n"DLRM:5*1"
11:PASSED (2503 ms) __REPRO: --matmul --wtag=ab 2048x1024:1024x512_n"DLRM:5*1"
12:PASSED (28 ms) __REPRO: --matmul --wtag=ba 2048x256:256x1_n"DLRM:7*1"
13:PASSED (25 ms) __REPRO: --matmul --wtag=ab 2048x256:256x1_n"DLRM:7*1"
14:PASSED (371 ms) __REPRO: --matmul --wtag=ba 2048x256:256x256_n"NCF:0*1"
15:PASSED (366 ms) __REPRO: --matmul --wtag=ab 2048x256:256x256_n"NCF:0*1"
16:PASSED (191 ms) __REPRO: --matmul --wtag=ba 2048x256:256x128_n"NCF:1*1"
17:PASSED (194 ms) __REPRO: --matmul --wtag=ab 2048x256:256x128_n"NCF:1*1"
18:PASSED (64 ms) __REPRO: --matmul --wtag=ba 2048x128:128x64_n"NCF:2*1"
19:PASSED (66 ms) __REPRO: --matmul --wtag=ab 2048x128:128x64_n"NCF:2*1"
20:PASSED (13 ms) __REPRO: --matmul --wtag=ba 2048x128:128x1_n"NCF:3*1"
21:PASSED (13 ms) __REPRO: --matmul --wtag=ab 2048x128:128x1_n"NCF:3*1"
22:PASSED (2440 ms) __REPRO: --matmul --wtag=ba 896x240:240x4096_n"RNN-T:Encoder_cell1_Input*2"
23:PASSED (2598 ms) __REPRO: --matmul --wtag=ab 896x240:240x4096_n"RNN-T:Encoder_cell1_Input*2"
24:PASSED (8711 ms) __REPRO: --matmul --wtag=ba 896x1024:1024x4096_n"RNN-T:Encoder_cell1_Hidden*11"
25:PASSED (10306 ms) __REPRO: --matmul --wtag=ab 896x1024:1024x4096_n"RNN-T:Encoder_cell1_Hidden*11"
26:PASSED (17482 ms) __REPRO: --matmul --wtag=ba 896x2048:2048x4096_n"RNN-T:Encoder_cell3_Input*1"
27:PASSED (20602 ms) __REPRO: --matmul --wtag=ab 896x2048:2048x4096_n"RNN-T:Encoder_cell3_Input*1"
28:PASSED (903 ms) __REPRO: --matmul --wtag=ba 896x320:320x1280_n"RNN-T:Prediction_Input*12"
29:PASSED (949 ms) __REPRO: --matmul --wtag=ab 896x320:320x1280_n"RNN-T:Prediction_Input*12"
30:PASSED (1296 ms) __REPRO: --matmul --wtag=ba 896x1344:1344x512_n"RNN-T:JointNet_Linear1*3"
31:PASSED (1374 ms) __REPRO: --matmul --wtag=ab 896x1344:1344x512_n"RNN-T:JointNet_Linear1*3"
32:PASSED (56 ms) __REPRO: --matmul --wtag=ba 896x512:512x29_n"RNN-T:JointNet_Linear2*3"
33:PASSED (67 ms) __REPRO: --matmul --wtag=ab 896x512:512x29_n"RNN-T:JointNet_Linear2*3"
============================================================
= Implementation statistics (--summary=no-impl to disable) =
============================================================
| RISCV64GCV : 34 (100%)                                   |
============================================================
tests:34 passed:34 skipped:0 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 90.28s; create_pd: 0.00s (0%); create_prim: 0.00s (0%); fill: 6.92s (8%); execute: 13.95s (15%); compute_ref: 65.55s (73%); compare: 3.23s (4%);

@zhangjian29 zhangjian29 requested a review from a team as a code owner December 10, 2025 09:01
@zhangjian29 zhangjian29 merged commit 914ba04 into uxlfoundation:main Dec 11, 2025
13 checks passed
@zhangjian29 zhangjian29 deleted the fix-matmul-col-major-bug branch December 11, 2025 01:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants