Skip to content

Commit

Permalink
Close SQLA session after every REST API request
Browse files Browse the repository at this point in the history
This is needed due to the server running in threaded mode, i.e.,
creating a new thread for each incoming request. This concept is great
for handling many requests, but crashes when used together with AiiDA's
global singleton SQLA session used, no matter the backend of the profile
by the `QueryBuilder`.

Specifically, this leads to issues with the SQLA QueuePool, since the
connections are not properly released when a thread is closed. This
leads to unintended QueuePool overflow.

This fix wraps all HTTP method requests and makes sure to close the
current thread's SQLA session after the request as been completely
handled.

Use Flask-RESTful's integrated `Resource` attribute `method_decorators`
to apply `close_session` wrapper to all and any HTTP request that may
be requested of AiiDA's `BaseResource` (and its sub-classes).

Additionally, remove the `__init__` function overwritten in
`Node(BaseResource)`, since it is redundant, and the attributes `tclass`
is not relevant with v4 (AiiDA v1.0.0 and above), but was never removed.
It should have been removed when moving to v4 in 4ff2829.

Concerning the added tests: the timeout needs to be set for Python 3.5
in order to stop the http socket and properly raise (and escape out of
an infinite loop). The `capfd` fixture must be used, otherwise the
exception cannot be properly captured.

The tests were simplified into the pytest scheme with ideas from
@sphuber and @greschd.
  • Loading branch information
CasperWA authored and sphuber committed May 11, 2020
1 parent 924f02e commit b9d4bbe
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 33 deletions.
2 changes: 1 addition & 1 deletion aiida/manage/tests/unittest_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run(self, suite, backend=None, profile_name=None):
import warnings
from aiida.common.warnings import AiidaDeprecationWarning
warnings.warn( # pylint: disable=no-member
'Please use "pytest" for testing AiiDA plugins. Support for "unittest" be removed in `v2.0.0`',
'Please use "pytest" for testing AiiDA plugins. Support for "unittest" will be removed in `v2.0.0`',
AiidaDeprecationWarning
)

Expand Down
5 changes: 2 additions & 3 deletions aiida/restapi/common/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
'process.calculation%.calcfunction.%|aiida.calculations:arithmetic.add' # More than one operator in segment
"""

import collections
from collections.abc import MutableMapping

from aiida.common.escaping import escape_for_sql_like

Expand Down Expand Up @@ -163,7 +162,7 @@ def load_entry_point_from_full_type(full_type):
raise EntryPointError('entry point of the given full type cannot be loaded')


class Namespace(collections.MutableMapping):
class Namespace(MutableMapping):
"""Namespace that can be used to map the node class hierarchy."""

namespace_separator = '.'
Expand Down
18 changes: 17 additions & 1 deletion aiida/restapi/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
# For further information please visit http://www.aiida.net #
###########################################################################
""" Util methods """
import urllib.parse
from datetime import datetime, timedelta
import urllib.parse

from flask import jsonify
from flask.json import JSONEncoder
from wrapt import decorator

from aiida.common.exceptions import InputValidationError, ValidationError
from aiida.manage.manager import get_manager
from aiida.restapi.common.exceptions import RestValidationError, \
RestInputValidationError

Expand Down Expand Up @@ -845,3 +847,17 @@ def list_routes():
output.append(line)

return sorted(set(output))


@decorator
def close_session(wrapped, _, args, kwargs):
"""Close AiiDA SQLAlchemy (QueryBuilder) session
This decorator can be used for router endpoints to close the SQLAlchemy global scoped session after the response
has been created. This is needed, since the QueryBuilder uses a SQLAlchemy global scoped session no matter the
profile's database backend.
"""
try:
return wrapped(*args, **kwargs)
finally:
get_manager().get_backend().get_session().close()
26 changes: 9 additions & 17 deletions aiida/restapi/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@

from aiida.common.lang import classproperty
from aiida.restapi.common.exceptions import RestInputValidationError
from aiida.restapi.common.utils import Utils
from aiida.restapi.common.utils import Utils, close_session


class ServerInfo(Resource):
# pylint: disable=fixme
"""Endpointd to return general server info"""
"""Endpoint to return general server info"""

def __init__(self, **kwargs):
# Configure utils
Expand Down Expand Up @@ -97,6 +96,8 @@ class BaseResource(Resource):
_translator_class = BaseTranslator
_parse_pk_uuid = None # Flag to tell the path parser whether to expect a pk or a uuid pattern

method_decorators = [close_session] # Close SQLA session after any method call

## TODO add the caching support. I cache total count, results, and possibly

def __init__(self, **kwargs):
Expand All @@ -106,11 +107,13 @@ def __init__(self, **kwargs):
utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT')
self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs}
self.utils = Utils(**self.utils_confs)
self.method_decorators = {'get': kwargs.get('get_decorators', [])}

# HTTP Request method decorators
if 'get_decorators' in kwargs and isinstance(kwargs['get_decorators'], (tuple, list, set)):
self.method_decorators = {'get': list(kwargs['get_decorators'])}

@classproperty
def parse_pk_uuid(cls):
# pylint: disable=no-self-argument
def parse_pk_uuid(cls): # pylint: disable=no-self-argument
return cls._parse_pk_uuid

def _load_and_verify(self, node_id=None):
Expand Down Expand Up @@ -212,17 +215,6 @@ class Node(BaseResource):
_translator_class = NodeTranslator
_parse_pk_uuid = 'uuid' # Parse a uuid pattern in the URL path (not a pk)

def __init__(self, **kwargs):
super().__init__(**kwargs)
from aiida.orm import Node as tNode
self.tclass = tNode

# Configure utils
utils_conf_keys = ('PREFIX', 'PERPAGE_DEFAULT', 'LIMIT_DEFAULT')
self.utils_confs = {k: kwargs[k] for k in utils_conf_keys if k in kwargs}
self.utils = Utils(**self.utils_confs)
self.method_decorators = {'get': kwargs.get('get_decorators', [])}

def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-name,unused-argument
# pylint: disable=too-many-locals,too-many-statements,too-many-branches,fixme,unused-variable
"""
Expand Down
12 changes: 6 additions & 6 deletions aiida/restapi/run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs)
port = kwargs.pop('port', CLI_DEFAULTS['PORT'])
debug = kwargs.pop('debug', APP_CONFIG['DEBUG'])

app, api = configure_api(flask_app, flask_api, **kwargs)
api = configure_api(flask_app, flask_api, **kwargs)

if hookup:
# Run app through built-in werkzeug server
Expand All @@ -66,7 +66,7 @@ def run_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs)
else:
# Return the app & api without specifying port/host to be handled by an external server (e.g. apache).
# Some of the user-defined configuration of the app is ineffective (only affects built-in server).
return (app, api)
return api.app, api


def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **kwargs):
Expand All @@ -81,7 +81,8 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k
:param catch_internal_server: If true, catch and print all inter server errors
:param wsgi_profile: use WSGI profiler middleware for finding bottlenecks in web application
:returns: tuple (app, api)
:returns: Flask RESTful API
:rtype: :py:class:`flask_restful.Api`
"""

# Unpack parameters
Expand Down Expand Up @@ -119,6 +120,5 @@ def configure_api(flask_app=api_classes.App, flask_api=api_classes.AiidaApi, **k
app.config['PROFILE'] = True
app.wsgi_app = ProfilerMiddleware(app.wsgi_app, restrictions=[30])

# Instantiate an Api by associating its app
api = flask_api(app, **API_CONFIG)
return (app, api)
# Instantiate and return a Flask RESTful API by associating its app
return flask_api(app, **API_CONFIG)
5 changes: 3 additions & 2 deletions docs/source/developer_guide/core/extend_restapi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,14 @@ as confirmed by the response to the GET request.

As a final remark, there might be circumstances in which you do not want to use the internal werkzeug-based server.
For example, you might want to run the app through Apache using a wsgi script.
In this case, simply use ``configure_api`` to return two custom objects ``app`` and ``api``:
In this case, simply use ``configure_api`` to return a custom object ``api``:

.. code-block:: python
(app, api) = configure_api(App, MycloudApi, **kwargs)
api = configure_api(App, MycloudApi, **kwargs)
The ``app`` can be retrieved by ``api.app``.
This snippet of code becomes the fundamental block of a *wsgi* file used by Apache as documented in :ref:`restapi_apache`.
Moreover, we recommend to consult the documentation of `mod_wsgi <https://modwsgi.readthedocs.io/>`_.

Expand Down
50 changes: 50 additions & 0 deletions tests/restapi/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""pytest fixtures for use with the aiida.restapi tests"""
import pytest


@pytest.fixture(scope='function')
def restapi_server():
"""Make REST API server"""
from werkzeug.serving import make_server

from aiida.restapi.common.config import CLI_DEFAULTS
from aiida.restapi.run_api import configure_api

def _restapi_server(restapi=None):
if restapi is None:
flask_restapi = configure_api()
else:
flask_restapi = configure_api(flask_api=restapi)

return make_server(
host=CLI_DEFAULTS['HOST_NAME'],
port=int(CLI_DEFAULTS['PORT']),
app=flask_restapi.app,
threaded=True,
processes=1,
request_handler=None,
passthrough_errors=True,
ssl_context=None,
fd=None
)

return _restapi_server


@pytest.fixture
def server_url():
from aiida.restapi.common.config import CLI_DEFAULTS, API_CONFIG

return 'http://{hostname}:{port}{api}'.format(
hostname=CLI_DEFAULTS['HOST_NAME'], port=CLI_DEFAULTS['PORT'], api=API_CONFIG['PREFIX']
)


@pytest.fixture
def restrict_sqlalchemy_queuepool(aiida_profile):
"""Create special SQLAlchemy engine for use with QueryBuilder - backend-agnostic"""
from aiida.manage.manager import get_manager

backend_manager = get_manager().get_backend_manager()
backend_manager.reset_backend_environment()
backend_manager.load_backend_environment(aiida_profile, pool_timeout=1, max_overflow=0)
6 changes: 3 additions & 3 deletions tests/restapi/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def setUpClass(cls, *args, **kwargs): # pylint: disable=too-many-locals, too-ma
# order, api.__init__)
kwargs = dict(PREFIX=cls._url_prefix, PERPAGE_DEFAULT=cls._PERPAGE_DEFAULT, LIMIT_DEFAULT=cls._LIMIT_DEFAULT)

app, _api = configure_api(catch_internal_server=True)
api = configure_api(catch_internal_server=True)

cls.app = app
cls.app = api.app
cls.app.config['TESTING'] = True

# create test inputs
Expand Down Expand Up @@ -286,7 +286,7 @@ def process_test(
"""
Check whether response matches expected values.
:param entity_type: url requested fot the type of the node
:param entity_type: url requested for the type of the node
:param url: web url
:param full_list: if url is requested to get full list
:param empty_list: if the response list is empty
Expand Down
121 changes: 121 additions & 0 deletions tests/restapi/test_threaded_restapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Tests for the `aiida.restapi` module, using it in threaded mode.
Threaded mode is the default (and only) way to run the AiiDA REST API (see `aiida.restapi.run_api:run_api()`).
This test file's layout is inspired by https://gist.github.com/prschmid/4643738
"""
import time
from threading import Thread

import requests
import pytest

NO_OF_REQUESTS = 100


@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool')
def test_run_threaded_server(restapi_server, server_url, aiida_localhost):
"""Run AiiDA REST API threaded in a separate thread and perform many sequential requests"""

server = restapi_server()
computer_id = aiida_localhost.uuid

# Create a thread that will contain the running server,
# since we do not wish to block the main thread
server_thread = Thread(target=server.serve_forever)

try:
server_thread.start()

for _ in range(NO_OF_REQUESTS):
response = requests.get(server_url + '/computers/{}'.format(computer_id), timeout=10)

assert response.status_code == 200

try:
response_json = response.json()
except ValueError:
pytest.fail('Could not turn response into JSON. Response: {}'.format(response.raw))
else:
assert 'data' in response_json

except Exception as exc: # pylint: disable=broad-except
pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc)))
finally:
server.shutdown()

# Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail
for _ in range(100):
if server_thread.is_alive():
time.sleep(0.6)
else:
break
else:
pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown')


@pytest.mark.usefixtures('clear_database_before_test', 'restrict_sqlalchemy_queuepool')
def test_run_without_close_session(restapi_server, server_url, aiida_localhost, capfd):
"""Run AiiDA REST API threaded in a separate thread and perform many sequential requests"""
from aiida.restapi.api import AiidaApi
from aiida.restapi.resources import Computer

class NoCloseSessionApi(AiidaApi):
"""Add Computer to this API (again) with a new endpoint, but pass an empty list for `get_decorators`"""

def __init__(self, app=None, **kwargs):
super().__init__(app=app, **kwargs)

# This is a copy of adding the `Computer` resource,
# but only a few URLs are added, and `get_decorators` is passed with an empty list.
extra_kwargs = kwargs.copy()
extra_kwargs.update({'get_decorators': []})
self.add_resource(
Computer,
'/computers_no_close_session/',
'/computers_no_close_session/<id>/',
endpoint='computers_no_close_session',
strict_slashes=False,
resource_class_kwargs=extra_kwargs,
)

server = restapi_server(NoCloseSessionApi)
computer_id = aiida_localhost.uuid

# Create a thread that will contain the running server,
# since we do not wish to block the main thread
server_thread = Thread(target=server.serve_forever)

try:
server_thread.start()

for _ in range(NO_OF_REQUESTS):
requests.get(server_url + '/computers_no_close_session/{}'.format(computer_id), timeout=10)
pytest.fail('{} requests were not enough to raise a SQLAlchemy TimeoutError!'.format(NO_OF_REQUESTS))

except (requests.exceptions.ConnectionError, OSError):
pass
except Exception as exc: # pylint: disable=broad-except
pytest.fail('Something went terribly wrong! Exception: {}'.format(repr(exc)))
finally:
server.shutdown()

# Wait a total of 1 min (100 x 0.6 s) for the Thread to close/join, else fail
for _ in range(100):
if server_thread.is_alive():
time.sleep(0.6)
else:
break
else:
pytest.fail('Thread did not close/join within 1 min after REST API server was called to shutdown')

captured = capfd.readouterr()
assert 'sqlalchemy.exc.TimeoutError: QueuePool limit of size ' in captured.err

0 comments on commit b9d4bbe

Please sign in to comment.