Skip to content

Commit

Permalink
Fix dropout static when axis != None (#37223) (#37589)
Browse files Browse the repository at this point in the history
* fix dropout static when axis != None

* update dropout test

* add dropout test

* fix test

* Update test_dropout_op.py

* Update test_dropout_op.py

* fix testcase

* fix testcase

* Update test_dropout_op.py

* fix testcase

* fix testcase

* optimize perf

* add new test

* fix testcase
  • Loading branch information
smallv0221 authored Nov 29, 2021
1 parent 7d9c669 commit 3a0c550
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
16 changes: 14 additions & 2 deletions python/paddle/fluid/tests/unittests/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def setUp(self):

def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[40, 40], dtype="float32")
input = fluid.data(name="input", shape=[-1, -1], dtype="float32")
res1 = paddle.nn.functional.dropout(x=input, p=0., training=False)
res2 = paddle.nn.functional.dropout(
x=input, p=0., axis=0, training=True, mode='upscale_in_train')
Expand Down Expand Up @@ -380,7 +380,10 @@ def check_static_result(self, place):
training=False,
mode='upscale_in_train')

in_np = np.random.random([40, 40]).astype("float32")
res13 = paddle.nn.functional.dropout(
x=input, p=0.7, axis=1, training=True, mode='upscale_in_train')

in_np = np.ones([40, 40]).astype("float32")
res_np = in_np
res_np2 = np.zeros_like(in_np)

Expand All @@ -398,6 +401,9 @@ def check_static_result(self, place):
feed={"input": in_np},
fetch_list=[res10])
self.assertTrue(np.allclose(fetches2[0], res_np2))
fetches3 = exe.run(fluid.default_main_program(),
feed={"input": in_np},
fetch_list=[res13])

def test_static(self):
for place in self.places:
Expand Down Expand Up @@ -471,6 +477,12 @@ def test_dygraph(self):
axis=(0, 1),
training=False,
mode='upscale_in_train')
res13 = paddle.nn.functional.dropout(
x=input,
p=0.5,
axis=1,
training=True,
mode='upscale_in_train')

res_list = [
res1, res2, res3, res4, res5, res6, res7, res8, res9, res11,
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,8 @@ def get_attrs(prog, dropout_prob, is_test, seed):

#get mask shape
input_shape = x.shape
if not in_dygraph_mode():
input_shape_tensor = paddle.shape(x)
drop_axes = [axis] if isinstance(axis, int) else list(axis)
if min(drop_axes) < 0 or max(drop_axes) > len(input_shape) - 1:
raise ValueError("axis value should be greater than or equal to 0 and less than dimensions of x:{}, but get axis value:{} " \
Expand All @@ -948,8 +950,12 @@ def get_attrs(prog, dropout_prob, is_test, seed):
"length of axis should not be greater than dimensions of x:{}, but get length of axis: {}".
format(len(input_shape), len(drop_axes)))
mask_shape = [1] * len(input_shape)
for i in drop_axes:
mask_shape[i] = input_shape[i]
if not in_dygraph_mode():
for i in drop_axes:
mask_shape[i] = input_shape_tensor[i]
else:
for i in drop_axes:
mask_shape[i] = input_shape[i]

#get mask
random_tensor = paddle.uniform(
Expand Down

0 comments on commit 3a0c550

Please sign in to comment.