Skip to content

Commit cd6aa7b

Browse files
authored
[Hexagon] Move aot/graph_executor interactions into launcher (#10907)
* [Hexagon] Move aot/graph_executor interactions into launcher Follow-up from #10581, applying similar changes to the AOT and graph executor interactions. This moves the file management and upload/download from the unit tests into the launcher. * Added Session.test_executor to avoid duplication in graph/aot test. * Resolve lint errors * Moved link flags workaround out of session, into create_aot_shared * Separated Session.get_*_executor and Session.get_executor_from_factory * Updated to resolve lint error
1 parent 03bbf14 commit cd6aa7b

File tree

4 files changed

+193
-57
lines changed

4 files changed

+193
-57
lines changed

python/tvm/contrib/hexagon/build.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,7 @@ def get_aot_executor(self, module_name: Union[str, pathlib.Path], session: Sessi
266266
aot_module : AotModule
267267
Runtime AOT module that can be used to execute.
268268
"""
269-
aot_mod = self.load_module(module_name, session)
270-
return tvm.runtime.executor.AotModule(aot_mod["default"](session.device))
269+
return session.get_aot_executor(module_name)
271270

272271

273272
class HexagonLauncherAndroid(HexagonLauncherRPC):

python/tvm/contrib/hexagon/session.py

Lines changed: 171 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
import tvm
2626
from tvm import rpc as _rpc
27+
import tvm.contrib.hexagon as hexagon
28+
from tvm.relay.backend.executor_factory import (
29+
ExecutorFactoryModule,
30+
AOTExecutorFactoryModule,
31+
GraphExecutorFactoryModule,
32+
)
2733

2834

2935
class Session:
@@ -101,6 +107,9 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
101107
def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
102108
"""Load TVM module.
103109
110+
The session must be established (via __enter__) prior to
111+
calling this function.
112+
104113
Parameters
105114
----------
106115
module : Union[str, pathlib.Path, tvm.runtime.Module]
@@ -115,16 +124,16 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
115124
the file must already have been uploaded to the remote,
116125
and be placed in the remote workspace.
117126
118-
session : Session
119-
120-
Remote session. The session must be established (via __enter__)
121-
prior to calling this function.
122-
123127
Returns
124128
-------
125129
TVMModule :
126130
TVM module object.
127131
"""
132+
133+
assert (
134+
self.device is not None
135+
), "Hexagon session must be started using __enter__ prior to use"
136+
128137
if isinstance(module, tvm.runtime.Module):
129138
with tempfile.TemporaryDirectory() as temp_dir:
130139
temp_dir = pathlib.Path(temp_dir)
@@ -136,3 +145,160 @@ def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
136145

137146
assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module))
138147
return self._rpc.get_function("tvm.hexagon.load_module")(str(module))
148+
149+
def get_graph_executor(
150+
self,
151+
graph_json: str,
152+
module_name: Union[str, pathlib.Path],
153+
):
154+
"""Create a local GraphModule which consumes a remote libmod.
155+
156+
The session must be established (via __enter__) prior to
157+
calling this function.
158+
159+
Parameters
160+
----------
161+
162+
module_name : Union[str, pathlib.Path]
163+
164+
The remote module filename, following the same restrictions
165+
as `load_module`.
166+
167+
graph_json : str
168+
169+
The string with the graph JSON.
170+
171+
Returns
172+
-------
173+
GraphModule :
174+
Runtime graph module that can be used to execute the graph.
175+
176+
"""
177+
178+
graph_mod = self.load_module(module_name)
179+
return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)
180+
181+
def get_aot_executor(
182+
self,
183+
module_name: Union[str, pathlib.Path],
184+
):
185+
"""Create a local GraphModule which consumes a remote libmod.
186+
187+
The session must be established (via __enter__) prior to
188+
calling this function.
189+
190+
Parameters
191+
----------
192+
193+
module_name : Union[str, pathlib.Path]
194+
195+
The remote module filename, following the same restrictions
196+
as `load_module`.
197+
198+
Returns
199+
-------
200+
GraphModule :
201+
Runtime graph module that can be used to execute the graph.
202+
203+
"""
204+
205+
aot_mod = self.load_module(module_name)
206+
return tvm.runtime.executor.AotModule(aot_mod["default"](self.device))
207+
208+
def get_executor_from_factory(self, module: ExecutorFactoryModule):
209+
"""Create a local GraphModule which consumes a remote libmod.
210+
211+
Parameters
212+
----------
213+
214+
module : ExecutorFactoryModule
215+
216+
The module to upload to the remote
217+
session and load.
218+
"""
219+
if isinstance(module, AOTExecutorFactoryModule):
220+
return self._aot_executor_from_factory(module)
221+
if isinstance(module, GraphExecutorFactoryModule):
222+
return self._graph_executor_from_factory(module)
223+
224+
raise TypeError(f"Unsupported executor type: {type(module)}")
225+
226+
def _graph_executor_from_factory(
227+
self,
228+
module: Union[str, pathlib.Path, GraphExecutorFactoryModule],
229+
):
230+
"""Create a local GraphModule which consumes a remote libmod.
231+
232+
The session must be established (via __enter__) prior to
233+
calling this function.
234+
235+
Parameters
236+
----------
237+
238+
module : GraphExecutorFactoryModule
239+
240+
The graph executor module to upload to the remote and load.
241+
This will typically be the output of `tvm.relay.build`,
242+
when passing `executor=Executor("graph")`.
243+
244+
Returns
245+
-------
246+
GraphModule :
247+
Runtime graph module that can be used to execute the graph.
248+
249+
"""
250+
251+
graph_json = module.get_graph_json()
252+
graph_mod = self.load_module(module.get_lib())
253+
254+
return tvm.contrib.graph_executor.create(graph_json, graph_mod, self.device)
255+
256+
def _aot_executor_from_factory(
257+
self,
258+
module: Union[str, pathlib.Path, AOTExecutorFactoryModule],
259+
):
260+
"""Create a local GraphModule which consumes a remote libmod.
261+
262+
The session must be established (via __enter__) prior to
263+
calling this function.
264+
265+
Parameters
266+
----------
267+
268+
module : AOTExecutorFactoryModule
269+
270+
The graph executor module to upload to the remote and load.
271+
This will typically be the output of `tvm.relay.build`,
272+
when passing `executor=Executor("aot")`.
273+
274+
Returns
275+
-------
276+
GraphModule :
277+
Runtime graph module that can be used to execute the graph.
278+
279+
"""
280+
281+
hexagon_arch = set(
282+
target.mcpu.replace("hexagon", "")
283+
for target in module.target.values()
284+
if "hexagon" in target.keys
285+
)
286+
assert hexagon_arch, "No hexagon target architecture found"
287+
assert len(hexagon_arch) == 1, f"Inconsistent hexagon architecture found, {hexagon_arch}"
288+
hexagon_arch = hexagon_arch.pop()
289+
290+
with tempfile.TemporaryDirectory() as temp_dir:
291+
temp_dir = pathlib.Path(temp_dir)
292+
binary_name = "test_binary.so"
293+
binary_path = temp_dir / binary_name
294+
295+
module.export_library(
296+
str(binary_path),
297+
fcompile=hexagon.create_aot_shared,
298+
hexagon_arch=hexagon_arch,
299+
)
300+
301+
self.upload(binary_path, binary_name)
302+
303+
aot_mod = self.load_module(binary_name)
304+
return tvm.runtime.executor.AotModule(aot_mod["default"](self.device))

python/tvm/contrib/hexagon/tools.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def create_aot_shared(so_name: Union[str, pathlib.Path], files, hexagon_arch: st
160160
+ "HEXAGON_SDK_PATH in your environment."
161161
)
162162

163+
# The AOT C codegen uses TVM runtime functions
164+
# (e.g. TVMBackendAllocWorkspace) directly. On Hexagon these calls
165+
# should be made using functions pointers provided as __TVM*
166+
# variables in the provided context. This workaround allows the
167+
# the TVM runtime symbols to be visible to the compiled shared
168+
# library.
169+
#
170+
# This workaround can be removed when AOT codegen can be done with
171+
# LLVM codegen.
172+
workaround_link_flags = os.environ.get("HEXAGON_SHARED_LINK_FLAGS")
173+
if workaround_link_flags:
174+
options.extend(workaround_link_flags.split())
175+
163176
tvm_dir = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) / ".." / ".." / ".." / ".."
164177
compute_arch = f"compute{hexagon_arch}"
165178
compile_options = [

0 commit comments

Comments
 (0)