diff --git a/onnxruntime/core/common/safeint.h b/onnxruntime/core/common/safeint.h index 6aba5871ac62e..7062e99f6ce3b 100644 --- a/onnxruntime/core/common/safeint.h +++ b/onnxruntime/core/common/safeint.h @@ -36,3 +36,48 @@ class SafeIntExceptionHandler { #if defined(__GNUC__) #pragma GCC diagnostic pop #endif + +#include + +namespace onnxruntime { + +template +using remove_cvref_t = std::remove_cv_t>; + +template +inline constexpr bool is_supported_integer_v = + std::is_integral_v> && !std::is_same_v, bool>; + +//------------------------------------------------------------------------------ +// Safe multiplication of two or more integer values into an explicit result type R. +// Throws OnnxRuntimeException on overflow. +//------------------------------------------------------------------------------ +template +[[nodiscard]] R SafeMul(T a, U b, Rest... rest) { + static_assert(is_supported_integer_v, + "SafeMul requires an integral result type (excluding bool)"); + static_assert(is_supported_integer_v && is_supported_integer_v, + "SafeMul requires integral operand types (excluding bool)"); + static_assert((is_supported_integer_v && ...), + "SafeMul requires integral operand types (excluding bool)"); + + // SafeMultiply(T, U, T&) requires the first argument and result to share + // the same type. Cast the first operand to R so the result is directly in R. + R cast_a{}; + if (!SafeCast(a, cast_a)) { + SafeIntDefaultExceptionHandler::SafeIntOnOverflow(); + } + + R result{}; + if (!SafeMultiply(cast_a, b, result)) { + SafeIntDefaultExceptionHandler::SafeIntOnOverflow(); + } + + if constexpr (sizeof...(rest) > 0) { + return SafeMul(result, rest...); + } else { + return result; + } +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/rnn/rnn.cc b/onnxruntime/core/providers/cpu/rnn/rnn.cc index 6865571eb7a13..7b27befeb3e66 100644 --- a/onnxruntime/core/providers/cpu/rnn/rnn.cc +++ b/onnxruntime/core/providers/cpu/rnn/rnn.cc @@ -3,6 +3,7 @@ #include "core/providers/cpu/rnn/rnn.h" +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel_context_internal.h" #include "core/providers/cpu/rnn/rnn_activation_functors.h" @@ -84,15 +85,32 @@ void ApplyActivationToBatches(const Tensor* sequence_lens, const T* h_prev, T* Y template void Assign_Y_h(const T* Y_buffer_data, Tensor* Y_h, const Tensor* sequence_lens, int64_t num_directions, int direction, bool isReverse, int64_t batch_size, int64_t seq_length, int64_t hidden_size) { + if (seq_length == 0) { + // No sequence data was processed; zero out Y_h for this direction. + const size_t y_h_direction_size = SafeMul(batch_size, hidden_size); + const size_t Y_h_direction_offset = SafeMul(direction, y_h_direction_size); + math::Set(y_h_direction_size, T{0}, + Y_h->MutableData() + Y_h_direction_offset, &CPUMathUtil::Instance()); + return; + } + for (int batch = 0; batch < batch_size; batch++) { int64_t last_time_step = isReverse ? 0 : seq_length - 1; - if (nullptr != sequence_lens && !isReverse) + if (nullptr != sequence_lens && !isReverse) { last_time_step = sequence_lens->Data()[batch] - 1; + if (last_time_step < 0) { + // sequence_lens[batch] == 0: no data was processed for this batch; zero out Y_h. + int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size; + math::Set(narrow(hidden_size), T{0}, + Y_h->MutableData() + Y_h_offset, &CPUMathUtil::Instance()); + continue; + } + } int64_t y_offset = last_time_step * num_directions * batch_size * hidden_size + direction * batch_size * hidden_size + batch * hidden_size; int64_t Y_h_offset = direction * batch_size * hidden_size + batch * hidden_size; - math::CopyVector(static_cast(hidden_size), Y_buffer_data + y_offset, + math::CopyVector(narrow(hidden_size), Y_buffer_data + y_offset, Y_h->MutableData() + Y_h_offset, &CPUMathUtil::Instance()); } @@ -109,7 +127,7 @@ void ClearMissingFrames(T* Y_buffer_data, const Tensor* sequence_lens, seq * num_directions * batch_size * hidden_size + direction * batch_size * hidden_size + batch * hidden_size; - math::Set(onnxruntime::narrow(hidden_size), 0, Y_buffer_data + offset, &CPUMathUtil::Instance()); + math::Set(narrow(hidden_size), 0, Y_buffer_data + offset, &CPUMathUtil::Instance()); } } } @@ -155,7 +173,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); // X * W^t, each direction has shape of [seq_length, batch_size, hidden_size] - auto x_matmul_data = alloc->Alloc(SafeInt(sizeof(float)) * seq_length * batch_size * hidden_size_); + auto x_matmul_data = alloc->Alloc(SafeMul(sizeof(float), seq_length, batch_size, hidden_size_)); BufferUniquePtr x_matmul_buffer(x_matmul_data, BufferDeleter(alloc)); auto* x_matmul_w_buffer_data = static_cast(x_matmul_buffer.get()); @@ -165,7 +183,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { if (Y != nullptr) Y_buffer_data = Y->MutableData(); else { - Y_data = alloc->Alloc(SafeInt(sizeof(float)) * seq_length * num_directions * batch_size * hidden_size_); + Y_data = alloc->Alloc(SafeMul(sizeof(float), seq_length, num_directions, batch_size, hidden_size_)); Y_matmul_buffer = BufferUniquePtr(Y_data, BufferDeleter(alloc)); Y_buffer_data = static_cast(Y_matmul_buffer.get()); } @@ -177,20 +195,20 @@ Status RNN::Compute(OpKernelContext* ctx) const { bool isReverse = direction_ == "reverse" || direction == 1; if (B != nullptr) { - EigenMatrixMapRowMajor(x_matmul_w_buffer_data, seq_length * SafeInt(batch_size), onnxruntime::narrow(hidden_size_)).rowwise() = - ConstEigenVectorMap(B->Data() + direction * 2 * hidden_size_, onnxruntime::narrow(hidden_size_)).transpose() + - ConstEigenVectorMap(B->Data() + direction * 2 * hidden_size_ + hidden_size_, onnxruntime::narrow(hidden_size_)).transpose(); + EigenMatrixMapRowMajor(x_matmul_w_buffer_data, SafeMul(seq_length, batch_size), narrow(hidden_size_)).rowwise() = + ConstEigenVectorMap(B->Data() + direction * 2 * hidden_size_, narrow(hidden_size_)).transpose() + + ConstEigenVectorMap(B->Data() + direction * 2 * hidden_size_ + hidden_size_, narrow(hidden_size_)).transpose(); } else { - math::Set(seq_length * batch_size * SafeInt(hidden_size_), 0, x_matmul_w_buffer_data, &CPUMathUtil::Instance()); + math::Set(SafeMul(seq_length, batch_size, hidden_size_), 0, x_matmul_w_buffer_data, &CPUMathUtil::Instance()); } // X * W[direction]^t + B math::Gemm( CblasNoTrans, CblasTrans, - static_cast(seq_length * batch_size), - static_cast(hidden_size_), - static_cast(input_size), + SafeMul(seq_length, batch_size), + narrow(hidden_size_), + narrow(input_size), 1, X.Data(), W.Data() + direction * hidden_size_ * input_size, @@ -202,7 +220,7 @@ Status RNN::Compute(OpKernelContext* ctx) const { int64_t time_step = isReverse ? (seq_length - t - 1) : t; int64_t Y_frame_offset = (time_step * num_directions + direction) * Y_frame_size; float* Y_buffer_data_current_frame = Y_buffer_data + Y_frame_offset; - auto y_frame_mat = EigenMatrixMapRowMajor(Y_buffer_data_current_frame, onnxruntime::narrow(batch_size), onnxruntime::narrow(hidden_size_)); + auto y_frame_mat = EigenMatrixMapRowMajor(Y_buffer_data_current_frame, narrow(batch_size), narrow(hidden_size_)); const float* h_prev = nullptr; if (t == 0) { @@ -224,9 +242,9 @@ Status RNN::Compute(OpKernelContext* ctx) const { math::Gemm( CblasNoTrans, CblasTrans, - static_cast(batch_size), - static_cast(hidden_size_), - static_cast(hidden_size_), + narrow(batch_size), + narrow(hidden_size_), + narrow(hidden_size_), 1, h_prev, R.Data() + direction * hidden_size_ * hidden_size_, @@ -234,11 +252,11 @@ Status RNN::Compute(OpKernelContext* ctx) const { Y_buffer_data_current_frame, tp, &mlas_backend_kernel_selector_config_); } else { - math::Set(batch_size * SafeInt(hidden_size_), 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance()); + math::Set(SafeMul(batch_size, hidden_size_), 0, Y_buffer_data_current_frame, &CPUMathUtil::Instance()); } // X[time_step] * W^t + H_t_1 * R^t - y_frame_mat += EigenMatrixMapRowMajor(&x_matmul_w_buffer_data[time_step * Y_frame_size], onnxruntime::narrow(batch_size), onnxruntime::narrow(hidden_size_)); + y_frame_mat += EigenMatrixMapRowMajor(&x_matmul_w_buffer_data[time_step * Y_frame_size], narrow(batch_size), narrow(hidden_size_)); // apply activation ApplyActivationToBatches(sequence_lens, h_prev, Y_buffer_data_current_frame, @@ -258,10 +276,10 @@ Status RNN::Compute(OpKernelContext* ctx) const { } if (Y != nullptr) - DumpMatrix("Y", Y_buffer_data, (int)(seq_length * num_directions * batch_size), (int)hidden_size_); + DumpMatrix("Y", Y_buffer_data, SafeMul(seq_length, num_directions, batch_size), narrow(hidden_size_)); if (Y_h != nullptr) - DumpMatrix("Y_h", Y_h->Data(), (int)(num_directions * batch_size), (int)hidden_size_); + DumpMatrix("Y_h", Y_h->Data(), SafeMul(num_directions, batch_size), narrow(hidden_size_)); return Status::OK(); } diff --git a/onnxruntime/test/common/safeint_test.cc b/onnxruntime/test/common/safeint_test.cc new file mode 100644 index 0000000000000..ced9bd94975d2 --- /dev/null +++ b/onnxruntime/test/common/safeint_test.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace onnxruntime::test { + +static_assert(is_supported_integer_v); +static_assert(is_supported_integer_v); +static_assert(!is_supported_integer_v); + +TEST(SafeIntTest, SafeMulMultipliesOperands) { + EXPECT_EQ(SafeMul(size_t{2}, 3U), size_t{6}); + EXPECT_EQ(SafeMul(-2, 3, 4), -24); +} + +TEST(SafeIntTest, SafeMulHandlesSameVariableOperands) { + const int value = 7; + EXPECT_EQ(SafeMul(value, value), 49); +} + +#ifndef ORT_NO_EXCEPTIONS +TEST(SafeIntTest, SafeMulThrowsOnInitialCastOverflow) { + EXPECT_THROW((void)SafeMul(-1, 2), OnnxRuntimeException); +} + +TEST(SafeIntTest, SafeMulThrowsOnMultiplyOverflow) { + EXPECT_THROW((void)SafeMul(std::numeric_limits::max(), 2), OnnxRuntimeException); +} +#endif + +} // namespace onnxruntime::test diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index 38734ab9f668f..382d1869a02f6 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/cpu/rnn/rnn.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" @@ -883,5 +885,106 @@ TEST(RNNTest, RNN_with_invalid_activation_load_failure) { {kCudaExecutionProvider, kTensorrtExecutionProvider}); } +// Test that seq_length == 0 produces zero-filled Y and Y_h without crashing. +TEST(RNNTest, RNN_seq_length_zero) { + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + + OpTester test("RNN"); + int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 0; + + test.AddAttribute("activations", vector(num_directions, "Tanh")); + test.AddAttribute("direction", "forward"); + test.AddAttribute("hidden_size", hidden_size); + + std::vector X_dims = {seq_length, batch_size, input_size}; + std::vector X_data{}; + test.AddInput("X", X_dims, X_data); + + std::vector W_dims = {num_directions, hidden_size, input_size}; + std::vector W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f}); + test.AddInput("W", W_dims, W_data); + + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + std::vector R_data(hidden_size * hidden_size, 0.f); + test.AddInput("R", R_dims, R_data); + + // Y: shape [0, 1, 2, 3] -> empty + std::vector Y_dims = {seq_length, num_directions, batch_size, hidden_size}; + std::vector Y_data{}; + test.AddOutput("Y", Y_dims, Y_data); + + // Y_h: shape [1, 2, 3] -> all zeros + std::vector Y_h_dims{num_directions, batch_size, hidden_size}; + std::vector Y_h_data(num_directions * batch_size * hidden_size, 0.f); + test.AddOutput("Y_h", Y_h_dims, Y_h_data); + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + +// Test that per-batch sequence_lens containing 0 produces zero-filled Y_h for those batches. +TEST(RNNTest, RNN_forward_sequence_lens_with_zero) { + auto cpu = DefaultCpuExecutionProvider(); + if (!cpu) GTEST_SKIP() << "CPU EP not available in this build."; + + OpTester test("RNN"); + int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 2, seq_length = 2; + + test.AddAttribute("activations", vector(num_directions, "Tanh")); + test.AddAttribute("direction", "forward"); + test.AddAttribute("hidden_size", hidden_size); + + // X shape: [seq_length=2, batch_size=2, input_size=2] + std::vector X_dims = {seq_length, batch_size, input_size}; + std::vector X_data({0.1f, 0.2f, + 0.3f, 0.4f, + 0.5f, 0.6f, + 0.7f, 0.8f}); + test.AddInput("X", X_dims, X_data); + + std::vector W_dims = {num_directions, hidden_size, input_size}; + std::vector W_data({-0.1f, 0.2f, 1.f, -2.f, -1.f, 3.f}); + test.AddInput("W", W_dims, W_data); + + std::vector R_dims = {num_directions, hidden_size, hidden_size}; + std::vector R_data(hidden_size * hidden_size, 0.f); + test.AddInput("R", R_dims, R_data); + + std::vector B_dims = {num_directions, 2 * hidden_size}; + std::vector B_data(2 * hidden_size, 0.f); + test.AddInput("B", B_dims, B_data); + + // batch 0 has sequence_lens=2, batch 1 has sequence_lens=0 + std::vector sequence_lens_dims{batch_size}; + std::vector sequence_lens_data{2, 0}; + test.AddInput("sequence_lens", sequence_lens_dims, sequence_lens_data); + + std::vector initial_h_dims = {num_directions, batch_size, hidden_size}; + std::vector initial_h_data(num_directions * batch_size * hidden_size, 0.f); + test.AddInput("initial_h", initial_h_dims, initial_h_data); + + // Y output is optional; skip it to keep test simple. + test.AddOptionalOutputEdge(); + + // Y_h: shape [1, 2, 3] + // batch 0 gets the result of forward pass at last time step (seq_length-1=1). + // batch 1 has sequence_lens=0 so Y_h should be zero. + // + // For batch 0: + // time_step 0: X=[0.1, 0.2], Y = tanh(X * W^T) = tanh([-0.1*0.1+0.2*0.2, 1*0.1-2*0.2, -1*0.1+3*0.2]) + // = tanh([0.03, -0.3, 0.5]) + // time_step 1: X=[0.5, 0.6], Y = tanh(X * W^T + H_prev * R^T) + // R is zero, so Y = tanh([-0.1*0.5+0.2*0.6, 1*0.5-2*0.6, -1*0.5+3*0.6]) + // = tanh([0.07, -0.7, 1.3]) + float y_h_batch0_f0 = std::tanh(0.07f); + float y_h_batch0_f1 = std::tanh(-0.7f); + float y_h_batch0_f2 = std::tanh(1.3f); + + std::vector Y_h_dims{num_directions, batch_size, hidden_size}; + std::vector Y_h_data{y_h_batch0_f0, y_h_batch0_f1, y_h_batch0_f2, + 0.f, 0.f, 0.f}; + test.AddOutput("Y_h", Y_h_dims, Y_h_data); + test.ConfigEp(std::move(cpu)).RunWithConfig(); +} + } // namespace test } // namespace onnxruntime