diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 4b1dec78c6..8febb38ad1 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -9,6 +9,7 @@ import sys import tempfile import traceback +import warnings from sys import exit from typing import Callable, List, Optional @@ -68,6 +69,23 @@ def _compute_array_job_index(): return offset +def _get_working_loop(): + """Returns a running event loop.""" + try: + return asyncio.get_running_loop() + except RuntimeError: + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + try: + return asyncio.get_event_loop_policy().get_event_loop() + # Since version 3.12, DeprecationWarning is emitted if there is no + # current event loop. + except DeprecationWarning: + loop = asyncio.get_event_loop_policy().new_event_loop() + asyncio.set_event_loop(loop) + return loop + + def _dispatch_execute( ctx: FlyteContext, load_task: Callable[[], PythonTask], @@ -107,7 +125,7 @@ def _dispatch_execute( if inspect.iscoroutine(outputs): # Handle eager-mode (async) tasks logger.info("Output is a coroutine") - outputs = asyncio.run(outputs) + outputs = _get_working_loop().run_until_complete(outputs) # Step3a if isinstance(outputs, VoidPromise): diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py index c25e2ae762..898d11a5ba 100644 --- a/tests/flytekit/unit/experimental/test_eager_workflows.py +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -1,4 +1,5 @@ import asyncio +import mock import os import sys import typing @@ -9,8 +10,13 @@ from hypothesis import given from flytekit import dynamic, task, workflow + +from flytekit.bin.entrypoint import _get_working_loop, _dispatch_execute +from flytekit.core import context_manager +from flytekit.core.promise import VoidPromise from flytekit.exceptions.user import FlyteValidationException from flytekit.experimental import EagerException, eager +from flytekit.models import literals as _literal_models from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile from flytekit.types.structured import StructuredDataset @@ -275,3 +281,26 @@ async def eager_wf_flyte_directory() -> str: result = asyncio.run(eager_wf_flyte_directory()) assert result == "some data" + + +@mock.patch("flytekit.core.utils.load_proto_from_file") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data") +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +@mock.patch("flytekit.core.utils.write_proto_to_file") +def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop): + """Test that event loop is preserved after executing eager workflow via dispatch.""" + + @eager + async def eager_wf(): + await asyncio.sleep(0.1) + return + + ctx = context_manager.FlyteContext.current_context() + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) + ) + ) as ctx: + _dispatch_execute(ctx, lambda: eager_wf, "inputs path", "outputs prefix") + loop_after_execute = asyncio.get_event_loop_policy().get_event_loop() + assert event_loop == loop_after_execute