@@ -203,12 +203,23 @@ def convert_batch_norm(g, op, block):
203203 mean_name = op .input ("Mean" )[0 ]
204204 variance_name = op .input ("Variance" )[0 ]
205205 epsilon = op .attr ("epsilon" )
206+ data_layout = op .attr ("data_layout" )
207+
208+ if data_layout == "NCHW" :
209+ axis = 1
210+ elif data_layout == "NHWC" :
211+ axis = 3
212+ else :
213+ msg = f'Value { data_layout } in attribute "batch_norm" of operator Conv is not "valid."'
214+ raise tvm .error .OpAttributeInvalid (msg )
215+
206216 out = _op .nn .batch_norm (
207- g .get_node (ipt_name ),
208- g .get_node (scale_name ),
209- g .get_node (bias_name ),
210- g .get_node (mean_name ),
211- g .get_node (variance_name ),
217+ g .get_node (ipt_name ), # data
218+ g .get_node (scale_name ), # gamma
219+ g .get_node (bias_name ), # beta
220+ g .get_node (mean_name ), # moving_mean
221+ g .get_node (variance_name ), # moving_var
222+ axis = axis ,
212223 epsilon = epsilon ,
213224 )
214225 g .add_node (op .output ("Y" )[0 ], out [0 ])
@@ -1208,12 +1219,12 @@ def convert_matmul(g, op, block):
12081219
12091220 # This implemention almost keeps same with ONNX
12101221 # Need to check input shape as batch matmul must be supported.
1211- a_shape = shape_of (inputs [0 ], dtype = "int32" )
1212- a_rank = infer_shape (a_shape )[0 ]
1213- b_shape = shape_of (inputs [1 ], dtype = "int32" )
1214- b_rank = infer_shape (b_shape )[0 ]
1222+ a_rank = len (a_shape )
1223+ b_rank = len (b_shape )
12151224 # When performing a batch matmul, we need to properly handle N-dim shapes.
12161225 if a_rank > 2 or b_rank > 2 :
1226+ a_shape = shape_of (inputs [0 ], dtype = "int32" )
1227+ b_shape = shape_of (inputs [1 ], dtype = "int32" )
12171228
12181229 def flatten_to_nd (x , x_shape , nd = 3 ):
12191230 ndims = infer_shape (x_shape )[0 ]
@@ -1524,10 +1535,16 @@ def convert_pool2d(g, op, block):
15241535 padding = paddings ,
15251536 ceil_mode = ceil_mode ,
15261537 count_include_pad = not exclusive ,
1538+ layout = data_format ,
15271539 )
15281540 else :
15291541 out = getattr (_op .nn , op_map [pooling_type ])(
1530- input_x , pool_size = ksize , strides = strides , padding = paddings , ceil_mode = ceil_mode
1542+ input_x ,
1543+ pool_size = ksize ,
1544+ strides = strides ,
1545+ padding = paddings ,
1546+ ceil_mode = ceil_mode ,
1547+ layout = data_format ,
15311548 )
15321549 else :
15331550 out = getattr (_op .nn , "adaptive_" + op_map [pooling_type ])(
@@ -2973,7 +2990,7 @@ def from_program(self, program, shape_dict, scope):
29732990 if scope is None :
29742991 import paddle
29752992
2976- scope = paddle .fluid .global_scope ()
2993+ scope = paddle .static .global_scope ()
29772994 self .check_unsupported_ops (program )
29782995 self .extract_parameters (program , scope )
29792996 self .ops_to_relay (program )
0 commit comments