File tree Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Expand file tree Collapse file tree 1 file changed +8
-0
lines changed Original file line number Diff line number Diff line change @@ -1341,13 +1341,21 @@ def is_segment_parallel_supported():
13411341 strategy .hybrid_configs ["sharding_configs" ].comm_buffer_size_MB = int (
13421342 self .sharding_comm_buffer_size_MB
13431343 )
1344+ # The `comm_buffer_size_MB` is added directly to sharding properties
1345+ # for semi-auto mode, avoiding potential confusion with strategy config,
1346+ # as parameters in semi-auto mode are managed via strategy.
1347+ strategy .sharding .comm_buffer_size_MB = int (self .sharding_comm_buffer_size_MB )
13441348
13451349 if "split_param" in sharding_parallel_config :
13461350 strategy .hybrid_configs ["sharding_configs" ].split_param = True
13471351 assert self .amp_master_grad , "Currently sharding stage1 v2 only support amp_master_grad"
13481352
13491353 if "enable_release_grads" in sharding_parallel_config :
13501354 strategy .hybrid_configs ["sharding_configs" ].release_gradients = True
1355+ # `release_gradients` is set directly in sharding properties for the same
1356+ # reason as `comm_buffer_size_MB`, to avoid confusion with centralized
1357+ # strategy management in semi-auto mode.
1358+ strategy .sharding .release_gradients = True
13511359
13521360 if self .pipeline_parallel_degree == 1 :
13531361 strategy .hybrid_configs ["sharding_configs" ].tensor_fusion = (
You can’t perform that action at this time.
0 commit comments