Skip to content

TypeError: NoneType object is not callable occurs when using fsdp_auto_wrap_policy=NO_WRAP with fsdp_activation_checkpointing=true #3822

@sdjasj

Description

@sdjasj

System Info

- `Accelerate` version: 1.12.0.dev0
- Platform: Linux-6.5.0-18-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate
- Python version: 3.10.19
- Numpy version: 2.2.6
- PyTorch version: 2.9.0+cu128
- PyTorch accelerator: CUDA
- System RAM: 1007.66 GB
- GPU type: NVIDIA GeForce RTX 4090

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

After modifying my configuration and running my training script, I encountered the following error: TypeError: 'NoneType' object is not callable.

After investigation, I found that the issue is caused by the combination of fsdp_auto_wrap_policy: NO_WRAP and fsdp_activation_checkpointing: true.

Here is the problematic configuration file failed_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
  fsdp_activation_checkpointing: true
  fsdp_auto_wrap_policy: NO_WRAP
  fsdp_cpu_ram_efficient_loading: true
  fsdp_offload_params: false
  fsdp_reshard_after_forward: false
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
parallelism_config:
  parallelism_config_cp_size: 1
  parallelism_config_dp_replicate_size: 2
  parallelism_config_dp_shard_size: 2
  parallelism_config_tp_size: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Here is the training script main.py:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator

class RandomDataset(Dataset):
    def __init__(self, num_samples=100, input_dim=128, num_classes=10):
        self.num_samples = num_samples
        self.input_dim = input_dim
        self.num_classes = num_classes

    def __getitem__(self, idx):
        x = torch.randn(self.input_dim)
        y = torch.randint(0, self.num_classes, (1,)).item()
        return x, y

    def __len__(self):
        return self.num_samples


class MLP10(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=512, num_classes=10):
        super().__init__()
        layers = []
        in_dim = input_dim
        for _ in range(10):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        layers.append(nn.Linear(hidden_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def main():
    accelerator = Accelerator()

    input_dim = 128
    hidden_dim = 512
    num_classes = 10
    batch_size = 64
    num_epochs = 1
    lr = 1e-3

    accelerator.print("🚀 Starting Accelerate MLP training")
    accelerator.print(f"Using device: {accelerator.device}")

    dataset = RandomDataset(num_samples=64 * 4 * 20, input_dim=input_dim, num_classes=num_classes)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    model = MLP10(input_dim, hidden_dim, num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for step, (inputs, targets) in enumerate(dataloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            accelerator.backward(loss)
            optimizer.step()

            total_loss += loss.item()
            accelerator.print(
                f"[Epoch {epoch+1}/{num_epochs}] Step {step+1}/{len(dataloader)} "
                f"Loss: {total_loss / (step+1):.4f}"
            )


        accelerator.print(f"✅ Epoch {epoch+1} finished. Avg Loss: {total_loss / len(dataloader):.4f}")

    accelerator.print("🎉 Training completed successfully!")

if __name__ == "__main__":
    main()

You can reproduce the issue by running the following command:

accelerate launch --config_file /home/yanzhen/distributed_test/accelerate/test/bug2/failed_config.yaml main.py

The following exception stack trace is produced:

/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/launch.py:238: UserWarning: Port `29500` is already in use. Accelerate will attempt to launch in a standalone-like mode by finding an open port automatically for this session. If this current attempt fails, or for more control in future runs, please specify a different port (e.g., `--main_process_port <your_chosen_port>`) or use `--main_process_port 0` for automatic selection in your launch command or Accelerate config file.
  warnings.warn(
[Gloo] Rank 3 is connected to [Gloo] Rank 30 peer ranks.  is connected to Expected number of connected peer ranks is : 33 peer ranks. 
Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 12 is connected to  is connected to 33 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 33

[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 01 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[Gloo] Rank [Gloo] Rank 01 is connected to  is connected to 11[Gloo] Rank  peer ranks.  peer ranks. 0Expected number of connected peer ranks is : Expected number of connected peer ranks is :  is connected to 111[Gloo] Rank 

 peer ranks. 1Expected number of connected peer ranks is :  is connected to 11
 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank 0 is connected to 11 is connected to  peer ranks. 1Expected number of connected peer ranks is :  peer ranks. 1Expected number of connected peer ranks is : 
1
[Gloo] Rank [Gloo] Rank 10 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[Gloo] Rank 1 is connected to [Gloo] Rank 10 peer ranks.  is connected to Expected number of connected peer ranks is : 11 peer ranks. 
Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : [Gloo] Rank 10
 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
🚀 Starting Accelerate MLP training
Using device: cuda:0
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank3]:     main()
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank3]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank3]:     result = self._prepare_fsdp2(*args)
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank3]:     model = fsdp2_apply_ac(self, model)
[rank3]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank3]:     if auto_wrap_policy_func(parent_module):
[rank3]: TypeError: 'NoneType' object is not callable
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank0]:     main()
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank0]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank0]:     result = self._prepare_fsdp2(*args)
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank0]:     model = fsdp2_apply_ac(self, model)
[rank0]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank0]:     if auto_wrap_policy_func(parent_module):
[rank0]: TypeError: 'NoneType' object is not callable
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank1]:     main()
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank1]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank1]:     result = self._prepare_fsdp2(*args)
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank1]:     model = fsdp2_apply_ac(self, model)
[rank1]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank1]:     if auto_wrap_policy_func(parent_module):
[rank1]: TypeError: 'NoneType' object is not callable
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 83, in <module>
[rank2]:     main()
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/test/bug2/main.py", line 59, in main
[rank2]:     model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1555, in prepare
[rank2]:     result = self._prepare_fsdp2(*args)
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/accelerator.py", line 1657, in _prepare_fsdp2
[rank2]:     model = fsdp2_apply_ac(self, model)
[rank2]:   File "/home/yanzhen/distributed_test/accelerate/src/accelerate/utils/fsdp_utils.py", line 595, in fsdp2_apply_ac
[rank2]:     if auto_wrap_policy_func(parent_module):
[rank2]: TypeError: 'NoneType' object is not callable
[rank0]:[W1030 16:37:14.889713455 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W1030 16:37:14.924255190 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W1030 16:37:14.958025159 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W1030 16:37:14.877000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3983811 closing signal SIGTERM
W1030 16:37:14.878000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3983813 closing signal SIGTERM
W1030 16:37:14.878000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 3983814 closing signal SIGTERM
E1030 16:37:15.093000 3983265 site-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 1 (pid: 3983812) of binary: /home/yanzhen/miniconda3/envs/accelerate_test/bin/python3.10
Traceback (most recent call last):
  File "/home/yanzhen/miniconda3/envs/accelerate_test/bin/accelerate", line 7, in <module>
    sys.exit(main())
  File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/accelerate_cli.py", line 50, in main
    args.func(args)
  File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 1222, in launch_command
    multi_gpu_launcher(args)
  File "/home/yanzhen/distributed_test/accelerate/src/accelerate/commands/launch.py", line 853, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/run.py", line 927, in run
    elastic_launch(
  File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/yanzhen/miniconda3/envs/accelerate_test/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
main.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-10-30_16:37:14
  host      : ubuntu
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3983812)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Expected behavior

Training should not fail. If fsdp_auto_wrap_policy: NO_WRAP and fsdp_activation_checkpointing: true are not allowed to be enabled simultaneously, a warning or prompt should be provided during accelerate config setup.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions