|
34 | 34 |
|
35 | 35 | import numpy as np |
36 | 36 | import paddle |
| 37 | +import paddle.distributed as dist |
| 38 | +from paddle.distributed import fleet |
| 39 | +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
37 | 40 | from paddle.io import IterableDataset |
38 | 41 | from paddle.optimizer.lr import LambdaDecay |
39 | 42 |
|
|
52 | 55 | "get_last_checkpoint", |
53 | 56 | "get_scheduler", |
54 | 57 | "set_hyrbid_parallel_seed", |
| 58 | + "init_dist_env", |
55 | 59 | ] |
56 | 60 |
|
57 | 61 |
|
| 62 | +_hcg = None |
| 63 | + |
| 64 | + |
| 65 | +def set_hcg(hcg): |
| 66 | + global _hcg |
| 67 | + _hcg = hcg |
| 68 | + |
| 69 | + |
| 70 | +def get_hcg(): |
| 71 | + global _hcg |
| 72 | + return _hcg |
| 73 | + |
| 74 | + |
58 | 75 | def set_seed(seed: int = 1234, args=None): |
| 76 | + # NOTE(shenliang03): For parameter init seed: |
| 77 | + # seed: dp/mp_undistributed_paramter/sharding is same; others is different |
| 78 | + # For compute seed(dropout): |
| 79 | + # global seed: only mp group is same. |
| 80 | + # local seed: all groups are different |
59 | 81 | if args is None: |
60 | 82 | random.seed(seed) |
61 | 83 | np.random.seed(seed) |
62 | 84 | paddle.seed(seed) |
| 85 | + else: |
| 86 | + hcg = get_hcg() |
| 87 | + if paddle.distributed.get_world_size() > 1: |
| 88 | + # obtain rank message of hybrid parallel |
| 89 | + if hcg is None: |
| 90 | + assert False |
| 91 | + |
| 92 | + mp_rank = hcg.get_model_parallel_rank() |
| 93 | + mp_size = hcg.get_model_parallel_world_size() |
| 94 | + |
| 95 | + pp_rank = hcg.get_stage_id() |
| 96 | + pp_size = hcg.get_pipe_parallel_world_size() |
| 97 | + |
| 98 | + dp_rank = hcg.get_data_parallel_rank() |
| 99 | + dp_size = hcg.get_data_parallel_world_size() |
| 100 | + |
| 101 | + sharding_rank = hcg.get_sharding_parallel_rank() |
| 102 | + # sharding_size = hcg.get_sharding_parallel_world_size() |
| 103 | + else: |
| 104 | + mp_rank, mp_size = 0, 1 |
| 105 | + pp_rank, pp_size = 0, 1 |
| 106 | + dp_rank, dp_size = 0, 1 |
| 107 | + sharding_rank, _ = 0, 1 |
| 108 | + |
| 109 | + # NOTE: the commented seeds are set only for precision validation |
| 110 | + # seed += 100 * pp_rank |
| 111 | + random.seed(seed + 100 * pp_rank) |
| 112 | + np.random.seed(seed + 100 * pp_rank) |
| 113 | + |
| 114 | + # seed = mp_rank + |
| 115 | + # pp_rank * (mp_size) + |
| 116 | + # dp_rank * (mp_size * pp_size) + |
| 117 | + # sharding_rank * (mp_size * pp_size * dp_size) |
| 118 | + # seed offset is order to avoid conflicts with the parameter initialization seed |
| 119 | + |
| 120 | + seed_offset = seed + 1024 + paddle.distributed.get_world_size() |
| 121 | + global_seed = ( |
| 122 | + seed_offset |
| 123 | + + pp_rank * (mp_size) |
| 124 | + + dp_rank * (mp_size * pp_size) |
| 125 | + + sharding_rank * (mp_size * pp_size * dp_size) |
| 126 | + ) |
| 127 | + |
| 128 | + seed_offset += paddle.distributed.get_world_size() |
| 129 | + local_seed = ( |
| 130 | + seed_offset |
| 131 | + + mp_rank |
| 132 | + + pp_rank * (mp_size) |
| 133 | + + dp_rank * (mp_size * pp_size) |
| 134 | + + sharding_rank * (mp_size * pp_size * dp_size) |
| 135 | + ) |
| 136 | + |
| 137 | + tracker = get_rng_state_tracker() |
| 138 | + tracker.add("global_seed", global_seed) |
| 139 | + tracker.add("local_seed", local_seed) |
| 140 | + |
| 141 | + paddle.seed(global_seed) |
| 142 | + |
| 143 | + logger.info("The global seed is set to {} and local seed is set to {}.".format(global_seed, local_seed)) |
| 144 | + |
| 145 | + |
| 146 | +def create_hcg(strategy, hcg_name="HybridCommunicateGroup"): |
| 147 | + if hcg_name == "HybridCommunicateGroup": |
| 148 | + fleet.init(is_collective=True, strategy=strategy) |
| 149 | + hcg = fleet.get_hybrid_communicate_group() |
| 150 | + else: |
| 151 | + dist.init_parallel_env() |
| 152 | + hcg = eval("{}".format(hcg_name))(strategy) |
| 153 | + print("asdfasdf hcg", hcg) |
| 154 | + return hcg |
63 | 155 |
|
64 | | - if args is not None: |
65 | | - if args.use_hybrid_parallel: |
66 | | - from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
67 | | - |
68 | | - random.seed(args.seed + args.dataset_rank) |
69 | | - np.random.seed(args.seed + args.dataset_rank) |
70 | | - paddle.seed(args.seed + args.dataset_rank) |
71 | | - |
72 | | - # local_seed/ global_seed is used to control dropout in ModelParallel |
73 | | - local_seed = args.seed + 59999 + args.tensor_parallel_rank * 10 + args.pipeline_parallel_rank * 1000 |
74 | | - global_seed = ( |
75 | | - args.seed |
76 | | - + 100003 |
77 | | - + args.dataset_rank |
78 | | - + args.tensor_parallel_rank * 10 |
79 | | - + args.pipeline_parallel_rank * 1000 |
80 | | - ) |
81 | | - tracker = get_rng_state_tracker() |
82 | 156 |
|
83 | | - if "global_seed" not in tracker.states_: |
84 | | - tracker.add("global_seed", global_seed) |
85 | | - if "local_seed" not in tracker.states_: |
86 | | - tracker.add("local_seed", local_seed) |
| 157 | +def init_dist_env( |
| 158 | + tensor_parallel_degree=1, sharding_parallel_degree=1, pipeline_parallel_degree=1, data_parallel_degree=1, seed=1 |
| 159 | +): |
| 160 | + |
| 161 | + strategy = fleet.DistributedStrategy() |
| 162 | + |
| 163 | + def is_segment_parallel_supported(): |
| 164 | + import inspect |
| 165 | + |
| 166 | + members = [name for (name, date) in inspect.getmembers(fleet.HybridCommunicateGroup)] |
| 167 | + return "get_sep_parallel_world_size" in members |
| 168 | + |
| 169 | + if tensor_parallel_degree == 1 and sharding_parallel_degree == 1: |
| 170 | + if is_segment_parallel_supported(): |
| 171 | + order = ["pp", "dp", "sharding", "sep", "mp"] |
| 172 | + else: |
| 173 | + order = ["pp", "dp", "sharding", "mp"] |
| 174 | + else: |
| 175 | + if is_segment_parallel_supported(): |
| 176 | + order = ["dp", "pp", "sharding", "sep", "mp"] |
87 | 177 | else: |
88 | | - random.seed(args.seed) |
89 | | - np.random.seed(args.seed) |
90 | | - paddle.seed(args.seed) |
| 178 | + order = ["dp", "pp", "sharding", "mp"] |
| 179 | + |
| 180 | + strategy.hybrid_configs = { |
| 181 | + "dp_degree": data_parallel_degree, |
| 182 | + "mp_degree": tensor_parallel_degree, |
| 183 | + "pp_degree": pipeline_parallel_degree, |
| 184 | + "sharding_degree": sharding_parallel_degree, |
| 185 | + "order": order, |
| 186 | + } |
| 187 | + |
| 188 | + # TODO(wawltor) The inference parallel do not support the pipeline mode |
| 189 | + |
| 190 | + """ |
| 191 | + if pipeline_parallel_degree > 1: |
| 192 | + if "sequence_parallel" in config.Model: |
| 193 | + if config.Model.sequence_parallel: |
| 194 | + assert config.Global.enable_partial_send_recv is False, ( |
| 195 | + "if config.Distributed.pp_degree > 1 and config.Model.sequence_parallel is True, " |
| 196 | + "config.Global.enable_partial_send_recv should be set False." |
| 197 | + ) |
| 198 | +
|
| 199 | + strategy.pipeline_configs = { |
| 200 | + "accumulate_steps": config.Global.local_batch_size // config.Global.micro_batch_size, |
| 201 | + "micro_batch_size": config.Global.micro_batch_size, |
| 202 | + "enable_partial_send_recv": config.Global.enable_partial_send_recv, |
| 203 | + } |
| 204 | + """ |
| 205 | + |
| 206 | + # set control in tensor parallel |
| 207 | + print("init_dist_env asdfasdfasdf niuliling") |
| 208 | + strategy.tensor_parallel_configs = {"tensor_init_seed": seed} |
| 209 | + |
| 210 | + hcg = create_hcg(strategy) |
| 211 | + set_hcg(hcg) |
91 | 212 |
|
92 | 213 |
|
93 | 214 | class ExplicitEnum(Enum): |
@@ -940,19 +1061,4 @@ def __call__(self, features: List[dict]): |
940 | 1061 |
|
941 | 1062 |
|
942 | 1063 | def set_hyrbid_parallel_seed(basic_seed, dataset_rank, tp_rank, pp_rank=0): |
943 | | - from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker |
944 | | - |
945 | | - random.seed(basic_seed + dataset_rank) |
946 | | - np.random.seed(basic_seed + dataset_rank) |
947 | | - paddle.seed(basic_seed + dataset_rank) |
948 | | - |
949 | | - # local_seed/ global_seed is used to control dropout in ModelParallel |
950 | | - local_seed = basic_seed + 59999 + tp_rank * 10 + pp_rank * 1000 |
951 | | - global_seed = basic_seed + 100003 + dataset_rank + tp_rank * 10 + pp_rank * 1000 |
952 | | - |
953 | | - tracker = get_rng_state_tracker() |
954 | | - |
955 | | - if "global_seed" not in tracker.states_: |
956 | | - tracker.add("global_seed", global_seed) |
957 | | - if "local_seed" not in tracker.states_: |
958 | | - tracker.add("local_seed", local_seed) |
| 1064 | + set_seed(basic_seed) |
0 commit comments