Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 34 additions & 50 deletions python/paddle/distributed/auto_parallel/static/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,58 +160,42 @@ def instantiate(self):
strategy.nrings = 1
if core.is_compiled_with_cuda():
place = core.CUDAPlace(genv.device_id)
use_new_comm = paddle.get_flags(
"FLAGS_dynamic_static_unified_comm"
)["FLAGS_dynamic_static_unified_comm"]
if use_new_comm:
store = core.create_or_get_global_tcp_store()
endpoints_str = ""
for endpoint in strategy.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += f"ring_id:{ring_id}"
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()

core.CommContextManager.set_device_id(genv.device_id)
core.CommContextManager.create_nccl_comm_context(
store,
str(ring_id),
strategy.local_rank,
strategy.nranks,
endpoints_str_hash,
)
else:
core.NCCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
store = core.create_or_get_global_tcp_store()
endpoints_str = ""
for endpoint in strategy.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += f"ring_id:{ring_id}"
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()

core.CommContextManager.set_device_id(genv.device_id)
core.CommContextManager.create_nccl_comm_context(
store,
str(ring_id),
strategy.local_rank,
strategy.nranks,
endpoints_str_hash,
)
elif core.is_compiled_with_xpu():
place = core.XPUPlace(genv.device_id)
use_new_comm = paddle.get_flags(
"FLAGS_dynamic_static_unified_comm"
)["FLAGS_dynamic_static_unified_comm"]
if use_new_comm:
store = core.create_or_get_global_tcp_store()
endpoints_str = ""
for endpoint in strategy.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += f"ring_id:{ring_id}"
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()

core.CommContextManager.set_device_id(genv.device_id)
core.CommContextManager.create_bkcl_comm_context(
store,
str(ring_id),
strategy.local_rank,
strategy.nranks,
endpoints_str_hash,
)
else:
core.BKCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
store = core.create_or_get_global_tcp_store()
endpoints_str = ""
for endpoint in strategy.trainer_endpoints:
endpoints_str += endpoint
endpoints_str += f"ring_id:{ring_id}"
endpoints_str_hash = hashlib.md5(
endpoints_str.encode(encoding='UTF-8')
).hexdigest()

core.CommContextManager.set_device_id(genv.device_id)
core.CommContextManager.create_bkcl_comm_context(
store,
str(ring_id),
strategy.local_rank,
strategy.nranks,
endpoints_str_hash,
)
elif genv.device_type in core.get_all_custom_device_type():
place = core.CustomPlace(genv.device_type, genv.device_id)
core.XCCLParallelContext(strategy, place).init_with_ring_id(
Expand Down
42 changes: 1 addition & 41 deletions python/paddle/distributed/fleet/base/private_helper_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import sys
import time
from contextlib import closing

import paddle

__all__ = []

Expand All @@ -35,39 +30,4 @@ def wait_server_ready(endpoints):

>>> wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"])
"""
try:
use_new_comm = paddle.get_flags("FLAGS_dynamic_static_unified_comm")[
"FLAGS_dynamic_static_unified_comm"
]
except:
use_new_comm = False

if use_new_comm:
return
assert not isinstance(endpoints, str)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with closing(
socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.settimeout(2)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)

result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write(
"not ready endpoints:" + str(not_ready_endpoints) + "\n"
)
sys.stderr.flush()
time.sleep(3)
else:
break
return
7 changes: 0 additions & 7 deletions python/paddle/distributed/fleet/meta_optimizers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@ def _init_communicator(
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)

if rank == 0 and wait_port:
use_new_comm = paddle.get_flags(
"FLAGS_dynamic_static_unified_comm"
)["FLAGS_dynamic_static_unified_comm"]
if not use_new_comm:
wait_server_ready(other_endpoints)

def _add_sync_by_allreduce(block):
sync_var = block.create_var(
name=unique_name.generate('sync_var'),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import os

import paddle
from paddle.base import core
from paddle.incubate.optimizer import PipelineOptimizer
from paddle.static import (
Expand Down Expand Up @@ -705,11 +704,6 @@ def minimize_impl(
self._recreate_not_persist_param_as_var()

self._dump_program_for_debug()
use_new_comm = paddle.get_flags("FLAGS_dynamic_static_unified_comm")[
"FLAGS_dynamic_static_unified_comm"
]
if not use_new_comm:
self._wait()
return optimize_ops, params_grads

def _init_pair_comm(self, pair, ring_id):
Expand Down
Loading