Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions llm/gpt-3/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,9 @@ def main():
if data_args.data_cache is not None:
os.makedirs(data_args.data_cache, exist_ok=True)

set_seed(seed=training_args.seed, args=training_args)
paddle.set_device(training_args.device)
if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()

set_seed(seed=training_args.seed)

training_args.eval_iters = 10
training_args.test_iters = training_args.eval_iters * 10
Expand Down
80 changes: 63 additions & 17 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

import numpy as np
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.io import IterableDataset
from paddle.optimizer.lr import LambdaDecay

Expand All @@ -56,32 +58,76 @@


def set_seed(seed: int = 1234, args=None):
# NOTE(shenliang03): For parameter init seed:
# seed: dp/mp_undistributed_paramter/sharding is same; others is different
# For compute seed(dropout):
# global seed: only mp group is same.
# local seed: all groups are different
if args is None:
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)

if args is not None:
if args.use_hybrid_parallel:
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
else:
hcg = fleet.get_hybrid_communicate_group() if hasattr(fleet.fleet, "_hcg") else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasattr(fleet.fleet, "_hcg") 这个属性是初始化分布式了之后才有?

  1. cpu版本paddle
  2. gpu版本paddle 跑cpu
  3. gpu版本跑gpu

2的情况下fleet.fleet._hcg是否是None。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的init之后才有hcg 或者init_dist_env之后才有hcg
2的时候不是none, 当未初始化的时候才是none

if hcg is not None and paddle.distributed.get_world_size() > 1:
# obtain rank message of hybrid parallel

random.seed(args.seed + args.dataset_rank)
np.random.seed(args.seed + args.dataset_rank)
paddle.seed(args.seed + args.dataset_rank)
mp_rank = hcg.get_model_parallel_rank()
mp_size = hcg.get_model_parallel_world_size()

# local_seed/ global_seed is used to control dropout in ModelParallel
local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000
global_seed = args.seed + 100003 + args.dataset_rank
tracker = get_rng_state_tracker()
pp_rank = hcg.get_stage_id()
pp_size = hcg.get_pipe_parallel_world_size()

if "global_seed" not in tracker.states_:
tracker.add("global_seed", global_seed)
if "local_seed" not in tracker.states_:
tracker.add("local_seed", local_seed)
dp_rank = hcg.get_data_parallel_rank()
dp_size = hcg.get_data_parallel_world_size()

sharding_rank = hcg.get_sharding_parallel_rank()
# sharding_size = hcg.get_sharding_parallel_world_size()
else:
random.seed(args.seed)
np.random.seed(args.seed)
paddle.seed(args.seed)
mp_rank, mp_size = 0, 1
pp_rank, pp_size = 0, 1
dp_rank, dp_size = 0, 1
sharding_rank, _ = 0, 1

# NOTE: the commented seeds are set only for precision validation
# seed += 100 * pp_rank
random.seed(seed + 100 * pp_rank)
np.random.seed(seed + 100 * pp_rank)

# seed = mp_rank +
# pp_rank * (mp_size) +
# dp_rank * (mp_size * pp_size) +
# sharding_rank * (mp_size * pp_size * dp_size)
# seed offset is order to avoid conflicts with the parameter initialization seed

seed_offset = seed + 1024 + paddle.distributed.get_world_size()
global_seed = (
seed_offset
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)

seed_offset += paddle.distributed.get_world_size()
local_seed = (
seed_offset
+ mp_rank
+ pp_rank * (mp_size)
+ dp_rank * (mp_size * pp_size)
+ sharding_rank * (mp_size * pp_size * dp_size)
)

tracker = get_rng_state_tracker()
if "global_seed" not in tracker.states_:
tracker.add("global_seed", global_seed)

if "local_seed" not in tracker.states_:
tracker.add("local_seed", local_seed)

paddle.seed(global_seed)

logger.info("The global seed is set to {} and local seed is set to {}.".format(global_seed, local_seed))


class ExplicitEnum(Enum):
Expand Down