Skip to content

Commit 962c8f6

Browse files
authored
optimize_custom_op_compile (#67615)
* optimize_custom_op_compile * fix * fix * fix * fix * fix * fix * fix
1 parent a9470b6 commit 962c8f6

File tree

1 file changed

+80
-15
lines changed

1 file changed

+80
-15
lines changed

python/paddle/utils/cpp_extension/cpp_extension.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@
1515
# isort: skip_file
1616

1717
from __future__ import annotations
18-
1918
from typing import TYPE_CHECKING, Any
20-
2119
import os
2220
import copy
21+
import concurrent
2322
import re
24-
2523
import setuptools
2624
from setuptools.command.easy_install import easy_install
2725
from setuptools.command.build_ext import build_ext
2826
from distutils.command.build import build
2927

28+
3029
from .extension_utils import (
3130
add_compile_flag,
3231
find_cuda_home,
@@ -64,6 +63,8 @@
6463
from .extension_utils import CLANG_COMPILE_FLAGS, CLANG_LINK_FLAGS
6564

6665
from ...base import core
66+
from concurrent.futures import ThreadPoolExecutor
67+
6768

6869
if TYPE_CHECKING:
6970
from collections.abc import Sequence
@@ -417,6 +418,7 @@ def build_extensions(self) -> None:
417418
self._valid_clang_compiler()
418419

419420
self._check_abi()
421+
current_extension_builder = self
420422

421423
# Note(Aurelius84): If already compiling source before, we should check whether
422424
# cflags have changed and delete the built shared library to re-compile the source
@@ -433,11 +435,9 @@ def build_extensions(self) -> None:
433435
self.compiler._cpp_extensions += ['.cu', '.cuh']
434436
original_compile = self.compiler.compile
435437
original_spawn = self.compiler.spawn
436-
else:
437-
original_compile = self.compiler._compile
438438

439-
def unix_custom_single_compiler(
440-
obj, src, ext, cc_args, extra_postargs, pp_opts
439+
def unix_custom_compile_single_file(
440+
self, obj, src, ext, cc_args, extra_postargs, pp_opts
441441
):
442442
"""
443443
Monkey patch mechanism to replace inner compiler to custom compile process on Unix platform.
@@ -447,7 +447,7 @@ def unix_custom_single_compiler(
447447
src = os.path.abspath(src)
448448
cflags = copy.deepcopy(extra_postargs)
449449
try:
450-
original_compiler = self.compiler.compiler_so
450+
original_compiler = self.compiler_so
451451
# nvcc or hipcc compile CUDA source
452452
if is_cuda_file(src):
453453
if core.is_compiled_with_rocm():
@@ -457,7 +457,7 @@ def unix_custom_single_compiler(
457457
please use `export ROCM_PATH= XXX` to specify it."
458458

459459
hipcc_cmd = os.path.join(ROCM_HOME, 'bin', 'hipcc')
460-
self.compiler.set_executable('compiler_so', hipcc_cmd)
460+
self.set_executable('compiler_so', hipcc_cmd)
461461
# {'nvcc': {}, 'cxx: {}}
462462
if isinstance(cflags, dict):
463463
cflags = cflags['hipcc']
@@ -468,7 +468,7 @@ def unix_custom_single_compiler(
468468
please use `export CUDA_HOME= XXX` to specify it."
469469

470470
nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
471-
self.compiler.set_executable('compiler_so', nvcc_cmd)
471+
self.set_executable('compiler_so', nvcc_cmd)
472472
# {'nvcc': {}, 'cxx: {}}
473473
if isinstance(cflags, dict):
474474
cflags = cflags['nvcc']
@@ -492,19 +492,77 @@ def unix_custom_single_compiler(
492492
# See https://stackoverflow.com/questions/34571583/understanding-gcc-5s-glibcxx-use-cxx11-abi-or-the-new-abi
493493
add_compile_flag(cflags, ['-D_GLIBCXX_USE_CXX11_ABI=1'])
494494
# Append this macro only when jointly compiling .cc with .cu
495-
if not is_cuda_file(src) and self.contain_cuda_file:
495+
if (
496+
not is_cuda_file(src)
497+
and current_extension_builder.contain_cuda_file
498+
):
496499
if core.is_compiled_with_rocm():
497500
cflags.append('-DPADDLE_WITH_HIP')
498501
else:
499502
cflags.append('-DPADDLE_WITH_CUDA')
500503

501504
add_std_without_repeat(
502-
cflags, self.compiler.compiler_type, use_std17=True
505+
cflags, self.compiler_type, use_std17=True
503506
)
504-
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
507+
self._compile(obj, src, ext, cc_args, cflags, pp_opts)
508+
except Exception as e:
509+
print(f'{src} compile failed, {e}')
505510
finally:
506511
# restore original_compiler
507-
self.compiler.set_executable('compiler_so', original_compiler)
512+
self.set_executable('compiler_so', original_compiler)
513+
514+
def unix_custom_single_compiler(
515+
self,
516+
sources,
517+
output_dir=None,
518+
macros=None,
519+
include_dirs=None,
520+
debug=False,
521+
extra_preargs=None,
522+
extra_postargs=None,
523+
depends=None,
524+
):
525+
# A concrete compiler class can either override this method
526+
# entirely or implement _compile().
527+
macros, objects, extra_postargs, pp_opts, build = (
528+
self._setup_compile(
529+
output_dir,
530+
macros,
531+
include_dirs,
532+
sources,
533+
depends,
534+
extra_postargs,
535+
)
536+
)
537+
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
538+
# Create a thread pool
539+
worke_number = min(os.cpu_count(), len(objects))
540+
with ThreadPoolExecutor(max_workers=worke_number) as executor:
541+
# Submit all compilation tasks to the thread pool.
542+
futures = {
543+
executor.submit(
544+
unix_custom_compile_single_file,
545+
copy.copy(self),
546+
obj,
547+
build[obj][0],
548+
build[obj][1],
549+
cc_args,
550+
extra_postargs,
551+
pp_opts,
552+
): obj
553+
for obj in objects
554+
}
555+
556+
for future in concurrent.futures.as_completed(futures):
557+
obj = futures[future]
558+
try:
559+
future.result()
560+
except Exception as exc:
561+
print(f'{obj!r} generated an exception: {exc}')
562+
else:
563+
print(f'{obj} is compiled')
564+
# Return *all* object filenames, not just the ones we just built.
565+
return objects
508566

509567
def win_custom_single_compiler(
510568
sources,
@@ -643,9 +701,11 @@ def wrapper(source_filenames, strip_dir=0, output_dir=''):
643701

644702
# customized compile process
645703
if self.compiler.compiler_type == 'msvc':
704+
original_compile = self.compiler.compile
646705
self.compiler.compile = win_custom_single_compiler
647706
else:
648-
self.compiler._compile = unix_custom_single_compiler
707+
original_compile = self.compiler.__class__.compile
708+
self.compiler.__class__.compile = unix_custom_single_compiler
649709

650710
self.compiler.object_filenames = object_filenames_with_cuda(
651711
self.compiler.object_filenames, self.build_lib
@@ -655,6 +715,11 @@ def wrapper(source_filenames, strip_dir=0, output_dir=''):
655715
print("Compiling user custom op, it will cost a few seconds.....")
656716
build_ext.build_extensions(self)
657717

718+
if self.compiler.compiler_type == 'msvc':
719+
self.compiler.compile = original_compile
720+
else:
721+
self.compiler.__class__.compile = original_compile
722+
658723
# Reset runtime library path on MacOS platform
659724
so_path = self.get_ext_fullpath(self.extensions[0]._full_name)
660725
_reset_so_rpath(so_path)

0 commit comments

Comments
 (0)