Skip to content

Conversation

@lshpku
Copy link
Contributor

@lshpku lshpku commented May 7, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

与torch对齐paddle.linalg.solve的调用逻辑,提升数百倍性能,并修复大Tensor下的报错

简介

linalg.solve用于求解矩阵方程A @ X = B,其中A[batch_size, n, n],X[batch_size, n, nrhs],B[batch_size, n, nrhs](batch_size可以没有或者是多维)

求解方法

  1. 首先对A做PLU分解,得到P @ A = L @ U;其中P是一个permutation矩阵(实际实现中使用pivot数组记录),L是一个对角线全为1的下三角矩阵,U是一个上三角矩阵

  2. 然后将P @ A = L @ U做等价变换得到A = P^T @ L @ U,然后将A带入原来的方程A @ X = B中,得到P^T @ L @ U @ X = B

  3. 接下来把P^T移到右边,得到L @ U @ X = P @ B,由于P是一个permutation矩阵,因此这步只需要scatter操作即可,不需要真的矩阵乘

  4. 最后进行两次triangular求解,将L和U依次移动到右边,即L @ U @ X = P @ B ==第1次求解=> U @ X = L^-1 @ P @ B ==第2次求解=> X = U^-1 @ L^-1 @ P @ B

  5. 此时X = U^-1 @ L^-1 @ P @ B即是方程的解

注:由于cublas使用列优先存储,而paddle使用行优先,出于性能考虑并没有对A在输入时进行转置(和torch保持一致),因此上述公式中A实际上是A^T,实际求解方程是X = P^T @ L^T^-1 @ U^T^-1 @ B,cublas调用方法大致相同,只是顺序不同

执行过程

  1. PLU分解:调用cublas getrfBatched
  2. triangular求解:调用2次cublas trsm/trsmBatched,这是本PR的核心修改;原来用的是getrfBatched函数,这个函数把trsm和permutation结合在一起了,性能又差又不支持size稍微大一点的情况,因此必须换掉
  3. permutation:首先调用自己写的UnpackPivot kernel把pivot数组转换为permutation index,然后调用phi scatter kernel进行重排序

参考torch实现:lu_solve_kernel

与torch对齐情况

  1. batch_size<=16 且 m<=65535*8,或 batch_size<=2 时,和torch完全对齐(nsys trace相同)
  2. batch_size>16 时,PLU分解这步torch会调用非cublas的函数,确实难以对齐,且有一定精度差别,因为PA=LU的P选取策略会影响分解结果,只能做到 atol=1e-3 rtol=1e-3 级别的对齐
  3. batch_size>2 且 m>65535*8 时,torch 2.6.0竟然是错的!看trace是因为用了不支持的trsmBatched函数,见以下测试
import paddle, torch, numpy as np

x = np.random.normal(size=[3, 16, 16]).astype(np.float32)
y = np.random.normal(size=[3, 16, 65535*8+1]).astype(np.float32)

a = paddle.linalg.solve(paddle.to_tensor(x), paddle.to_tensor(y))
b = torch.linalg.solve(torch.from_numpy(x).to('cuda'), torch.from_numpy(y).to('cuda')).cpu()
c = np.linalg.solve(x, y)

np.testing.assert_allclose(a, c, atol=1e-4, rtol=1e-4)  # OK
np.testing.assert_allclose(b, c, atol=1e-4, rtol=1e-4)  # Fail!

关于int使用的说明

注意到,我在kernel中很多地方没有把int转成int64_t,包括n、nrhs,这是因为:

  1. cublas的输入是int,我们在外面使用更大的类型也没有用
  2. 而且n是矩阵的边长,边长不可能超过int(哪怕n^2也不可能超过int,不然数值不稳定解不出来)
  3. 当然,batch_size * n * nrhs 可以超过int,所以我将batch_size定义为int64_t,这样定义只是为了乘的时候不用再cast,batch_size本身仍在int范围内
  4. 我已经做到在数据类型上达到算子库的上限,包括我自己写的kernel也是支持int64_t,其他的只能看算子库了

Pcard-85711

@paddle-bot
Copy link

paddle-bot bot commented May 7, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lshpku lshpku force-pushed the fix-linalg-solve branch 2 times, most recently from caa4e2d to ad9dc8f Compare May 8, 2025 07:03
@lshpku lshpku force-pushed the fix-linalg-solve branch from ad9dc8f to dd0fb78 Compare May 8, 2025 07:04
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lshpku lshpku merged commit 3ed6bc2 into PaddlePaddle:develop May 9, 2025
47 of 49 checks passed
GITD245 pushed a commit to GITD245/Paddle that referenced this pull request May 14, 2025
wanghuancoder pushed a commit to wanghuancoder/Paddle that referenced this pull request May 27, 2025
wanghuancoder added a commit that referenced this pull request Jun 3, 2025
* refine forrange (#72360)

* refine forrange

* refine forrange

* reduce support big tensor (#71970)

* reduce support big tensor

* [PHI] Fix gridDim limit for reduce kernel (#72507)

* [API] isclose support bigtensor (#72516)

* isclose support bigtensor

* refine

* [API] isnan isinf isfinite support bigtensor (#72517)

* isnan isinf isfinite support bigtensor

* refine

* [PHI] Fix cum kernel for big tensor (#72562)

* [PHI] Preliminary fix for elementwise broadcast int32 shape overflow (#72584)

* [PHI] Align linalg.solve kernel with torch (#72608)

* Update strided copy kernel (#72662)

* [PHI] Fix grid sample kernel for big tensor (#72628)

* [PHI] Fix argsort big tensor bug (#72712)

* [PHI] Fixed argsort big tensor bug

* [PHI] Fixed shape mismatch problem.

* [PHI] Fix contiguous kernel for big tensor (#72705)

* [PHI] Fix flatten and split kernel for big tensor (#72634)

* [PHI] Fix out-of-bound issue of paddle.take_along_axis (#72757)

* [PHI] fix paddle.diag with big tensor (#72638)

* [API] fix paddle.cross with big tensor (#72652)

* [PHI] Fix paddle.where api for big tensor (#72717)

* [PHI] Fix bincount kernel for big tensor (#72706)

* fix bincount kernel for big tensor

* use HostAlloc to alloc memory

* add cpu test case

* [PHI] Fix full_like kernel for big tensor (#72831)

* [API] Fix int overflow and float16 support for paddle.frac (#72815)

* [PHI] Align paddle.inner with torch in matmul logic (#72843)

* [PHI] Fix paddle.var & paddle.std float16 overflow (#72650)

* [PHI] Fix logsumexp precision problem (#72681)

* [PHI] Debug for logsumexp, bug source found

* [PHI] Removed GetNumBlocks func to get correct logsumexp

* [PHI] Removed redundant debug VLOG

* [PHI] Elegant grid bounded solution

* [Accuracy diff No.55-56、76-77] Fix accuracy diff for var&std API (#72879)

* [Accuracy diff No.21] Fix accuracy diff for heaviside API (#72894)

---------

Co-authored-by: Shuhao Liang <[email protected]>
Co-authored-by: Qianyue He <[email protected]>
Co-authored-by: Lei Ding <[email protected]>
Co-authored-by: ggggxm <[email protected]>
Co-authored-by: xkkkkkk23 <[email protected]>
Co-authored-by: Zx <[email protected]>
Co-authored-by: huangjiyi <[email protected]>
Co-authored-by: ooo oo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants