Skip to content

Commit

Permalink
[Feature EMA] update ema method (PaddlePaddle#3041)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sunting78 authored Mar 20, 2023
1 parent 574be6f commit 1d1b7e0
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 89 deletions.
50 changes: 48 additions & 2 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
import time
from collections import deque
import shutil
from copy import deepcopy

import paddle
import paddle.nn.functional as F

from paddleseg.utils import (TimeAverager, calculate_eta, resume, logger,
worker_init_fn, train_profiler, op_flops_funs)
worker_init_fn, train_profiler, op_flops_funs,
init_ema_params, update_ema_model)
from paddleseg.core.val import evaluate


Expand Down Expand Up @@ -68,6 +70,7 @@ def train(model,
log_iters=10,
num_workers=0,
use_vdl=False,
use_ema=False,
losses=None,
keep_checkpoint_max=5,
test_config=None,
Expand Down Expand Up @@ -102,6 +105,12 @@ def train(model,
profiler_options (str, optional): The option of train profiler.
to_static_training (bool, optional): Whether to use @to_static for training.
"""
if use_ema:
ema_model = deepcopy(model)
ema_model.eval()
for param in ema_model.parameters():
param.stop_gradient = True

model.train()
nranks = paddle.distributed.ParallelEnv().nranks
local_rank = paddle.distributed.ParallelEnv().local_rank
Expand Down Expand Up @@ -154,14 +163,16 @@ def train(model,
avg_loss_list = []
iters_per_epoch = len(batch_sampler)
best_mean_iou = -1.0
best_ema_mean_iou = -1.0
best_model_iter = -1
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
save_models = deque()
batch_start = time.time()

iter = start_iter
while iter < iters:
if iter == start_iter and use_ema:
init_ema_params(ema_model, model)
for data in loader:
iter += 1
if iter > iters:
Expand Down Expand Up @@ -290,6 +301,9 @@ def train(model,
reader_cost_averager.reset()
batch_cost_averager.reset()

if use_ema:
update_ema_model(ema_model, model, step=iter)

if (iter % save_interval == 0 or
iter == iters) and (val_dataset is not None):
num_workers = 1 if num_workers > 0 else 0
Expand All @@ -305,6 +319,15 @@ def train(model,
amp_level=amp_level,
**test_config)

if use_ema:
ema_mean_iou, ema_acc, _, _, _ = evaluate(
ema_model,
val_dataset,
num_workers=num_workers,
precision=precision,
amp_level=amp_level,
**test_config)

model.train()

if (iter % save_interval == 0 or iter == iters) and local_rank == 0:
Expand All @@ -316,6 +339,12 @@ def train(model,
os.path.join(current_save_dir, 'model.pdparams'))
paddle.save(optimizer.state_dict(),
os.path.join(current_save_dir, 'model.pdopt'))

if use_ema:
paddle.save(
ema_model.state_dict(),
os.path.join(current_save_dir, 'ema_model.pdparams'))

save_models.append(current_save_dir)
if len(save_models) > keep_checkpoint_max > 0:
model_to_remove = save_models.popleft()
Expand All @@ -332,10 +361,27 @@ def train(model,
logger.info(
'[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_mean_iou, best_model_iter))
if use_ema:
if ema_mean_iou > best_ema_mean_iou:
best_ema_mean_iou = ema_mean_iou
best_ema_model_iter = iter
best_ema_model_dir = os.path.join(save_dir,
"ema_best_model")
paddle.save(ema_model.state_dict(),
os.path.join(best_ema_model_dir,
'ema_model.pdparams'))
logger.info(
'[EVAL] The EMA model with the best validation mIoU ({:.4f}) was saved at iter {}.'
.format(best_ema_mean_iou, best_ema_model_iter))

if use_vdl:
log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
log_writer.add_scalar('Evaluate/Acc', acc, iter)
if use_ema:
log_writer.add_scalar('Evaluate/Ema_mIoU',
ema_mean_iou, iter)
log_writer.add_scalar('Evaluate/Ema_Acc', ema_acc,
iter)
batch_start = time.time()

# Calculate flops.
Expand Down
2 changes: 1 addition & 1 deletion paddleseg/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
from .utils import *
from .timer import TimeAverager, calculate_eta
from . import visualize
from .ema import EMA
from .ema import *
116 changes: 30 additions & 86 deletions paddleseg/utils/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,89 +16,33 @@
import paddle


class EMA(object):
"""
The implementation of Exponential Moving Average for the trainable parameters.
Args:
model (nn.Layer): The model for applying EMA.
decay (float, optional): Decay is used to calculate ema_variable by
`ema_variable = decay * ema_variable + (1 - decay) * new_variable`.
Default: 0.99.
Returns:
None
Examples:
.. code-block:: python
# 1. Define model and dataset
# 2. Create EMA
ema = EMA(model, decay=0.99)
# 3. Train stage
for data in dataloader():
...
optimizer.step()
ema.step()
# 4. Evaluate stage
ema.apply() # Use the EMA data to replace the origin data
for data in dataloader():
...
ema.restore() # Restore the origin data to the model
"""

def __init__(self, model, decay=0.99):
super().__init__()

assert isinstance(model, paddle.nn.Layer), \
"The model should be the instance of paddle.nn.Layer."
assert decay >= 0 and decay <= 1.0, \
"The decay = {} should in [0.0, 1.0]".format(decay)

self._model = model
self._decay = decay
self._ema_data = {}
self._backup_data = {}

for name, param in self._model.named_parameters():
if not param.stop_gradient:
self._ema_data[name] = param.numpy()

def step(self):
"""
Calculate the EMA data for all trainable parameters.
"""
for name, param in self._model.named_parameters():
if not param.stop_gradient:
assert name in self._ema_data, \
"The param ({}) isn't in the model".format(name)
self._ema_data[name] = self._decay * self._ema_data[name] \
+ (1.0 - self._decay) * param.numpy()

def apply(self):
"""
Save the origin data and use the EMA data to replace the origin data.
"""
for name, param in self._model.named_parameters():
if not param.stop_gradient:
assert name in self._ema_data, \
"The param ({}) isn't in the model".format(name)
self._backup_data[name] = param.numpy()
param.set_value(self._ema_data[name])

def restore(self):
"""
Restore the origin data to the model.
"""
for name, param in self._model.named_parameters():
if not param.stop_gradient:
assert name in self._backup_data, \
"The param ({}) isn't in the model".format(name)
param.set_value(self._backup_data[name])
self._backup_data = {}
def judge_params_equal(ema_model, model):
for ema_param, param in zip(ema_model.named_parameters(),
model.named_parameters()):
if not paddle.equal_all(ema_param[1], param[1]):
# print("Difference in", ema_param[0])
return False
return True


def init_ema_params(ema_model, model):
state = {}
msd = model.state_dict()
for k, v in ema_model.state_dict().items():
if paddle.is_floating_point(v):
v = msd[k].detach()
state[k] = v
ema_model.set_state_dict(state)


def update_ema_model(ema_model, model, step=0, decay=0.999):
with paddle.no_grad():
state = {}
decay = min(1 - 1 / (step + 1), decay)
msd = model.state_dict()
for k, v in ema_model.state_dict().items():
if paddle.is_floating_point(v):
v *= decay
v += (1.0 - decay) * msd[k].detach()
state[k] = v
ema_model.set_state_dict(state)
5 changes: 5 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def parse_args():
'--use_vdl',
help='Whether to record the data to VisualDL in training.',
action='store_true')
parser.add_argument(
'--use_ema',
help='Whether to ema the model in training.',
action='store_true')

# Runntime params
parser.add_argument(
Expand Down Expand Up @@ -176,6 +180,7 @@ def main(args):
log_iters=args.log_iters,
num_workers=args.num_workers,
use_vdl=args.use_vdl,
use_ema=args.use_ema,
losses=loss,
keep_checkpoint_max=args.keep_checkpoint_max,
test_config=cfg.test_config,
Expand Down

0 comments on commit 1d1b7e0

Please sign in to comment.