Skip to content

Commit 4aa2959

Browse files
authored
Suppress jit trace warning + graph once (#3454)
* Suppress jit trace warning + graph once Suppress harmless jit trace warning on TensorBoard add_graph call. Also fix multiple add_graph() calls bug, now only on batch 0. * Update train.py
1 parent af2bc3a commit 4aa2959

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

train.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import random
66
import time
7+
import warnings
78
from copy import deepcopy
89
from pathlib import Path
910
from threading import Thread
@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None):
323324
mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
324325
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
325326
s = ('%10s' * 2 + '%10.4g' * 6) % (
326-
'%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1])
327+
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])
327328
pbar.set_description(s)
328329

329330
# Plot
330331
if plots and ni < 3:
331332
f = save_dir / f'train_batch{ni}.jpg' # filename
332333
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
333-
if tb_writer:
334-
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph
335-
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
334+
if tb_writer and ni == 0:
335+
with warnings.catch_warnings():
336+
warnings.simplefilter('ignore') # suppress jit trace warning
337+
tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph
336338
elif plots and ni == 10 and wandb_logger.wandb:
337-
wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
339+
wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in
338340
save_dir.glob('train*.jpg') if x.exists()]})
339341

340342
# end batch ------------------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)