5252 # However, please note that `nn.matmul` is in experimental so it may have some performance
5353 # issues.
5454 "use_dense" : True ,
55+ # By default, TVM converts `tf.batch_matmul` to `transpose(weight) + nn.batch_matmul_NT`.
56+ # Change this flag to False to directly convert to `nn.batch_matmul`.
57+ # Note that `nn.batch_matmul` with format other than NT is in experimental, it may have some
58+ # performance issues.
59+ "use_nt_batch_matmul" : True ,
5560}
5661
5762# compatible operators that do NOT require any conversion.
@@ -1214,7 +1219,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None):
12141219 return func , self ._params
12151220
12161221
1217- def from_tensorflow (graph , layout = "NHWC" , shape = None , outputs = None , use_dense_op = True ):
1222+ def from_tensorflow (graph , layout = "NHWC" , shape = None , outputs = None , convert_config = None ):
12181223 """Load tensorflow graph which is a python tensorflow graph object into relay.
12191224 The companion parameters will be handled automatically.
12201225
@@ -1232,10 +1237,15 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op
12321237 outputs : List of output tensor names (Optional)
12331238 if not specified then the last node is assumed as graph output.
12341239
1235- use_dense_op : bool (Optional) = True
1236- Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
1237- The `nn.dense` op requires the data tensor to be non-transposed and weight tensor to be
1238- transposed, may insert extra `transpose` to the original graph.
1240+ convert_config : Optional[Dict[str, Any]]
1241+ Default config:
1242+ use_dense : bool = True
1243+ Ture to convert `tf.matmul` to `nn.dense`, else to `nn.matmul`.
1244+ The `nn.dense` op requires the data tensor to be non-transposed and weight tensor
1245+ to be transposed, may insert extra `transpose` to the original graph.
1246+ use_nt_batch_matmul : bool = True
1247+ True to convert `tf.batch_matmul` to `nn.batch_matmul` strict to NT format
1248+ (transpose_a=False, transpose_b=True).
12391249
12401250 Returns
12411251 -------
@@ -1246,7 +1256,8 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None, use_dense_op
12461256 Dict of converted parameters stored in tvm.nd.NDArray format
12471257 """
12481258 global TF_DEFAULT_CONFIGS
1249- TF_DEFAULT_CONFIGS ["use_dense" ] = use_dense_op
1259+ if convert_config is not None :
1260+ TF_DEFAULT_CONFIGS .update (convert_config )
12501261
12511262 g = GraphProto ()
12521263 mod , params = g .from_tensorflow (graph , layout , shape , outputs )
0 commit comments