diff --git a/requirements.txt b/requirements.txt index cb50cf8f32e..51de7735d30 100755 --- a/requirements.txt +++ b/requirements.txt @@ -8,12 +8,12 @@ opencv-python>=4.1.2 Pillow PyYAML>=5.3.1 scipy>=1.4.1 -tensorboard>=2.2 torch>=1.7.0 torchvision>=0.8.1 tqdm>=4.41.0 # logging ------------------------------------- +tensorboard>=2.4.1 # wandb # plotting ------------------------------------ diff --git a/train.py b/train.py index fd2d6745ab4..b9e4eea613d 100644 --- a/train.py +++ b/train.py @@ -1,4 +1,3 @@ - import argparse import logging import math @@ -34,7 +33,7 @@ from utils.loss import ComputeLoss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel -from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id, check_wandb_config_file +from utils.wandb_logging.wandb_utils import WandbLogger, resume_and_get_id logger = logging.getLogger(__name__) @@ -75,7 +74,7 @@ def train(hyp, opt, device, tb_writer=None): data_dict = wandb_logger.data_dict if wandb_logger.wandb: weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming - + nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check @@ -405,7 +404,7 @@ def train(hyp, opt, device, tb_writer=None): wandb_logger.log_model( last.parent, opt, epoch, fi, best_model=best_fitness == fi) del ckpt - + # end epoch ---------------------------------------------------------------------------------------------------- # end training if rank in [-1, 0]: @@ -534,7 +533,8 @@ def train(hyp, opt, device, tb_writer=None): if not opt.evolve: tb_writer = None # init loggers if opt.global_rank in [-1, 0]: - logger.info(f'Start Tensorboard with "tensorboard --logdir {opt.project}", view at http://localhost:6006/') + prefix = colorstr('tensorboard: ') + logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/") tb_writer = SummaryWriter(opt.save_dir) # Tensorboard train(hyp, opt, device, tb_writer)