|
47 | 47 | from utils.datasets import create_dataloader
|
48 | 48 | from utils.downloads import attempt_download
|
49 | 49 | from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
|
50 |
| - check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds, |
51 |
| - intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods, |
52 |
| - one_cycle, print_args, print_mutation, strip_optimizer) |
| 50 | + check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path, |
| 51 | + init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, |
| 52 | + methods, one_cycle, print_args, print_mutation, strip_optimizer) |
53 | 53 | from utils.loggers import Loggers
|
54 | 54 | from utils.loggers.wandb.wandb_utils import check_wandb_resume
|
55 | 55 | from utils.loss import ComputeLoss
|
@@ -269,7 +269,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
|
269 | 269 |
|
270 | 270 | # DDP mode
|
271 | 271 | if cuda and RANK != -1:
|
272 |
| - model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) |
| 272 | + if check_version(torch.__version__, '1.11.0'): |
| 273 | + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True) |
| 274 | + else: |
| 275 | + model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) |
273 | 276 |
|
274 | 277 | # Model attributes
|
275 | 278 | nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
|
|
0 commit comments