Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/cudf/cudf/core/udf/udf_kernel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from cudf.core.udf.nrt_utils import CaptureNRTUsage, nrt_enabled
from cudf.core.udf.strings_typing import str_view_arg_handler
from cudf.core.udf.utils import (
UDF_SHIM_FILE,
_generate_cache_key,
_masked_array_type_from_col,
_ptx_file,
_supported_cols_from_frame,
compile_udf,
precompiled as kernel_cache,
Expand Down Expand Up @@ -157,7 +157,9 @@ def compile_kernel_string(self, kernel_string, nrt=False):
ctx = nrt_enabled() if nrt else nullcontext()
with ctx:
kernel = cuda.jit(
self.sig, link=[_ptx_file()], extensions=[str_view_arg_handler]
self.sig,
link=[UDF_SHIM_FILE],
extensions=[str_view_arg_handler],
)(_kernel)
return kernel

Expand Down
63 changes: 3 additions & 60 deletions python/cudf/cudf/core/udf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import functools
import glob
import os
from pickle import dumps
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -68,65 +67,9 @@
_udf_code_cache: cachetools.LRUCache = cachetools.LRUCache(maxsize=32)


def _get_best_ptx_file(archs, max_compute_capability):
"""
Determine of the available PTX files which one is
the most recent up to and including the device compute capability.
"""
filtered_archs = [x for x in archs if x[0] <= max_compute_capability]
if filtered_archs:
return max(filtered_archs, key=lambda x: x[0])
else:
return None


def _get_ptx_file(path, prefix):
if "RAPIDS_NO_INITIALIZE" in os.environ:
# cc=70 ptx is always built
cc = int(os.environ.get("STRINGS_UDF_CC", "70"))
else:
dev = cuda.get_current_device()

# Load the highest compute capability file available that is less than
# the current device's.
cc = int("".join(str(x) for x in dev.compute_capability))
files = glob.glob(os.path.join(path, f"{prefix}*.ptx"))
if len(files) == 0:
raise RuntimeError(f"Missing PTX files for cc={cc}")
regular_sms = []

for f in files:
file_name = os.path.basename(f)
sm_number = file_name.rstrip(".ptx").lstrip(prefix)
if sm_number.endswith("a"):
processed_sm_number = int(sm_number.rstrip("a"))
if processed_sm_number == cc:
return f
else:
regular_sms.append((int(sm_number), f))

regular_result = None

if regular_sms:
regular_result = _get_best_ptx_file(regular_sms, cc)

if regular_result is None:
raise RuntimeError(
"This cuDF installation is missing the necessary PTX "
f"files that are <={cc}."
)
else:
return regular_result[1]


@functools.cache
def _ptx_file():
return _get_ptx_file(
os.path.join(
os.path.dirname(strings_udf.__file__), "..", "core", "udf"
),
"shim_",
)
UDF_SHIM_FILE = os.path.join(
os.path.dirname(strings_udf.__file__), "..", "core", "udf", "shim.fatbin"
)


def _all_dtypes_from_frame(frame, supported_types=JIT_SUPPORTED_TYPES):
Expand Down
36 changes: 19 additions & 17 deletions python/cudf/udf_cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,24 @@ list(TRANSFORM CMAKE_CUDA_ARCHITECTURES REPLACE "-virtual" "")
list(SORT CMAKE_CUDA_ARCHITECTURES)
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES)

foreach(arch IN LISTS CMAKE_CUDA_ARCHITECTURES)
set(tgt shim_${arch})

add_library(${tgt} OBJECT shim.cu)

set_target_properties(${tgt} PROPERTIES CUDA_ARCHITECTURES ${arch} CUDA_PTX_COMPILATION ON)
add_library(shim OBJECT shim.cu)
set_target_properties(
shim
PROPERTIES CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}"
CUDA_SEPARABLE_COMPILATION ON
CUDA_FATBIN_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
INTERPROCEDURAL_OPTIMIZATION ON
)

target_include_directories(
${tgt} PUBLIC "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/strings/include>"
)
target_compile_options(${tgt} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${SHIM_CUDA_FLAGS}>")
target_link_libraries(${tgt} PUBLIC cudf::cudf)
target_compile_options(shim PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${SHIM_CUDA_FLAGS}>")
target_link_libraries(shim PUBLIC cudf::cudf)
target_include_directories(
shim PUBLIC "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/strings/include>"
)

install(
FILES $<TARGET_OBJECTS:${tgt}>
DESTINATION cudf/core/udf/
RENAME ${tgt}.ptx
)
endforeach()
install(
FILES $<TARGET_OBJECTS:shim>
DESTINATION cudf/core/udf/
RENAME shim.fatbin
)