Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move "fit" ProcessPool out of module top-level #647

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions httpstan/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import aiohttp.web

import httpstan.pools
import httpstan.routes

try:
Expand Down Expand Up @@ -41,5 +42,7 @@ def make_app() -> aiohttp.web.Application:
httpstan.routes.setup_routes(app)
# startup and shutdown tasks
app["operations"] = {}
httpstan.pools.setup_pools(app)
app.on_cleanup.append(_warn_unfinished_operations)
app.on_cleanup.append(httpstan.pools.shutdown_pools)
return app
47 changes: 47 additions & 0 deletions httpstan/pools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import concurrent.futures
import multiprocessing as mp
import signal

import aiohttp.web


def init_call_worker() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore KeyboardInterrupt


def setup_pools(app: aiohttp.web.Application) -> None:
"""Create any Process or Thread Pools needed by the application

This won't create the pools immediately, in case a feature that uses them
isn't used, but instead lazily. That's why the pools are represented by a
function instead of the pool exectur object itself.

"""
fit_executor = None

def create_fit_executor(shutdown=False):
nonlocal fit_executor

if shutdown:
if fit_executor is None:
return

fit_executor.shutdown()
return

if fit_executor is not None:
return fit_executor

# Use `get_context` to get a package-specific multiprocessing context.
# See "Contexts and start methods" in the `multiprocessing` docs for details.
fit_executor = concurrent.futures.ProcessPoolExecutor(
mp_context=mp.get_context("fork"), initializer=init_call_worker
)

return fit_executor

app["create_fit_executor"] = create_fit_executor


async def shutdown_pools(app: aiohttp.web.Application) -> None:
app["create_fit_executor"](shutdown=True)
14 changes: 2 additions & 12 deletions httpstan/services_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,21 @@
"""
import asyncio
import collections
import concurrent.futures
import functools
import io
import logging
import multiprocessing as mp
import os
import select
import signal
import socket
import tempfile
import typing
import zlib

import httpstan.cache
import httpstan.models
import httpstan.services.arguments as arguments
from httpstan.config import HTTPSTAN_DEBUG
from httpstan.services import arguments


# Use `get_context` to get a package-specific multiprocessing context.
# See "Contexts and start methods" in the `multiprocessing` docs for details.
def init_worker() -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN) # ignore KeyboardInterrupt


executor = concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("fork"), initializer=init_worker)
logger = logging.getLogger("httpstan")


Expand All @@ -59,6 +48,7 @@ async def call(
function_name: str,
model_name: str,
fit_name: str,
executor,
logger_callback: typing.Optional[typing.Callable] = None,
**kwargs: dict,
) -> None:
Expand Down
33 changes: 19 additions & 14 deletions httpstan/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
Handlers are separated from the endpoint names. Endpoints are defined in
`httpstan.routes`.
"""
import asyncio
import functools
import gzip
import http
import logging
import re
import traceback
from types import CoroutineType
from typing import Optional, Sequence, cast

import aiohttp.web
Expand Down Expand Up @@ -364,7 +364,7 @@ async def handle_create_fit(request: aiohttp.web.Request) -> aiohttp.web.Respons
request.app["operations"][operation_name] = operation_dict
return aiohttp.web.json_response(operation_dict, status=201)

def _services_call_done(operation: dict, future: asyncio.Future) -> None:
async def _services_call_done(operation: dict, coroutine: CoroutineType) -> None:
"""Called when services call (i.e., an operation) is done.

This needs to handle both successful and exception-raising calls.
Expand All @@ -374,11 +374,12 @@ def _services_call_done(operation: dict, future: asyncio.Future) -> None:
future: Finished future

"""
# either the call succeeded or it raised an exception.
operation["done"] = True

exc = future.exception()
if exc:
try:
await coroutine
logger.info("Operation `%s` finished.", operation["name"])
operation["result"] = schemas.Fit().load(operation["metadata"]["fit"])
except Exception as exc:
# e.g., "hmc_nuts_diag_e_adapt_wrapper() got an unexpected keyword argument, ..."
# e.g., dimension errors in variable declarations
# e.g., initialization failed
Expand All @@ -394,9 +395,9 @@ def _services_call_done(operation: dict, future: asyncio.Future) -> None:
httpstan.cache.delete_fit(operation["metadata"]["fit"]["name"])
except KeyError:
pass
else:
logger.info(f"Operation `{operation['name']}` finished.")
operation["result"] = schemas.Fit().load(operation["metadata"]["fit"])
finally:
# either the call succeeded or it raised an exception.
operation["done"] = True

operation_name = f'operations/{name.split("/")[-1]}'
operation_dict = schemas.Operation().load(
Expand All @@ -414,12 +415,16 @@ def logger_callback(operation: dict, message: bytes) -> None:
operation["metadata"]["progress"] = iteration_info_re.findall(message).pop().decode()

logger_callback_partial = functools.partial(logger_callback, operation_dict)
task = asyncio.create_task(
services_stub.call(
function, model_name, operation_dict["metadata"]["fit"]["name"], logger_callback_partial, **args
)

call = services_stub.call(
function,
model_name,
operation_dict["metadata"]["fit"]["name"],
request.app["create_fit_executor"](),
logger_callback_partial,
**args,
)
task.add_done_callback(functools.partial(_services_call_done, operation_dict))
await _services_call_done(operation_dict, call)
request.app["operations"][operation_name] = operation_dict
return aiohttp.web.json_response(operation_dict, status=201)

Expand Down