Skip to content

Commit 299f2c4

Browse files
shuyingsunshine21carmoccapre-commit-ci[bot]SeanNaren
authored
FSDP with full state dict (#7487)
* Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * checkpoint consolidation * Update ddp_spawn.py * Update test_metric_result_integration.py * Update test_results.py * Update utils.py * Update utils.py * Update test_all_gather_grad.py * Update test_all_gather_grad.py * Update test_results.py * Revert "Update test_results.py" This reverts commit 9d4a2b8. * Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da, reversing changes made to 0d23d75. * Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75. * Revert "Update utils.py" This reverts commit 70fe5da. * Revert "Update utils.py" This reverts commit a9aae99. * Revert "Update test_results.py" This reverts commit ea74906. * Revert "Update test_metric_result_integration.py" This reverts commit bf70e43. * Revert "Update ddp_spawn.py" This reverts commit f172101. * Revert "checkpoint consolidation" This reverts commit 536c132. * Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde9. * Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f4. * Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc9. * Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2. * Revert "Revert "Update test_results.py"" This reverts commit 250d0aa. * Revert "Revert "Update utils.py"" This reverts commit 8651d54. * Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29. * modify distributed environment to make test pass * fix version for ddp plugin test * fix * fix * changelog * Update CHANGELOG.md * fsdp with full state dict * fix missing import * modify unitest * fix * fix * fix typo * modify test and add changelog * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * limit max_epoch to 1 for testing * test * fix * update * testing remove special for multi gpu * assert gpu * add assertion for gpu * fix * Re-enable special test, use ModelCheckpoint * Fix paths * Fix path passing * test * test * fix test * fix * pre-commit format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SeanNaren <[email protected]>
1 parent 01109cd commit 299f2c4

File tree

12 files changed

+522
-13
lines changed

12 files changed

+522
-13
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241))
3939

4040

41+
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))
42+
43+
4144
### Changed
4245

4346
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

pytorch_lightning/plugins/__init__.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
77
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
88
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
9+
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
10+
FullyShardedNativeMixedPrecisionPlugin,
11+
)
912
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
1013
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
1114
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
@@ -15,6 +18,7 @@
1518
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
1619
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
1720
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
21+
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
1822
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
1923
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
2024
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
@@ -32,24 +36,26 @@
3236
"DDP2Plugin",
3337
"DDPPlugin",
3438
"DDPSpawnPlugin",
39+
"DDPFullyShardedPlugin",
3540
"DeepSpeedPlugin",
3641
"DeepSpeedPrecisionPlugin",
3742
"DoublePrecisionPlugin",
3843
"HorovodPlugin",
3944
"NativeMixedPrecisionPlugin",
4045
"PrecisionPlugin",
4146
"ShardedNativeMixedPrecisionPlugin",
47+
"FullyShardedNativeMixedPrecisionPlugin"
4248
"SingleDevicePlugin",
4349
"SingleTPUPlugin",
4450
"TPUHalfPrecisionPlugin",
4551
"TPUSpawnPlugin",
46-
'RPCPlugin',
47-
'RPCSequentialPlugin',
48-
'TrainingTypePlugin',
49-
'ParallelPlugin',
50-
'Plugin',
51-
'DDPShardedPlugin',
52-
'DDPSpawnShardedPlugin',
52+
"RPCPlugin",
53+
"RPCSequentialPlugin",
54+
"TrainingTypePlugin",
55+
"ParallelPlugin",
56+
"Plugin",
57+
"DDPShardedPlugin",
58+
"DDPSpawnShardedPlugin",
5359
]
5460

5561
from pathlib import Path

pytorch_lightning/plugins/precision/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
22
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
33
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
4+
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
5+
FullyShardedNativeMixedPrecisionPlugin,
6+
)
47
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
58
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
69
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Optional, Union
15+
16+
from torch.nn import Module
17+
from torch.optim import Optimizer
18+
19+
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
20+
from pytorch_lightning.utilities import GradClipAlgorithmType
21+
22+
23+
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
24+
"""Mixed Precision for Full Sharded Training"""
25+
26+
precision = "mixed"
27+
28+
def clip_gradients(
29+
self,
30+
optimizer: Optimizer,
31+
clip_val: Union[int, float],
32+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE,
33+
model: Optional[Module] = None,
34+
) -> None:
35+
clip_val = float(clip_val)
36+
if clip_val <= 0:
37+
return
38+
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html
39+
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
40+
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
41+
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
42+
# trace back the root FSDP. Now we only support clip by value.
43+
assert (
44+
gradient_clip_algorithm == GradClipAlgorithmType.VALUE
45+
), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`"
46+
self.clip_grad_by_value(optimizer, clip_val)

pytorch_lightning/plugins/training_type/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
44
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
55
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
6+
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
67
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
78
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
89
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import contextlib
15+
from typing import Any, Dict, Generator, List, Optional, Union
16+
17+
import torch
18+
from torch import Tensor
19+
from torch.nn import Module
20+
21+
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
22+
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
23+
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
24+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
25+
26+
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
27+
from fairscale.nn import default_auto_wrap_policy, enable_wrap
28+
from fairscale.nn.data_parallel import FullyShardedDataParallel
29+
30+
31+
class DDPFullyShardedPlugin(DDPPlugin):
32+
33+
def __init__(
34+
self,
35+
cpu_offload: bool = False,
36+
flatten_parameters: bool = True,
37+
reshard_after_forward: bool = True,
38+
move_grads_to_cpu: Optional[bool] = None,
39+
fp32_reduce_scatter: Optional[bool] = None,
40+
compute_dtype: Optional[torch.dtype] = None,
41+
bucket_cap_mb: int = 25,
42+
min_num_params: int = 1e8,
43+
state_dict_to_cpu: bool = True,
44+
parallel_devices: Optional[List[torch.device]] = None,
45+
cluster_environment: ClusterEnvironment = None,
46+
):
47+
"""
48+
Plugin for Fully Sharded Data Parallel provided by FairScale.
49+
50+
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
51+
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
52+
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
53+
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch.
54+
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`.
55+
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change.
56+
57+
Defaults have been set and options have been exposed, but may require configuration
58+
based on your level of memory/speed efficiency. We suggest having a look at this PR for more information.
59+
`https://github.com/facebookresearch/fairscale/pull/413`
60+
61+
Many of the helpful doc strings below came from the original FairScale documentation:
62+
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`
63+
64+
Arguments:
65+
cpu_offload: Offload FP32 params to CPU. Only usable in precision=16 mode.
66+
(Default: False).
67+
move_grads_to_cpu: Moves gradient shards to CPU after reduction.
68+
Only disable if using CPU based optimizers
69+
(Default to ``cpu_offload``).
70+
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
71+
(Default: True).
72+
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows
73+
down training. This is only relevant when resharding individual layers.
74+
(Default: True).
75+
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision
76+
(Default: None).
77+
compute_dtype: dtype for full parameters for computation. Default to torch.float32,
78+
unless using mixed precision, in which case defaults to torch.float16.
79+
(Default: None).
80+
bucket_cap_mb: bucket parameters so that gradient reduction
81+
can potentially overlap with backward computation.
82+
bucket_cap_mb controls the bucket size in MegaBytes (MB).
83+
Buckets are sub-divided based on world_size,
84+
so the max shard size is roughly bucket_cap_mb / world_size.
85+
Values <= 0 disable bucketing.
86+
(Default: 25).
87+
min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``.
88+
(Default: 1e8)
89+
state_dict_to_cpu: Whether to return parameters (returned by :func:`state_dict`) on CPU device.
90+
If ``False``, this will default to ``compute_device``.
91+
(Defautl: True).
92+
"""
93+
94+
super().__init__(
95+
parallel_devices=parallel_devices,
96+
cluster_environment=cluster_environment,
97+
)
98+
self.cpu_offload = cpu_offload
99+
self.move_grads_to_cpu = move_grads_to_cpu
100+
self.flatten_parameters = flatten_parameters
101+
self.reshard_after_forward = reshard_after_forward
102+
self.fp32_reduce_scatter = fp32_reduce_scatter
103+
self.compute_dtype = compute_dtype
104+
self.bucket_cap_mb = bucket_cap_mb
105+
self.min_num_params = min_num_params
106+
self.state_dict_device = torch.device("cpu") if state_dict_to_cpu else None
107+
self._process_group = None
108+
109+
@property
110+
def process_group(self):
111+
if self._process_group is None:
112+
self._process_group = torch.distributed.new_group()
113+
return self._process_group
114+
115+
def setup_distributed(self) -> None:
116+
if not self.on_gpu:
117+
raise MisconfigurationException(
118+
"You selected accelerator to be `ddp_fully_sharded`, but GPU is not available."
119+
)
120+
super().setup_distributed()
121+
torch.cuda.set_device(self.root_device)
122+
123+
@contextlib.contextmanager
124+
def model_sharded_context(self) -> Generator:
125+
precision = self.lightning_module.trainer.precision
126+
127+
def wrap_policy(*args, **kwargs):
128+
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)
129+
130+
with enable_wrap(
131+
wrapper_cls=FullyShardedDataParallel,
132+
auto_wrap_policy=wrap_policy,
133+
process_group=self.process_group,
134+
cpu_offload=self.cpu_offload,
135+
move_grads_to_cpu=self.move_grads_to_cpu,
136+
flatten_parameters=self.flatten_parameters,
137+
mixed_precision=precision == "mixed",
138+
reshard_after_forward=self.reshard_after_forward,
139+
fp32_reduce_scatter=self.fp32_reduce_scatter,
140+
compute_dtype=self.compute_dtype,
141+
bucket_cap_mb=self.bucket_cap_mb,
142+
state_dict_device=self.state_dict_device,
143+
):
144+
yield
145+
146+
def connect(self, model: Module) -> None:
147+
super().connect(model)
148+
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
149+
if not model_call_configure_sharded_model_hook:
150+
# if model has not called configure sharded model, we reset
151+
# the training type plugin's call_configure_sharded_model_hook
152+
# to give trainer a chance to configure.
153+
self.call_configure_sharded_model_hook = True
154+
155+
def configure_ddp(self) -> None:
156+
if not self.cpu_offload:
157+
# When using CPU Offload, FSDP will manage the CUDA movement for us.
158+
# Note: this would be problematic for large model (which could not fit in one GPU)
159+
# as FSDP module.to(device) would first summon all parameters
160+
# (TODO: need to figure out solution)
161+
self.model_to_device()
162+
163+
# setup optimizers after fully sharded has wrapped the lightning module
164+
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer)
165+
166+
def pre_dispatch(self) -> None:
167+
if self.sync_batchnorm:
168+
self.model = self.configure_sync_batchnorm(self.model)
169+
self.configure_ddp()
170+
self.barrier()
171+
172+
def model_to_device(self) -> None:
173+
# ensure we update the device type in the lightning module
174+
self.lightning_module.to(self.root_device)
175+
176+
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
177+
# Currently it is same as default TrainingTypePlugin, i.e. return
178+
# the full state dict for FSDP, in the future, we will provide sharded
179+
# state dict.
180+
return super().lightning_module_state_dict()
181+
182+
@property
183+
def setup_optimizers_in_pre_dispatch(self) -> bool:
184+
# Setup optimizers after the Fully Sharded Model has been made
185+
return True
186+
187+
def training_step(self, *args, **kwargs):
188+
return self.model.training_step(*args, **kwargs)
189+
190+
def validation_step(self, *args, **kwargs):
191+
return self.model.validation_step(*args, **kwargs)
192+
193+
def test_step(self, *args, **kwargs):
194+
return self.model.test_step(*args, **kwargs)
195+
196+
def predict_step(self, *args, **kwargs):
197+
return self.model.predict_step(*args, **kwargs)
198+
199+
def post_training_step(self):
200+
pass
201+
202+
@classmethod
203+
def register_plugins(cls, plugin_registry: Dict):
204+
plugin_registry.register(
205+
"fsdp",
206+
cls,
207+
description="Fully sharded training with checkpointing the full state dict.",
208+
)

0 commit comments

Comments
 (0)