Skip to content

Commit

Permalink
[AMP] Check call order of paddle.amp.decorate and paddle.DataParallel (
Browse files Browse the repository at this point in the history
…#38785)

* check amp.decorate and DataParallel

* refine coverage

* fix layer dtype

* refine code
  • Loading branch information
zhangbo9674 authored Jan 11, 2022
1 parent 9f34a07 commit fbb4028
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def check_models(models):
raise RuntimeError(
"Current train mode is pure fp16, models should be paddle.nn.Layer, but receive {}.".
format(type(model)))
if isinstance(model, paddle.DataParallel):
raise RuntimeError(
"For distributed AMP training, you should first use paddle.amp.decorate() to decotate origin model, and then call paddle.DataParallel get distributed model."
)


def check_optimizers(optimizers):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,8 @@ def _apply(self, func, device, dtype, blocking, include_sublayers=True):
for key, buf in self._buffers.items():
self._buffers[key] = func(buf, device, dtype, blocking)

self._dtype = dtype

def _to_impl(self,
device=None,
dtype=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,14 @@ def __init__(self):

self.assertRaises(TypeError, test_error_model)

def test_error_distributed_model():
model = fluid.dygraph.Conv2D(3, 2, 3, bias_attr=False, act=None)
model = paddle.DataParallel(model)
with fluid.dygraph.guard():
model = paddle.amp.decorate(models=model, level='O2')

self.assertRaises(RuntimeError, test_error_distributed_model)

def test_error_optimizer():
class MyOptimizer(object):
def __init__(self):
Expand Down

0 comments on commit fbb4028

Please sign in to comment.