Skip to content

Commit

Permalink
Merge pull request #27 from Xreki/benchmark
Browse files Browse the repository at this point in the history
Calculate the average time for benchmark.
  • Loading branch information
qingqing01 authored Sep 23, 2020
2 parents e41decb + 12c78ee commit 5815414
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
34 changes: 21 additions & 13 deletions ppgan/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs
from ..utils.timer import TimeAverager
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim


Expand Down Expand Up @@ -61,30 +62,37 @@ def distributed_data_parallel(self):
paddle.DataParallel(net, strategy))

def train(self):
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()

for epoch in range(self.start_epoch, self.epochs):
self.current_epoch = epoch
start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader):
data_time = time.time()
reader_cost_averager.record(time.time() - step_start_time)

self.batch_id = i
# unpack data from dataset and apply preprocessing
# data input should be dict
self.model.set_input(data)
self.model.optimize_parameters()

self.data_time = data_time - step_start_time
self.step_time = time.time() - step_start_time
batch_cost_averager.record(time.time() - step_start_time)
if i % self.log_interval == 0:
self.data_time = reader_cost_averager.get_average()
self.step_time = batch_cost_averager.get_average()
self.print_log()

reader_cost_averager.reset()
batch_cost_averager.reset()

if i % self.visual_interval == 0:
self.visual('visual_train')

step_start_time = time.time()

self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
self.logger.info(
'train one epoch time: {}'.format(time.time() - start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step()
Expand All @@ -94,8 +102,8 @@ def train(self):

def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
is_train=False)
self.val_dataloader = build_dataloader(
self.cfg.dataset.val, is_train=False)

metric_result = {}

Expand Down Expand Up @@ -141,8 +149,8 @@ def validate(self):
self.visual('visual_val', visual_results=visual_results)

if i % self.log_interval == 0:
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
self.logger.info(
'val iter: [%d/%d]' % (i, len(self.val_dataloader)))

for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
Expand All @@ -152,8 +160,8 @@ def validate(self):

def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
self.test_dataloader = build_dataloader(
self.cfg.dataset.test, is_train=False)

# data[0]: img, data[1]: img path index
# test batch size must be 1
Expand All @@ -177,8 +185,8 @@ def test(self):
self.visual('visual_test', visual_results=visual_results)

if i % self.log_interval == 0:
self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
self.logger.info(
'Test iter: [%d/%d]' % (i, len(self.test_dataloader)))

def print_log(self):
losses = self.model.get_current_losses()
Expand Down
33 changes: 33 additions & 0 deletions ppgan/utils/timer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time


class TimeAverager(object):
def __init__(self):
self.reset()

def reset(self):
self._cnt = 0
self._total_time = 0

def record(self, usetime):
self._cnt += 1
self._total_time += usetime

def get_average(self):
if self._cnt == 0:
return 0
return self._total_time / self._cnt

0 comments on commit 5815414

Please sign in to comment.