-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Open
Labels
bugSomething that is supposed to be working; but isn'tSomething that is supposed to be working; but isn'tcommunity-backlogperformancequestionJust a question :)Just a question :)trainRay Train Related IssueRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)
Description
What happened + What you expected to happen
When running RayTorchTrainer on some server nodes, there is a long delay when transitioning from setting up the process group to starting distributed worker processes, while other server nodes experience smooth operation. What could be the cause?
Versions / Dependencies
Ray: 2.50.0
torch: 2.8.0
Reproduction script
from accelerate import Accelerator
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
from torchvision.transforms import ToTensor
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
try:
from ray.train import report as ray_report
def safe_report(*args, **kwargs):
try:
ray_report(*args, **kwargs)
except RuntimeError:
pass
RAY_AVAILABLE = True
except ImportError:
RAY_AVAILABLE = False
def safe_report(*args, **kwargs):
pass
def train_func(config=None):
import time
t_start = time.time()
accelerator = Accelerator(
mixed_precision="no",
gradient_accumulation_steps=1,
)
t_accelerator = time.time()
accelerator.print(
f"[init] world={accelerator.num_processes} "
f"local_rank={accelerator.local_process_index} "
f"main={accelerator.is_main_process} "
f"[时间] Accelerator init: {t_accelerator-t_start:.2f} seconds"
)
t_model_start = time.time()
model = nn.Sequential(
nn.Flatten(),
nn.Linear(3 * 64 * 64, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
opt = optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
t_model_end = time.time()
t_data_start = time.time()
train_set = FakeData(size=1024, image_size=(3, 64, 64), num_classes=10, transform=ToTensor())
valid_set = FakeData(size=256, image_size=(3, 64, 64), num_classes=10, transform=ToTensor())
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
t_data_end = time.time()
t_prepare_start = time.time()
model, opt, train_loader, valid_loader = accelerator.prepare(model, opt, train_loader, valid_loader)
t_prepare_end = time.time()
accelerator.print(
f"model: {t_model_end-t_model_start:.2f}秒, "
f"loader: {t_data_end-t_data_start:.2f}秒, "
f"prepare: {t_prepare_end-t_prepare_start:.2f}秒, "
f"total: {t_prepare_end-t_start:.2f}秒"
)
for epoch in range(10):
model.train()
for x, y in train_loader:
opt.zero_grad(set_to_none=True)
logits = model(x)
loss = loss_fn(logits, y)
accelerator.backward(loss)
opt.step()
model.eval()
tot, cnt = 0.0, 0
with torch.no_grad():
for x, y in valid_loader:
logits = model(x)
loss = loss_fn(logits, y)
(loss_all, ) = accelerator.gather_for_metrics((loss.detach(), ))
tot += loss_all.sum().item()
cnt += loss_all.numel()
avg_val = tot / max(cnt, 1)
accelerator.print(f"[epoch {epoch}] val_loss={avg_val:.4f}")
if accelerator.is_main_process:
accelerator.save(
accelerator.unwrap_model(model).state_dict(),
f"/data/shenjn/MultiTrain/Ray/test/ckpt_epoch{epoch}.pth",
)
safe_report({"epoch": epoch, "val_loss": float(avg_val)})
if __name__ == "__main__":
ray.init(address="auto")
trainer = ray.train.torch.TorchTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=8,
use_gpu=True,
resources_per_worker={"GPU": 1},
),
)
trainer.fit()Issue Severity
None
Metadata
Metadata
Assignees
Labels
bugSomething that is supposed to be working; but isn'tSomething that is supposed to be working; but isn'tcommunity-backlogperformancequestionJust a question :)Just a question :)trainRay Train Related IssueRay Train Related IssuetriageNeeds triage (eg: priority, bug/not-bug, and owning component)Needs triage (eg: priority, bug/not-bug, and owning component)