diff --git a/src/lightning/app/core/app.py b/src/lightning/app/core/app.py index ffedad588e3c8..ed4c1d5c76f3f 100644 --- a/src/lightning/app/core/app.py +++ b/src/lightning/app/core/app.py @@ -29,6 +29,7 @@ from lightning.app import _console from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest from lightning.app.core.constants import ( + BATCH_DELTA_COUNT, DEBUG_ENABLED, FLOW_DURATION_SAMPLES, FLOW_DURATION_THRESHOLD, @@ -308,6 +309,14 @@ def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) except queue.Empty: return None + @staticmethod + def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]: + try: + timeout = timeout or q.default_timeout + return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT) + except queue.Empty: + return [] + def check_error_queue(self) -> None: exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type] if isinstance(exception, Exception): @@ -341,12 +350,15 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque while (time() - t0) < self.state_accumulate_wait: # TODO: Fetch all available deltas at once to reduce queue calls. - delta: Optional[ + received_deltas: List[ Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta] - ] = self.get_state_changed_from_queue( + ] = self.batch_get_state_changed_from_queue( self.delta_queue # type: ignore[assignment,arg-type] ) - if delta: + if len(received_deltas) == []: + break + + for delta in received_deltas: if isinstance(delta, _DeltaRequest): deltas.append(delta.delta) elif isinstance(delta, ComponentDelta): @@ -364,8 +376,6 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque deltas.append(delta) else: api_or_command_request_deltas.append(delta) - else: - break if api_or_command_request_deltas: _process_requests(self, api_or_command_request_deltas) diff --git a/src/lightning/app/core/constants.py b/src/lightning/app/core/constants.py index cc23ebd645c24..566fc87bc9438 100644 --- a/src/lightning/app/core/constants.py +++ b/src/lightning/app/core/constants.py @@ -98,6 +98,8 @@ def get_lightning_cloud_url() -> str: # directory where system customization sync files will be copied to be packed into app tarball SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync" +BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128")) + def enable_multiple_works_in_default_container() -> bool: return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0"))) diff --git a/src/lightning/app/core/queues.py b/src/lightning/app/core/queues.py index da941ae72503e..d37251c824616 100644 --- a/src/lightning/app/core/queues.py +++ b/src/lightning/app/core/queues.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import multiprocessing import pickle import queue # needed as import instead from/import for mocking in tests @@ -20,7 +21,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple from urllib.parse import urljoin import backoff @@ -28,6 +29,7 @@ from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout from lightning.app.core.constants import ( + BATCH_DELTA_COUNT, HTTP_QUEUE_REFRESH_INTERVAL, HTTP_QUEUE_REQUESTS_PER_SECOND, HTTP_QUEUE_TOKEN, @@ -189,6 +191,20 @@ def get(self, timeout: Optional[float] = None) -> Any: """ pass + @abstractmethod + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]: + """Returns the left most elements of the queue. + + Parameters + ---------- + timeout: + Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used. + A timeout of None can be used to block indefinitely. + count: + The number of element to get from the queue + + """ + @property def is_running(self) -> bool: """Returns True if the queue is running, False otherwise. @@ -214,6 +230,12 @@ def get(self, timeout: Optional[float] = None) -> Any: timeout = self.default_timeout return self.queue.get(timeout=timeout, block=(timeout is None)) + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]: + if timeout == 0: + timeout = self.default_timeout + # For multiprocessing, we can simply collect the latest upmost element + return [self.queue.get(timeout=timeout, block=(timeout is None))] + class RedisQueue(BaseQueue): @requires("redis") @@ -312,6 +334,9 @@ def get(self, timeout: Optional[float] = None) -> Any: raise queue.Empty return pickle.loads(out[1]) + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any: + return [self.get(timeout=timeout)] + def clear(self) -> None: """Clear all elements in the queue.""" self.redis.delete(self.name) @@ -366,7 +391,6 @@ def __init__(self, queue: BaseQueue, requests_per_second: float): self._seconds_per_request = 1 / requests_per_second self._last_get = 0.0 - self._last_put = 0.0 @property def is_running(self) -> bool: @@ -383,9 +407,12 @@ def get(self, timeout: Optional[float] = None) -> Any: self._last_get = time.time() return self._queue.get(timeout=timeout) + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any: + self._wait_until_allowed(self._last_get) + self._last_get = time.time() + return self._queue.batch_get(timeout=timeout) + def put(self, item: Any) -> None: - self._wait_until_allowed(self._last_put) - self._last_put = time.time() return self._queue.put(item) @@ -477,6 +504,20 @@ def _get(self) -> Any: # we consider the queue is empty to avoid failing the app. raise queue.Empty + def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]: + try: + resp = self.client.post( + f"v1/{self.app_id}/{self._name_suffix}", + query_params={"action": "popCount", "count": str(count or BATCH_DELTA_COUNT)}, + ) + if resp.status_code == 204: + raise queue.Empty + return [pickle.loads(base64.b64decode(data)) for data in resp.json()] + except ConnectionError: + # Note: If the Http Queue service isn't available, + # we consider the queue is empty to avoid failing the app. + raise queue.Empty + @backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError)) def put(self, item: Any) -> None: if not self.app_id: diff --git a/src/lightning/app/testing/helpers.py b/src/lightning/app/testing/helpers.py index 7f87180c959f3..61a00f957299e 100644 --- a/src/lightning/app/testing/helpers.py +++ b/src/lightning/app/testing/helpers.py @@ -142,6 +142,11 @@ def get(self, timeout: int = 0): raise Empty() return self._queue.pop(0) + def batch_get(self, timeout: int = 0, count: int = None): + if not self._queue: + raise Empty() + return [self._queue.pop(0)] + def __contains__(self, item): return item in self._queue diff --git a/src/lightning/app/utilities/packaging/lightning_utils.py b/src/lightning/app/utilities/packaging/lightning_utils.py index 9e5493f332e3e..e8846b382e49e 100644 --- a/src/lightning/app/utilities/packaging/lightning_utils.py +++ b/src/lightning/app/utilities/packaging/lightning_utils.py @@ -150,6 +150,16 @@ def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = " tar_name = _copy_tar(lightning_cloud_project_path, root) tar_files.append(os.path.join(root, tar_name)) + lightning_launcher_project_path = get_dist_path_if_editable_install("lightning_launcher") + if lightning_launcher_project_path: + from lightning_launcher.__version__ import __version__ as cloud_version + + # todo: check why logging.info is missing in outputs + print(f"Packaged Lightning Launcher with your application. Version: {cloud_version}") + _prepare_wheel(lightning_launcher_project_path) + tar_name = _copy_tar(lightning_launcher_project_path, root) + tar_files.append(os.path.join(root, tar_name)) + return functools.partial(_cleanup, *tar_files) diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index 86f71d6f09154..d529b373e4df0 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -446,8 +446,8 @@ def run(self): @pytest.mark.parametrize( ("sleep_time", "expect"), [ - (1, 0), - pytest.param(0, 10.0, marks=pytest.mark.xfail(strict=False, reason="failing...")), # fixme + (0, 9), + pytest.param(9, 10.0, marks=pytest.mark.xfail(strict=False, reason="failing...")), # fixme ], ) @pytest.mark.flaky(reruns=5) @@ -456,10 +456,10 @@ def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQu time window.""" class SlowQueue(queue_type_cls): - def get(self, timeout): + def batch_get(self, timeout, count): out = super().get(timeout) sleep(sleep_time) - return out + return [out] app = LightningApp(EmptyFlow()) @@ -480,7 +480,7 @@ def make_delta(i): delta = app._collect_deltas_from_ui_and_work_queues()[-1] generated = delta.to_dict()["values_changed"]["root['vars']['counter']"]["new_value"] if sleep_time: - assert generated == expect + assert generated == expect, (generated, expect) else: # validate the flow should have aggregated at least expect. assert generated > expect @@ -497,7 +497,8 @@ def get(self, timeout): app.delta_queue = SlowQueue("api_delta_queue", 0) t0 = time() assert app._collect_deltas_from_ui_and_work_queues() == [] - assert (time() - t0) < app.state_accumulate_wait + delta = time() - t0 + assert delta < app.state_accumulate_wait + 0.01, delta class SimpleFlow2(LightningFlow): diff --git a/tests/tests_app/core/test_queues.py b/tests/tests_app/core/test_queues.py index 583e828b12430..0f68d8aa1ff98 100644 --- a/tests/tests_app/core/test_queues.py +++ b/tests/tests_app/core/test_queues.py @@ -1,3 +1,4 @@ +import base64 import multiprocessing import pickle import queue @@ -220,6 +221,24 @@ def test_http_queue_get(self, monkeypatch): ) assert test_queue.get() == "test" + def test_http_queue_batch_get(self, monkeypatch): + monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token") + test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT) + adapter = requests_mock.Adapter() + test_queue.client.session.mount("http://", adapter) + + adapter.register_uri( + "POST", + f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=popCount", + request_headers={"Authorization": "Bearer test-token"}, + status_code=200, + json=[ + base64.b64encode(pickle.dumps("test")).decode("utf-8"), + base64.b64encode(pickle.dumps("test2")).decode("utf-8"), + ], + ) + assert test_queue.batch_get() == ["test", "test2"] + def test_unreachable_queue(monkeypatch): monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")