Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix docs api #475

Merged
merged 11 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paconvert/api_alias_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"torch.nn.modules.Dropout": "torch.nn.Dropout",
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm",
"torch.nn.modules.LSTM": "torch.nn.LSTM",
"torch.nn.modules.Linear": "torch.nn.linear",
"torch.nn.modules.Linear": "torch.nn.Linear",
"torch.nn.modules.Module": "torch.nn.Module",
"torch.nn.modules.RNN": "torch.nn.RNN",
"torch.nn.modules.RNNBase": "torch.nn.RNNBase",
Expand All @@ -141,7 +141,7 @@
"torch.nn.modules.batchnorm.SyncBatchNorm": "torch.nn.SyncBatchNorm",
"torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d",
"torch.nn.modules.distance.CosineSimilarity": "torch.nn.CosineSimilarity",
"torch.nn.modules.linear.Linear": "torch.nn.linear",
"torch.nn.modules.linear.Linear": "torch.nn.Linear",
"torch.nn.modules.module.Module": "torch.nn.Module",
"torch.nn.modules.pooling.AvgPool1d": "torch.nn.AvgPool1d",
"torch.nn.modules.pooling.AvgPool2d": "torch.nn.AvgPool2d",
Expand Down
117 changes: 115 additions & 2 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -4226,6 +4226,22 @@
"out"
]
},
"torch.amp.autocast": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.amp.auto_cast",
"args_list": [
"device_type",
"dtype",
"enabled",
"cache_enabled"
],
"kwargs_change": {
"device_type": "",
"dtype": "dtype",
"enabled": "enable",
"cache_enabled": ""
}
},
"torch.angle": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.angle",
Expand Down Expand Up @@ -6173,6 +6189,25 @@
"validate_args": ""
}
},
"torch.distributions.Binomial": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Binomial",
"args_list": [
"total_count",
"probs",
"logits",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
},
"unsupport_args": [
"logits"
],
"paddle_default_kwargs": {
"total_count": "1"
}
},
"torch.distributions.Categorical": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Categorical",
Expand Down Expand Up @@ -6215,6 +6250,22 @@
"cache_size": ""
}
},
"torch.distributions.ContinuousBernoulli": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.ContinuousBernoulli",
"args_list": [
"probs",
"logits",
"lims",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
},
"unsupport_args": [
"logits"
]
},
"torch.distributions.Dirichlet": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Dirichlet",
Expand Down Expand Up @@ -6306,6 +6357,17 @@
"cache_size": ""
}
},
"torch.distributions.Exponential": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Exponential",
"args_list": [
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.ExponentialFamily": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.ExponentialFamily",
Expand Down Expand Up @@ -6421,6 +6483,20 @@
"logits"
]
},
"torch.distributions.MultivariateNormal": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.MultivariateNormal",
"args_list": [
"loc",
"covariance_matrix",
"precision_matrix",
"scale_tril",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.Normal": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Normal",
Expand Down Expand Up @@ -11570,6 +11646,17 @@
"input": "x"
}
},
"torch.nn.functional.channel_shuffle": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.channel_shuffle",
"args_list": [
"input",
"groups"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.conv1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.conv1d",
Expand Down Expand Up @@ -14904,8 +14991,34 @@
"input": "x"
}
},
"torch.special.gammainc": {},
"torch.special.gammaincc": {},
"torch.special.gammainc": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.gammainc",
"args_list": [
"input",
"other",
"*",
"out"
],
"kwargs_change": {
"input": "x",
"other": "y"
}
},
"torch.special.gammaincc": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.gammaincc",
"args_list": [
"input",
"other",
"*",
"out"
],
"kwargs_change": {
"input": "x",
"other": "y"
}
},
"torch.special.gammaln": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.lgamma",
Expand Down
23 changes: 21 additions & 2 deletions tests/distributed/load_lib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import ctypes
import os


def load_cusparse_library():
# 获取当前脚本文件的绝对路径
script_dir = os.path.dirname(os.path.abspath(__file__))
# 构建完整的库文件路径
cusparse_lib_path = os.path.join(script_dir, 'libcusparse.so.12')
cusparse_lib_path = os.path.join(script_dir, "libcusparse.so.12")

# 检查库文件是否存在
if not os.path.exists(cusparse_lib_path):
Expand All @@ -23,6 +40,7 @@ def load_cusparse_library():

return libcusparse


def main():
try:
libcusparse = load_cusparse_library()
Expand All @@ -33,5 +51,6 @@ def main():
# print(f"Error occurred: {e}")
pass


if __name__ == "__main__":
main()
main()
154 changes: 154 additions & 0 deletions tests/test_amp_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import textwrap

import numpy as np
import paddle
import pytest
from apibase import APIBase


class AmpAutocastBase(APIBase):
def compare(
self,
name,
pytorch_result,
paddle_result,
check_value=True,
check_dtype=True,
check_stop_gradient=True,
rtol=1.0e-6,
atol=0.0,
):
(
pytorch_numpy,
paddle_numpy,
) = pytorch_result.float().cpu().detach().numpy(), paddle_result.astype(
"float32"
).numpy(
False
)
assert (
pytorch_numpy.shape == paddle_numpy.shape
), "API ({}): shape mismatch, torch shape is {}, paddle shape is {}".format(
name, pytorch_numpy.shape, paddle_numpy.shape
)
assert (
pytorch_numpy.dtype == paddle_numpy.dtype
), "API ({}): dtype mismatch, torch dtype is {}, paddle dtype is {}".format(
name, pytorch_numpy.dtype, paddle_numpy.dtype
)
if check_value:
assert np.allclose(
pytorch_numpy, paddle_numpy, rtol=rtol, atol=atol
), "API ({}): paddle result has diff with pytorch result".format(name)


obj = AmpAutocastBase("torch.amp.autocast")


@pytest.mark.skipif(
condition=not paddle.device.is_compiled_with_cuda(),
reason="can only run on paddle with CUDA",
)
def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
model = torch.nn.Linear(10, 5, device="cuda")
input = torch.randn(4, 10, device="cuda")
with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=False, cache_enabled=True):
result = model(input)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


@pytest.mark.skipif(
condition=not paddle.device.is_compiled_with_cuda(),
reason="can only run on paddle with CUDA",
)
def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
model = torch.nn.Linear(10, 5, device="cuda")
input = torch.randn(4, 10, device="cuda")
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=False, cache_enabled=True):
result = model(input)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


@pytest.mark.skipif(
condition=not paddle.device.is_compiled_with_cuda(),
reason="can only run on paddle with CUDA",
)
def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
model = torch.nn.Linear(10, 5, device="cuda")
input = torch.randn(4, 10, device="cuda")
with torch.amp.autocast(device_type="cuda"):
result = model(input)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


@pytest.mark.skipif(
condition=not paddle.device.is_compiled_with_cuda(),
reason="can only run on paddle with CUDA",
)
def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
model = torch.nn.Linear(10, 5, device="cuda")
input = torch.randn(4, 10, device="cuda")
with torch.amp.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda", cache_enabled=True):
result = model(input)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


@pytest.mark.skipif(
condition=not paddle.device.is_compiled_with_cuda(),
reason="can only run on paddle with CUDA",
)
def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
model = torch.nn.Linear(10, 5, device="cuda")
input = torch.randn(4, 10, device="cuda")
with torch.amp.autocast("cuda", torch.float16, False, True):
result = model(input)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)
Loading