-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 2】23、为 Paddle 新增 Softmax2D 组网API #40910
Changes from 8 commits
7c642fa
4a34616
03852ce
3a809f7
8afb8c6
294046c
548f9d1
48be37d
bd38188
59cab23
d0484ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个不需要 |
||
|
||
import unittest | ||
import numpy as np | ||
import paddle | ||
import paddle.fluid as fluid | ||
import paddle.fluid.core as core | ||
from test_softmax_op import ref_softmax | ||
|
||
paddle.enable_static() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个可以放到 def test_static_api(self) 中 |
||
np.random.seed(2022) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不需要固定随机种子 |
||
|
||
|
||
class TestSoftmax2DAPI(unittest.TestCase): | ||
# test paddle.nn.Softmax2D | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不必要的注释可以删去 |
||
def setUp(self): | ||
self.shape = [2, 6, 5, 4] | ||
self.dtype = 'float64' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 输入默认需要支持float32,可以增加一个float32的测试 |
||
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') | ||
self.axis = -3 | ||
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ | ||
else paddle.CPUPlace() | ||
|
||
def test_static_api(self): | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) | ||
m = paddle.nn.Softmax2D() | ||
out = m(x) | ||
exe = paddle.static.Executor(self.place) | ||
res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) | ||
out_ref = ref_softmax(self.x_np, self.axis) | ||
self.assertTrue(np.allclose(out_ref, res)) | ||
|
||
def test_dygraph_api(self): | ||
paddle.disable_static(self.place) | ||
x = paddle.to_tensor(self.x_np) | ||
m = paddle.nn.Softmax2D() | ||
out = m(x) | ||
out_ref = ref_softmax(self.x_np, self.axis) | ||
self.assertTrue(np.allclose(out_ref, out.numpy())) | ||
paddle.enable_static() | ||
|
||
|
||
class TestSoftmax2DShape(TestSoftmax2DAPI): | ||
def setUp(self): | ||
self.shape = [2, 6, 4] | ||
self.dtype = 'float64' | ||
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') | ||
self.axis = -3 | ||
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ | ||
else paddle.CPUPlace() | ||
|
||
|
||
class TestSoftmax2DCPU(TestSoftmax2DAPI): | ||
def setUp(self): | ||
self.shape = [2, 6, 4] | ||
self.dtype = 'float64' | ||
self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') | ||
self.axis = -3 | ||
self.place = paddle.CPUPlace() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 增加GPU上单独测试、shape不是3或4维度的错误测试(可以使用assertRaises(error, func,input)) |
||
|
||
class TestSoftmax2DRepr(unittest.TestCase): | ||
def setUp(self): | ||
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ | ||
else paddle.CPUPlace() | ||
|
||
def test_extra_repr(self): | ||
paddle.disable_static(self.place) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里相应修改 |
||
m = paddle.nn.Softmax2D(name='test') | ||
self.assertTrue(m.extra_repr() == 'name=test') | ||
paddle.enable_static() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
空格对齐