diff --git a/onnxruntime/core/providers/cpu/rnn/rnn.cc b/onnxruntime/core/providers/cpu/rnn/rnn.cc index 7b27befeb3e66..fc254faaf304b 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn.cc @@ -60,7 +60,7 @@ template void ApplyActivationToBatches(const Tensor* sequence_lens, const T* h_prev, T* Y_buffer_data_current_frame, int64_t time_step, int64_t batch_size, int64_t hidden_size, T alpha, T beta, T clip, std::function activation_func) { - const int* seq_len_data = sequence_lens ? sequence_lens->Data() : nullptr; + const int32_t* seq_len_data = sequence_lens ? sequence_lens->Data() : nullptr; for (int batch = 0; batch < batch_size; batch++) { bool valid = true; @@ -95,17 +95,22 @@ void Assign_Y_h(const T* Y_buffer_data, Tensor* Y_h, const Tensor* sequence_lens } for (int batch = 0; batch < batch_size; batch++) { - int64_t last_time_step = isReverse ? 0 : seq_length - 1; - if (nullptr != sequence_lens && !isReverse) { - last_time_step = sequence_lens->Data()[batch] - 1; - if (last_time_step < 0) { - // sequence_lens[batch] == 0: no data was processed for this batch; zero out Y_h. + // Handle zero-length sequences for both forward and reverse directions consistently. + if (nullptr != sequence_lens) { + int32_t seq_len = sequence_lens->Data()[batch]; + if (seq_len == 0) { int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size; math::Set(narrow(hidden_size), T{0}, Y_h->MutableData() + Y_h_offset, &CPUMathUtil::Instance()); continue; } } + + int64_t last_time_step = isReverse ? 0 : seq_length - 1; + if (nullptr != sequence_lens && !isReverse) { + last_time_step = sequence_lens->Data()[batch] - 1; + } + int64_t y_offset = last_time_step * num_directions * batch_size * hidden_size + direction * batch_size * hidden_size + batch * hidden_size; @@ -121,8 +126,8 @@ void ClearMissingFrames(T* Y_buffer_data, const Tensor* sequence_lens, int64_t num_directions, int64_t batch_size, int64_t seq_length, int64_t hidden_size) { for (int direction = 0; direction < num_directions; direction++) { for (int batch = 0; batch < batch_size; batch++) { - if (sequence_lens->Data()[batch] < seq_length) { - for (int seq = sequence_lens->Data()[batch]; seq < seq_length; seq++) { + if (sequence_lens->Data()[batch] < seq_length) { + for (int seq = sequence_lens->Data()[batch]; seq < seq_length; seq++) { int64_t offset = seq * num_directions * batch_size * hidden_size + direction * batch_size * hidden_size + @@ -169,6 +174,19 @@ Status RNN::Compute(OpKernelContext* ctx) const { std::vector Y_h_dims({num_directions, batch_size, hidden_size_}); Tensor* Y_h = ctx->Output(1, Y_h_dims); + // Reset output and return if max sequence length is 0 + if (sequence_lens != nullptr && sequence_lens->Shape().Size() > 0) { + int32_t max_sequence_length = *std::max_element(sequence_lens->Data(), + sequence_lens->Data() + sequence_lens->Shape().Size()); + if (max_sequence_length == 0) { + if (Y != nullptr) + std::fill_n(Y->MutableData(), Y->Shape().Size(), 0.f); + if (Y_h != nullptr) + std::fill_n(Y_h->MutableData(), Y_h->Shape().Size(), 0.f); + return Status::OK(); + } + } + AllocatorPtr alloc; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); diff --git a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc index 683415bca8f16..f522b801fb4c8 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc @@ -78,13 +78,13 @@ Status ValidateCommonRnnInputs(const Tensor& X, batch_size, "}. Actual:", sequence_lens_shape); } - auto sequence_len_entries = sequence_lens->DataAsSpan(); + auto sequence_len_entries = sequence_lens->DataAsSpan(); if (std::any_of(sequence_len_entries.begin(), sequence_len_entries.end(), - [seq_length](int len) { return len < 0 || len > seq_length; })) { + [seq_length](int32_t len) { return len < 0 || len > seq_length; })) { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid value/s in sequence_lens. All values must be > 0 and < seq_length. seq_length=", seq_length); + "Invalid value/s in sequence_lens. All values must be >= 0 and <= seq_length. seq_length=", seq_length); } } diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index 382d1869a02f6..0dcf4f597d9c8 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -756,9 +756,9 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { run_test(invalid_num_seq_len_entries, "Input sequence_lens must have shape {2}. Actual:{1}"); - // 0 is an invalid value + // 5 exceeds seq_length (3) std::vector bad_seq_len_entry{0, 5}; - run_test(bad_seq_len_entry, "Invalid value/s in sequence_lens. All values must be > 0 and < seq_length."); + run_test(bad_seq_len_entry, "Invalid value/s in sequence_lens. All values must be >= 0 and <= seq_length."); } TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { @@ -986,5 +986,118 @@ TEST(RNNTest, RNN_forward_sequence_lens_with_zero) { test.ConfigEp(std::move(cpu)).RunWithConfig(); } +// Test reverse RNN with all-zero sequence_lens and non-zero initial_h. +// The bug: reverse direction with sequence_lens=0 would return initial_h instead of zero-filling. +TEST(RNNTest, RNN_reverse_sequence_lens_all_zero) { + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + + OpTester test("RNN"); + int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 2; + + test.AddAttribute("activations", vector(num_directions, "Tanh")); + test.AddAttribute("direction", "reverse"); + test.AddAttribute("hidden_size", hidden_size); + + std::vector X_dims = {seq_length, batch_size, input_size}; + std::vector X_data({0.1f, 0.2f, 0.3f, 0.4f, + 0.5f, 0.6f, 0.7f, 0.8f}); + test.AddInput("X", X_dims, X_data); + + std::vector W_dims = {num_directions, hidden_size, input_size}; + std::vector W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f}); + test.AddInput("W", W_dims, W_data); + + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + std::vector R_data(hidden_size * hidden_size, 0.f); + test.AddInput("R", R_dims, R_data); + + std::vector B_dims = {num_directions, 2 * hidden_size}; + std::vector B_data(2 * hidden_size, 0.f); + test.AddInput("B", B_dims, B_data); + + // All batches have sequence_lens=0 + std::vector sequence_lens_dims{batch_size}; + std::vector sequence_lens_data{0, 0}; + test.AddInput("sequence_lens", sequence_lens_dims, sequence_lens_data); + + // Non-zero initial_h to detect if the bug returns initial_h instead of zeros. + std::vector initial_h_dims = {num_directions, batch_size, hidden_size}; + std::vector initial_h_data{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + test.AddInput("initial_h", initial_h_dims, initial_h_data); + + test.AddOptionalOutputEdge(); + + // Y_h must be all zeros despite non-zero initial_h (sequence_lens=0 means no output). + std::vector Y_h_dims{num_directions, batch_size, hidden_size}; + std::vector Y_h_data(num_directions * batch_size * hidden_size, 0.f); + test.AddOutput("Y_h", Y_h_dims, Y_h_data); + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +// Test reverse RNN with mixed sequence_lens (0 and non-zero) and non-zero initial_h. +TEST(RNNTest, RNN_reverse_sequence_lens_mixed_zero) { + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + + OpTester test("RNN"); + int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 2; + + test.AddAttribute("activations", vector(num_directions, "Tanh")); + test.AddAttribute("direction", "reverse"); + test.AddAttribute("hidden_size", hidden_size); + + // X shape: [seq_length=2, batch_size=2, input_size=2] + std::vector X_dims = {seq_length, batch_size, input_size}; + std::vector X_data({0.1f, 0.2f, + 0.3f, 0.4f, + 0.5f, 0.6f, + 0.7f, 0.8f}); + test.AddInput("X", X_dims, X_data); + + std::vector W_dims = {num_directions, hidden_size, input_size}; + std::vector W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f}); + test.AddInput("W", W_dims, W_data); + + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + std::vector R_data(hidden_size * hidden_size, 0.f); + test.AddInput("R", R_dims, R_data); + + std::vector B_dims = {num_directions, 2 * hidden_size}; + std::vector B_data(2 * hidden_size, 0.f); + test.AddInput("B", B_dims, B_data); + + // batch 0 has sequence_lens=2 (full), batch 1 has sequence_lens=0 + std::vector sequence_lens_dims{batch_size}; + std::vector sequence_lens_data{2, 0}; + test.AddInput("sequence_lens", sequence_lens_dims, sequence_lens_data); + + // Non-zero initial_h so that the bug (returning initial_h for batch 1) is detectable. + std::vector initial_h_dims = {num_directions, batch_size, hidden_size}; + std::vector initial_h_data{0.5f, -0.5f, 0.1f, 1.f, 2.f, 3.f}; + test.AddInput("initial_h", initial_h_dims, initial_h_data); + + test.AddOptionalOutputEdge(); + + // Y_h: shape [1, 2, 3] + // Reverse direction processes time steps from seq_length-1 down to 0. + // For batch 0 (seq_len=2): Y_h = Y at time_step 0 (the last processed step in reverse). + // time_step 1 (first in reverse): Y = tanh(X[1,0]*W^T + initial_h*R^T) + // R=0, so Y = tanh([-0.1*0.5+0.2*0.6, 1*0.5-2*0.6, -1*0.5+3*0.6]) = tanh([0.07, -0.7, 1.3]) + // time_step 0 (second in reverse): Y = tanh(X[0,0]*W^T + H_prev*R^T) + // R=0, so Y = tanh([-0.1*0.1+0.2*0.2, 1*0.1-2*0.2, -1*0.1+3*0.2]) = tanh([0.03, -0.3, 0.5]) + // Y_h for batch 0 = Y at time_step 0 = tanh([0.03, -0.3, 0.5]) + float y_h_batch0_f0 = std::tanh(0.03f); + float y_h_batch0_f1 = std::tanh(-0.3f); + float y_h_batch0_f2 = std::tanh(0.5f); + + // batch 1 has sequence_lens=0 so Y_h must be zero (not initial_h). + std::vector Y_h_dims{num_directions, batch_size, hidden_size}; + std::vector Y_h_data{y_h_batch0_f0, y_h_batch0_f1, y_h_batch0_f2, + 0.f, 0.f, 0.f}; + test.AddOutput("Y_h", Y_h_dims, Y_h_data); + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + } // namespace test } // namespace onnxruntime