diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 258eced184..462b00e640 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -16,12 +16,14 @@ import time from collections import deque import shutil +from copy import deepcopy import paddle import paddle.nn.functional as F from paddleseg.utils import (TimeAverager, calculate_eta, resume, logger, - worker_init_fn, train_profiler, op_flops_funs) + worker_init_fn, train_profiler, op_flops_funs, + init_ema_params, update_ema_model) from paddleseg.core.val import evaluate @@ -68,6 +70,7 @@ def train(model, log_iters=10, num_workers=0, use_vdl=False, + use_ema=False, losses=None, keep_checkpoint_max=5, test_config=None, @@ -102,6 +105,12 @@ def train(model, profiler_options (str, optional): The option of train profiler. to_static_training (bool, optional): Whether to use @to_static for training. """ + if use_ema: + ema_model = deepcopy(model) + ema_model.eval() + for param in ema_model.parameters(): + param.stop_gradient = True + model.train() nranks = paddle.distributed.ParallelEnv().nranks local_rank = paddle.distributed.ParallelEnv().local_rank @@ -154,14 +163,16 @@ def train(model, avg_loss_list = [] iters_per_epoch = len(batch_sampler) best_mean_iou = -1.0 + best_ema_mean_iou = -1.0 best_model_iter = -1 reader_cost_averager = TimeAverager() batch_cost_averager = TimeAverager() save_models = deque() batch_start = time.time() - iter = start_iter while iter < iters: + if iter == start_iter and use_ema: + init_ema_params(ema_model, model) for data in loader: iter += 1 if iter > iters: @@ -290,6 +301,9 @@ def train(model, reader_cost_averager.reset() batch_cost_averager.reset() + if use_ema: + update_ema_model(ema_model, model, step=iter) + if (iter % save_interval == 0 or iter == iters) and (val_dataset is not None): num_workers = 1 if num_workers > 0 else 0 @@ -305,6 +319,15 @@ def train(model, amp_level=amp_level, **test_config) + if use_ema: + ema_mean_iou, ema_acc, _, _, _ = evaluate( + ema_model, + val_dataset, + num_workers=num_workers, + precision=precision, + amp_level=amp_level, + **test_config) + model.train() if (iter % save_interval == 0 or iter == iters) and local_rank == 0: @@ -316,6 +339,12 @@ def train(model, os.path.join(current_save_dir, 'model.pdparams')) paddle.save(optimizer.state_dict(), os.path.join(current_save_dir, 'model.pdopt')) + + if use_ema: + paddle.save( + ema_model.state_dict(), + os.path.join(current_save_dir, 'ema_model.pdparams')) + save_models.append(current_save_dir) if len(save_models) > keep_checkpoint_max > 0: model_to_remove = save_models.popleft() @@ -332,10 +361,27 @@ def train(model, logger.info( '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.' .format(best_mean_iou, best_model_iter)) + if use_ema: + if ema_mean_iou > best_ema_mean_iou: + best_ema_mean_iou = ema_mean_iou + best_ema_model_iter = iter + best_ema_model_dir = os.path.join(save_dir, + "ema_best_model") + paddle.save(ema_model.state_dict(), + os.path.join(best_ema_model_dir, + 'ema_model.pdparams')) + logger.info( + '[EVAL] The EMA model with the best validation mIoU ({:.4f}) was saved at iter {}.' + .format(best_ema_mean_iou, best_ema_model_iter)) if use_vdl: log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter) log_writer.add_scalar('Evaluate/Acc', acc, iter) + if use_ema: + log_writer.add_scalar('Evaluate/Ema_mIoU', + ema_mean_iou, iter) + log_writer.add_scalar('Evaluate/Ema_Acc', ema_acc, + iter) batch_start = time.time() # Calculate flops. diff --git a/paddleseg/utils/__init__.py b/paddleseg/utils/__init__.py index dc01765dcc..a88f009298 100644 --- a/paddleseg/utils/__init__.py +++ b/paddleseg/utils/__init__.py @@ -19,4 +19,4 @@ from .utils import * from .timer import TimeAverager, calculate_eta from . import visualize -from .ema import EMA +from .ema import * diff --git a/paddleseg/utils/ema.py b/paddleseg/utils/ema.py index 861200c3f0..26636e1056 100644 --- a/paddleseg/utils/ema.py +++ b/paddleseg/utils/ema.py @@ -16,89 +16,33 @@ import paddle -class EMA(object): - """ - The implementation of Exponential Moving Average for the trainable parameters. - - Args: - model (nn.Layer): The model for applying EMA. - decay (float, optional): Decay is used to calculate ema_variable by - `ema_variable = decay * ema_variable + (1 - decay) * new_variable`. - Default: 0.99. - - Returns: - None - - Examples: - .. code-block:: python - - # 1. Define model and dataset - - # 2. Create EMA - ema = EMA(model, decay=0.99) - - # 3. Train stage - for data in dataloader(): - ... - optimizer.step() - ema.step() - - # 4. Evaluate stage - ema.apply() # Use the EMA data to replace the origin data - - for data in dataloader(): - ... - - ema.restore() # Restore the origin data to the model - - """ - - def __init__(self, model, decay=0.99): - super().__init__() - - assert isinstance(model, paddle.nn.Layer), \ - "The model should be the instance of paddle.nn.Layer." - assert decay >= 0 and decay <= 1.0, \ - "The decay = {} should in [0.0, 1.0]".format(decay) - - self._model = model - self._decay = decay - self._ema_data = {} - self._backup_data = {} - - for name, param in self._model.named_parameters(): - if not param.stop_gradient: - self._ema_data[name] = param.numpy() - - def step(self): - """ - Calculate the EMA data for all trainable parameters. - """ - for name, param in self._model.named_parameters(): - if not param.stop_gradient: - assert name in self._ema_data, \ - "The param ({}) isn't in the model".format(name) - self._ema_data[name] = self._decay * self._ema_data[name] \ - + (1.0 - self._decay) * param.numpy() - - def apply(self): - """ - Save the origin data and use the EMA data to replace the origin data. - """ - for name, param in self._model.named_parameters(): - if not param.stop_gradient: - assert name in self._ema_data, \ - "The param ({}) isn't in the model".format(name) - self._backup_data[name] = param.numpy() - param.set_value(self._ema_data[name]) - - def restore(self): - """ - Restore the origin data to the model. - """ - for name, param in self._model.named_parameters(): - if not param.stop_gradient: - assert name in self._backup_data, \ - "The param ({}) isn't in the model".format(name) - param.set_value(self._backup_data[name]) - self._backup_data = {} +def judge_params_equal(ema_model, model): + for ema_param, param in zip(ema_model.named_parameters(), + model.named_parameters()): + if not paddle.equal_all(ema_param[1], param[1]): + # print("Difference in", ema_param[0]) + return False + return True + + +def init_ema_params(ema_model, model): + state = {} + msd = model.state_dict() + for k, v in ema_model.state_dict().items(): + if paddle.is_floating_point(v): + v = msd[k].detach() + state[k] = v + ema_model.set_state_dict(state) + + +def update_ema_model(ema_model, model, step=0, decay=0.999): + with paddle.no_grad(): + state = {} + decay = min(1 - 1 / (step + 1), decay) + msd = model.state_dict() + for k, v in ema_model.state_dict().items(): + if paddle.is_floating_point(v): + v *= decay + v += (1.0 - decay) * msd[k].detach() + state[k] = v + ema_model.set_state_dict(state) \ No newline at end of file diff --git a/tools/train.py b/tools/train.py index d12d6878ec..09d864499a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -53,6 +53,10 @@ def parse_args(): '--use_vdl', help='Whether to record the data to VisualDL in training.', action='store_true') + parser.add_argument( + '--use_ema', + help='Whether to ema the model in training.', + action='store_true') # Runntime params parser.add_argument( @@ -176,6 +180,7 @@ def main(args): log_iters=args.log_iters, num_workers=args.num_workers, use_vdl=args.use_vdl, + use_ema=args.use_ema, losses=loss, keep_checkpoint_max=args.keep_checkpoint_max, test_config=cfg.test_config,