Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions onnxruntime/contrib_ops/cpu/murmur_hash3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ ONNX_OPERATOR_KERNEL_EX(
KernelDefBuilder()
.TypeConstraint("T1", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>(),
DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<uint64_t>(),
DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>(),
DataTypeImpl::GetTensorType<std::string>()})
.TypeConstraint("T2", std::vector<MLDataType>{DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<uint32_t>()}),
Expand Down Expand Up @@ -193,14 +197,17 @@ Status MurmurHash3::Compute(OpKernelContext* ctx) const {
++output;
}
} else {
auto input = reinterpret_cast<const uint32_t*>(keys->DataRaw());
const auto input_end = input + input_count;
auto input = reinterpret_cast<const unsigned char*>(keys->DataRaw());
//input_element_bytes is 4, 8,.. less than 4 bytes is not allowed
int input_num_bytes = static_cast<int>(input_element_bytes);
ORT_ENFORCE(input_num_bytes % 4 == 0);
const auto input_end = input + input_count * input_num_bytes;
while (input != input_end) {
MurmurHash3_x86_32(input,
static_cast<int>(input_element_bytes),
input_num_bytes,
seed_,
output);
++input;
input += input_num_bytes;
++output;
}
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output)
.SetDoc(R"DOC(The underlying implementation is MurmurHash3_x86_32 generating low latency 32bits hash suitable for implementing lookup tables, Bloom filters, count min sketch or feature hashing.)DOC")
.Input(0, "X", "An input tensor to hash.", "T1")
.Output(0, "Y", "32-bit hash value.", "T2")
.TypeConstraint("T1", {"tensor(uint32)", "tensor(int32)", "tensor(string)"}, "Constrain input type to unsigned or signed 32-bit integer tensor, or string tensor. It should be utf-8 encoded if using unicode.")
.TypeConstraint("T1", {"tensor(uint32)", "tensor(int32)", "tensor(uint64)", "tensor(int64)", "tensor(float)", "tensor(double)", "tensor(string)"}, "Constrain input type to unsigned or signed 32-bit integer tensor, or string tensor. It should be utf-8 encoded if using unicode.")
.TypeConstraint("T2", {"tensor(uint32)", "tensor(int32)"}, "Constrain output type to unsigned and signed 32-bit integer tensor.")
.Attr(
"seed",
Expand Down
36 changes: 34 additions & 2 deletions onnxruntime/test/contrib_ops/murmur_hash3_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace test {

TEST(MurmurHash3OpTest, UnsupportedInputType) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<double>("X", {1}, {3.});
test.AddInput<int8_t>("X", {1}, {3});
test.AddAttribute<int64_t>("positive", 0);
test.AddOutput<int32_t>("Y", {1}, {847579505L});
// Unsupported input type
Expand Down Expand Up @@ -49,14 +49,46 @@ TEST(MurmurHash3OpTest, ZeroSeedUIntResult2) {
test.Run();
}

TEST(MurmurHash3OpTest, MoreData) {
TEST(MurmurHash3OpTest, ZeroSeedUIntResult3) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<int64_t>("X", {1}, {4LL});
test.AddAttribute<int64_t>("seed", 0LL);
test.AddOutput<uint32_t>("Y", {1}, {3491892518L});
test.Run();
}

TEST(MurmurHash3OpTest, ZeroSeedFloatResult) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<float>("X", {1}, {3.});
test.AddAttribute<int64_t>("seed", 0LL);
test.AddOutput<uint32_t>("Y", {1}, {6814352L});
test.Run();
}

TEST(MurmurHash3OpTest, ZeroSeedDoubleResult) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<double>("X", {1}, {3.});
test.AddAttribute<int64_t>("seed", 0LL);
test.AddOutput<uint32_t>("Y", {1}, {3554953595L});
test.Run();
}

TEST(MurmurHash3OpTest, MoreDataInt) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("X", {2}, {3L, 4L});
test.AddAttribute<int64_t>("seed", 0LL);
test.AddOutput<uint32_t>("Y", {2}, {847579505L, 1889779975L});
test.Run();
}

TEST(MurmurHash3OpTest, MoreDataFloat) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<float>("X", {2}, {3., 4.});
test.AddAttribute<int64_t>("seed", 0LL);
test.AddOutput<uint32_t>("Y", {2}, {6814352L, 313312394L});
test.Run();
}

TEST(MurmurHash3OpTest, NonZeroSeed) {
OpTester test("MurmurHash3", 1, onnxruntime::kMSDomain);
test.AddInput<int32_t>("X", {1}, {3L});
Expand Down