Skip to content

Commit 5815414

Browse files
authored
Merge pull request #27 from Xreki/benchmark
Calculate the average time for benchmark.
2 parents e41decb + 12c78ee commit 5815414

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed

ppgan/engine/trainer.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..models.builder import build_model
1212
from ..utils.visual import tensor2img, save_image
1313
from ..utils.filesystem import save, load, makedirs
14+
from ..utils.timer import TimeAverager
1415
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
1516

1617

@@ -61,30 +62,37 @@ def distributed_data_parallel(self):
6162
paddle.DataParallel(net, strategy))
6263

6364
def train(self):
65+
reader_cost_averager = TimeAverager()
66+
batch_cost_averager = TimeAverager()
6467

6568
for epoch in range(self.start_epoch, self.epochs):
6669
self.current_epoch = epoch
6770
start_time = step_start_time = time.time()
6871
for i, data in enumerate(self.train_dataloader):
69-
data_time = time.time()
72+
reader_cost_averager.record(time.time() - step_start_time)
73+
7074
self.batch_id = i
7175
# unpack data from dataset and apply preprocessing
7276
# data input should be dict
7377
self.model.set_input(data)
7478
self.model.optimize_parameters()
7579

76-
self.data_time = data_time - step_start_time
77-
self.step_time = time.time() - step_start_time
80+
batch_cost_averager.record(time.time() - step_start_time)
7881
if i % self.log_interval == 0:
82+
self.data_time = reader_cost_averager.get_average()
83+
self.step_time = batch_cost_averager.get_average()
7984
self.print_log()
8085

86+
reader_cost_averager.reset()
87+
batch_cost_averager.reset()
88+
8189
if i % self.visual_interval == 0:
8290
self.visual('visual_train')
8391

8492
step_start_time = time.time()
8593

86-
self.logger.info('train one epoch time: {}'.format(time.time() -
87-
start_time))
94+
self.logger.info(
95+
'train one epoch time: {}'.format(time.time() - start_time))
8896
if self.validate_interval > -1 and epoch % self.validate_interval:
8997
self.validate()
9098
self.model.lr_scheduler.step()
@@ -94,8 +102,8 @@ def train(self):
94102

95103
def validate(self):
96104
if not hasattr(self, 'val_dataloader'):
97-
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
98-
is_train=False)
105+
self.val_dataloader = build_dataloader(
106+
self.cfg.dataset.val, is_train=False)
99107

100108
metric_result = {}
101109

@@ -141,8 +149,8 @@ def validate(self):
141149
self.visual('visual_val', visual_results=visual_results)
142150

143151
if i % self.log_interval == 0:
144-
self.logger.info('val iter: [%d/%d]' %
145-
(i, len(self.val_dataloader)))
152+
self.logger.info(
153+
'val iter: [%d/%d]' % (i, len(self.val_dataloader)))
146154

147155
for metric_name in metric_result.keys():
148156
metric_result[metric_name] /= len(self.val_dataloader.dataset)
@@ -152,8 +160,8 @@ def validate(self):
152160

153161
def test(self):
154162
if not hasattr(self, 'test_dataloader'):
155-
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
156-
is_train=False)
163+
self.test_dataloader = build_dataloader(
164+
self.cfg.dataset.test, is_train=False)
157165

158166
# data[0]: img, data[1]: img path index
159167
# test batch size must be 1
@@ -177,8 +185,8 @@ def test(self):
177185
self.visual('visual_test', visual_results=visual_results)
178186

179187
if i % self.log_interval == 0:
180-
self.logger.info('Test iter: [%d/%d]' %
181-
(i, len(self.test_dataloader)))
188+
self.logger.info(
189+
'Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
182190

183191
def print_log(self):
184192
losses = self.model.get_current_losses()

ppgan/utils/timer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
17+
18+
class TimeAverager(object):
19+
def __init__(self):
20+
self.reset()
21+
22+
def reset(self):
23+
self._cnt = 0
24+
self._total_time = 0
25+
26+
def record(self, usetime):
27+
self._cnt += 1
28+
self._total_time += usetime
29+
30+
def get_average(self):
31+
if self._cnt == 0:
32+
return 0
33+
return self._total_time / self._cnt

0 commit comments

Comments
 (0)