Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ opencv-python>=4.1.2
Pillow
PyYAML>=5.3.1
scipy>=1.4.1
tensorboard>=2.2
torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.41.0

# logging -------------------------------------
tensorboard>=2.4.1
# wandb

# plotting ------------------------------------
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import argparse
import logging
import math
Expand Down Expand Up @@ -34,7 +33,7 @@
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file
from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,7 +74,7 @@ def train(hyp, opt, device, tb_writer=None):
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming

nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
Expand Down Expand Up @@ -405,7 +404,7 @@ def train(hyp, opt, device, tb_writer=None):
wandb_logger.log_model(
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
del ckpt

# end epoch ----------------------------------------------------------------------------------------------------
# end training
if rank in [-1, 0]:
Expand Down Expand Up @@ -534,7 +533,8 @@ def train(hyp, opt, device, tb_writer=None):
if not opt.evolve:
tb_writer = None # init loggers
if opt.global_rank in [-1, 0]:
logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/')
prefix = colorstr('tensorboard: ')
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
train(hyp, opt, device, tb_writer)

Expand Down