|
4 | 4 | import os |
5 | 5 | import random |
6 | 6 | import time |
| 7 | +import warnings |
7 | 8 | from copy import deepcopy |
8 | 9 | from pathlib import Path |
9 | 10 | from threading import Thread |
@@ -323,18 +324,19 @@ def train(hyp, opt, device, tb_writer=None): |
323 | 324 | mloss = (mloss * i + loss_items) / (i + 1) # update mean losses |
324 | 325 | mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) |
325 | 326 | s = ('%10s' * 2 + '%10.4g' * 6) % ( |
326 | | - '%g/%g' % (epoch, epochs - 1), mem, *mloss, targets.shape[0], imgs.shape[-1]) |
| 327 | + f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]) |
327 | 328 | pbar.set_description(s) |
328 | 329 |
|
329 | 330 | # Plot |
330 | 331 | if plots and ni < 3: |
331 | 332 | f = save_dir / f'train_batch{ni}.jpg' # filename |
332 | 333 | Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start() |
333 | | - if tb_writer: |
334 | | - tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # model graph |
335 | | - # tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch) |
| 334 | + if tb_writer and ni == 0: |
| 335 | + with warnings.catch_warnings(): |
| 336 | + warnings.simplefilter('ignore') # suppress jit trace warning |
| 337 | + tb_writer.add_graph(torch.jit.trace(de_parallel(model), imgs, strict=False), []) # graph |
336 | 338 | elif plots and ni == 10 and wandb_logger.wandb: |
337 | | - wandb_logger.log({"Mosaics": [wandb_logger.wandb.Image(str(x), caption=x.name) for x in |
| 339 | + wandb_logger.log({'Mosaics': [wandb_logger.wandb.Image(str(x), caption=x.name) for x in |
338 | 340 | save_dir.glob('train*.jpg') if x.exists()]}) |
339 | 341 |
|
340 | 342 | # end batch ------------------------------------------------------------------------------------------------ |
|
0 commit comments