Skip to content

Commit d28c5ac

Browse files
committed
update
1 parent e52eba7 commit d28c5ac

File tree

2 files changed

+161
-43
lines changed

2 files changed

+161
-43
lines changed

llm/gpt-3/run_pretrain.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Trainer,
3232
TrainingArguments,
3333
get_last_checkpoint,
34+
init_dist_env,
3435
speed_metrics,
3536
)
3637
from paddlenlp.transformers import (
@@ -325,10 +326,21 @@ def main():
325326
if model_args.tokenizer_name_or_path is None:
326327
model_args.tokenizer_name_or_path = model_args.model_name_or_path
327328

328-
set_seed(training_args)
329+
if data_args.data_cache is not None:
330+
os.makedirs(data_args.data_cache, exist_ok=True)
331+
329332
paddle.set_device(training_args.device)
333+
330334
if paddle.distributed.get_world_size() > 1:
331-
paddle.distributed.init_parallel_env()
335+
init_dist_env(
336+
training_args.tensor_parallel_degree,
337+
training_args.sharding_parallel_degree,
338+
training_args.pipeline_parallel_degree,
339+
training_args.data_parallel_degree,
340+
training_args.seed,
341+
)
342+
343+
set_seed(seed=training_args.seed)
332344

333345
training_args.eval_iters = 10
334346
training_args.test_iters = training_args.eval_iters * 10

paddlenlp/trainer/trainer_utils.py

Lines changed: 147 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434

3535
import numpy as np
3636
import paddle
37+
import paddle.distributed as dist
38+
from paddle.distributed import fleet
39+
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
3740
from paddle.io import IterableDataset
3841
from paddle.optimizer.lr import LambdaDecay
3942

@@ -52,42 +55,160 @@
5255
"get_last_checkpoint",
5356
"get_scheduler",
5457
"set_hyrbid_parallel_seed",
58+
"init_dist_env",
5559
]
5660

5761

62+
_hcg = None
63+
64+
65+
def set_hcg(hcg):
66+
global _hcg
67+
_hcg = hcg
68+
69+
70+
def get_hcg():
71+
global _hcg
72+
return _hcg
73+
74+
5875
def set_seed(seed: int = 1234, args=None):
76+
# NOTE(shenliang03): For parameter init seed:
77+
# seed: dp/mp_undistributed_paramter/sharding is same; others is different
78+
# For compute seed(dropout):
79+
# global seed: only mp group is same.
80+
# local seed: all groups are different
5981
if args is None:
6082
random.seed(seed)
6183
np.random.seed(seed)
6284
paddle.seed(seed)
85+
else:
86+
hcg = get_hcg()
87+
if paddle.distributed.get_world_size() > 1:
88+
# obtain rank message of hybrid parallel
89+
if hcg is None:
90+
assert False
91+
92+
mp_rank = hcg.get_model_parallel_rank()
93+
mp_size = hcg.get_model_parallel_world_size()
94+
95+
pp_rank = hcg.get_stage_id()
96+
pp_size = hcg.get_pipe_parallel_world_size()
97+
98+
dp_rank = hcg.get_data_parallel_rank()
99+
dp_size = hcg.get_data_parallel_world_size()
100+
101+
sharding_rank = hcg.get_sharding_parallel_rank()
102+
# sharding_size = hcg.get_sharding_parallel_world_size()
103+
else:
104+
mp_rank, mp_size = 0, 1
105+
pp_rank, pp_size = 0, 1
106+
dp_rank, dp_size = 0, 1
107+
sharding_rank, _ = 0, 1
108+
109+
# NOTE: the commented seeds are set only for precision validation
110+
# seed += 100 * pp_rank
111+
random.seed(seed + 100 * pp_rank)
112+
np.random.seed(seed + 100 * pp_rank)
113+
114+
# seed = mp_rank +
115+
# pp_rank * (mp_size) +
116+
# dp_rank * (mp_size * pp_size) +
117+
# sharding_rank * (mp_size * pp_size * dp_size)
118+
# seed offset is order to avoid conflicts with the parameter initialization seed
119+
120+
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
121+
global_seed = (
122+
seed_offset
123+
+ pp_rank * (mp_size)
124+
+ dp_rank * (mp_size * pp_size)
125+
+ sharding_rank * (mp_size * pp_size * dp_size)
126+
)
127+
128+
seed_offset += paddle.distributed.get_world_size()
129+
local_seed = (
130+
seed_offset
131+
+ mp_rank
132+
+ pp_rank * (mp_size)
133+
+ dp_rank * (mp_size * pp_size)
134+
+ sharding_rank * (mp_size * pp_size * dp_size)
135+
)
136+
137+
tracker = get_rng_state_tracker()
138+
tracker.add("global_seed", global_seed)
139+
tracker.add("local_seed", local_seed)
140+
141+
paddle.seed(global_seed)
142+
143+
logger.info("The global seed is set to {} and local seed is set to {}.".format(global_seed, local_seed))
144+
145+
146+
def create_hcg(strategy, hcg_name="HybridCommunicateGroup"):
147+
if hcg_name == "HybridCommunicateGroup":
148+
fleet.init(is_collective=True, strategy=strategy)
149+
hcg = fleet.get_hybrid_communicate_group()
150+
else:
151+
dist.init_parallel_env()
152+
hcg = eval("{}".format(hcg_name))(strategy)
153+
print("asdfasdf hcg", hcg)
154+
return hcg
63155

64-
if args is not None:
65-
if args.use_hybrid_parallel:
66-
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
67-
68-
random.seed(args.seed + args.dataset_rank)
69-
np.random.seed(args.seed + args.dataset_rank)
70-
paddle.seed(args.seed + args.dataset_rank)
71-
72-
# local_seed/ global_seed is used to control dropout in ModelParallel
73-
local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000
74-
global_seed = (
75-
args.seed
76-
+ 100003
77-
+ args.dataset_rank
78-
+ args.tensor_parallel_rank * 10
79-
+ args.pipeline_parallel_rank * 1000
80-
)
81-
tracker = get_rng_state_tracker()
82156

83-
if "global_seed" not in tracker.states_:
84-
tracker.add("global_seed", global_seed)
85-
if "local_seed" not in tracker.states_:
86-
tracker.add("local_seed", local_seed)
157+
def init_dist_env(
158+
tensor_parallel_degree=1, sharding_parallel_degree=1, pipeline_parallel_degree=1, data_parallel_degree=1, seed=1
159+
):
160+
161+
strategy = fleet.DistributedStrategy()
162+
163+
def is_segment_parallel_supported():
164+
import inspect
165+
166+
members = [name for (name, date) in inspect.getmembers(fleet.HybridCommunicateGroup)]
167+
return "get_sep_parallel_world_size" in members
168+
169+
if tensor_parallel_degree == 1 and sharding_parallel_degree == 1:
170+
if is_segment_parallel_supported():
171+
order = ["pp", "dp", "sharding", "sep", "mp"]
172+
else:
173+
order = ["pp", "dp", "sharding", "mp"]
174+
else:
175+
if is_segment_parallel_supported():
176+
order = ["dp", "pp", "sharding", "sep", "mp"]
87177
else:
88-
random.seed(args.seed)
89-
np.random.seed(args.seed)
90-
paddle.seed(args.seed)
178+
order = ["dp", "pp", "sharding", "mp"]
179+
180+
strategy.hybrid_configs = {
181+
"dp_degree": data_parallel_degree,
182+
"mp_degree": tensor_parallel_degree,
183+
"pp_degree": pipeline_parallel_degree,
184+
"sharding_degree": sharding_parallel_degree,
185+
"order": order,
186+
}
187+
188+
# TODO(wawltor) The inference parallel do not support the pipeline mode
189+
190+
"""
191+
if pipeline_parallel_degree > 1:
192+
if "sequence_parallel" in config.Model:
193+
if config.Model.sequence_parallel:
194+
assert config.Global.enable_partial_send_recv is False, (
195+
"if config.Distributed.pp_degree > 1 and config.Model.sequence_parallel is True, "
196+
"config.Global.enable_partial_send_recv should be set False."
197+
)
198+
199+
strategy.pipeline_configs = {
200+
"accumulate_steps": config.Global.local_batch_size // config.Global.micro_batch_size,
201+
"micro_batch_size": config.Global.micro_batch_size,
202+
"enable_partial_send_recv": config.Global.enable_partial_send_recv,
203+
}
204+
"""
205+
206+
# set control in tensor parallel
207+
print("init_dist_env asdfasdfasdf niuliling")
208+
strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
209+
210+
hcg = create_hcg(strategy)
211+
set_hcg(hcg)
91212

92213

93214
class ExplicitEnum(Enum):
@@ -940,19 +1061,4 @@ def __call__(self, features: List[dict]):
9401061

9411062

9421063
def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0):
943-
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
944-
945-
random.seed(basic_seed + dataset_rank)
946-
np.random.seed(basic_seed + dataset_rank)
947-
paddle.seed(basic_seed + dataset_rank)
948-
949-
# local_seed/ global_seed is used to control dropout in ModelParallel
950-
local_seed = basic_seed + 59999 + tp_rank * 10 + pp_rank * 1000
951-
global_seed = basic_seed + 100003 + dataset_rank + tp_rank * 10 + pp_rank * 1000
952-
953-
tracker = get_rng_state_tracker()
954-
955-
if "global_seed" not in tracker.states_:
956-
tracker.add("global_seed", global_seed)
957-
if "local_seed" not in tracker.states_:
958-
tracker.add("local_seed", local_seed)
1064+
set_seed(basic_seed)

0 commit comments

Comments
 (0)