@@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table):
4444 want to add the operator's pattern to the pattern table so that the compiler
4545 wouldn't attempt to offload an operator without full stack support."""
4646 mod = relay .transform .InferType ()(mod )
47+ mod = relay .transform .replicate_pads (mod )
48+ mod = relay .transform .InferType ()(mod )
4749 mod = relay .transform .MergeComposite (pattern_table )(mod )
4850 mod = relay .transform .AnnotateTarget ("ethos-u" )(mod )
4951 mod = relay .transform .MergeCompilerRegions ()(mod )
@@ -3646,5 +3648,149 @@ def _visit(stmt):
36463648 verify (mod ["tvmgen_default_ethos_u_main_0" ])
36473649
36483650
3651+ @pytest .mark .parametrize ("ifm_shape" , [(1 , 55 , 55 , 3 )])
3652+ @pytest .mark .parametrize ("kernel_shape" , [(3 , 3 )])
3653+ @pytest .mark .parametrize ("strides, dilation" , [((1 , 1 ), (1 , 1 ))])
3654+ @pytest .mark .parametrize ("op_padding" , ["SAME" , "VALID" ])
3655+ @pytest .mark .parametrize ("sep_padding" , [(0 , 0 , 1 , 1 ), (7 , 5 , 4 , 5 )])
3656+ @pytest .mark .parametrize (
3657+ "op_pairs" , [("conv2d" , "conv2d" ), ("depthwise" , "depthwise" ), ("conv2d" , "depthwise" )]
3658+ )
3659+ def test_tflite_shared_pad_legalize (
3660+ ifm_shape ,
3661+ kernel_shape ,
3662+ strides ,
3663+ dilation ,
3664+ op_padding ,
3665+ sep_padding ,
3666+ op_pairs ,
3667+ ):
3668+ dtype = "int8"
3669+
3670+ def create_tflite_graph ():
3671+ class Model (tf .Module ):
3672+ @tf .function
3673+ def tf_function (self , x ):
3674+
3675+ x = tf .pad (
3676+ x ,
3677+ [
3678+ [0 , 0 ],
3679+ [sep_padding [0 ], sep_padding [2 ]],
3680+ [sep_padding [1 ], sep_padding [3 ]],
3681+ [0 , 0 ],
3682+ ],
3683+ "CONSTANT" ,
3684+ )
3685+
3686+ # The input strides to the TensorFlow API needs to be of shape 1x4
3687+ tf_strides = [1 , strides [0 ], strides [1 ], 1 ]
3688+
3689+ if op_pairs [0 ] == "depthwise" :
3690+ weight_shape = [kernel_shape [0 ], kernel_shape [1 ], ifm_shape [3 ], 1 ]
3691+ weight = tf .constant (np .random .uniform (size = weight_shape ), dtype = tf .float32 )
3692+ x1 = tf .nn .depthwise_conv2d (
3693+ x , weight , strides = tf_strides , padding = op_padding , dilations = dilation
3694+ )
3695+ else :
3696+ weight_shape = [kernel_shape [0 ], kernel_shape [1 ], ifm_shape [3 ], 3 ]
3697+ weight = tf .constant (np .random .uniform (size = weight_shape ), dtype = tf .float32 )
3698+ x1 = tf .nn .conv2d (
3699+ x ,
3700+ weight ,
3701+ strides = tf_strides ,
3702+ padding = op_padding ,
3703+ dilations = dilation ,
3704+ )
3705+
3706+ if op_pairs [1 ] == "depthwise" :
3707+ weight_shape = [kernel_shape [0 ], kernel_shape [1 ], ifm_shape [3 ], 1 ]
3708+ weight = tf .constant (np .random .uniform (size = weight_shape ), dtype = tf .float32 )
3709+ x2 = tf .nn .depthwise_conv2d (
3710+ x , weight , strides = tf_strides , padding = op_padding , dilations = dilation
3711+ )
3712+ else :
3713+ weight_shape = [kernel_shape [0 ], kernel_shape [1 ], ifm_shape [3 ], 3 ]
3714+ weight = tf .constant (np .random .uniform (size = weight_shape ), dtype = tf .float32 )
3715+ x2 = tf .nn .conv2d (
3716+ x ,
3717+ weight ,
3718+ strides = tf_strides ,
3719+ padding = op_padding ,
3720+ dilations = dilation ,
3721+ )
3722+
3723+ x3 = tf .math .add (x1 , x2 )
3724+ return x3
3725+
3726+ model = Model ()
3727+ concrete_func = model .tf_function .get_concrete_function (
3728+ tf .TensorSpec (ifm_shape , dtype = tf .float32 )
3729+ )
3730+ # Convert the model
3731+ def representative_dataset ():
3732+ for _ in range (100 ):
3733+ data = np .random .rand (* tuple (ifm_shape ))
3734+ yield [data .astype (np .float32 )]
3735+
3736+ converter = tf .lite .TFLiteConverter .from_concrete_functions ([concrete_func ])
3737+ converter .optimizations = [tf .lite .Optimize .DEFAULT ]
3738+ converter .representative_dataset = representative_dataset
3739+ converter .target_spec .supported_ops = [tf .lite .OpsSet .TFLITE_BUILTINS_INT8 ]
3740+ converter .inference_input_type = tf .int8
3741+ converter .inference_output_type = tf .int8
3742+ tflite_model = converter .convert ()
3743+ return tflite_model
3744+
3745+ conv2d_pattern_table = [
3746+ (
3747+ ethosu .QnnConv2DParams .composite_name ,
3748+ ethosu .qnn_conv2d_pattern (),
3749+ lambda pat : ethosu .QnnConv2DParams (pat ).is_valid (),
3750+ ),
3751+ (
3752+ ethosu .QnnDepthwiseConv2DParams .composite_name ,
3753+ ethosu .qnn_depthwise_conv2d_pattern (),
3754+ lambda pat : ethosu .QnnDepthwiseConv2DParams (pat ).is_valid (),
3755+ ),
3756+ ]
3757+
3758+ tflite_graph = create_tflite_graph ()
3759+ # tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
3760+ tflite_model = tflite .Model .GetRootAsModel (tflite_graph , 0 )
3761+
3762+ mod , params = relay .frontend .from_tflite (
3763+ tflite_model ,
3764+ shape_dict = {"input" : ifm_shape },
3765+ dtype_dict = {"input" : dtype },
3766+ )
3767+
3768+ mod ["main" ] = bind_params_by_name (mod ["main" ], params )
3769+ mod = partition_ethosu_by_table (mod , conv2d_pattern_table )
3770+
3771+ mod ["tvmgen_default_ethos_u_main_0" ] = dataflow_pattern .rewrite (
3772+ [legalize .Conv2DRewriter (), legalize .DepthwiseConv2DRewriter ()],
3773+ mod ["tvmgen_default_ethos_u_main_0" ],
3774+ )
3775+ mod ["tvmgen_default_ethos_u_main_1" ] = dataflow_pattern .rewrite (
3776+ [legalize .Conv2DRewriter (), legalize .DepthwiseConv2DRewriter ()],
3777+ mod ["tvmgen_default_ethos_u_main_1" ],
3778+ )
3779+
3780+ if op_pairs [0 ] == "depthwise" :
3781+ assert (
3782+ mod ["tvmgen_default_ethos_u_main_0" ].body .op .name == "contrib.ethosu.depthwise_conv2d"
3783+ )
3784+ else :
3785+ assert mod ["tvmgen_default_ethos_u_main_0" ].body .op .name == "contrib.ethosu.conv2d"
3786+
3787+ if op_pairs [1 ] == "depthwise" :
3788+ assert (
3789+ mod ["tvmgen_default_ethos_u_main_1" ].body .op .name == "contrib.ethosu.depthwise_conv2d"
3790+ )
3791+ else :
3792+ assert mod ["tvmgen_default_ethos_u_main_1" ].body .op .name == "contrib.ethosu.conv2d"
3793+
3794+
36493795if __name__ == "__main__" :
36503796 tvm .testing .main ()
0 commit comments