diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py index d230c4875ad150..844b56003fc145 100644 --- a/python/paddle/utils/cpp_extension/cpp_extension.py +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -25,9 +25,13 @@ import setuptools import sys import paddle +import site + from setuptools.command.easy_install import easy_install from setuptools.command.build_ext import build_ext from distutils.command.build import build +from setuptools.command.install import install + from .extension_utils import ( add_compile_flag, @@ -55,9 +59,9 @@ ) from .extension_utils import _reset_so_rpath, clean_object_if_change_cflags from .extension_utils import ( - bootstrap_context, get_build_directory, add_std_without_repeat, + custom_write_stub, ) from .extension_utils import ( @@ -235,6 +239,11 @@ def setup(**attr: Any) -> None: assert 'easy_install' not in cmdclass cmdclass['easy_install'] = EasyInstallCommand + # Compatible with wheel installation via `pip install .` + # Note: This is rarely used with modern pip, which uses bdist_wheel instead + assert 'install' not in cmdclass + cmdclass['install'] = InstallCommand + # Note(Aurelius84): Add rename build_base directory hook in build command. # To avoid using same build directory that will lead to remove the directory # by mistake while parallelling execute setup.py, for example on CI. @@ -246,9 +255,7 @@ def setup(**attr: Any) -> None: # See http://peak.telecommunity.com/DevCenter/setuptools#setting-the-zip-safe-flag attr['zip_safe'] = False - # switch `write_stub` to inject paddle api in .egg - with bootstrap_context(): - setuptools.setup(**attr) + setuptools.setup(**attr) def CppExtension( @@ -849,8 +856,43 @@ def _clean_intermediate_files(self): os.remove(os.path.join(root, file)) print(f"Removed: {os.path.join(root, file)}") + def _generate_python_api_file(self) -> None: + """ + Generate the top-level python api file (package stub) alongside the + built shared library in build_lib. This replaces the legacy bdist_egg + write_stub mechanism that is no longer triggered in setuptools >= 80. + """ + try: + outputs = self.get_outputs() + if not outputs: + return + # We only support a single extension per setup() + so_path = os.path.abspath(outputs[0]) + so_name = os.path.basename(so_path) + build_dir = os.path.dirname(so_path) + + # Get the extension name from the extension module, not the distribution name + # This ensures we use the correct package name from setup.py + ext_name = self.extensions[0].name + + # Extract the last part of the extension name for the Python file + # For example, from "custom_setup_ops.my_ops.custom_relu" we get "custom_relu" + lib_name = ext_name.split('.')[-1] if '.' in ext_name else ext_name + + pyfile = os.path.join(build_dir, f"{lib_name}.py") + # Write stub; it will reference the _pd_ renamed resource at import time + custom_write_stub(so_name, pyfile) + except Exception as e: + raise RuntimeError( + f"Failed to generate python api file: {e}" + ) from e + def run(self): super().run() + + # Compatible with wheel installation via `pip install .` + self._generate_python_api_file() + self._clean_intermediate_files() @@ -926,6 +968,222 @@ def initialize_options(self) -> None: self.build_base = self._specified_build_base +class InstallCommand(install): + """ + Extend install Command to: + 1) choose an install dir that is actually importable (on sys.path) + 2) ensure a single top-level entry for the package in site/dist-packages so + legacy tests that expect a sole artifact (egg/package) keep working + 3) rename the compiled library to *_pd_.so to avoid shadowing the python stub + """ + + def _get_extension_name(self) -> str: + """ + Get the extension name from the extension module, not the distribution name. + This ensures we use the correct package name from setup.py. + + Note: This assumes there is only one extension module (len(ext_modules) == 1). + + Returns: + str: The extension name + """ + return self.distribution.ext_modules[0].name + + def finalize_options(self) -> None: + super().finalize_options() + + install_dir = ( + getattr(self, 'install_lib', None) + or getattr(self, 'install_purelib', None) + or getattr(self, 'install_platlib', None) + ) + if not install_dir or not os.path.isdir(install_dir): + return + + # Get the extension name + ext_name = self._get_extension_name() + + # Extract the first part of the extension name for the shared library + # For example, from "custom_setup_ops.my_ops.custom_relu" we get "custom_setup_ops" + pkg_name = ext_name.split('.')[0] if '.' in ext_name else ext_name + + # Check if dist-info exists + has_dist_info = any( + name.endswith('.dist-info') and name.startswith(pkg_name) + for name in os.listdir(install_dir) + ) + # If dist-info exists, we are installing a wheel, so we are done + if has_dist_info: + return + + # Build candidate site dirs: global + user + entries already on sys.path + candidates = [] + candidates.extend(site.getsitepackages()) + usp = site.getusersitepackages() + if usp: + candidates.append(usp) + for sp in sys.path: + if isinstance(sp, str) and sp.endswith( + ('site-packages', 'dist-packages') + ): + candidates.append(sp) + # De-dup while preserving order + seen = set() + ordered = [] + for c in candidates: + if c and c not in seen: + seen.add(c) + ordered.append(c) + # Prefer a candidate that is actually on sys.path + target = None + for c in ordered: + if c in sys.path and os.path.isdir(c): + target = c + break + # Fallback: pick the first existing candidate + if target is None: + for c in ordered: + if os.path.isdir(c): + target = c + break + if target: + option_dict = self.distribution.get_option_dict('install') + + if 'install_lib' not in option_dict: + self.install_lib = target + + if 'install_purelib' not in option_dict: + self.install_purelib = target + + if 'install_platlib' not in option_dict: + self.install_platlib = target + + def run(self, *args: Any, **kwargs: Any) -> None: + super().run(*args, **kwargs) + + install_dir = ( + getattr(self, 'install_lib', None) + or getattr(self, 'install_purelib', None) + or getattr(self, 'install_platlib', None) + ) + if not install_dir or not os.path.isdir(install_dir): + return + + # Get the extension name + ext_name = self._get_extension_name() + + # Extract the first part of the extension name for the shared library + # For example, from "custom_setup_ops.my_ops.custom_relu" we get "custom_setup_ops" + pkg_name = ext_name.split('.')[0] if '.' in ext_name else ext_name + + # Check if dist-info exists + has_egg_info = any( + name.endswith('.egg-info') and name.startswith(pkg_name) + for name in os.listdir(install_dir) + ) + # If egg-info exists, we are installing a source distribution, we need to + # reorganize the files + if has_egg_info: + # First rename the shared library if present at top-level + self._rename_shared_library() + # Then canonicalize layout to a single top-level entry for this package + self._single_entry_layout() + + def _rename_shared_library(self) -> None: + install_dir = ( + getattr(self, 'install_lib', None) + or getattr(self, 'install_purelib', None) + or getattr(self, 'install_platlib', None) + ) + if not install_dir or not os.path.isdir(install_dir): + return + + # Get the extension name + ext_name = self._get_extension_name() + + # Extract the last part of the extension name for the shared library + # For example, from "custom_setup_ops.my_ops.custom_relu" we get "custom_relu" + names = ext_name.split('.') if '.' in ext_name else [ext_name] + lib_name = names[-1] + + suffix = ( + '.pyd' + if IS_WINDOWS + else ('.dylib' if OS_NAME.startswith('darwin') else '.so') + ) + + # Build the directory path for the shared library + # For single-level: names[:-1] is empty, so dir_path = install_dir + # For multi-level: names[:-1] contains the package path + dir_path = os.path.join(install_dir, *names[:-1]) + old = os.path.join(dir_path, f"{lib_name}{suffix}") + new = os.path.join(dir_path, f"{lib_name}_pd_{suffix}") + if os.path.exists(old): + if os.path.exists(new): + os.remove(new) + os.rename(old, new) + + def _single_entry_layout(self) -> None: + """ + Ensure only one top-level item in install_dir contains the package name by: + - moving {pkg}.py -> {pkg}/__init__.py + - moving {pkg}_pd_.so -> {pkg}/{pkg}_pd_.so + - removing any {pkg}-*.egg-info left by setuptools install (only if dist-info exists) + This keeps legacy tests that scan os.listdir(site_dir) happy. + """ + install_dir = ( + getattr(self, 'install_lib', None) + or getattr(self, 'install_purelib', None) + or getattr(self, 'install_platlib', None) + ) + if not install_dir or not os.path.isdir(install_dir): + return + + # Get the extension name + ext_name = self._get_extension_name() + + # Extract the package path from the extension name + # For example, from "custom_setup_ops.my_ops.custom_relu" we get "custom_setup_ops/my_ops" + pkg_path_parts = ( + ext_name.split('.')[:-1] if '.' in ext_name else [ext_name] + ) + pkg_path = os.path.join(*pkg_path_parts) + + # Extract the last part of the extension name for the Python file and shared library + # For example, from "custom_setup_ops.my_ops.custom_relu" we get "custom_relu" + lib_name = ext_name.split('.')[-1] if '.' in ext_name else ext_name + + # Prepare paths + pkg_dir = os.path.join(install_dir, pkg_path) + py_src = os.path.join(install_dir, f"{lib_name}.py") + # Find compiled lib (renamed or not) + suf_so = ( + '.pyd' + if IS_WINDOWS + else ('.dylib' if OS_NAME.startswith('darwin') else '.so') + ) + so_candidates = [ + os.path.join(install_dir, f"{lib_name}_pd_{suf_so}"), + os.path.join(install_dir, f"{lib_name}{suf_so}"), + ] + so_src = next((p for p in so_candidates if os.path.exists(p)), None) + # Create package dir + if not os.path.isdir(pkg_dir): + os.makedirs(pkg_dir, exist_ok=True) + # Move python stub to package/__init__.py if exists + if os.path.exists(py_src): + py_dst = os.path.join(pkg_dir, "__init__.py") + if os.path.exists(py_dst): + os.remove(py_dst) + os.replace(py_src, py_dst) + # Move shared lib into the package dir if exists + if so_src and os.path.exists(so_src): + so_dst = os.path.join(pkg_dir, os.path.basename(so_src)) + if os.path.exists(so_dst): + os.remove(so_dst) + os.replace(so_src, so_dst) + + def load( name: str, sources: Sequence[str], diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index c1ecd343b245d0..2918e588df350d 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -31,11 +31,8 @@ import textwrap import threading import warnings -from contextlib import contextmanager from importlib import machinery -from setuptools.command import bdist_egg - try: from subprocess import DEVNULL # py3 except ImportError: @@ -151,18 +148,6 @@ ] -@contextmanager -def bootstrap_context(): - """ - Context to manage how to write `__bootstrap__` code in .egg - """ - origin_write_stub = bdist_egg.write_stub - bdist_egg.write_stub = custom_write_stub - yield - - bdist_egg.write_stub = origin_write_stub - - def load_op_meta_info_and_register_op(lib_filename: str) -> list[str]: new_list = core.load_op_meta_info_and_register_op(lib_filename) proto_sync_ops = OpProtoHolder.instance().update_op_proto(new_list) @@ -237,7 +222,8 @@ def __bootstrap__(): with open(pyfile, 'w') as f: f.write( _stub_template.format( - resource=resource, custom_api='\n\n'.join(api_content) + resource=os.path.basename(resource), + custom_api='\n\n'.join(api_content), ) ) diff --git a/test/cpp_extension/test_cpp_extension_setup.py b/test/cpp_extension/test_cpp_extension_setup.py index 53e39fc2993c32..dbf60d83a0dfd9 100644 --- a/test/cpp_extension/test_cpp_extension_setup.py +++ b/test/cpp_extension/test_cpp_extension_setup.py @@ -39,13 +39,15 @@ def setUp(self): cmd += f' --install-lib={site_dir}' run_cmd(cmd) - custom_egg_path = [ + custom_install_path = [ x for x in os.listdir(site_dir) if 'custom_cpp_extension' in x ] - assert len(custom_egg_path) == 1, ( - f"Matched egg number is {len(custom_egg_path)}." + + assert len(custom_install_path) == 2, ( + f"Matched egg number is {len(custom_install_path)}." ) - sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + sys.path.append(os.path.join(site_dir, custom_install_path[0])) ################################# # config seed diff --git a/test/cpp_extension/test_mixed_extension_setup.py b/test/cpp_extension/test_mixed_extension_setup.py index b064aaeb2099e3..61e6991b686188 100644 --- a/test/cpp_extension/test_mixed_extension_setup.py +++ b/test/cpp_extension/test_mixed_extension_setup.py @@ -111,13 +111,15 @@ def setUp(self): cmd += f' --install-lib={site_dir}' run_cmd(cmd) - custom_egg_path = [ + custom_install_path = [ x for x in os.listdir(site_dir) if 'mix_relu_extension' in x ] - assert len(custom_egg_path) == 1, ( - f"Matched egg number is {len(custom_egg_path)}." + + assert len(custom_install_path) == 2, ( + f"Matched egg number is {len(custom_install_path)}." ) - sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + sys.path.append(os.path.join(site_dir, custom_install_path[0])) ################################# # config seed diff --git a/test/custom_op/test_custom_relu_op_setup.py b/test/custom_op/test_custom_relu_op_setup.py index c13c2890a0eb65..440c0a4afcec48 100644 --- a/test/custom_op/test_custom_relu_op_setup.py +++ b/test/custom_op/test_custom_relu_op_setup.py @@ -167,13 +167,15 @@ def setUp(self): site_dir = site.getsitepackages()[1] else: site_dir = site.getsitepackages()[0] - custom_egg_path = [ + custom_install_path = [ x for x in os.listdir(site_dir) if 'custom_relu_module_setup' in x ] - assert len(custom_egg_path) == 2, ( - f"Matched egg number is {len(custom_egg_path)}." + + assert len(custom_install_path) == 2, ( + f"Matched egg number is {len(custom_install_path)}." ) - sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + sys.path.append(os.path.join(site_dir, custom_install_path[0])) # usage: import the package directly import custom_relu_module_setup diff --git a/test/custom_op/test_custom_relu_op_xpu_setup.py b/test/custom_op/test_custom_relu_op_xpu_setup.py index 84cb45a30f4223..122d8e31abda78 100644 --- a/test/custom_op/test_custom_relu_op_xpu_setup.py +++ b/test/custom_op/test_custom_relu_op_xpu_setup.py @@ -72,15 +72,17 @@ def setUp(self): run_cmd(cmd) site_dir = site.getsitepackages()[0] - custom_egg_path = [ + custom_install_path = [ x for x in os.listdir(site_dir) if 'custom_relu_xpu_module_setup' in x ] - assert len(custom_egg_path) == 1, ( - f"Matched egg number is {len(custom_egg_path)}." + + assert len(custom_install_path) == 2, ( + f"Matched egg number is {len(custom_install_path)}." ) - sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + sys.path.append(os.path.join(site_dir, custom_install_path[0])) # usage: import the package directly import custom_relu_xpu_module_setup diff --git a/test/custom_op/test_inference_gap_setup.py b/test/custom_op/test_inference_gap_setup.py index 697e5dc36dcc39..a8ea70614c003d 100644 --- a/test/custom_op/test_inference_gap_setup.py +++ b/test/custom_op/test_inference_gap_setup.py @@ -54,13 +54,15 @@ def setUp(self): site_dir = site.getsitepackages()[0] - custom_egg_path = [ + custom_install_path = [ x for x in os.listdir(site_dir) if 'gap_op_setup' in x ] - assert len(custom_egg_path) == 1, ( - f"Matched egg number is {len(custom_egg_path)}." + + assert len(custom_install_path) == 2, ( + f"Matched egg number is {len(custom_install_path)}." ) - sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + sys.path.append(os.path.join(site_dir, custom_install_path[0])) # usage: import the package directly import gap_op_setup