1111from ..models .builder import build_model
1212from ..utils .visual import tensor2img , save_image
1313from ..utils .filesystem import save , load , makedirs
14+ from ..utils .timer import TimeAverager
1415from ..metric .psnr_ssim import calculate_psnr , calculate_ssim
1516
1617
@@ -61,30 +62,37 @@ def distributed_data_parallel(self):
6162 paddle .DataParallel (net , strategy ))
6263
6364 def train (self ):
65+ reader_cost_averager = TimeAverager ()
66+ batch_cost_averager = TimeAverager ()
6467
6568 for epoch in range (self .start_epoch , self .epochs ):
6669 self .current_epoch = epoch
6770 start_time = step_start_time = time .time ()
6871 for i , data in enumerate (self .train_dataloader ):
69- data_time = time .time ()
72+ reader_cost_averager .record (time .time () - step_start_time )
73+
7074 self .batch_id = i
7175 # unpack data from dataset and apply preprocessing
7276 # data input should be dict
7377 self .model .set_input (data )
7478 self .model .optimize_parameters ()
7579
76- self .data_time = data_time - step_start_time
77- self .step_time = time .time () - step_start_time
80+ batch_cost_averager .record (time .time () - step_start_time )
7881 if i % self .log_interval == 0 :
82+ self .data_time = reader_cost_averager .get_average ()
83+ self .step_time = batch_cost_averager .get_average ()
7984 self .print_log ()
8085
86+ reader_cost_averager .reset ()
87+ batch_cost_averager .reset ()
88+
8189 if i % self .visual_interval == 0 :
8290 self .visual ('visual_train' )
8391
8492 step_start_time = time .time ()
8593
86- self .logger .info ('train one epoch time: {}' . format ( time . time () -
87- start_time ))
94+ self .logger .info (
95+ 'train one epoch time: {}' . format ( time . time () - start_time ))
8896 if self .validate_interval > - 1 and epoch % self .validate_interval :
8997 self .validate ()
9098 self .model .lr_scheduler .step ()
@@ -94,8 +102,8 @@ def train(self):
94102
95103 def validate (self ):
96104 if not hasattr (self , 'val_dataloader' ):
97- self .val_dataloader = build_dataloader (self . cfg . dataset . val ,
98- is_train = False )
105+ self .val_dataloader = build_dataloader (
106+ self . cfg . dataset . val , is_train = False )
99107
100108 metric_result = {}
101109
@@ -141,8 +149,8 @@ def validate(self):
141149 self .visual ('visual_val' , visual_results = visual_results )
142150
143151 if i % self .log_interval == 0 :
144- self .logger .info ('val iter: [%d/%d]' %
145- (i , len (self .val_dataloader )))
152+ self .logger .info (
153+ 'val iter: [%d/%d]' % (i , len (self .val_dataloader )))
146154
147155 for metric_name in metric_result .keys ():
148156 metric_result [metric_name ] /= len (self .val_dataloader .dataset )
@@ -152,8 +160,8 @@ def validate(self):
152160
153161 def test (self ):
154162 if not hasattr (self , 'test_dataloader' ):
155- self .test_dataloader = build_dataloader (self . cfg . dataset . test ,
156- is_train = False )
163+ self .test_dataloader = build_dataloader (
164+ self . cfg . dataset . test , is_train = False )
157165
158166 # data[0]: img, data[1]: img path index
159167 # test batch size must be 1
@@ -177,8 +185,8 @@ def test(self):
177185 self .visual ('visual_test' , visual_results = visual_results )
178186
179187 if i % self .log_interval == 0 :
180- self .logger .info ('Test iter: [%d/%d]' %
181- (i , len (self .test_dataloader )))
188+ self .logger .info (
189+ 'Test iter: [%d/%d]' % (i , len (self .test_dataloader )))
182190
183191 def print_log (self ):
184192 losses = self .model .get_current_losses ()
0 commit comments