1515# isort: skip_file
1616
1717from __future__ import annotations
18-
1918from typing import TYPE_CHECKING , Any
20-
2119import os
2220import copy
21+ import concurrent
2322import re
24-
2523import setuptools
2624from setuptools .command .easy_install import easy_install
2725from setuptools .command .build_ext import build_ext
2826from distutils .command .build import build
2927
28+
3029from .extension_utils import (
3130 add_compile_flag ,
3231 find_cuda_home ,
6463from .extension_utils import CLANG_COMPILE_FLAGS , CLANG_LINK_FLAGS
6564
6665from ...base import core
66+ from concurrent .futures import ThreadPoolExecutor
67+
6768
6869if TYPE_CHECKING :
6970 from collections .abc import Sequence
@@ -417,6 +418,7 @@ def build_extensions(self) -> None:
417418 self ._valid_clang_compiler ()
418419
419420 self ._check_abi ()
421+ current_extension_builder = self
420422
421423 # Note(Aurelius84): If already compiling source before, we should check whether
422424 # cflags have changed and delete the built shared library to re-compile the source
@@ -433,11 +435,9 @@ def build_extensions(self) -> None:
433435 self .compiler ._cpp_extensions += ['.cu' , '.cuh' ]
434436 original_compile = self .compiler .compile
435437 original_spawn = self .compiler .spawn
436- else :
437- original_compile = self .compiler ._compile
438438
439- def unix_custom_single_compiler (
440- obj , src , ext , cc_args , extra_postargs , pp_opts
439+ def unix_custom_compile_single_file (
440+ self , obj , src , ext , cc_args , extra_postargs , pp_opts
441441 ):
442442 """
443443 Monkey patch mechanism to replace inner compiler to custom compile process on Unix platform.
@@ -447,7 +447,7 @@ def unix_custom_single_compiler(
447447 src = os .path .abspath (src )
448448 cflags = copy .deepcopy (extra_postargs )
449449 try :
450- original_compiler = self .compiler . compiler_so
450+ original_compiler = self .compiler_so
451451 # nvcc or hipcc compile CUDA source
452452 if is_cuda_file (src ):
453453 if core .is_compiled_with_rocm ():
@@ -457,7 +457,7 @@ def unix_custom_single_compiler(
457457 please use `export ROCM_PATH= XXX` to specify it."
458458
459459 hipcc_cmd = os .path .join (ROCM_HOME , 'bin' , 'hipcc' )
460- self .compiler . set_executable ('compiler_so' , hipcc_cmd )
460+ self .set_executable ('compiler_so' , hipcc_cmd )
461461 # {'nvcc': {}, 'cxx: {}}
462462 if isinstance (cflags , dict ):
463463 cflags = cflags ['hipcc' ]
@@ -468,7 +468,7 @@ def unix_custom_single_compiler(
468468 please use `export CUDA_HOME= XXX` to specify it."
469469
470470 nvcc_cmd = os .path .join (CUDA_HOME , 'bin' , 'nvcc' )
471- self .compiler . set_executable ('compiler_so' , nvcc_cmd )
471+ self .set_executable ('compiler_so' , nvcc_cmd )
472472 # {'nvcc': {}, 'cxx: {}}
473473 if isinstance (cflags , dict ):
474474 cflags = cflags ['nvcc' ]
@@ -492,19 +492,77 @@ def unix_custom_single_compiler(
492492 # See https://stackoverflow.com/questions/34571583/understanding-gcc-5s-glibcxx-use-cxx11-abi-or-the-new-abi
493493 add_compile_flag (cflags , ['-D_GLIBCXX_USE_CXX11_ABI=1' ])
494494 # Append this macro only when jointly compiling .cc with .cu
495- if not is_cuda_file (src ) and self .contain_cuda_file :
495+ if (
496+ not is_cuda_file (src )
497+ and current_extension_builder .contain_cuda_file
498+ ):
496499 if core .is_compiled_with_rocm ():
497500 cflags .append ('-DPADDLE_WITH_HIP' )
498501 else :
499502 cflags .append ('-DPADDLE_WITH_CUDA' )
500503
501504 add_std_without_repeat (
502- cflags , self .compiler . compiler_type , use_std17 = True
505+ cflags , self .compiler_type , use_std17 = True
503506 )
504- original_compile (obj , src , ext , cc_args , cflags , pp_opts )
507+ self ._compile (obj , src , ext , cc_args , cflags , pp_opts )
508+ except Exception as e :
509+ print (f'{ src } compile failed, { e } ' )
505510 finally :
506511 # restore original_compiler
507- self .compiler .set_executable ('compiler_so' , original_compiler )
512+ self .set_executable ('compiler_so' , original_compiler )
513+
514+ def unix_custom_single_compiler (
515+ self ,
516+ sources ,
517+ output_dir = None ,
518+ macros = None ,
519+ include_dirs = None ,
520+ debug = False ,
521+ extra_preargs = None ,
522+ extra_postargs = None ,
523+ depends = None ,
524+ ):
525+ # A concrete compiler class can either override this method
526+ # entirely or implement _compile().
527+ macros , objects , extra_postargs , pp_opts , build = (
528+ self ._setup_compile (
529+ output_dir ,
530+ macros ,
531+ include_dirs ,
532+ sources ,
533+ depends ,
534+ extra_postargs ,
535+ )
536+ )
537+ cc_args = self ._get_cc_args (pp_opts , debug , extra_preargs )
538+ # Create a thread pool
539+ worke_number = min (os .cpu_count (), len (objects ))
540+ with ThreadPoolExecutor (max_workers = worke_number ) as executor :
541+ # Submit all compilation tasks to the thread pool.
542+ futures = {
543+ executor .submit (
544+ unix_custom_compile_single_file ,
545+ copy .copy (self ),
546+ obj ,
547+ build [obj ][0 ],
548+ build [obj ][1 ],
549+ cc_args ,
550+ extra_postargs ,
551+ pp_opts ,
552+ ): obj
553+ for obj in objects
554+ }
555+
556+ for future in concurrent .futures .as_completed (futures ):
557+ obj = futures [future ]
558+ try :
559+ future .result ()
560+ except Exception as exc :
561+ print (f'{ obj !r} generated an exception: { exc } ' )
562+ else :
563+ print (f'{ obj } is compiled' )
564+ # Return *all* object filenames, not just the ones we just built.
565+ return objects
508566
509567 def win_custom_single_compiler (
510568 sources ,
@@ -643,9 +701,11 @@ def wrapper(source_filenames, strip_dir=0, output_dir=''):
643701
644702 # customized compile process
645703 if self .compiler .compiler_type == 'msvc' :
704+ original_compile = self .compiler .compile
646705 self .compiler .compile = win_custom_single_compiler
647706 else :
648- self .compiler ._compile = unix_custom_single_compiler
707+ original_compile = self .compiler .__class__ .compile
708+ self .compiler .__class__ .compile = unix_custom_single_compiler
649709
650710 self .compiler .object_filenames = object_filenames_with_cuda (
651711 self .compiler .object_filenames , self .build_lib
@@ -655,6 +715,11 @@ def wrapper(source_filenames, strip_dir=0, output_dir=''):
655715 print ("Compiling user custom op, it will cost a few seconds....." )
656716 build_ext .build_extensions (self )
657717
718+ if self .compiler .compiler_type == 'msvc' :
719+ self .compiler .compile = original_compile
720+ else :
721+ self .compiler .__class__ .compile = original_compile
722+
658723 # Reset runtime library path on MacOS platform
659724 so_path = self .get_ext_fullpath (self .extensions [0 ]._full_name )
660725 _reset_so_rpath (so_path )
0 commit comments