Skip to content

Commit 884dad1

Browse files
committed
Fix loading weights to buffers for VM
1 parent e421068 commit 884dad1

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/relay/transforms/annotate_texture_storage.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,11 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
174174
for (const auto& ttype : FlattenTupleType(fn->params[i]->checked_type())) {
175175
std::string scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
176176
if (expr_attrib.as<Conv2DAttrs>() || expr_attrib.as<Conv2DWinogradAttrs>()) {
177+
String kernel_layout = (expr_attrib.as<Conv2DAttrs>())
178+
? expr_attrib.as<Conv2DAttrs>()->kernel_layout
179+
: expr_attrib.as<Conv2DWinogradAttrs>()->kernel_layout;
177180
if ((i == weights_pos) && !ttype->dtype.is_float16() &&
178-
CanUseBuffers(call->args[i], ttype->shape, fn->attrs)) {
181+
CanUseBuffers(call->args[i], ttype->shape, kernel_layout)) {
179182
buffers_params.insert(fn->params[i]);
180183
buffers_args.insert(call->args[i]);
181184
scope = "global";
@@ -426,10 +429,9 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
426429
}
427430

428431
bool CanUseBuffers(const Expr param, const Array<PrimExpr> shape,
429-
const tvm::DictAttrs param_attrs) const {
432+
const String kernel_layout) const {
430433
bool use_buffer = false;
431434
if (param.as<ConstantNode>() && shape.size() == 5) {
432-
auto kernel_layout = param_attrs.GetAttr<String>("kernel_layout");
433435
if (kernel_layout == "HWOI4o" || kernel_layout == "HWIO4o") {
434436
int a0 = shape[0].as<IntImmNode>()->value;
435437
int a1 = shape[1].as<IntImmNode>()->value;

0 commit comments

Comments
 (0)