Skip to content

Commit 29762b9

Browse files
committed
adding pytest_plugin to python so other repos can access
1 parent fe1090e commit 29762b9

File tree

2 files changed

+240
-208
lines changed

2 files changed

+240
-208
lines changed
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# pylint: disable=invalid-name,redefined-outer-name
19+
""" Hexagon testing fixtures used to deduce testing argument
20+
values from testing parameters """
21+
22+
import os
23+
import random
24+
from typing import Optional, Union
25+
26+
import pytest
27+
28+
import tvm
29+
import tvm.rpc.tracker
30+
from tvm.contrib.hexagon.build import HexagonLauncher, HexagonLauncherRPC
31+
from tvm.contrib.hexagon.session import Session
32+
33+
HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
34+
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
35+
TVM_TRACKER_PORT = "TVM_TRACKER_PORT"
36+
ANDROID_REMOTE_DIR = "ANDROID_REMOTE_DIR"
37+
ANDROID_SERIAL_NUMBER = "ANDROID_SERIAL_NUMBER"
38+
ADB_SERVER_SOCKET = "ADB_SERVER_SOCKET"
39+
40+
41+
@tvm.testing.fixture
42+
def shape_nhwc(batch, in_channel, in_size):
43+
return (batch, in_size, in_size, in_channel)
44+
45+
46+
def _compose(args, decs):
47+
"""Helper to apply multiple markers"""
48+
if len(args) > 0:
49+
func = args[0]
50+
for dec in reversed(decs):
51+
func = dec(func)
52+
return func
53+
return decs
54+
55+
56+
def requires_hexagon_toolchain(*args):
57+
_requires_hexagon_toolchain = [
58+
pytest.mark.skipif(
59+
os.environ.get(HEXAGON_TOOLCHAIN) is None,
60+
reason=f"Missing environment variable {HEXAGON_TOOLCHAIN}.",
61+
),
62+
]
63+
64+
return _compose(args, _requires_hexagon_toolchain)
65+
66+
67+
@tvm.testing.fixture
68+
def android_serial_number() -> Optional[str]:
69+
serial = os.getenv(ANDROID_SERIAL_NUMBER, default="")
70+
# Setting ANDROID_SERIAL_NUMBER to an empty string should be
71+
# equivalent to having it unset.
72+
if not serial.strip():
73+
serial = None
74+
return serial
75+
76+
77+
# NOTE on server ports:
78+
# These tests use different port numbers for the RPC server (7070 + ...).
79+
# The reason is that an RPC session cannot be gracefully closed without
80+
# triggering TIME_WAIT state on the server socket. This prevents another
81+
# server to bind to the same port until the wait time elapses.
82+
83+
LISTEN_PORT_MIN = 2000 # Well above the privileged ports (1024 or lower)
84+
LISTEN_PORT_MAX = 9000 # Below the search range end (port_end=9199) of RPC server
85+
PREVIOUS_PORT = None
86+
87+
88+
def get_free_port() -> int:
89+
"""Return the next port that is available to listen on"""
90+
global PREVIOUS_PORT
91+
if PREVIOUS_PORT is None:
92+
port = random.randint(LISTEN_PORT_MIN, LISTEN_PORT_MAX)
93+
else:
94+
port = PREVIOUS_PORT + 1
95+
96+
while tvm.contrib.hexagon.build._is_port_in_use(port):
97+
port = port + 1 if port < LISTEN_PORT_MAX else LISTEN_PORT_MIN
98+
99+
PREVIOUS_PORT = port
100+
return port
101+
102+
103+
@pytest.fixture(scope="session")
104+
def _tracker_info() -> Union[str, int]:
105+
env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="")
106+
env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="")
107+
108+
if env_tracker_host or env_tracker_port:
109+
# A tracker is already running, and we should connect to it
110+
# when running tests.
111+
assert env_tracker_host, "TVM_TRACKER_PORT is defined, but TVM_TRACKER_HOST is not"
112+
assert env_tracker_port, "TVM_TRACKER_HOST is defined, but TVM_TRACKER_PORT is not"
113+
env_tracker_port = int(env_tracker_port)
114+
115+
try:
116+
tvm.rpc.connect_tracker(env_tracker_host, env_tracker_port)
117+
except RuntimeError as exc:
118+
message = (
119+
"Could not connect to external tracker "
120+
"specified by $TVM_TRACKER_HOST and $TVM_TRACKER_PORT "
121+
f"({env_tracker_host}:{env_tracker_port})"
122+
)
123+
raise RuntimeError(message) from exc
124+
125+
yield (env_tracker_host, env_tracker_port)
126+
127+
else:
128+
# No tracker is provided to the tests, so we should start one
129+
# for the tests to use.
130+
tracker = tvm.rpc.tracker.Tracker("127.0.0.1", get_free_port())
131+
try:
132+
yield (tracker.host, tracker.port)
133+
finally:
134+
tracker.terminate()
135+
136+
137+
@pytest.fixture(scope="session")
138+
def tvm_tracker_host(_tracker_info) -> str:
139+
host, _ = _tracker_info
140+
return host
141+
142+
143+
@pytest.fixture(scope="session")
144+
def tvm_tracker_port(_tracker_info) -> int:
145+
_, port = _tracker_info
146+
return port
147+
148+
149+
@tvm.testing.fixture
150+
def rpc_server_port() -> int:
151+
return get_free_port()
152+
153+
154+
@tvm.testing.fixture
155+
def adb_server_socket() -> str:
156+
return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037")
157+
158+
159+
@tvm.testing.fixture
160+
def hexagon_launcher(
161+
request, android_serial_number, rpc_server_port, adb_server_socket
162+
) -> HexagonLauncherRPC:
163+
"""Initials and returns hexagon launcher if ANDROID_SERIAL_NUMBER is defined"""
164+
if android_serial_number is None:
165+
yield None
166+
else:
167+
# Requesting these fixtures sets up a local tracker, if one
168+
# hasn't been provided to us. Delaying the evaluation of
169+
# these fixtures avoids starting a tracker unless necessary.
170+
tvm_tracker_host = request.getfixturevalue("tvm_tracker_host")
171+
tvm_tracker_port = request.getfixturevalue("tvm_tracker_port")
172+
173+
rpc_info = {
174+
"rpc_tracker_host": tvm_tracker_host,
175+
"rpc_tracker_port": tvm_tracker_port,
176+
"rpc_server_port": rpc_server_port,
177+
"adb_server_socket": adb_server_socket,
178+
}
179+
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
180+
launcher.start_server()
181+
try:
182+
yield launcher
183+
finally:
184+
launcher.stop_server()
185+
186+
187+
@tvm.testing.fixture
188+
def hexagon_session(hexagon_launcher) -> Session:
189+
if hexagon_launcher is None:
190+
yield None
191+
else:
192+
with hexagon_launcher.start_session() as session:
193+
yield session
194+
195+
196+
# If the execution aborts while an RPC server is running, the python
197+
# code that is supposed to shut it down will never execute. This will
198+
# keep pytest from terminating (indefinitely), so add a cleanup
199+
# fixture to terminate any still-running servers.
200+
@pytest.fixture(scope="session", autouse=True)
201+
def terminate_rpc_servers():
202+
# Since this is a fixture that runs regardless of whether the
203+
# execution happens on simulator or on target, make sure the
204+
# yield happens every time.
205+
serial = os.environ.get(ANDROID_SERIAL_NUMBER)
206+
yield []
207+
if serial == "simulator":
208+
os.system("ps ax | grep tvm_rpc_x86 | awk '{print $1}' | xargs kill")
209+
210+
211+
aot_host_target = tvm.testing.parameter(
212+
"c",
213+
"llvm -keys=hexagon -link-params=0 "
214+
"-mattr=+hvxv68,+hvx-length128b,+hvx-qfloat,-hvx-ieee-fp "
215+
"-mcpu=hexagonv68 -mtriple=hexagon",
216+
)
217+
218+
219+
@tvm.testing.fixture
220+
def aot_target(aot_host_target):
221+
if aot_host_target == "c":
222+
yield tvm.target.hexagon("v68")
223+
elif aot_host_target.startswith("llvm"):
224+
yield aot_host_target
225+
else:
226+
assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]."
227+
228+
229+
def pytest_addoption(parser):
230+
parser.addoption("--gtest_args", action="store", default="")
231+
232+
233+
def pytest_generate_tests(metafunc):
234+
option_value = metafunc.config.option.gtest_args
235+
if "gtest_args" in metafunc.fixturenames and option_value is not None:
236+
metafunc.parametrize("gtest_args", [option_value])

0 commit comments

Comments
 (0)