7979 to_placements ,
8080)
8181from .random import determinate_rng , rng_state
82- from .sharding import ShardingOptimizerStage1
82+ from .sharding import ShardingOptimizerStage1 , get_placement_with_sharding
8383
8484if TYPE_CHECKING :
8585 from collections .abc import Callable , Sequence
@@ -992,30 +992,6 @@ def replicate_layer_params_and_buffers(
992992 )
993993
994994
995- def get_placement_with_sharding (param , sharding_axis ):
996- shard_axis = - 1
997- for placement in param .placements :
998- if isinstance (placement , dist .Shard ):
999- # the parameter can't be shard twice with sharding on different mesh now
1000- # for example, [Shard(0), Shard(1)], assert here in case
1001- assert (
1002- shard_axis == - 1
1003- ), "The parameter can't be shard twice with sharding strategy even in different mesh now."
1004- shard_axis = placement .get_dim ()
1005-
1006- placement_with_sharding = None
1007- for dim in range (param .ndim ):
1008- if dim != shard_axis :
1009- placement_with_sharding = dist .Shard (dim )
1010- break
1011-
1012- new_placements = param .placements
1013- if placement_with_sharding is not None :
1014- new_placements [sharding_axis ] = placement_with_sharding
1015-
1016- return new_placements
1017-
1018-
1019995class _ShardOptimizer (Optimizer ):
1020996 def __init__ (self , optimizer , shard_fn = None , gradient_accumulation_steps = 1 ):
1021997 assert (
@@ -2548,6 +2524,7 @@ def state_dict(
25482524 self ,
25492525 mode : Literal ['opt' , 'param' , 'all' ] = "all" ,
25502526 split_fusion : bool = True ,
2527+ load_sharded_model : bool = True ,
25512528 ) -> dict [str , Tensor ]:
25522529 """
25532530 Get the state dict of model and optimizer.
@@ -2559,7 +2536,6 @@ def state_dict(
25592536 'all' : The return value contains the variable in the network and optimizer.
25602537 Default: 'all'
25612538 """
2562-
25632539 if use_pir_api ():
25642540 scope = paddle .static .global_scope ()
25652541 local_state_dict = self .dist_main_program (
@@ -2626,6 +2602,21 @@ def state_dict(
26262602 ] = dist_tensor
26272603 dist_state_dict .pop (param )
26282604
2605+ # when tensor-fusion is enabled, the optimizer parameters are unbalanced
2606+ # in their sharding. We need to process the optimizer parameters to make
2607+ # them evenly balanced
2608+ if self ._engine ._optimizer is not None and load_sharded_model :
2609+ optimizer = self ._engine ._optimizer
2610+ if isinstance (
2611+ optimizer ,
2612+ paddle .static .amp .decorator .OptimizerWithMixedPrecision ,
2613+ ):
2614+ optimizer = optimizer ._optimizer
2615+ if isinstance (optimizer , ShardingOptimizerStage1 ):
2616+ optimizer .convert_state_dict_without_tensor_fusion_param (
2617+ dist_state_dict
2618+ )
2619+
26292620 mapping_names = [
26302621 (
26312622 self ._parameter_to_structured_name [k ]
@@ -2669,7 +2660,8 @@ def build_distributed_tensor(local_tensor, dist_attr):
26692660 mesh = ProcessMesh (
26702661 np .array (dist_attr ["process_group" ]).reshape (
26712662 dist_attr ["process_shape" ]
2672- )
2663+ ),
2664+ dim_names = dist_attr ["dim_names" ],
26732665 )
26742666 placements = to_placements (dist_attr ["dims_mapping" ], mesh )
26752667 dist_tensor = dtensor_from_local (local_tensor , mesh , placements )
@@ -2693,7 +2685,25 @@ def build_distributed_tensor(local_tensor, dist_attr):
26932685 def set_state_dict (self , state_dict : dict [str , Tensor ]) -> None :
26942686 local_state_dict = {}
26952687 dist_main_program = self .dist_main_program (mode = self ._engine ._mode )
2696- cur_state_dict = self .state_dict (split_fusion = False )
2688+ cur_state_dict = self .state_dict (
2689+ split_fusion = False , load_sharded_model = False
2690+ )
2691+
2692+ # For sharding with tensor-fusion, we need to convert the state_dict
2693+ # to include tensor-fusion parameters before calling set_state_dict,
2694+ # as stored parameters are processed as if tensor-fusion is not applied
2695+ if self ._engine ._optimizer is not None :
2696+ optimizer = self ._engine ._optimizer
2697+ if isinstance (
2698+ optimizer ,
2699+ paddle .static .amp .decorator .OptimizerWithMixedPrecision ,
2700+ ):
2701+ optimizer = optimizer ._optimizer
2702+ if isinstance (optimizer , ShardingOptimizerStage1 ):
2703+ optimizer .convert_state_dict_with_tensor_fusion_param (
2704+ state_dict
2705+ )
2706+
26972707 for k , v in state_dict .items ():
26982708 assert v .is_dist (), f"key { k } value:{ v } is not a dist tensor."
26992709 if k in cur_state_dict :
0 commit comments