Skip to content

Commit 43703ac

Browse files
committed
[FFI][REFACTOR] Establish Stream Context in ffi
This PR sets up the stream context in ffi and migrate the existing per device API stream context management to ffi env API. The new API will help us to streamline stream related integration formost libraries.
1 parent 6bc94d0 commit 43703ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+305
-157
lines changed

ffi/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
7373
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module.cc"
7474
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_system_lib.cc"
7575
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/library_module_dynamic_lib.cc"
76+
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/extra/stream_context.cc"
7677
)
7778
endif()
7879

ffi/include/tvm/ffi/extra/c_env_api.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,39 @@
2929
extern "C" {
3030
#endif
3131

32+
// ----------------------------------------------------------------------------
33+
// Stream context
34+
// Focusing on minimalistic thread-local context recording stream being used.
35+
// We explicitly not handle allocation/de-allocation of stream here.
36+
// ----------------------------------------------------------------------------
37+
typedef void* TVMFFIStreamHandle;
38+
39+
/*!
40+
* \brief FFI function to set the current stream for a device
41+
*
42+
* \param device_type The type of the device.
43+
* \param device_id The id of the device.
44+
* \param stream The stream to set.
45+
* \param opt_out_original_stream Output original stream if the address is not nullptr.
46+
* \note The stream is a weak reference that is cached/owned by the module.
47+
* \return 0 when success, nonzero when failure happens
48+
*/
49+
TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id,
50+
TVMFFIStreamHandle stream,
51+
TVMFFIStreamHandle* opt_out_original_stream);
52+
53+
/*!
54+
* \brief FFI function to get the current stream for a device
55+
*
56+
* \param device_type The type of the device.
57+
* \param device_id The id of the device.
58+
* \return The current stream of the device.
59+
*/
60+
TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id);
61+
62+
// ----------------------------------------------------------------------------
63+
// Module symbol management
64+
// ----------------------------------------------------------------------------
3265
/*!
3366
* \brief FFI function to lookup a function from a module's imports.
3467
*
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
/*
20+
* \file src/ffi/extra/stream_context.cc
21+
*
22+
* \brief A minimalistic stream context based on ffi values.
23+
*/
24+
25+
#include <tvm/ffi/extra/c_env_api.h>
26+
#include <tvm/ffi/function.h>
27+
28+
#include <vector>
29+
30+
namespace tvm {
31+
namespace ffi {
32+
33+
class StreamContext {
34+
public:
35+
void SetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
36+
TVMFFIStreamHandle* out_original_stream) {
37+
if (static_cast<size_t>(device_type) >= stream_table_.size()) {
38+
stream_table_.resize(device_type + 1);
39+
}
40+
if (static_cast<size_t>(device_id) >= stream_table_[device_type].size()) {
41+
stream_table_[device_type].resize(device_id + 1, nullptr);
42+
}
43+
if (out_original_stream != nullptr) {
44+
*out_original_stream = stream_table_[device_type][device_id];
45+
}
46+
stream_table_[device_type][device_id] = stream;
47+
}
48+
49+
TVMFFIStreamHandle GetStream(int32_t device_type, int32_t device_id) {
50+
if (static_cast<size_t>(device_type) < stream_table_.size() &&
51+
static_cast<size_t>(device_id) < stream_table_[device_type].size()) {
52+
return stream_table_[device_type][device_id];
53+
}
54+
return nullptr;
55+
}
56+
57+
static StreamContext* ThreadLocal() {
58+
static thread_local StreamContext inst;
59+
return &inst;
60+
}
61+
62+
private:
63+
std::vector<std::vector<TVMFFIStreamHandle>> stream_table_;
64+
};
65+
66+
} // namespace ffi
67+
} // namespace tvm
68+
69+
int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, TVMFFIStreamHandle stream,
70+
TVMFFIStreamHandle* out_original_stream) {
71+
TVM_FFI_SAFE_CALL_BEGIN();
72+
tvm::ffi::StreamContext::ThreadLocal()->SetStream(device_type, device_id, stream,
73+
out_original_stream);
74+
TVM_FFI_SAFE_CALL_END();
75+
}
76+
77+
TVMFFIStreamHandle TVMFFIEnvGetCurrentStream(int32_t device_type, int32_t device_id) {
78+
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
79+
return tvm::ffi::StreamContext::ThreadLocal()->GetStream(device_type, device_id);
80+
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetCurrentStream);
81+
}

include/tvm/runtime/device_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ class TVM_DLL DeviceAPI {
225225
* \param dev The device to set stream.
226226
* \param stream The stream to be set.
227227
*/
228-
virtual void SetStream(Device dev, TVMStreamHandle stream) {}
228+
virtual void SetStream(Device dev, TVMStreamHandle stream);
229229
/*!
230230
* \brief Get the current stream
231231
* \param dev The device to get stream.

python/tvm/contrib/cutlass/attention_operation.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,7 @@ def instantiate_attention_template(attrs):
147147
}
148148
149149
CHECK(Attention::check_supported(p));
150-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
151-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
150+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));
152151
153152
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
154153
@@ -186,8 +185,7 @@ def instantiate_flash_attention_template(attrs):
186185
int v_batch_stride = v_row_stride * ${num_keys};
187186
int o_batch_stride = o_row_stride * ${num_queries};
188187
189-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
190-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
188+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));
191189
192190
flash_attn::flash_attention_forward(
193191
static_cast<const cutlass::half_t*>(${query}->data),
@@ -237,8 +235,7 @@ def instantiate_flash_attention_template(attrs):
237235
int v_batch_stride = v_row_stride * ${num_keys};
238236
int o_batch_stride = o_row_stride * ${num_queries};
239237
240-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
241-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
238+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));
242239
243240
flash_attn::flash_attention_forward(
244241
static_cast<const cutlass::half_t*>(${qkv}->data),
@@ -294,8 +291,7 @@ def instantiate_flash_attention_var_len_template(attrs):
294291
int v_row_stride = v_head_stride * ${num_kv_heads};
295292
int o_row_stride = o_head_stride * ${num_q_heads};
296293
297-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
298-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
294+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${query}->device.device_id));
299295
300296
flash_attn::flash_attention_var_len_forward(
301297
static_cast<const cutlass::half_t*>(${query}->data),

python/tvm/contrib/cutlass/conv2d_operation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,7 @@ def instantiate_conv2d_template(attrs):
424424
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
425425
${split_k_update}
426426
427-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
428-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
427+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${data_arg}->device.device_id));
429428
430429
status = conv2d_op(stream);
431430
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);

python/tvm/contrib/cutlass/gemm_operation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,7 @@ def instantiate_gemm_template(attrs):
345345
status = gemm_op.initialize(arguments, workspace.get());
346346
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
347347
348-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
349-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
348+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id));
350349
351350
status = gemm_op(stream);
352351
TVM_FFI_ICHECK(status == cutlass::Status::kSuccess);
@@ -428,8 +427,8 @@ def emit_fp16A_intB_matmul(attrs):
428427
int n = ${B_arg}->shape[1] * ${float_per_int};
429428
int k = ${B_arg}->shape[0];
430429
431-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
432-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
430+
cudaStream_t stream = static_cast<cudaStream_t>(
431+
TVMFFIEnvGetCurrentStream(kDLCUDA, ${A_arg}->device.device_id));
433432
""",
434433
attrs,
435434
)
@@ -447,12 +446,14 @@ def emit_fp16A_intB_matmul(attrs):
447446

448447
template_residual = """
449448
${template_common}
450-
gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>(static_cast<cutlass::half_t*>(${A_arg}->data),
449+
gemm_fp16_int_bias_act_residual<${weight_dtype}, QuantOp>(
450+
static_cast<cutlass::half_t*>(${A_arg}->data),
451451
static_cast<${weight_dtype}*>(${B_arg}->data),
452452
static_cast<cutlass::half_t*>(${scales_arg}->data),
453453
${bias},
454454
static_cast<cutlass::half_t*>(${residual_arg}->data),
455-
static_cast<cutlass::half_t*>(out0->data), "${activation}", "${binary_op}", "${unary_op}",
455+
static_cast<cutlass::half_t*>(out0->data),
456+
"${activation}", "${binary_op}", "${unary_op}",
456457
m, n, k, ${group_size}, nullptr, 0, stream);
457458
"""
458459

python/tvm/contrib/cutlass/gen_tensor_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def instantiate_template(func_name, annotations, func_args):
487487
if k in annotations:
488488
attrs[k] = annotations[k]
489489

490-
headers = ["tvm/ffi/function.h"]
490+
headers = ["tvm/ffi/function.h", "tvm/ffi/extra/c_env_api.h"]
491491

492492
if "relu" in func_name:
493493
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")

python/tvm/contrib/cutlass/layer_norm_operation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def instantiate_layer_norm_template(attrs):
3939
cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data, layout_channels);
4040
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, layout_2D);
4141
42-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
43-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
42+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id));
4443
4544
cutlass::layernorm(size, _output, _input, _gamma, _beta, stream);
4645
"""

python/tvm/contrib/cutlass/rms_norm_operation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ def instantiate_rms_norm_template(attrs):
3838
cutlass::TensorRef<data_type, RowMajor> _weight((data_type*)${weight}->data, layout_channels);
3939
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, layout_2D);
4040
41-
auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
42-
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
41+
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA, ${input}->device.device_id));
4342
4443
cutlass::rmsnorm(size, _output, _input, _weight, stream, ${rms_eps});
4544
"""

0 commit comments

Comments
 (0)