Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
83946f3
Cuda Plug EP Core
tianleiwu Mar 20, 2026
1aab300
Add doc
tianleiwu Mar 21, 2026
f97bbe4
ops
tianleiwu Mar 23, 2026
267bb06
remove cuda graph
tianleiwu Mar 23, 2026
e61f27a
review feedback
tianleiwu Mar 23, 2026
44cf955
refactoring
tianleiwu Mar 23, 2026
eb329a4
Copilot feedback
tianleiwu Mar 23, 2026
59fda40
fix test
tianleiwu Mar 24, 2026
87edf0e
add more ops
tianleiwu Mar 24, 2026
5dbba29
add fused conv
tianleiwu Mar 24, 2026
ede1493
Add group norm and qordered ops.
tianleiwu Mar 24, 2026
f4b1881
add ort stream adapter
tianleiwu Mar 24, 2026
2869301
remove duplicated link cudnn; use adapter in llm attention
tianleiwu Mar 24, 2026
dbefbf1
Merge remote-tracking branch 'origin/main' into tlwu/20260320/cuda_pl…
tianleiwu Mar 24, 2026
0bb6422
onnx attention op
tianleiwu Mar 24, 2026
7e17d0a
Merge remote-tracking branch 'origin/main' into tlwu/20260320/cuda_pl…
tianleiwu Mar 24, 2026
ced3fdb
redesign CudaKernelAdapterRuntimeConfig map
tianleiwu Mar 24, 2026
6307a15
refactor ConstantOfShape and other feedbacks
tianleiwu Mar 25, 2026
b3fdf25
update doc; fix test
tianleiwu Mar 25, 2026
8bb2a98
Merge remote-tracking branch 'origin/main' into tlwu/20260320/cuda_pl…
tianleiwu Mar 25, 2026
876c3b6
comments
tianleiwu Mar 25, 2026
45ed1c1
refine
tianleiwu Mar 25, 2026
b35226c
fix Windows build
tianleiwu Mar 25, 2026
c8341c3
refine CleanupDeferredCPUBuffers etc.
tianleiwu Mar 25, 2026
3c7e3e0
CUDA Plugin EP: Test Coverage & Bug Fixes (#27817)
tianleiwu Mar 26, 2026
34a5416
fill controlflow opset gap
tianleiwu Mar 26, 2026
5ac5b9f
ep config
tianleiwu Mar 26, 2026
0c311f9
Add fuse_conv_bias and sdpa_kernel option
tianleiwu Mar 26, 2026
6175224
RestoreDeviceIfKnown and GetCublasHandleOrDefault
tianleiwu Mar 26, 2026
d7f5205
update import style
tianleiwu Mar 26, 2026
827afa7
update tests
tianleiwu Mar 27, 2026
ef4dcc0
review feedback
tianleiwu Mar 27, 2026
5328c53
refactoring
tianleiwu Mar 27, 2026
8da28f5
update doc about arena and resource accounting
tianleiwu Mar 27, 2026
0e105fe
update doc for OpSchema API
tianleiwu Mar 27, 2026
18e4a2e
refine design of config storage
tianleiwu Mar 27, 2026
102b6b2
fx build
tianleiwu Mar 27, 2026
94e6d7a
refine
tianleiwu Mar 27, 2026
ab0d23e
add droput, identity, crop, synamicslice and fft ops
tianleiwu Mar 28, 2026
d201f35
doc: quick start
tianleiwu Mar 28, 2026
8ffb3c2
add script for parity report
tianleiwu Mar 28, 2026
481b895
Add test cases
tianleiwu Mar 29, 2026
c14fb1e
MemcpyToHost and MemcpyFromHosst
tianleiwu Mar 29, 2026
2b994c9
lintrunner
tianleiwu Mar 29, 2026
47f7674
add second gpu test
tianleiwu Mar 29, 2026
5db09d7
Merge remote-tracking branch 'origin/main' into tlwu/20260320/cuda_pl…
tianleiwu Mar 30, 2026
8471fec
Fix dropout and webgpu plugin test failures
tianleiwu Mar 30, 2026
20b9609
review feedback
tianleiwu Mar 30, 2026
d5ff49f
update design doc
tianleiwu Mar 30, 2026
6377dfd
address remaining feedbacks
tianleiwu Mar 30, 2026
3f7ec37
refine code and doc
tianleiwu Mar 30, 2026
a7fb8f1
fix webgpu
tianleiwu Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF)

cmake_dependent_option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_BUILD_CUDA_EP_AS_PLUGIN "Build CUDA EP as a separate plugin shared library" OFF "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF)
option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
Expand Down Expand Up @@ -1431,6 +1432,9 @@ if (Git_FOUND)
if (onnxruntime_USE_FP8_KV_CACHE)
string(APPEND ORT_BUILD_INFO "fp8-kv-cache=1, ")
endif()
if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN)
string(APPEND ORT_BUILD_INFO "cuda-plugin-ep=1, ")
endif()
if (onnxruntime_DUMP_TENSOR)
string(APPEND ORT_BUILD_INFO "dump-tensor=1, ")
endif()
Expand Down Expand Up @@ -1763,6 +1767,11 @@ endif()
foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES})
include(${onnxruntime_cmake_file}.cmake)
endforeach()

# CUDA EP Plugin build (independent shared library)
if (onnxruntime_BUILD_CUDA_EP_AS_PLUGIN)
include(onnxruntime_providers_cuda_plugin.cmake)
endif()
if (UNIX)
option(BUILD_PKGCONFIG_FILES "Build and install pkg-config files" ON)
else()
Expand Down
5 changes: 5 additions & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
)
endif()
# Exclude plugin directory if it was picked up by GLOB_RECURSE
list(FILTER onnxruntime_providers_cuda_cc_srcs EXCLUDE REGEX "core/providers/cuda/plugin/.*")

# Remove pch files
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h"
Expand All @@ -43,6 +46,8 @@
"${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu"
)
endif()
# Exclude plugin directory if it was picked up by GLOB_RECURSE
list(FILTER onnxruntime_providers_cuda_cu_srcs EXCLUDE REGEX "core/providers/cuda/plugin/.*")
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})

Expand Down
358 changes: 358 additions & 0 deletions cmake/onnxruntime_providers_cuda_plugin.cmake

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,11 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()
else()
target_link_libraries(onnxruntime_perf_test PRIVATE onnx_test_runner_common absl::flags absl::flags_parse ${onnx_test_libs})
# When onnxruntime_BUILD_SHARED_LIB is OFF (the plugin build path), perf test was missing CUDA include directories and CUDA::cudart linkage.
if (onnxruntime_USE_CUDA OR onnxruntime_USE_NV OR onnxruntime_USE_TENSORRT)
target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(onnxruntime_perf_test PRIVATE CUDA::cudart)
endif()
endif()
set_target_properties(onnxruntime_perf_test PROPERTIES FOLDER "ONNXRuntimeTest")

Expand Down
269 changes: 269 additions & 0 deletions docs/cuda_plugin_ep/cuda_ops_for_plugin_ep.md

Large diffs are not rendered by default.

709 changes: 709 additions & 0 deletions docs/cuda_plugin_ep/cuda_plugin_ep_design.md

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@
// Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine.
// Use InitOrtValue() methods to allocate for OrtValue.

#ifdef BUILD_CUDA_EP_AS_PLUGIN
/// Static factory kept for plugin EP kernels that still call Tensor::Create().
/// The main tree deprecated these in favor of constructors, but dynamically-linked
/// plugin code relies on the static method.
static std::unique_ptr<Tensor> Create(MLDataType elt_type, const TensorShape& shape,
std::shared_ptr<IAllocator> allocator) {
return std::make_unique<Tensor>(elt_type, shape, std::move(allocator));

Check warning on line 52 in include/onnxruntime/core/framework/tensor.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/framework/tensor.h:52: Add #include <utility> for move [build/include_what_you_use] [4]
}
#endif

Tensor() = default; // to allow creating vector<Tensor> to support seq(tensor)

/**
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/ep/adapter/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ struct KernelDefBuilder {
return *this;
}

// ExecQueueId is intentionally a no-op. The plugin EP manages stream
// assignment externally; the queue id hint is not needed.
KernelDefBuilder& ExecQueueId(int /*queue_id*/) { return *this; }

Ort::KernelDef Build() { return builder_.Build(); }
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/ep/adapter/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
return kernel_info_.GetOperatorType();
}

/** Gets the Node's domain. */
std::string Domain() const {

Check warning on line 30 in include/onnxruntime/ep/adapter/node.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/ep/adapter/node.h:30: Add #include <string> for string [build/include_what_you_use] [4]
return kernel_info_.GetOperatorDomain();
}

/** Gets the since version of the operator. */
int SinceVersion() const noexcept {
return kernel_info_.GetOperatorSinceVersion();
Expand Down
17 changes: 14 additions & 3 deletions include/onnxruntime/ep/adapter/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct OpKernel {
explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_{info} {}
virtual ~OpKernel() {}

Node Node() const {
adapter::Node Node() const {
return op_kernel_info_.node();
}
const OpKernelInfo& Info() const {
Expand Down Expand Up @@ -93,6 +93,13 @@ struct OpKernelContext {
input_tensors_[index] = CreateTensorFromApiValue(const_cast<OrtValue*>(static_cast<const OrtValue*>(input)));
return &input_tensors_[index];
}
template <typename T,
typename = std::enable_if_t<std::is_same_v<T, Tensor>>>
const T& RequiredInput(int index) const {
auto* input = Input<T>(index);
ORT_ENFORCE(input != nullptr, "Required input ", index, " is null");
return *input;
}
Comment thread
tianleiwu marked this conversation as resolved.
Tensor* Output(int index, const TensorShape& shape) {
if (index < 0 || static_cast<size_t>(index) >= output_tensors_.size()) {
return nullptr;
Expand All @@ -109,6 +116,11 @@ struct OpKernelContext {
output_tensors_[index] = CreateTensorFromApiValue(output);
return &output_tensors_[index];
}
Tensor& RequiredOutput(int index, const TensorShape& shape) {
auto* output = Output(index, shape);
ORT_ENFORCE(output != nullptr, "Required output ", index, " is null");
return *output;
}
Tensor* Output(int index, const std::vector<int64_t>& shape) {
return Output(index, TensorShape{shape});
}
Expand All @@ -131,7 +143,6 @@ struct OpKernelContext {
// TODO(fs-eire): Implement GetUseDeterministicCompute().
return false;
}

void* GetGPUComputeStream() const {
return context_.GetGPUComputeStream();
}
Expand All @@ -146,7 +157,7 @@ struct OpKernelContext {
};

/// <summary>
/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `::OrtKernelImpl`.
/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `onnxruntime::OrtKernelImpl`.
/// </summary>
struct KernelImpl : OrtKernelImpl {
explicit KernelImpl(std::unique_ptr<OpKernel> impl)
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/ep/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace onnxruntime {
namespace ep {

struct ApiPtrs {
ApiPtrs(const OrtApi& ort_, const OrtEpApi& ep_, const OrtModelEditorApi& model_editor_)
: ort(ort_), ep(ep_), model_editor(model_editor_) {}
const OrtApi& ort;
const OrtEpApi& ep;
const OrtModelEditorApi& model_editor;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ Status ComputeSoftmaxWithRawMask(Stream* ort_stream,

if (use_persistent_softmax) {
return onnxruntime::cuda::dispatch_warpwise_softmax_forward<T, T, float, false>(
ort_stream,
stream,
output,
persistent_softmax_workspace,
total_sequence_length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
int m = batch_size * sequence_length;
int n = (parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size);
int k = parameters.input_hidden_size;
gemm_buffer = GetScratchBuffer<T1>(static_cast<size_t>(m) * n, context->GetComputeStream());
gemm_buffer = GetScratchBuffer<T1>(static_cast<size_t>(m) * n, GetComputeStream(context));

CudaT one = ToCudaType<T1>::FromFloat(1.0f);
CudaT zero = ToCudaType<T1>::FromFloat(0.0f);
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ REGISTER_KERNEL_TYPED(double)

using namespace ONNX_NAMESPACE;

#ifdef BUILD_CUDA_EP_AS_PLUGIN
static Status CheckInputsForPlugin(const OpKernelContext* context) {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* bias = context->Input<Tensor>(1);

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() < 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 0 is expected to have 1 or more dimensions, got ", input_dims.size());
}

if (nullptr != bias) {
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 is expected to have 1 dimensions, got ", bias_dims.size());
}
if (bias_dims[0] != input_dims[input_dims.size() - 1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 1 dimension 0 should have same length as the last dimension of input 0");
}
}

return Status::OK();
}
#endif

template <typename T>
FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) {
const TransformerOptions* options = TransformerOptions::GetInstance();
Expand All @@ -38,7 +65,11 @@ FastGelu<T>::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel

template <typename T>
Status FastGelu<T>::ComputeInternal(OpKernelContext* context) const {
#ifdef BUILD_CUDA_EP_AS_PLUGIN
ORT_RETURN_IF_ERROR(CheckInputsForPlugin(context));
#else
ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context));
#endif

const Tensor* input = context->Input<Tensor>(0);
const Tensor* bias = context->Input<Tensor>(1);
Expand Down
40 changes: 24 additions & 16 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ GroupQueryAttention<T, U>::GroupQueryAttention(const OpKernelInfo& info)
// 11. head_sink (Tensor) - Attention sink for GPT-OSS
template <typename T, typename U>
Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) const {
// Stream access: void* for GetScratchBuffer, Stream* for QkvToContext.
#ifdef BUILD_CUDA_EP_AS_PLUGIN
onnxruntime::PluginStreamShim __stream_shim(GetComputeStream(context));
auto* ort_stream = static_cast<onnxruntime::Stream*>(&__stream_shim);
#else
auto* ort_stream = context->GetComputeStream();
#endif

const Tensor* query = context->Input<Tensor>(0);
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
Expand Down Expand Up @@ -259,8 +267,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, dense_head_size};

TensorShape present_shape(present_dims);
Tensor* present_key_tensor = context->Output(1, present_shape);
Tensor* present_value_tensor = context->Output(2, present_shape);
Tensor* present_key_output = context->Output(1, present_shape); // present_key
Tensor* present_value_output = context->Output(2, present_shape); // present_value

IAllocatorUniquePtr<void> k_buffer;
IAllocatorUniquePtr<void> v_buffer;
Expand Down Expand Up @@ -288,8 +296,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast<const CudaU*>(past_key->Data<U>());
data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast<const CudaU*>(past_value->Data<U>());

data.present_key = reinterpret_cast<CudaU*>(present_key_tensor->MutableData<U>());
data.present_value = reinterpret_cast<CudaU*>(present_value_tensor->MutableData<U>());
data.present_key = reinterpret_cast<CudaU*>(present_key_output->MutableData<U>());
data.present_value = reinterpret_cast<CudaU*>(present_value_output->MutableData<U>());

// Compute past_present_share_buffer early since it's needed for flash attention path selection.
// This compares the final pointer values after quantization handling.
Expand Down Expand Up @@ -370,7 +378,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
xqa_total_bytes += q_bytes + k_bytes;
}

xqa_scratch_buffer = this->GetScratchBuffer<void>(xqa_total_bytes, context->GetComputeStream());
xqa_scratch_buffer = this->GetScratchBuffer<void>(xqa_total_bytes, GetComputeStream(context));
data.xqa_buffer = xqa_scratch_buffer.get();
data.xqa_buffer_bytes = xqa_internal_bytes;

Expand Down Expand Up @@ -413,11 +421,11 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32));
}

softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, GetComputeStream(context));
softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, GetComputeStream(context));
out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, GetComputeStream(context));

auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
auto cuda_stream = Stream(context);
if (softmax_lse_accum_bytes > 0) {
// Initialize to 0 is fine because Flash kernel will write -inf to it if needed.
// However, the standard Flash kernel often doesn't zero it globally.
Expand All @@ -442,8 +450,8 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
} else {
// Compute sequence length buffers (past_seq_lens and total_seq_lens).
// Allocate buffer for both: first half is past_seq_lens, second half is total_seq_lens.
seq_lens_buffer = GetScratchBuffer<int>(3 * parameters.batch_size, context->GetComputeStream());
auto cuda_stream = static_cast<cudaStream_t>(context->GetComputeStream()->GetHandle());
seq_lens_buffer = GetScratchBuffer<int>(3 * parameters.batch_size, GetComputeStream(context));
auto cuda_stream = Stream(context);
data.past_seq_lens = seq_lens_buffer.get();
data.total_seq_lens = seq_lens_buffer.get() + parameters.batch_size;
data.padded_seq_lens = data.total_seq_lens + parameters.batch_size;
Expand Down Expand Up @@ -480,9 +488,9 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
? (sizeof(float) * parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size)
: 0;

k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, context->GetComputeStream());
fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, context->GetComputeStream());
k_buffer = GetScratchBuffer<void>(kv_buffer_bytes, GetComputeStream(context));
v_buffer = GetScratchBuffer<void>(kv_buffer_bytes, GetComputeStream(context));
fmha_buffer = GetScratchBuffer<void>(fmha_buffer_bytes, GetComputeStream(context));

data.k = reinterpret_cast<CudaT*>(k_buffer.get());
data.v = reinterpret_cast<CudaT*>(v_buffer.get());
Expand All @@ -501,7 +509,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
data.use_memory_efficient_attention);

if (buffer_req.qkv_buffer_bytes > 0) {
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.qkv_buffer_bytes, context->GetComputeStream());
unpacked_qkv_buffer = GetScratchBuffer<void>(buffer_req.qkv_buffer_bytes, GetComputeStream(context));
data.qkv_buffer = reinterpret_cast<CudaT*>(unpacked_qkv_buffer.get());
}

Expand Down Expand Up @@ -556,7 +564,7 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
cublasHandle_t cublas = GetCublasHandle(context);

ORT_RETURN_IF_ERROR((QkvToContext<CudaT, CudaU>(
device_prop, cublas, context->GetComputeStream(), parameters, data)));
device_prop, cublas, ort_stream, parameters, data)));
return Status::OK();
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
int m = parameters.token_count;
int n = parameters.hidden_size + parameters.hidden_size + parameters.v_hidden_size;
int k = parameters.input_hidden_size;
gemm_buffer = this->template GetScratchBuffer<T>(static_cast<size_t>(m) * n, context->GetComputeStream());
gemm_buffer = this->template GetScratchBuffer<T>(static_cast<size_t>(m) * n, this->GetComputeStream(context));

cublasHandle_t cublas = this->GetCublasHandle(context);

Expand All @@ -310,7 +310,7 @@ Status PackedAttention<T>::ComputeInternal(OpKernelContext* context) const {
false,
use_memory_efficient_attention,
no_qkv_workspace);
auto work_space = this->template GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
auto work_space = this->template GetScratchBuffer<void>(workSpaceSize, this->GetComputeStream(context));

typedef typename ToCudaType<T>::MappedType CudaT;
PackedAttentionData<CudaT> data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ Status PackedMultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) co
use_flash_attention,
use_memory_efficient_attention,
no_qkv_workspace);
auto work_space = this->template GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
auto work_space = this->template GetScratchBuffer<void>(workSpaceSize, this->GetComputeStream(context));

PackedMultiHeadAttentionData<CudaT> data;
data.query = reinterpret_cast<const CudaT*>(query->Data<T>());
Expand Down
Loading
Loading