diff --git a/docs/asyncio.rst b/docs/asyncio.rst new file mode 100644 index 00000000..57c1a938 --- /dev/null +++ b/docs/asyncio.rst @@ -0,0 +1,132 @@ +Support for asyncio (experimental) +================================== + +.. versionadded:: 5.8.0 + +Karton v5.8.0 implements experimental support for asyncio. The intended use-case is to support: + +- "auto-scalable" Consumers that are waiting for external job to be done for most of the time (e.g. sandbox executors) +- Producers in asyncio-based projects + +.. warning:: + + ``karton.core.asyncio`` requires at least Python 3.11 + +How to use it? +-------------- + +The basic usage is almost the same as in sync version. If you want to write a consumer, just import needed things +from ``karton.core.asyncio`` package and use ``async def`` keyword in ``process(...)`` method. + +.. code-block:: python + + import asyncio + from karton.core.asyncio import Consumer, Task + + class FooBarConsumer(Consumer): + identity = "foobar-consumer" + filters = [ + { + "type": "foobar" + } + ] + + async def process(self, task: Task) -> None: + num = task.get_payload("data") + self.log.info("Got number %d", num) + await asyncio.sleep(5) + if num % 3 == 0: + self.log.info("Foo") + if num % 5 == 0: + self.log.info("Bar") + + if __name__ == "__main__": + # calls asyncio.run(FooBarConsumer().loop()) + FooBarConsumer.main() + + +Using a Producer is similar, but you need to remember to call ``async connect()`` in the initialization code before sending a first task. +Synchronous version of KartonBackend connects to the Redis/S3 in the Producer constructor, but in asyncio, connection must be done explicitly. + +.. code-block:: python + + import asyncio + from karton.core.asyncio import Producer, Task + + foo_producer = Producer(identity="foobar-producer") + + async def main(): + await foo_producer.connect() + + for i in range(5): + task = Task(headers={"type": "foobar"}, payload={"data": i}) + await foo_producer.send_task(task) + + if __name__ == "__main__": + asyncio.run(main()) + +Limiting the Consumer concurrency +--------------------------------- + +asyncio Consumers are very greedy when it comes to consuming tasks. Each task is started as soon as possible and +proper `process()` coroutine is scheduled in event loop. It's recommended to set a limit of concurrently running +tasks via ``concurrency_limit`` configuration argument. + +.. code-block:: python + + import asyncio + from karton.core.asyncio import Consumer, Task + + class FooBarConsumer(Consumer): + identity = "foobar-consumer" + filters = [ + { + "type": "foobar" + } + ] + + concurrency_limit = 16 + +Choosing the appropriate limit depends on how many of the parallel connections/jobs can be handled by the service +that is used by the Consumer. + +Asynchronous resources +---------------------- + +Resources provided in Tasks are deserialized into ``karton.core.asyncio.RemoteResource`` objects. + +There are few differences in their API compared to the synchronous version: + +- all downloading methods need to be called with ``await`` keyword (they're coroutines). +- ``RemoteResource.content`` raises ``RuntimeError`` when resource wasn't explicitly downloaded before. + You need to call ``await resource.download()`` first. + +It's also required to use ``karton.core.asyncio.LocalResource`` while creating a new task. + +Termination handling +-------------------- + +Asynchronous consumers must be aware of `task cancellation `_ +and handle the `asyncio.CancelledError `_ +if they want to gracefully terminate their operations in case of ``SIGINT``/``SIGTERM`` or exceeded ``task_timeout``. + +Asynchronous Karton can't interrupt blocking/hanged operations. + +Known issues: reported number of replicas +----------------------------------------- + +When using asyncio-based Karton consumers, be aware that the reported number of replicas may not accurately reflect +the actual number of running consumer instances. + +This is due to how the Karton framework determines the replica count — it relies on counting active Redis connections. + +Missing features +---------------- + +``karton.core.asyncio`` implements only a subset of Karton API, required to run most common producers/consumers. + +Right now we don't support: + +- test suite (``karton.core.test``) +- Karton state inspection (``karton.core.inspect``) +- pre/post/signalling hooks diff --git a/docs/index.rst b/docs/index.rst index 3c6bfe28..a82fe220 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -70,6 +70,7 @@ Task routing and data exchange is achieved with the help of **Karton-System** - service_configuration advanced_concepts unit_tests + asyncio karton_api Indices and tables diff --git a/karton/core/asyncio/__init__.py b/karton/core/asyncio/__init__.py new file mode 100644 index 00000000..197a6588 --- /dev/null +++ b/karton/core/asyncio/__init__.py @@ -0,0 +1,21 @@ +import sys + +if sys.version_info < (3, 11, 0): + raise ImportError("karton.core.asyncio is only compatible with Python 3.11+") + +from karton.core.config import Config +from karton.core.task import Task + +from .karton import Consumer, Karton, Producer +from .resource import LocalResource, RemoteResource, Resource + +__all__ = [ + "Karton", + "Producer", + "Consumer", + "Task", + "Config", + "LocalResource", + "Resource", + "RemoteResource", +] diff --git a/karton/core/asyncio/backend.py b/karton/core/asyncio/backend.py new file mode 100644 index 00000000..c822b140 --- /dev/null +++ b/karton/core/asyncio/backend.py @@ -0,0 +1,370 @@ +import json +import logging +import time +from typing import IO, Any, Dict, List, Optional, Tuple, Union + +import aioboto3 +from aiobotocore.credentials import ContainerProvider, InstanceMetadataProvider +from aiobotocore.session import ClientCreatorContext, get_session +from aiobotocore.utils import InstanceMetadataFetcher +from redis.asyncio import Redis +from redis.asyncio.client import Pipeline +from redis.exceptions import AuthenticationError + +from karton.core import Config, Task +from karton.core.asyncio.resource import RemoteResource +from karton.core.backend import ( + KARTON_BINDS_HSET, + KARTON_TASK_NAMESPACE, + KARTON_TASKS_QUEUE, + KartonBackendBase, + KartonBind, + KartonMetrics, + KartonServiceInfo, +) +from karton.core.task import TaskState + +logger = logging.getLogger(__name__) + + +class KartonAsyncBackend(KartonBackendBase): + def __init__( + self, + config: Config, + identity: Optional[str] = None, + service_info: Optional[KartonServiceInfo] = None, + ) -> None: + super().__init__(config, identity, service_info) + self._redis: Optional[Redis] = None + self._s3_session: Optional[aioboto3.Session] = None + self._s3_iam_auth = False + + @property + def redis(self) -> Redis: + if not self._redis: + raise RuntimeError("Call connect() first before using KartonAsyncBackend") + return self._redis + + @property + def s3(self) -> ClientCreatorContext: + if not self._s3_session: + raise RuntimeError("Call connect() first before using KartonAsyncBackend") + endpoint = self.config.get("s3", "address") + if self._s3_iam_auth: + return self._s3_session.client( + "s3", + endpoint_url=endpoint, + ) + else: + access_key = self.config.get("s3", "access_key") + secret_key = self.config.get("s3", "secret_key") + return self._s3_session.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + + async def connect(self): + if self._redis is not None or self._s3_session is not None: + # Already connected + return + self._redis = await self.make_redis( + self.config, identity=self.identity, service_info=self.service_info + ) + + endpoint = self.config.get("s3", "address") + access_key = self.config.get("s3", "access_key") + secret_key = self.config.get("s3", "secret_key") + iam_auth = self.config.getboolean("s3", "iam_auth") + + if not endpoint: + raise RuntimeError("Attempting to get S3 client without an endpoint set") + + if access_key and secret_key and iam_auth: + logger.warning( + "Warning: iam is turned on and both S3 access key and secret key are" + " provided" + ) + + if iam_auth: + s3_client_creator = await self.iam_auth_s3() + if s3_client_creator: + self._s3_iam_auth = True + self._s3_session = s3_client_creator + return + + if access_key is None or secret_key is None: + raise RuntimeError( + "Attempting to get S3 client without an access_key/secret_key set" + ) + + session = aioboto3.Session() + self._s3_session = session + + async def iam_auth_s3(self): + boto_session = get_session() + iam_providers = [ + ContainerProvider(), + InstanceMetadataProvider( + iam_role_fetcher=InstanceMetadataFetcher(timeout=1000, num_attempts=2) + ), + ] + + for provider in iam_providers: + creds = await provider.load() + if creds: + boto_session._credentials = creds # type: ignore + return aioboto3.Session(botocore_session=boto_session) + + @staticmethod + async def make_redis( + config, + identity: Optional[str] = None, + service_info: Optional[KartonServiceInfo] = None, + ) -> Redis: + """ + Create and test a Redis connection. + + :param config: The karton configuration + :param identity: Karton service identity + :param service_info: Additional service identity metadata + :return: Redis connection + """ + if service_info is not None: + client_name: Optional[str] = service_info.make_client_name() + else: + client_name = identity + + redis_args = { + "host": config["redis"]["host"], + "port": config.getint("redis", "port", 6379), + "db": config.getint("redis", "db", 0), + "username": config.get("redis", "username"), + "password": config.get("redis", "password"), + "client_name": client_name, + # set socket_timeout to None if set to 0 + "socket_timeout": config.getint("redis", "socket_timeout", 30) or None, + "decode_responses": True, + } + try: + rs = Redis(**redis_args) + await rs.ping() + except AuthenticationError: + # Maybe we've sent a wrong password. + # Or maybe the server is not (yet) password protected + # To make smooth transition possible, try to login insecurely + del redis_args["password"] + rs = Redis(**redis_args) + await rs.ping() + return rs + + def unserialize_resource(self, resource_spec: Dict[str, Any]) -> RemoteResource: + """ + Unserializes resource into a RemoteResource object bound with current backend + + :param resource_spec: Resource specification + :return: RemoteResource object + """ + return RemoteResource.from_dict(resource_spec, backend=self) + + async def register_task(self, task: Task, pipe: Optional[Pipeline] = None) -> None: + """ + Register or update task in Redis. + + :param task: Task object + :param pipe: Optional pipeline object if operation is a part of pipeline + """ + rs = pipe or self.redis + await rs.set(f"{KARTON_TASK_NAMESPACE}:{task.uid}", task.serialize()) + + async def set_task_status( + self, task: Task, status: TaskState, pipe: Optional[Pipeline] = None + ) -> None: + """ + Request task status change to be applied by karton-system + + :param task: Task object + :param status: New task status (TaskState) + :param pipe: Optional pipeline object if operation is a part of pipeline + """ + if task.status == status: + return + task.status = status + task.last_update = time.time() + await self.register_task(task, pipe=pipe) + + async def register_bind(self, bind: KartonBind) -> Optional[KartonBind]: + """ + Register bind for Karton service and return the old one + + :param bind: KartonBind object with bind definition + :return: Old KartonBind that was registered under this identity + """ + async with self.redis.pipeline(transaction=True) as pipe: + await pipe.hget(KARTON_BINDS_HSET, bind.identity) + await pipe.hset(KARTON_BINDS_HSET, bind.identity, self.serialize_bind(bind)) + old_serialized_bind, _ = await pipe.execute() + + if old_serialized_bind: + return self.unserialize_bind(bind.identity, old_serialized_bind) + else: + return None + + async def get_bind(self, identity: str) -> KartonBind: + """ + Get bind object for given identity + + :param identity: Karton service identity + :return: KartonBind object + """ + return self.unserialize_bind( + identity, await self.redis.hget(KARTON_BINDS_HSET, identity) + ) + + async def produce_unrouted_task(self, task: Task) -> None: + """ + Add given task to unrouted task (``karton.tasks``) queue + + Task must be registered before with :py:meth:`register_task` + + :param task: Task object + """ + await self.redis.rpush(KARTON_TASKS_QUEUE, task.uid) + + async def consume_queues( + self, queues: Union[str, List[str]], timeout: int = 0 + ) -> Optional[Tuple[str, str]]: + """ + Get item from queues (ordered from the most to the least prioritized) + If there are no items, wait until one appear. + + :param queues: Redis queue name or list of names + :param timeout: Waiting for item timeout (default: 0 = wait forever) + :return: Tuple of [queue_name, item] objects or None if timeout has been reached + """ + return await self.redis.blpop(queues, timeout=timeout) + + async def get_task(self, task_uid: str) -> Optional[Task]: + """ + Get task object with given identifier + + :param task_uid: Task identifier + :return: Task object + """ + task_data = await self.redis.get(f"{KARTON_TASK_NAMESPACE}:{task_uid}") + if not task_data: + return None + return Task.unserialize( + task_data, resource_unserializer=self.unserialize_resource + ) + + async def consume_routed_task( + self, identity: str, timeout: int = 5 + ) -> Optional[Task]: + """ + Get routed task for given consumer identity. + + If there are no tasks, blocks until new one appears or timeout is reached. + + :param identity: Karton service identity + :param timeout: Waiting for task timeout (default: 5) + :return: Task object + """ + item = await self.consume_queues( + self.get_queue_names(identity), + timeout=timeout, + ) + if not item: + return None + queue, data = item + return await self.get_task(data) + + async def increment_metrics( + self, metric: KartonMetrics, identity: str, pipe: Optional[Pipeline] = None + ) -> None: + """ + Increments metrics for given operation type and identity + + :param metric: Operation metric type + :param identity: Related Karton service identity + :param pipe: Optional pipeline object if operation is a part of pipeline + """ + rs = pipe or self.redis + await rs.hincrby(metric.value, identity, 1) + + async def upload_object( + self, + bucket: str, + object_uid: str, + content: Union[bytes, IO[bytes]], + ) -> None: + """ + Upload resource object to underlying object storage (S3) + + :param bucket: Bucket name + :param object_uid: Object identifier + :param content: Object content as bytes or file-like stream + """ + async with self.s3 as client: + await client.put_object(Bucket=bucket, Key=object_uid, Body=content) + + async def upload_object_from_file( + self, bucket: str, object_uid: str, path: str + ) -> None: + """ + Upload resource object file to underlying object storage + + :param bucket: Bucket name + :param object_uid: Object identifier + :param path: Path to the object content + """ + async with self.s3 as client: + with open(path, "rb") as f: + await client.put_object(Bucket=bucket, Key=object_uid, Body=f) + + async def download_object(self, bucket: str, object_uid: str) -> bytes: + """ + Download resource object from object storage. + + :param bucket: Bucket name + :param object_uid: Object identifier + :return: Content bytes + """ + async with self.s3 as client: + obj = await client.get_object(Bucket=bucket, Key=object_uid) + return await obj["Body"].read() + + async def download_object_to_file( + self, bucket: str, object_uid: str, path: str + ) -> None: + """ + Download resource object from object storage to file + + :param bucket: Bucket name + :param object_uid: Object identifier + :param path: Target file path + """ + async with self.s3 as client: + await client.download_file(Bucket=bucket, Key=object_uid, Filename=path) + + async def produce_log( + self, + log_record: Dict[str, Any], + logger_name: str, + level: str, + ) -> bool: + """ + Push new log record to the logs channel + + :param log_record: Dict with log record + :param logger_name: Logger name + :param level: Log level + :return: True if any active log consumer received log record + """ + return ( + await self.redis.publish( + self._log_channel(logger_name, level), json.dumps(log_record) + ) + > 0 + ) diff --git a/karton/core/asyncio/base.py b/karton/core/asyncio/base.py new file mode 100644 index 00000000..4b43958d --- /dev/null +++ b/karton/core/asyncio/base.py @@ -0,0 +1,129 @@ +import abc +import asyncio +import signal +from asyncio import CancelledError +from typing import Optional + +from karton.core import Task +from karton.core.__version__ import __version__ +from karton.core.backend import KartonServiceInfo +from karton.core.base import ConfigMixin, LoggingMixin +from karton.core.config import Config +from karton.core.task import get_current_task, set_current_task +from karton.core.utils import StrictClassMethod + +from .backend import KartonAsyncBackend +from .logger import KartonAsyncLogHandler + + +class KartonAsyncBase(abc.ABC, ConfigMixin, LoggingMixin): + """ + Base class for all Karton services + + You can set an informative version information by setting the ``version`` class + attribute. + """ + + #: Karton service identity + identity: str = "" + #: Karton service version + version: Optional[str] = None + #: Include extended service information for non-consumer services + with_service_info: bool = False + + def __init__( + self, + config: Optional[Config] = None, + identity: Optional[str] = None, + backend: Optional[KartonAsyncBackend] = None, + ) -> None: + ConfigMixin.__init__(self, config, identity) + + self.service_info = None + if self.identity is not None and self.with_service_info: + self.service_info = KartonServiceInfo( + identity=self.identity, + karton_version=__version__, + service_version=self.version, + ) + + self.backend = backend or KartonAsyncBackend( + self.config, identity=self.identity, service_info=self.service_info + ) + + log_handler = KartonAsyncLogHandler(backend=self.backend, channel=self.identity) + LoggingMixin.__init__(self, log_handler) + + async def connect(self) -> None: + await self.backend.connect() + + @property + def current_task(self) -> Optional[Task]: + return get_current_task() + + @current_task.setter + def current_task(self, task: Optional[Task]): + set_current_task(task) + + +class KartonAsyncServiceBase(KartonAsyncBase): + """ + Karton base class for looping services. + + You can set an informative version information by setting the ``version`` class + attribute + + :param config: Karton config to use for service configuration + :param identity: Karton service identity to use + :param backend: Karton backend to use + """ + + def __init__( + self, + config: Optional[Config] = None, + identity: Optional[str] = None, + backend: Optional[KartonAsyncBackend] = None, + ) -> None: + super().__init__( + config=config, + identity=identity, + backend=backend, + ) + self.setup_logger() + self._loop_coro: Optional[asyncio.Task] = None + + def _do_shutdown(self) -> None: + self.log.info("Got signal, shutting down...") + if self._loop_coro is not None: + self._loop_coro.cancel() + + @abc.abstractmethod + async def _loop(self) -> None: + raise NotImplementedError + + # Base class for Karton services + async def loop(self) -> None: + if self.enable_publish_log and hasattr(self.log_handler, "start_consuming"): + self.log_handler.start_consuming() + await self.connect() + event_loop = asyncio.get_event_loop() + for sig in (signal.SIGTERM, signal.SIGINT): + event_loop.add_signal_handler(sig, self._do_shutdown) + self._loop_coro = asyncio.create_task(self._loop()) + try: + await self._loop_coro + finally: + for sig in (signal.SIGTERM, signal.SIGINT): + event_loop.remove_signal_handler(sig) + if self.enable_publish_log and hasattr(self.log_handler, "stop_consuming"): + await self.log_handler.stop_consuming() + + @StrictClassMethod + def main(cls) -> None: + """Main method invoked from CLI.""" + service = cls.karton_from_args() + try: + asyncio.run(service.loop()) + except CancelledError: + # Swallow cancellation, we're done! + pass diff --git a/karton/core/asyncio/karton.py b/karton/core/asyncio/karton.py new file mode 100644 index 00000000..1544c7f6 --- /dev/null +++ b/karton/core/asyncio/karton.py @@ -0,0 +1,355 @@ +import abc +import argparse +import asyncio +import sys +import time +import traceback +from asyncio import CancelledError +from typing import Any, Dict, List, Optional + +from karton.core import query +from karton.core.__version__ import __version__ +from karton.core.backend import KartonBind, KartonMetrics +from karton.core.config import Config +from karton.core.exceptions import TaskTimeoutError +from karton.core.resource import LocalResource as SyncLocalResource +from karton.core.task import Task, TaskState + +from .backend import KartonAsyncBackend +from .base import KartonAsyncBase, KartonAsyncServiceBase +from .resource import LocalResource + + +class Producer(KartonAsyncBase): + """ + Producer part of Karton. Used for dispatching initial tasks into karton. + + :param config: Karton configuration object (optional) + :type config: :class:`karton.Config` + :param identity: Producer name (optional) + :type identity: str + + Usage example: + + .. code-block:: python + + from karton.core.asyncio import Producer + + producer = Producer(identity="karton.mwdb") + await producer.connect() + task = Task( + headers={ + "type": "sample", + "kind": "raw" + }, + payload={ + "sample": Resource("sample.exe", b"put content here") + } + ) + await producer.send_task(task) + + :param config: Karton config to use for service configuration + :param identity: Karton producer identity + :param backend: Karton backend to use + """ + + def __init__( + self, + config: Optional[Config] = None, + identity: Optional[str] = None, + backend: Optional[KartonAsyncBackend] = None, + ) -> None: + super().__init__(config=config, identity=identity, backend=backend) + + async def send_task(self, task: Task) -> bool: + """ + Sends a task to the unrouted task queue. Takes care of logging. + Given task will be child of task we are currently handling (if such exists). + + :param task: Task object to be sent + :return: Bool indicating if the task was delivered + """ + self.log.debug("Dispatched task %s", task.uid) + + # Complete information about task + if self.current_task is not None: + task.set_task_parent(self.current_task) + task.merge_persistent_payload(self.current_task) + task.merge_persistent_headers(self.current_task) + task.priority = self.current_task.priority + + task.last_update = time.time() + task.headers.update({"origin": self.identity}) + + # Ensure all local resources have good buckets + for resource in task.iterate_resources(): + if isinstance(resource, LocalResource) and not resource.bucket: + resource.bucket = self.backend.default_bucket_name + if isinstance(resource, SyncLocalResource): + raise RuntimeError( + "Synchronous resources are not supported. " + "Use karton.core.asyncio.resource module instead." + ) + + # Register new task + await self.backend.register_task(task) + + # Upload local resources + for resource in task.iterate_resources(): + if isinstance(resource, LocalResource): + await resource.upload(self.backend) + + # Add task to karton.tasks + await self.backend.produce_unrouted_task(task) + await self.backend.increment_metrics(KartonMetrics.TASK_PRODUCED, self.identity) + return True + + +class Consumer(KartonAsyncServiceBase): + """ + Base consumer class, this is the part of Karton responsible for processing + incoming tasks + + :param config: Karton config to use for service configuration + :param identity: Karton service identity + :param backend: Karton backend to use + :param task_timeout: The maximum time, in seconds, this consumer will wait for + a task to finish processing before being CRASHED on timeout. + Set 0 for unlimited, and None for using global value + :param concurrency_limit: The maximum number of concurrent tasks that may be + gathered from queue and processed asynchronously. + """ + + filters: List[Dict[str, Any]] = [] + persistent: bool = True + version: Optional[str] = None + task_timeout = None + concurrency_limit: Optional[int] = None + + def __init__( + self, + config: Optional[Config] = None, + identity: Optional[str] = None, + backend: Optional[KartonAsyncBackend] = None, + ) -> None: + super().__init__(config=config, identity=identity, backend=backend) + + if self.filters is None: + raise ValueError("Cannot bind consumer on Empty binds") + + # Dummy conversion to make sure the filters are well-formed. + query.convert(self.filters) + + self.persistent = ( + self.config.getboolean("karton", "persistent", self.persistent) + and not self.debug + ) + if self.task_timeout is None: + self.task_timeout = self.config.getint("karton", "task_timeout") + + if self.concurrency_limit is None: + self.concurrency_limit = self.config.getint("karton", "concurrency_limit") + + self.concurrency_semaphore: Optional[asyncio.Semaphore] = None + if self.concurrency_limit is not None: + self.concurrency_semaphore = asyncio.BoundedSemaphore( + self.concurrency_limit + ) + + @abc.abstractmethod + async def process(self, task: Task) -> None: + """ + Task processing method. + + :param task: The incoming task object + + self.current_task contains task that triggered invocation of + :py:meth:`karton.Consumer.process` but you should only focus on the passed + task object and shouldn't interact with the field directly. + """ + raise NotImplementedError() + + async def _internal_process(self, task: Task) -> None: + exception_str = None + try: + self.log.info("Received new task - %s", task.uid) + await self.backend.set_task_status(task, TaskState.STARTED) + + if self.task_timeout: + try: + # asyncio.timeout is Py3.11+ + async with asyncio.timeout(self.task_timeout): # type: ignore + await self.process(task) + except asyncio.TimeoutError as e: + raise TaskTimeoutError from e + else: + await self.process(task) + self.log.info("Task done - %s", task.uid) + except (Exception, TaskTimeoutError, CancelledError): + exc_info = sys.exc_info() + exception_str = traceback.format_exception(*exc_info) + + await self.backend.increment_metrics( + KartonMetrics.TASK_CRASHED, self.identity + ) + self.log.exception("Failed to process task - %s", task.uid) + finally: + await self.backend.increment_metrics( + KartonMetrics.TASK_CONSUMED, self.identity + ) + + task_state = TaskState.FINISHED + + # report the task status as crashed + # if an exception was caught while processing + if exception_str is not None: + task_state = TaskState.CRASHED + task.error = exception_str + + await self.backend.set_task_status(task, task_state) + + async def internal_process(self, task: Task) -> None: + """ + The internal side of :py:meth:`Consumer.process` function, takes care of + synchronizing the task state, handling errors and running task hooks. + + :param task: Task object to process + + :meta private: + """ + try: + self.current_task = task + + if not task.matches_filters(self.filters): + self.log.info("Task rejected because binds are no longer valid.") + await self.backend.set_task_status(task, TaskState.FINISHED) + # Task rejected: end of processing + return + + await self._internal_process(task) + finally: + if self.concurrency_semaphore is not None: + self.concurrency_semaphore.release() + self.current_task = None + + @property + def _bind(self) -> KartonBind: + return KartonBind( + identity=self.identity, + info=self.__class__.__doc__, + version=__version__, + filters=self.filters, + persistent=self.persistent, + service_version=self.__class__.version, + is_async=True, + ) + + @classmethod + def args_parser(cls) -> argparse.ArgumentParser: + parser = super().args_parser() + # store_false defaults to True, we intentionally want None there + parser.add_argument( + "--non-persistent", + action="store_const", + const=False, + dest="persistent", + help="Run service with non-persistent queue", + ) + parser.add_argument( + "--task-timeout", + type=int, + help="Limit task execution time", + ) + parser.add_argument( + "--concurrency-limit", + type=int, + help="Limit number of concurrent tasks", + ) + return parser + + @classmethod + def config_from_args(cls, config: Config, args: argparse.Namespace) -> None: + super().config_from_args(config, args) + config.load_from_dict( + { + "karton": { + "persistent": args.persistent, + "task_timeout": args.task_timeout, + "concurrency_limit": args.concurrency_limit, + } + } + ) + + async def _loop(self) -> None: + """ + Blocking loop that consumes tasks and runs + :py:meth:`karton.Consumer.process` as a handler + + :meta private: + """ + self.log.info("Service %s started", self.identity) + + if self.task_timeout: + self.log.info(f"Task timeout is set to {self.task_timeout} seconds") + if self.concurrency_limit: + self.log.info(f"Concurrency limit is set to {self.concurrency_limit}") + + # Get the old binds and set the new ones atomically + old_bind = await self.backend.register_bind(self._bind) + + if not old_bind: + self.log.info("Service binds created.") + elif old_bind != self._bind: + self.log.info("Binds changed, old service instances should exit soon.") + + for task_filter in self.filters: + self.log.info("Binding on: %s", task_filter) + + concurrent_tasks: List[asyncio.Task] = [] + + try: + while True: + current_bind = await self.backend.get_bind(self.identity) + if current_bind != self._bind: + self.log.info("Binds changed, shutting down.") + break + if self.concurrency_semaphore is not None: + await self.concurrency_semaphore.acquire() + task = await self.backend.consume_routed_task(self.identity) + if task: + coro_task = asyncio.create_task(self.internal_process(task)) + concurrent_tasks.append(coro_task) + # Garbage collection and exception propagation + # for finished concurrent tasks + unfinished_tasks: List[asyncio.Task] = [] + for coro_task in concurrent_tasks: + if coro_task.done(): + # Propagate possible unhandled exception + coro_task.result() + else: + unfinished_tasks.append(coro_task) + concurrent_tasks = unfinished_tasks + finally: + # Finally handles shutdown events: + # - main loop cancellation (SIGINT/SIGTERM) + # - unhandled exception in internal_process + # First cancel all pending tasks + for coro_task in concurrent_tasks: + if not coro_task.done(): + coro_task.cancel() + # Then gather all tasks to finalize them + await asyncio.gather(*concurrent_tasks) + + +class Karton(Consumer, Producer): + """ + This glues together Consumer and Producer - which is the most common use case + """ + + def __init__( + self, + config: Optional[Config] = None, + identity: Optional[str] = None, + backend: Optional[KartonAsyncBackend] = None, + ) -> None: + super().__init__(config=config, identity=identity, backend=backend) diff --git a/karton/core/asyncio/logger.py b/karton/core/asyncio/logger.py new file mode 100644 index 00000000..5f24290e --- /dev/null +++ b/karton/core/asyncio/logger.py @@ -0,0 +1,57 @@ +""" +asyncio implementation of KartonLogHandler +""" + +import asyncio +import logging +import platform +from typing import Any, Dict, Optional, Tuple + +from karton.core.logger import LogLineFormatterMixin + +from .backend import KartonAsyncBackend + +HOSTNAME = platform.node() + +QueuedRecord = Optional[Tuple[Dict[str, Any], str]] + + +async def async_log_consumer( + queue: asyncio.Queue[QueuedRecord], backend: KartonAsyncBackend, channel: str +) -> None: + while True: + item = await queue.get() + if not item: + break + log_line, levelname = item + await backend.produce_log(log_line, logger_name=channel, level=levelname) + + +class KartonAsyncLogHandler(logging.Handler, LogLineFormatterMixin): + """ + logging.Handler that passes logs to the Karton backend. + """ + + def __init__(self, backend: KartonAsyncBackend, channel: str) -> None: + logging.Handler.__init__(self) + self._consumer: Optional[asyncio.Task] = None + self._queue: asyncio.Queue[QueuedRecord] = asyncio.Queue() + self._backend = backend + self._channel = channel + + def emit(self, record: logging.LogRecord) -> None: + log_line = self.prepare_log_line(record) + self._queue.put_nowait((log_line, record.levelname)) + + def start_consuming(self): + if self._consumer is not None: + raise RuntimeError("Consumer already started") + self._consumer = asyncio.create_task( + async_log_consumer(self._queue, self._backend, self._channel) + ) + + async def stop_consuming(self): + if self._consumer is None: + raise RuntimeError("Consumer is not started") + self._queue.put_nowait(None) # Signal that queue is finished + await self._consumer diff --git a/karton/core/asyncio/resource.py b/karton/core/asyncio/resource.py new file mode 100644 index 00000000..91284a31 --- /dev/null +++ b/karton/core/asyncio/resource.py @@ -0,0 +1,384 @@ +import contextlib +import os +import shutil +import tempfile +import zipfile +from io import BytesIO +from typing import IO, TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Union + +from karton.core.resource import LocalResourceBase, ResourceBase + +if TYPE_CHECKING: + from .backend import KartonAsyncBackend + + +class LocalResource(LocalResourceBase): + """ + Represents local resource with arbitrary binary data e.g. file contents. + + Local resources will be uploaded to object hub (S3) during + task dispatching. + + .. code-block:: python + + # Creating resource from bytes + sample = Resource("original_name.exe", content=b"X5O!P%@AP[4\\ + PZX54(P^)7CC)7}$EICAR-STANDARD-ANT...") + + # Creating resource from path + sample = Resource("original_name.exe", path="sample/original_name.exe") + + :param name: Name of the resource (e.g. name of file) + :param content: Resource content + :param path: Path of file with resource content + :param bucket: Alternative S3 bucket for resource + :param metadata: Resource metadata + :param uid: Alternative S3 resource id + :param sha256: Resource sha256 hash + :param fd: Seekable file descriptor + :param _flags: Resource flags + :param _close_fd: Close file descriptor after upload (default: False) + """ + + def __init__( + self, + name: str, + content: Optional[Union[str, bytes]] = None, + path: Optional[str] = None, + bucket: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + uid: Optional[str] = None, + sha256: Optional[str] = None, + fd: Optional[IO[bytes]] = None, + _flags: Optional[List[str]] = None, + _close_fd: bool = False, + ) -> None: + super().__init__( + name=name, + content=content, + path=path, + bucket=bucket, + metadata=metadata, + uid=uid, + sha256=sha256, + fd=fd, + _flags=_flags, + _close_fd=_close_fd, + ) + + async def _upload(self, backend: "KartonAsyncBackend") -> None: + """Internal function for uploading resources + + :param backend: KartonBackend to use while uploading the resource + + :meta private: + """ + + # Note: never transform resource into Remote + # Multiple task dispatching with same local, in that case resource + # can be deleted between tasks. + if self.bucket is None: + raise RuntimeError( + "Resource object can't be uploaded because its bucket is not set" + ) + + if self._content: + # Upload contents + await backend.upload_object(self.bucket, self.uid, self._content) + elif self.fd: + if self.fd.tell() != 0: + raise RuntimeError( + f"Resource object can't be uploaded: " + f"file descriptor must point at first byte " + f"(fd.tell = {self.fd.tell()})" + ) + # Upload contents from fd + await backend.upload_object(self.bucket, self.uid, self.fd) + # If file descriptor is managed by Resource, close it after upload + if self._close_fd: + self.fd.close() + elif self._path: + # Upload file provided by path + await backend.upload_object_from_file(self.bucket, self.uid, self._path) + + async def upload(self, backend: "KartonAsyncBackend") -> None: + """Internal function for uploading resources + + :param backend: KartonBackend to use while uploading the resource + + :meta private: + """ + if not self._content and not self._path and not self.fd: + raise RuntimeError("Can't upload resource without content") + await self._upload(backend) + + +Resource = LocalResource + + +class RemoteResource(ResourceBase): + """ + Keeps reference to remote resource object shared between subsystems + via object storage (S3) + + Should never be instantiated directly by subsystem, but can be directly passed to + outgoing payload. + + :param name: Name of the resource (e.g. name of file) + :param bucket: Alternative S3 bucket for resource + :param metadata: Resource metadata + :param uid: Alternative S3 resource id + :param size: Resource size + :param backend: :py:meth:`KartonBackend` to bind to this resource + :param sha256: Resource sha256 hash + :param _flags: Resource flags + """ + + def __init__( + self, + name: str, + bucket: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + uid: Optional[str] = None, + size: Optional[int] = None, + backend: Optional["KartonAsyncBackend"] = None, + sha256: Optional[str] = None, + _flags: Optional[List[str]] = None, + ) -> None: + super(RemoteResource, self).__init__( + name, + bucket=bucket, + metadata=metadata, + sha256=sha256, + _uid=uid, + _size=size, + _flags=_flags, + ) + self.backend = backend + + def loaded(self) -> bool: + """ + Checks whether resource is loaded into memory + + :return: Flag indicating if the resource is loaded or not + """ + return self._content is not None + + @property + def content(self) -> bytes: + """ + Resource content. Performs download when resource was not loaded before. + + :return: Content bytes + """ + if self._content is None: + raise RuntimeError( + "Resource object needs to be explicitly downloaded first" + ) + return self._content + + @classmethod + def from_dict( + cls, dict: Dict[str, Any], backend: Optional["KartonAsyncBackend"] + ) -> "RemoteResource": + """ + Internal deserialization method for remote resources + + :param dict: Serialized information about resource + :param backend: KartonBackend object + :return: Deserialized :py:meth:`RemoteResource` object + + :meta private: + """ + # Backwards compatibility + metadata = dict.get("metadata", {}) + if "sha256" in dict: + metadata["sha256"] = dict["sha256"] + + return cls( + name=dict["name"], + metadata=metadata, + bucket=dict["bucket"], + uid=dict["uid"], + size=dict.get("size"), # Backwards compatibility (2.x.x) + backend=backend, + _flags=dict.get("flags"), # Backwards compatibility (3.x.x) + ) + + def unload(self) -> None: + """ + Unloads resource object from memory + """ + self._content = None + + async def download(self) -> bytes: + """ + Downloads remote resource content from object hub into memory. + + .. code-block:: python + + sample = self.current_task.get_resource("sample") + + # Ensure that resource will be downloaded before it will be + # passed to processing method + sample.download() + + self.process_sample(sample) + + :return: Downloaded content bytes + """ + if self.backend is None: + raise RuntimeError( + ( + "Resource object can't be downloaded because it's not bound to " + "the backend" + ) + ) + if self.bucket is None: + raise RuntimeError( + "Resource object can't be downloaded because its bucket is not set" + ) + + self._content = await self.backend.download_object(self.bucket, self.uid) + return self._content + + async def download_to_file(self, path: str) -> None: + """ + Downloads remote resource into file. + + .. code-block:: python + + sample = self.current_task.get_resource("sample") + + sample.download_to_file("sample/sample.exe") + + with open("sample/sample.exe", "rb") as f: + contents = f.read() + + :param path: Path to download the resource into + """ + if self.backend is None: + raise RuntimeError( + ( + "Resource object can't be downloaded because it's not bound to " + "the backend" + ) + ) + if self.bucket is None: + raise RuntimeError( + "Resource object can't be downloaded because its bucket is not set" + ) + + await self.backend.download_object_to_file(self.bucket, self.uid, path) + + @contextlib.asynccontextmanager + async def download_temporary_file(self, suffix=None) -> AsyncIterator[IO[bytes]]: + """ + Downloads remote resource into named temporary file. + + .. code-block:: python + + sample = self.current_task.get_resource("sample") + + with sample.download_temporary_file() as f: + contents = f.read() + path = f.name + + # Temporary file is deleted after exitting the "with" scope + + :return: ContextManager with the temporary file + """ + # That tempfile-fu is necessary because minio.fget_object removes file + # under provided path and renames its own part-file with downloaded content + # under previously deleted path + # Weird move, but ok... + tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) + tmp.close() + try: + await self.download_to_file(tmp.name) + with open(tmp.name, "rb") as f: + yield f + finally: + os.remove(tmp.name) + + @contextlib.asynccontextmanager + async def zip_file(self) -> AsyncIterator[zipfile.ZipFile]: + """ + If resource contains a Zip file, downloads it to the temporary file + and wraps it with ZipFile object. + + .. code-block:: python + + dumps = self.current_task.get_resource("dumps") + + with dumps.zip_file() as zipf: + print("Fetched dumps: ", zipf.namelist()) + + By default: method downloads zip into temporary file, which is deleted after + leaving the context. If you want to load zip into memory, + call :py:meth:`RemoteResource.download` first. + + If you want to pre-download Zip under specified path and open it using + zipfile module, you need to do this manually: + + .. code-block:: python + + dumps = self.current_task.get_resource("dumps") + + # Download zip file + zip_path = "./dumps.zip" + dumps.download_to_file(zip_path) + + zipf = zipfile.Zipfile(zip_path) + + :return: ContextManager with zipfile + """ + if self._content: + yield zipfile.ZipFile(BytesIO(self._content)) + else: + async with self.download_temporary_file() as f: + yield zipfile.ZipFile(f) + + async def extract_to_directory(self, path: str) -> None: + """ + If resource contains a Zip file, extracts files contained in Zip into + provided path. + + By default: method downloads zip into temporary file, which is deleted + after extraction. If you want to load zip into memory, call + :py:meth:`RemoteResource.download` first. + + :param path: Directory path where the resource should be unpacked + """ + async with self.zip_file() as zf: + zf.extractall(path) + + @contextlib.asynccontextmanager + async def extract_temporary(self) -> AsyncIterator[str]: + """ + If resource contains a Zip file, extracts files contained in Zip + to the temporary directory. + + Returns path of directory with extracted files. Directory is recursively + deleted after leaving the context. + + .. code-block:: python + + dumps = self.current_task.get_resource("dumps") + + with dumps.extract_temporary() as dumps_path: + print("Fetched dumps:", os.listdir(dumps_path)) + + By default: method downloads zip into temporary file, which is deleted + after extraction. If you want to load zip into memory, call + :py:meth:`RemoteResource.download` first. + + :return: ContextManager with the temporary directory + """ + tmpdir = tempfile.mkdtemp() + try: + await self.extract_to_directory(tmpdir) + yield tmpdir + yield tmpdir + finally: + shutil.rmtree(tmpdir) diff --git a/karton/core/backend.py b/karton/core/backend.py index ea6ccb82..09df6934 100644 --- a/karton/core/backend.py +++ b/karton/core/backend.py @@ -21,6 +21,7 @@ from .config import Config from .exceptions import InvalidIdentityError +from .resource import RemoteResource from .task import Task, TaskPriority, TaskState from .utils import chunks, chunks_iter @@ -33,12 +34,20 @@ KartonBind = namedtuple( "KartonBind", - ["identity", "info", "version", "persistent", "filters", "service_version"], + [ + "identity", + "info", + "version", + "persistent", + "filters", + "service_version", + "is_async", + ], ) KartonOutputs = namedtuple("KartonOutputs", ["identity", "outputs"]) -logger = logging.getLogger("karton.core.backend") +logger = logging.getLogger(__name__) class KartonMetrics(enum.Enum): @@ -103,13 +112,13 @@ def parse_client_name( ) -class KartonBackend: +class KartonBackendBase: def __init__( self, config: Config, identity: Optional[str] = None, service_info: Optional[KartonServiceInfo] = None, - ) -> None: + ): self.config = config if identity is not None: @@ -117,59 +126,6 @@ def __init__( self.identity = identity self.service_info = service_info - self.redis = self.make_redis( - config, identity=identity, service_info=service_info - ) - - endpoint = config.get("s3", "address") - access_key = config.get("s3", "access_key") - secret_key = config.get("s3", "secret_key") - iam_auth = config.getboolean("s3", "iam_auth") - - if not endpoint: - raise RuntimeError("Attempting to get S3 client without an endpoint set") - - if access_key and secret_key and iam_auth: - logger.warning( - "Warning: iam is turned on and both S3 access key and secret key are" - " provided" - ) - - if iam_auth: - s3_client = self.iam_auth_s3(endpoint) - if s3_client: - self.s3 = s3_client - return - - if access_key is None or secret_key is None: - raise RuntimeError( - "Attempting to get S3 client without an access_key/secret_key set" - ) - - self.s3 = boto3.client( - "s3", - endpoint_url=endpoint, - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - ) - - def iam_auth_s3(self, endpoint: str): - boto_session = get_session() - iam_providers = [ - ContainerProvider(), - InstanceMetadataProvider( - iam_role_fetcher=InstanceMetadataFetcher(timeout=1000, num_attempts=2) - ), - ] - - for provider in iam_providers: - creds = provider.load() - if creds: - boto_session._credentials = creds # type: ignore - return boto3.Session(botocore_session=boto_session).client( - "s3", - endpoint_url=endpoint, - ) @staticmethod def _validate_identity(identity: str): @@ -179,48 +135,6 @@ def _validate_identity(identity: str): f"Karton identity should not contain {disallowed_chars}" ) - @staticmethod - def make_redis( - config, - identity: Optional[str] = None, - service_info: Optional[KartonServiceInfo] = None, - ) -> StrictRedis: - """ - Create and test a Redis connection. - - :param config: The karton configuration - :param identity: Karton service identity - :param service_info: Additional service identity metadata - :return: Redis connection - """ - if service_info is not None: - client_name: Optional[str] = service_info.make_client_name() - else: - client_name = identity - - redis_args = { - "host": config["redis"]["host"], - "port": config.getint("redis", "port", 6379), - "db": config.getint("redis", "db", 0), - "username": config.get("redis", "username"), - "password": config.get("redis", "password"), - "client_name": client_name, - # set socket_timeout to None if set to 0 - "socket_timeout": config.getint("redis", "socket_timeout", 30) or None, - "decode_responses": True, - } - try: - redis = StrictRedis(**redis_args) - redis.ping() - except AuthenticationError: - # Maybe we've sent a wrong password. - # Or maybe the server is not (yet) password protected - # To make smooth transition possible, try to login insecurely - del redis_args["password"] - redis = StrictRedis(**redis_args) - redis.ping() - return redis - @property def default_bucket_name(self) -> str: bucket_name = self.config.get("s3", "bucket") @@ -270,6 +184,7 @@ def serialize_bind(bind: KartonBind) -> str: "filters": bind.filters, "persistent": bind.persistent, "service_version": bind.service_version, + "is_async": bind.is_async, }, sort_keys=True, ) @@ -294,6 +209,7 @@ def unserialize_bind(identity: str, bind_data: str) -> KartonBind: persistent=not identity.endswith(".test"), filters=bind, service_version=None, + is_async=False, ) return KartonBind( identity=identity, @@ -302,6 +218,7 @@ def unserialize_bind(identity: str, bind_data: str) -> KartonBind: persistent=bind["persistent"], filters=bind["filters"], service_version=bind.get("service_version"), + is_async=bind.get("is_async", False), ) @staticmethod @@ -316,6 +233,126 @@ def unserialize_output(identity: str, output_data: Set[str]) -> KartonOutputs: output = [json.loads(output_type) for output_type in output_data] return KartonOutputs(identity=identity, outputs=output) + @staticmethod + def _log_channel(logger_name: Optional[str], level: Optional[str]) -> str: + return ".".join( + [KARTON_LOG_CHANNEL, (level or "*").lower(), logger_name or "*"] + ) + + +class KartonBackend(KartonBackendBase): + def __init__( + self, + config: Config, + identity: Optional[str] = None, + service_info: Optional[KartonServiceInfo] = None, + ) -> None: + super().__init__(config, identity, service_info) + self.redis = self.make_redis( + config, identity=identity, service_info=service_info + ) + + endpoint = config.get("s3", "address") + access_key = config.get("s3", "access_key") + secret_key = config.get("s3", "secret_key") + iam_auth = config.getboolean("s3", "iam_auth") + + if not endpoint: + raise RuntimeError("Attempting to get S3 client without an endpoint set") + + if access_key and secret_key and iam_auth: + logger.warning( + "Warning: iam is turned on and both S3 access key and secret key are" + " provided" + ) + + if iam_auth: + s3_client = self.iam_auth_s3(endpoint) + if s3_client: + self.s3 = s3_client + return + + if access_key is None or secret_key is None: + raise RuntimeError( + "Attempting to get S3 client without an access_key/secret_key set" + ) + + self.s3 = boto3.client( + "s3", + endpoint_url=endpoint, + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + ) + + def iam_auth_s3(self, endpoint: str): + boto_session = get_session() + iam_providers = [ + ContainerProvider(), + InstanceMetadataProvider( + iam_role_fetcher=InstanceMetadataFetcher(timeout=1000, num_attempts=2) + ), + ] + + for provider in iam_providers: + creds = provider.load() + if creds: + boto_session._credentials = creds # type: ignore + return boto3.Session(botocore_session=boto_session).client( + "s3", + endpoint_url=endpoint, + ) + + @staticmethod + def make_redis( + config, + identity: Optional[str] = None, + service_info: Optional[KartonServiceInfo] = None, + ) -> StrictRedis: + """ + Create and test a Redis connection. + + :param config: The karton configuration + :param identity: Karton service identity + :param service_info: Additional service identity metadata + :return: Redis connection + """ + if service_info is not None: + client_name: Optional[str] = service_info.make_client_name() + else: + client_name = identity + + redis_args = { + "host": config["redis"]["host"], + "port": config.getint("redis", "port", 6379), + "db": config.getint("redis", "db", 0), + "username": config.get("redis", "username"), + "password": config.get("redis", "password"), + "client_name": client_name, + # set socket_timeout to None if set to 0 + "socket_timeout": config.getint("redis", "socket_timeout", 30) or None, + "decode_responses": True, + } + try: + redis = StrictRedis(**redis_args) + redis.ping() + except AuthenticationError: + # Maybe we've sent a wrong password. + # Or maybe the server is not (yet) password protected + # To make smooth transition possible, try to login insecurely + del redis_args["password"] + redis = StrictRedis(**redis_args) + redis.ping() + return redis + + def unserialize_resource(self, resource_spec: Dict[str, Any]) -> RemoteResource: + """ + Unserializes resource into a RemoteResource object bound with current backend + + :param resource_spec: Resource specification + :return: RemoteResource object + """ + return RemoteResource.from_dict(resource_spec, backend=self) + def get_bind(self, identity: str) -> KartonBind: """ Get bind object for given identity @@ -426,7 +463,9 @@ def get_task(self, task_uid: str) -> Optional[Task]: task_data = self.redis.get(f"{KARTON_TASK_NAMESPACE}:{task_uid}") if not task_data: return None - return Task.unserialize(task_data, backend=self) + return Task.unserialize( + task_data, resource_unserializer=self.unserialize_resource + ) def get_tasks( self, @@ -449,7 +488,11 @@ def get_tasks( chunk_size, ) return [ - Task.unserialize(task_data, backend=self, parse_resources=parse_resources) + Task.unserialize( + task_data, + parse_resources=parse_resources, + resource_unserializer=self.unserialize_resource, + ) for chunk in keys for task_data in self.redis.mget(chunk) if task_data is not None @@ -464,7 +507,9 @@ def _iter_tasks( for chunk in chunks_iter(task_keys, chunk_size): yield from ( Task.unserialize( - task_data, backend=self, parse_resources=parse_resources + task_data, + parse_resources=parse_resources, + resource_unserializer=self.unserialize_resource, ) for task_data in self.redis.mget(chunk) if task_data is not None @@ -550,7 +595,9 @@ def _iter_legacy_task_tree( lambda task: task.root_uid == root_uid, ( Task.unserialize( - task_data, backend=self, parse_resources=parse_resources + task_data, + parse_resources=parse_resources, + resource_unserializer=self.unserialize_resource, ) for task_data in self.redis.mget(chunk) if task_data is not None @@ -798,12 +845,6 @@ def restart_task(self, task: Task) -> Task: p.execute() return new_task - @staticmethod - def _log_channel(logger_name: Optional[str], level: Optional[str]) -> str: - return ".".join( - [KARTON_LOG_CHANNEL, (level or "*").lower(), logger_name or "*"] - ) - def produce_log( self, log_record: Dict[str, Any], diff --git a/karton/core/base.py b/karton/core/base.py index ca0b8975..aa5b2df9 100644 --- a/karton/core/base.py +++ b/karton/core/base.py @@ -10,31 +10,15 @@ from .backend import KartonBackend, KartonServiceInfo from .config import Config from .logger import KartonLogHandler -from .task import Task +from .task import Task, get_current_task, set_current_task from .utils import HardShutdownInterrupt, StrictClassMethod, graceful_killer -class KartonBase(abc.ABC): - """ - Base class for all Karton services - - You can set an informative version information by setting the ``version`` class - attribute. - """ - - #: Karton service identity - identity: str = "" - #: Karton service version - version: Optional[str] = None - #: Include extended service information for non-consumer services - with_service_info: bool = False +class ConfigMixin: + identity: Optional[str] + version: Optional[str] - def __init__( - self, - config: Optional[Config] = None, - identity: Optional[str] = None, - backend: Optional[KartonBackend] = None, - ) -> None: + def __init__(self, config: Optional[Config] = None, identity: Optional[str] = None): self.config = config or Config() self.enable_publish_log = self.config.getboolean( "logging", "enable_publish", True @@ -50,25 +34,80 @@ def __init__( self.debug = self.config.getboolean("karton", "debug", False) - if self.debug: + if self.debug and self.identity: self.identity += "-" + os.urandom(4).hex() + "-dev" - self.service_info = None - if self.identity is not None and self.with_service_info: - self.service_info = KartonServiceInfo( - identity=self.identity, - karton_version=__version__, - service_version=self.version, - ) + @classmethod + def args_description(cls) -> str: + """Return short description for argument parser.""" + if not cls.__doc__: + return "" + return textwrap.dedent(cls.__doc__).strip().splitlines()[0] - self.backend = backend or KartonBackend( - self.config, identity=self.identity, service_info=self.service_info + @classmethod + def args_parser(cls) -> argparse.ArgumentParser: + """ + Return ArgumentParser for main() class method. + + This method should be overridden and call super methods + if you want to add more arguments. + """ + parser = argparse.ArgumentParser(description=cls.args_description()) + parser.add_argument( + "--version", action="version", version=cast(str, cls.version) + ) + parser.add_argument("--config-file", help="Alternative configuration path") + parser.add_argument( + "--identity", help="Alternative identity for Karton service" ) + parser.add_argument("--log-level", help="Logging level of Karton logger") + parser.add_argument( + "--debug", help="Enable debugging mode", action="store_true", default=None + ) + return parser - self._log_handler = KartonLogHandler( - backend=self.backend, channel=self.identity + @classmethod + def config_from_args(cls, config: Config, args: argparse.Namespace) -> None: + """ + Updates configuration with settings from arguments + + This method should be overridden and call super methods + if you want to add more arguments. + """ + config.load_from_dict( + { + "karton": { + "identity": args.identity, + "debug": args.debug, + }, + "logging": {"level": args.log_level}, + } ) - self.current_task: Optional[Task] = None + + @classmethod + def karton_from_args(cls, args: Optional[argparse.Namespace] = None): + """ + Returns Karton instance configured using configuration files + and provided arguments + + Used by :py:meth:`KartonServiceBase.main` method + """ + if args is None: + parser = cls.args_parser() + args = parser.parse_args() + config = Config(path=args.config_file) + cls.config_from_args(config, args) + return cls(config=config) + + +class LoggingMixin: + config: Config + identity: Optional[str] + debug: bool + enable_publish_log: bool + + def __init__(self, log_handler: logging.Handler): + self._log_handler = log_handler def setup_logger(self, level: Optional[Union[str, int]] = None) -> None: """ @@ -115,7 +154,7 @@ def setup_logger(self, level: Optional[Union[str, int]] = None) -> None: logger.addHandler(self._log_handler) @property - def log_handler(self) -> KartonLogHandler: + def log_handler(self) -> logging.Handler: """ Return KartonLogHandler bound to this Karton service. @@ -141,67 +180,52 @@ def log(self) -> logging.Logger: """ return logging.getLogger(self.identity) - @classmethod - def args_description(cls) -> str: - """Return short description for argument parser.""" - if not cls.__doc__: - return "" - return textwrap.dedent(cls.__doc__).strip().splitlines()[0] - @classmethod - def args_parser(cls) -> argparse.ArgumentParser: - """ - Return ArgumentParser for main() class method. +class KartonBase(abc.ABC, ConfigMixin, LoggingMixin): + """ + Base class for all Karton services - This method should be overridden and call super methods - if you want to add more arguments. - """ - parser = argparse.ArgumentParser(description=cls.args_description()) - parser.add_argument( - "--version", action="version", version=cast(str, cls.version) - ) - parser.add_argument("--config-file", help="Alternative configuration path") - parser.add_argument( - "--identity", help="Alternative identity for Karton service" - ) - parser.add_argument("--log-level", help="Logging level of Karton logger") - parser.add_argument( - "--debug", help="Enable debugging mode", action="store_true", default=None - ) - return parser + You can set an informative version information by setting the ``version`` class + attribute. + """ - @classmethod - def config_from_args(cls, config: Config, args: argparse.Namespace) -> None: - """ - Updates configuration with settings from arguments + #: Karton service identity + identity: str = "" + #: Karton service version + version: Optional[str] = None + #: Include extended service information for non-consumer services + with_service_info: bool = False - This method should be overridden and call super methods - if you want to add more arguments. - """ - config.load_from_dict( - { - "karton": { - "identity": args.identity, - "debug": args.debug, - }, - "logging": {"level": args.log_level}, - } + def __init__( + self, + config: Optional[Config] = None, + identity: Optional[str] = None, + backend: Optional[KartonBackend] = None, + ) -> None: + ConfigMixin.__init__(self, config, identity) + + self.service_info = None + if self.identity is not None and self.with_service_info: + self.service_info = KartonServiceInfo( + identity=self.identity, + karton_version=__version__, + service_version=self.version, + ) + + self.backend = backend or KartonBackend( + self.config, identity=self.identity, service_info=self.service_info ) - @classmethod - def karton_from_args(cls, args: Optional[argparse.Namespace] = None): - """ - Returns Karton instance configured using configuration files - and provided arguments + log_handler = KartonLogHandler(backend=self.backend, channel=self.identity) + LoggingMixin.__init__(self, log_handler) - Used by :py:meth:`KartonServiceBase.main` method - """ - if args is None: - parser = cls.args_parser() - args = parser.parse_args() - config = Config(path=args.config_file) - cls.config_from_args(config, args) - return cls(config=config) + @property + def current_task(self) -> Optional[Task]: + return get_current_task() + + @current_task.setter + def current_task(self, task: Optional[Task]): + set_current_task(task) class KartonServiceBase(KartonBase): diff --git a/karton/core/karton.py b/karton/core/karton.py index e2bbece0..363026af 100644 --- a/karton/core/karton.py +++ b/karton/core/karton.py @@ -137,7 +137,7 @@ def __init__( ) if self.task_timeout is None: self.task_timeout = self.config.getint("karton", "task_timeout") - self.current_task: Optional[Task] = None + self._pre_hooks: List[Tuple[Optional[str], Callable[[Task], None]]] = [] self._post_hooks: List[ Tuple[ @@ -170,19 +170,18 @@ def internal_process(self, task: Task) -> None: """ self.current_task = task - self.log_handler.set_task(self.current_task) - if not self.current_task.matches_filters(self.filters): + if not task.matches_filters(self.filters): self.log.info("Task rejected because binds are no longer valid.") - self.backend.set_task_status(self.current_task, TaskState.FINISHED) + self.backend.set_task_status(task, TaskState.FINISHED) # Task rejected: end of processing return exception_str = None try: - self.log.info("Received new task - %s", self.current_task.uid) - self.backend.set_task_status(self.current_task, TaskState.STARTED) + self.log.info("Received new task - %s", task.uid) + self.backend.set_task_status(task, TaskState.STARTED) self._run_pre_hooks() @@ -190,22 +189,22 @@ def internal_process(self, task: Task) -> None: try: if self.task_timeout: with timeout(self.task_timeout): - self.process(self.current_task) + self.process(task) else: - self.process(self.current_task) + self.process(task) except (Exception, TaskTimeoutError) as exc: saved_exception = exc raise finally: self._run_post_hooks(saved_exception) - self.log.info("Task done - %s", self.current_task.uid) + self.log.info("Task done - %s", task.uid) except (Exception, TaskTimeoutError): exc_info = sys.exc_info() exception_str = traceback.format_exception(*exc_info) self.backend.increment_metrics(KartonMetrics.TASK_CRASHED, self.identity) - self.log.exception("Failed to process task - %s", self.current_task.uid) + self.log.exception("Failed to process task - %s", task.uid) finally: self.backend.increment_metrics(KartonMetrics.TASK_CONSUMED, self.identity) @@ -215,9 +214,10 @@ def internal_process(self, task: Task) -> None: # if an exception was caught while processing if exception_str is not None: task_state = TaskState.CRASHED - self.current_task.error = exception_str + task.error = exception_str - self.backend.set_task_status(self.current_task, task_state) + self.backend.set_task_status(task, task_state) + self.current_task = None @property def _bind(self) -> KartonBind: @@ -228,6 +228,7 @@ def _bind(self) -> KartonBind: filters=self.filters, persistent=self.persistent, service_version=self.__class__.version, + is_async=False, ) @classmethod diff --git a/karton/core/logger.py b/karton/core/logger.py index 7989563c..04be5815 100644 --- a/karton/core/logger.py +++ b/karton/core/logger.py @@ -2,30 +2,18 @@ import platform import traceback import warnings -from typing import Optional +from typing import Any, Callable, Dict from .backend import KartonBackend -from .task import Task +from .task import get_current_task HOSTNAME = platform.node() -class KartonLogHandler(logging.Handler): - """ - logging.Handler that passes logs to the Karton backend. - """ - - def __init__(self, backend: KartonBackend, channel: str) -> None: - logging.Handler.__init__(self) - self.backend = backend - self.task: Optional[Task] = None - self.is_consumer_active: bool = True - self.channel: str = channel +class LogLineFormatterMixin: + format: Callable[[logging.LogRecord], str] - def set_task(self, task: Task) -> None: - self.task = task - - def emit(self, record: logging.LogRecord) -> None: + def prepare_log_line(self, record: logging.LogRecord) -> Dict[str, Any]: ignore_fields = [ "args", "asctime", @@ -54,11 +42,27 @@ def emit(self, record: logging.LogRecord) -> None: log_line["type"] = "log" log_line["message"] = self.format(record) - if self.task is not None: - log_line["task"] = self.task.serialize() + current_task = get_current_task() + if current_task is not None: + log_line["task"] = current_task.serialize() log_line["hostname"] = HOSTNAME + return log_line + +class KartonLogHandler(logging.Handler, LogLineFormatterMixin): + """ + logging.Handler that passes logs to the Karton backend. + """ + + def __init__(self, backend: KartonBackend, channel: str) -> None: + logging.Handler.__init__(self) + self.backend = backend + self.is_consumer_active: bool = True + self.channel: str = channel + + def emit(self, record: logging.LogRecord) -> None: + log_line = self.prepare_log_line(record) log_consumed = self.backend.produce_log( log_line, logger_name=self.channel, level=record.levelname ) diff --git a/karton/core/resource.py b/karton/core/resource.py index bf8866d9..426aecd7 100644 --- a/karton/core/resource.py +++ b/karton/core/resource.py @@ -150,34 +150,7 @@ def to_dict(self) -> Dict[str, Any]: } -class LocalResource(ResourceBase): - """ - Represents local resource with arbitrary binary data e.g. file contents. - - Local resources will be uploaded to object hub (S3) during - task dispatching. - - .. code-block:: python - - # Creating resource from bytes - sample = Resource("original_name.exe", content=b"X5O!P%@AP[4\\ - PZX54(P^)7CC)7}$EICAR-STANDARD-ANT...") - - # Creating resource from path - sample = Resource("original_name.exe", path="sample/original_name.exe") - - :param name: Name of the resource (e.g. name of file) - :param content: Resource content - :param path: Path of file with resource content - :param bucket: Alternative S3 bucket for resource - :param metadata: Resource metadata - :param uid: Alternative S3 resource id - :param sha256: Resource sha256 hash - :param fd: Seekable file descriptor - :param _flags: Resource flags - :param _close_fd: Close file descriptor after upload (default: False) - """ - +class LocalResourceBase(ResourceBase): def __init__( self, name: str, @@ -194,7 +167,7 @@ def __init__( if len(list(filter(None, [path, content, fd]))) != 1: raise ValueError("You must exclusively provide a path, content or fd") - super(LocalResource, self).__init__( + super().__init__( name, content=content, path=path, @@ -247,7 +220,7 @@ def from_directory( bucket: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, uid: Optional[str] = None, - ) -> "LocalResource": + ) -> "LocalResourceBase": """ Resource extension, allowing to pass whole directory as a zipped resource. @@ -305,6 +278,35 @@ def from_directory( _close_fd=True, ) + +class LocalResource(LocalResourceBase): + """ + Represents local resource with arbitrary binary data e.g. file contents. + + Local resources will be uploaded to object hub (S3) during + task dispatching. + + .. code-block:: python + + # Creating resource from bytes + sample = Resource("original_name.exe", content=b"X5O!P%@AP[4\\ + PZX54(P^)7CC)7}$EICAR-STANDARD-ANT...") + + # Creating resource from path + sample = Resource("original_name.exe", path="sample/original_name.exe") + + :param name: Name of the resource (e.g. name of file) + :param content: Resource content + :param path: Path of file with resource content + :param bucket: Alternative S3 bucket for resource + :param metadata: Resource metadata + :param uid: Alternative S3 resource id + :param sha256: Resource sha256 hash + :param fd: Seekable file descriptor + :param _flags: Resource flags + :param _close_fd: Close file descriptor after upload (default: False) + """ + def _upload(self, backend: "KartonBackend") -> None: """Internal function for uploading resources diff --git a/karton/core/task.py b/karton/core/task.py index 251e931a..bc92ddf2 100644 --- a/karton/core/task.py +++ b/karton/core/task.py @@ -3,6 +3,7 @@ import time import uuid import warnings +from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, @@ -24,6 +25,16 @@ import orjson +current_task: ContextVar[Optional["Task"]] = ContextVar("current_task") + + +def get_current_task() -> Optional["Task"]: + return current_task.get(None) + + +def set_current_task(task: Optional["Task"]): + current_task.set(task) + class TaskState(enum.Enum): DECLARED = "Declared" # Task declared in TASKS_QUEUE @@ -375,12 +386,15 @@ def unserialize( data: Union[str, bytes], backend: Optional["KartonBackend"] = None, parse_resources: bool = True, + resource_unserializer: Optional[Callable[[Dict], Any]] = None, ) -> "Task": """ Unserialize Task instance from JSON string :param data: JSON-serialized task - :param backend: Backend instance to be bound to RemoteResource objects + :param backend: | + Backend instance to be bound to RemoteResource objects. + Deprecated: pass resource_unserializer instead. :param parse_resources: | If set to False (default is True), method doesn't deserialize '__karton_resource__' entries, which speeds up deserialization @@ -388,6 +402,9 @@ def unserialize( filtering based on status. When resource deserialization is turned off, Task.unserialize will try to use faster 3rd-party JSON parser (orjson). + :param resource_unserializer: | + Resource factory used for deserialization of __karton_resource__ + dictionary values. :return: Unserialized Task object :meta private: @@ -399,7 +416,12 @@ def unserialize_resources(value: Any) -> Any: RemoteResource object instances """ if isinstance(value, dict) and "__karton_resource__" in value: - return RemoteResource.from_dict(value["__karton_resource__"], backend) + if resource_unserializer is None: + return RemoteResource.from_dict( + value["__karton_resource__"], backend + ) + else: + return resource_unserializer(value["__karton_resource__"]) return value if not isinstance(data, str): diff --git a/requirements.txt b/requirements.txt index f52c6b2b..12074a84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ redis orjson -# https://github.com/boto/boto3/issues/4392 -boto3<1.36.0 +# Required by aioboto3 +boto3>=1.37.2,<1.37.4 +aioboto3==14.3.0 diff --git a/setup.cfg b/setup.cfg index 308072f2..1145b588 100644 --- a/setup.cfg +++ b/setup.cfg @@ -11,4 +11,7 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-boto3.*] ignore_missing_imports = True - +[mypy-aioboto3.*] +ignore_missing_imports = True +[mypy-aiobotocore.*] +ignore_missing_imports = True diff --git a/setup.py b/setup.py index d1e59002..a3ea2a6a 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ long_description=open("README.md").read(), long_description_content_type="text/markdown", namespace_packages=["karton"], - packages=["karton.core", "karton.system"], + packages=["karton.core", "karton.core.asyncio", "karton.system"], package_data={"karton.core": ["py.typed"]}, install_requires=open("requirements.txt").read().splitlines(), entry_points={