Skip to content

Commit 093bfe5

Browse files
authored
[cherry-pick]【Auto-Parallel|Comm】fix set_skip_gc_vars to adapt to all scenarios and fix communication hang issue on GPU-H (#70687)
* 【Auto-Parallel】Refactor set_skip_gc_vars to adapt to all scenarios (#70615) * 【Auto-Parallel | Comm】fix communication hang issue on GPU-H (#70360)
1 parent 061bc35 commit 093bfe5

File tree

6 files changed

+466
-30
lines changed

6 files changed

+466
-30
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,19 @@ void BindOperation(py::module *m) {
946946
phi::IntArray(val));
947947
self.set_attribute(attr_name, attr);
948948
})
949+
.def("set_str_array_attr",
950+
[](Operation &self,
951+
std::string &attr_name,
952+
const std::vector<std::string> &val) {
953+
std::vector<Attribute> val_attr;
954+
for (auto &str : val) {
955+
val_attr.emplace_back(
956+
StrAttribute::get(pir::IrContext::Instance(), str));
957+
}
958+
auto attr =
959+
pir::ArrayAttribute::get(pir::IrContext::Instance(), val_attr);
960+
self.set_attribute(attr_name, attr);
961+
})
949962
.def("set_str_attr",
950963
[](Operation &self, std::string &attr_name, std::string &val) {
951964
self.set_attribute(

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from paddle.distributed.passes.pass_base import new_pass
3535
from paddle.distributed.passes.pass_utils import (
3636
_split_program_into_forward_backward_optimize,
37-
set_pir_skip_gc_vars,
37+
set_skip_gc_vars,
3838
)
3939
from paddle.framework import (
4040
IrGraph,
@@ -943,7 +943,7 @@ def _parallel_pir(self, mode):
943943
opt_job.set_micro_batch_id(0)
944944
jobs.append(opt_job)
945945

946-
type_to_program = set_pir_skip_gc_vars(
946+
type_to_program = set_skip_gc_vars(
947947
self._strategy.gradient_merge.k_steps,
948948
job_types,
949949
sub_programs,

python/paddle/distributed/passes/pass_utils.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,21 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
268268
thus a sub_program's vars might be used as the op's input of the later sub_program,
269269
and these vars cannot be gc after executing current sub_program.
270270
"""
271+
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
272+
"FLAGS_enable_pir_api"
273+
]:
274+
return _set_skip_gc_vars_in_pir(
275+
num_micro_batches, job_types, sub_programs, jobs
276+
)
277+
else:
278+
return _set_skip_gc_vars_in_old_ir(
279+
num_micro_batches, job_types, sub_programs, jobs
280+
)
281+
282+
283+
def _set_skip_gc_vars_in_old_ir(
284+
num_micro_batches, job_types, sub_programs, jobs
285+
):
271286
assert num_micro_batches >= 1, "num_micro_batches needs to be >= 1"
272287
type_to_program = dict(zip(job_types, sub_programs))
273288

@@ -300,32 +315,54 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
300315
return type_to_program
301316

302317

303-
def set_pir_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
318+
def _set_skip_gc_vars_in_pir(num_micro_batches, job_types, sub_programs, jobs):
304319
assert num_micro_batches >= 1, "num_micro_batches needs to be >= 1"
305-
type_to_var_names = {}
306320
type_to_program = dict(zip(job_types, sub_programs))
307-
for job_type, program in type_to_program.items():
308-
type_to_var_names[job_type] = set()
309-
ops = program.global_block().ops
310-
for op in ops:
311-
if op.name() == "builtin.shadow_output":
312-
# if a value is renamed by shadow_output,
313-
# it will be used by other sub_programs
314-
type_to_var_names[job_type].add(op.attrs()["output_name"])
315-
if job_type in ["backward", "backward_w"]:
316-
assert (
317-
len(type_to_var_names[job_type]) == 0
318-
), f"The {job_type} sub_program can't have skip_gc_vars. But it is {type_to_var_names[job_type]}."
319321

322+
# step1: Get all required vars of every sub_program that are non-persistable and not in op's no_need_buffer.
323+
type_to_required_vars = {}
320324
no_need_buffer_vars = core.get_no_need_buffer_values(type_to_program)
325+
for job_type, program in type_to_program.items():
326+
required_vars = set()
327+
persistable_vars = set()
328+
for key in program.global_block().kwargs():
329+
required_vars.add(key)
330+
for op in program.global_block().ops:
331+
for var in op.operands_source():
332+
if var.has_name:
333+
required_vars.add(var.name)
334+
if var.persistable:
335+
persistable_vars.add(var.name)
336+
for var in op.results():
337+
if var.has_name:
338+
required_vars.add(var.name)
339+
if var.persistable:
340+
persistable_vars.add(var.name)
341+
if job_type in no_need_buffer_vars:
342+
required_vars -= no_need_buffer_vars[job_type]
343+
required_vars -= persistable_vars
344+
type_to_required_vars[job_type] = required_vars
321345

322-
for job_type, var_set in no_need_buffer_vars.items():
323-
if len(var_set) > 0:
324-
type_to_var_names[job_type] = type_to_var_names[job_type] - var_set
325-
326-
for job in jobs:
346+
# step2: Set `skip_gc_vars` for each job
347+
suffixed_required_vars = [set() for i in range(num_micro_batches)]
348+
num_jobs = len(jobs)
349+
for job_id in reversed(range(num_jobs)):
350+
job = jobs[job_id]
327351
job_type = job.type()
328-
job.set_skip_gc_vars(type_to_var_names[job_type])
352+
required_vars = type_to_required_vars[job_type]
353+
micro_batch_id = job.micro_batch_id()
354+
skip_gc_vars = required_vars & suffixed_required_vars[micro_batch_id]
355+
logger.debug(
356+
f"Skip gc vars for {job_type}-({micro_batch_id}): {skip_gc_vars}"
357+
)
358+
359+
if job_type in ["send_backward", "backward_w"]:
360+
assert (
361+
len(skip_gc_vars) == 0
362+
), f"When enabling pipeline parallelism strategy, the skip_gc_vars for {job_type} subprogram must be empty, but it is {skip_gc_vars}."
363+
364+
job.set_skip_gc_vars(skip_gc_vars)
365+
suffixed_required_vars[micro_batch_id] |= required_vars
329366

330367
return type_to_program
331368

0 commit comments

Comments
 (0)