Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ffi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

[project]
name = "apache-tvm-ffi"
version = "0.1.0a6"
version = "0.1.0a7"
description = "tvm ffi"

authors = [{ name = "TVM FFI team" }]
Expand Down
40 changes: 2 additions & 38 deletions ffi/python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,6 @@ except ImportError:
torch = None


def load_torch_get_current_cuda_stream():
"""Create a faster get_current_cuda_stream for torch through cpp extension.
"""
source = """
#include <c10/cuda/CUDAStream.h>

int64_t get_current_cuda_stream(int device_id) {
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
// fast invariant, default stream is always 0
if (stream.id() == 0) return 0;
// convert to cudaStream_t
return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
}
"""
def fallback_get_current_cuda_stream(device_id):
"""Fallback with python api"""
return torch.cuda.current_stream(device_id).cuda_stream
try:
from torch.utils import cpp_extension
result = cpp_extension.load_inline(
name="get_current_cuda_stream",
cpp_sources=[source],
cuda_sources=[],
extra_cflags=["-O3"],
extra_include_paths=cpp_extension.include_paths("cuda"),
functions=["get_current_cuda_stream"],
)
return result.get_current_cuda_stream
except Exception:
return fallback_get_current_cuda_stream


torch_get_current_cuda_stream = None


cdef inline object make_ret_small_str(TVMFFIAny result):
"""convert small string to return value."""
cdef TVMFFIByteArray bytes
Expand Down Expand Up @@ -146,9 +111,8 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
ctx_dev_type[0] = temp_dltensor.device.device_type
ctx_dev_id[0] = temp_dltensor.device.device_id
if torch_get_current_cuda_stream is None:
torch_get_current_cuda_stream = load_torch_get_current_cuda_stream()
temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id)
# This is an API that dynamo and other uses to get the raw stream from torch
temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):
Expand Down
Loading