Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions onnxruntime/core/providers/cpu/rnn/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ template <typename T>
void ApplyActivationToBatches(const Tensor* sequence_lens, const T* h_prev, T* Y_buffer_data_current_frame,
int64_t time_step, int64_t batch_size, int64_t hidden_size,
T alpha, T beta, T clip, std::function<T(T, T, T)> activation_func) {
const int* seq_len_data = sequence_lens ? sequence_lens->Data<int>() : nullptr;
const int32_t* seq_len_data = sequence_lens ? sequence_lens->Data<int32_t>() : nullptr;

for (int batch = 0; batch < batch_size; batch++) {
bool valid = true;
Expand Down Expand Up @@ -95,17 +95,22 @@ void Assign_Y_h(const T* Y_buffer_data, Tensor* Y_h, const Tensor* sequence_lens
}

for (int batch = 0; batch < batch_size; batch++) {
int64_t last_time_step = isReverse ? 0 : seq_length - 1;
if (nullptr != sequence_lens && !isReverse) {
last_time_step = sequence_lens->Data<int>()[batch] - 1;
if (last_time_step < 0) {
// sequence_lens[batch] == 0: no data was processed for this batch; zero out Y_h.
// Handle zero-length sequences for both forward and reverse directions consistently.
if (nullptr != sequence_lens) {
int32_t seq_len = sequence_lens->Data<int32_t>()[batch];
if (seq_len == 0) {
int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size;
math::Set<T, CPUMathUtil>(narrow<size_t>(hidden_size), T{0},
Y_h->MutableData<T>() + Y_h_offset, &CPUMathUtil::Instance());
continue;
}
}

int64_t last_time_step = isReverse ? 0 : seq_length - 1;
if (nullptr != sequence_lens && !isReverse) {
last_time_step = sequence_lens->Data<int32_t>()[batch] - 1;
}

int64_t y_offset = last_time_step * num_directions * batch_size * hidden_size +
direction * batch_size * hidden_size +
batch * hidden_size;
Expand All @@ -121,8 +126,8 @@ void ClearMissingFrames(T* Y_buffer_data, const Tensor* sequence_lens,
int64_t num_directions, int64_t batch_size, int64_t seq_length, int64_t hidden_size) {
for (int direction = 0; direction < num_directions; direction++) {
for (int batch = 0; batch < batch_size; batch++) {
if (sequence_lens->Data<int>()[batch] < seq_length) {
for (int seq = sequence_lens->Data<int>()[batch]; seq < seq_length; seq++) {
if (sequence_lens->Data<int32_t>()[batch] < seq_length) {
for (int seq = sequence_lens->Data<int32_t>()[batch]; seq < seq_length; seq++) {
int64_t offset =
seq * num_directions * batch_size * hidden_size +
direction * batch_size * hidden_size +
Expand Down Expand Up @@ -169,6 +174,19 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
std::vector<int64_t> Y_h_dims({num_directions, batch_size, hidden_size_});
Tensor* Y_h = ctx->Output(1, Y_h_dims);

// Reset output and return if max sequence length is 0
if (sequence_lens != nullptr && sequence_lens->Shape().Size() > 0) {
int32_t max_sequence_length = *std::max_element(sequence_lens->Data<int32_t>(),
sequence_lens->Data<int32_t>() + sequence_lens->Shape().Size());
Comment thread
vraspar marked this conversation as resolved.
if (max_sequence_length == 0) {
if (Y != nullptr)
std::fill_n(Y->MutableData<float>(), Y->Shape().Size(), 0.f);
if (Y_h != nullptr)
std::fill_n(Y_h->MutableData<float>(), Y_h->Shape().Size(), 0.f);
return Status::OK();
}
}
Comment thread
yuslepukhin marked this conversation as resolved.

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ Status ValidateCommonRnnInputs(const Tensor& X,
batch_size, "}. Actual:", sequence_lens_shape);
}

auto sequence_len_entries = sequence_lens->DataAsSpan<int>();
auto sequence_len_entries = sequence_lens->DataAsSpan<int32_t>();
if (std::any_of(sequence_len_entries.begin(),
sequence_len_entries.end(),
[seq_length](int len) { return len < 0 || len > seq_length; })) {
[seq_length](int32_t len) { return len < 0 || len > seq_length; })) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid value/s in sequence_lens. All values must be > 0 and < seq_length. seq_length=", seq_length);
"Invalid value/s in sequence_lens. All values must be >= 0 and <= seq_length. seq_length=", seq_length);
}
}

Expand Down
117 changes: 115 additions & 2 deletions onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,9 @@ TEST(RNNTest, RNN_invalid_sequence_lens) {

run_test(invalid_num_seq_len_entries, "Input sequence_lens must have shape {2}. Actual:{1}");

// 0 is an invalid value
// 5 exceeds seq_length (3)
std::vector<int> bad_seq_len_entry{0, 5};
run_test(bad_seq_len_entry, "Invalid value/s in sequence_lens. All values must be > 0 and < seq_length.");
run_test(bad_seq_len_entry, "Invalid value/s in sequence_lens. All values must be >= 0 and <= seq_length.");
}

TEST(RNNTest, RNN_bidirectional_with_sequence_lens) {
Expand Down Expand Up @@ -986,5 +986,118 @@ TEST(RNNTest, RNN_forward_sequence_lens_with_zero) {
test.ConfigEp(std::move(cpu)).RunWithConfig();
}

// Test reverse RNN with all-zero sequence_lens and non-zero initial_h.
// The bug: reverse direction with sequence_lens=0 would return initial_h instead of zero-filling.
TEST(RNNTest, RNN_reverse_sequence_lens_all_zero) {
auto cpu = DefaultCpuExecutionProvider();
if (!cpu) GTEST_SKIP() << "CPU EP not available in this build.";

OpTester test("RNN");
int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 2;

test.AddAttribute("activations", vector<string>(num_directions, "Tanh"));
test.AddAttribute("direction", "reverse");
test.AddAttribute("hidden_size", hidden_size);

std::vector<int64_t> X_dims = {seq_length, batch_size, input_size};
std::vector<float> X_data({0.1f, 0.2f, 0.3f, 0.4f,
0.5f, 0.6f, 0.7f, 0.8f});
test.AddInput<float>("X", X_dims, X_data);

std::vector<int64_t> W_dims = {num_directions, hidden_size, input_size};
std::vector<float> W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f});
test.AddInput<float>("W", W_dims, W_data);

std::vector<int64_t> R_dims = {num_directions, hidden_size, hidden_size};
std::vector<float> R_data(hidden_size * hidden_size, 0.f);
test.AddInput<float>("R", R_dims, R_data);

std::vector<int64_t> B_dims = {num_directions, 2 * hidden_size};
std::vector<float> B_data(2 * hidden_size, 0.f);
test.AddInput<float>("B", B_dims, B_data);

// All batches have sequence_lens=0
std::vector<int64_t> sequence_lens_dims{batch_size};
std::vector<int> sequence_lens_data{0, 0};
test.AddInput<int>("sequence_lens", sequence_lens_dims, sequence_lens_data);

// Non-zero initial_h to detect if the bug returns initial_h instead of zeros.
std::vector<int64_t> initial_h_dims = {num_directions, batch_size, hidden_size};
std::vector<float> initial_h_data{1.f, 2.f, 3.f, 4.f, 5.f, 6.f};
test.AddInput<float>("initial_h", initial_h_dims, initial_h_data);

test.AddOptionalOutputEdge<float>();

// Y_h must be all zeros despite non-zero initial_h (sequence_lens=0 means no output).
std::vector<int64_t> Y_h_dims{num_directions, batch_size, hidden_size};
std::vector<float> Y_h_data(num_directions * batch_size * hidden_size, 0.f);
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
test.ConfigEp(std::move(cpu)).RunWithConfig();
}

// Test reverse RNN with mixed sequence_lens (0 and non-zero) and non-zero initial_h.
TEST(RNNTest, RNN_reverse_sequence_lens_mixed_zero) {
auto cpu = DefaultCpuExecutionProvider();
if (!cpu) GTEST_SKIP() << "CPU EP not available in this build.";

OpTester test("RNN");
int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 2;

test.AddAttribute("activations", vector<string>(num_directions, "Tanh"));
test.AddAttribute("direction", "reverse");
test.AddAttribute("hidden_size", hidden_size);

// X shape: [seq_length=2, batch_size=2, input_size=2]
std::vector<int64_t> X_dims = {seq_length, batch_size, input_size};
std::vector<float> X_data({0.1f, 0.2f,
0.3f, 0.4f,
0.5f, 0.6f,
0.7f, 0.8f});
test.AddInput<float>("X", X_dims, X_data);

std::vector<int64_t> W_dims = {num_directions, hidden_size, input_size};
std::vector<float> W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f});
test.AddInput<float>("W", W_dims, W_data);

std::vector<int64_t> R_dims = {num_directions, hidden_size, hidden_size};
std::vector<float> R_data(hidden_size * hidden_size, 0.f);
test.AddInput<float>("R", R_dims, R_data);

std::vector<int64_t> B_dims = {num_directions, 2 * hidden_size};
std::vector<float> B_data(2 * hidden_size, 0.f);
test.AddInput<float>("B", B_dims, B_data);

// batch 0 has sequence_lens=2 (full), batch 1 has sequence_lens=0
std::vector<int64_t> sequence_lens_dims{batch_size};
std::vector<int> sequence_lens_data{2, 0};
test.AddInput<int>("sequence_lens", sequence_lens_dims, sequence_lens_data);

// Non-zero initial_h so that the bug (returning initial_h for batch 1) is detectable.
std::vector<int64_t> initial_h_dims = {num_directions, batch_size, hidden_size};
std::vector<float> initial_h_data{0.5f, -0.5f, 0.1f, 1.f, 2.f, 3.f};
test.AddInput<float>("initial_h", initial_h_dims, initial_h_data);

test.AddOptionalOutputEdge<float>();

// Y_h: shape [1, 2, 3]
// Reverse direction processes time steps from seq_length-1 down to 0.
// For batch 0 (seq_len=2): Y_h = Y at time_step 0 (the last processed step in reverse).
// time_step 1 (first in reverse): Y = tanh(X[1,0]*W^T + initial_h*R^T)
// R=0, so Y = tanh([-0.1*0.5+0.2*0.6, 1*0.5-2*0.6, -1*0.5+3*0.6]) = tanh([0.07, -0.7, 1.3])
// time_step 0 (second in reverse): Y = tanh(X[0,0]*W^T + H_prev*R^T)
// R=0, so Y = tanh([-0.1*0.1+0.2*0.2, 1*0.1-2*0.2, -1*0.1+3*0.2]) = tanh([0.03, -0.3, 0.5])
// Y_h for batch 0 = Y at time_step 0 = tanh([0.03, -0.3, 0.5])
float y_h_batch0_f0 = std::tanh(0.03f);
float y_h_batch0_f1 = std::tanh(-0.3f);
float y_h_batch0_f2 = std::tanh(0.5f);

// batch 1 has sequence_lens=0 so Y_h must be zero (not initial_h).
std::vector<int64_t> Y_h_dims{num_directions, batch_size, hidden_size};
std::vector<float> Y_h_data{y_h_batch0_f0, y_h_batch0_f1, y_h_batch0_f2,
0.f, 0.f, 0.f};
test.AddOutput<float>("Y_h", Y_h_dims, Y_h_data);
test.ConfigEp(std::move(cpu)).RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Loading