Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lightning App: Use the batch get endpoint #19180

Merged
merged 13 commits into from
Dec 18, 2023
17 changes: 12 additions & 5 deletions src/lightning/app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from lightning.app import _console
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
from lightning.app.core.constants import (
BATCH_DELTA_COUNT,
DEBUG_ENABLED,
FLOW_DURATION_SAMPLES,
FLOW_DURATION_THRESHOLD,
Expand Down Expand Up @@ -308,6 +309,14 @@ def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None)
except queue.Empty:
return None

@staticmethod
def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]:
try:
timeout = timeout or q.default_timeout
return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT)
except queue.Empty:
return []

def check_error_queue(self) -> None:
exception: Exception = self.get_state_changed_from_queue(self.error_queue) # type: ignore[assignment,arg-type]
if isinstance(exception, Exception):
Expand Down Expand Up @@ -341,12 +350,12 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque

while (time() - t0) < self.state_accumulate_wait:
# TODO: Fetch all available deltas at once to reduce queue calls.
delta: Optional[
received_deltas: List[
Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta]
] = self.get_state_changed_from_queue(
] = self.batch_get_state_changed_from_queue(
self.delta_queue # type: ignore[assignment,arg-type]
)
if delta:
for delta in received_deltas:
if isinstance(delta, _DeltaRequest):
deltas.append(delta.delta)
elif isinstance(delta, ComponentDelta):
Expand All @@ -364,8 +373,6 @@ def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIReque
deltas.append(delta)
else:
api_or_command_request_deltas.append(delta)
else:
break

if api_or_command_request_deltas:
_process_requests(self, api_or_command_request_deltas)
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/app/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def get_lightning_cloud_url() -> str:
# directory where system customization sync files will be copied to be packed into app tarball
SYS_CUSTOMIZATIONS_SYNC_PATH = ".sys-customizations-sync"

BATCH_DELTA_COUNT = int(os.getenv("BATCH_DELTA_COUNT", "128"))


def enable_multiple_works_in_default_container() -> bool:
return bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0")))
Expand Down
49 changes: 45 additions & 4 deletions src/lightning/app/core/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import multiprocessing
import pickle
import queue # needed as import instead from/import for mocking in tests
Expand All @@ -20,14 +21,15 @@
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Optional, Tuple
from typing import Any, List, Optional, Tuple
from urllib.parse import urljoin

import backoff
import requests
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout

from lightning.app.core.constants import (
BATCH_DELTA_COUNT,
HTTP_QUEUE_REFRESH_INTERVAL,
HTTP_QUEUE_REQUESTS_PER_SECOND,
HTTP_QUEUE_TOKEN,
Expand Down Expand Up @@ -189,6 +191,20 @@ def get(self, timeout: Optional[float] = None) -> Any:
"""
pass

@abstractmethod
def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
"""Returns the left most elements of the queue.

Parameters
----------
timeout:
Read timeout in seconds, in case of input timeout is 0, the `self.default_timeout` is used.
A timeout of None can be used to block indefinitely.
count:
The number of element to get from the queue

"""

@property
def is_running(self) -> bool:
"""Returns True if the queue is running, False otherwise.
Expand All @@ -214,6 +230,12 @@ def get(self, timeout: Optional[float] = None) -> Any:
timeout = self.default_timeout
return self.queue.get(timeout=timeout, block=(timeout is None))

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
if timeout == 0:
timeout = self.default_timeout
# For multiprocessing, we can simply collect the latest upmost element
return [self.queue.get(timeout=timeout, block=(timeout is None))]


class RedisQueue(BaseQueue):
@requires("redis")
Expand Down Expand Up @@ -312,6 +334,9 @@ def get(self, timeout: Optional[float] = None) -> Any:
raise queue.Empty
return pickle.loads(out[1])

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
raise NotImplementedError("The batch_get method isn't implemented.")

def clear(self) -> None:
"""Clear all elements in the queue."""
self.redis.delete(self.name)
Expand Down Expand Up @@ -366,7 +391,6 @@ def __init__(self, queue: BaseQueue, requests_per_second: float):
self._seconds_per_request = 1 / requests_per_second

self._last_get = 0.0
self._last_put = 0.0

@property
def is_running(self) -> bool:
Expand All @@ -383,9 +407,12 @@ def get(self, timeout: Optional[float] = None) -> Any:
self._last_get = time.time()
return self._queue.get(timeout=timeout)

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> Any:
self._wait_until_allowed(self._last_get)
self._last_get = time.time()
return self._queue.batch_get(timeout=timeout)

def put(self, item: Any) -> None:
self._wait_until_allowed(self._last_put)
self._last_put = time.time()
return self._queue.put(item)


Expand Down Expand Up @@ -477,6 +504,20 @@ def _get(self) -> Any:
# we consider the queue is empty to avoid failing the app.
raise queue.Empty

def batch_get(self, timeout: Optional[float] = None, count: Optional[int] = None) -> List[Any]:
try:
resp = self.client.post(
f"v1/{self.app_id}/{self._name_suffix}",
query_params={"action": "popCount", "count": str(count or BATCH_DELTA_COUNT)},
)
if resp.status_code == 204:
raise queue.Empty
return [pickle.loads(base64.b64decode(data)) for data in resp.json()]
except ConnectionError:
# Note: If the Http Queue service isn't available,
# we consider the queue is empty to avoid failing the app.
raise queue.Empty

@backoff.on_exception(backoff.expo, (RuntimeError, requests.exceptions.HTTPError))
def put(self, item: Any) -> None:
if not self.app_id:
Expand Down
10 changes: 10 additions & 0 deletions src/lightning/app/utilities/packaging/lightning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = "
tar_name = _copy_tar(lightning_cloud_project_path, root)
tar_files.append(os.path.join(root, tar_name))

lightning_launcher_project_path = get_dist_path_if_editable_install("lightning_launcher")
if lightning_launcher_project_path:
from lightning_launcher.__version__ import __version__ as cloud_version

# todo: check why logging.info is missing in outputs
print(f"Packaged Lightning Launcher with your application. Version: {cloud_version}")
_prepare_wheel(lightning_launcher_project_path)
tar_name = _copy_tar(lightning_launcher_project_path, root)
tar_files.append(os.path.join(root, tar_name))

return functools.partial(_cleanup, *tar_files)


Expand Down
19 changes: 19 additions & 0 deletions tests/tests_app/core/test_queues.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import multiprocessing
import pickle
import queue
Expand Down Expand Up @@ -220,6 +221,24 @@ def test_http_queue_get(self, monkeypatch):
)
assert test_queue.get() == "test"

def test_http_queue_batch_get(self, monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
test_queue = HTTPQueue("test_http_queue", STATE_UPDATE_TIMEOUT)
adapter = requests_mock.Adapter()
test_queue.client.session.mount("http://", adapter)

adapter.register_uri(
"POST",
f"{HTTP_QUEUE_URL}/v1/test/http_queue?action=popCount",
request_headers={"Authorization": "Bearer test-token"},
status_code=200,
json=[
base64.b64encode(pickle.dumps("test")).decode("utf-8"),
base64.b64encode(pickle.dumps("test2")).decode("utf-8"),
],
)
assert test_queue.batch_get() == ["test", "test2"]


def test_unreachable_queue(monkeypatch):
monkeypatch.setattr(queues, "HTTP_QUEUE_TOKEN", "test-token")
Expand Down
Loading