Skip to content

Commit ce844b8

Browse files
authored
【Auto-Parallel | Comm】fix communication hang issue on GPU-H (#70360)
1 parent 0452626 commit ce844b8

File tree

4 files changed

+407
-7
lines changed

4 files changed

+407
-7
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,19 @@ void BindOperation(py::module *m) {
935935
phi::IntArray(val));
936936
self.set_attribute(attr_name, attr);
937937
})
938+
.def("set_str_array_attr",
939+
[](Operation &self,
940+
std::string &attr_name,
941+
const std::vector<std::string> &val) {
942+
std::vector<Attribute> val_attr;
943+
for (auto &str : val) {
944+
val_attr.emplace_back(
945+
StrAttribute::get(pir::IrContext::Instance(), str));
946+
}
947+
auto attr =
948+
pir::ArrayAttribute::get(pir::IrContext::Instance(), val_attr);
949+
self.set_attribute(attr_name, attr);
950+
})
938951
.def("set_str_attr",
939952
[](Operation &self, std::string &attr_name, std::string &val) {
940953
self.set_attribute(

python/paddle/distributed/passes/pass_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def _set_skip_gc_vars_in_pir(num_micro_batches, job_types, sub_programs, jobs):
356356
f"Skip gc vars for {job_type}-({micro_batch_id}): {skip_gc_vars}"
357357
)
358358

359-
if job_type in ["backward", "backward_w"]:
359+
if job_type in ["send_backward", "backward_w"]:
360360
assert (
361361
len(skip_gc_vars) == 0
362362
), f"When enabling pipeline parallelism strategy, the skip_gc_vars for {job_type} subprogram must be empty, but it is {skip_gc_vars}."

python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_1f1b.py

Lines changed: 263 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,45 @@
1414

1515
import logging
1616

17+
import paddle
1718
from paddle.base import core
1819
from paddle.distributed.auto_parallel.static.cost import calc_time_by_cost_model
20+
from paddle.framework import (
21+
_current_expected_place_ as _get_device,
22+
)
1923

2024
from ...utils.log_utils import get_logger
2125
from ..pass_base import register_pass
2226
from ..pass_utils import (
2327
AutoParallelStreamType,
2428
_add_event_dependency,
2529
_program_for_fthenb_and_1f1b,
26-
_split_program_into_forward_backward_optimize,
30+
forward_complete_op_role,
2731
split_program,
2832
)
2933
from .pipeline_pass_base import PipelinePassBase
3034

35+
RECV_FORWARD = "recv_forward"
3136
FORWARD = "forward"
3237
BACKWARD = "backward"
38+
SEND_BACKWARD = "send_backward"
3339
OPT = "optimizer"
3440

3541
logger = get_logger(logging.INFO)
3642

3743

3844
@register_pass("pipeline_scheduler_1F1B")
3945
class Pipeline1F1BPass(PipelinePassBase):
46+
4047
def __init__(self):
4148
super().__init__()
4249
self.jobs_in_stable_phase = [BACKWARD, FORWARD]
50+
self.jobs_in_stable_phase_in_pir = [
51+
BACKWARD,
52+
RECV_FORWARD,
53+
SEND_BACKWARD,
54+
FORWARD,
55+
]
4356
self.set_attr("enable_backward_forward_overlap", 0)
4457

4558
# Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj.
@@ -168,6 +181,9 @@ def _backward_forward_overlap(self, backward_program, forward_program):
168181
)
169182

170183
def _create_job_list(self):
184+
if self._in_pir_mode:
185+
return self._create_job_list_in_pir()
186+
171187
num_micro_batches = self.get_attr("num_micro_batches")
172188
pp_stage = self.get_attr("pp_stage")
173189
pp_degree = self.get_attr("pp_degree")
@@ -212,6 +228,61 @@ def _create_job_list(self):
212228
job_list.append(opt_job)
213229
return job_list
214230

231+
def _create_job_list_in_pir(self):
232+
num_micro_batches = self.get_attr("num_micro_batches")
233+
pp_stage = self.get_attr("pp_stage")
234+
pp_degree = self.get_attr("pp_degree")
235+
236+
job_list = []
237+
assert (
238+
pp_degree <= num_micro_batches
239+
), "Num of micro batches should larger than or equal to pp degree."
240+
241+
micro_batch_in_warmup = pp_degree - pp_stage
242+
micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup
243+
244+
forward_micro_batch_id = 0
245+
for i in range(micro_batch_in_warmup):
246+
recv_fwd_job = core.Job(RECV_FORWARD)
247+
recv_fwd_job.set_micro_batch_id(forward_micro_batch_id)
248+
job_list.append(recv_fwd_job)
249+
250+
forward_job = core.Job(FORWARD)
251+
forward_job.set_micro_batch_id(forward_micro_batch_id)
252+
job_list.append(forward_job)
253+
forward_micro_batch_id += 1
254+
255+
backward_micro_batch_id = 0
256+
for i in range(micro_batch_in_1f1b):
257+
for job_type in self.jobs_in_stable_phase_in_pir:
258+
job = core.Job(job_type)
259+
micro_batch_id = (
260+
forward_micro_batch_id
261+
if job_type.startswith(FORWARD)
262+
or job_type.startswith(RECV_FORWARD)
263+
else backward_micro_batch_id
264+
)
265+
job.set_micro_batch_id(micro_batch_id)
266+
job_list.append(job)
267+
forward_micro_batch_id += 1
268+
backward_micro_batch_id += 1
269+
270+
for i in range(micro_batch_in_warmup):
271+
backward_job = core.Job(BACKWARD)
272+
backward_job.set_micro_batch_id(backward_micro_batch_id)
273+
job_list.append(backward_job)
274+
275+
send_bwd_job = core.Job(SEND_BACKWARD)
276+
send_bwd_job.set_micro_batch_id(backward_micro_batch_id)
277+
job_list.append(send_bwd_job)
278+
279+
backward_micro_batch_id += 1
280+
281+
opt_job = core.Job(OPT)
282+
opt_job.set_micro_batch_id(0)
283+
job_list.append(opt_job)
284+
return job_list
285+
215286
def _multistreaming_for_overlapping(self, programs, job_type):
216287
num_programs = len(programs)
217288
higher_stream_priority = -1
@@ -343,16 +414,17 @@ def _partial_pir_programs(self, program):
343414
not enable_send_recv_overlap
344415
), "PIR does not support 1F1B with enable_send_recv_overlap yet."
345416

346-
types = [FORWARD, BACKWARD, OPT]
347-
sub_program_list = _split_program_into_forward_backward_optimize(
348-
program, enable_send_recv_overlap
349-
)
417+
types = [RECV_FORWARD, FORWARD, BACKWARD, SEND_BACKWARD, OPT]
418+
prog_splitter = ProgramSplitter(program, types)
419+
sub_program_list = prog_splitter.split_programs()
350420

351421
for i in range(len(types)):
352422
logger.debug(
353423
f"type = {types[i]}, sub_programs = {sub_program_list[i]}\n"
354424
)
355-
logger.debug(f"jobs_in_stable_phase = {self.jobs_in_stable_phase}")
425+
logger.debug(
426+
f"jobs_in_stable_phase = {self.jobs_in_stable_phase_in_pir}"
427+
)
356428
return types, sub_program_list
357429

358430
def _split_program_for_overlapping(self, job_type, program, split_points):
@@ -376,3 +448,188 @@ def is_comm_op_valid_to_overlap(self, op):
376448
and op.dist_attr.execution_stream
377449
== AutoParallelStreamType.CALC_STREAM.value
378450
)
451+
452+
453+
class ProgramSplitter:
454+
def __init__(self, main_program, job_types):
455+
assert job_types == [
456+
RECV_FORWARD,
457+
FORWARD,
458+
BACKWARD,
459+
SEND_BACKWARD,
460+
OPT,
461+
]
462+
self._overlap_send_recv(main_program)
463+
forward_complete_op_role(main_program)
464+
self.job_types = job_types
465+
self.complete_ops = main_program.global_block().ops
466+
self.programs = self._clone_programs(main_program)
467+
self.ops_dict = {
468+
key: prog.global_block().ops for key, prog in self.programs.items()
469+
}
470+
self.blocks_dict = {
471+
key: prog.global_block() for key, prog in self.programs.items()
472+
}
473+
474+
self.cur_place = self._get_cur_place()
475+
476+
def _overlap_send_recv(self, program):
477+
# TODO(liym27): This function should not be in ProgramSplitter, move it to pipeline_pass_base.py after vpp fixed.
478+
for block in program.blocks:
479+
for op in block.ops:
480+
if op.name() == "pd_op.send_v2":
481+
op.set_bool_attr("dynamic_shape", False)
482+
op.set_bool_attr("use_calc_stream", True)
483+
ring_id = op.attrs()["ring_id"]
484+
op.set_execution_stream("send_recv_stream")
485+
op.set_scheduling_priority(0)
486+
elif op.name() == "pd_op.recv_v2":
487+
op.set_bool_attr("dynamic_shape", False)
488+
op.set_bool_attr("use_calc_stream", True)
489+
op.set_execution_stream("send_recv_stream")
490+
op.set_scheduling_priority(0)
491+
492+
def _clone_programs(self, program):
493+
prog_dict = {}
494+
for job_type in self.job_types:
495+
prog_dict[job_type] = program.clone()
496+
return prog_dict
497+
498+
def _get_cur_place(self):
499+
place = _get_device()
500+
if isinstance(place, paddle.framework.CUDAPlace):
501+
place = paddle.framework.CUDAPlace(
502+
paddle.distributed.ParallelEnv().dev_id
503+
)
504+
cur_place = paddle.base.libpaddle.Place()
505+
cur_place.set_place(place)
506+
return cur_place
507+
508+
def split_programs(self):
509+
region = "opt"
510+
for op_idx in range(len(self.complete_ops) - 1, -1, -1):
511+
op = self.complete_ops[op_idx]
512+
if op.op_role != -1:
513+
if op.op_role == 1:
514+
region = "bwd"
515+
elif op.op_role == 0:
516+
region = "fwd"
517+
elif op.op_role == 2:
518+
region = "opt"
519+
520+
if region == "opt":
521+
self._erase_op_from_other_programs(op_idx, OPT)
522+
elif region == "bwd" and op.name() == "pd_op.send_v2":
523+
self._handle_func(op_idx, SEND_BACKWARD, self.job_types[4:])
524+
self._erase_op_from_other_programs(op_idx, SEND_BACKWARD)
525+
elif region == "bwd" and op.name() != "pd_op.send_v2":
526+
self._handle_func(op_idx, BACKWARD, self.job_types[3:])
527+
self._erase_op_from_other_programs(op_idx, BACKWARD)
528+
elif region == "fwd" and op.name() != "pd_op.recv_v2":
529+
self._handle_func(op_idx, FORWARD, self.job_types[2:])
530+
self._erase_op_from_other_programs(op_idx, FORWARD)
531+
elif region == "fwd" and op.name() == "pd_op.recv_v2":
532+
self._handle_func(op_idx, RECV_FORWARD, self.job_types[1:])
533+
self._erase_op_from_other_programs(op_idx, RECV_FORWARD)
534+
progs = []
535+
for job_type in self.job_types:
536+
progs.append(self.programs[job_type])
537+
return progs
538+
539+
def _erase_op_from_other_programs(self, op_idx, keep_job_type):
540+
for job_type in self.job_types:
541+
if job_type != keep_job_type:
542+
self.ops_dict[job_type][op_idx].erase()
543+
544+
def _handle_func(self, op_idx, cur_job_type, suffixed_job_types):
545+
for idx in range(self.complete_ops[op_idx].num_results()):
546+
if self._result_is_used(suffixed_job_types, op_idx, idx):
547+
var_name = self._get_or_create_var_name(
548+
self.ops_dict[cur_job_type], op_idx, idx
549+
)
550+
for job_type in suffixed_job_types:
551+
if self._result_is_used([job_type], op_idx, idx):
552+
self._add_dependency_if_necessary(
553+
cur_job_type, job_type, op_idx, idx, var_name
554+
)
555+
self._add_kwarg_and_replace(
556+
self.blocks_dict[job_type],
557+
self.ops_dict[job_type],
558+
op_idx,
559+
idx,
560+
var_name,
561+
)
562+
563+
def _add_dependency_if_necessary(
564+
self, cur_job_type, next_job_type, op_idx, rst_idx, var_name
565+
):
566+
if not (
567+
cur_job_type == BACKWARD and next_job_type == SEND_BACKWARD
568+
) and not (cur_job_type == RECV_FORWARD and next_job_type == FORWARD):
569+
return
570+
571+
first_used_idx = None
572+
first_used_op = None
573+
for used_op in (
574+
self.ops_dict[next_job_type][op_idx].result(rst_idx).all_used_ops()
575+
):
576+
used_idx = self.ops_dict[next_job_type].index(used_op)
577+
if first_used_idx is None or used_idx < first_used_idx:
578+
first_used_idx = used_idx
579+
first_used_op = used_op
580+
self._add_dependency(
581+
self.ops_dict[cur_job_type][op_idx], first_used_op, var_name
582+
)
583+
584+
def _add_dependency(self, recorder_op, waiter_op, name):
585+
'''
586+
Add the extra event dependency of the two operators.
587+
This function mainly aims for the cross-programs in pipeline parallelism,
588+
especial for the 'send_v2' 'recv_v2' etc.
589+
'''
590+
if not recorder_op.has_attr("force_record_event"):
591+
recorder_op.set_bool_attr("force_record_event", True)
592+
recorder_op.set_str_attr("event_to_record", name)
593+
waiter_op.set_str_array_attr("events_to_wait", [name])
594+
595+
def _result_is_used(self, job_types, op_idx, rst_idx):
596+
is_used = False
597+
for job_type in job_types:
598+
is_used = (
599+
is_used
600+
or self.ops_dict[job_type][op_idx].result(rst_idx).use_empty()
601+
is False
602+
)
603+
return is_used
604+
605+
def _get_or_create_var_name(self, cur_sub_ops, op_idx, rst_idx):
606+
var_name = None
607+
# case1: get var_name in current sub-program
608+
op = cur_sub_ops[op_idx]
609+
if op.name() == "pd_op.data" or op.name() == "builtin.parameter":
610+
var_name = op.result(rst_idx).name
611+
else:
612+
# case2: get var_name from shadow_output in complete program
613+
result_var = self.complete_ops[op_idx].result(rst_idx)
614+
shadow_output_op = None
615+
for used_op in result_var.all_used_ops():
616+
if used_op.name() == "builtin.shadow_output":
617+
shadow_output_op = used_op
618+
if shadow_output_op is not None:
619+
var_name = shadow_output_op.attrs()["output_name"]
620+
621+
if var_name is None:
622+
# case3: create var_name in current sub-program
623+
paddle.pir.set_insertion_point_after(op)
624+
var_name = (
625+
f"var_{op_idx}_{self.complete_ops[op_idx].name()}_{rst_idx}"
626+
)
627+
paddle._C_ops.set_persistable_value(op.result(rst_idx), var_name)
628+
return var_name
629+
630+
def _add_kwarg_and_replace(self, block, ops, op_idx, rst_idx, var_name):
631+
ori_result = ops[op_idx].result(rst_idx)
632+
new_result_var = block.add_kwarg(var_name, ori_result.type())
633+
new_result_var.place_attr = self.cur_place
634+
new_result_var.persistable = ori_result.persistable
635+
ops[op_idx].result(rst_idx).replace_all_uses_with(new_result_var)

0 commit comments

Comments
 (0)