|
89 | 89 | from ..transformers.model_utils import ( |
90 | 90 | PretrainedModel, |
91 | 91 | _add_variant, |
| 92 | + _load_state_dict_into_model, |
92 | 93 | load_sharded_checkpoint, |
93 | 94 | unwrap_model, |
94 | 95 | ) |
|
149 | 150 | from .utils import reshard as reshard_util |
150 | 151 | from .utils.async_save import AsyncSaver |
151 | 152 | from .utils.helper import ( # nested_truncate, |
| 153 | + broadcast_dataset_rank0_model, |
152 | 154 | broadcast_dp_optimizer, |
153 | 155 | broadcast_moe_optimizer, |
154 | 156 | distributed_concat, |
@@ -1161,6 +1163,10 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg): |
1161 | 1163 | self.state.best_model_checkpoint, |
1162 | 1164 | safe_serialization=True, |
1163 | 1165 | ) |
| 1166 | + if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: |
| 1167 | + state_dict = broadcast_dataset_rank0_model(self.model.state_dict()) |
| 1168 | + if self.args.dataset_rank > 0: |
| 1169 | + _load_state_dict_into_model(self.model, state_dict, "") |
1164 | 1170 | else: |
1165 | 1171 | weight_name = PADDLE_WEIGHTS_NAME |
1166 | 1172 | best_model_path = os.path.join( |
@@ -1203,6 +1209,10 @@ def _load_best_model_from_peft_checkpoint(self): |
1203 | 1209 | self.state.best_model_checkpoint, |
1204 | 1210 | safe_serialization=True, |
1205 | 1211 | ) |
| 1212 | + if self.args.sharding_parallel_degree > 1 or self.args.data_parallel_degree > 1: |
| 1213 | + state_dict = broadcast_dataset_rank0_model(self.model.get_trainable_state_dict()) |
| 1214 | + if self.args.dataset_rank > 0: |
| 1215 | + _load_state_dict_into_model(self.model, state_dict, "") |
1206 | 1216 | return |
1207 | 1217 |
|
1208 | 1218 | convert_tp = False |
|
0 commit comments