Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions onnxruntime/core/providers/cuda/llm/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// the MHA path below, where 2D masks follow ONNX broadcasting: [A, B] → [1, 1, A, B], so
// 2D = (q_seq_len, total_seq_len) with both batch and heads broadcast.
if (attn_mask != nullptr && attn_mask->IsDataType<bool>()) {
// Allocate validation result buffer on GPU
auto validation_buffer = GetScratchBuffer<int>(parameters.batch_size, context->GetComputeStream());

// Get mask dimensions for broadcasting
// attn_mask can be 2D, 3D, or 4D and broadcasts to (batch_size, num_heads, q_seq_len, total_seq_len)
const auto& mask_shape = attn_mask->Shape();
Expand All @@ -411,11 +408,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
"Boolean attn_mask must be 2D, 3D, or 4D. Got ", mask_dims, "D.");
}

// Launch CUDA kernel to convert mask to seqlens_k and validate
// Launch CUDA kernel to convert mask to seqlens_k.
// Mask validity (right-padding, contiguous) is checked asynchronously via CUDA_KERNEL_ASSERT.
ORT_RETURN_IF_ERROR(LaunchConvertMaskToSeqlensK(
attn_mask->Data<bool>(),
seqlens_k_buffer.get(),
validation_buffer.get(),
parameters.batch_size,
parameters.total_sequence_length,
mask_dims,
Expand All @@ -424,28 +421,6 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
mask_dim2,
cuda_stream,
device_prop.maxThreadsPerBlock));

// Copy validation results to CPU and check for errors
std::vector<int> validation_host(parameters.batch_size);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(validation_host.data(), validation_buffer.get(),
sizeof(int) * parameters.batch_size,
cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));

for (int b = 0; b < parameters.batch_size; ++b) {
if (validation_host[b] == 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Boolean attn_mask for batch ", b,
" does not start with True. "
"GQA path only supports right-padding masks where valid tokens come first.");
} else if (validation_host[b] == 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Boolean attn_mask for batch ", b,
" is not contiguous. "
"GQA path only supports right-padding masks with contiguous True values "
"followed by contiguous False values (no interleaving).");
}
}
} else if (attn_mask != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Non-boolean attn_mask is not supported yet in GQA path of Attention op (CUDA).");
Expand Down
29 changes: 6 additions & 23 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@
namespace onnxruntime {
namespace cuda {

// Validation error codes (stored in validation_result buffer)
constexpr int kValidationOK = 0;
constexpr int kValidationErrorNotStartWithTrue = 1;
constexpr int kValidationErrorNotContiguous = 2;

// CUDA kernel to convert boolean attention mask to sequence lengths.
// Also validates that the mask follows right-padding convention.
// Also validates that the mask follows right-padding convention via CUDA_KERNEL_ASSERT.
//
// The kernel processes one batch per thread.
// For each batch, it finds the first False in the mask row, which indicates
// where padding starts. The sequence length is the index of first False.
//
// Validation:
// Validation (via CUDA_KERNEL_ASSERT, reported asynchronously):
// - The mask must start with True (first element must be True)
// - After the first False, all remaining elements must be False (contiguous padding)
//
Expand All @@ -31,7 +26,6 @@ constexpr int kValidationErrorNotContiguous = 2;
__global__ void ConvertMaskToSeqlensKernel(
const bool* __restrict__ attn_mask,
int* __restrict__ seqlens_k,
int* __restrict__ validation_result,
const int batch_size,
const int total_seq_len,
const int mask_dims,
Expand Down Expand Up @@ -78,15 +72,8 @@ __global__ void ConvertMaskToSeqlensKernel(
mask_row = attn_mask + effective_batch * batch_stride + h_idx * head_stride + q_idx * q_stride;
}

// Initialize validation result for this batch
validation_result[batch_idx] = kValidationOK;

// Check that mask starts with True
if (!mask_row[0]) {
validation_result[batch_idx] = kValidationErrorNotStartWithTrue;
seqlens_k[batch_idx] = -1; // Invalid
return;
}
// Validate that mask starts with True (right-padding convention)
CUDA_KERNEL_ASSERT(mask_row[0]); // mask must start with True

// Find the first False (where padding starts)
// All elements before this should be True, all after should be False
Expand All @@ -101,10 +88,8 @@ __global__ void ConvertMaskToSeqlensKernel(
seq_len = i;
found_first_false = true;
} else if (found_first_false && current) {
// Found True after False - this is invalid (not contiguous)
validation_result[batch_idx] = kValidationErrorNotContiguous;
seqlens_k[batch_idx] = -1; // Invalid
return;
// Found True after False - mask is not contiguous (invalid)
CUDA_KERNEL_ASSERT(false); // mask must be contiguous (no True after False)
}
}

Expand All @@ -115,7 +100,6 @@ __global__ void ConvertMaskToSeqlensKernel(
Status LaunchConvertMaskToSeqlensK(
const bool* attn_mask_bool,
int* seqlens_k,
int* validation_result,
int batch_size,
int total_seq_len,
int mask_dims,
Expand All @@ -134,7 +118,6 @@ Status LaunchConvertMaskToSeqlensK(
ConvertMaskToSeqlensKernel<<<blocks, threads, 0, stream>>>(
attn_mask_bool,
seqlens_k,
validation_result,
batch_size,
total_seq_len,
mask_dims,
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/core/providers/cuda/llm/attention_mask_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,13 @@ namespace cuda {
//
// Returns:
// Status::OK() on success
// Error status if mask is invalid (not right-padding, doesn't start with True, etc.)
//
// Note: This function validates the mask on GPU and will return an error if:
// - The mask doesn't start with True for any batch
// - The True/False values are not contiguous (e.g., True, False, True pattern)
// Note: Mask validity (right-padding convention, starts with True, contiguous True/False)
// is checked asynchronously via CUDA_KERNEL_ASSERT inside the kernel. Invalid masks will
// trigger a device-side assertion failure.
Status LaunchConvertMaskToSeqlensK(
const bool* attn_mask_bool,
int* seqlens_k,
int* validation_result, // GPU buffer for validation, size = batch_size
int batch_size,
int total_seq_len,
int mask_dims,
Expand Down
20 changes: 2 additions & 18 deletions onnxruntime/core/providers/cuda/llm/tensorscatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,8 @@ Status TensorScatter::ComputeInternal(OpKernelContext* context) const {
write_indices_tensor->Shape()[0] == batch_size,
"TensorScatter: write_indices must have shape [batch_size]");
write_indices = write_indices_tensor->Data<int64_t>();

// Copy write_indices to host for validation (batch_size elements, negligible overhead).
std::vector<int64_t> host_write_indices(static_cast<size_t>(batch_size));
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(host_write_indices.data(), write_indices,
static_cast<size_t>(batch_size) * sizeof(int64_t),
cudaMemcpyDeviceToHost, Stream(context)));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(Stream(context)));

for (int64_t b = 0; b < batch_size; ++b) {
int64_t wi = host_write_indices[static_cast<size_t>(b)];
ORT_ENFORCE(wi >= 0, "TensorScatter: write_indices[", b, "] = ", wi, " is negative");
if (!circular_) {
ORT_ENFORCE(wi + sequence_length <= max_sequence_length,
"TensorScatter linear mode: write_indices[", b, "] + sequence_length (",
wi, " + ", sequence_length, ") exceeds max_sequence_length (", max_sequence_length, ")");
}
}
// write_indices values (non-negative, in-bounds) are validated asynchronously
// inside the CUDA kernel via CUDA_KERNEL_ASSERT to avoid cudaStreamSynchronize.
}

// Allocate output with the same shape as past_cache.
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ __global__ void _TensorScatterKernel(

int64_t batch_idx = prefix_idx / prefix_stride_for_batch;
int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
// write_indices are validated on the host before kernel launch.
CUDA_KERNEL_ASSERT(wi >= 0);
int64_t cache_pos;
if (circular) {
cache_pos = (wi + seq_idx) % max_seq_len;
} else {
cache_pos = wi + seq_idx;
CUDA_KERNEL_ASSERT(cache_pos < max_seq_len);
}

int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
Expand Down
18 changes: 15 additions & 3 deletions onnxruntime/test/providers/cpu/llm/tensorscatter_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ TEST(TensorScatterTest, InPlace_IOBinding) {
}

// Negative write_indices should fail validation.
// Run CPU-only: CUDA validates asynchronously via CUDA_KERNEL_ASSERT.
TEST(TensorScatterTest, Linear_NegativeWriteIndex) {
OpTester test("TensorScatter", 24);
test.AddAttribute<std::string>("mode", "linear");
Expand All @@ -308,10 +309,14 @@ TEST(TensorScatterTest, Linear_NegativeWriteIndex) {
test.AddOutput<float>("present_cache", {1, 4, 3},
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});

test.Run(OpTester::ExpectResult::kExpectFailure, "is negative");
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "is negative",
{}, nullptr, &execution_providers);
}

// Linear mode: write_indices + sequence_length > max_sequence_length should fail.
// Run CPU-only: CUDA validates asynchronously via CUDA_KERNEL_ASSERT.
TEST(TensorScatterTest, Linear_OutOfBoundsWriteIndex) {
OpTester test("TensorScatter", 24);
test.AddAttribute<std::string>("mode", "linear");
Expand All @@ -324,10 +329,14 @@ TEST(TensorScatterTest, Linear_OutOfBoundsWriteIndex) {
test.AddOutput<float>("present_cache", {1, 4, 3},
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});

test.Run(OpTester::ExpectResult::kExpectFailure, "exceeds max_sequence_length");
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "exceeds max_sequence_length",
{}, nullptr, &execution_providers);
}

// Circular mode: negative write_indices should still fail.
// Run CPU-only: CUDA validates asynchronously via CUDA_KERNEL_ASSERT.
TEST(TensorScatterTest, Circular_NegativeWriteIndex) {
OpTester test("TensorScatter", 24);
test.AddAttribute<std::string>("mode", "circular");
Expand All @@ -339,7 +348,10 @@ TEST(TensorScatterTest, Circular_NegativeWriteIndex) {
test.AddOutput<float>("present_cache", {1, 4, 2},
{0, 0, 0, 0, 0, 0, 0, 0});

test.Run(OpTester::ExpectResult::kExpectFailure, "is negative");
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "is negative",
{}, nullptr, &execution_providers);
}

} // namespace test
Expand Down
Loading