Skip to content

Commit b468ff7

Browse files
committed
Calculate Conv2d buffer size with respect to architecture extensions
This correctly calculates the buffer sizes for a variety of targets based on the `-mcpu` and `-mattr` flags passed to the `cmsis-nn` code generator.
1 parent 66eed5c commit b468ff7

File tree

10 files changed

+550
-5
lines changed

10 files changed

+550
-5
lines changed

python/tvm/driver/tvmc/composite_target.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
"pass_pipeline": partition_for_arm_compute_lib,
5353
},
5454
"cmsis-nn": {
55-
"config_key": None,
55+
"config_key": "relay.ext.cmsisnn.options",
5656
"pass_pipeline": partition_for_cmsisnn,
5757
},
5858
"ethos-n77": {
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <tvm/ir/attrs.h>
21+
#include <tvm/ir/transform.h>
22+
23+
#include "compiler_attrs.h"
24+
25+
namespace tvm {
26+
namespace relay {
27+
namespace contrib {
28+
namespace cmsisnn {
29+
30+
int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, int32_t input_n,
31+
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
32+
int32_t stride_w, int32_t stride_h, int32_t filter_w, int32_t filter_h) {
33+
bool is1x1 = (padding_w == 0) && (padding_h == 0) && (input_c % 4 == 0) && (stride_w == 1) &&
34+
(stride_h == 1) && (filter_w == 1) && (filter_h == 1);
35+
bool is1xN =
36+
(output_h == 1) && (input_h == 1) && (filter_h == 1) && (output_w % 4 == 0) && (input_n == 1);
37+
38+
if (is1x1) {
39+
return 0;
40+
}
41+
42+
if (is1xN) {
43+
if (flags.dsp && !flags.mve) {
44+
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
45+
}
46+
return 0;
47+
}
48+
49+
if (flags.dsp) {
50+
return (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
51+
}
52+
return 0;
53+
}
54+
55+
} // namespace cmsisnn
56+
} // namespace contrib
57+
} // namespace relay
58+
} // namespace tvm
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/backend/contrib/cmsisnn/buffer_size.h
22+
* \brief CMSIS-NN Buffer Size calculation functions
23+
*/
24+
25+
#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
26+
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
27+
28+
#include <tvm/ir/transform.h>
29+
30+
#include "compiler_attrs.h"
31+
32+
namespace tvm {
33+
namespace relay {
34+
namespace contrib {
35+
namespace cmsisnn {
36+
37+
/*!
38+
* \brief Calculates the appropriate buffer size for CMSIS-NN Convolutions
39+
* See:
40+
* https://github.com/ARM-software/CMSIS_5/blob/8c60448c0e1e50e426180b26db9bc31ddf774361/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_wrapper_s8.c#L108-L127
41+
*
42+
* \param flags - CMSIS-NN feature flags
43+
* \param padding_w - Width padding
44+
* \param padding_h - Height padding
45+
* \param input_n - Input batch size
46+
* \param input_h - Input height
47+
* \param input_c - Input channels
48+
* \param output_h - Output height
49+
* \param output_w - Output width
50+
* \param stride_w - Stride width
51+
* \param stride_h - Stride height
52+
* \param filter_w - Filter width
53+
* \param filter_h - Filter height
54+
*
55+
* \return Size of buffer to allocate for convolution
56+
*/
57+
int Conv2dBufferSize(CMSISNNFlags flags, int32_t padding_w, int32_t padding_h, int32_t input_n,
58+
int32_t input_h, int32_t input_c, int32_t output_h, int32_t output_w,
59+
int32_t stride_w, int32_t stride_h, int32_t filter_w, int32_t filter_h);
60+
61+
} // namespace cmsisnn
62+
} // namespace contrib
63+
} // namespace relay
64+
} // namespace tvm
65+
66+
#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_BUFFER_SIZE_H_
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include "compiler_attrs.h"
21+
22+
#include <tvm/ir/attrs.h>
23+
#include <tvm/ir/transform.h>
24+
25+
namespace tvm {
26+
namespace relay {
27+
namespace contrib {
28+
namespace cmsisnn {
29+
30+
static const char* mveCPUs[] = {"cortex-m55"};
31+
static const char* dspCPUs[] = {"cortex-m7", "cortex-m33", "cortex-m35p"};
32+
33+
TVM_REGISTER_NODE_TYPE(CMSISNNCompilerConfigNode);
34+
TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.cmsisnn.options", CMSISNNCompilerConfig);
35+
36+
template <typename Container>
37+
static inline bool MatchesCpu(std::string mcpu, Container& cpus) {
38+
auto matches_cpu = [mcpu](const char* cpu) { return mcpu.find(cpu) == 0; };
39+
return std::find_if(std::begin(cpus), std::end(cpus), matches_cpu) != std::end(cpus);
40+
}
41+
42+
static inline bool HasFlag(std::string attr, std::string flag) {
43+
return attr.find(flag) != std::string::npos;
44+
}
45+
46+
CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx) {
47+
auto cfg = ctx->GetConfig<CMSISNNCompilerConfig>("relay.ext.cmsisnn.options");
48+
if (!cfg.defined()) {
49+
return {false, false};
50+
}
51+
52+
std::string mcpu = cfg.value()->mcpu;
53+
std::string mattr = cfg.value()->mattr;
54+
55+
bool nomve = HasFlag(mcpu, "+nomve") || HasFlag(mattr, "+nomve");
56+
bool nodsp = HasFlag(mcpu, "+nodsp") || HasFlag(mattr, "+nodsp");
57+
58+
auto has_mve = MatchesCpu(mcpu, mveCPUs);
59+
if (has_mve && !nomve && !nodsp) {
60+
return {true, true};
61+
}
62+
63+
auto has_dsp = MatchesCpu(mcpu, dspCPUs);
64+
if (has_dsp && !nodsp) {
65+
return {true, false};
66+
}
67+
68+
return {false, false};
69+
}
70+
71+
} // namespace cmsisnn
72+
} // namespace contrib
73+
} // namespace relay
74+
} // namespace tvm
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file src/relay/backend/contrib/cmsisnn/compiler_attrs.h
22+
* \brief CMSIS-NN Compiler Attribute functionality
23+
*/
24+
25+
#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
26+
#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_
27+
28+
#include <tvm/ir/transform.h>
29+
30+
namespace tvm {
31+
namespace relay {
32+
namespace contrib {
33+
namespace cmsisnn {
34+
35+
/*! \brief Attributes to store the compiler options for CMSIS-NN. */
36+
struct CMSISNNCompilerConfigNode : public tvm::AttrsNode<CMSISNNCompilerConfigNode> {
37+
String mcpu;
38+
String mattr;
39+
40+
TVM_DECLARE_ATTRS(CMSISNNCompilerConfigNode, "ext.attrs.CMSISNNCompilerConfigNode") {
41+
TVM_ATTR_FIELD(mcpu)
42+
.describe(
43+
"The CPU to configure CMSIS-NN for (i.e. cortex-m55, cortex-m4), can also include "
44+
"attributes (i.e. cortex-m55+nomve)")
45+
.set_default("");
46+
TVM_ATTR_FIELD(mattr)
47+
.describe("The attributes to configure CMSIS-NN (i.e. +nodsp, +nomve)")
48+
.set_default("");
49+
}
50+
};
51+
52+
class CMSISNNCompilerConfig : public Attrs {
53+
public:
54+
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CMSISNNCompilerConfig, Attrs,
55+
CMSISNNCompilerConfigNode);
56+
};
57+
58+
/*! \brief Flags to configure the calculations for CMSIS-NN. */
59+
struct CMSISNNFlags {
60+
bool dsp; // Enable or disable dsp buffers
61+
bool mve; // Enable or disable mve buffers
62+
};
63+
64+
CMSISNNFlags GetCompilerFlags(const tvm::transform::PassContext& ctx);
65+
66+
} // namespace cmsisnn
67+
} // namespace contrib
68+
} // namespace relay
69+
} // namespace tvm
70+
71+
#endif // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPILER_ATTRS_H_

src/relay/backend/contrib/cmsisnn/relay_to_tir.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
#include "../../../qnn/utils.h"
3030
#include "../../../transforms/pattern_utils.h"
31+
#include "buffer_size.h"
32+
#include "compiler_attrs.h"
3133

3234
namespace tvm {
3335
namespace relay {
@@ -53,7 +55,6 @@ class RelayToTIRVisitor : public MixedModeVisitor {
5355
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));
5456

5557
if (context_buffer_size) {
56-
// TODO(@ashutosh-arm) while supporting MVE, we need to move allocation through TVMBAW
5758
tir::Var buffer_var("context_buffer", PointerType(PrimType(DataType::Int(8)), "global"));
5859
body = tir::Allocate(buffer_var, DataType::Int(8), {context_buffer_size}, tir::const_true(),
5960
body);
@@ -179,8 +180,10 @@ class RelayToTIRVisitor : public MixedModeVisitor {
179180
func_signature.push_back(const_var5);
180181
func_signature.push_back(out_var);
181182

182-
// https://github.com/ARM-software/CMSIS_5/blob/d788fd583984388553391de18afd8b4d2a146868/CMSIS/NN/Source/ConvolutionFunctions/arm_convolve_s8.c#L367
183-
size_t context_buffer_size = (2 * input_c * filter_w * filter_h) * (int32_t)sizeof(int16_t);
183+
CMSISNNFlags flags = GetCompilerFlags(transform::PassContext::Current());
184+
int context_buffer_size =
185+
Conv2dBufferSize(flags, padding_w, padding_h, input_n, input_h, input_c, output_h, output_w,
186+
stride_w, stride_h, filter_w, filter_h);
184187

185188
CreatePrimFuncForExtern(func_signature, call_ext_args, context_buffer_size);
186189
}

src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ class CodeGenCMSISNN : public CodeGenC {
208208
stream << "printf(\"Failed during execution of " << cmsis_func_name << "().\");\n";
209209
PrintIndent();
210210
stream << "}\n";
211+
212+
ResetBufferContext();
211213
}
212214

213215
/*! * \brief Creates a cplusplus guard prefix for extern "C" printing */
@@ -226,8 +228,13 @@ class CodeGenCMSISNN : public CodeGenC {
226228
ss << "#endif\n";
227229
}
228230

231+
void ResetBufferContext() {
232+
context_buffer_name_ = "NULL";
233+
context_buffer_size_ = 0;
234+
}
235+
229236
private:
230-
std::string context_buffer_name_ = "Empty";
237+
std::string context_buffer_name_ = "NULL";
231238
int context_buffer_size_ = 0;
232239
};
233240

0 commit comments

Comments
 (0)