2424
2525import tvm
2626from tvm import rpc as _rpc
27+ import tvm .contrib .hexagon as hexagon
28+ from tvm .relay .backend .executor_factory import (
29+ ExecutorFactoryModule ,
30+ AOTExecutorFactoryModule ,
31+ GraphExecutorFactoryModule ,
32+ )
2733
2834
2935class Session :
@@ -101,6 +107,9 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
101107 def load_module (self , module : Union [str , pathlib .Path , tvm .runtime .Module ]):
102108 """Load TVM module.
103109
110+ The session must be established (via __enter__) prior to
111+ calling this function.
112+
104113 Parameters
105114 ----------
106115 module : Union[str, pathlib.Path, tvm.runtime.Module]
@@ -115,16 +124,16 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
115124 the file must already have been uploaded to the remote,
116125 and be placed in the remote workspace.
117126
118- session : Session
119-
120- Remote session. The session must be established (via __enter__)
121- prior to calling this function.
122-
123127 Returns
124128 -------
125129 TVMModule :
126130 TVM module object.
127131 """
132+
133+ assert (
134+ self .device is not None
135+ ), "Hexagon session must be started using __enter__ prior to use"
136+
128137 if isinstance (module , tvm .runtime .Module ):
129138 with tempfile .TemporaryDirectory () as temp_dir :
130139 temp_dir = pathlib .Path (temp_dir )
@@ -136,3 +145,160 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
136145
137146 assert isinstance (module , (str , pathlib .Path )), "Invalid path type:" + str (type (module ))
138147 return self ._rpc .get_function ("tvm.hexagon.load_module" )(str (module ))
148+
149+ def get_graph_executor (
150+ self ,
151+ graph_json : str ,
152+ module_name : Union [str , pathlib .Path ],
153+ ):
154+ """Create a local GraphModule which consumes a remote libmod.
155+
156+ The session must be established (via __enter__) prior to
157+ calling this function.
158+
159+ Parameters
160+ ----------
161+
162+ module_name : Union[str, pathlib.Path]
163+
164+ The remote module filename, following the same restrictions
165+ as `load_module`.
166+
167+ graph_json : str
168+
169+ The string with the graph JSON.
170+
171+ Returns
172+ -------
173+ GraphModule :
174+ Runtime graph module that can be used to execute the graph.
175+
176+ """
177+
178+ graph_mod = self .load_module (module_name )
179+ return tvm .contrib .graph_executor .create (graph_json , graph_mod , self .device )
180+
181+ def get_aot_executor (
182+ self ,
183+ module_name : Union [str , pathlib .Path ],
184+ ):
185+ """Create a local GraphModule which consumes a remote libmod.
186+
187+ The session must be established (via __enter__) prior to
188+ calling this function.
189+
190+ Parameters
191+ ----------
192+
193+ module_name : Union[str, pathlib.Path]
194+
195+ The remote module filename, following the same restrictions
196+ as `load_module`.
197+
198+ Returns
199+ -------
200+ GraphModule :
201+ Runtime graph module that can be used to execute the graph.
202+
203+ """
204+
205+ aot_mod = self .load_module (module_name )
206+ return tvm .runtime .executor .AotModule (aot_mod ["default" ](self .device ))
207+
208+ def get_executor_from_factory (self , module : ExecutorFactoryModule ):
209+ """Create a local GraphModule which consumes a remote libmod.
210+
211+ Parameters
212+ ----------
213+
214+ module : ExecutorFactoryModule
215+
216+ The module to upload to the remote
217+ session and load.
218+ """
219+ if isinstance (module , AOTExecutorFactoryModule ):
220+ return self ._aot_executor_from_factory (module )
221+ if isinstance (module , GraphExecutorFactoryModule ):
222+ return self ._graph_executor_from_factory (module )
223+
224+ raise TypeError (f"Unsupported executor type: { type (module )} " )
225+
226+ def _graph_executor_from_factory (
227+ self ,
228+ module : Union [str , pathlib .Path , GraphExecutorFactoryModule ],
229+ ):
230+ """Create a local GraphModule which consumes a remote libmod.
231+
232+ The session must be established (via __enter__) prior to
233+ calling this function.
234+
235+ Parameters
236+ ----------
237+
238+ module : GraphExecutorFactoryModule
239+
240+ The graph executor module to upload to the remote and load.
241+ This will typically be the output of `tvm.relay.build`,
242+ when passing `executor=Executor("graph")`.
243+
244+ Returns
245+ -------
246+ GraphModule :
247+ Runtime graph module that can be used to execute the graph.
248+
249+ """
250+
251+ graph_json = module .get_graph_json ()
252+ graph_mod = self .load_module (module .get_lib ())
253+
254+ return tvm .contrib .graph_executor .create (graph_json , graph_mod , self .device )
255+
256+ def _aot_executor_from_factory (
257+ self ,
258+ module : Union [str , pathlib .Path , AOTExecutorFactoryModule ],
259+ ):
260+ """Create a local GraphModule which consumes a remote libmod.
261+
262+ The session must be established (via __enter__) prior to
263+ calling this function.
264+
265+ Parameters
266+ ----------
267+
268+ module : AOTExecutorFactoryModule
269+
270+ The graph executor module to upload to the remote and load.
271+ This will typically be the output of `tvm.relay.build`,
272+ when passing `executor=Executor("aot")`.
273+
274+ Returns
275+ -------
276+ GraphModule :
277+ Runtime graph module that can be used to execute the graph.
278+
279+ """
280+
281+ hexagon_arch = set (
282+ target .mcpu .replace ("hexagon" , "" )
283+ for target in module .target .values ()
284+ if "hexagon" in target .keys
285+ )
286+ assert hexagon_arch , "No hexagon target architecture found"
287+ assert len (hexagon_arch ) == 1 , f"Inconsistent hexagon architecture found, { hexagon_arch } "
288+ hexagon_arch = hexagon_arch .pop ()
289+
290+ with tempfile .TemporaryDirectory () as temp_dir :
291+ temp_dir = pathlib .Path (temp_dir )
292+ binary_name = "test_binary.so"
293+ binary_path = temp_dir / binary_name
294+
295+ module .export_library (
296+ str (binary_path ),
297+ fcompile = hexagon .create_aot_shared ,
298+ hexagon_arch = hexagon_arch ,
299+ )
300+
301+ self .upload (binary_path , binary_name )
302+
303+ aot_mod = self .load_module (binary_name )
304+ return tvm .runtime .executor .AotModule (aot_mod ["default" ](self .device ))
0 commit comments