Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions python/paddle/distributed/fleet/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,20 @@ def launch_ps(args, distribute_mode):
return


def infer_backend(args):
if args.backend != "auto": return
if fluid.core.is_compiled_with_cuda():
args.backend = 'nccl'
elif fluid.core.is_compiled_with_npu():
args.backend = 'unknown'
elif fluid.core.is_compiled_with_xpu():
args.backend = 'bkcl'
else:
args.backend = 'gloo'


def which_distributed_mode(args):
infer_backend(args) # modify the args.backend
if args.run_mode is not None:
assert args.run_mode in ["collective", "ps", "ps-heter"]

Expand Down Expand Up @@ -368,12 +381,9 @@ def which_distributed_mode(args):

if fluid.core.is_compiled_with_cuda():
accelerators = fluid.core.get_cuda_device_count()
args.backend = 'nccl'
elif fluid.core.is_compiled_with_npu():
args.backend = 'unknown'
accelerators = fluid.core.get_npu_device_count()
elif fluid.core.is_compiled_with_xpu():
args.backend = 'bkcl'
accelerators = fluid.core.get_xpu_device_count()
else:
accelerators = 0
Expand All @@ -400,7 +410,6 @@ def which_distributed_mode(args):
But found args.servers not empty, default use ps mode")
return DistributeMode.PS
else:
args.backend = "gloo"
return DistributeMode.COLLECTIVE
else:
logger.warning(
Expand Down Expand Up @@ -583,20 +592,21 @@ def launch():
_print_arguments(args)

if args.backend == 'auto':
distribute_mode = which_distributed_mode(args)
assert args.backend in [
'gloo', 'nccl', 'bkcl', 'unknown'
] # which_distributed_mode must modify args.backend
distribute_mode = which_distributed_mode(
args) # which_distributed_mode must modify args.backend
else:
assert args.run_mode == 'collective' or args.run_mode == None, "When backend is not 'auto', run mode must be collective"
check_backend(args.backend)
distribute_mode = DistributeMode.COLLECTIVE

block_windows_and_macos(
args.backend) # raise error when using gloo on windows or macos
assert args.backend in ['gloo', 'nccl', 'bkcl', 'unknown']

if args.backend == 'gloo':
logger.warning("launch start with CPUONLY mode")

block_windows_and_macos(
args.backend) # raise error when using gloo on windows or macos

if enable_elastic(args, distribute_mode):
launch_elastic(args, distribute_mode)
return
Expand Down