2323from executorch .exir ._serialize ._serialize import serialize_for_executorch
2424from executorch .exir ._serialize .data_serializer import DataSerializer
2525from executorch .exir ._warnings import experimental
26- from executorch .exir .backend .backend_api import to_backend
26+ from executorch .exir .backend .backend_api import (
27+ MethodProgramsPartitionerSpec ,
28+ to_backend ,
29+ )
2730from executorch .exir .backend .partitioner import Partitioner
2831from executorch .exir .capture ._config import EdgeCompileConfig , ExecutorchBackendConfig
2932from executorch .exir .delegate import executorch_call_delegate , is_lowered_module
@@ -1239,10 +1242,16 @@ def to_edge_transform_and_lower(
12391242 if transform_passes is not None :
12401243 edge_manager = edge_manager .transform (transform_passes )
12411244
1242- if partitioner is not None :
1245+ max_num_partitioners = 0
1246+ for partitioner_list in partitioner .values ():
1247+ max_num_partitioners = max (max_num_partitioners , len (partitioner_list ))
1248+
1249+ for i in range (max_num_partitioners ):
1250+ method_to_partitioner = {}
12431251 for name , partitioner_list in partitioner .items ():
1244- for curr_partitioner in partitioner_list :
1245- edge_manager = edge_manager .to_backend ({name : curr_partitioner })
1252+ if i < len (partitioner_list ):
1253+ method_to_partitioner [name ] = partitioner_list [i ]
1254+ edge_manager = edge_manager .to_backend (method_to_partitioner )
12461255
12471256 for name , program in edge_manager ._edge_programs .items ():
12481257 ops_set_to_not_decompose : Set [torch ._ops .OpOverload ] = set ()
@@ -1475,7 +1484,8 @@ def transform(
14751484
14761485 @et_logger ("to_backend" )
14771486 def to_backend (
1478- self , partitioner : Union [Partitioner , Dict [str , Partitioner ]]
1487+ self ,
1488+ partitioner : Union [Partitioner , Dict [str , Partitioner ]],
14791489 ) -> "EdgeProgramManager" :
14801490 """
14811491 Returns a semantically-equivalent program to the one given as input,
@@ -1501,17 +1511,18 @@ def to_backend(
15011511 specified subgraphs lowered.
15021512 """
15031513 new_edge_programs : Dict [str , ExportedProgram ] = {}
1504- if isinstance (partitioner , dict ):
1505- for name , program in self ._edge_programs .items ():
1506- if name in partitioner .keys ():
1507- new_edge_programs [name ] = to_backend (program , partitioner [name ])
1508- else :
1509- new_edge_programs [name ] = program
1514+ method_to_partitioner : Dict [str , Partitioner ] = {}
1515+ if not isinstance (partitioner , dict ):
1516+ method_to_partitioner = {name : partitioner for name in self ._edge_programs }
1517+ else :
1518+ method_to_partitioner = partitioner
15101519
1511- else : # apply partitioner to every method
1512- for name , program in self ._edge_programs .items ():
1513- new_edge_programs [name ] = to_backend (program , partitioner )
1520+ method_to_programs_and_partitioners = MethodProgramsPartitionerSpec (
1521+ self ._edge_programs ,
1522+ method_to_partitioner ,
1523+ )
15141524
1525+ new_edge_programs = to_backend (method_to_programs_and_partitioners )
15151526 config = EdgeCompileConfig (_check_ir_validity = False )
15161527 return EdgeProgramManager (
15171528 new_edge_programs ,
0 commit comments