Skip to content

Commit e29d6d3

Browse files
Merge pull request #197 from ooooo-create/norm
[Accuracy diff No.28-29] Fix accuracy diff for paddle.linalg.norm API
2 parents c08a280 + 679a1f2 commit e29d6d3

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tester/paddle_to_torch/rules.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3872,7 +3872,7 @@ def apply(self, paddle_api: str) -> ConvertResult:
38723872
result = (x!= 0).sum(dim=axis, keepdim=True).to(x.dtype)
38733873
else:
38743874
result = (x!= 0).sum(dim=axis).to(x.dtype)
3875-
elif len(x.shape)>2 and axis is None:
3875+
elif len(x.shape)>=2 and axis is None:
38763876
if p==math.inf:
38773877
if keepdim:
38783878
result = x.abs().amax().reshape([1] * x.ndim)
@@ -3884,7 +3884,12 @@ def apply(self, paddle_api: str) -> ConvertResult:
38843884
else:
38853885
result = x.abs().amin()
38863886
else:
3887+
_kwargs["input"] = x.flatten()
3888+
if p == "fro":
3889+
_kwargs["ord"] = 2
38873890
result = {self.torch_api}(**_kwargs)
3891+
if keepdim:
3892+
result = result.reshape([1] * x.ndim)
38883893
else:
38893894
result = {self.torch_api}(**_kwargs)
38903895
"""

0 commit comments

Comments
 (0)