Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion onnxruntime/contrib_ops/contrib_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
namespace onnxruntime {
namespace contrib {
using ::ONNX_NAMESPACE::AttributeProto;
using ::ONNX_NAMESPACE::OPTIONAL;
using ::ONNX_NAMESPACE::OpSchema;
using ::ONNX_NAMESPACE::OPTIONAL;

void RegisterContribSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp)
Expand Down Expand Up @@ -452,6 +452,39 @@ The bounding box coordinates corresponding to the selected indices can then be o
->set_dim_value(1);
}
});

ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer)
.SetDomain(kMSDomain)
.SinceVersion(1)
.Input(0, "X", "Strings to normalize", "T")
.Output(0, "Y", "Normalized strings", "T")
.TypeConstraint(
"T",
{"tensor(string)"},
"Input/Output is a string tensor")
.Attr(
"casechangeaction",
"string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"",
AttributeProto::STRING)
.Attr(
"is_case_sensitive",
"Boolean. Whether the identification of stop words in X is case-sensitive.",
AttributeProto::INT)
.Attr(
"stopwords",
"List of stop words",
AttributeProto::STRINGS,
OPTIONAL)
.Attr(
"locale",
"Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US",
AttributeProto::STRING,
OPTIONAL)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING);
})
.SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC");
}

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
Expand All @@ -461,6 +494,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression);

void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
Expand All @@ -474,6 +508,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression)>());
}
} // namespace contrib
Expand Down
244 changes: 244 additions & 0 deletions onnxruntime/contrib_ops/cpu/string_normalizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_normalizer.h"
#include "onnx/defs/schema.h"
#include "core/common/common.h"
#include "core/framework/tensor.h"

#include <codecvt>
#include <locale>
#include <functional>
#include <unordered_set>

namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
StringNormalizer,
1,
string,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<std::string>()),
contrib::StringNormalizer);

namespace string_normalizer {
const std::string conv_error("Conversion Error");
const std::wstring wconv_error(L"Conversion Error");
// performs tolower/toupper in-place
inline void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction,
std::wstring& wstr) {
assert(caseaction != StringNormalizer::NONE);
if (caseaction == StringNormalizer::LOWER) {
std::transform(wstr.begin(), wstr.end(), wstr.begin(),
[&loc](wchar_t ch) { return std::tolower(ch, loc); });
} else {
std::transform(wstr.begin(), wstr.end(), wstr.begin(),
[&loc](wchar_t ch) { return std::toupper(ch, loc); });
}
}

template <class ForwardIter>
Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx,
const std::locale& loc,
std::wstring_convert<std::codecvt_utf8<wchar_t>>& converter,
size_t N, size_t C,
StringNormalizer::CaseAction caseaction) {
std::vector<int64_t> output_dims;
if (N == 1) {
output_dims.push_back(1);
}

// Empty output case
if (C == 0) {
output_dims.push_back(1);
TensorShape output_shape(output_dims);
auto output_ten = ctx->Output(0, output_shape);
auto output_default = output_ten->template MutableData<std::string>();
new (output_default) std::string();
return Status::OK();
}

output_dims.push_back(C);

TensorShape output_shape(output_dims);
auto output_tensor = ctx->Output(0, output_shape);
auto const output_data = output_tensor->template MutableData<std::string>();

size_t output_idx = 0;
while (first != end) {
auto& s = *first;
if (caseaction == StringNormalizer::LOWER || caseaction == StringNormalizer::UPPER) {
std::wstring wstr = converter.from_bytes(s);
if (wstr == wconv_error) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input contains invalid utf8 chars at: " + static_cast<const std::string&>(s));
}
// In place transform
ChangeCase(loc, caseaction, wstr);
new (output_data + output_idx) std::string(converter.to_bytes(wstr));
} else {
assert(caseaction == StringNormalizer::NONE);
// Simple copy or move if the iterator points to a non-const string
new (output_data + output_idx) std::string(std::move(s));
}
++output_idx;
++first;
}
return Status::OK();
}

inline std::locale GetLocale(const std::string& locale_name) {
try {
std::locale result(locale_name);
return result;
} catch (const std::runtime_error& e) {
ONNXRUNTIME_THROW("Failed to construct locale with name:",
locale_name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales");
}
}
} // namespace string_normalizer

using namespace string_normalizer;

StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info),
is_case_sensitive_(true),
casechangeaction_(NONE),
compare_caseaction_(NONE) {
int64_t iscasesensitive = 0;
Status status = info.GetAttr("is_case_sensitive", &iscasesensitive);
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set");
is_case_sensitive_ = iscasesensitive != 0;

std::string casechangeaction;
status = info.GetAttr("casechangeaction", &casechangeaction);
ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute caseaction is not set");
if (casechangeaction == "LOWER") {
casechangeaction_ = LOWER;
} else if (casechangeaction == "UPPER") {
casechangeaction_ = UPPER;
} else if (casechangeaction == "NONE") {
casechangeaction_ = NONE;
} else {
ONNXRUNTIME_ENFORCE(false, "attribute casechangeaction has invalid value");
}

if (!is_case_sensitive_) {
// Convert stop words to a case which can help us preserve the case of filtered strings
compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER;
}

locale_name_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8"));
std::locale locale = GetLocale(locale_name_);
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);

std::vector<std::string> swords = info.GetAttrsOrDefault<std::string>("stopwords");
for (const auto& sw : swords) {
ONNXRUNTIME_ENFORCE(!sw.empty(), "Empty stopwords not allowed");
if (is_case_sensitive_) {
auto p = stopwords_.insert(sw);
ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed");
} else {
std::wstring wstr = converter.from_bytes(sw);
ONNXRUNTIME_ENFORCE(wstr != wconv_error, "Stopword contains invalid utf8 chars");
ChangeCase(locale, compare_caseaction_, wstr);
auto p = wstopwords_.insert(wstr);
ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed");
}
}
}

Status StringNormalizer::Compute(OpKernelContext* ctx) const {
using namespace string_normalizer;

auto X = ctx->Input<Tensor>(0);
auto& input_dims = X->Shape().GetDims();

size_t N = 0;
size_t C = 0;
if (input_dims.size() == 1) {
if (input_dims[0] < 1) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Single dimension value must be greater than 0");
}
C = input_dims[0];
} else if (input_dims.size() == 2) {
if (input_dims[0] != 1 || input_dims[1] < 1) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input dimensions are either[C > 0] or [1][C > 0] allowed");
}
N = 1;
C = input_dims[1];
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input dimensions are either[C > 0] or [1][C > 0] allowed");
}

Status status;
std::locale locale = GetLocale(locale_name_);
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter(conv_error, wconv_error);
auto const input_data = X->template Data<std::string>();
using StrRef = std::reference_wrapper<const std::string>;
if (is_case_sensitive_) {
if (!stopwords_.empty()) {
std::vector<StrRef> filtered_strings;
filtered_strings.reserve(C);
auto first = input_data;
auto const last = input_data + C;
while (first != last) {
const std::string& s = *first;
if (0 == stopwords_.count(s)) {
filtered_strings.push_back(std::cref(s));
}
++first;
}
status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter,
N, filtered_strings.size(), casechangeaction_);
} else {
// Nothing to filter. Copy input to output and change case if needed
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
}
} else {
if (!wstopwords_.empty()) {
// Filter input. When no case action is required
// we simply store original string references.
// Otherwise, we store converted strings.
std::vector<StrRef> filtered_orignal_strings;
std::vector<std::string> filtered_cased_strings;
filtered_orignal_strings.reserve(C);
filtered_cased_strings.reserve(C);
auto first = input_data;
auto const last = input_data + C;
while (first != last) {
const std::string& s = *first;
std::wstring wstr = converter.from_bytes(s);
if (wstr == wconv_error) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Input contains invalid utf8 chars at: " + s);
}
ChangeCase(locale, compare_caseaction_, wstr);
if (0 == wstopwords_.count(wstr)) {
if (casechangeaction_ == NONE) {
filtered_orignal_strings.push_back(std::cref(s));
} else {
filtered_cased_strings.push_back(converter.to_bytes(wstr));
}
}
++first;
}
if (casechangeaction_ == NONE) {
status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter,
N, filtered_orignal_strings.size(), NONE);
} else {
status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale, converter,
N, filtered_cased_strings.size(), NONE);
}
} else {
// Nothing to filter. Copy input to output and change case if needed
status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
}
}
return status;
}
} // namespace contrib
} // namespace onnxruntime
39 changes: 39 additions & 0 deletions onnxruntime/contrib_ops/cpu/string_normalizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"

#include <locale>
#include <string>
#include <unordered_set>

namespace onnxruntime {
namespace contrib {

class StringNormalizer : public OpKernel {
public:
enum CaseAction {
NONE = 0,
LOWER = 1,
UPPER = 2,
};

explicit StringNormalizer(const OpKernelInfo& info);
~StringNormalizer() = default;

Status Compute(OpKernelContext* ctx) const override;

private:
bool is_case_sensitive_;
CaseAction casechangeaction_;
CaseAction compare_caseaction_; // used for case-insensitive compare
std::string locale_name_;
// Either if these are populated but not both
std::unordered_set<std::string> stopwords_;
std::unordered_set<std::wstring> wstopwords_;
};

} // namespace contrib
} // namespace onnxruntime
Loading