Skip to content

Commit 4a97ba5

Browse files
authored
[KUNLUN]Revert "revert p2p communication for xpu (#53496)" (#53633)
* Revert "revert p2p communication for xpu (#53496)" This reverts commit eda0c58. * update
1 parent 8075752 commit 4a97ba5

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import os
1716

1817
import numpy as np
@@ -27,10 +26,30 @@
2726
_use_cache = False
2827
_enable_partial_send_recv = True
2928

29+
_xpu_comm_group_started = False
30+
3031
_sync_send = os.environ.get("PADDLE_P2P_SYNC_SEND", "0")
3132
_sync_send = _sync_send.lower() in ['1', 'true']
3233

3334

35+
def _xpu_comm_group_start():
36+
if not paddle.is_compiled_with_xpu():
37+
return
38+
global _xpu_comm_group_started
39+
assert not _xpu_comm_group_started
40+
framework.core.ProcessGroupBKCL.group_start()
41+
_xpu_comm_group_started = True
42+
43+
44+
def _xpu_comm_group_end():
45+
if not paddle.is_compiled_with_xpu():
46+
return
47+
global _xpu_comm_group_started
48+
if _xpu_comm_group_started:
49+
framework.core.ProcessGroupBKCL.group_end()
50+
_xpu_comm_group_started = False
51+
52+
3453
def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
3554
global _hcg, _use_cache, _enable_partial_send_recv
3655
_hcg = hcg
@@ -357,6 +376,7 @@ def _p2p_helper(
357376
# TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
358377
tasks = []
359378
# start to p2p communicate
379+
360380
if _sync_send:
361381
# Some devices(NPU for example) do not support asynchronized send op, So the order is
362382
# recv_prev -> send_next -> recv_next -> send_prev
@@ -492,8 +512,8 @@ def _p2p_helper(
492512
group=_hcg.send_prev_group,
493513
use_calc_stream=False,
494514
)
495-
496515
else:
516+
_xpu_comm_group_start()
497517
if tensor_send_prev is not None:
498518
if isinstance(tensor_send_prev, tuple):
499519
for d in tensor_send_prev:
@@ -529,6 +549,7 @@ def _p2p_helper(
529549
use_calc_stream=sync_recv,
530550
)
531551
if sync_recv:
552+
_xpu_comm_group_end()
532553
allgather_partial(
533554
d,
534555
nranks=mp_degree,
@@ -549,6 +570,7 @@ def _p2p_helper(
549570
)
550571

551572
if sync_recv:
573+
_xpu_comm_group_end()
552574
allgather_partial(
553575
tensor_recv_prev,
554576
nranks=mp_degree,
@@ -595,6 +617,7 @@ def _p2p_helper(
595617
)
596618

597619
if sync_recv:
620+
_xpu_comm_group_end()
598621
allgather_partial(
599622
d,
600623
nranks=mp_degree,
@@ -615,6 +638,7 @@ def _p2p_helper(
615638
use_calc_stream=sync_recv,
616639
)
617640
if sync_recv:
641+
_xpu_comm_group_end()
618642
allgather_partial(
619643
tensor_recv_next,
620644
nranks=mp_degree,
@@ -624,7 +648,7 @@ def _p2p_helper(
624648
)
625649
else:
626650
tasks.append(task)
627-
651+
_xpu_comm_group_end()
628652
if not sync_recv:
629653
if framework.in_dygraph_mode():
630654
# wait irecv tasks in eager dygraph mode with new comm library

0 commit comments

Comments
 (0)