File tree Expand file tree Collapse file tree 1 file changed +14
-4
lines changed
python/paddle/distributed/auto_parallel Expand file tree Collapse file tree 1 file changed +14
-4
lines changed Original file line number Diff line number Diff line change @@ -1140,10 +1140,20 @@ def _set_and_check_sharding_prop_from_param(self):
11401140 placements [self ._sharding_axis ], dist .Replicate
11411141 ), "The placement on sharding_axis should be Replicate"
11421142
1143- # check the sharding degree since it has already been set
1144- assert (
1145- mesh .dim_size (self ._sharding_axis ) == self ._sharding_degree
1146- ), "The sharding degree of all parameters must be equal currently."
1143+ # check the sharding degree since it has already been set,
1144+ # skip check when mesh is true subset of global_mesh
1145+ if global_mesh :
1146+ if set (mesh .process_ids ) < set (global_mesh .process_ids ):
1147+ continue
1148+ elif self ._shard_fn ._mesh :
1149+ if set (mesh .process_ids ) < set (
1150+ self ._shard_fn ._mesh .process_ids
1151+ ):
1152+ continue
1153+ else :
1154+ assert (
1155+ mesh .dim_size (self ._sharding_axis ) == self ._sharding_degree
1156+ ), "The sharding degree of all parameters must be equal currently."
11471157
11481158 def _shard_accumulator (self , param ):
11491159 target_name = param .name
You can’t perform that action at this time.
0 commit comments