Skip to content

Commit

Permalink
Add utility class to support background server
Browse files Browse the repository at this point in the history
Summary: A subclass of the aepsych server with methods specifically to run the server in a background process. This will be used to ensured that even within the same main script, the server will run like an actual server and does not do anything sneaky like bypassing the async queue.

Test Plan: New test
  • Loading branch information
JasonKChow committed Feb 10, 2025
1 parent c975e6a commit bb6a1de
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 13 deletions.
4 changes: 2 additions & 2 deletions aepsych/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .server import AEPsychServer
from .server import AEPsychBackgroundServer, AEPsychServer

__all__ = ["AEPsychServer"]
__all__ = ["AEPsychServer", "AEPsychBackgroundServer"]
68 changes: 60 additions & 8 deletions aepsych/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
replay,
)
from aepsych.strategy import SequentialStrategy, Strategy
from multiprocess import Process

logger = utils_logging.getLogger()

Expand All @@ -48,11 +49,9 @@ def __init__(
host: str = "0.0.0.0",
port: int = 5555,
database_path: str = "./databases/default.db",
max_workers: Optional[int] = None,
):
self.host = host
self.port = port
self.max_workers = max_workers
self.clients_connected = 0
self.db: db.Database = db.Database(database_path)
self.is_performing_replay = False
Expand Down Expand Up @@ -278,11 +277,6 @@ def start_blocking(self) -> None:
process or machine."""
asyncio.run(self.serve())

def start_background(self):
"""Starts the server in a background thread. Used for scripts where the
client and server are in the same process."""
raise NotImplementedError

async def serve(self) -> None:
"""Serves the server on the set IP and port. This creates a coroutine
for asyncio to handle requests asyncronously.
Expand All @@ -291,7 +285,7 @@ async def serve(self) -> None:
self.handle_client, self.host, self.port
)
self.loop = asyncio.get_running_loop()
pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers)
pool = concurrent.futures.ThreadPoolExecutor()
self.loop.set_default_executor(pool)

async with self.server:
Expand Down Expand Up @@ -427,6 +421,64 @@ def __getstate__(self):
return state


class AEPsychBackgroundServer(AEPsychServer):
"""A class to handle the server in a background thread. Unlike the normal
AEPsychServer, this does not create the db right away until the server is
started. When starting this server, it'll be sent to another process, a db
will be initialized, then the server will be served. This server should then
be interacted with by the main thread via a client."""

def __init__(
self,
host: str = "0.0.0.0",
port: int = 5555,
database_path: str = "./databases/default.db",
):
self.host = host
self.port = port
self.database_path = database_path
self.clients_connected = 0
self.is_performing_replay = False
self.exit_server_loop = False
self._db_raw_record = None
self.skip_computations = False
self.background_process = None
self.strat_names = None
self.extensions = None
self._strats = []
self._parnames = []
self._configs = []
self._master_records = []
self.strat_id = -1
self.outcome_names = []

def _start_server(self) -> None:
self.db: db.Database = db.Database(self.database_path)
if self.db.is_update_required():
self.db.perform_updates()

super().start_blocking()

def start(self):
"""Starts the server in a background thread. Used by the client to start
the server for a client in another process or machine."""
self.background_process = Process(target=self._start_server, daemon=True)
self.background_process.start()

def stop(self):
"""Stops the server and closes the background process."""
self.exit_server_loop = True
self.background_process.terminate()
self.background_process.join()
self.background_process.close()
self.background_process = None

def __getstate__(self):
# Override parent's __getstate__ to not worry about the db
state = self.__dict__.copy()
return state


def parse_argument():
parser = argparse.ArgumentParser(description="AEPsych Server")
parser.add_argument(
Expand Down
109 changes: 106 additions & 3 deletions tests/server/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@

import asyncio
import json
import logging
import time
import unittest
import uuid
from pathlib import Path
from typing import Any, Dict

import aepsych.server as server
import aepsych.utils_logging as utils_logging
from aepsych.server.sockets import BAD_REQUEST

dummy_config = """
[common]
Expand Down Expand Up @@ -87,7 +86,7 @@ async def asyncSetUp(self):
self.port = 5555

# setup logger
server.logger = utils_logging.getLogger("unittests")
self.logger = utils_logging.getLogger("unittests")

# random datebase path name without dashes
database_path = self.database_path
Expand Down Expand Up @@ -523,5 +522,109 @@ async def _mock_client2(request: Dict[str, Any]) -> Any:
self.assertTrue(self.s.clients_connected == 2)


class BackgroundServerTestCase(unittest.IsolatedAsyncioTestCase):
@property
def database_path(self):
return "./{}_test_server.db".format(str(uuid.uuid4().hex))

async def asyncSetUp(self):
self.ip = "0.0.0.0"
self.port = 5555

# setup logger
self.logger = utils_logging.getLogger("unittests")

# random datebase path name without dashes
database_path = self.database_path
self.s = server.AEPsychBackgroundServer(
database_path=database_path, host=self.ip, port=self.port
)
self.db_name = database_path.split("/")[1]
self.db_path = database_path

# Writer will be made in tests
self.writer = None

async def asyncTearDown(self):
# Stops the client
if self.writer is not None:
self.writer.close()

time.sleep(0.1)

# cleanup the db
db_path = Path(self.db_path)
try:
print(db_path)
db_path.unlink()
except PermissionError as e:
print("Failed to deleted database: ", e)

async def test_background_server(self):
self.assertIsNone(self.s.background_process)
self.s.start()
self.assertTrue(self.s.background_process.is_alive())

# Give time for the server to start
time.sleep(5)

# Create a client
reader, self.writer = await asyncio.open_connection(self.ip, self.port)

async def _mock_client(request: Dict[str, Any]) -> Any:
self.writer.write(json.dumps(request).encode())
await self.writer.drain()

response = await reader.read(1024 * 512)
return response.decode()

setup_request = {
"type": "setup",
"version": "0.01",
"message": {"config_str": dummy_config},
}
ask_request = {"type": "ask", "message": ""}
tell_request = {
"type": "tell",
"message": {"config": {"x": [0.5]}, "outcome": 1},
"extra_info": {},
}

await _mock_client(setup_request)

expected_x = [0, 1, 2, 3]
expected_z = list(reversed(expected_x))
expected_y = [x % 2 for x in expected_x]
i = 0
while True:
response = await _mock_client(ask_request)
response = json.loads(response)
tell_request["message"]["config"]["x"] = [expected_x[i]]
tell_request["message"]["config"]["z"] = [expected_z[i]]
tell_request["message"]["outcome"] = expected_y[i]
tell_request["extra_info"]["e1"] = 1
tell_request["extra_info"]["e2"] = 2
i = i + 1
await _mock_client(tell_request)

if response["is_finished"]:
break

self.s.stop()
self.assertIsNone(self.s.background_process)

# Create a synchronous server to check db contents
s = server.AEPsychServer(database_path=self.db_path)
unique_id = s.db.get_master_records()[-1].unique_id
out_df = s.get_dataframe_from_replay(unique_id)
self.assertTrue((out_df.x == expected_x).all())
self.assertTrue((out_df.z == expected_z).all())
self.assertTrue((out_df.response == expected_y).all())
self.assertTrue((out_df.e1 == [1] * 4).all())
self.assertTrue((out_df.e2 == [2] * 4).all())
self.assertTrue("post_mean" in out_df.columns)
self.assertTrue("post_var" in out_df.columns)


if __name__ == "__main__":
unittest.main()

0 comments on commit bb6a1de

Please sign in to comment.