Skip to content

Commit e33bc16

Browse files
committed
[Lint] Add check to prevent usage of #include <regex> (apache#16412)
Currently, the pytorch wheels available through `pip install` use the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to user the pre-C++11 ABI, this would cause breakages with dynamically-linked LLVM environments. This commit adds a lint check to search for use of `#include <regex>` in any C++ files. Use of this header should be avoided, as its implementation is not supported by gcc's dual ABI. This ABI incompatibility results in runtime errors either when `std::regex` is called from TVM, or when `std::regex` is called from pytorch, depending on which library was loaded first. This restriction can be removed when a version of pytorch compiled using `-DUSE_CXX11_ABI=1` is available from PyPI. [0] pytorch/pytorch#51039
1 parent fda0bee commit e33bc16

File tree

12 files changed

+230
-94
lines changed

12 files changed

+230
-94
lines changed

python/tvm/runtime/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,5 @@
4141

4242
from . import executor
4343
from . import disco
44+
45+
from .support import _regex_match

python/tvm/runtime/support.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Runtime support infra of TVM."""
19+
20+
import re
21+
22+
import tvm._ffi
23+
24+
25+
@tvm._ffi.register_func("tvm.runtime.regex_match")
26+
def _regex_match(regex_pattern: str, match_against: str) -> bool:
27+
"""Check if a pattern matches a regular expression
28+
29+
This function should be used instead of `std::regex` within C++
30+
call sites, to avoid ABI incompatibilities with pytorch.
31+
32+
Currently, the pytorch wheels available through pip install use
33+
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
34+
user the pre-C++11 ABI, this would cause breakages with
35+
dynamically-linked LLVM environments.
36+
37+
Use of the `<regex>` header in TVM should be avoided, as its
38+
implementation is not supported by gcc's dual ABI. This ABI
39+
incompatibility results in runtime errors either when `std::regex`
40+
is called from TVM, or when `std::regex` is called from pytorch,
41+
depending on which library was loaded first. This restriction can
42+
be removed when a version of pytorch compiled using
43+
`-DUSE_CXX11_ABI=1` is available from PyPI.
44+
45+
This is exposed as part of `libtvm_runtime.so` as it is used by
46+
the DNNL runtime.
47+
48+
[0] https://github.com/pytorch/pytorch/issues/51039
49+
50+
Parameters
51+
----------
52+
regex_pattern: str
53+
54+
The regular expression
55+
56+
match_against: str
57+
58+
The string against which to match the regular expression
59+
60+
Returns
61+
-------
62+
match_result: bool
63+
64+
True if `match_against` matches the pattern defined by
65+
`regex_pattern`, and False otherwise.
66+
67+
"""
68+
match = re.match(regex_pattern, match_against)
69+
return match is not None

python/tvm/support.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import textwrap
2020
import ctypes
2121
import os
22-
import re
2322
import sys
2423

2524
import tvm
@@ -88,46 +87,3 @@ def add_function(self, name, func):
8887

8988
def __setitem__(self, key, value):
9089
self.add_function(key, value)
91-
92-
93-
@tvm._ffi.register_func("tvm.support.regex_match")
94-
def _regex_match(regex_pattern: str, match_against: str) -> bool:
95-
"""Check if a pattern matches a regular expression
96-
97-
This function should be used instead of `std::regex` within C++
98-
call sites, to avoid ABI incompatibilities with pytorch.
99-
100-
Currently, the pytorch wheels available through pip install use
101-
the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
102-
user the pre-C++11 ABI, this would cause breakages with
103-
dynamically-linked LLVM environments.
104-
105-
Use of the `<regex>` header in TVM should be avoided, as its
106-
implementation is not supported by gcc's dual ABI. This ABI
107-
incompatibility results in runtime errors either when `std::regex`
108-
is called from TVM, or when `std::regex` is called from pytorch,
109-
depending on which library was loaded first. This restriction can
110-
be removed when a version of pytorch compiled using
111-
`-DUSE_CXX11_ABI=1` is available from PyPI.
112-
113-
[0] https://github.com/pytorch/pytorch/issues/51039
114-
115-
Parameters
116-
----------
117-
regex_pattern: str
118-
119-
The regular expression
120-
121-
match_against: str
122-
123-
The string against which to match the regular expression
124-
125-
Returns
126-
-------
127-
match_result: bool
128-
129-
True if `match_against` matches the pattern defined by
130-
`regex_pattern`, and False otherwise.
131-
"""
132-
match = re.match(regex_pattern, match_against)
133-
return match is not None

src/ir/transform.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <unordered_set>
3636

3737
#include "../runtime/object_internal.h"
38+
#include "../runtime/regex.h"
3839

3940
namespace tvm {
4041
namespace transform {
@@ -538,17 +539,11 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex,
538539
.str();
539540

540541
auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule {
541-
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
542-
CHECK(regex_match_func)
543-
<< "RuntimeError: "
544-
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
545-
<< "This can occur if the TVM Python library has not yet been imported.";
546-
547542
IRModule subset;
548543

549544
for (const auto& [gvar, func] : mod->functions) {
550545
std::string name = gvar->name_hint;
551-
if ((*regex_match_func)(func_name_regex, name)) {
546+
if (tvm::runtime::regex_match(name, func_name_regex)) {
552547
subset->Add(gvar, func);
553548
}
554549
}

src/relax/transform/update_param_struct_info.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
#include <tvm/relax/transform.h>
2828

2929
#include <optional>
30-
#include <regex>
3130
#include <unordered_map>
3231
#include <vector>
3332

33+
#include "../../runtime/regex.h"
3434
#include "utils.h"
3535

3636
namespace tvm {

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
#include <fstream>
3333
#include <numeric>
34-
#include <regex>
3534
#include <sstream>
3635

3736
#include "../../utils.h"

src/relay/backend/contrib/dnnl/query_layout.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131

3232
#include <fstream>
3333
#include <numeric>
34-
#include <regex>
3534
#include <sstream>
3635

3736
#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
37+
#include "../../../../runtime/regex.h"
3838
#include "../../utils.h"
3939
#include "dnnl.hpp"
4040
namespace tvm {
@@ -173,12 +173,12 @@ dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
173173
}
174174

175175
void check_shapes(const std::vector<std::string> shapes) {
176-
std::regex valid_pat("(\\d*)(,(\\d*))*");
177-
bool checked = std::regex_match(shapes[0], valid_pat);
176+
std::string valid_pat("(\\d*)(,(\\d*))*");
177+
bool checked = tvm::runtime::regex_match(shapes[0], valid_pat);
178178
for (size_t i = 1; i < shapes.size() - 1; i++) {
179-
checked &= std::regex_match(shapes[i], valid_pat);
179+
checked &= tvm::runtime::regex_match(shapes[i], valid_pat);
180180
}
181-
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
181+
checked &= tvm::runtime::regex_match(shapes[shapes.size() - 1], "\\d*");
182182
if (!checked) {
183183
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
184184
}
@@ -194,8 +194,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
194194
std::string weight_shape, std::string out_shape,
195195
std::string paddings, std::string strides,
196196
std::string dilates, std::string G, std::string dtype) {
197-
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
198-
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
197+
check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true);
198+
check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)OI(D?)(H?)W"), true);
199199
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
200200

201201
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
@@ -278,8 +278,8 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
278278
std::string paddings, std::string output_paddings,
279279
std::string strides, std::string dilates,
280280
std::string G, std::string dtype) {
281-
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
282-
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
281+
check_layout(tvm::runtime::regex_match(data_layout, "NC(D?)(H?)W"), true);
282+
check_layout(tvm::runtime::regex_match(kernel_layout, "(G?)((IO)|(OI))(D?)(H?)W"), true);
283283
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});
284284

285285
dnnl::engine eng(dnnl::engine::kind::cpu, 0);

src/relay/backend/contrib/mrvl/codegen.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include <iostream>
3232
#include <limits>
3333
#include <memory>
34-
#include <regex>
3534
#include <string>
3635
#include <unordered_map>
3736
#include <utility>

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#include <tvm/runtime/registry.h>
2727

2828
#include <cstddef>
29-
#include <regex>
3029
#include <string>
3130
#include <vector>
3231

32+
#include "../../../runtime/regex.h"
3333
#include "../json/json_node.h"
3434
#include "../json/json_runtime.h"
3535

@@ -194,53 +194,54 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
194194
if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;
195195

196196
// Define RegExp.
197-
std::regex bias_add_pat(".*_bias.*");
198-
std::regex relu_pat(".*_relu.*");
199-
std::regex tanh_pat(".*_tanh.*");
200-
std::regex sigmoid_pat(".*_sigmoid.*");
201-
std::regex clip_pat(".*_clip.*");
202-
std::regex gelu_pat(".*_gelu.*");
203-
std::regex swish_pat(".*_swish.*");
204-
std::regex sum_pat(".*_sum.*");
205-
std::regex mish_pat(".*_mish.*");
197+
std::string bias_add_pat(".*_bias.*");
198+
std::string relu_pat(".*_relu.*");
199+
std::string tanh_pat(".*_tanh.*");
200+
std::string sigmoid_pat(".*_sigmoid.*");
201+
std::string clip_pat(".*_clip.*");
202+
std::string gelu_pat(".*_gelu.*");
203+
std::string swish_pat(".*_swish.*");
204+
std::string sum_pat(".*_sum.*");
205+
std::string mish_pat(".*_mish.*");
206206

207207
// parsing of name to extract attributes
208208
auto op_name = nodes_[nid].GetOpName();
209209

210210
// Parsing post-ops.
211211
dnnl::post_ops ops;
212-
if (std::regex_match(op_name, sum_pat)) {
212+
if (tvm::runtime::regex_match(op_name, sum_pat)) {
213213
ops.append_sum(1.f);
214214
}
215-
if (std::regex_match(op_name, relu_pat)) {
215+
if (tvm::runtime::regex_match(op_name, relu_pat)) {
216216
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
217217
}
218-
if (std::regex_match(op_name, tanh_pat)) {
218+
if (tvm::runtime::regex_match(op_name, tanh_pat)) {
219219
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
220220
}
221-
if (std::regex_match(op_name, clip_pat)) {
221+
if (tvm::runtime::regex_match(op_name, clip_pat)) {
222222
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
223223
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
224224
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
225225
}
226-
if (std::regex_match(op_name, sigmoid_pat)) {
226+
if (tvm::runtime::regex_match(op_name, sigmoid_pat)) {
227227
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
228228
}
229-
if (std::regex_match(op_name, swish_pat)) {
229+
if (tvm::runtime::regex_match(op_name, swish_pat)) {
230230
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
231231
}
232-
if (std::regex_match(op_name, gelu_pat)) {
232+
if (tvm::runtime::regex_match(op_name, gelu_pat)) {
233233
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
234234
}
235-
if (std::regex_match(op_name, mish_pat)) {
235+
if (tvm::runtime::regex_match(op_name, mish_pat)) {
236236
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
237237
}
238238
if (ops.len() != 0) {
239239
attr.set_post_ops(ops);
240240
}
241241

242242
// Parsing bias_add.
243-
*bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
243+
*bias_tr =
244+
tvm::runtime::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
244245

245246
return attr;
246247
}
@@ -253,31 +254,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
253254
std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
254255
tensor_registry_ = TensorRegistry(engine_, io_eid_set);
255256

256-
std::regex conv_pat(".*conv[1-3]d.*");
257-
std::regex deconv_pat(".*deconv[1-3]d.*");
258-
std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
259-
std::regex dense_pat(".*dense.*");
260-
std::regex max_pool_pat(".*max_pool[1-3]d");
261-
std::regex avg_pool_pat(".*avg_pool[1-3]d");
257+
std::string conv_pat(".*conv[1-3]d.*");
258+
std::string deconv_pat(".*deconv[1-3]d.*");
259+
std::string conv_transpose_pat(".*conv[1-3]d_transpose.*");
260+
std::string dense_pat(".*dense.*");
261+
std::string max_pool_pat(".*max_pool[1-3]d");
262+
std::string avg_pool_pat(".*avg_pool[1-3]d");
262263

263264
// Build subgraph engine.
264265
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
265266
const auto& node = nodes_[nid];
266267
if (node.GetOpType() == "kernel") {
267268
ICHECK_EQ(node.GetOpType(), "kernel");
268269
auto op_name = node.GetOpName();
269-
if (std::regex_match(op_name, deconv_pat) ||
270-
std::regex_match(op_name, conv_transpose_pat)) {
270+
if (tvm::runtime::regex_match(op_name, deconv_pat) ||
271+
tvm::runtime::regex_match(op_name, conv_transpose_pat)) {
271272
Deconvolution(nid);
272-
} else if (std::regex_match(op_name, conv_pat)) {
273+
} else if (tvm::runtime::regex_match(op_name, conv_pat)) {
273274
Convolution(nid);
274-
} else if (std::regex_match(op_name, dense_pat)) {
275+
} else if (tvm::runtime::regex_match(op_name, dense_pat)) {
275276
Dense(nid);
276277
} else if ("nn.batch_norm" == op_name) {
277278
BatchNorm(nid);
278-
} else if (std::regex_match(op_name, max_pool_pat)) {
279+
} else if (tvm::runtime::regex_match(op_name, max_pool_pat)) {
279280
Pooling(nid, dnnl::algorithm::pooling_max);
280-
} else if (std::regex_match(op_name, avg_pool_pat)) {
281+
} else if (tvm::runtime::regex_match(op_name, avg_pool_pat)) {
281282
Pooling(nid, dnnl::algorithm::pooling_avg);
282283
} else if (elt_name2algo.count(op_name)) {
283284
Eltwise(nid);

0 commit comments

Comments
 (0)