Skip to content

Commit

Permalink
[hapi] support dygraph amp O2 (#36441)
Browse files Browse the repository at this point in the history
* [hapi] support dygrapg amp O2

* fix problem of static pure fp16 in hapi

* fix bug

* fix format

* fix ut

* follow comments

* update ut

* update amp save/load

* fix ut

* refine code format
  • Loading branch information
zhiqiu authored Oct 22, 2021
1 parent 6580ad1 commit 08248db
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 63 deletions.
17 changes: 15 additions & 2 deletions paddle/fluid/framework/data_type_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,24 @@ struct CastDataType {
void TransDataType(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const Tensor& in,
Tensor* out) {
PADDLE_ENFORCE_EQ(in.type(), kernel_type_for_var.data_type_,
platform::errors::InvalidArgument(
"The src dtype(%s) of input tensor and kernel_type(%s) "
"are not conststent.",
DataTypeToString(in.type()),
DataTypeToString(kernel_type_for_var.data_type_)));
auto dst_type = expected_kernel_type.data_type_;
TransDataType(in, dst_type, out);
}

void TransDataType(const Tensor& in,
const paddle::framework::proto::VarType::Type& type,
Tensor* out) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

out->Resize(in.dims());
auto src_type = kernel_type_for_var.data_type_;
auto dst_type = expected_kernel_type.data_type_;
auto src_type = in.type();
auto dst_type = type;
auto ctx = pool.Get(in.place());

switch (src_type) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/data_type_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
void TransDataType(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const Tensor& in,
Tensor* out);
void TransDataType(const Tensor& in,
const paddle::framework::proto::VarType::Type& type,
Tensor* out);

/**
* Transform complex gradient to real data type.
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/imperative/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ class Tracer {

void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }

void SetAmpLevel(AmpLevel level) { amp_level_ = level; }
void SetAmpLevel(AmpLevel level) {
VLOG(4) << "set amp_level to " << static_cast<unsigned int>(level);
amp_level_ = level;
}

AmpLevel GetAmpLevel() const { return amp_level_; }

Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License. */

#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
Expand Down Expand Up @@ -1116,6 +1117,15 @@ PYBIND11_MODULE(core_noavx, m) {
ostr << self;
return ostr.str();
})
.def("_as_type",
[](const LoDTensor &self,
paddle::framework::proto::VarType::Type type) {
LoDTensor dst;
if (self.IsInitialized() && self.numel() > 0) {
TransDataType(self, type, &dst);
}
return dst;
})
.def("_copy", [](const LoDTensor &self, const platform::Place &place) {
// follow fetch_op's inplementation
LoDTensor dst;
Expand Down
10 changes: 7 additions & 3 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def amp_guard(enable=True,
print(conv.dtype) # FP32
"""
if not (level in ['O1', 'O2']):
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O1 or O2, O1 represent AMP train mode, O2 represent Pure fp16 train mode."
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode."
)

tracer = _dygraph_tracer()
Expand All @@ -256,10 +256,14 @@ def amp_guard(enable=True,
amp_level = AMP_LEVEL.O1
_white_list = WHITE_LIST
_black_list = BLACK_LIST
else:
elif level == 'O2':
amp_level = AMP_LEVEL.O2
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif level == 'O0':
amp_level = AMP_LEVEL.O0
_white_list = WHITE_LIST
_black_list = BLACK_LIST

if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(custom_white_list,
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/fluid/tests/test_lod_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def test_dlpack_support(self):
np.array(gtensor_from_dlpack),
np.array([[1], [2], [3], [4]]).astype('int')))

def test_as_type(self):
tensor = fluid.create_lod_tensor(
np.array([[1], [2], [3], [4]]).astype('int'), [[1, 3]],
fluid.CPUPlace())
fp32_tensor = tensor._as_type(core.VarDesc.VarType.FP32)
print(fp32_tensor)


if __name__ == '__main__':
unittest.main()
104 changes: 69 additions & 35 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(self, model):
self._amp_level = "O0"
self._amp_configs = {}
self._amp_custom_lists = {}
self._use_fp16_guard = True
self._use_fp16_guard = None

@property
def mode(self):
Expand Down Expand Up @@ -338,6 +338,7 @@ def _save(state, path):

_save(optim, optim_path)

# TODO: support save/load scaler state in static graph
def load(self, param_state_pairs, optim_state):
if self._executor is None:
executor = fluid.Executor(fluid.CPUPlace())._default_executor
Expand Down Expand Up @@ -455,10 +456,19 @@ def _run(self, inputs, labels=None):

feed = {}
input_names = [v.name for v in self._input_vars[self.mode]]
input_dtypes = [v.dtype for v in self._input_vars[self.mode]]

for idx, n in enumerate(input_names):
# train and test may take different arguments
if inputs[idx] is not None:
feed[n] = inputs[idx]
if self._amp_level == 'O2' and input_dtypes[
idx] == core.VarDesc.VarType.FP16:
if isinstance(feed[n], core.LoDTensor):
feed[n] = feed[n]._as_type(core.VarDesc.VarType.FP16)
elif isinstance(feed[n], numpy.array):
feed[n] = feed[n].astype('float16')

if labels is not None:
for idx, v in enumerate(self._label_vars[self.mode]):
feed[v.name] = labels[idx]
Expand Down Expand Up @@ -592,7 +602,6 @@ def _make_program(self, mode):
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
**self.
_amp_custom_lists) if self._amp_custom_lists else None

self.model._optimizer = paddle.static.amp.decorate(
self.model._optimizer,
amp_lists=amp_lists,
Expand Down Expand Up @@ -702,26 +711,30 @@ def train_batch(self, inputs, labels=None, update=True):
labels = labels or []
labels = [to_variable(l) for l in to_list(labels)]

if self._amp_level != "O0":
scaler = paddle.amp.GradScaler(**self._amp_configs)
# scaler should be initialized only once
if self._amp_level != "O0" and self.model._scaler is None:
self.model._scaler = paddle.amp.GradScaler(**self._amp_configs)

with paddle.amp.auto_cast(
enable=self._amp_level != 'O0', **self._amp_custom_lists):
enable=self._amp_level != 'O0',
**self._amp_custom_lists,
level=self._amp_level):
if self._nranks > 1:
outputs = self.ddp_model.forward(
*[to_variable(x) for x in inputs])
else:
outputs = self.model.network.forward(
*[to_variable(x) for x in inputs])

losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
final_loss = fluid.layers.sum(losses)
losses = self.model._loss(*(to_list(outputs) + labels))
losses = to_list(losses)
final_loss = fluid.layers.sum(losses)

if self._amp_level != "O0":
scaled = scaler.scale(final_loss)
scaled = self.model._scaler.scale(final_loss)
scaled.backward()
if update:
scaler.minimize(self.model._optimizer, scaled)
self.model._scaler.minimize(self.model._optimizer, scaled)
self.model.network.clear_gradients()
else:
final_loss.backward()
Expand Down Expand Up @@ -804,17 +817,24 @@ def parameters(self, *args, **kwargs):
def save(self, path):
params = self.model.network.state_dict()
fluid.save_dygraph(params, path)
if self.model._optimizer is None:
return
if self.model._optimizer.state_dict():
optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path)

def load(self, param_state_pairs, optim_state):
if self.model._optimizer is not None:
if self.model._optimizer.state_dict():
optim = self.model._optimizer.state_dict()
fluid.save_dygraph(optim, path)
if hasattr(self.model, '_scaler') and self.model._scaler is not None:
if self.model._scaler.state_dict():
scaler = self.model._scaler.state_dict()
paddle.save(scaler, path + '.pdscaler')

def load(self, param_state_pairs, optim_state, scaler_state=None):
# restore parameter states
for param, state in param_state_pairs:
param.set_value(state)

if hasattr(self.model, '_scaler') and self.model._scaler is not None:
if scaler_state:
self.model._scaler.load_state_dict(scaler_state)

# resotre optimizer states
if not self.model._optimizer or not optim_state:
return
Expand Down Expand Up @@ -872,6 +892,16 @@ def load(self, param_state_pairs, optim_state):
else:
self.model._optimizer.set_state_dict(converted_state)

def prepare(self):
if self._amp_level == "O2" and self.model.mode == 'train' and core.is_compiled_with_cuda(
):
self.model.network, self.model._optimizer = paddle.amp.decorate(
models=self.model.network,
optimizers=self.model._optimizer,
level='O2')
if self._amp_level != "O0":
self.model._scaler = None


class Model(object):
"""
Expand All @@ -882,9 +912,9 @@ class Model(object):
instantiating a Model. The input description, i.e, paddle.static.InputSpec,
must be required for static graph.
When training on GPU, auto mixed precision (AMP) training is supported, and
pure float16 training is also supported in static mode while using Adam,
AdamW and Momentum optimizer. Before using pure float16 training,
When training on GPU, auto mixed precision (AMP O1) and pure float16
(AMP O2) training are both supported in static mode and dynamic mode.
In static graph mode, before traing with pure float16 (AMP O2),
`multi_precision` could be set to True when creating optimizer, which can
avoid poor accuracy or slow convergence in a way, and inputs of dtype float
should be cast to float16 by users. `paddle.static.amp.fp16_guard` API
Expand Down Expand Up @@ -946,7 +976,8 @@ class Model(object):
2. An example using mixed precision training.
.. code-block:: python
# required: gpu
import paddle
import paddle.nn as nn
import paddle.vision.transforms as T
Expand Down Expand Up @@ -1331,7 +1362,18 @@ def _strip_postfix(path):

optim_state = None if reset_optimizer else _load_state_from_path(
path + ".pdopt")
return self._adapter.load(matched_param_state, optim_state)

# TODO: support save/load scaler state in static graph
if in_dygraph_mode():
scaler_state = None
if hasattr(self, '_scaler') and self._scaler is not None:
if os.path.exists(path + '.pdscaler'):
scaler_state = paddle.load(path + '.pdscaler')

return self._adapter.load(matched_param_state, optim_state,
scaler_state)
else:
return self._adapter.load(matched_param_state, optim_state)

def parameters(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -1363,15 +1405,10 @@ def parameters(self, *args, **kwargs):
def _prepare_amp(self, amp_configs):
def _check_pure_fp16_configs():
# pure float16 training has some restricts now
if self._adapter._amp_level == "O2":
if in_dygraph_mode():
warnings.warn(
"Pure float16 training is not supported in dygraph mode now, and it will be supported in future version."
)
else:
# grad clip is not supported in pure fp16 training now
assert self._optimizer._grad_clip is None, \
"Grad clip is not supported in pure float16 training now, and it will be supported in future version."
if self._adapter._amp_level == "O2" and self._optimizer._grad_clip:
# clip by value is not supported
assert isinstance(self._optimizer._grad_clip, (paddle.nn.ClipGradByGlobalNorm, paddle.nn.ClipGradByNorm)), \
"Only GradientClipByNorm and GradientClipByGlobalNorm are supported in amp training with level=O2 currently."

self._adapter._amp_custom_lists = {}
self._adapter._amp_configs = {}
Expand Down Expand Up @@ -1479,7 +1516,6 @@ def prepare(self, optimizer=None, loss=None, metrics=None,
Returns:
None
"""

self._place = _get_device()
if isinstance(self._place, fluid.CUDAPlace):
global _parallel_context_initialized
Expand Down Expand Up @@ -1515,8 +1551,7 @@ def prepare(self, optimizer=None, loss=None, metrics=None,
self._metrics = to_list(metrics)
self._prepare_amp(amp_configs)

if not in_dygraph_mode():
self._adapter.prepare()
self._adapter.prepare()

def fit(self,
train_data=None,
Expand Down Expand Up @@ -1667,7 +1702,6 @@ def fit(self,
epochs=2,
save_dir='mnist_checkpoint')
"""

assert train_data is not None, \
"train_data must be given!"

Expand Down
Loading

0 comments on commit 08248db

Please sign in to comment.