Skip to content

Commit f3607d5

Browse files
[Trainer] Update set_seed in trainer_utils.py (#7528)
1 parent ca0e45b commit f3607d5

File tree

2 files changed

+65
-20
lines changed

2 files changed

+65
-20
lines changed

llm/gpt-3/run_pretrain.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,9 @@ def main():
349349
if data_args.data_cache is not None:
350350
os.makedirs(data_args.data_cache, exist_ok=True)
351351

352-
set_seed(seed=training_args.seed, args=training_args)
353352
paddle.set_device(training_args.device)
354-
if paddle.distributed.get_world_size() > 1:
355-
paddle.distributed.init_parallel_env()
353+
354+
set_seed(seed=training_args.seed)
356355

357356
training_args.eval_iters = 10
358357
training_args.test_iters = training_args.eval_iters * 10

paddlenlp/trainer/trainer_utils.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535
import numpy as np
3636
import paddle
37+
from paddle.distributed import fleet
38+
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3739
from paddle.io import IterableDataset
3840
from paddle.optimizer.lr import LambdaDecay
3941

@@ -56,32 +58,76 @@
5658

5759

5860
def set_seed(seed: int = 1234, args=None):
61+
# NOTE(shenliang03): For parameter init seed:
62+
# seed: dp/mp_undistributed_paramter/sharding is same; others is different
63+
# For compute seed(dropout):
64+
# global seed: only mp group is same.
65+
# local seed: all groups are different
5966
if args is None:
6067
random.seed(seed)
6168
np.random.seed(seed)
6269
paddle.seed(seed)
6370

64-
if args is not None:
65-
if args.use_hybrid_parallel:
66-
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
71+
else:
72+
hcg = fleet.get_hybrid_communicate_group() if hasattr(fleet.fleet, "_hcg") else None
73+
if hcg is not None and paddle.distributed.get_world_size() > 1:
74+
# obtain rank message of hybrid parallel
6775

68-
random.seed(args.seed + args.dataset_rank)
69-
np.random.seed(args.seed + args.dataset_rank)
70-
paddle.seed(args.seed + args.dataset_rank)
76+
mp_rank = hcg.get_model_parallel_rank()
77+
mp_size = hcg.get_model_parallel_world_size()
7178

72-
# local_seed/ global_seed is used to control dropout in ModelParallel
73-
local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000
74-
global_seed = args.seed + 100003 + args.dataset_rank
75-
tracker = get_rng_state_tracker()
79+
pp_rank = hcg.get_stage_id()
80+
pp_size = hcg.get_pipe_parallel_world_size()
7681

77-
if "global_seed" not in tracker.states_:
78-
tracker.add("global_seed", global_seed)
79-
if "local_seed" not in tracker.states_:
80-
tracker.add("local_seed", local_seed)
82+
dp_rank = hcg.get_data_parallel_rank()
83+
dp_size = hcg.get_data_parallel_world_size()
84+
85+
sharding_rank = hcg.get_sharding_parallel_rank()
86+
# sharding_size = hcg.get_sharding_parallel_world_size()
8187
else:
82-
random.seed(args.seed)
83-
np.random.seed(args.seed)
84-
paddle.seed(args.seed)
88+
mp_rank, mp_size = 0, 1
89+
pp_rank, pp_size = 0, 1
90+
dp_rank, dp_size = 0, 1
91+
sharding_rank, _ = 0, 1
92+
93+
# NOTE: the commented seeds are set only for precision validation
94+
# seed += 100 * pp_rank
95+
random.seed(seed + 100 * pp_rank)
96+
np.random.seed(seed + 100 * pp_rank)
97+
98+
# seed = mp_rank +
99+
# pp_rank * (mp_size) +
100+
# dp_rank * (mp_size * pp_size) +
101+
# sharding_rank * (mp_size * pp_size * dp_size)
102+
# seed offset is order to avoid conflicts with the parameter initialization seed
103+
104+
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
105+
global_seed = (
106+
seed_offset
107+
+ pp_rank * (mp_size)
108+
+ dp_rank * (mp_size * pp_size)
109+
+ sharding_rank * (mp_size * pp_size * dp_size)
110+
)
111+
112+
seed_offset += paddle.distributed.get_world_size()
113+
local_seed = (
114+
seed_offset
115+
+ mp_rank
116+
+ pp_rank * (mp_size)
117+
+ dp_rank * (mp_size * pp_size)
118+
+ sharding_rank * (mp_size * pp_size * dp_size)
119+
)
120+
121+
tracker = get_rng_state_tracker()
122+
if "global_seed" not in tracker.states_:
123+
tracker.add("global_seed", global_seed)
124+
125+
if "local_seed" not in tracker.states_:
126+
tracker.add("local_seed", local_seed)
127+
128+
paddle.seed(global_seed)
129+
130+
logger.info("The global seed is set to {} and local seed is set to {}.".format(global_seed, local_seed))
85131

86132

87133
class ExplicitEnum(Enum):

0 commit comments

Comments
 (0)