Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 80 additions & 15 deletions python/paddle/utils/cpp_extension/cpp_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
# isort: skip_file

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import os
import copy
import concurrent
import re

import setuptools
from setuptools.command.easy_install import easy_install
from setuptools.command.build_ext import build_ext
from distutils.command.build import build


from .extension_utils import (
add_compile_flag,
find_cuda_home,
Expand Down Expand Up @@ -64,6 +63,8 @@
from .extension_utils import CLANG_COMPILE_FLAGS, CLANG_LINK_FLAGS

from ...base import core
from concurrent.futures import ThreadPoolExecutor


if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -417,6 +418,7 @@ def build_extensions(self) -> None:
self._valid_clang_compiler()

self._check_abi()
current_extension_builder = self

# Note(Aurelius84): If already compiling source before, we should check whether
# cflags have changed and delete the built shared library to re-compile the source
Expand All @@ -433,11 +435,9 @@ def build_extensions(self) -> None:
self.compiler._cpp_extensions += ['.cu', '.cuh']
original_compile = self.compiler.compile
original_spawn = self.compiler.spawn
else:
original_compile = self.compiler._compile

def unix_custom_single_compiler(
obj, src, ext, cc_args, extra_postargs, pp_opts
def unix_custom_compile_single_file(
self, obj, src, ext, cc_args, extra_postargs, pp_opts
):
"""
Monkey patch mechanism to replace inner compiler to custom compile process on Unix platform.
Expand All @@ -447,7 +447,7 @@ def unix_custom_single_compiler(
src = os.path.abspath(src)
cflags = copy.deepcopy(extra_postargs)
try:
original_compiler = self.compiler.compiler_so
original_compiler = self.compiler_so
# nvcc or hipcc compile CUDA source
if is_cuda_file(src):
if core.is_compiled_with_rocm():
Expand All @@ -457,7 +457,7 @@ def unix_custom_single_compiler(
please use `export ROCM_PATH= XXX` to specify it."

hipcc_cmd = os.path.join(ROCM_HOME, 'bin', 'hipcc')
self.compiler.set_executable('compiler_so', hipcc_cmd)
self.set_executable('compiler_so', hipcc_cmd)
# {'nvcc': {}, 'cxx: {}}
if isinstance(cflags, dict):
cflags = cflags['hipcc']
Expand All @@ -468,7 +468,7 @@ def unix_custom_single_compiler(
please use `export CUDA_HOME= XXX` to specify it."

nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
self.compiler.set_executable('compiler_so', nvcc_cmd)
self.set_executable('compiler_so', nvcc_cmd)
# {'nvcc': {}, 'cxx: {}}
if isinstance(cflags, dict):
cflags = cflags['nvcc']
Expand All @@ -492,19 +492,77 @@ def unix_custom_single_compiler(
# See https://stackoverflow.com/questions/34571583/understanding-gcc-5s-glibcxx-use-cxx11-abi-or-the-new-abi
add_compile_flag(cflags, ['-D_GLIBCXX_USE_CXX11_ABI=1'])
# Append this macro only when jointly compiling .cc with .cu
if not is_cuda_file(src) and self.contain_cuda_file:
if (
not is_cuda_file(src)
and current_extension_builder.contain_cuda_file
):
if core.is_compiled_with_rocm():
cflags.append('-DPADDLE_WITH_HIP')
else:
cflags.append('-DPADDLE_WITH_CUDA')

add_std_without_repeat(
cflags, self.compiler.compiler_type, use_std17=True
cflags, self.compiler_type, use_std17=True
)
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
self._compile(obj, src, ext, cc_args, cflags, pp_opts)
except Exception as e:
print(f'{src} compile failed, {e}')
finally:
# restore original_compiler
self.compiler.set_executable('compiler_so', original_compiler)
self.set_executable('compiler_so', original_compiler)

def unix_custom_single_compiler(
self,
sources,
output_dir=None,
macros=None,
include_dirs=None,
debug=False,
extra_preargs=None,
extra_postargs=None,
depends=None,
):
# A concrete compiler class can either override this method
# entirely or implement _compile().
macros, objects, extra_postargs, pp_opts, build = (
self._setup_compile(
output_dir,
macros,
include_dirs,
sources,
depends,
extra_postargs,
)
)
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
# Create a thread pool
worke_number = min(os.cpu_count(), len(objects))
with ThreadPoolExecutor(max_workers=worke_number) as executor:
# Submit all compilation tasks to the thread pool.
futures = {
executor.submit(
unix_custom_compile_single_file,
copy.copy(self),
obj,
build[obj][0],
build[obj][1],
cc_args,
extra_postargs,
pp_opts,
): obj
for obj in objects
}

for future in concurrent.futures.as_completed(futures):
obj = futures[future]
try:
future.result()
except Exception as exc:
print(f'{obj!r} generated an exception: {exc}')
else:
print(f'{obj} is compiled')
# Return *all* object filenames, not just the ones we just built.
return objects

def win_custom_single_compiler(
sources,
Expand Down Expand Up @@ -643,9 +701,11 @@ def wrapper(source_filenames, strip_dir=0, output_dir=''):

# customized compile process
if self.compiler.compiler_type == 'msvc':
original_compile = self.compiler.compile
self.compiler.compile = win_custom_single_compiler
else:
self.compiler._compile = unix_custom_single_compiler
original_compile = self.compiler.__class__.compile
self.compiler.__class__.compile = unix_custom_single_compiler

self.compiler.object_filenames = object_filenames_with_cuda(
self.compiler.object_filenames, self.build_lib
Expand All @@ -655,6 +715,11 @@ def wrapper(source_filenames, strip_dir=0, output_dir=''):
print("Compiling user custom op, it will cost a few seconds.....")
build_ext.build_extensions(self)

if self.compiler.compiler_type == 'msvc':
self.compiler.compile = original_compile
else:
self.compiler.__class__.compile = original_compile

# Reset runtime library path on MacOS platform
so_path = self.get_ext_fullpath(self.extensions[0]._full_name)
_reset_so_rpath(so_path)
Expand Down