Skip to content

Commit 145826c

Browse files
authored
add is_distributed_field in sharding reshard (#8875)
1 parent f0f1d9d commit 145826c

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paddlenlp/trainer/utils/sharding_io.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,8 @@ def _gather_sharding_metas(self):
588588
param_meta = {}
589589
for k, v in model.state_dict().items():
590590
structure_name_mapping[k] = v.name
591-
param_meta[k] = (v.shape, int(v.dtype))
591+
is_distributed = getattr(v, "is_distributed", False)
592+
param_meta[k] = (v.shape, int(v.dtype), is_distributed)
592593

593594
sharding_metas = {}
594595
sharding_meta = {}

0 commit comments

Comments
 (0)