Skip to content

[RayTrainer] Severe stall when setting up process group for: env: // #58878

@karriganastA

Description

@karriganastA

What happened + What you expected to happen

Image

When running RayTorchTrainer on some server nodes, there is a long delay when transitioning from setting up the process group to starting distributed worker processes, while other server nodes experience smooth operation. What could be the cause?

Versions / Dependencies

Ray: 2.50.0
torch: 2.8.0

Image

Reproduction script

from accelerate import Accelerator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

try:
    from ray.train import report as ray_report
    def safe_report(*args, **kwargs):
        try:
            ray_report(*args, **kwargs)
        except RuntimeError:
            pass
    RAY_AVAILABLE = True
except ImportError:
    RAY_AVAILABLE = False
    def safe_report(*args, **kwargs):
        pass

def train_func(config=None):
    import time
    t_start = time.time()
    accelerator = Accelerator(
        mixed_precision="no",
        gradient_accumulation_steps=1,
    )
    t_accelerator = time.time()
    accelerator.print(
        f"[init] world={accelerator.num_processes} "
        f"local_rank={accelerator.local_process_index} "
        f"main={accelerator.is_main_process} "
        f"[时间] Accelerator init: {t_accelerator-t_start:.2f} seconds"
    )

    t_model_start = time.time()
    model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(3 * 64 * 64, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
    )
    opt = optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()
    t_model_end = time.time()

    t_data_start = time.time()
    train_set = FakeData(size=1024, image_size=(3, 64, 64), num_classes=10, transform=ToTensor())
    valid_set = FakeData(size=256, image_size=(3, 64, 64), num_classes=10, transform=ToTensor())

    train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
    valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
    t_data_end = time.time()

    t_prepare_start = time.time()
    model, opt, train_loader, valid_loader = accelerator.prepare(model, opt, train_loader, valid_loader)
    t_prepare_end = time.time()
    
    accelerator.print(
        f"model: {t_model_end-t_model_start:.2f}秒, "
        f"loader: {t_data_end-t_data_start:.2f}秒, "
        f"prepare: {t_prepare_end-t_prepare_start:.2f}秒, "
        f"total: {t_prepare_end-t_start:.2f}秒"
    )

    for epoch in range(10):
        model.train()
        for x, y in train_loader:
            opt.zero_grad(set_to_none=True)
            logits = model(x)
            loss = loss_fn(logits, y)
            accelerator.backward(loss)
            opt.step()

        model.eval()
        tot, cnt = 0.0, 0
        with torch.no_grad():
            for x, y in valid_loader:
                logits = model(x)
                loss = loss_fn(logits, y)
                (loss_all, ) = accelerator.gather_for_metrics((loss.detach(), ))
                tot += loss_all.sum().item()
                cnt += loss_all.numel()
        
        avg_val = tot / max(cnt, 1)
        accelerator.print(f"[epoch {epoch}] val_loss={avg_val:.4f}")

        if accelerator.is_main_process:
            accelerator.save(
                accelerator.unwrap_model(model).state_dict(),
                f"/data/shenjn/MultiTrain/Ray/test/ckpt_epoch{epoch}.pth",
            )
        safe_report({"epoch": epoch, "val_loss": float(avg_val)})

if __name__ == "__main__":
    ray.init(address="auto")
    trainer = ray.train.torch.TorchTrainer(
        train_func,
        scaling_config=ScalingConfig(
            num_workers=8,
            use_gpu=True,
            resources_per_worker={"GPU": 1},
        ),
    )
    trainer.fit()

Issue Severity

None

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'tcommunity-backlogperformancequestionJust a question :)trainRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions