Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions flashinfer/jit/cpp_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from . import env as jit_env
from ..utils import use_paddle_compatible_api
from ..compilation_context import CompilationContext


Expand Down Expand Up @@ -75,13 +76,26 @@ def generate_ninja_build_for_op(
) -> str:
system_includes = [
sysconfig.get_path("include"),
"$torch_home/include",
"$torch_home/include/torch/csrc/api/include",
"$cuda_home/include",
"$cuda_home/include/cccl",
jit_env.FLASHINFER_INCLUDE_DIR.resolve(),
jit_env.FLASHINFER_CSRC_DIR.resolve(),
]
if use_paddle_compatible_api():
system_includes.extend(
[
"$torch_home/include",
"$torch_home/include/torch/csrc/api/include",
]
)
Comment on lines +84 to +90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The system_includes.extend method is called within the if block, but the system_includes list already contains some default paths. This could lead to duplicate include paths if use_paddle_compatible_api() returns True, potentially causing issues during compilation. Consider adding the default paths within the else block to avoid duplication.

Alternatively, you can initialize system_includes as an empty list and populate it entirely within the if and else blocks to ensure no overlap.

        system_includes = [
            sysconfig.get_path("include"),
            "$cuda_home/include",
            jit_env.FLASHINFER_INCLUDE_DIR.resolve(),
            jit_env.FLASHINFER_CSRC_DIR.resolve(),
        ]
        if use_paddle_compatible_api():
            system_includes.extend(
                [
                    "$torch_home/include",
                    "$torch_home/include/torch/csrc/api/include",
                ]
            )
        else:
            system_includes.extend(
                [
                    "$torch_home/include",
                    "$torch_home/include/paddle/phi/api/include/compat",
                    "$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include",
                ]
            )

else:
system_includes.extend(
[
"$torch_home/include",
"$torch_home/include/paddle/phi/api/include/compat",
"$torch_home/include/paddle/phi/api/include/compat/torch/csrc/api/include",
]
)
system_includes += [p.resolve() for p in jit_env.CUTLASS_INCLUDE_DIRS]
system_includes.append(jit_env.SPDLOG_INCLUDE_DIR.resolve())

Expand Down Expand Up @@ -144,15 +158,35 @@ def generate_ninja_build_for_op(

ldflags = [
"-shared",
"-L$torch_home/lib",
"-L$cuda_home/lib64",
"-lc10",
"-lc10_cuda",
"-ltorch_cpu",
"-ltorch_cuda",
"-ltorch",
"-lcudart",
]
Comment on lines 159 to 162
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ldflags list is initialized with "-shared" and "-lcudart" regardless of the use_paddle_compatible_api() condition. This could lead to redundancy or conflicts if the subsequent extend calls also include -shared or -lcudart. Consider initializing ldflags as an empty list and adding these flags conditionally within the if and else blocks to avoid potential issues.

    ldflags = []
    if use_paddle_compatible_api():
        ldflags.extend(
            [
                "-shared",
                "-L$torch_home/lib",
                "-L$cuda_home/lib64",
                "-lc10",
                "-lc10_cuda",
                "-ltorch_cpu",
                "-ltorch_cuda",
                "-ltorch",
                "-lcudart",
            ]
        )
    else:
        ldflags.extend(
            [
                "-shared",
                "-L$torch_home/libs",
                "-L$torch_home/base",
                "-L$cuda_home/lib64",
                "-lpaddle",
                "-lphi",
                "-lphi_core",
                "-lphi_gpu",
                "-lcommon",
                "-lcudart",
            ]
        )

if use_paddle_compatible_api():
ldflags.extend(
[
"-L$torch_home/lib",
"-L$cuda_home/lib64",
"-lc10",
"-lc10_cuda",
"-ltorch_cpu",
"-ltorch_cuda",
"-ltorch",
]
)
else:
ldflags.extend(
[
"-shared",
"-L$torch_home/libs",
"-L$torch_home/base",
"-L$cuda_home/lib64",
"-lpaddle",
"-lphi",
"-lphi_core",
"-lphi_gpu",
"-lcommon",
"-lcudart",
]
)

env_extra_ldflags = os.environ.get("FLASHINFER_EXTRA_LDFLAGS")
if env_extra_ldflags:
Expand Down
20 changes: 14 additions & 6 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

from .jit import gen_jit_spec, env as jit_env
import flashinfer

IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"

Expand Down Expand Up @@ -240,7 +240,15 @@ def _check_cached_qkv_data_type(
)


if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"):
def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


if (
use_paddle_compatible_api()
or IS_BUILDING_DOCS
or TorchVersion(torch_version) < TorchVersion("2.4")
):

def register_custom_op(
name: str,
Expand Down Expand Up @@ -492,14 +500,14 @@ def check_shape_dtype_device(


def gen_logging_module():
return gen_jit_spec(
return flashinfer.jit.gen_jit_spec(
"logging",
[
jit_env.FLASHINFER_CSRC_DIR / "logging.cc",
flashinfer.jit.env.FLASHINFER_CSRC_DIR / "logging.cc",
],
extra_include_paths=[
jit_env.SPDLOG_INCLUDE_DIR,
jit_env.FLASHINFER_INCLUDE_DIR,
flashinfer.jit.env.SPDLOG_INCLUDE_DIR,
flashinfer.jit.env.FLASHINFER_INCLUDE_DIR,
],
)

Expand Down
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
enable_aot = aot_ops_package_dir.is_dir() and any(aot_ops_package_dir.iterdir())


def use_paddle_compatible_api() -> bool:
return os.environ.get("PADDLE_COMPATIBLE_API", "0").lower() in ["1", "on", "true"]


def write_if_different(path: Path, content: str) -> None:
if path.exists() and path.read_text() == content:
return
Expand All @@ -55,7 +59,6 @@ def generate_build_meta(aot_build_meta: dict) -> None:
cmdclass: Mapping[str, type[setuptools.Command]] = {}
install_requires = [
"numpy",
"torch",
"ninja",
"requests",
"pynvml",
Expand All @@ -66,9 +69,17 @@ def generate_build_meta(aot_build_meta: dict) -> None:
"packaging>=24.2",
"nvidia-cudnn-frontend>=1.13.0",
]
if not use_paddle_compatible_api():
install_requires.append("torch")

generate_build_meta({})

if enable_aot:
if use_paddle_compatible_api():
import paddle
Comment on lines 77 to +79
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Consider adding a check to ensure paddle is importable before calling paddle.compat.enable_torch_proxy(). If paddle is not installed or available in the environment, this could lead to an ImportError and break the build process. A try...except block can be used to handle this scenario gracefully.

Suggested change
if enable_aot:
if use_paddle_compatible_api():
import paddle
if use_paddle_compatible_api():
try:
import paddle
paddle.compat.enable_torch_proxy()
except ImportError:
print("PaddlePaddle is not installed. Skipping paddle.compat.enable_torch_proxy().")


paddle.compat.enable_torch_proxy()

import torch
import torch.utils.cpp_extension as torch_cpp_ext
from packaging.version import Version
Expand Down