Skip to content

Commit 5c2ded5

Browse files
authored
[AutoParallel] support sharding tensor-fusion save&load (#69823)
1 parent 9499924 commit 5c2ded5

File tree

7 files changed

+920
-43
lines changed

7 files changed

+920
-43
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
to_placements,
8080
)
8181
from .random import determinate_rng, rng_state
82-
from .sharding import ShardingOptimizerStage1
82+
from .sharding import ShardingOptimizerStage1, get_placement_with_sharding
8383

8484
if TYPE_CHECKING:
8585
from collections.abc import Callable, Sequence
@@ -992,30 +992,6 @@ def replicate_layer_params_and_buffers(
992992
)
993993

994994

995-
def get_placement_with_sharding(param, sharding_axis):
996-
shard_axis = -1
997-
for placement in param.placements:
998-
if isinstance(placement, dist.Shard):
999-
# the parameter can't be shard twice with sharding on different mesh now
1000-
# for example, [Shard(0), Shard(1)], assert here in case
1001-
assert (
1002-
shard_axis == -1
1003-
), "The parameter can't be shard twice with sharding strategy even in different mesh now."
1004-
shard_axis = placement.get_dim()
1005-
1006-
placement_with_sharding = None
1007-
for dim in range(param.ndim):
1008-
if dim != shard_axis:
1009-
placement_with_sharding = dist.Shard(dim)
1010-
break
1011-
1012-
new_placements = param.placements
1013-
if placement_with_sharding is not None:
1014-
new_placements[sharding_axis] = placement_with_sharding
1015-
1016-
return new_placements
1017-
1018-
1019995
class _ShardOptimizer(Optimizer):
1020996
def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
1021997
assert (
@@ -2548,6 +2524,7 @@ def state_dict(
25482524
self,
25492525
mode: Literal['opt', 'param', 'all'] = "all",
25502526
split_fusion: bool = True,
2527+
load_sharded_model: bool = True,
25512528
) -> dict[str, Tensor]:
25522529
"""
25532530
Get the state dict of model and optimizer.
@@ -2559,7 +2536,6 @@ def state_dict(
25592536
'all' : The return value contains the variable in the network and optimizer.
25602537
Default: 'all'
25612538
"""
2562-
25632539
if use_pir_api():
25642540
scope = paddle.static.global_scope()
25652541
local_state_dict = self.dist_main_program(
@@ -2626,6 +2602,21 @@ def state_dict(
26262602
] = dist_tensor
26272603
dist_state_dict.pop(param)
26282604

2605+
# when tensor-fusion is enabled, the optimizer parameters are unbalanced
2606+
# in their sharding. We need to process the optimizer parameters to make
2607+
# them evenly balanced
2608+
if self._engine._optimizer is not None and load_sharded_model:
2609+
optimizer = self._engine._optimizer
2610+
if isinstance(
2611+
optimizer,
2612+
paddle.static.amp.decorator.OptimizerWithMixedPrecision,
2613+
):
2614+
optimizer = optimizer._optimizer
2615+
if isinstance(optimizer, ShardingOptimizerStage1):
2616+
optimizer.convert_state_dict_without_tensor_fusion_param(
2617+
dist_state_dict
2618+
)
2619+
26292620
mapping_names = [
26302621
(
26312622
self._parameter_to_structured_name[k]
@@ -2669,7 +2660,8 @@ def build_distributed_tensor(local_tensor, dist_attr):
26692660
mesh = ProcessMesh(
26702661
np.array(dist_attr["process_group"]).reshape(
26712662
dist_attr["process_shape"]
2672-
)
2663+
),
2664+
dim_names=dist_attr["dim_names"],
26732665
)
26742666
placements = to_placements(dist_attr["dims_mapping"], mesh)
26752667
dist_tensor = dtensor_from_local(local_tensor, mesh, placements)
@@ -2693,7 +2685,25 @@ def build_distributed_tensor(local_tensor, dist_attr):
26932685
def set_state_dict(self, state_dict: dict[str, Tensor]) -> None:
26942686
local_state_dict = {}
26952687
dist_main_program = self.dist_main_program(mode=self._engine._mode)
2696-
cur_state_dict = self.state_dict(split_fusion=False)
2688+
cur_state_dict = self.state_dict(
2689+
split_fusion=False, load_sharded_model=False
2690+
)
2691+
2692+
# For sharding with tensor-fusion, we need to convert the state_dict
2693+
# to include tensor-fusion parameters before calling set_state_dict,
2694+
# as stored parameters are processed as if tensor-fusion is not applied
2695+
if self._engine._optimizer is not None:
2696+
optimizer = self._engine._optimizer
2697+
if isinstance(
2698+
optimizer,
2699+
paddle.static.amp.decorator.OptimizerWithMixedPrecision,
2700+
):
2701+
optimizer = optimizer._optimizer
2702+
if isinstance(optimizer, ShardingOptimizerStage1):
2703+
optimizer.convert_state_dict_with_tensor_fusion_param(
2704+
state_dict
2705+
)
2706+
26972707
for k, v in state_dict.items():
26982708
assert v.is_dist(), f"key {k} value:{v} is not a dist tensor."
26992709
if k in cur_state_dict:

python/paddle/distributed/auto_parallel/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class _AMPConfig(TypedDict, total=False): # noqa: PYI049
158158
set_field_default_config(SHARDING, "enable_tuning", False)
159159
set_field_default_config(SHARDING, "tuning_range", [])
160160
set_field_default_config(SHARDING, "release_gradients", False)
161-
set_field_default_config(SHARDING, "comm_buffer_size_MB", -1)
161+
set_field_default_config(SHARDING, "comm_buffer_size_MB", 256)
162162

163163
if TYPE_CHECKING:
164164

0 commit comments

Comments
 (0)