2424 process_output ,
2525 process_placeholder ,
2626)
27- from executorch .backends .arm .tosa .specification import get_tosa_spec
27+ from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
2828from executorch .exir .backend .backend_details import BackendDetails , PreprocessResult
2929from executorch .exir .backend .compile_spec_schema import CompileSpec
3030from torch .export .exported_program import ExportedProgram
@@ -80,38 +80,24 @@ class TOSABackend(BackendDetails):
8080 """
8181
8282 @staticmethod
83- def preprocess ( # noqa: C901
83+ def preprocess (edge_program : ExportedProgram , compile_specs : List [CompileSpec ]):
84+ return TOSABackend ._preprocess (
85+ edge_program , TosaCompileSpec .from_list (compile_specs )
86+ )
87+
88+ @staticmethod
89+ def _preprocess ( # noqa: C901
8490 edge_program : ExportedProgram ,
85- compile_spec : List [ CompileSpec ] ,
91+ compile_spec : TosaCompileSpec ,
8692 ) -> PreprocessResult :
8793 # if a debug/test build capture output files from TOSA stage
88- artifact_path = None
89- output_format = ""
90- compile_flags = []
91- dump_debug_info = None
92- for spec in compile_spec :
93- if spec .key == "debug_artifact_path" :
94- artifact_path = spec .value .decode ()
95- if spec .key == "output_format" :
96- output_format = spec .value .decode ()
97- if spec .key == "compile_flags" :
98- compile_flags .append (spec .value .decode ())
99- if spec .key == "dump_debug_info" :
100- dump_debug_info = spec .value .decode ()
101-
102- # Check that the output format is set correctly in the compile spec
103- if output_format != "tosa" :
104- raise ValueError (f'Invalid output format { output_format } , must be "tosa"' )
94+ artifact_path = compile_spec .get_intermediate_path ()
95+ tosa_spec = compile_spec .tosa_spec
96+ dump_debug_info = compile_spec .tosa_debug_mode
10597
10698 # Assign to every node external id
10799 node_2_id = _annotate_external_ids (edge_program .graph )
108100
109- tosa_spec = get_tosa_spec (compile_spec )
110- if tosa_spec is None :
111- raise ValueError (
112- "TOSA backend needs a TOSA version specified in the CompileSpec"
113- )
114-
115101 logger .info (f"Converting ExportedProgram to TOSA: { tosa_spec } " )
116102
117103 # Converted output for this subgraph, serializer needs path early as it emits
@@ -132,7 +118,7 @@ def preprocess( # noqa: C901
132118
133119 debug_hook = None
134120 if dump_debug_info is not None :
135- debug_hook = DebugHook (ArmCompileSpec . DebugMode [ dump_debug_info ] )
121+ debug_hook = DebugHook (dump_debug_info )
136122
137123 # TODO: Fix the need to lazily import this.
138124 from executorch .backends .arm .operators .node_visitor import get_node_visitors
@@ -204,8 +190,8 @@ def _sort_key(t: Node) -> int:
204190
205191 @staticmethod
206192 def filter_tosa_compile_specs (
207- compile_spec : List [ CompileSpec ] ,
208- ) -> List [ CompileSpec ] :
193+ compile_spec : ArmCompileSpec ,
194+ ) -> TosaCompileSpec :
209195 """
210196 Filter out the CompileSpec elements relevant for the TOSA backend.
211197 This is needed to compose a backend targetting hardware IP with the
@@ -214,17 +200,9 @@ def filter_tosa_compile_specs(
214200 flatbuffer can then be consumed by the backend targetting specific
215201 hardware.
216202 """
217- tosa_compile_spec = []
218- tosa_compile_spec .append (CompileSpec ("output_format" , "tosa" .encode ()))
219-
220- # Copy everything that's TOSA generic
221- tosa_backend_compile_spec_keys = [
222- "tosa_spec" ,
223- "debug_artifact_path" ,
224- ]
225203
226- for spec in compile_spec :
227- if spec . key in tosa_backend_compile_spec_keys :
228- tosa_compile_spec . append ( CompileSpec ( spec . key , spec . value ) )
229-
230- return tosa_compile_spec
204+ new_compile_spec = TosaCompileSpec . __new__ ( TosaCompileSpec )
205+ new_compile_spec . _set_compile_specs (
206+ compile_spec . tosa_spec , [], compile_spec . get_intermediate_path ( )
207+ )
208+ return new_compile_spec
0 commit comments