Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc"
)
endif()

Expand Down
33 changes: 33 additions & 0 deletions ffi/include/tvm/ffi/extra/c_env_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@
extern "C" {
#endif

// ----------------------------------------------------------------------------
// Stream context
// Focusing on minimalistic thread-local context recording stream being used.
// We explicitly not handle allocation/de-allocation of stream here.
// ----------------------------------------------------------------------------
typedef void* TVMFFIStreamHandle;

/*!
* \brief FFI function to set the current stream for a device
*
* \param device_type The type of the device.
* \param device_id The id of the device.
* \param stream The stream to set.
* \param opt_out_original_stream Output original stream if the address is not nullptr.
* \note The stream is a weak reference that is cached/owned by the module.
* \return 0 when success, nonzero when failure happens
*/
TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
TVMFFIStreamHandle stream,
TVMFFIStreamHandle* opt_out_original_stream);

/*!
* \brief FFI function to get the current stream for a device
*
* \param device_type The type of the device.
* \param device_id The id of the device.
* \return The current stream of the device.
*/
TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id);

// ----------------------------------------------------------------------------
// Module symbol management
// ----------------------------------------------------------------------------
/*!
* \brief FFI function to lookup a function from a module's imports.
*
Expand Down
81 changes: 81 additions & 0 deletions ffi/src/ffi/extra/stream_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file src/ffi/extra/stream_context.cc
*
* \brief A minimalistic stream context based on ffi values.
*/

#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>

#include <vector>

namespace tvm {
namespace ffi {

class StreamContext {
public:
void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
TVMFFIStreamHandle* out_original_stream) {
if (static_cast<size_t>(device_type) >= stream_table_.size()) {
stream_table_.resize(device_type + 1);
}
if (static_cast<size_t>(device_id) >= stream_table_[device_type].size()) {
stream_table_[device_type].resize(device_id + 1, nullptr);
}
if (out_original_stream != nullptr) {
*out_original_stream = stream_table_[device_type][device_id];
}
stream_table_[device_type][device_id] = stream;
}

TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) {
if (static_cast<size_t>(device_type) < stream_table_.size() &&
static_cast<size_t>(device_id) < stream_table_[device_type].size()) {
return stream_table_[device_type][device_id];
}
return nullptr;
}

static StreamContext* ThreadLocal() {
static thread_local StreamContext inst;
return &inst;
}

private:
std::vector<std::vector<TVMFFIStreamHandle>> stream_table_;
};

} // namespace ffi
} // namespace tvm

int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
TVMFFIStreamHandle* out_original_stream) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream,
out_original_stream);
TVM_FFI_SAFE_CALL_END();
}

TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) {
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, device_id);
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream);
}
2 changes: 1 addition & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class TVM_DLL DeviceAPI {
* \param dev The device to set stream.
* \param stream The stream to be set.
*/
virtual void SetStream(Device dev, TVMStreamHandle stream) {}
virtual void SetStream(Device dev, TVMStreamHandle stream);
/*!
* \brief Get the current stream
* \param dev The device to get stream.
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def instantiate_attention_template(attrs):
}

CHECK(Attention::check_supported(p));
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));

kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);

Expand Down Expand Up @@ -186,8 +185,7 @@ def instantiate_flash_attention_template(attrs):
int v_batch_stride = v_row_stride * ${num_keys};
int o_batch_stride = o_row_stride * ${num_queries};

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));

flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${query}->data),
Expand Down Expand Up @@ -237,8 +235,7 @@ def instantiate_flash_attention_template(attrs):
int v_batch_stride = v_row_stride * ${num_keys};
int o_batch_stride = o_row_stride * ${num_queries};

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));

flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${qkv}->data),
Expand Down Expand Up @@ -294,8 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs):
int v_row_stride = v_head_stride * ${num_kv_heads};
int o_row_stride = o_head_stride * ${num_q_heads};

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));

flash_attn::flash_attention_var_len_forward(
static_cast<const cutlass::half_t*>(${query}->data),
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ def instantiate_conv2d_template(attrs):
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
${split_k_update}

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${data_arg}->device.device_id));

status = conv2d_op(stream);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/contrib/cutlass/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def instantiate_gemm_template(attrs):
status = gemm_op.initialize(arguments, workspace.get());
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id));

status = gemm_op(stream);
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
Expand Down Expand Up @@ -428,8 +427,8 @@ def emit_fp16A_intB_matmul(attrs):
int n = ${B_arg}->shape[1] * ${float_per_int};
int k = ${B_arg}->shape[0];

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(
TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id));
""",
attrs,
)
Expand All @@ -447,12 +446,14 @@ def emit_fp16A_intB_matmul(attrs):

template_residual = """
${template_common}
gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>(
static_cast<cutlass::half_t*>(${A_arg}->data),
static_cast<${weight_dtype}*>(${B_arg}->data),
static_cast<cutlass::half_t*>(${scales_arg}->data),
${bias},
static_cast<cutlass::half_t*>(${residual_arg}->data),
static_cast<cutlass::half_t*>(out0->data), "${activation}", "${binary_op}", "${unary_op}",
static_cast<cutlass::half_t*>(out0->data),
"${activation}", "${binary_op}", "${unary_op}",
m, n, k, ${group_size}, nullptr, 0, stream);
"""

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def instantiate_template(func_name, annotations, func_args):
if k in annotations:
attrs[k] = annotations[k]

headers = ["tvm/ffi/function.h"]
headers = ["tvm/ffi/function.h", "tvm/ffi/extra/c_env_api.h"]

if "relu" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/contrib/cutlass/layer_norm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def instantiate_layer_norm_template(attrs):
cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data, layout_channels);
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, layout_2D);

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id));

cutlass::layernorm(size, _output, _input, _gamma, _beta, stream);
"""
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/contrib/cutlass/rms_norm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def instantiate_rms_norm_template(attrs):
cutlass::TensorRef<data_type, RowMajor> _weight((data_type*)${weight}->data, layout_channels);
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, layout_2D);

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id));

cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps});
"""
Expand Down
4 changes: 3 additions & 1 deletion src/contrib/msc/plugin/tvm_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ void TVMPluginCodeGen::CodeGenOpRuntime(const Plugin& plugin) {
const auto& attr_name = MetaAttrCls(plugin);
const auto& func_name = ComputeName(plugin);
String device_cond = "";
String device_index = "";
for (size_t i = 0; i < plugin->inputs.size(); i++) {
String device_type = "";
if (plugin->inputs[i]->device == "cuda" || plugin->inputs[i]->device == "default") {
Expand Down Expand Up @@ -381,7 +382,8 @@ void TVMPluginCodeGen::CodeGenCompute(const Plugin& plugin, const String& device
ICHECK(plugin->buffers.size() == 0) << "Plugin with buffers is not supported in tvm";
compute_args.push_back("meta_attr");
if (device == "cuda") {
stack_.assign("stream", "runtime::CUDAThreadEntry::ThreadLocal()->stream", "auto");
// TODO(tvm-team): update to support get stream from device id
stack_.assign("stream", "TVMFFIEnvGetCurrentStream(kDLCUDA, 0)", "auto");
compute_args.push_back("stream");
}
CodeGenSafeCall(plugin->externs[device + "_compute"], compute_args);
Expand Down
11 changes: 6 additions & 5 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/*!
* \file Use external cblas library call.
*/
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/data_type.h>
Expand Down Expand Up @@ -522,7 +523,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
auto A = args[0].cast<DLTensor*>();
auto C = args[2].cast<DLTensor*>();

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device);

CUBLASTryEnableTensorCore(entry_ptr->handle);

Expand All @@ -549,15 +550,15 @@ TVM_FFI_STATIC_INIT_BLOCK({
"tvm.contrib.cublaslt.matmul", [](ffi::PackedArgs args, ffi::Any* ret) {
auto A = args[0].cast<DLTensor*>();

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device);

CUBLASTryEnableTensorCore(entry_ptr->handle);

ICHECK(TypeMatch(A->dtype, kDLInt, 8)) << "Expects dtype to be int8\n";
cublasLtHandle_t ltHandle;
CHECK_CUBLAS_ERROR(cublasLtCreate(&ltHandle));
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, A->device.device_id));
CallLtIgemm(args, ret, ltHandle, stream);
CHECK_CUBLAS_ERROR(cublasLtDestroy(ltHandle));
});
Expand All @@ -571,7 +572,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
auto A = args[0].cast<DLTensor*>();
auto C = args[2].cast<DLTensor*>();

CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal();
CuBlasThreadEntry* entry_ptr = CuBlasThreadEntry::ThreadLocal(A->device);

CUBLASTryEnableTensorCore(entry_ptr->handle);
if (TypeEqual(A->dtype, C->dtype)) {
Expand Down
16 changes: 10 additions & 6 deletions src/runtime/contrib/cublas/cublas_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* \brief A simple JSON runtime for CUBLAS.
*/

#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/ndarray.h>
Expand All @@ -30,6 +31,7 @@
#include <string>
#include <vector>

#include "../../cuda/cuda_common.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"
#include "cublas_utils.h"
Expand Down Expand Up @@ -67,13 +69,8 @@ class CublasJSONRuntime : public JSONRuntimeBase {
const char* kind() const override { return "cublas_json"; } // May be overridden

void Run(ffi::PackedArgs args) {
auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();

auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());

std::vector<const DLTensor*> dl_tensors(NumEntries());

int device_id = -1;
for (size_t i = 0; i < static_cast<size_t>(args.size()); i++) {
auto eid = i < input_var_eid_.size() ? input_var_eid_[i]
: EntryID(outputs_[i - input_var_eid_.size()]);
Expand All @@ -87,7 +84,14 @@ class CublasJSONRuntime : public JSONRuntimeBase {
}

dl_tensors[eid] = arg;
device_id = arg->device.device_id;
}

if (device_id == -1) {
CUDA_CALL(cudaGetDevice(&device_id));
}
auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(DLDevice{kDLCUDA, device_id});
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, device_id));

auto get_input = [this, &dl_tensors](const JSONGraphNode& node, int idx) {
ICHECK_LT(idx, node.GetInputs().size());
Expand Down
12 changes: 8 additions & 4 deletions src/runtime/contrib/cublas/cublas_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "cublas_utils.h"

#include <dmlc/thread_local.h>
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>

#include "../../cuda/cuda_common.h"
Expand All @@ -41,10 +42,11 @@ CuBlasThreadEntry::~CuBlasThreadEntry() {

typedef dmlc::ThreadLocalStore<CuBlasThreadEntry> CuBlasThreadStore;

CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
auto stream = runtime::CUDAThreadEntry::ThreadLocal()->stream;
CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal(DLDevice curr_device) {
CuBlasThreadEntry* retval = CuBlasThreadStore::Get();
CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, static_cast<cudaStream_t>(stream)));
cudaStream_t stream = static_cast<cudaStream_t>(
TVMFFIEnvGetCurrentStream(curr_device.device_type, curr_device.device_id));
CHECK_CUBLAS_ERROR(cublasSetStream(retval->handle, stream));
return retval;
}

Expand All @@ -71,7 +73,9 @@ CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {

typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;

CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal() { return CuBlasLtThreadStore::Get(); }
CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal(DLDevice curr_device) {
return CuBlasLtThreadStore::Get();
}

} // namespace contrib
} // namespace tvm
Loading
Loading