-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
System Info
- `Accelerate` version: 1.12.0.dev0
- Platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate
- Python version: 3.10.19
- Numpy version: 2.2.6
- PyTorch version: 2.9.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1007.66 GB
- GPU type: NVIDIA GeForce RTX 4090Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - My own task or dataset (give details below)
Reproduction
After modifying my configuration and running my training script, I encountered the following error: TypeError: 'NoneType' object is not callable.
After investigation, I found that the issue is caused by the combination of fsdp_auto_wrap_policy: NO_WRAP and fsdp_activation_checkpointing: true.
Here is the problematic configuration file failed_config.yaml:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_auto_wrap_policy: NO_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: false
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
parallelism_config:
parallelism_config_cp_size: 1
parallelism_config_dp_replicate_size: 2
parallelism_config_dp_shard_size: 2
parallelism_config_tp_size: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Here is the training script main.py:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
class RandomDataset(Dataset):
def __init__(self, num_samples=100, input_dim=128, num_classes=10):
self.num_samples = num_samples
self.input_dim = input_dim
self.num_classes = num_classes
def __getitem__(self, idx):
x = torch.randn(self.input_dim)
y = torch.randint(0, self.num_classes, (1,)).item()
return x, y
def __len__(self):
return self.num_samples
class MLP10(nn.Module):
def __init__(self, input_dim=128, hidden_dim=512, num_classes=10):
super().__init__()
layers = []
in_dim = input_dim
for _ in range(10):
layers.append(nn.Linear(in_dim, hidden_dim))
layers.append(nn.ReLU())
in_dim = hidden_dim
layers.append(nn.Linear(hidden_dim, num_classes))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def main():
accelerator = Accelerator()
input_dim = 128
hidden_dim = 512
num_classes = 10
batch_size = 64
num_epochs = 1
lr = 1e-3
accelerator.print("🚀 Starting Accelerate MLP training")
accelerator.print(f"Using device: {accelerator.device}")
dataset = RandomDataset(num_samples=64 * 4 * 20, input_dim=input_dim, num_classes=num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
model = MLP10(input_dim, hidden_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
for step, (inputs, targets) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
total_loss += loss.item()
accelerator.print(
f"[Epoch {epoch+1}/{num_epochs}] Step {step+1}/{len(dataloader)} "
f"Loss: {total_loss / (step+1):.4f}"
)
accelerator.print(f"✅ Epoch {epoch+1} finished. Avg Loss: {total_loss / len(dataloader):.4f}")
accelerator.print("🎉 Training completed successfully!")
if __name__ == "__main__":
main()
You can reproduce the issue by running the following command:
accelerate launch --config_file /home/yanzhen/distributed_test/accelerate/test/bug2/failed_config.yaml main.py
The following exception stack trace is produced:
/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/launch.py:238: UserWarning: Port `29500` is already in use. Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. If this current attempt fails, or for more control in future runs, please specify a different port (e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection in your launch command or Accelerate config file.
warnings.warn(
[Gloo] Rank 3 is connected to [Gloo] Rank 30 peer ranks. is connected to Expected number of connected peer ranks is : 33 peer ranks.
Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 12 is connected to is connected to 33 peer ranks. peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 33
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 01 is connected to is connected to 11 peer ranks. peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11
[Gloo] Rank [Gloo] Rank 01 is connected to is connected to 11[Gloo] Rank peer ranks. peer ranks. 0Expected number of connected peer ranks is : Expected number of connected peer ranks is : is connected to 111[Gloo] Rank
peer ranks. 1Expected number of connected peer ranks is : is connected to 11
peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 0 is connected to 11 is connected to peer ranks. 1Expected number of connected peer ranks is : peer ranks. 1Expected number of connected peer ranks is :
1
[Gloo] Rank [Gloo] Rank 10 is connected to is connected to 11 peer ranks. peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11
[Gloo] Rank 1 is connected to [Gloo] Rank 10 peer ranks. is connected to Expected number of connected peer ranks is : 11 peer ranks.
Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : [Gloo] Rank 10
is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
🚀 Starting Accelerate MLP training
Using device: cuda:0
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank3]: main()
[rank3]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank3]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank3]: result = self._prepare_fsdp2(*args)
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank3]: model = fsdp2_apply_ac(self, model)
[rank3]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank3]: if auto_wrap_policy_func(parent_module):
[rank3]: TypeError: 'NoneType' object is not callable
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank0]: main()
[rank0]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank0]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank0]: result = self._prepare_fsdp2(*args)
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank0]: model = fsdp2_apply_ac(self, model)
[rank0]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank0]: if auto_wrap_policy_func(parent_module):
[rank0]: TypeError: 'NoneType' object is not callable
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank1]: main()
[rank1]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank1]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank1]: result = self._prepare_fsdp2(*args)
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank1]: model = fsdp2_apply_ac(self, model)
[rank1]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank1]: if auto_wrap_policy_func(parent_module):
[rank1]: TypeError: 'NoneType' object is not callable
[rank2]: Traceback (most recent call last):
[rank2]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank2]: main()
[rank2]: File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank2]: model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank2]: result = self._prepare_fsdp2(*args)
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank2]: model = fsdp2_apply_ac(self, model)
[rank2]: File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank2]: if auto_wrap_policy_func(parent_module):
[rank2]: TypeError: 'NoneType' object is not callable
[rank0]:[W1030 16:37:14.889713455 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W1030 16:37:14.924255190 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1030 16:37:14.958025159 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1030 16:37:14.877000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3983811 closing signal SIGTERM
W1030 16:37:14.878000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3983813 closing signal SIGTERM
W1030 16:37:14.878000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3983814 closing signal SIGTERM
E1030 16:37:15.093000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 1 (pid: 3983812) of binary: /home/yanzhen/miniconda3/envs/accelerate_test/bin/python3.10
Traceback (most recent call last):
File "/home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate", line 7, in <module>
sys.exit(main())
File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/accelerate_cli.py", line 50, in main
args.func(args)
File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 1222, in launch_command
multi_gpu_launcher(args)
File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 853, in multi_gpu_launcher
distrib_run.run(args)
File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/run.py", line 927, in run
elastic_launch(
File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
main.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2025-10-30_16:37:14
host : ubuntu
rank : 1 (local_rank: 1)
exitcode : 1 (pid: 3983812)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
Expected behavior
Training should not fail. If fsdp_auto_wrap_policy: NO_WRAP and fsdp_activation_checkpointing: true are not allowed to be enabled simultaneously, a warning or prompt should be provided during accelerate config setup.
Metadata
Metadata
Assignees
Labels
No labels