@@ -49,6 +49,7 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None,
4949        self .weights  =  weights 
5050        self .opt  =  opt 
5151        self .hyp  =  hyp 
52+         self .plots  =  not  opt .noplots   # plot results 
5253        self .logger  =  logger   # for printing results to console 
5354        self .include  =  include 
5455        self .keys  =  [
@@ -110,26 +111,26 @@ def on_train_start(self):
110111        # Callback runs on train start 
111112        pass 
112113
113-     def  on_pretrain_routine_end (self , labels , names ,  plots ):
114+     def  on_pretrain_routine_end (self , labels , names ):
114115        # Callback runs on pre-train routine end 
115-         if  plots :
116+         if  self . plots :
116117            plot_labels (labels , names , self .save_dir )
117-         paths  =  self .save_dir .glob ('*labels*.jpg' )  # training labels 
118-         if  self .wandb :
119-             self .wandb .log ({"Labels" : [wandb .Image (str (x ), caption = x .name ) for  x  in  paths ]})
120-         # if self.clearml: 
121-         #    pass  # ClearML saves these images automatically using hooks 
118+              paths  =  self .save_dir .glob ('*labels*.jpg' )  # training labels 
119+              if  self .wandb :
120+                  self .wandb .log ({"Labels" : [wandb .Image (str (x ), caption = x .name ) for  x  in  paths ]})
121+              # if self.clearml: 
122+              #    pass  # ClearML saves these images automatically using hooks 
122123
123-     def  on_train_batch_end (self , ni ,  model , imgs , targets , paths ,  plots ):
124+     def  on_train_batch_end (self , model ,  ni , imgs , targets , paths ):
124125        # Callback runs on train batch end 
125126        # ni: number integrated batches (since train start) 
126-         if  plots :
127-             if  ni  ==  0  and  not  self .opt .sync_bn  and  self .tb :
128-                 log_tensorboard_graph (self .tb , model , imgsz = list (imgs .shape [2 :4 ]))
127+         if  self .plots :
129128            if  ni  <  3 :
130129                f  =  self .save_dir  /  f'train_batch{ ni }    # filename 
131130                plot_images (imgs , targets , paths , f )
132-             if  (self .wandb  or  self .clearml ) and  ni  ==  10 :
131+                 if  ni  ==  0  and  self .tb  and  not  self .opt .sync_bn :
132+                     log_tensorboard_graph (self .tb , model , imgsz = (self .opt .imgsz , self .opt .imgsz ))
133+             if  ni  ==  10  and  (self .wandb  or  self .clearml ):
133134                files  =  sorted (self .save_dir .glob ('train*.jpg' ))
134135                if  self .wandb :
135136                    self .wandb .log ({'Mosaics' : [wandb .Image (str (f ), caption = f .name ) for  f  in  files  if  f .exists ()]})
@@ -197,9 +198,9 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
197198                                                      model_name = 'Latest Model' ,
198199                                                      auto_delete_file = False )
199200
200-     def  on_train_end (self , last , best , plots ,  epoch , results ):
201+     def  on_train_end (self , last , best , epoch , results ):
201202        # Callback runs on training end, i.e. saving best model 
202-         if  plots :
203+         if  self . plots :
203204            plot_results (file = self .save_dir  /  'results.csv' )  # save results.png 
204205        files  =  ['results.png' , 'confusion_matrix.png' , * (f'{ x }   for  x  in  ('F1' , 'PR' , 'P' , 'R' ))]
205206        files  =  [(self .save_dir  /  f ) for  f  in  files  if  (self .save_dir  /  f ).exists ()]  # filter 
@@ -291,6 +292,7 @@ def log_model(self, model_path, epoch=0, metadata={}):
291292            wandb .log_artifact (art )
292293
293294
295+ @threaded  
294296def  log_tensorboard_graph (tb , model , imgsz = (640 , 640 )):
295297    # Log model graph to TensorBoard 
296298    try :
@@ -300,5 +302,5 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
300302        with  warnings .catch_warnings ():
301303            warnings .simplefilter ('ignore' )  # suppress jit trace warning 
302304            tb .add_graph (torch .jit .trace (de_parallel (model ), im , strict = False ), [])
303-     except  Exception :
304-         print ('WARNING: TensorBoard graph visualization failure' )
305+     except  Exception   as   e :
306+         print (f 'WARNING: TensorBoard graph visualization failure  { e } 
0 commit comments