Skip to content

Commit 135d24c

Browse files
committed
set pg names
Summary: - we need to pass the global rank information to pytorch so that the pg name can include the pg information - this is necessary to differentiate the default pg's on different replicas - these need to different because flight recorder matches collectives based on pg name as well
1 parent ad78ed8 commit 135d24c

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

torchtitan/distributed/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,10 @@ def maybe_enable_amp(
259259

260260

261261
def init_distributed(
262-
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = ""
262+
comm_config: CommConfig,
263+
enable_cpu_backend: bool = False,
264+
base_folder: str = "",
265+
ranks: list[int] | None = None,
263266
):
264267
def _warn_overwrite_env(env, val):
265268
if env in os.environ:
@@ -303,6 +306,7 @@ def _get_distributed_backend(enable_cpu_backend):
303306
torch.distributed.init_process_group(
304307
backend=_get_distributed_backend(enable_cpu_backend),
305308
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
309+
_ranks=ranks if ranks is not None else [],
306310
)
307311

308312

torchtitan/train.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,22 @@ def __init__(self, job_config: JobConfig):
8484
# Device has to be set before creating TorchFT manager.
8585
device_module.set_device(self.device)
8686

87+
# determine the global ranks when fault tolerance is enabled
88+
global_ranks = []
89+
ft_config = job_config.fault_tolerance
90+
if ft_config.enable:
91+
group_size = ft_config.group_size
92+
replica_id = ft_config.replica_id
93+
first_rank = replica_id * group_size
94+
last_rank = first_rank + group_size - 1
95+
global_ranks = list(range(first_rank, last_rank + 1))
96+
8797
# init distributed and build meshes
8898
dist_utils.init_distributed(
8999
job_config.comm,
90100
enable_cpu_backend=job_config.training.enable_cpu_offload,
91101
base_folder=job_config.job.dump_folder,
102+
ranks=global_ranks,
92103
)
93104

94105
job_config.maybe_log()

0 commit comments

Comments
 (0)