diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc index a1894d1e13dbc..4c0c86aa60729 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include #include "word_conv_embedding.h" @@ -14,6 +16,7 @@ namespace contrib { void WordConvEmbedding::CharEmbeddingLookup( const int* seq_ptr, const float* char_embedding_weight_p, + size_t char_embedding_table_size, size_t seq_len, size_t word_len, size_t char_embedding_size, @@ -26,7 +29,12 @@ void WordConvEmbedding::CharEmbeddingLookup( float* cur_dst_ptr = dst + word_inx * word_len * char_embedding_size; size_t char_length_to_lookup = std::max(words_len_ptr[word_inx], filter_width); for (size_t char_inx = 0; char_inx < char_length_to_lookup; char_inx++) { - memcpy(cur_dst_ptr, char_embedding_weight_p + (*cur_seq_ptr) * char_embedding_size, sizeof(float) * char_embedding_size); + const int char_index = *cur_seq_ptr; + if (char_index >= 0 && static_cast(char_index) < char_embedding_table_size) { + memcpy(cur_dst_ptr, + char_embedding_weight_p + static_cast(char_index) * char_embedding_size, + sizeof(float) * char_embedding_size); + } cur_dst_ptr += char_embedding_size; cur_seq_ptr++; } @@ -131,7 +139,23 @@ void WordConvEmbedding::CalculateLengthOfEachWordInSequence( } } -Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, const TensorShape& w_char_embedding_shape) const { +Status WordConvEmbedding::ValidateInputShape(const TensorShape& sequence_shape, const TensorShape& w_conv_shape, + const TensorShape& w_char_embedding_shape) const { + if (sequence_shape.NumDimensions() <= 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Sequence input must have rank greater than 1.", + " Sequence rank: ", sequence_shape.NumDimensions()); + } + + if (w_conv_shape.NumDimensions() <= 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv weight input must have rank greater than 3.", + " Conv weight rank: ", w_conv_shape.NumDimensions()); + } + + if (w_char_embedding_shape.NumDimensions() <= 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Char embedding input must have rank greater than 1.", + " Char embedding rank: ", w_char_embedding_shape.NumDimensions()); + } + if (embedding_size_ != -1 && w_conv_shape[0] != embedding_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv filter size does not match embedding_size attribute.", " embedding_size attribute: ", embedding_size_, @@ -156,6 +180,12 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co " Conv kernal size 2 : ", w_conv_shape[3]); } + if (w_conv_shape[2] > sequence_shape[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv kernel width must not exceed word length.", + " Conv kernel width: ", w_conv_shape[2], + " Word length: ", sequence_shape[1]); + } + return Status::OK(); } @@ -170,7 +200,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { const TensorShape& w_conv_shape = w_conv.Shape(); const TensorShape& w_char_embedding_shape = w_char_embedding.Shape(); - ORT_RETURN_IF_ERROR(ValidateInputShape(w_conv_shape, w_char_embedding_shape)); + ORT_RETURN_IF_ERROR(ValidateInputShape(sequence_shape, w_conv_shape, w_char_embedding_shape)); int64_t seq_len = sequence_shape[0]; int64_t word_len = sequence_shape[1]; @@ -198,6 +228,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const { CharEmbeddingLookup(seq_ptr, w_char_embedding.Data(), + onnxruntime::narrow(w_char_embedding_shape[0]), onnxruntime::narrow(seq_len), onnxruntime::narrow(word_len), onnxruntime::narrow(char_embedding_size), diff --git a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h index 89c3ba1f9964f..3dda4ec298a17 100644 --- a/onnxruntime/contrib_ops/cpu/word_conv_embedding.h +++ b/onnxruntime/contrib_ops/cpu/word_conv_embedding.h @@ -26,6 +26,7 @@ class WordConvEmbedding final : public OpKernel { void CharEmbeddingLookup( const int* seq_ptr, const float* char_embedding_weight_p, + size_t char_embedding_table_size, size_t seq_len, size_t word_len, size_t char_embedding_size, @@ -51,6 +52,7 @@ class WordConvEmbedding final : public OpKernel { size_t word_len) const; Status ValidateInputShape( + const TensorShape& sequence_shape, const TensorShape& w_conv_shape, const TensorShape& w_char_embedding_shape) const; diff --git a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc index fb4189cf3de5c..3f50166438190 100644 --- a/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc +++ b/onnxruntime/test/contrib_ops/word_conv_embedding_test.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" @@ -126,5 +127,45 @@ TEST(ContribOpTest, WordConvEmbedding_char_embedding_shape_conv_shape_not_match) test.Run(OpTester::ExpectResult::kExpectFailure); } +TEST(ContribOpTest, WordConvEmbedding_out_of_range_char_index_treated_as_padding) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + test.AddAttribute("embedding_size", 1LL); + test.AddAttribute("conv_window_size", 2LL); + test.AddAttribute("char_embedding_size", 1LL); + + test.AddInput("Sequence", {1, 2}, {1, 99}); + test.AddInput("W", {1, 1, 2, 1}, {1.0f, 1.0f}); + test.AddInput("B", {1}, {0.0f}); + test.AddInput("C", {2, 1}, {123.0f, 2.0f}); + test.AddOutput("Y", {1, 1}, {std::tanh(2.0f)}); + + test.Run(); +} + +TEST(ContribOpTest, WordConvEmbedding_rejects_filter_width_larger_than_word_length) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + test.AddInput("Sequence", {1, 2}, {1, 2}); + test.AddInput("W", {1, 1, 3, 1}, {1.0f, 1.0f, 1.0f}); + test.AddInput("B", {1}, {0.0f}); + test.AddInput("C", {3, 1}, {0.0f, 1.0f, 2.0f}); + test.AddOutput("Y", {1, 1}, {0.0f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "Conv kernel width must not exceed word length"); +} + +TEST(ContribOpTest, WordConvEmbedding_rejects_sequence_rank_one) { + OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain); + + test.AddInput("Sequence", {2}, {1, 2}); + test.AddInput("W", {1, 1, 2, 1}, {1.0f, 1.0f}); + test.AddInput("B", {1}, {0.0f}); + test.AddInput("C", {3, 1}, {0.0f, 1.0f, 2.0f}); + test.AddOutput("Y", {1, 1}, {0.0f}); + + test.Run(OpTester::ExpectResult::kExpectFailure, "Sequence input must have rank greater than 1"); +} + } // namespace test } // namespace onnxruntime