diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index a0777b9e37..9b6314d04c 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -19,6 +19,7 @@ ) from . import env as jit_env +from ..utils import use_paddle_compatible_api from ..compilation_context import CompilationContext @@ -75,13 +76,26 @@ def generate_ninja_build_for_op( ) -> str: system_includes = [ sysconfig.get_path("include"), - "$torch_home/include", - "$torch_home/include/torch/csrc/api/include", "$cuda_home/include", "$cuda_home/include/cccl", jit_env.FLASHINFER_INCLUDE_DIR.resolve(), jit_env.FLASHINFER_CSRC_DIR.resolve(), ] + if use_paddle_compatible_api(): + system_includes.extend( + [ + "$torch_home/include", + "$torch_home/include/torch/csrc/api/include", + ] + ) + else: + system_includes.extend( + [ + "$torch_home/include", + "$torch_home/include/paddle/phi/api/include/compat", + "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include", + ] + ) system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS] system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve()) @@ -144,15 +158,35 @@ def generate_ninja_build_for_op( ldflags = [ "-shared", - "-L$torch_home/lib", - "-L$cuda_home/lib64", - "-lc10", - "-lc10_cuda", - "-ltorch_cpu", - "-ltorch_cuda", - "-ltorch", "-lcudart", ] + if use_paddle_compatible_api(): + ldflags.extend( + [ + "-L$torch_home/lib", + "-L$cuda_home/lib64", + "-lc10", + "-lc10_cuda", + "-ltorch_cpu", + "-ltorch_cuda", + "-ltorch", + ] + ) + else: + ldflags.extend( + [ + "-shared", + "-L$torch_home/libs", + "-L$torch_home/base", + "-L$cuda_home/lib64", + "-lpaddle", + "-lphi", + "-lphi_core", + "-lphi_gpu", + "-lcommon", + "-lcudart", + ] + ) env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS") if env_extra_ldflags: diff --git a/flashinfer/utils.py b/flashinfer/utils.py index ab1a1fa71e..ea30ac034b 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -25,7 +25,7 @@ from torch.torch_version import TorchVersion from torch.torch_version import __version__ as torch_version -from .jit import gen_jit_spec, env as jit_env +import flashinfer IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" @@ -240,7 +240,15 @@ def _check_cached_qkv_data_type( ) -if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"): +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + +if ( + use_paddle_compatible_api() + or IS_BUILDING_DOCS + or TorchVersion(torch_version) < TorchVersion("2.4") +): def register_custom_op( name: str, @@ -492,14 +500,14 @@ def check_shape_dtype_device( def gen_logging_module(): - return gen_jit_spec( + return flashinfer.jit.gen_jit_spec( "logging", [ - jit_env.FLASHINFER_CSRC_DIR / "logging.cc", + flashinfer.jit.env.FLASHINFER_CSRC_DIR / "logging.cc", ], extra_include_paths=[ - jit_env.SPDLOG_INCLUDE_DIR, - jit_env.FLASHINFER_INCLUDE_DIR, + flashinfer.jit.env.SPDLOG_INCLUDE_DIR, + flashinfer.jit.env.FLASHINFER_INCLUDE_DIR, ], ) diff --git a/setup.py b/setup.py index 3949ea8ce0..1f865ca155 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,10 @@ enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir()) +def use_paddle_compatible_api() -> bool: + return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"] + + def write_if_different(path: Path, content: str) -> None: if path.exists() and path.read_text() == content: return @@ -55,7 +59,6 @@ def generate_build_meta(aot_build_meta: dict) -> None: cmdclass: Mapping[str, type[setuptools.Command]] = {} install_requires = [ "numpy", - "torch", "ninja", "requests", "pynvml", @@ -66,9 +69,17 @@ def generate_build_meta(aot_build_meta: dict) -> None: "packaging>=24.2", "nvidia-cudnn-frontend>=1.13.0", ] +if not use_paddle_compatible_api(): + install_requires.append("torch") + generate_build_meta({}) if enable_aot: + if use_paddle_compatible_api(): + import paddle + + paddle.compat.enable_torch_proxy() + import torch import torch.utils.cpp_extension as torch_cpp_ext from packaging.version import Version