Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onnxruntime/contrib_ops/cpu/word_conv_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace contrib {
void WordConvEmbedding::CharEmbeddingLookup(
const int* seq_ptr,
const float* char_embedding_weight_p,
size_t num_chars,
size_t seq_len,
size_t word_len,
size_t char_embedding_size,
Expand All @@ -24,8 +25,11 @@ void WordConvEmbedding::CharEmbeddingLookup(
if (words_len_ptr[word_inx] > 0) {
const int* cur_seq_ptr = seq_ptr + word_inx * word_len;
float* cur_dst_ptr = dst + word_inx * word_len * char_embedding_size;
size_t char_length_to_lookup = std::max<size_t>(words_len_ptr[word_inx], filter_width);
size_t char_length_to_lookup = std::min(std::max<size_t>(words_len_ptr[word_inx], filter_width), word_len);
for (size_t char_inx = 0; char_inx < char_length_to_lookup; char_inx++) {
Comment thread
vraspar marked this conversation as resolved.
ORT_ENFORCE(*cur_seq_ptr >= 0 && static_cast<size_t>(*cur_seq_ptr) < num_chars,
"CharEmbeddingLookup: character index ", *cur_seq_ptr,
" is out of range [0, ", num_chars, ").");
memcpy(cur_dst_ptr, char_embedding_weight_p + (*cur_seq_ptr) * char_embedding_size, sizeof(float) * char_embedding_size);
Comment thread
vraspar marked this conversation as resolved.
cur_dst_ptr += char_embedding_size;
cur_seq_ptr++;
Expand Down Expand Up @@ -198,6 +202,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {

CharEmbeddingLookup(seq_ptr,
w_char_embedding.Data<float>(),
onnxruntime::narrow<size_t>(w_char_embedding_shape[0]),
onnxruntime::narrow<size_t>(seq_len),
onnxruntime::narrow<size_t>(word_len),
onnxruntime::narrow<size_t>(char_embedding_size),
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/word_conv_embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class WordConvEmbedding final : public OpKernel {
void CharEmbeddingLookup(
const int* seq_ptr,
const float* char_embedding_weight_p,
size_t num_chars,
size_t seq_len,
size_t word_len,
size_t char_embedding_size,
Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/test/contrib_ops/word_conv_embedding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,77 @@ TEST(ContribOpTest, WordConvEmbedding_char_embedding_shape_conv_shape_not_match)
test.Run(OpTester::ExpectResult::kExpectFailure);
}

TEST(ContribOpTest, WordConvEmbedding_negative_char_index) {
OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain);

// Sequence contains a negative index (-1) which should be rejected
std::vector<int64_t> seq_words_shape = {2, 5};
std::vector<int> seq_words{1, -1, 3, 4, 0,
4, 3, 2, 1, 0};

std::vector<int64_t> W_char_embedding_shape = {5, 3};
std::vector<float> W_char_embedding{0.1f, 0.2f, 0.3f,
0.2f, 0.3f, 0.1f,
0.3f, 0.1f, 0.2f,
0.4f, 0.5f, 0.6f,
0.7f, 0.8f, 0.9f};

std::vector<int64_t> W_conv_shape = {2, 1, 2, 3};
std::vector<float> W_conv{0.1f, 0.2f, 0.3f,
0.2f, 0.3f, 0.1f,
0.3f, 0.1f, 0.2f,
1.0f, 1.1f, 1.2f};

std::vector<int64_t> B_conv_shape = {2};
std::vector<float> B_conv{0.1f, 0.2f};

std::vector<int64_t> output_shape = {2, 2};
std::vector<float> output{0.0f, 0.0f, 0.0f, 0.0f};

test.AddInput<int>("Sequence", seq_words_shape, seq_words);
test.AddInput<float>("W", W_conv_shape, W_conv);
test.AddInput<float>("B", B_conv_shape, B_conv);
test.AddInput<float>("C", W_char_embedding_shape, W_char_embedding);
test.AddOutput<float>("Y", output_shape, output);

test.Run(OpTester::ExpectResult::kExpectFailure, "CharEmbeddingLookup: character index");
}

TEST(ContribOpTest, WordConvEmbedding_oob_char_index) {
OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain);

// Sequence contains an out-of-range index (99) exceeding char embedding table size (5)
std::vector<int64_t> seq_words_shape = {2, 5};
std::vector<int> seq_words{1, 2, 99, 4, 0,
4, 3, 2, 1, 0};

std::vector<int64_t> W_char_embedding_shape = {5, 3};
std::vector<float> W_char_embedding{0.1f, 0.2f, 0.3f,
0.2f, 0.3f, 0.1f,
0.3f, 0.1f, 0.2f,
0.4f, 0.5f, 0.6f,
0.7f, 0.8f, 0.9f};

std::vector<int64_t> W_conv_shape = {2, 1, 2, 3};
std::vector<float> W_conv{0.1f, 0.2f, 0.3f,
0.2f, 0.3f, 0.1f,
0.3f, 0.1f, 0.2f,
1.0f, 1.1f, 1.2f};

std::vector<int64_t> B_conv_shape = {2};
std::vector<float> B_conv{0.1f, 0.2f};

std::vector<int64_t> output_shape = {2, 2};
std::vector<float> output{0.0f, 0.0f, 0.0f, 0.0f};

test.AddInput<int>("Sequence", seq_words_shape, seq_words);
test.AddInput<float>("W", W_conv_shape, W_conv);
test.AddInput<float>("B", B_conv_shape, B_conv);
test.AddInput<float>("C", W_char_embedding_shape, W_char_embedding);
test.AddOutput<float>("Y", output_shape, output);

test.Run(OpTester::ExpectResult::kExpectFailure, "CharEmbeddingLookup: character index");
}

} // namespace test
} // namespace onnxruntime
Loading