Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@

from . import executor
from . import disco

from .support import _regex_match
69 changes: 69 additions & 0 deletions python/tvm/runtime/support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Runtime support infra of TVM."""

import re

import tvm._ffi


@tvm._ffi.register_func("tvm.runtime.regex_match")
def _regex_match(regex_pattern: str, match_against: str) -> bool:
"""Check if a pattern matches a regular expression

This function should be used instead of `std::regex` within C++
call sites, to avoid ABI incompatibilities with pytorch.

Currently, the pytorch wheels available through pip install use
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
user the pre-C++11 ABI, this would cause breakages with
dynamically-linked LLVM environments.

Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex`
is called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first. This restriction can
be removed when a version of pytorch compiled using
`-DUSE_CXX11_ABI=1` is available from PyPI.

This is exposed as part of `libtvm_runtime.so` as it is used by
the DNNL runtime.

[0] https://github.com/pytorch/pytorch/issues/51039

Parameters
----------
regex_pattern: str

The regular expression

match_against: str

The string against which to match the regular expression

Returns
-------
match_result: bool

True if `match_against` matches the pattern defined by
`regex_pattern`, and False otherwise.

"""
match = re.match(regex_pattern, match_against)
return match is not None
44 changes: 0 additions & 44 deletions python/tvm/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import textwrap
import ctypes
import os
import re
import sys

import tvm
Expand Down Expand Up @@ -88,46 +87,3 @@ def add_function(self, name, func):

def __setitem__(self, key, value):
self.add_function(key, value)


@tvm._ffi.register_func("tvm.support.regex_match")
def _regex_match(regex_pattern: str, match_against: str) -> bool:
"""Check if a pattern matches a regular expression

This function should be used instead of `std::regex` within C++
call sites, to avoid ABI incompatibilities with pytorch.

Currently, the pytorch wheels available through pip install use
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
user the pre-C++11 ABI, this would cause breakages with
dynamically-linked LLVM environments.

Use of the `<regex>` header in TVM should be avoided, as its
implementation is not supported by gcc's dual ABI. This ABI
incompatibility results in runtime errors either when `std::regex`
is called from TVM, or when `std::regex` is called from pytorch,
depending on which library was loaded first. This restriction can
be removed when a version of pytorch compiled using
`-DUSE_CXX11_ABI=1` is available from PyPI.

[0] https://github.com/pytorch/pytorch/issues/51039

Parameters
----------
regex_pattern: str

The regular expression

match_against: str

The string against which to match the regular expression

Returns
-------
match_result: bool

True if `match_against` matches the pattern defined by
`regex_pattern`, and False otherwise.
"""
match = re.match(regex_pattern, match_against)
return match is not None
9 changes: 2 additions & 7 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <unordered_set>

#include "../runtime/object_internal.h"
#include "../runtime/regex.h"

namespace tvm {
namespace transform {
Expand Down Expand Up @@ -538,17 +539,11 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex,
.str();

auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule {
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
CHECK(regex_match_func)
<< "RuntimeError: "
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
<< "This can occur if the TVM Python library has not yet been imported.";

IRModule subset;

for (const auto& [gvar, func] : mod->functions) {
std::string name = gvar->name_hint;
if ((*regex_match_func)(func_name_regex, name)) {
if (tvm::runtime::regex_match(name, func_name_regex)) {
subset->Add(gvar, func);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/update_param_struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
#include <tvm/relax/transform.h>

#include <optional>
#include <regex>
#include <unordered_map>
#include <vector>

#include "../../runtime/regex.h"
#include "utils.h"

namespace tvm {
Expand Down
1 change: 0 additions & 1 deletion src/relay/backend/contrib/dnnl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../utils.h"
Expand Down
18 changes: 9 additions & 9 deletions src/relay/backend/contrib/dnnl/query_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@

#include <fstream>
#include <numeric>
#include <regex>
#include <sstream>

#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
#include "../../../../runtime/regex.h"
#include "../../utils.h"
#include "dnnl.hpp"
namespace tvm {
Expand Down Expand Up @@ -173,12 +173,12 @@ dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
}

void check_shapes(const std::vector<std::string> shapes) {
std::regex valid_pat("(\\d*)(,(\\d*))*");
bool checked = std::regex_match(shapes[0], valid_pat);
std::string valid_pat("(\\d*)(,(\\d*))*");
bool checked = tvm::runtime::regex_match(shapes[0], valid_pat);
for (size_t i = 1; i < shapes.size() - 1; i++) {
checked &= std::regex_match(shapes[i], valid_pat);
checked &= tvm::runtime::regex_match(shapes[i], valid_pat);
}
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
checked &= tvm::runtime::regex_match(shapes[shapes.size() - 1], "\\d*");
if (!checked) {
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
}
Expand All @@ -194,8 +194,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
std::string weight_shape, std::string out_shape,
std::string paddings, std::string strides,
std::string dilates, std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true);
check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)OI(D?)(H?)W"), true);
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
Expand Down Expand Up @@ -278,8 +278,8 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
std::string paddings, std::string output_paddings,
std::string strides, std::string dilates,
std::string G, std::string dtype) {
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true);
check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)((IO)|(OI))(D?)(H?)W"), true);
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});

dnnl::engine eng(dnnl::engine::kind::cpu, 0);
Expand Down
1 change: 0 additions & 1 deletion src/relay/backend/contrib/mrvl/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include <iostream>
#include <limits>
#include <memory>
#include <regex>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down
63 changes: 32 additions & 31 deletions src/runtime/contrib/dnnl/dnnl_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
#include <tvm/runtime/registry.h>

#include <cstddef>
#include <regex>
#include <string>
#include <vector>

#include "../../../runtime/regex.h"
#include "../json/json_node.h"
#include "../json/json_runtime.h"

Expand Down Expand Up @@ -194,53 +194,54 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;

// Define RegExp.
std::regex bias_add_pat(".*_bias.*");
std::regex relu_pat(".*_relu.*");
std::regex tanh_pat(".*_tanh.*");
std::regex sigmoid_pat(".*_sigmoid.*");
std::regex clip_pat(".*_clip.*");
std::regex gelu_pat(".*_gelu.*");
std::regex swish_pat(".*_swish.*");
std::regex sum_pat(".*_sum.*");
std::regex mish_pat(".*_mish.*");
std::string bias_add_pat(".*_bias.*");
std::string relu_pat(".*_relu.*");
std::string tanh_pat(".*_tanh.*");
std::string sigmoid_pat(".*_sigmoid.*");
std::string clip_pat(".*_clip.*");
std::string gelu_pat(".*_gelu.*");
std::string swish_pat(".*_swish.*");
std::string sum_pat(".*_sum.*");
std::string mish_pat(".*_mish.*");

// parsing of name to extract attributes
auto op_name = nodes_[nid].GetOpName();

// Parsing post-ops.
dnnl::post_ops ops;
if (std::regex_match(op_name, sum_pat)) {
if (tvm::runtime::regex_match(op_name, sum_pat)) {
ops.append_sum(1.f);
}
if (std::regex_match(op_name, relu_pat)) {
if (tvm::runtime::regex_match(op_name, relu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
}
if (std::regex_match(op_name, tanh_pat)) {
if (tvm::runtime::regex_match(op_name, tanh_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
}
if (std::regex_match(op_name, clip_pat)) {
if (tvm::runtime::regex_match(op_name, clip_pat)) {
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
}
if (std::regex_match(op_name, sigmoid_pat)) {
if (tvm::runtime::regex_match(op_name, sigmoid_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
}
if (std::regex_match(op_name, swish_pat)) {
if (tvm::runtime::regex_match(op_name, swish_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
}
if (std::regex_match(op_name, gelu_pat)) {
if (tvm::runtime::regex_match(op_name, gelu_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
}
if (std::regex_match(op_name, mish_pat)) {
if (tvm::runtime::regex_match(op_name, mish_pat)) {
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
}
if (ops.len() != 0) {
attr.set_post_ops(ops);
}

// Parsing bias_add.
*bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
*bias_tr =
tvm::runtime::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};

return attr;
}
Expand All @@ -253,31 +254,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
tensor_registry_ = TensorRegistry(engine_, io_eid_set);

std::regex conv_pat(".*conv[1-3]d.*");
std::regex deconv_pat(".*deconv[1-3]d.*");
std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::regex dense_pat(".*dense.*");
std::regex max_pool_pat(".*max_pool[1-3]d");
std::regex avg_pool_pat(".*avg_pool[1-3]d");
std::string conv_pat(".*conv[1-3]d.*");
std::string deconv_pat(".*deconv[1-3]d.*");
std::string conv_transpose_pat(".*conv[1-3]d_transpose.*");
std::string dense_pat(".*dense.*");
std::string max_pool_pat(".*max_pool[1-3]d");
std::string avg_pool_pat(".*avg_pool[1-3]d");

// Build subgraph engine.
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
const auto& node = nodes_[nid];
if (node.GetOpType() == "kernel") {
ICHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
if (std::regex_match(op_name, deconv_pat) ||
std::regex_match(op_name, conv_transpose_pat)) {
if (tvm::runtime::regex_match(op_name, deconv_pat) ||
tvm::runtime::regex_match(op_name, conv_transpose_pat)) {
Deconvolution(nid);
} else if (std::regex_match(op_name, conv_pat)) {
} else if (tvm::runtime::regex_match(op_name, conv_pat)) {
Convolution(nid);
} else if (std::regex_match(op_name, dense_pat)) {
} else if (tvm::runtime::regex_match(op_name, dense_pat)) {
Dense(nid);
} else if ("nn.batch_norm" == op_name) {
BatchNorm(nid);
} else if (std::regex_match(op_name, max_pool_pat)) {
} else if (tvm::runtime::regex_match(op_name, max_pool_pat)) {
Pooling(nid, dnnl::algorithm::pooling_max);
} else if (std::regex_match(op_name, avg_pool_pat)) {
} else if (tvm::runtime::regex_match(op_name, avg_pool_pat)) {
Pooling(nid, dnnl::algorithm::pooling_avg);
} else if (elt_name2algo.count(op_name)) {
Eltwise(nid);
Expand Down
Loading