Skip to content

[FSDP] Support with AMP Grad scaler #421

@SeanNaren

Description

@SeanNaren

🐛 Bug

If you modify the FSDP test here to include a Grad Scaler, the cpu_offload test fails:

Modification:

     from fairscale.optim.grad_scaler import ShardedGradScaler

    @staticmethod
    def _train_for_several_steps(model, num_steps, autocast, lr=0.01, norm_type=None):
        model_device = next(model.parameters()).device
        # use SGD with momentum instead of Adam, since Adam is scale invariant
        # and this makes it bad for tests
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        scaler = ShardedGradScaler()
        for _ in range(num_steps):
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled=autocast):
                # Inputs always cuda regardless of move_grads_cpu, or model.device
                input = model.module.get_input(torch.device("cuda"))
                output = model(*input)
                loss = model.module.get_loss(input, output).to(model_device)
            loss = scaler.scale(loss)
            assert loss.dtype == torch.float32
            model.module.run_backward(loss)
            if norm_type is not None:
                clip_norm = 0.3
                if isinstance(model, FullyShardedDataParallel):
                    model.clip_grad_norm_(clip_norm, norm_type)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
            scaler.step(optim)
            scaler.update()
        if hasattr(model, "assert_idle"):
            model.assert_idle()
        return loss.detach()

Command:

pytest tests/nn/data_parallel/test_fsdp.py::TestComparisonToPyTorchDDP::test_cpu_offload_and_cpu_grads

Error:

E       torch.multiprocessing.spawn.ProcessRaisedException:
E
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
E           fn(i, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 812, in init_and_run
E           fn(rank, group, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 292, in _test_identical_outputs
E           shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 63, in _train_for_several_steps
E           loss = scaler.scale(loss)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 161, in scale
E           assert outputs.is_cuda
E       AssertionError

If I remove the to(model_device) I get a different error, but probably still due to the devices:

E       torch.multiprocessing.spawn.ProcessRaisedException:
E
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
E           fn(i, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 812, in init_and_run
E           fn(rank, group, *args)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 292, in _test_identical_outputs
E           shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type)
E         File "/home/sean/fairscale/tests/nn/data_parallel/test_fsdp.py", line 74, in _train_for_several_steps
E           scaler.step(optim)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 324, in step
E           self.unscale_(optimizer)
E         File "/home/sean/fairscale/fairscale/optim/grad_scaler.py", line 48, in unscale_
E           super().unscale_(optimizer)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 275, in unscale_
E           optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False)
E         File "/home/sean/miniconda3/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 223, in _unscale_grads_
E           torch._amp_foreach_non_finite_check_and_unscale_(grads,
E       RuntimeError: Could not run 'aten::_amp_foreach_non_finite_check_and_unscale_' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_amp_foreach_non_finite_check_and_unscale_' is only available for these backends: [CUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].
E
E       CUDA: registered at /pytorch/build/aten/src/ATen/RegisterCUDA.cpp:7100 [kernel]
E       BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
E       Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
E       AutogradOther: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradCPU: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradCUDA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradXLA: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradNestedTensor: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       UNKNOWN_TENSOR_TYPE_ID: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradPrivateUse1: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradPrivateUse2: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       AutogradPrivateUse3: registered at /pytorch/torch/csrc/autograd/generated/VariableType_0.cpp:9273 [autograd kernel]
E       Tracer: registered at /pytorch/torch/csrc/autograd/generated/TraceType_0.cpp:10499 [kernel]
E       Autocast: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:250 [backend fallback]
E       Batched: registered at /pytorch/aten/src/ATen/BatchingRegistrations.cpp:1016 [backend fallback]
E       VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

Proposed fix

Modify the ShardedGradScaler to work with cpu_offload. It seems the issue boils down to having to deal with gradients that are on CPU, so the question becomes can we modify the inf check and unscale operation for CPU?

Another direction is motivated by the PyTorch GradScaler class being hella confusing, and after finding this in fairseq, maybe it's better we define our own grad scaler logic for FSDP? If this is something that would be preferred, I can work on this!

Environment

PyTorch version: 1.8.0+cu112
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: A100-SXM4-40GB
GPU 1: A100-SXM4-40GB
GPU 2: A100-SXM4-40GB
GPU 3: A100-SXM4-40GB
GPU 4: A100-SXM4-40GB
GPU 5: A100-SXM4-40GB
GPU 6: A100-SXM4-40GB
GPU 7: A100-SXM4-40GB

Nvidia driver version: 460.32.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.1
[pip3] pytorch-lightning==1.3.0.dev0
[pip3] torch==1.8.0+cu112
[pip3] torchvision==0.8.2
[conda] numpy                     1.20.1                   pypi_0    pypi
[conda] pytorch-lightning         1.3.0.dev0                dev_0    <develop>
[conda] torch                     1.8.0+cu112              pypi_0    pypi
[conda] torchvision               0.8.2                    pypi_0    pypi

cc @myleott @sshleifer @blefaudeux @min-xu-ai

Metadata

Metadata

Labels

FSDPFullyShardedDataParallel (zero-3)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions