Skip to content

Commit 67be0c1

Browse files
committed
Remove current usage of #include <regex>
1 parent 75a546d commit 67be0c1

File tree

7 files changed

+149
-49
lines changed

7 files changed

+149
-49
lines changed

src/ir/transform.cc

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <unordered_set>
3636

3737
#include "../runtime/object_internal.h"
38+
#include "../support/regex.h"
3839

3940
namespace tvm {
4041
namespace transform {
@@ -538,17 +539,11 @@ Pass ApplyPassToFunction(Pass pass, String func_name_regex,
538539
.str();
539540

540541
auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule {
541-
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
542-
CHECK(regex_match_func)
543-
<< "RuntimeError: "
544-
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
545-
<< "This can occur if the TVM Python library has not yet been imported.";
546-
547542
IRModule subset;
548543

549544
for (const auto& [gvar, func] : mod->functions) {
550545
std::string name = gvar->name_hint;
551-
if ((*regex_match_func)(func_name_regex, name)) {
546+
if (tvm::support::regex_match(name, func_name_regex)) {
552547
subset->Add(gvar, func);
553548
}
554549
}

src/relax/transform/update_param_struct_info.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
#include <tvm/relax/transform.h>
2828

2929
#include <optional>
30-
#include <regex>
3130
#include <unordered_map>
3231
#include <vector>
3332

33+
#include "../../support/regex.h"
3434
#include "utils.h"
3535

3636
namespace tvm {

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
#include <fstream>
3333
#include <numeric>
34-
#include <regex>
3534
#include <sstream>
3635

3736
#include "../../utils.h"

src/relay/backend/contrib/dnnl/query_layout.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131

3232
#include <fstream>
3333
#include <numeric>
34-
#include <regex>
3534
#include <sstream>
3635

3736
#include "../../../../runtime/contrib/dnnl/dnnl_utils.h"
37+
#include "../../../../support/regex.h"
3838
#include "../../utils.h"
3939
#include "dnnl.hpp"
4040
namespace tvm {
@@ -173,12 +173,12 @@ dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false,
173173
}
174174

175175
void check_shapes(const std::vector<std::string> shapes) {
176-
std::regex valid_pat("(\\d*)(,(\\d*))*");
177-
bool checked = std::regex_match(shapes[0], valid_pat);
176+
std::string valid_pat("(\\d*)(,(\\d*))*");
177+
bool checked = tvm::support::regex_match(shapes[0], valid_pat);
178178
for (size_t i = 1; i < shapes.size() - 1; i++) {
179-
checked &= std::regex_match(shapes[i], valid_pat);
179+
checked &= tvm::support::regex_match(shapes[i], valid_pat);
180180
}
181-
checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*"));
181+
checked &= tvm::support::regex_match(shapes[shapes.size() - 1], "\\d*");
182182
if (!checked) {
183183
LOG(FATAL) << "Invalid input args for query dnnl optimal layout.";
184184
}
@@ -194,8 +194,8 @@ std::string get_optimal_layout_for_conv(std::string data_layout, std::string ker
194194
std::string weight_shape, std::string out_shape,
195195
std::string paddings, std::string strides,
196196
std::string dilates, std::string G, std::string dtype) {
197-
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
198-
check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true);
197+
check_layout(tvm::support::regex_match(data_layout, "NC(D?)(H?)W"), true);
198+
check_layout(tvm::support::regex_match(kernel_layout, "(G?)OI(D?)(H?)W"), true);
199199
check_shapes({weight_shape, out_shape, paddings, strides, dilates, G});
200200

201201
dnnl::engine eng(dnnl::engine::kind::cpu, 0);
@@ -278,8 +278,8 @@ std::string get_optimal_layout_for_conv_transpose(std::string data_layout,
278278
std::string paddings, std::string output_paddings,
279279
std::string strides, std::string dilates,
280280
std::string G, std::string dtype) {
281-
check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true);
282-
check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true);
281+
check_layout(tvm::support::regex_match(data_layout, "NC(D?)(H?)W"), true);
282+
check_layout(tvm::support::regex_match(kernel_layout, "(G?)((IO)|(OI))(D?)(H?)W"), true);
283283
check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G});
284284

285285
dnnl::engine eng(dnnl::engine::kind::cpu, 0);

src/runtime/contrib/dnnl/dnnl_json_runtime.cc

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
#include <tvm/runtime/registry.h>
2727

2828
#include <cstddef>
29-
#include <regex>
3029
#include <string>
3130
#include <vector>
3231

32+
#include "../../../support/regex.h"
3333
#include "../json/json_node.h"
3434
#include "../json/json_runtime.h"
3535

@@ -194,53 +194,54 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
194194
if (o_scl_tr || activation[0] != "none" || sum_scl_tr || dst_zp_tr) return attr;
195195

196196
// Define RegExp.
197-
std::regex bias_add_pat(".*_bias.*");
198-
std::regex relu_pat(".*_relu.*");
199-
std::regex tanh_pat(".*_tanh.*");
200-
std::regex sigmoid_pat(".*_sigmoid.*");
201-
std::regex clip_pat(".*_clip.*");
202-
std::regex gelu_pat(".*_gelu.*");
203-
std::regex swish_pat(".*_swish.*");
204-
std::regex sum_pat(".*_sum.*");
205-
std::regex mish_pat(".*_mish.*");
197+
std::string bias_add_pat(".*_bias.*");
198+
std::string relu_pat(".*_relu.*");
199+
std::string tanh_pat(".*_tanh.*");
200+
std::string sigmoid_pat(".*_sigmoid.*");
201+
std::string clip_pat(".*_clip.*");
202+
std::string gelu_pat(".*_gelu.*");
203+
std::string swish_pat(".*_swish.*");
204+
std::string sum_pat(".*_sum.*");
205+
std::string mish_pat(".*_mish.*");
206206

207207
// parsing of name to extract attributes
208208
auto op_name = nodes_[nid].GetOpName();
209209

210210
// Parsing post-ops.
211211
dnnl::post_ops ops;
212-
if (std::regex_match(op_name, sum_pat)) {
212+
if (tvm::support::regex_match(op_name, sum_pat)) {
213213
ops.append_sum(1.f);
214214
}
215-
if (std::regex_match(op_name, relu_pat)) {
215+
if (tvm::support::regex_match(op_name, relu_pat)) {
216216
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_relu, 0.f, 0.f);
217217
}
218-
if (std::regex_match(op_name, tanh_pat)) {
218+
if (tvm::support::regex_match(op_name, tanh_pat)) {
219219
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_tanh, 0.f, 0.f);
220220
}
221-
if (std::regex_match(op_name, clip_pat)) {
221+
if (tvm::support::regex_match(op_name, clip_pat)) {
222222
float a_min = GetNodeAttr<float>(nodes_[nid], "a_min");
223223
float a_max = GetNodeAttr<float>(nodes_[nid], "a_max");
224224
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_clip, a_min, a_max);
225225
}
226-
if (std::regex_match(op_name, sigmoid_pat)) {
226+
if (tvm::support::regex_match(op_name, sigmoid_pat)) {
227227
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_logistic, 0.f, 0.f);
228228
}
229-
if (std::regex_match(op_name, swish_pat)) {
229+
if (tvm::support::regex_match(op_name, swish_pat)) {
230230
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_swish, 1.f, 1.f);
231231
}
232-
if (std::regex_match(op_name, gelu_pat)) {
232+
if (tvm::support::regex_match(op_name, gelu_pat)) {
233233
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f);
234234
}
235-
if (std::regex_match(op_name, mish_pat)) {
235+
if (tvm::support::regex_match(op_name, mish_pat)) {
236236
ops.append_eltwise(1.f, dnnl::algorithm::eltwise_mish, 1.f, 0.f);
237237
}
238238
if (ops.len() != 0) {
239239
attr.set_post_ops(ops);
240240
}
241241

242242
// Parsing bias_add.
243-
*bias_tr = std::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
243+
*bias_tr =
244+
tvm::support::regex_match(op_name, bias_add_pat) ? GetInput(nid, 2) : TensorRequisite{};
244245

245246
return attr;
246247
}
@@ -253,31 +254,31 @@ class DNNLJSONRuntime : public JSONRuntimeBase {
253254
std::set<uint32_t> io_eid_set(run_arg_eid_.begin(), run_arg_eid_.end());
254255
tensor_registry_ = TensorRegistry(engine_, io_eid_set);
255256

256-
std::regex conv_pat(".*conv[1-3]d.*");
257-
std::regex deconv_pat(".*deconv[1-3]d.*");
258-
std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*");
259-
std::regex dense_pat(".*dense.*");
260-
std::regex max_pool_pat(".*max_pool[1-3]d");
261-
std::regex avg_pool_pat(".*avg_pool[1-3]d");
257+
std::string conv_pat(".*conv[1-3]d.*");
258+
std::string deconv_pat(".*deconv[1-3]d.*");
259+
std::string conv_transpose_pat(".*conv[1-3]d_transpose.*");
260+
std::string dense_pat(".*dense.*");
261+
std::string max_pool_pat(".*max_pool[1-3]d");
262+
std::string avg_pool_pat(".*avg_pool[1-3]d");
262263

263264
// Build subgraph engine.
264265
for (size_t nid = 0; nid < nodes_.size(); ++nid) {
265266
const auto& node = nodes_[nid];
266267
if (node.GetOpType() == "kernel") {
267268
ICHECK_EQ(node.GetOpType(), "kernel");
268269
auto op_name = node.GetOpName();
269-
if (std::regex_match(op_name, deconv_pat) ||
270-
std::regex_match(op_name, conv_transpose_pat)) {
270+
if (tvm::support::regex_match(op_name, deconv_pat) ||
271+
tvm::support::regex_match(op_name, conv_transpose_pat)) {
271272
Deconvolution(nid);
272-
} else if (std::regex_match(op_name, conv_pat)) {
273+
} else if (tvm::support::regex_match(op_name, conv_pat)) {
273274
Convolution(nid);
274-
} else if (std::regex_match(op_name, dense_pat)) {
275+
} else if (tvm::support::regex_match(op_name, dense_pat)) {
275276
Dense(nid);
276277
} else if ("nn.batch_norm" == op_name) {
277278
BatchNorm(nid);
278-
} else if (std::regex_match(op_name, max_pool_pat)) {
279+
} else if (tvm::support::regex_match(op_name, max_pool_pat)) {
279280
Pooling(nid, dnnl::algorithm::pooling_max);
280-
} else if (std::regex_match(op_name, avg_pool_pat)) {
281+
} else if (tvm::support::regex_match(op_name, avg_pool_pat)) {
281282
Pooling(nid, dnnl::algorithm::pooling_avg);
282283
} else if (elt_name2algo.count(op_name)) {
283284
Eltwise(nid);

src/support/regex.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/support/regex.cc
22+
* \brief Exposes calls to python's `re` library.
23+
*/
24+
25+
#include "./regex.h"
26+
27+
#include <tvm/runtime/registry.h>
28+
29+
namespace tvm {
30+
namespace support {
31+
32+
bool regex_match(const std::string& match_against, const std::string& regex_pattern) {
33+
const auto* regex_match_func = tvm::runtime::Registry::Get("tvm.support.regex_match");
34+
CHECK(regex_match_func) << "RuntimeError: "
35+
<< "The PackedFunc 'tvm.support.regex_match' has not been registered. "
36+
<< "This can occur if the TVM Python library has not yet been imported.";
37+
return (*regex_match_func)(regex_pattern, match_against);
38+
}
39+
40+
} // namespace support
41+
} // namespace tvm

src/support/regex.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file regex.h
22+
* \brief Exposes calls to python's `re` library.
23+
*/
24+
#ifndef TVM_SUPPORT_REGEX_H_
25+
#define TVM_SUPPORT_REGEX_H_
26+
27+
#include <string>
28+
29+
namespace tvm {
30+
namespace support {
31+
32+
/* \brief Check if a pattern matches a regular expression
33+
*
34+
* This function should be used instead of `std::regex` within C++
35+
* call sites, to avoid ABI incompatibilities with pytorch.
36+
*
37+
* Currently, the pytorch wheels available through pip install use
38+
* the pre-C++11 ABI by setting `-DUSE_CXX11_ABI=0` [0]. If TVM were to
39+
* user the pre-C++11 ABI, this would cause breakages with
40+
* dynamically-linked LLVM environments.
41+
*
42+
* Use of the `<regex>` header in TVM should be avoided, as its
43+
* implementation is not supported by gcc's dual ABI. This ABI
44+
* incompatibility results in runtime errors either when `std::regex`
45+
* is called from TVM, or when `std::regex` is called from pytorch,
46+
* depending on which library was loaded first. This restriction can
47+
* be removed when a version of pytorch compiled using
48+
* `-DUSE_CXX11_ABI=1` is available from PyPI.
49+
*
50+
* [0] https://github.com/pytorch/pytorch/issues/51039
51+
*
52+
* \param match_against The string against which to match the regular expression
53+
*
54+
* \param regex_pattern The regular expression
55+
*
56+
* \returns match_result True if `match_against` matches the pattern
57+
* defined by `regex_pattern`, and False otherwise.
58+
*/
59+
60+
bool regex_match(const std::string& match_against, const std::string& regex_pattern);
61+
62+
} // namespace support
63+
} // namespace tvm
64+
#endif // TVM_SUPPORT_REGEX_H_

0 commit comments

Comments
 (0)