@@ -316,7 +316,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
316
316
train_loader .sampler .set_epoch (epoch )
317
317
pbar = enumerate (train_loader )
318
318
LOGGER .info (('\n ' + '%10s' * 7 ) % ('Epoch' , 'gpu_mem' , 'box' , 'obj' , 'cls' , 'labels' , 'img_size' ))
319
- if RANK in [ - 1 , 0 ] :
319
+ if RANK in ( - 1 , 0 ) :
320
320
pbar = tqdm (pbar , total = nb , bar_format = '{l_bar}{bar:10}{r_bar}{bar:-10b}' ) # progress bar
321
321
optimizer .zero_grad ()
322
322
for i , (imgs , targets , paths , _ ) in pbar : # batch -------------------------------------------------------------
@@ -365,7 +365,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
365
365
last_opt_step = ni
366
366
367
367
# Log
368
- if RANK in [ - 1 , 0 ] :
368
+ if RANK in ( - 1 , 0 ) :
369
369
mloss = (mloss * i + loss_items ) / (i + 1 ) # update mean losses
370
370
mem = f'{ torch .cuda .memory_reserved () / 1E9 if torch .cuda .is_available () else 0 :.3g} G' # (GB)
371
371
pbar .set_description (('%10s' * 2 + '%10.4g' * 5 ) %
@@ -379,7 +379,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
379
379
lr = [x ['lr' ] for x in optimizer .param_groups ] # for loggers
380
380
scheduler .step ()
381
381
382
- if RANK in [ - 1 , 0 ] :
382
+ if RANK in ( - 1 , 0 ) :
383
383
# mAP
384
384
callbacks .run ('on_train_epoch_end' , epoch = epoch )
385
385
ema .update_attr (model , include = ['yaml' , 'nc' , 'hyp' , 'names' , 'stride' , 'class_weights' ])
@@ -440,7 +440,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
440
440
441
441
# end epoch ----------------------------------------------------------------------------------------------------
442
442
# end training -----------------------------------------------------------------------------------------------------
443
- if RANK in [ - 1 , 0 ] :
443
+ if RANK in ( - 1 , 0 ) :
444
444
LOGGER .info (f'\n { epoch - start_epoch + 1 } epochs completed in { (time .time () - t0 ) / 3600 :.3f} hours.' )
445
445
for f in last , best :
446
446
if f .exists ():
@@ -518,7 +518,7 @@ def parse_opt(known=False):
518
518
519
519
def main (opt , callbacks = Callbacks ()):
520
520
# Checks
521
- if RANK in [ - 1 , 0 ] :
521
+ if RANK in ( - 1 , 0 ) :
522
522
print_args (vars (opt ))
523
523
check_git_status ()
524
524
check_requirements (exclude = ['thop' ])
0 commit comments