1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15-
1615import os
1716
1817import numpy as np
2726_use_cache = False
2827_enable_partial_send_recv = True
2928
29+ _xpu_comm_group_started = False
30+
3031_sync_send = os .environ .get ("PADDLE_P2P_SYNC_SEND" , "0" )
3132_sync_send = _sync_send .lower () in ['1' , 'true' ]
3233
3334
35+ def _xpu_comm_group_start ():
36+ if not paddle .is_compiled_with_xpu ():
37+ return
38+ global _xpu_comm_group_started
39+ assert not _xpu_comm_group_started
40+ framework .core .ProcessGroupBKCL .group_start ()
41+ _xpu_comm_group_started = True
42+
43+
44+ def _xpu_comm_group_end ():
45+ if not paddle .is_compiled_with_xpu ():
46+ return
47+ global _xpu_comm_group_started
48+ if _xpu_comm_group_started :
49+ framework .core .ProcessGroupBKCL .group_end ()
50+ _xpu_comm_group_started = False
51+
52+
3453def initialize_p2p_groups (hcg , use_cache = True , enable_partial_send_recv = True ):
3554 global _hcg , _use_cache , _enable_partial_send_recv
3655 _hcg = hcg
@@ -357,6 +376,7 @@ def _p2p_helper(
357376 # TODO(Yuang Liu): use batch_isend_irecv replace all these comm ops
358377 tasks = []
359378 # start to p2p communicate
379+
360380 if _sync_send :
361381 # Some devices(NPU for example) do not support asynchronized send op, So the order is
362382 # recv_prev -> send_next -> recv_next -> send_prev
@@ -492,8 +512,8 @@ def _p2p_helper(
492512 group = _hcg .send_prev_group ,
493513 use_calc_stream = False ,
494514 )
495-
496515 else :
516+ _xpu_comm_group_start ()
497517 if tensor_send_prev is not None :
498518 if isinstance (tensor_send_prev , tuple ):
499519 for d in tensor_send_prev :
@@ -529,6 +549,7 @@ def _p2p_helper(
529549 use_calc_stream = sync_recv ,
530550 )
531551 if sync_recv :
552+ _xpu_comm_group_end ()
532553 allgather_partial (
533554 d ,
534555 nranks = mp_degree ,
@@ -549,6 +570,7 @@ def _p2p_helper(
549570 )
550571
551572 if sync_recv :
573+ _xpu_comm_group_end ()
552574 allgather_partial (
553575 tensor_recv_prev ,
554576 nranks = mp_degree ,
@@ -595,6 +617,7 @@ def _p2p_helper(
595617 )
596618
597619 if sync_recv :
620+ _xpu_comm_group_end ()
598621 allgather_partial (
599622 d ,
600623 nranks = mp_degree ,
@@ -615,6 +638,7 @@ def _p2p_helper(
615638 use_calc_stream = sync_recv ,
616639 )
617640 if sync_recv :
641+ _xpu_comm_group_end ()
618642 allgather_partial (
619643 tensor_recv_next ,
620644 nranks = mp_degree ,
@@ -624,7 +648,7 @@ def _p2p_helper(
624648 )
625649 else :
626650 tasks .append (task )
627-
651+ _xpu_comm_group_end ()
628652 if not sync_recv :
629653 if framework .in_dygraph_mode ():
630654 # wait irecv tasks in eager dygraph mode with new comm library
0 commit comments