@@ -174,8 +174,11 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
174174 for (const auto & ttype : FlattenTupleType (fn->params [i]->checked_type ())) {
175175 std::string scope = Scope (ttype->shape , GetVirtualDevice (GetRef<Expr>(call)));
176176 if (expr_attrib.as <Conv2DAttrs>() || expr_attrib.as <Conv2DWinogradAttrs>()) {
177+ String kernel_layout = (expr_attrib.as <Conv2DAttrs>())
178+ ? expr_attrib.as <Conv2DAttrs>()->kernel_layout
179+ : expr_attrib.as <Conv2DWinogradAttrs>()->kernel_layout ;
177180 if ((i == weights_pos) && !ttype->dtype .is_float16 () &&
178- CanUseBuffers (call->args [i], ttype->shape , fn-> attrs )) {
181+ CanUseBuffers (call->args [i], ttype->shape , kernel_layout )) {
179182 buffers_params.insert (fn->params [i]);
180183 buffers_args.insert (call->args [i]);
181184 scope = " global" ;
@@ -426,10 +429,9 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
426429 }
427430
428431 bool CanUseBuffers (const Expr param, const Array<PrimExpr> shape,
429- const tvm::DictAttrs param_attrs ) const {
432+ const String kernel_layout ) const {
430433 bool use_buffer = false ;
431434 if (param.as <ConstantNode>() && shape.size () == 5 ) {
432- auto kernel_layout = param_attrs.GetAttr <String>(" kernel_layout" );
433435 if (kernel_layout == " HWOI4o" || kernel_layout == " HWIO4o" ) {
434436 int a0 = shape[0 ].as <IntImmNode>()->value ;
435437 int a1 = shape[1 ].as <IntImmNode>()->value ;
0 commit comments