Skip to content

Commit d95a728

Browse files
Implement DDP static_graph=True (#6940)
* Implement DDP `static_graph=True` Experimental implementation of new PyTorch 1.11.0 DDP feature. * Add 1.11.0 check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f3fecf9 commit d95a728

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@
4747
from utils.datasets import create_dataloader
4848
from utils.downloads import attempt_download
4949
from utils.general import (LOGGER, check_dataset, check_file, check_git_status, check_img_size, check_requirements,
50-
check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
51-
intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights, methods,
52-
one_cycle, print_args, print_mutation, strip_optimizer)
50+
check_suffix, check_version, check_yaml, colorstr, get_latest_run, increment_path,
51+
init_seeds, intersect_dicts, is_ascii, labels_to_class_weights, labels_to_image_weights,
52+
methods, one_cycle, print_args, print_mutation, strip_optimizer)
5353
from utils.loggers import Loggers
5454
from utils.loggers.wandb.wandb_utils import check_wandb_resume
5555
from utils.loss import ComputeLoss
@@ -269,7 +269,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
269269

270270
# DDP mode
271271
if cuda and RANK != -1:
272-
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
272+
if check_version(torch.__version__, '1.11.0'):
273+
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
274+
else:
275+
model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
273276

274277
# Model attributes
275278
nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)

0 commit comments

Comments
 (0)