-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mutiply bug allow non-tensor data input #27690
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码我觉得没有问题,测试有一些代码我觉得不是很好,已评论,因为ci已经过了,我这边先approve,要这个PR改还是下一个PR改请自己决定。
@@ -42,66 +43,140 @@ def __run_static_graph_case(self, x_data, y_data, axis=-1): | |||
res = outs[0] | |||
return res | |||
|
|||
def __run_static_graph_case_with_numpy_input(self, x_data, y_data, axis=-1): | |||
with program_guard(Program(), Program()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
program是静态图才有的,所以最好paddle.enable_static写在with program_guard前面
@@ -26,6 +26,7 @@ class TestMultiplyAPI(unittest.TestCase): | |||
|
|||
def __run_static_graph_case(self, x_data, y_data, axis=-1): | |||
with program_guard(Program(), Program()): | |||
paddle.enable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
program是静态图才有的,所以最好paddle.enable_static写在with program_guard前面
res = tensor.multiply(x_data, y_data, axis=axis) | ||
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( | ||
) else fluid.CPUPlace() | ||
exe = fluid.Executor(place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
将来可以尽量用2.0 迁移的写法,比如paddle.CPUPlace, paddle.CUDAPlace,paddle.static.Executor
# test static computation graph: 1-d array | ||
x_data = np.random.rand(200) | ||
y_data = np.random.rand(200) | ||
res = self.__run_static_graph_case(x_data, y_data) | ||
self.assertTrue(np.allclose(res, np.multiply(x_data, y_data))) | ||
|
||
# test static computation graph: 1-d array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
像我以前说过,python里面用注释分隔测试并不是最好的方法,你可以直接创建一个测试method:
def test_static_multiply_1d(self) 这类的
Mutiply allows non-tensor data input
PR types
Bug fixes
PR changes
APIs
Describe