diff --git a/Makefile b/Makefile index 480311810..1e7216e6a 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,9 @@ config.yml: install: .venv/bin/python .venv/bin/pip-licenses .venv/bin/elastic-ingest .venv/bin/pip-licenses --format=plain-vertical --with-license-file --no-license-path > NOTICE.txt +install-agent: install + .venv/bin/pip install -r requirements/agent.txt + .venv/bin/elastic-ingest: .venv/bin/python .venv/bin/pip install -r requirements/$(ARCH).txt .venv/bin/python setup.py develop @@ -64,7 +67,10 @@ autoformat: .venv/bin/python .venv/bin/ruff .venv/bin/elastic-ingest .venv/bin/ruff format setup.py test: .venv/bin/pytest .venv/bin/elastic-ingest - .venv/bin/pytest --cov-report term-missing --cov-fail-under 92 --cov-report html --cov=connectors --fail-slow=$(SLOW_TEST_THRESHOLD) -sv tests + .venv/bin/pytest --cov-report term-missing --cov-fail-under 92 --cov-report html --cov=connectors --fail-slow=$(SLOW_TEST_THRESHOLD) -sv tests --ignore tests/agent + +test-agent: .venv/bin/pytest .venv/bin/elastic-ingest install-agent + .venv/bin/pytest --cov-report term-missing --cov-fail-under 92 --cov-report html --cov=connectors/agent --fail-slow=$(SLOW_TEST_THRESHOLD) -sv tests/agent release: install .venv/bin/python setup.py sdist diff --git a/connectors/agent/README.md b/connectors/agent/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/connectors/agent/cli.py b/connectors/agent/cli.py new file mode 100644 index 000000000..d3e10664d --- /dev/null +++ b/connectors/agent/cli.py @@ -0,0 +1,42 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +import asyncio +import functools +import signal + +from elastic_agent_client.util.async_tools import ( + sleeps_for_retryable, +) +from elastic_agent_client.util.logger import logger + +from connectors.agent.component import ConnectorsAgentComponent + + +def main(args=None): + """Script entry point into running Connectors Service on Agent. + + It initialises an event loop, creates a component and runs the component. + Additionally, signals are handled for graceful termination of the component. + """ + loop = asyncio.get_event_loop() + logger.info("Running agent") + component = ConnectorsAgentComponent() + + def _shutdown(signal_name): + sleeps_for_retryable.cancel(signal_name) + component.stop(signal_name) + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, functools.partial(_shutdown, sig.name)) + + return loop.run_until_complete(component.run()) + + +if __name__ == "__main__": + try: + main() + finally: + logger.info("Bye") diff --git a/connectors/agent/component.py b/connectors/agent/component.py new file mode 100644 index 000000000..4d5d7e0e7 --- /dev/null +++ b/connectors/agent/component.py @@ -0,0 +1,74 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +import sys + +from elastic_agent_client.client import V2Options, VersionInfo +from elastic_agent_client.reader import new_v2_from_reader +from elastic_agent_client.service.actions import ActionsService +from elastic_agent_client.service.checkin import CheckinV2Service + +from connectors.agent.config import ConnectorsAgentConfigurationWrapper +from connectors.agent.protocol import ConnectorActionHandler, ConnectorCheckinHandler +from connectors.agent.service_manager import ConnectorServiceManager +from connectors.services.base import MultiService + +CONNECTOR_SERVICE = "connector-service" + + +class ConnectorsAgentComponent: + """Entry point into running connectors service in Agent. + + This class provides a simple abstraction over Agent components and Connectors Service manager. + + It instantiates everything needed to read from Agent protocol, creates a wrapper around Connectors Service + and provides applied interface to be able to run it in 2 simple methods: run and stop. + """ + + def __init__(self): + """Inits the class. + + Init should be safe to call without expectations of side effects (connections to Agent, blocking or anything). + """ + self.ver = VersionInfo( + name=CONNECTOR_SERVICE, meta={"input": CONNECTOR_SERVICE} + ) + self.opts = V2Options() + self.buffer = sys.stdin.buffer + self.config_wrapper = ConnectorsAgentConfigurationWrapper() + + async def run(self): + """Start reading from Agent protocol and run Connectors Service with settings reported by agent. + + This method can block if it's not running from Agent - it expects the client to be able to read messages + on initialisation. These messages are a handshake and it's a sync handshake. + If no messages are sent, this method can and will hang. + + However, if ran under agent, this method will read configuration from Agent and attempt to start an + instance of Connectors Service with this configuration. + + Additionally services for handling Check-in and Actions will be started to implement the protocol correctly. + """ + client = new_v2_from_reader(self.buffer, self.ver, self.opts) + action_handler = ConnectorActionHandler() + self.connector_service_manager = ConnectorServiceManager(self.config_wrapper) + checkin_handler = ConnectorCheckinHandler( + client, self.config_wrapper, self.connector_service_manager + ) + + self.multi_service = MultiService( + CheckinV2Service(client, checkin_handler), + ActionsService(client, action_handler), + self.connector_service_manager, + ) + + await self.multi_service.run() + + def stop(self, sig): + """Shutdown everything running in the component. + + Attempts to gracefully shutdown the services that are running under the component. + """ + self.multi_service.shutdown(sig) diff --git a/connectors/agent/config.py b/connectors/agent/config.py new file mode 100644 index 000000000..da7698e36 --- /dev/null +++ b/connectors/agent/config.py @@ -0,0 +1,110 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +from connectors.config import add_defaults + + +class ConnectorsAgentConfigurationWrapper: + """A wrapper that facilitates passing configuration from Agent to Connectors Service. + + This class is responsible for: + - Storing in-memory configuration of Connectors Service running on Agent + - Transforming configuration reported by Agent to valid Connectors Service configuration + - Indicating that configuration has changed so that the user of the class can trigger the restart + """ + + def __init__(self): + """Inits the class. + + There's default config that allows us to run connectors natively (see _force_allow_native flag), + when final configuration is reported these defaults will be merged with defaults from Connectors + Service config and specific config coming from Agent. + """ + self._default_config = { + "_force_allow_native": True, + "native_service_types": [ + "azure_blob_storage", + "box", + "confluence", + "dropbox", + "github", + "gmail", + "google_cloud_storage", + "google_drive", + "jira", + "mongodb", + "mssql", + "mysql", + "notion", + "onedrive", + "oracle", + "outlook", + "network_drive", + "postgresql", + "s3", + "salesforce", + "servicenow", + "sharepoint_online", + "slack", + "microsoft_teams", + "zoom", + ], + } + + self.specific_config = {} + + def try_update(self, source): + """Try update the configuration and see if it changed. + + This method takes the check-in event coming from Agent and checks if config needs an update. + + If update is needed, configuration is updated and method returns True. If no update is needed + the method returns False. + """ + + # TODO: find a good link to what this object is. + has_hosts = source.fields.get("hosts") + has_api_key = source.fields.get("api_key") + has_basic_auth = source.fields.get("username") and source.fields.get("password") + if has_hosts and (has_api_key or has_basic_auth): + es_creds = { + "host": source["hosts"][0], + } + + if source.fields.get("api_key"): + es_creds["api_key"] = source["api_key"] + elif source.fields.get("username") and source.fields.get("password"): + es_creds["username"] = source["username"] + es_creds["password"] = source["password"] + else: + msg = "Invalid Elasticsearch credentials" + raise ValueError(msg) + + new_config = { + "elasticsearch": es_creds, + } + self.specific_config = new_config + return True + + return False + + def get(self): + """Get current Connectors Service configuration. + + This method combines three configs with higher ones taking precedence: + - Config reported from Agent + - Default config stored in this class + - Default config of Connectors Service + + Resulting config should be sufficient to run Connectors Service with. + """ + # First take "default config" + config = self._default_config.copy() + # Then override with what we get from Agent + config.update(self.specific_config) + # Then merge with default connectors config + configuration = dict(add_defaults(config)) + + return configuration diff --git a/connectors/agent/protocol.py b/connectors/agent/protocol.py new file mode 100644 index 000000000..b45a4dc1f --- /dev/null +++ b/connectors/agent/protocol.py @@ -0,0 +1,77 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +from elastic_agent_client.generated import elastic_agent_client_pb2 as proto +from elastic_agent_client.handler.action import BaseActionHandler +from elastic_agent_client.handler.checkin import BaseCheckinHandler +from elastic_agent_client.util.logger import logger + + +class ConnectorActionHandler(BaseActionHandler): + """Class handling Agent actions. + + As there are no actions that we can respond to, we don't actually do anything here. + """ + + async def handle_action(self, action: proto.ActionRequest): + """Implementation of BaseActionHandler.handle_action + + Right now does nothing as Connectors Service has no actions to respond to. + """ + msg = ( + f"This connector component can't handle action requests. Received: {action}" + ) + raise NotImplementedError(msg) + + +class ConnectorCheckinHandler(BaseCheckinHandler): + """Class handling to Agent check-in events. + + Agent sends check-in events from time to time that might contain + information that's needed to run Connectors Service. + + This class reads the events, sees if there's a reported change to connector-specific settings, + tries to update the configuration and, if the configuration is updated, restarts the Connectors Service. + """ + + def __init__(self, client, agent_connectors_config_wrapper, service_manager): + """Inits the class. + + Initing this class should not produce side-effects. + """ + super().__init__(client) + self.agent_connectors_config_wrapper = agent_connectors_config_wrapper + self.service_manager = service_manager + + async def apply_from_client(self): + """Implementation of BaseCheckinHandler.apply_from_client + + This method is called by the Agent Protocol handlers when there's a check-in event + coming from Agent. This class reads the event and runs business logic based on the + content of the event. + + If this class blocks for too long, the component will mark the agent as failed: + agent expects the components to respond within 30 seconds. + See comment in https://github.com/elastic/elastic-agent-client/blob/main/elastic-agent-client.proto#L29 + """ + logger.info("There's new information for the components/units!") + if self.client.units: + logger.debug("Client reported units") + outputs = [ + unit + for unit in self.client.units + if unit.unit_type == proto.UnitType.OUTPUT + ] + + if len(outputs) > 0 and outputs[0].config: + logger.debug("Outputs were found") + source = outputs[0].config.source + + changed = self.agent_connectors_config_wrapper.try_update(source) + if changed: + logger.info("Updating connector service manager config") + self.service_manager.restart() + else: + logger.debug("No changes to connectors config") diff --git a/connectors/agent/service_manager.py b/connectors/agent/service_manager.py new file mode 100644 index 000000000..a9e59e11d --- /dev/null +++ b/connectors/agent/service_manager.py @@ -0,0 +1,96 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +from connectors.logger import DocumentLogger +from connectors.services.base import ( + ServiceAlreadyRunningError, + get_services, +) +from connectors.utils import CancellableSleeps + + +class ConnectorServiceManager: + """Class responsible for properly configuring and running Connectors Service in Elastic Agent + + ConnectorServiceManager is a middle man between Elastic Agent and ConnectorsService. + + This class is taking care of starting Connectors Service subservices with correct configuration. + If configuration changes, as reported by Agent, then this Manager class gracefully shuts down the + subservices and starts them again with new configuration. + + """ + + name = "connector-service-manager" + + def __init__(self, configuration): + """Inits ConnectorServiceManager with shared ConnectorsAgentConfigurationWrapper. + + This service is supposed to be ran once, and after it's stopped or finished running it's not + supposed to be started again. + + There is nothing enforcing it, but expect problems if that happens. + """ + service_name = self.__class__.name + self._logger = DocumentLogger( + f"[{service_name}]", {"service_name": service_name} + ) + self._agent_config = configuration + self._multi_service = None + self._running = False + self._sleeps = CancellableSleeps() + + async def run(self): + """Starts the running loop of the service. + + Once started, the service attempts to run all needed connector subservices + in parallel via MultiService. + + Service can be restarted - it will keep running but with refreshed config, + or it can be stopped - it will just gracefully shut down. + """ + if self._running: + msg = f"{self.__class__.__name__} is already running." + raise ServiceAlreadyRunningError(msg) + + self._running = True + + try: + while self._running: + try: + self._logger.info("Starting connector services") + self._multi_service = get_services( + ["schedule", "sync_content", "sync_access_control", "cleanup"], + self._agent_config.get(), + ) + await self._multi_service.run() + except Exception as e: + self._logger.exception( + f"Error while running services in ConnectorServiceManager: {e}" + ) + raise + finally: + self._logger.info("Finished running, exiting") + + def stop(self): + """Stop the service manager and all running subservices. + + Running stop attempts to gracefully shutdown all subservices currently running. + """ + self._logger.info("Stopping connector services.") + self._running = False + self._done = True + if self._multi_service: + self._multi_service.shutdown(None) + + def restart(self): + """Restart the service manager and all running subservices. + + Running restart attempts to gracefully shutdown all subservices currently running. + After services are gracefully stopped, they will be started again with fresh configuration + that comes from ConnectorsAgentConfigurationWrapper. + """ + self._logger.info("Restarting connector services") + if self._multi_service: + self._multi_service.shutdown(None) diff --git a/connectors/config.py b/connectors/config.py index d74a01d50..d635c912d 100644 --- a/connectors/config.py +++ b/connectors/config.py @@ -28,6 +28,11 @@ def load_config(config_file): return configuration +def add_defaults(config): + configuration = dict(_merge_dicts(_default_config(), config)) + return configuration + + # Left - in Enterprise Search; Right - in Connectors config_mappings = { "elasticsearch.host": "elasticsearch.host", diff --git a/requirements/agent.txt b/requirements/agent.txt new file mode 100644 index 000000000..5f757e987 --- /dev/null +++ b/requirements/agent.txt @@ -0,0 +1 @@ +elastic-agent-client@git+https://github.com/elastic/python-elastic-agent-client@main diff --git a/requirements/framework.txt b/requirements/framework.txt index 504cba9f5..5bb91d8df 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -15,7 +15,7 @@ tzcron==1.0.0 pytz==2019.3 python-dateutil==2.8.2 aiogoogle==5.3.0 -uvloop==0.18.0; sys_platform != 'win32' +uvloop==0.20.0; sys_platform != 'win32' fastjsonschema==2.16.2 base64io==1.0.3 azure-storage-blob==12.19.1 diff --git a/requirements/ftest.txt b/requirements/ftest.txt index 7bd3a98ca..7350ee663 100644 --- a/requirements/ftest.txt +++ b/requirements/ftest.txt @@ -1,7 +1,7 @@ flask==2.3.2 flask_limiter==3.3.1 faker==18.11.2 -mysql-connector-python==8.1.0 +mysql-connector-python==9.0.0 google-auth==2.22.0 google-cloud-storage==2.10.0 asyncpg==0.29.0 diff --git a/setup.py b/setup.py index 9802e672b..e70c03d79 100644 --- a/setup.py +++ b/setup.py @@ -103,5 +103,6 @@ def read_reqs(req_file): fake-kibana = connectors.kibana:main connectors = connectors.connectors_cli:main test-connectors = scripts.testing.cli:main + elastic-agent-connectors = connectors.agent.cli:main """, ) diff --git a/tests/agent/test_agent_config.py b/tests/agent/test_agent_config.py new file mode 100644 index 000000000..577832734 --- /dev/null +++ b/tests/agent/test_agent_config.py @@ -0,0 +1,55 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +from unittest.mock import MagicMock + +import pytest + +from connectors.agent.config import ConnectorsAgentConfigurationWrapper + + +@pytest.mark.asyncio +async def test_try_update_without_auth_data(): + config_wrapper = ConnectorsAgentConfigurationWrapper() + source_mock = MagicMock() + source_mock.fields = {} + + assert config_wrapper.try_update(source_mock) is False + + +@pytest.mark.asyncio +async def test_try_update_with_api_key_auth_data(): + hosts = ["https://localhost:9200"] + api_key = "lemme_in" + + config_wrapper = ConnectorsAgentConfigurationWrapper() + source_mock = MagicMock() + fields_container = {"hosts": hosts, "api_key": api_key} + + source_mock.fields = fields_container + source_mock.__getitem__.side_effect = fields_container.__getitem__ + + assert config_wrapper.try_update(source_mock) is True + assert config_wrapper.get()["elasticsearch"]["host"] == hosts[0] + assert config_wrapper.get()["elasticsearch"]["api_key"] == api_key + + +@pytest.mark.asyncio +async def test_try_update_with_basic_auth_auth_data(): + hosts = ["https://localhost:9200"] + username = "elastic" + password = "hold the door" + + config_wrapper = ConnectorsAgentConfigurationWrapper() + source_mock = MagicMock() + fields_container = {"hosts": hosts, "username": username, "password": password} + + source_mock.fields = fields_container + source_mock.__getitem__.side_effect = fields_container.__getitem__ + + assert config_wrapper.try_update(source_mock) is True + assert config_wrapper.get()["elasticsearch"]["host"] == hosts[0] + assert config_wrapper.get()["elasticsearch"]["username"] == username + assert config_wrapper.get()["elasticsearch"]["password"] == password diff --git a/tests/agent/test_cli.py b/tests/agent/test_cli.py new file mode 100644 index 000000000..ba8b1c6a4 --- /dev/null +++ b/tests/agent/test_cli.py @@ -0,0 +1,25 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +import asyncio +import os +import signal +from unittest.mock import AsyncMock, patch + +from connectors.agent.cli import main + + +@patch("connectors.agent.cli.ConnectorsAgentComponent", return_value=AsyncMock()) +def test_main(patch_component): + async def kill(): + await asyncio.sleep(0.2) + os.kill(os.getpid(), signal.SIGTERM) + + loop = asyncio.new_event_loop() + loop.create_task(kill()) + + main() + + loop.close() diff --git a/tests/agent/test_component.py b/tests/agent/test_component.py new file mode 100644 index 000000000..9bdfe9e17 --- /dev/null +++ b/tests/agent/test_component.py @@ -0,0 +1,45 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from connectors.agent.component import ConnectorsAgentComponent + + +class StubMultiService: + def __init__(self): + self.running_stop = asyncio.Event() + self.has_ran = False + self.has_shutdown = False + + async def run(self): + self.has_ran = True + self.running_stop.clear() + await self.running_stop.wait() + + def shutdown(self, sig): + self.has_shutdown = True + self.running_stop.set() + + +@pytest.mark.asyncio +@patch("connectors.agent.component.MultiService", return_value=StubMultiService()) +@patch("connectors.agent.component.new_v2_from_reader", return_value=MagicMock()) +async def test_try_update_without_auth_data( + stub_multi_service, patch_new_v2_from_reader +): + component = ConnectorsAgentComponent() + + async def stop_after_timeout(): + await asyncio.sleep(0.1) + component.stop("SIGINT") + + await asyncio.gather(component.run(), stop_after_timeout()) + + assert stub_multi_service.has_ran + assert stub_multi_service.has_shutdown diff --git a/tests/agent/test_protocol.py b/tests/agent/test_protocol.py new file mode 100644 index 000000000..f01c075e4 --- /dev/null +++ b/tests/agent/test_protocol.py @@ -0,0 +1,106 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +from unittest.mock import Mock + +import pytest +from elastic_agent_client.generated import elastic_agent_client_pb2 as proto + +from connectors.agent.protocol import ConnectorActionHandler, ConnectorCheckinHandler + + +class TestConnectorActionHandler: + @pytest.mark.asyncio + async def test_handle_action(self): + action_handler = ConnectorActionHandler() + + with pytest.raises(NotImplementedError): + await action_handler.handle_action(None) + + +class TestConnectorCheckingHandler: + @pytest.mark.asyncio + async def test_apply_from_client_when_no_units_received(self): + client_mock = Mock() + config_wrapper_mock = Mock() + service_manager_mock = Mock() + + client_mock.units = [] + + checkin_handler = ConnectorCheckinHandler( + client_mock, config_wrapper_mock, service_manager_mock + ) + + await checkin_handler.apply_from_client() + + assert not config_wrapper_mock.try_update.called + assert not service_manager_mock.restart.called + + @pytest.mark.asyncio + async def test_apply_from_client_when_units_with_no_output(self): + client_mock = Mock() + config_wrapper_mock = Mock() + service_manager_mock = Mock() + unit_mock = Mock() + unit_mock.unit_type = "Something else" + + client_mock.units = [unit_mock] + + checkin_handler = ConnectorCheckinHandler( + client_mock, config_wrapper_mock, service_manager_mock + ) + + await checkin_handler.apply_from_client() + + assert not config_wrapper_mock.try_update.called + assert not service_manager_mock.restart.called + + @pytest.mark.asyncio + async def test_apply_from_client_when_units_with_output_and_non_updating_config( + self, + ): + client_mock = Mock() + config_wrapper_mock = Mock() + + config_wrapper_mock.try_update.return_value = False + + service_manager_mock = Mock() + unit_mock = Mock() + unit_mock.unit_type = proto.UnitType.OUTPUT + unit_mock.config.source = {"elasticsearch": {"api_key": 123}} + + client_mock.units = [unit_mock] + + checkin_handler = ConnectorCheckinHandler( + client_mock, config_wrapper_mock, service_manager_mock + ) + + await checkin_handler.apply_from_client() + + assert config_wrapper_mock.try_update.called_once_with(unit_mock.config.source) + assert not service_manager_mock.restart.called + + @pytest.mark.asyncio + async def test_apply_from_client_when_units_with_output_and_updating_config(self): + client_mock = Mock() + config_wrapper_mock = Mock() + + config_wrapper_mock.try_update.return_value = True + + service_manager_mock = Mock() + unit_mock = Mock() + unit_mock.unit_type = proto.UnitType.OUTPUT + unit_mock.config.source = {"elasticsearch": {"api_key": 123}} + + client_mock.units = [unit_mock] + + checkin_handler = ConnectorCheckinHandler( + client_mock, config_wrapper_mock, service_manager_mock + ) + + await checkin_handler.apply_from_client() + + assert config_wrapper_mock.try_update.called_once_with(unit_mock.config.source) + assert service_manager_mock.restart.called diff --git a/tests/agent/test_service_manager.py b/tests/agent/test_service_manager.py new file mode 100644 index 000000000..6e4a528ba --- /dev/null +++ b/tests/agent/test_service_manager.py @@ -0,0 +1,91 @@ +# +# Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +# or more contributor license agreements. Licensed under the Elastic License 2.0; +# you may not use this file except in compliance with the Elastic License 2.0. +# +import asyncio +from unittest.mock import Mock, patch + +import pytest + +from connectors.agent.service_manager import ConnectorServiceManager +from connectors.services.base import ServiceAlreadyRunningError + + +@pytest.fixture(autouse=True) +def config_mock(): + config = Mock() + + config.get.return_value = { + "service": {"idling": 123, "heartbeat": 5}, + "elasticsearch": {}, + "sources": [], + } + + return config + + +class StubMultiService: + def __init__(self): + self.running_stop = asyncio.Event() + self.has_ran = False + self.has_shutdown = False + + async def run(self): + self.has_ran = True + self.running_stop.clear() + await self.running_stop.wait() + + def shutdown(self, sig): + self.has_shutdown = True + self.running_stop.set() + + +@pytest.mark.asyncio +@patch("connectors.agent.service_manager.get_services", return_value=StubMultiService()) +async def test_run_and_stop_work_as_intended(patch_get_services, config_mock): + service_manager = ConnectorServiceManager(config_mock) + + async def stop_service_after_timeout(): + await asyncio.sleep(0.1) + service_manager.stop() + + await asyncio.gather(service_manager.run(), stop_service_after_timeout()) + + assert patch_get_services.return_value.has_ran + assert patch_get_services.return_value.has_shutdown + + +@pytest.mark.asyncio +@patch("connectors.agent.service_manager.get_services", return_value=StubMultiService()) +async def test_restart_starts_another_multiservice(patch_get_services, config_mock): + service_manager = ConnectorServiceManager(config_mock) + + async def stop_service_after_timeout(): + await asyncio.sleep(0.1) + service_manager.restart() + await asyncio.sleep(0.1) + service_manager.stop() + + await asyncio.gather(service_manager.run(), stop_service_after_timeout()) + + assert patch_get_services.called + assert patch_get_services.call_count == 2 + + +@pytest.mark.asyncio +@patch("connectors.agent.service_manager.get_services", return_value=StubMultiService()) +async def test_cannot_run_same_service_manager_twice(patch_get_services, config_mock): + service_manager = ConnectorServiceManager(config_mock) + + with pytest.raises(ServiceAlreadyRunningError): + tasks = [asyncio.create_task(service_manager.run()) for _ in range(2)] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + + # Cancel pending to clean up the tasks + for task in pending: + task.cancel() + + # Execute task results to cause an exception to be raised if any + for task in done: + task.result()