Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,34 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const int seq_stride = parameters.seq_stride;
const int batch_stride = parameters.batch_stride;
const int position_ids_format = parameters.position_ids_format;
const int max_sequence_length = parameters.max_sequence_length;
const int rotary_emb_dim = parameters.rotary_embedding_dim;
const int half_rotary_emb_dim = rotary_emb_dim / 2;

// Validate position_ids values are within cos/sin cache bounds
if (position_ids_format == 0) {
// Format 0: single offset, effective positions are [base_pos, base_pos + sequence_length - 1].
// Check without overflow: base_pos must be in [0, max_sequence_length - sequence_length].
int64_t base_pos = position_ids[0];
int64_t max_valid_base = static_cast<int64_t>(max_sequence_length) - static_cast<int64_t>(sequence_length);
if (base_pos < 0 || base_pos > max_valid_base) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids base value ", base_pos,
" with sequence_length ", sequence_length,
" exceeds cos/sin cache range [0, ", max_sequence_length, ")");
}
} else if (position_ids_format == 1) {
// Format 1: 2D array (batch_size, sequence_length)
for (int i = 0; i < batch_size * sequence_length; ++i) {
int64_t pos = position_ids[i];
if (pos < 0 || pos >= static_cast<int64_t>(max_sequence_length)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids value ", pos, " at index ", i,
" is out of range [0, ", max_sequence_length, ")");
}
}
}

// Parallel to calculate based on head_size
const int loop_len = batch_size * sequence_length * n_heads;
// The cost is calculated as:
Expand Down
25 changes: 19 additions & 6 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNx
const int64_t* position_ids, // (1) or BxS
const int* past_sequence_lengths, // (B) for format 2
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int position_ids_format,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
) {
Expand Down Expand Up @@ -69,9 +70,21 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNx
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
int position_id = 0;
if (position_ids_format == 0) {
position_id = static_cast<int>(position_ids[0]) + s;
// Validate base without overflow: base must be in [0, max_sequence_length - sequence_length].
int64_t base_pos = position_ids[0];
int64_t max_valid_base = static_cast<int64_t>(max_sequence_length) - static_cast<int64_t>(sequence_length);
if (base_pos < 0 || base_pos > max_valid_base) {
output_data[i] = use_smem ? smem[i] : input_data[i];
return;
}
position_id = static_cast<int>(base_pos) + s;
} else if (position_ids_format == 1) {
position_id = static_cast<int>(position_ids[b * sequence_length + s]);
int64_t pos = position_ids[b * sequence_length + s];
if (pos < 0 || pos >= static_cast<int64_t>(max_sequence_length)) {
output_data[i] = use_smem ? smem[i] : input_data[i];
return;
}
position_id = static_cast<int>(pos);
} else if (position_ids_format == 2) {
// format 2: past_sequence_length + s
// used for Decoding (past_sequence_length = seqlens_k[b]) or First Prompt (past=0 if nullptr)
Expand Down Expand Up @@ -139,7 +152,7 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu
const int* past_sequence_lengths,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block,
int4 in_strides, int4 out_strides // strides in bnsh coord
Expand All @@ -164,13 +177,13 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu
size_t smem_size = head_size * sizeof(T);
RotaryEmbeddingBSNH<T, true><<<grid, block, smem_size, stream>>>(
output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
num_heads, head_size, rotary_embedding_dim, max_sequence_length, position_ids_format,
interleaved, in_strides, out_strides);
} else {
// Separate buffers: no shared memory needed
RotaryEmbeddingBSNH<T, false><<<grid, block, 0, stream>>>(
output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
num_heads, head_size, rotary_embedding_dim, max_sequence_length, position_ids_format,
interleaved, in_strides, out_strides);
}

Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/providers/cpu/llm/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,22 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const int seq_stride = parameters.seq_stride;
const int batch_stride = parameters.batch_stride;
const int position_ids_format = parameters.position_ids_format;
const int max_sequence_length = parameters.max_sequence_length;
const int rotary_emb_dim = parameters.rotary_embedding_dim;
const int half_rotary_emb_dim = rotary_emb_dim / 2;

// Validate position_ids values are within cos/sin cache bounds
if (position_ids_format != 0) {
for (int i = 0; i < batch_size * sequence_length; ++i) {
int64_t pos = position_ids[i];
if (pos < 0 || pos >= static_cast<int64_t>(max_sequence_length)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"position_ids value ", pos, " at index ", i,
" is out of range [0, ", max_sequence_length, ")");
}
Comment thread
tianleiwu marked this conversation as resolved.
}
}

// Parallel to calculate based on head_size
const int loop_len = batch_size * sequence_length * n_heads;
// The cost is calculated as:
Expand Down
19 changes: 13 additions & 6 deletions onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const T* sin_cache, // BxSx(H/2) or Mx(H/2)
const int64_t* position_ids, // (0) or BxS
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int position_ids_format,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
) {
Expand Down Expand Up @@ -52,11 +53,17 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
int cache_offset;

// position_ids_format == 0 means position_ids is nullptr
// position_ids_format == 0 means position_ids is nullptr; cache is (B*S, H/2) and index is always valid.
// position_ids_format == 1 means position_ids is a 2D array of size (batch_size, sequence_length)
int b_s_index = b * sequence_length + s;
if (position_ids_format != 0) {
b_s_index = static_cast<int>(position_ids[b_s_index]);
int64_t pos = position_ids[b_s_index];
if (pos < 0 || pos >= static_cast<int64_t>(max_sequence_length)) {
// OOB position id — can't propagate error from GPU, so pass through input unchanged.
output_data[i] = input_data[i];
return;
}
b_s_index = static_cast<int>(pos);
}
Comment thread
tianleiwu marked this conversation as resolved.
cache_offset = b_s_index * half_rotary_embedding_dim;
const T* cos_data = cos_cache + cache_offset;
Expand Down Expand Up @@ -117,7 +124,7 @@ template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block,
int4 in_strides, int4 out_strides // strides in bnsh coord
Expand All @@ -137,8 +144,8 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu

assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
interleaved, in_strides, out_strides);
num_heads, head_size, rotary_embedding_dim, max_sequence_length,
position_ids_format, interleaved, in_strides, out_strides);
return CUDA_CALL(cudaGetLastError());
}

Expand Down
187 changes: 187 additions & 0 deletions onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,5 +759,192 @@ TEST(ContribOpRotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi
true /*use_fp16*/);
}

// Test that position_ids (format 1) exceeding max_sequence_length is rejected (CPU).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_ExceedsMaxSeqLen) {
int batch_size = 1;
int sequence_length = 1;
int num_heads = 2;
int head_size = 4;
Comment thread
tianleiwu marked this conversation as resolved.
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 1.0f));
// Format 1: position_ids shape is {B, S}
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {2048});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "position_ids value 2048 at index 0 is out of range",
{}, nullptr, &execution_providers);
}

// Test that negative position_ids (format 1) are rejected (CPU).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Negative) {
int batch_size = 1;
int sequence_length = 1;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 1.0f));
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {-1});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "position_ids value -1 at index 0 is out of range",
{}, nullptr, &execution_providers);
}

// Test that out-of-bounds position_ids in a batch (format 1) are rejected (CPU).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_InBatch) {
int batch_size = 2;
int sequence_length = 2;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(batch_size * sequence_length * hidden_size, 1.0f));
// Second batch has position_id = 100 which exceeds max_sequence_length = 8
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {0, 1, 2, 100});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(batch_size * sequence_length * hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "position_ids value 100 at index 3 is out of range",
{}, nullptr, &execution_providers);
}

// Test that format-0 position_ids base offset exceeding cache is rejected (CPU).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Format0_OOB) {
int batch_size = 1;
int sequence_length = 2;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(batch_size * sequence_length * hidden_size, 1.0f));
// Format 0: single value. Effective positions = [7, 8] — position 8 is out of range [0, 8).
test.AddInput<int64_t>("position_ids", {1}, {7});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(batch_size * sequence_length * hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "position_ids base value 7 with sequence_length 2 exceeds cos/sin cache range",
{}, nullptr, &execution_providers);
}

// Test that format-0 negative position_ids base offset is rejected (CPU).
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Format0_Negative) {
int batch_size = 1;
int sequence_length = 1;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 1.0f));
// Format 0: negative base offset
test.AddInput<int64_t>("position_ids", {1}, {-5});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure, "position_ids base value -5 with sequence_length 1 exceeds cos/sin cache range",
{}, nullptr, &execution_providers);
}

// Test that OOB position_ids on CUDA (format 1) pass through input unchanged.
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_OOB_CUDA_Passthrough) {
int batch_size = 1;
int sequence_length = 1;
int num_heads = 2;
int head_size = 4;
int max_sequence_length = 8;
int hidden_size = num_heads * head_size;

if (!HasCudaEnvironment(0)) {
return; // Skip when CUDA is not available.
}

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));

std::vector<float> input_data(hidden_size);
for (int i = 0; i < hidden_size; ++i) {
input_data[i] = static_cast<float>(i + 1);
}

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size}, input_data);
// position_id = 2048 exceeds max_sequence_length = 8 — CUDA should pass through input unchanged.
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {2048});
test.AddInput<float>("cos_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 1.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, head_size / 2},
std::vector<float>(max_sequence_length * head_size / 2, 0.0f));

// Output should equal input when position_id is OOB (pass-through).
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size}, input_data);
test.SetOutputAbsErr("output", 0.0f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading
Loading