Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions utils/loggers/wandb/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import sys
from contextlib import contextmanager
from pathlib import Path
import pkg_resources as pkg

import pkg_resources as pkg
import yaml
from tqdm import tqdm

Expand All @@ -19,16 +19,17 @@
from utils.datasets import img2label_paths
from utils.general import check_dataset, check_file

RANK = int(os.getenv('RANK', -1))

try:
import wandb

assert hasattr(wandb, '__version__') # verify package import not local dir
if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2'):
if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in [0, -1]:
wandb.login(timeout=30)
except (ImportError, AssertionError):
wandb = None

RANK = int(os.getenv('RANK', -1))
WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'


Expand All @@ -48,9 +49,11 @@ def check_wandb_dataset(data_file):
if check_file(data_file) and data_file.endswith('.yaml'):
with open(data_file, errors='ignore') as f:
data_dict = yaml.safe_load(f)
is_wandb_artifact = (data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX) or
data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX))
if is_wandb_artifact:
is_trainset_wandb_artifact = (isinstance(data_dict['train'], str) and
data_dict['train'].startswith(WANDB_ARTIFACT_PREFIX))
is_valset_wandb_artifact = (isinstance(data_dict['val'], str) and
data_dict['val'].startswith(WANDB_ARTIFACT_PREFIX))
if is_trainset_wandb_artifact or is_valset_wandb_artifact:
return data_dict
else:
return check_dataset(data_file)
Expand Down