From 08248db0789d22227589cd19767664bc6b6b25b6 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 22 Oct 2021 15:06:46 +0800 Subject: [PATCH] [hapi] support dygraph amp O2 (#36441) * [hapi] support dygrapg amp O2 * fix problem of static pure fp16 in hapi * fix bug * fix format * fix ut * follow comments * update ut * update amp save/load * fix ut * refine code format --- paddle/fluid/framework/data_type_transform.cc | 17 ++- paddle/fluid/framework/data_type_transform.h | 3 + paddle/fluid/imperative/tracer.h | 5 +- paddle/fluid/pybind/pybind.cc | 10 ++ python/paddle/fluid/dygraph/amp/auto_cast.py | 10 +- python/paddle/fluid/tests/test_lod_tensor.py | 7 ++ python/paddle/hapi/model.py | 104 ++++++++++------ python/paddle/tests/test_hapi_amp.py | 115 ++++++++++++++---- 8 files changed, 208 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index 888687c06ce90..faff846cf2a60 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -65,11 +65,24 @@ struct CastDataType { void TransDataType(const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_type, const Tensor& in, Tensor* out) { + PADDLE_ENFORCE_EQ(in.type(), kernel_type_for_var.data_type_, + platform::errors::InvalidArgument( + "The src dtype(%s) of input tensor and kernel_type(%s) " + "are not conststent.", + DataTypeToString(in.type()), + DataTypeToString(kernel_type_for_var.data_type_))); + auto dst_type = expected_kernel_type.data_type_; + TransDataType(in, dst_type, out); +} + +void TransDataType(const Tensor& in, + const paddle::framework::proto::VarType::Type& type, + Tensor* out) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); out->Resize(in.dims()); - auto src_type = kernel_type_for_var.data_type_; - auto dst_type = expected_kernel_type.data_type_; + auto src_type = in.type(); + auto dst_type = type; auto ctx = pool.Get(in.place()); switch (src_type) { diff --git a/paddle/fluid/framework/data_type_transform.h b/paddle/fluid/framework/data_type_transform.h index 499b133dadb17..678764430f0ff 100644 --- a/paddle/fluid/framework/data_type_transform.h +++ b/paddle/fluid/framework/data_type_transform.h @@ -32,6 +32,9 @@ using KernelTypePair = std::pair; void TransDataType(const OpKernelType& kernel_type_for_var, const OpKernelType& expected_kernel_type, const Tensor& in, Tensor* out); +void TransDataType(const Tensor& in, + const paddle::framework::proto::VarType::Type& type, + Tensor* out); /** * Transform complex gradient to real data type. diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 418b2069b5bb6..93f68f2054b9a 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -108,7 +108,10 @@ class Tracer { void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } - void SetAmpLevel(AmpLevel level) { amp_level_ = level; } + void SetAmpLevel(AmpLevel level) { + VLOG(4) << "set amp_level to " << static_cast(level); + amp_level_ = level; + } AmpLevel GetAmpLevel() const { return amp_level_; } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 529e7c6dab8ce..b27c05d98a1c0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -30,6 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/custom_operator.h" #include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/data_type_transform.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/executor_gc_helper.h" @@ -1116,6 +1117,15 @@ PYBIND11_MODULE(core_noavx, m) { ostr << self; return ostr.str(); }) + .def("_as_type", + [](const LoDTensor &self, + paddle::framework::proto::VarType::Type type) { + LoDTensor dst; + if (self.IsInitialized() && self.numel() > 0) { + TransDataType(self, type, &dst); + } + return dst; + }) .def("_copy", [](const LoDTensor &self, const platform::Place &place) { // follow fetch_op's inplementation LoDTensor dst; diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index c807303621aea..ddde3e66c56dc 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -235,9 +235,9 @@ def amp_guard(enable=True, print(conv.dtype) # FP32 """ - if not (level in ['O1', 'O2']): + if not (level in ['O0', 'O1', 'O2']): raise ValueError( - "level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode." + "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode." ) tracer = _dygraph_tracer() @@ -256,10 +256,14 @@ def amp_guard(enable=True, amp_level = AMP_LEVEL.O1 _white_list = WHITE_LIST _black_list = BLACK_LIST - else: + elif level == 'O2': amp_level = AMP_LEVEL.O2 _white_list = PURE_FP16_WHITE_LIST _black_list = PURE_FP16_BLACK_LIST + elif level == 'O0': + amp_level = AMP_LEVEL.O0 + _white_list = WHITE_LIST + _black_list = BLACK_LIST if custom_white_list or custom_black_list: _white_list, _black_list = _update_list(custom_white_list, diff --git a/python/paddle/fluid/tests/test_lod_tensor.py b/python/paddle/fluid/tests/test_lod_tensor.py index 00bfb84602afd..e21224c909f58 100644 --- a/python/paddle/fluid/tests/test_lod_tensor.py +++ b/python/paddle/fluid/tests/test_lod_tensor.py @@ -149,6 +149,13 @@ def test_dlpack_support(self): np.array(gtensor_from_dlpack), np.array([[1], [2], [3], [4]]).astype('int'))) + def test_as_type(self): + tensor = fluid.create_lod_tensor( + np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]], + fluid.CPUPlace()) + fp32_tensor = tensor._as_type(core.VarDesc.VarType.FP32) + print(fp32_tensor) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index abc7aedbd8af7..15d5640b11fe5 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -278,7 +278,7 @@ def __init__(self, model): self._amp_level = "O0" self._amp_configs = {} self._amp_custom_lists = {} - self._use_fp16_guard = True + self._use_fp16_guard = None @property def mode(self): @@ -338,6 +338,7 @@ def _save(state, path): _save(optim, optim_path) + # TODO: support save/load scaler state in static graph def load(self, param_state_pairs, optim_state): if self._executor is None: executor = fluid.Executor(fluid.CPUPlace())._default_executor @@ -455,10 +456,19 @@ def _run(self, inputs, labels=None): feed = {} input_names = [v.name for v in self._input_vars[self.mode]] + input_dtypes = [v.dtype for v in self._input_vars[self.mode]] + for idx, n in enumerate(input_names): # train and test may take different arguments if inputs[idx] is not None: feed[n] = inputs[idx] + if self._amp_level == 'O2' and input_dtypes[ + idx] == core.VarDesc.VarType.FP16: + if isinstance(feed[n], core.LoDTensor): + feed[n] = feed[n]._as_type(core.VarDesc.VarType.FP16) + elif isinstance(feed[n], numpy.array): + feed[n] = feed[n].astype('float16') + if labels is not None: for idx, v in enumerate(self._label_vars[self.mode]): feed[v.name] = labels[idx] @@ -592,7 +602,6 @@ def _make_program(self, mode): amp_lists = paddle.static.amp.AutoMixedPrecisionLists( **self. _amp_custom_lists) if self._amp_custom_lists else None - self.model._optimizer = paddle.static.amp.decorate( self.model._optimizer, amp_lists=amp_lists, @@ -702,10 +711,14 @@ def train_batch(self, inputs, labels=None, update=True): labels = labels or [] labels = [to_variable(l) for l in to_list(labels)] - if self._amp_level != "O0": - scaler = paddle.amp.GradScaler(**self._amp_configs) + # scaler should be initialized only once + if self._amp_level != "O0" and self.model._scaler is None: + self.model._scaler = paddle.amp.GradScaler(**self._amp_configs) + with paddle.amp.auto_cast( - enable=self._amp_level != 'O0', **self._amp_custom_lists): + enable=self._amp_level != 'O0', + **self._amp_custom_lists, + level=self._amp_level): if self._nranks > 1: outputs = self.ddp_model.forward( *[to_variable(x) for x in inputs]) @@ -713,15 +726,15 @@ def train_batch(self, inputs, labels=None, update=True): outputs = self.model.network.forward( *[to_variable(x) for x in inputs]) - losses = self.model._loss(*(to_list(outputs) + labels)) - losses = to_list(losses) - final_loss = fluid.layers.sum(losses) + losses = self.model._loss(*(to_list(outputs) + labels)) + losses = to_list(losses) + final_loss = fluid.layers.sum(losses) if self._amp_level != "O0": - scaled = scaler.scale(final_loss) + scaled = self.model._scaler.scale(final_loss) scaled.backward() if update: - scaler.minimize(self.model._optimizer, scaled) + self.model._scaler.minimize(self.model._optimizer, scaled) self.model.network.clear_gradients() else: final_loss.backward() @@ -804,17 +817,24 @@ def parameters(self, *args, **kwargs): def save(self, path): params = self.model.network.state_dict() fluid.save_dygraph(params, path) - if self.model._optimizer is None: - return - if self.model._optimizer.state_dict(): - optim = self.model._optimizer.state_dict() - fluid.save_dygraph(optim, path) - - def load(self, param_state_pairs, optim_state): + if self.model._optimizer is not None: + if self.model._optimizer.state_dict(): + optim = self.model._optimizer.state_dict() + fluid.save_dygraph(optim, path) + if hasattr(self.model, '_scaler') and self.model._scaler is not None: + if self.model._scaler.state_dict(): + scaler = self.model._scaler.state_dict() + paddle.save(scaler, path + '.pdscaler') + + def load(self, param_state_pairs, optim_state, scaler_state=None): # restore parameter states for param, state in param_state_pairs: param.set_value(state) + if hasattr(self.model, '_scaler') and self.model._scaler is not None: + if scaler_state: + self.model._scaler.load_state_dict(scaler_state) + # resotre optimizer states if not self.model._optimizer or not optim_state: return @@ -872,6 +892,16 @@ def load(self, param_state_pairs, optim_state): else: self.model._optimizer.set_state_dict(converted_state) + def prepare(self): + if self._amp_level == "O2" and self.model.mode == 'train' and core.is_compiled_with_cuda( + ): + self.model.network, self.model._optimizer = paddle.amp.decorate( + models=self.model.network, + optimizers=self.model._optimizer, + level='O2') + if self._amp_level != "O0": + self.model._scaler = None + class Model(object): """ @@ -882,9 +912,9 @@ class Model(object): instantiating a Model. The input description, i.e, paddle.static.InputSpec, must be required for static graph. - When training on GPU, auto mixed precision (AMP) training is supported, and - pure float16 training is also supported in static mode while using Adam, - AdamW and Momentum optimizer. Before using pure float16 training, + When training on GPU, auto mixed precision (AMP O1) and pure float16 + (AMP O2) training are both supported in static mode and dynamic mode. + In static graph mode, before traing with pure float16 (AMP O2), `multi_precision` could be set to True when creating optimizer, which can avoid poor accuracy or slow convergence in a way, and inputs of dtype float should be cast to float16 by users. `paddle.static.amp.fp16_guard` API @@ -946,7 +976,8 @@ class Model(object): 2. An example using mixed precision training. .. code-block:: python - + + # required: gpu import paddle import paddle.nn as nn import paddle.vision.transforms as T @@ -1331,7 +1362,18 @@ def _strip_postfix(path): optim_state = None if reset_optimizer else _load_state_from_path( path + ".pdopt") - return self._adapter.load(matched_param_state, optim_state) + + # TODO: support save/load scaler state in static graph + if in_dygraph_mode(): + scaler_state = None + if hasattr(self, '_scaler') and self._scaler is not None: + if os.path.exists(path + '.pdscaler'): + scaler_state = paddle.load(path + '.pdscaler') + + return self._adapter.load(matched_param_state, optim_state, + scaler_state) + else: + return self._adapter.load(matched_param_state, optim_state) def parameters(self, *args, **kwargs): """ @@ -1363,15 +1405,10 @@ def parameters(self, *args, **kwargs): def _prepare_amp(self, amp_configs): def _check_pure_fp16_configs(): # pure float16 training has some restricts now - if self._adapter._amp_level == "O2": - if in_dygraph_mode(): - warnings.warn( - "Pure float16 training is not supported in dygraph mode now, and it will be supported in future version." - ) - else: - # grad clip is not supported in pure fp16 training now - assert self._optimizer._grad_clip is None, \ - "Grad clip is not supported in pure float16 training now, and it will be supported in future version." + if self._adapter._amp_level == "O2" and self._optimizer._grad_clip: + # clip by value is not supported + assert isinstance(self._optimizer._grad_clip, (paddle.nn.ClipGradByGlobalNorm, paddle.nn.ClipGradByNorm)), \ + "Only GradientClipByNorm and GradientClipByGlobalNorm are supported in amp training with level=O2 currently." self._adapter._amp_custom_lists = {} self._adapter._amp_configs = {} @@ -1479,7 +1516,6 @@ def prepare(self, optimizer=None, loss=None, metrics=None, Returns: None """ - self._place = _get_device() if isinstance(self._place, fluid.CUDAPlace): global _parallel_context_initialized @@ -1515,8 +1551,7 @@ def prepare(self, optimizer=None, loss=None, metrics=None, self._metrics = to_list(metrics) self._prepare_amp(amp_configs) - if not in_dygraph_mode(): - self._adapter.prepare() + self._adapter.prepare() def fit(self, train_data=None, @@ -1667,7 +1702,6 @@ def fit(self, epochs=2, save_dir='mnist_checkpoint') """ - assert train_data is not None, \ "train_data must be given!" diff --git a/python/paddle/tests/test_hapi_amp.py b/python/paddle/tests/test_hapi_amp.py index ecab4db7516d7..d17b6f3594713 100644 --- a/python/paddle/tests/test_hapi_amp.py +++ b/python/paddle/tests/test_hapi_amp.py @@ -15,6 +15,9 @@ from __future__ import division from __future__ import print_function +import os +os.environ['FLAGS_cudnn_deterministic'] = '1' + import unittest import numpy as np @@ -26,34 +29,102 @@ from paddle.static import InputSpec from paddle.nn.layer.loss import CrossEntropyLoss from paddle.vision.models import LeNet +from paddle.vision.datasets import MNIST +import paddle.vision.transforms as T @unittest.skipIf(not fluid.is_compiled_with_cuda(), 'CPU testing is not supported') -class TestDistTraningUsingAMP(unittest.TestCase): - def test_amp_training(self): - if not fluid.is_compiled_with_cuda(): - self.skipTest('module not tested when ONLY_CPU compling') - data = np.random.random(size=(4, 1, 28, 28)).astype(np.float32) - label = np.random.randint(0, 10, size=(4, 1)).astype(np.int64) - amp_level = "O1" +class TestHapiWithAmp(unittest.TestCase): + def get_model(self, amp_config): + net = LeNet() + inputs = InputSpec([None, 1, 28, 28], "float32", 'x') + labels = InputSpec([None, 1], "int64", "y") + model = Model(net, inputs, labels) + optim = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + model.prepare( + optimizer=optim, + loss=CrossEntropyLoss(reduction="sum"), + amp_configs=amp_config) + return model + + def run_model(self, model): + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = MNIST(mode='train', transform=transform) + model.fit(train_dataset, + epochs=1, + batch_size=64, + num_iters=2, + log_freq=1) + + def run_amp(self, amp_level): for dynamic in [True, False]: - if not fluid.is_compiled_with_cuda(): - self.skipTest('module not tested when ONLY_CPU compling') - paddle.enable_static() if not dynamic else None + if not dynamic and amp_level['level'] == 'O2': + amp_level['use_fp16_guard'] = False + print('dynamic' if dynamic else 'static', amp_level) + + paddle.seed(2021) + paddle.enable_static() if not dynamic else paddle.disable_static() paddle.set_device('gpu') - net = LeNet() - inputs = InputSpec([None, 1, 28, 28], "float32", 'x') - labels = InputSpec([None, 1], "int64", "y") - model = Model(net, inputs, labels) - optim = paddle.optimizer.Adam( - learning_rate=0.001, parameters=model.parameters()) - amp_configs = {"level": amp_level} - model.prepare( - optimizer=optim, - loss=CrossEntropyLoss(reduction="sum"), - amp_configs=amp_configs) - model.train_batch([data], [label]) + model = self.get_model(amp_level) + self.run_model(model) + + def test_pure_fp16(self): + amp_config = { + "level": "O2", + "init_loss_scaling": 128, + } + self.run_amp(amp_config) + + def test_amp(self): + amp_config = {"level": "O1", "init_loss_scaling": 128} + self.run_amp(amp_config) + + def test_fp32(self): + amp_config = {"level": "O0", } + self.run_amp(amp_config) + + def test_save_load(self): + paddle.disable_static() + paddle.set_device('gpu') + amp_level = {"level": "O1", "init_loss_scaling": 128} + paddle.seed(2021) + model = self.get_model(amp_level) + transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) + train_dataset = MNIST(mode='train', transform=transform) + model.fit(train_dataset, + epochs=1, + batch_size=64, + num_iters=2, + log_freq=1) + model.save('./lenet_amp') + + with paddle.fluid.unique_name.guard(): + paddle.seed(2021) + new_model = self.get_model(amp_level) + train_dataset = MNIST(mode='train', transform=transform) + new_model.fit(train_dataset, + epochs=1, + batch_size=64, + num_iters=1, + log_freq=1) + # not equal before load + self.assertNotEqual(new_model._scaler.state_dict()['incr_count'], + model._scaler.state_dict()['incr_count']) + print((new_model._scaler.state_dict()['incr_count'], + model._scaler.state_dict()['incr_count'])) + + # equal after load + new_model.load('./lenet_amp') + self.assertEqual(new_model._scaler.state_dict()['incr_count'], + model._scaler.state_dict()['incr_count']) + self.assertEqual(new_model._scaler.state_dict()['decr_count'], + model._scaler.state_dict()['decr_count']) + self.assertTrue( + np.array_equal(new_model._optimizer.state_dict( + )['conv2d_1.w_0_moment1_0'].numpy( + ), model._optimizer.state_dict()['conv2d_1.w_0_moment1_0'].numpy())) def test_dynamic_check_input(self): paddle.disable_static()