6161from executorch .backends .arm .tosa_partitioner import TOSAPartitioner
6262from executorch .backends .arm .tosa_specification import TosaSpecification
6363
64+ from executorch .backends .test .harness .stages import Stage , StageType
6465from executorch .backends .xnnpack .test .tester import Tester
6566from executorch .devtools .backend_debug import get_delegation_info
6667
@@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
259260 super ().run (artifact , inputs )
260261
261262
262- class InitialModel (tester . Stage ):
263+ class InitialModel (Stage ):
263264 def __init__ (self , model : torch .nn .Module ):
264265 self .model = model
265266
267+ def stage_type (self ) -> StageType :
268+ return StageType .INITIAL_MODEL
269+
266270 def run (self , artifact , inputs = None ) -> None :
267271 pass
268272
@@ -305,13 +309,13 @@ def __init__(
305309 self .constant_methods = constant_methods
306310 self .compile_spec = compile_spec
307311 super ().__init__ (model , example_inputs , dynamic_shapes )
308- self .pipeline [self . stage_name ( InitialModel ) ] = [
309- self . stage_name ( tester . Quantize ) ,
310- self . stage_name ( tester . Export ) ,
312+ self .pipeline [StageType . INITIAL_MODEL ] = [
313+ StageType . QUANTIZE ,
314+ StageType . EXPORT ,
311315 ]
312316
313317 # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
314- self .stages [self . stage_name ( InitialModel ) ] = None
318+ self .stages [StageType . INITIAL_MODEL ] = None
315319 self ._run_stage (InitialModel (self .original_module ))
316320
317321 def quantize (self , quantize_stage : Optional [tester .Quantize ] = None ):
@@ -410,7 +414,7 @@ def serialize(
410414 return super ().serialize (serialize_stage )
411415
412416 def is_quantized (self ) -> bool :
413- return self .stages [self . stage_name ( tester . Quantize ) ] is not None
417+ return self .stages [StageType . QUANTIZE ] is not None
414418
415419 def run_method_and_compare_outputs (
416420 self ,
@@ -439,18 +443,16 @@ def run_method_and_compare_outputs(
439443 """
440444
441445 if not run_eager_mode :
442- edge_stage = self .stages [self . stage_name ( tester . ToEdge ) ]
446+ edge_stage = self .stages [StageType . TO_EDGE ]
443447 if edge_stage is None :
444- edge_stage = self .stages [
445- self .stage_name (tester .ToEdgeTransformAndLower )
446- ]
448+ edge_stage = self .stages [StageType .TO_EDGE_TRANSFORM_AND_LOWER ]
447449 assert (
448450 edge_stage is not None
449451 ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
450452 else :
451453 # Run models in eager mode. We do this when we want to check that the passes
452454 # are numerically accurate and the exported graph is correct.
453- export_stage = self .stages [self . stage_name ( tester . Export ) ]
455+ export_stage = self .stages [StageType . EXPORT ]
454456 assert (
455457 export_stage is not None
456458 ), "To compare outputs in eager mode, the model must be at Export stage"
@@ -460,11 +462,11 @@ def run_method_and_compare_outputs(
460462 is_quantized = self .is_quantized ()
461463
462464 if is_quantized :
463- reference_stage = self .stages [self . stage_name ( tester . Quantize ) ]
465+ reference_stage = self .stages [StageType . QUANTIZE ]
464466 else :
465- reference_stage = self .stages [self . stage_name ( InitialModel ) ]
467+ reference_stage = self .stages [StageType . INITIAL_MODEL ]
466468
467- exported_program = self .stages [self . stage_name ( tester . Export ) ].artifact
469+ exported_program = self .stages [StageType . EXPORT ].artifact
468470 output_nodes = get_output_nodes (exported_program )
469471
470472 output_qparams = get_output_quantization_params (output_nodes )
@@ -474,7 +476,7 @@ def run_method_and_compare_outputs(
474476 quantization_scales .append (getattr (output_qparams [node ], "scale" , None ))
475477
476478 logger .info (
477- f"Comparing Stage '{ self . stage_name ( test_stage )} ' with Stage '{ self . stage_name ( reference_stage )} '"
479+ f"Comparing Stage '{ test_stage . stage_type ( )} ' with Stage '{ reference_stage . stage_type ( )} '"
478480 )
479481
480482 # Loop inputs and compare reference stage with the compared stage.
@@ -525,14 +527,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
525527 stage = self .cur
526528 artifact = self .get_artifact (stage )
527529 if (
528- self .cur == self . stage_name ( tester . ToEdge )
529- or self .cur == self . stage_name ( Partition )
530- or self .cur == self . stage_name ( ToEdgeTransformAndLower )
530+ self .cur == StageType . TO_EDGE
531+ or self .cur == StageType . PARTITION
532+ or self .cur == StageType . TO_EDGE_TRANSFORM_AND_LOWER
531533 ):
532534 graph = artifact .exported_program ().graph
533- elif self .cur == self .stage_name (tester .Export ) or self .cur == self .stage_name (
534- tester .Quantize
535- ):
535+ elif self .cur == StageType .EXPORT or self .cur == StageType .QUANTIZE :
536536 graph = artifact .graph
537537 else :
538538 raise RuntimeError (
@@ -553,13 +553,13 @@ def dump_operator_distribution(
553553 Returns self for daisy-chaining.
554554 """
555555 line = "#" * 10
556- to_print = f"{ line } { self .cur . capitalize () } Operator Distribution { line } \n "
556+ to_print = f"{ line } { self .cur } Operator Distribution { line } \n "
557557
558558 if (
559559 self .cur
560560 in (
561- self . stage_name ( tester . Partition ) ,
562- self . stage_name ( ToEdgeTransformAndLower ) ,
561+ StageType . PARTITION ,
562+ StageType . TO_EDGE_TRANSFORM_AND_LOWER ,
563563 )
564564 and print_table
565565 ):
@@ -599,9 +599,7 @@ def dump_dtype_distribution(
599599 """
600600
601601 line = "#" * 10
602- to_print = (
603- f"{ line } { self .cur .capitalize ()} Placeholder Dtype Distribution { line } \n "
604- )
602+ to_print = f"{ line } { self .cur } Placeholder Dtype Distribution { line } \n "
605603
606604 graph = self .get_graph (self .cur )
607605 tosa_spec = get_tosa_spec (self .compile_spec )
@@ -650,7 +648,7 @@ def run_transform_for_annotation_pipeline(
650648 stage = self .cur
651649 # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
652650 artifact = self .get_artifact (stage )
653- if self .cur == self . stage_name ( tester . Export ) :
651+ if self .cur == StageType . EXPORT :
654652 new_gm = ArmPassManager (get_tosa_spec (self .compile_spec )).transform_for_annotation_pipeline ( # type: ignore[arg-type]
655653 graph_module = artifact .graph_module
656654 )
0 commit comments