diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index a1894d1e13dbc..6c48fa7c9c9af 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -11,9 +11,10 @@ namespace onnxruntime { namespace contrib { -void WordConvEmbedding::CharEmbeddingLookup( +Status WordConvEmbedding::CharEmbeddingLookup( const int* seq_ptr, const float* char_embedding_weight_p, + size_t num_chars, size_t seq_len, size_t word_len, size_t char_embedding_size, @@ -24,14 +25,18 @@ void WordConvEmbedding::CharEmbeddingLookup( if (words_len_ptr[word_inx] > 0) { const int* cur_seq_ptr = seq_ptr + word_inx * word_len; float* cur_dst_ptr = dst + word_inx * word_len * char_embedding_size; - size_t char_length_to_lookup = std::max(words_len_ptr[word_inx], filter_width); + size_t char_length_to_lookup = std::min(std::max(words_len_ptr[word_inx], filter_width), word_len); for (size_t char_inx = 0; char_inx < char_length_to_lookup; char_inx++) { + ORT_RETURN_IF_NOT(*cur_seq_ptr >= 0 && static_cast(*cur_seq_ptr) < num_chars, + "CharEmbeddingLookup: character index ", *cur_seq_ptr, + " is out of range [0, ", num_chars, ")."); memcpy(cur_dst_ptr, char_embedding_weight_p + (*cur_seq_ptr) * char_embedding_size, sizeof(float) * char_embedding_size); cur_dst_ptr += char_embedding_size; cur_seq_ptr++; } } } + return Status::OK(); } // input : [sequence_length, word_length, char_embedding_size] @@ -196,14 +201,18 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { CalculateLengthOfEachWordInSequence(seq_ptr, words_length_ptr.get(), onnxruntime::narrow(seq_len), onnxruntime::narrow(word_len)); - CharEmbeddingLookup(seq_ptr, - w_char_embedding.Data(), - onnxruntime::narrow(seq_len), - onnxruntime::narrow(word_len), - onnxruntime::narrow(char_embedding_size), - onnxruntime::narrow(filter_width), - words_length_ptr.get(), - chars_embeddings_ptr.get()); + ORT_RETURN_IF_ERROR(CharEmbeddingLookup(seq_ptr, + w_char_embedding.Data(), + onnxruntime::narrow(w_char_embedding_shape[0]), + onnxruntime::narrow(seq_len), + onnxruntime::narrow(word_len), + onnxruntime::narrow(char_embedding_size), + onnxruntime::narrow(filter_width), + words_length_ptr.get(), + chars_embeddings_ptr.get())); + + ORT_RETURN_IF_NOT(filter_width <= word_len, + "filter_width (", filter_width, ") must be <= word_len (", word_len, ")."); ComputeConvMaxPoolWithActivation( alloc, diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h index 89c3ba1f9964f..8123bb1547952 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h @@ -23,9 +23,10 @@ class WordConvEmbedding final : public OpKernel { Status Compute(OpKernelContext* context) const override; private: - void CharEmbeddingLookup( + Status CharEmbeddingLookup( const int* seq_ptr, const float* char_embedding_weight_p, + size_t num_chars, size_t seq_len, size_t word_len, size_t char_embedding_size, diff --git a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc index fb4189cf3de5c..fd145273f9c26 100644 --- a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc +++ b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc @@ -126,5 +126,116 @@ TEST(ContribOpTest, WordConvEmbedding_char_embedding_shape_conv_shape_not_match) test.Run(OpTester::ExpectResult::kExpectFailure); } +TEST(ContribOpTest, WordConvEmbedding_negative_char_index) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // Sequence contains a negative index (-1) which should be rejected + std::vector seq_words_shape = {2, 5}; + std::vector seq_words{1, -1, 3, 4, 0, + 4, 3, 2, 1, 0}; + + std::vector W_char_embedding_shape = {5, 3}; + std::vector W_char_embedding{0.1f, 0.2f, 0.3f, + 0.2f, 0.3f, 0.1f, + 0.3f, 0.1f, 0.2f, + 0.4f, 0.5f, 0.6f, + 0.7f, 0.8f, 0.9f}; + + std::vector W_conv_shape = {2, 1, 2, 3}; + std::vector W_conv{0.1f, 0.2f, 0.3f, + 0.2f, 0.3f, 0.1f, + 0.3f, 0.1f, 0.2f, + 1.0f, 1.1f, 1.2f}; + + std::vector B_conv_shape = {2}; + std::vector B_conv{0.1f, 0.2f}; + + std::vector output_shape = {2, 2}; + std::vector output{0.0f, 0.0f, 0.0f, 0.0f}; + + test.AddInput("Sequence", seq_words_shape, seq_words); + test.AddInput("W", W_conv_shape, W_conv); + test.AddInput("B", B_conv_shape, B_conv); + test.AddInput("C", W_char_embedding_shape, W_char_embedding); + test.AddOutput("Y", output_shape, output); + + test.Run(OpTester::ExpectResult::kExpectFailure, "CharEmbeddingLookup: character index"); +} + +TEST(ContribOpTest, WordConvEmbedding_oob_char_index) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // Sequence contains an out-of-range index (99) exceeding char embedding table size (5) + std::vector seq_words_shape = {2, 5}; + std::vector seq_words{1, 2, 99, 4, 0, + 4, 3, 2, 1, 0}; + + std::vector W_char_embedding_shape = {5, 3}; + std::vector W_char_embedding{0.1f, 0.2f, 0.3f, + 0.2f, 0.3f, 0.1f, + 0.3f, 0.1f, 0.2f, + 0.4f, 0.5f, 0.6f, + 0.7f, 0.8f, 0.9f}; + + std::vector W_conv_shape = {2, 1, 2, 3}; + std::vector W_conv{0.1f, 0.2f, 0.3f, + 0.2f, 0.3f, 0.1f, + 0.3f, 0.1f, 0.2f, + 1.0f, 1.1f, 1.2f}; + + std::vector B_conv_shape = {2}; + std::vector B_conv{0.1f, 0.2f}; + + std::vector output_shape = {2, 2}; + std::vector output{0.0f, 0.0f, 0.0f, 0.0f}; + + test.AddInput("Sequence", seq_words_shape, seq_words); + test.AddInput("W", W_conv_shape, W_conv); + test.AddInput("B", B_conv_shape, B_conv); + test.AddInput("C", W_char_embedding_shape, W_char_embedding); + test.AddOutput("Y", output_shape, output); + + test.Run(OpTester::ExpectResult::kExpectFailure, "CharEmbeddingLookup: character index"); +} + +TEST(ContribOpTest, WordConvEmbedding_filter_width_exceeds_word_len) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + // word_len=2 but filter_width=3 (from W_conv_shape[2]) → should fail + std::vector seq_words_shape = {2, 2}; + std::vector seq_words{1, 2, + 3, 4}; + + std::vector W_char_embedding_shape = {5, 3}; + std::vector W_char_embedding{0.1f, 0.2f, 0.3f, + 0.2f, 0.3f, 0.1f, + 0.3f, 0.1f, 0.2f, + 0.4f, 0.5f, 0.6f, + 0.7f, 0.8f, 0.9f}; + + // filter_width = W_conv_shape[2] = 3, which exceeds word_len = 2 + std::vector W_conv_shape = {2, 1, 3, 3}; + std::vector W_conv{0.1f, 0.2f, 0.3f, + 0.2f, 0.3f, 0.1f, + 0.3f, 0.1f, 0.2f, + 1.0f, 1.1f, 1.2f, + 1.3f, 1.4f, 1.5f, + 1.6f, 1.7f, 1.8f}; + + std::vector B_conv_shape = {2}; + std::vector B_conv{0.1f, 0.2f}; + + std::vector output_shape = {2, 2}; + std::vector output{0.0f, 0.0f, 0.0f, 0.0f}; + + test.AddInput("Sequence", seq_words_shape, seq_words); + test.AddInput("W", W_conv_shape, W_conv); + test.AddInput("B", B_conv_shape, B_conv); + test.AddInput("C", W_char_embedding_shape, W_char_embedding); + test.AddOutput("Y", output_shape, output); + + test.Run(OpTester::ExpectResult::kExpectFailure, "filter_width"); +} + } // namespace test } // namespace onnxruntime