Skip to content

Commit c5cc427

Browse files
aoyulongzhaoyingliaJZ-LIANGzhaoyingliCaozhou1995
authored
[Cherry-pick][Auto Parallel] Improve the APIs (#46164)
* [AutoParallel] adapt gradient merge pass (#45915) * adapt gradient merge * fix op_role * fix strategy * [Auto Parallel] Gradient Fuse Allreduce (#45643) * bugfix (#45332) * dist embedding support lookup table v1 * add unitest * customize wait_comm * group gradients * bugfix * update program * [Auto Parallel] Improve the APIs (#45776) * [Auto Parallel] Use c++ dist attr in the completion process * [Auto Parallel] Add minor changes * [Auto Parallel] Use c++ dist attr in the completion process * [Auto Parallel] Add minor changes * [Auto Parallel] Add the serialization process for dist attrs * [Auto Parallel] Remove unnecessary comments * [Auto Parallel] Fix some bugs * [Auto Parallel] Fix the code style * [Auto Parallel] Remove unnecessary impls * [Auto Parallel] Fix the importing error * [Auto Parallel] Fix the copy from bugs of op dist attr * [Auto Parallel] Replace the use of constexpr if * [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh * [Auto Parallel] Change API of the completion unittest * [Auto Parallel] Fix the bug when set_attr an int * [Auto Parallel] Add the unittest for the serialization * [Auto Parallel] Add some unit tests * [Auto Paralle] Unify the strategy * [Auto Parallel] Improve the engine api * [Auto Parallel] Reset the changes made to the framework * [Auto Parallel] Change the engine unittest * [Auto Parallel] Update API of the completion and partitioner * [Auto Parallel] Update unit tests using engine api * update shard annotation * [Auto Parallel] Remove the modifications of other modules * [Auto Parallel] Add docs for APIs * add new strategy * [Auto Parallel] Replace the logger * [Auto Parallel] Restore the test_program.py * [Auto Parallel] Change the import rules * [Auto Parallel] Add the examples for Engine * [Auto Parallel] Do some minor changes * [Auto Parallel] Remove yaml dependency * [Auto Parallel] Fix the unittests * add valid after train * bug fix Co-authored-by: zhaoyingli <[email protected]> Co-authored-by: caozhou <[email protected]> Co-authored-by: caozhou <[email protected]> * [Auto Parallel] Bugfix allreduce fuse for MP (#46086) * bugfix * bugfix * typos fixed * update strategy (#46138) Co-authored-by: zhaoyingli <[email protected]> Co-authored-by: JZ-LIANG <[email protected]> Co-authored-by: zhaoyingli <[email protected]> Co-authored-by: caozhou <[email protected]> Co-authored-by: caozhou <[email protected]>
1 parent 860f607 commit c5cc427

File tree

82 files changed

+4235
-2901
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+4235
-2901
lines changed

python/paddle/distributed/auto_parallel/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .interface import shard_tensor # noqa: F401
16-
from .interface import shard_op # noqa: F401
15+
from .strategy import Strategy
1716
from .process_mesh import ProcessMesh
18-
from .reshard import Resharder # noqa: F401
19-
from .cost_model import estimate_cost
17+
from .engine import Engine
18+
from .interface import shard_tensor
19+
from .interface import shard_op
20+
from .interface import recompute
21+
from .interface import fetch
2022

2123
__all__ = []
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
15+
from collections import defaultdict
16+
17+
# _g_default_config[category][field] = default_value
18+
_g_default_config = defaultdict(dict)
19+
20+
21+
def get_category_default_config(category):
22+
return _g_default_config[category]
23+
24+
25+
def set_category_default_config(category, default_value):
26+
_g_default_config[category] = default_value
27+
28+
29+
def get_field_default_config(category, field):
30+
return _g_default_config[category][field]
31+
32+
33+
def set_field_default_config(category, field, default_value):
34+
_g_default_config[category][field] = default_value
35+
36+
37+
NOT_FOUND = "not_found"
38+
39+
#########################################
40+
# base configuration
41+
#########################################
42+
BASE = "base"
43+
set_field_default_config(BASE, "auto_mode", "semi")
44+
set_field_default_config(BASE, "gradient_scale", True)
45+
set_field_default_config(BASE, "use_cache", True)
46+
set_field_default_config(BASE, "return_numpy", True)
47+
set_field_default_config(BASE, "all_ranks", False)
48+
set_field_default_config(BASE, "split_data", False)
49+
set_field_default_config(BASE, "seed", None)
50+
set_field_default_config(BASE, "reinit", False) # Only for debug
51+
52+
#########################################
53+
# recompute configuration
54+
#########################################
55+
RECOMPUTE = "recompute"
56+
set_field_default_config(RECOMPUTE, "enable", False)
57+
set_field_default_config(RECOMPUTE, "checkpoints", None)
58+
set_field_default_config(RECOMPUTE, "enable_tuning", False)
59+
60+
#########################################
61+
# AMP configuration
62+
#########################################
63+
AMP = "amp"
64+
set_field_default_config(AMP, "enable", False)
65+
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
66+
set_field_default_config(AMP, "incr_every_n_steps", 1000)
67+
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
68+
set_field_default_config(AMP, "incr_ratio", 2.0)
69+
set_field_default_config(AMP, "decr_ratio", 0.8)
70+
set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
71+
set_field_default_config(AMP, "custom_white_list", [])
72+
set_field_default_config(AMP, "custom_black_list", [])
73+
set_field_default_config(AMP, "custom_black_varnames", [])
74+
set_field_default_config(AMP, "use_pure_fp16", False)
75+
set_field_default_config(AMP, "use_fp16_guard", True)
76+
set_field_default_config(AMP, "use_optimizer_fp16", False)
77+
78+
#########################################
79+
# sharding configuration
80+
#########################################
81+
SHARDING = "sharding"
82+
set_field_default_config(SHARDING, "enable", False)
83+
set_field_default_config(SHARDING, "stage", 1)
84+
set_field_default_config(SHARDING, "degree", 8)
85+
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
86+
set_field_default_config(SHARDING, "enable_tuning", False)
87+
set_field_default_config(SHARDING, "tuning_range", [])
88+
89+
#########################################
90+
# gradient merge configuration
91+
#########################################
92+
GRADIENT_MERGE = "gradient_merge"
93+
set_field_default_config(GRADIENT_MERGE, "enable", False)
94+
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
95+
set_field_default_config(GRADIENT_MERGE, "avg", True)
96+
97+
#########################################
98+
# quantization configuration
99+
#########################################
100+
QAT = "qat"
101+
set_field_default_config(QAT, "enable", False)
102+
set_field_default_config(QAT, "channel_wise_abs_max", True)
103+
set_field_default_config(QAT, "weight_bits", 8)
104+
set_field_default_config(QAT, "activation_bits", 8)
105+
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
106+
set_field_default_config(QAT, "algo", None)
107+
108+
# #########################################
109+
# auto tuning configuration
110+
# #########################################
111+
TUNING = "tuning"
112+
set_field_default_config(TUNING, "enable", False)
113+
set_field_default_config(TUNING, "batch_size", 1)
114+
set_field_default_config(TUNING, "dataset", None)
115+
set_field_default_config(TUNING, "profile_start_step", 1)
116+
set_field_default_config(TUNING, "profile_end_step", 1)
117+
set_field_default_config(TUNING, "run_after_tuning", True)
118+
set_field_default_config(TUNING, "verbose", True)

python/paddle/distributed/auto_parallel/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import warnings
1717
import logging
1818
import numpy as np
19-
from ..utils import get_logger
19+
from .utils import get_logger
2020

2121

2222
class Converter(object):

python/paddle/distributed/auto_parallel/dist_attribute.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,17 @@ def mark_annotated_as(self, dist_attr):
173173
def clear_annotated(self):
174174
self._is_annotated.clear()
175175

176+
def __eq__(self, other):
177+
if not isinstance(other, TensorDistributedAttribute):
178+
return False
179+
if self.process_mesh != other.process_mesh:
180+
return False
181+
if self.dims_mapping != other.dims_mapping:
182+
return False
183+
if self._is_annotated != other._is_annotated:
184+
return False
185+
return True
186+
176187
def __str__(self):
177188
str = "\n\ttensor_dist_attr = {"
178189
if self.is_annotated("process_mesh"):
@@ -486,6 +497,27 @@ def is_annotated_output_dims_mapping(self, name):
486497
else:
487498
return False
488499

500+
def __eq__(self, other):
501+
if not isinstance(other, OperatorDistributedAttribute):
502+
return False
503+
if self.process_mesh != other.process_mesh:
504+
return False
505+
if self.op_type != other.op_type:
506+
return False
507+
if self.impl_type != other.impl_type:
508+
return False
509+
if self.impl_idx != other.impl_idx:
510+
return False
511+
if self._is_annotated != other._is_annotated:
512+
return False
513+
if self._is_recompute != other._is_recompute:
514+
return False
515+
if self.inputs_dist_attrs != other.inputs_dist_attrs:
516+
return False
517+
if self.outputs_dist_attrs != other.outputs_dist_attrs:
518+
return False
519+
return True
520+
489521
def __str__(self):
490522
str = "\n\top_dist_attr = {"
491523
if self.is_annotated("process_mesh"):

python/paddle/distributed/auto_parallel/dist_context.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def __init__(self,
126126
# A flag indicates whether the used parallelism is data parallel
127127
self._data_parallel = False
128128

129-
# flag whether using `to_static`
130-
self._dygraph_mode = False
131-
132129
@property
133130
def serial_main_program(self):
134131
return self._serial_main_program

python/paddle/distributed/auto_parallel/dist_op.py

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .dist_attribute import append_op_output_suffix
2424
from .dist_attribute import get_tensor_dist_attr_field_keys
2525
from .dist_attribute import get_op_dist_attr_field_keys
26+
from .utils import convert_to_shard_spec, verify_shard_spec
2627

2728

2829
class DistributedOperator:
@@ -248,23 +249,106 @@ def __deepcopy__(self, memo):
248249
return result
249250

250251

251-
class DistributedModule:
252+
class DistributedOperatorHelper:
252253

253-
def __init__(self, serial_module, dist_attr=None):
254-
self._serial_module = serial_module
255-
self._dist_attr = dist_attr
254+
def __init__(self, serial_op, process_mesh, in_dims_mappings,
255+
out_dims_mappings):
256+
self._serial_op = serial_op
257+
self._process_mesh = process_mesh
258+
self._in_dims_mappings = in_dims_mappings
259+
self._out_dims_mappings = out_dims_mappings
256260

257261
def __call__(self, *args, **kwargs):
258-
from .dist_context import get_default_distributed_context
262+
tensor_to_dims_mapping = {}
263+
index = 0
264+
if self._in_dims_mappings:
265+
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
266+
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
267+
for arg in args:
268+
if isinstance(arg, Variable) and self._in_dims_mappings:
269+
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
270+
index += 1
271+
for arg in kwargs.values() and self._in_dims_mappings:
272+
if isinstance(arg, Variable):
273+
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
274+
index += 1
275+
259276
default_prog = paddle.fluid.default_main_program()
260277
cur_block = default_prog.current_block()
261278
op_size = len(cur_block.ops)
262-
output = self._serial_module(*args, **kwargs)
279+
output = self._serial_op(*args, **kwargs)
263280
new_op_size = len(cur_block.ops)
281+
282+
if isinstance(output, tuple) or isinstance(output, list):
283+
new_output = list(output)
284+
elif isinstance(output, Variable):
285+
new_output = [output]
286+
else:
287+
raise ValueError("Unrecognized outpout.")
288+
289+
if self._out_dims_mappings:
290+
assert len(new_output) == len(self._out_dims_mappings), \
291+
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
292+
for i, item in enumerate(new_output):
293+
if isinstance(item, Variable) and self._out_dims_mappings:
294+
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
295+
296+
from .dist_context import get_default_distributed_context
264297
default_dist_ctx = get_default_distributed_context()
265298
for idx in range(op_size, new_op_size):
266299
op = cur_block.ops[idx]
267-
dist_op = DistributedOperator(op, self._dist_attr)
268-
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
300+
dist_op = DistributedOperator(op)
301+
for name in dist_op.serial_op.input_arg_names:
302+
if name in tensor_to_dims_mapping.keys():
303+
tensor = dist_op.get_serial_input(name)
304+
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
305+
name)
306+
dims_mapping = tensor_to_dims_mapping[name]
307+
if tensor is None:
308+
tensor_shape = []
309+
else:
310+
if tensor.type == core.VarDesc.VarType.READER \
311+
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
312+
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
313+
tensor_shape = []
314+
else:
315+
tensor_shape = tensor.shape
316+
if dims_mapping is not None:
317+
dims_mapping = tensor_to_dims_mapping[name]
318+
shard_spec = convert_to_shard_spec(
319+
dims_mapping, self._process_mesh)
320+
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
321+
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
322+
name, shard_spec, tensor_shape, self._process_mesh)
323+
tensor_dist_attr.dims_mapping = dims_mapping
324+
tensor_dist_attr.mark_annotated("dims_mapping")
325+
for name in dist_op.serial_op.output_arg_names:
326+
if name in tensor_to_dims_mapping.keys():
327+
tensor = dist_op.get_serial_output(name)
328+
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
329+
name)
330+
dims_mapping = tensor_to_dims_mapping[name]
331+
if tensor is None:
332+
tensor_shape = []
333+
else:
334+
if tensor.type == core.VarDesc.VarType.READER \
335+
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
336+
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
337+
tensor_shape = []
338+
else:
339+
tensor_shape = tensor.shape
340+
if dims_mapping is not None:
341+
dims_mapping = tensor_to_dims_mapping[name]
342+
shard_spec = convert_to_shard_spec(
343+
dims_mapping, self._process_mesh)
344+
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
345+
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
346+
name, shard_spec, tensor_shape, self._process_mesh)
347+
tensor_dist_attr.dims_mapping = dims_mapping
348+
tensor_dist_attr.mark_annotated("dims_mapping")
349+
dist_op.dist_attr.process_mesh = self._process_mesh
350+
if self._process_mesh is not None:
351+
dist_op.dist_attr.mark_annotated("process_mesh")
269352
default_dist_ctx.add_dist_op_for_program(dist_op)
353+
270354
return output

0 commit comments

Comments
 (0)