|
34 | 34 |
|
35 | 35 | import numpy as np |
36 | 36 | import paddle |
| 37 | +from paddle.distributed import fleet |
| 38 | +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
37 | 39 | from paddle.io import IterableDataset |
38 | 40 | from paddle.optimizer.lr import LambdaDecay |
39 | 41 |
|
|
56 | 58 |
|
57 | 59 |
|
58 | 60 | def set_seed(seed: int = 1234, args=None): |
| 61 | + # NOTE(shenliang03): For parameter init seed: |
| 62 | + # seed: dp/mp_undistributed_paramter/sharding is same; others is different |
| 63 | + # For compute seed(dropout): |
| 64 | + # global seed: only mp group is same. |
| 65 | + # local seed: all groups are different |
59 | 66 | if args is None: |
60 | 67 | random.seed(seed) |
61 | 68 | np.random.seed(seed) |
62 | 69 | paddle.seed(seed) |
63 | 70 |
|
64 | | - if args is not None: |
65 | | - if args.use_hybrid_parallel: |
66 | | - from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
| 71 | + else: |
| 72 | + hcg = fleet.get_hybrid_communicate_group() if hasattr(fleet.fleet, "_hcg") else None |
| 73 | + if hcg is not None and paddle.distributed.get_world_size() > 1: |
| 74 | + # obtain rank message of hybrid parallel |
67 | 75 |
|
68 | | - random.seed(args.seed + args.dataset_rank) |
69 | | - np.random.seed(args.seed + args.dataset_rank) |
70 | | - paddle.seed(args.seed + args.dataset_rank) |
| 76 | + mp_rank = hcg.get_model_parallel_rank() |
| 77 | + mp_size = hcg.get_model_parallel_world_size() |
71 | 78 |
|
72 | | - # local_seed/ global_seed is used to control dropout in ModelParallel |
73 | | - local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000 |
74 | | - global_seed = args.seed + 100003 + args.dataset_rank |
75 | | - tracker = get_rng_state_tracker() |
| 79 | + pp_rank = hcg.get_stage_id() |
| 80 | + pp_size = hcg.get_pipe_parallel_world_size() |
76 | 81 |
|
77 | | - if "global_seed" not in tracker.states_: |
78 | | - tracker.add("global_seed", global_seed) |
79 | | - if "local_seed" not in tracker.states_: |
80 | | - tracker.add("local_seed", local_seed) |
| 82 | + dp_rank = hcg.get_data_parallel_rank() |
| 83 | + dp_size = hcg.get_data_parallel_world_size() |
| 84 | + |
| 85 | + sharding_rank = hcg.get_sharding_parallel_rank() |
| 86 | + # sharding_size = hcg.get_sharding_parallel_world_size() |
81 | 87 | else: |
82 | | - random.seed(args.seed) |
83 | | - np.random.seed(args.seed) |
84 | | - paddle.seed(args.seed) |
| 88 | + mp_rank, mp_size = 0, 1 |
| 89 | + pp_rank, pp_size = 0, 1 |
| 90 | + dp_rank, dp_size = 0, 1 |
| 91 | + sharding_rank, _ = 0, 1 |
| 92 | + |
| 93 | + # NOTE: the commented seeds are set only for precision validation |
| 94 | + # seed += 100 * pp_rank |
| 95 | + random.seed(seed + 100 * pp_rank) |
| 96 | + np.random.seed(seed + 100 * pp_rank) |
| 97 | + |
| 98 | + # seed = mp_rank + |
| 99 | + # pp_rank * (mp_size) + |
| 100 | + # dp_rank * (mp_size * pp_size) + |
| 101 | + # sharding_rank * (mp_size * pp_size * dp_size) |
| 102 | + # seed offset is order to avoid conflicts with the parameter initialization seed |
| 103 | + |
| 104 | + seed_offset = seed + 1024 + paddle.distributed.get_world_size() |
| 105 | + global_seed = ( |
| 106 | + seed_offset |
| 107 | + + pp_rank * (mp_size) |
| 108 | + + dp_rank * (mp_size * pp_size) |
| 109 | + + sharding_rank * (mp_size * pp_size * dp_size) |
| 110 | + ) |
| 111 | + |
| 112 | + seed_offset += paddle.distributed.get_world_size() |
| 113 | + local_seed = ( |
| 114 | + seed_offset |
| 115 | + + mp_rank |
| 116 | + + pp_rank * (mp_size) |
| 117 | + + dp_rank * (mp_size * pp_size) |
| 118 | + + sharding_rank * (mp_size * pp_size * dp_size) |
| 119 | + ) |
| 120 | + |
| 121 | + tracker = get_rng_state_tracker() |
| 122 | + if "global_seed" not in tracker.states_: |
| 123 | + tracker.add("global_seed", global_seed) |
| 124 | + |
| 125 | + if "local_seed" not in tracker.states_: |
| 126 | + tracker.add("local_seed", local_seed) |
| 127 | + |
| 128 | + paddle.seed(global_seed) |
| 129 | + |
| 130 | + logger.info("The global seed is set to {} and local seed is set to {}.".format(global_seed, local_seed)) |
85 | 131 |
|
86 | 132 |
|
87 | 133 | class ExplicitEnum(Enum): |
|
0 commit comments