diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index de5e859662..7ffa762e87 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include #include #include diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 09fac79d19..ea449ad040 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -1,18 +1,18 @@ import dataclasses import logging import os -import tvm_ffi from contextlib import nullcontext +from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Sequence, Union -from datetime import datetime +import tvm_ffi from filelock import FileLock +from ..compilation_context import CompilationContext from . import env as jit_env from .cpp_ext import generate_ninja_build_for_op, run_ninja from .utils import write_if_different -from ..compilation_context import CompilationContext os.makedirs(jit_env.FLASHINFER_WORKSPACE_DIR, exist_ok=True) os.makedirs(jit_env.FLASHINFER_CSRC_DIR, exist_ok=True) @@ -193,6 +193,16 @@ def get_library_path(self) -> Path: return self.aot_path return self.jit_library_path + def get_object_paths(self) -> List[Path]: + object_paths = [] + jit_dir = self.jit_library_path.parent + for source in self.sources: + is_cuda = source.suffix == ".cu" + object_suffix = ".cuda.o" if is_cuda else ".o" + obj_name = source.with_suffix(object_suffix).name + object_paths.append(jit_dir / obj_name) + return object_paths + @property def aot_path(self) -> Path: return jit_env.FLASHINFER_AOT_DIR / self.name / f"{self.name}.so"