5050from executorch .backends .arm .tosa_partitioner import TOSAPartitioner
5151from executorch .backends .arm .tosa_specification import TosaSpecification
5252
53+ from executorch .backends .test .harness .stages import StageType
5354from executorch .backends .xnnpack .test .tester import Tester
5455from executorch .devtools .backend_debug import get_delegation_info
5556
@@ -284,13 +285,13 @@ def __init__(
284285 self .constant_methods = constant_methods
285286 self .compile_spec = compile_spec
286287 super ().__init__ (model , example_inputs , dynamic_shapes )
287- self .pipeline [self . stage_name ( InitialModel ) ] = [
288- self . stage_name ( tester . Quantize ) ,
289- self . stage_name ( tester . Export ) ,
288+ self .pipeline [StageType . INITIAL_MODEL ] = [
289+ StageType . QUANTIZE ,
290+ StageType . EXPORT ,
290291 ]
291292
292293 # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
293- self .stages [self . stage_name ( InitialModel ) ] = None
294+ self .stages [StageType . INTIIAL_MODEL ] = None
294295 self ._run_stage (InitialModel (self .original_module ))
295296
296297 def quantize (self , quantize_stage : Optional [tester .Quantize ] = None ):
@@ -385,7 +386,7 @@ def serialize(
385386 return super ().serialize (serialize_stage )
386387
387388 def is_quantized (self ) -> bool :
388- return self .stages [self . stage_name ( tester . Quantize ) ] is not None
389+ return self .stages [StageType . QUANTIZE ] is not None
389390
390391 def run_method_and_compare_outputs (
391392 self ,
@@ -414,18 +415,16 @@ def run_method_and_compare_outputs(
414415 """
415416
416417 if not run_eager_mode :
417- edge_stage = self .stages [self . stage_name ( tester . ToEdge ) ]
418+ edge_stage = self .stages [StageType . TO_EDGE ]
418419 if edge_stage is None :
419- edge_stage = self .stages [
420- self .stage_name (tester .ToEdgeTransformAndLower )
421- ]
420+ edge_stage = self .stages [StageType .TO_EDGE_TRANSFORM_AND_LOWER ]
422421 assert (
423422 edge_stage is not None
424423 ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
425424 else :
426425 # Run models in eager mode. We do this when we want to check that the passes
427426 # are numerically accurate and the exported graph is correct.
428- export_stage = self .stages [self . stage_name ( tester . Export ) ]
427+ export_stage = self .stages [StageType . EXPORT ]
429428 assert (
430429 export_stage is not None
431430 ), "To compare outputs in eager mode, the model must be at Export stage"
@@ -435,11 +434,11 @@ def run_method_and_compare_outputs(
435434 is_quantized = self .is_quantized ()
436435
437436 if is_quantized :
438- reference_stage = self .stages [self . stage_name ( tester . Quantize ) ]
437+ reference_stage = self .stages [StageType . QUANTIZE ]
439438 else :
440- reference_stage = self .stages [self . stage_name ( InitialModel ) ]
439+ reference_stage = self .stages [StageType . INITIAL_MODEL ]
441440
442- exported_program = self .stages [self . stage_name ( tester . Export ) ].artifact
441+ exported_program = self .stages [StageType . EXPORT ].artifact
443442 output_nodes = get_output_nodes (exported_program )
444443
445444 output_qparams = get_output_quantization_params (output_nodes )
@@ -449,7 +448,7 @@ def run_method_and_compare_outputs(
449448 quantization_scales .append (getattr (output_qparams [node ], "scale" , None ))
450449
451450 logger .info (
452- f"Comparing Stage '{ self . stage_name ( test_stage )} ' with Stage '{ self . stage_name ( reference_stage )} '"
451+ f"Comparing Stage '{ test_stage . stage_type ( )} ' with Stage '{ reference_stage . stage_type ( )} '"
453452 )
454453
455454 # Loop inputs and compare reference stage with the compared stage.
@@ -500,14 +499,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
500499 stage = self .cur
501500 artifact = self .get_artifact (stage )
502501 if (
503- self .cur == self . stage_name ( tester . ToEdge )
504- or self .cur == self . stage_name ( Partition )
505- or self .cur == self . stage_name ( ToEdgeTransformAndLower )
502+ self .cur == StageType . TO_EDGE
503+ or self .cur == StageType . PARTITION
504+ or self .cur == StageType . TO_EDGE_TRANSFORM_AND_LOWER
506505 ):
507506 graph = artifact .exported_program ().graph
508- elif self .cur == self .stage_name (tester .Export ) or self .cur == self .stage_name (
509- tester .Quantize
510- ):
507+ elif self .cur == StageType .EXPORT or self .cur == StageType .QUANTIZE :
511508 graph = artifact .graph
512509 else :
513510 raise RuntimeError (
@@ -533,8 +530,8 @@ def dump_operator_distribution(
533530 if (
534531 self .cur
535532 in (
536- self . stage_name ( tester . Partition ) ,
537- self . stage_name ( ToEdgeTransformAndLower ) ,
533+ StageType . PARTITION ,
534+ StageType . TO_EDGE_TRANSFORM_AND_LOWER ,
538535 )
539536 and print_table
540537 ):
@@ -625,7 +622,7 @@ def run_transform_for_annotation_pipeline(
625622 stage = self .cur
626623 # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
627624 artifact = self .get_artifact (stage )
628- if self .cur == self . stage_name ( tester . Export ) :
625+ if self .cur == StageType . EXPORT :
629626 new_gm = ArmPassManager (get_tosa_spec (self .compile_spec )).transform_for_annotation_pipeline ( # type: ignore[arg-type]
630627 graph_module = artifact .graph_module
631628 )
0 commit comments