Skip to content

Commit 8665d55

Browse files
authored
Threaded TensorBoard graph logging (#9070)
* Log TensorBoard graph on pretrain_routine_end * fix
1 parent 0b8639a commit 8665d55

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
219219
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
220220
model.half().float() # pre-reduce anchor precision
221221

222-
callbacks.run('on_pretrain_routine_end', labels, names, plots)
222+
callbacks.run('on_pretrain_routine_end', labels, names)
223223

224224
# DDP mode
225225
if cuda and RANK != -1:
@@ -328,7 +328,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
328328
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
329329
pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
330330
(f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
331-
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots)
331+
callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
332332
if callbacks.stop_training:
333333
return
334334
# end batch ------------------------------------------------------------------------------------------------
@@ -420,7 +420,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
420420
if is_coco:
421421
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
422422

423-
callbacks.run('on_train_end', last, best, plots, epoch, results)
423+
callbacks.run('on_train_end', last, best, epoch, results)
424424

425425
torch.cuda.empty_cache()
426426
return results

utils/loggers/__init__.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None,
4949
self.weights = weights
5050
self.opt = opt
5151
self.hyp = hyp
52+
self.plots = not opt.noplots # plot results
5253
self.logger = logger # for printing results to console
5354
self.include = include
5455
self.keys = [
@@ -110,26 +111,26 @@ def on_train_start(self):
110111
# Callback runs on train start
111112
pass
112113

113-
def on_pretrain_routine_end(self, labels, names, plots):
114+
def on_pretrain_routine_end(self, labels, names):
114115
# Callback runs on pre-train routine end
115-
if plots:
116+
if self.plots:
116117
plot_labels(labels, names, self.save_dir)
117-
paths = self.save_dir.glob('*labels*.jpg') # training labels
118-
if self.wandb:
119-
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
120-
# if self.clearml:
121-
# pass # ClearML saves these images automatically using hooks
118+
paths = self.save_dir.glob('*labels*.jpg') # training labels
119+
if self.wandb:
120+
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
121+
# if self.clearml:
122+
# pass # ClearML saves these images automatically using hooks
122123

123-
def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
124+
def on_train_batch_end(self, model, ni, imgs, targets, paths):
124125
# Callback runs on train batch end
125126
# ni: number integrated batches (since train start)
126-
if plots:
127-
if ni == 0 and not self.opt.sync_bn and self.tb:
128-
log_tensorboard_graph(self.tb, model, imgsz=list(imgs.shape[2:4]))
127+
if self.plots:
129128
if ni < 3:
130129
f = self.save_dir / f'train_batch{ni}.jpg' # filename
131130
plot_images(imgs, targets, paths, f)
132-
if (self.wandb or self.clearml) and ni == 10:
131+
if ni == 0 and self.tb and not self.opt.sync_bn:
132+
log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
133+
if ni == 10 and (self.wandb or self.clearml):
133134
files = sorted(self.save_dir.glob('train*.jpg'))
134135
if self.wandb:
135136
self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]})
@@ -197,9 +198,9 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
197198
model_name='Latest Model',
198199
auto_delete_file=False)
199200

200-
def on_train_end(self, last, best, plots, epoch, results):
201+
def on_train_end(self, last, best, epoch, results):
201202
# Callback runs on training end, i.e. saving best model
202-
if plots:
203+
if self.plots:
203204
plot_results(file=self.save_dir / 'results.csv') # save results.png
204205
files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
205206
files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
@@ -291,6 +292,7 @@ def log_model(self, model_path, epoch=0, metadata={}):
291292
wandb.log_artifact(art)
292293

293294

295+
@threaded
294296
def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
295297
# Log model graph to TensorBoard
296298
try:
@@ -300,5 +302,5 @@ def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
300302
with warnings.catch_warnings():
301303
warnings.simplefilter('ignore') # suppress jit trace warning
302304
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
303-
except Exception:
304-
print('WARNING: TensorBoard graph visualization failure')
305+
except Exception as e:
306+
print(f'WARNING: TensorBoard graph visualization failure {e}')

0 commit comments

Comments
 (0)