diff --git a/onnxruntime/contrib_ops/cpu/murmur_hash3.cc b/onnxruntime/contrib_ops/cpu/murmur_hash3.cc index 56683f548382a..d23b112547a79 100644 --- a/onnxruntime/contrib_ops/cpu/murmur_hash3.cc +++ b/onnxruntime/contrib_ops/cpu/murmur_hash3.cc @@ -103,6 +103,10 @@ ONNX_OPERATOR_KERNEL_EX( KernelDefBuilder() .TypeConstraint("T1", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", std::vector{DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), @@ -193,14 +197,17 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const { ++output; } } else { - auto input = reinterpret_cast(keys->DataRaw()); - const auto input_end = input + input_count; + auto input = reinterpret_cast(keys->DataRaw()); + //input_element_bytes is 4, 8,.. less than 4 bytes is not allowed + int input_num_bytes = static_cast(input_element_bytes); + ORT_ENFORCE(input_num_bytes % 4 == 0); + const auto input_end = input + input_count * input_num_bytes; while (input != input_end) { MurmurHash3_x86_32(input, - static_cast(input_element_bytes), + input_num_bytes, seed_, output); - ++input; + input += input_num_bytes; ++output; } } diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 76a1ab4d79e0e..047dec091e474 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2084,7 +2084,7 @@ Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output) .SetDoc(R"DOC(The underlying implementation is MurmurHash3_x86_32 generating low latency 32bits hash suitable for implementing lookup tables, Bloom filters, count min sketch or feature hashing.)DOC") .Input(0, "X", "An input tensor to hash.", "T1") .Output(0, "Y", "32-bit hash value.", "T2") - .TypeConstraint("T1", {"tensor(uint32)", "tensor(int32)", "tensor(string)"}, "Constrain input type to unsigned or signed 32-bit integer tensor, or string tensor. It should be utf-8 encoded if using unicode.") + .TypeConstraint("T1", {"tensor(uint32)", "tensor(int32)", "tensor(uint64)", "tensor(int64)", "tensor(float)", "tensor(double)", "tensor(string)"}, "Constrain input type to unsigned or signed 32-bit integer tensor, or string tensor. It should be utf-8 encoded if using unicode.") .TypeConstraint("T2", {"tensor(uint32)", "tensor(int32)"}, "Constrain output type to unsigned and signed 32-bit integer tensor.") .Attr( "seed", diff --git a/onnxruntime/test/contrib_ops/murmur_hash3_test.cc b/onnxruntime/test/contrib_ops/murmur_hash3_test.cc index e6f1774a7ee7a..6de5780208edf 100644 --- a/onnxruntime/test/contrib_ops/murmur_hash3_test.cc +++ b/onnxruntime/test/contrib_ops/murmur_hash3_test.cc @@ -9,7 +9,7 @@ namespace test { TEST(MurmurHash3OpTest, UnsupportedInputType) { OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); - test.AddInput("X", {1}, {3.}); + test.AddInput("X", {1}, {3}); test.AddAttribute("positive", 0); test.AddOutput("Y", {1}, {847579505L}); // Unsupported input type @@ -49,7 +49,31 @@ TEST(MurmurHash3OpTest, ZeroSeedUIntResult2) { test.Run(); } -TEST(MurmurHash3OpTest, MoreData) { +TEST(MurmurHash3OpTest, ZeroSeedUIntResult3) { + OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); + test.AddInput("X", {1}, {4LL}); + test.AddAttribute("seed", 0LL); + test.AddOutput("Y", {1}, {3491892518L}); + test.Run(); +} + +TEST(MurmurHash3OpTest, ZeroSeedFloatResult) { + OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); + test.AddInput("X", {1}, {3.}); + test.AddAttribute("seed", 0LL); + test.AddOutput("Y", {1}, {6814352L}); + test.Run(); +} + +TEST(MurmurHash3OpTest, ZeroSeedDoubleResult) { + OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); + test.AddInput("X", {1}, {3.}); + test.AddAttribute("seed", 0LL); + test.AddOutput("Y", {1}, {3554953595L}); + test.Run(); +} + +TEST(MurmurHash3OpTest, MoreDataInt) { OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); test.AddInput("X", {2}, {3L, 4L}); test.AddAttribute("seed", 0LL); @@ -57,6 +81,14 @@ TEST(MurmurHash3OpTest, MoreData) { test.Run(); } +TEST(MurmurHash3OpTest, MoreDataFloat) { + OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); + test.AddInput("X", {2}, {3., 4.}); + test.AddAttribute("seed", 0LL); + test.AddOutput("Y", {2}, {6814352L, 313312394L}); + test.Run(); +} + TEST(MurmurHash3OpTest, NonZeroSeed) { OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain); test.AddInput("X", {1}, {3L});