From 0558be078f23ff69ad1b4c1bb8c4745b8f6073d6 Mon Sep 17 00:00:00 2001 From: Kilian Lieret Date: Mon, 11 Nov 2024 18:45:39 -0500 Subject: [PATCH] Ref: Move parameters to start to __init__ method --- src/swerex/deployment/docker.py | 15 ++++++++++++--- src/swerex/deployment/fargate.py | 6 +++--- src/swerex/deployment/modal.py | 23 +++++++++++------------ src/swerex/runtime/remote.py | 2 +- tests/test_modal_deployment.py | 4 ++-- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/src/swerex/deployment/docker.py b/src/swerex/deployment/docker.py index 2634ac4..8d3e9e9 100644 --- a/src/swerex/deployment/docker.py +++ b/src/swerex/deployment/docker.py @@ -15,13 +15,21 @@ class DockerDeployment(AbstractDeployment): - def __init__(self, image: str, *, port: int | None = None, docker_args: list[str] | None = None): + def __init__( + self, + image: str, + *, + port: int | None = None, + docker_args: list[str] | None = None, + startup_timeout: float = 60.0, + ): """Deployment to local docker image. Args: image: The name of the docker image to use. port: The port that the docker container connects to. If None, a free port is found. docker_args: Additional arguments to pass to the docker run command. + startup_timeout: The time to wait for the runtime to start. """ self._image_name = image self._runtime: RemoteRuntime | None = None @@ -33,6 +41,7 @@ def __init__(self, image: str, *, port: int | None = None, docker_args: list[str self._container_name = None self.logger = get_logger("deploy") self._runtime_timeout = 0.15 + self._startup_timeout = startup_timeout def _get_container_name(self) -> str: """Returns a unique container name based on the image name.""" @@ -89,7 +98,7 @@ def _get_swerex_start_cmd(self, token: str) -> list[str]: f"{REMOTE_EXECUTABLE_NAME} {rex_args} || ({pipx_install} && pipx run {PACKAGE_NAME} {rex_args})", ] - async def start(self, *, timeout: float = 10.0): + async def start(self): """Starts the runtime.""" port = self._port or find_free_port() assert self._container_name is None @@ -117,7 +126,7 @@ async def start(self, *, timeout: float = 10.0): self.logger.info(f"Starting runtime at {self._port}") self._runtime = RemoteRuntime(port=port, timeout=self._runtime_timeout, auth_token=token) t0 = time.time() - await self._wait_until_alive(timeout=timeout) + await self._wait_until_alive(timeout=self._startup_timeout) self.logger.info(f"Runtime started in {time.time() - t0:.2f}s") async def stop(self): diff --git a/src/swerex/deployment/fargate.py b/src/swerex/deployment/fargate.py index f130cd3..305da51 100644 --- a/src/swerex/deployment/fargate.py +++ b/src/swerex/deployment/fargate.py @@ -34,6 +34,7 @@ def __init__( security_group_prefix: str = "swe-rex-deployment-sg", fargate_args: dict | None = None, container_timeout: float = 60 * 15, + runtime_timeout: float = 30, ): self._image = image self._runtime: RemoteRuntime | None = None @@ -57,6 +58,7 @@ def __init__( self._subnet_id = None self._task_arn = None self._security_group_id = None + self._runtime_timeout = runtime_timeout def _init_aws(self): self._cluster_arn = get_cluster_arn(self._cluster_name) @@ -124,8 +126,6 @@ def _get_token(self) -> str: async def start( self, - *, - timeout: float = 120, ): """Starts the runtime.""" self._init_aws() @@ -162,7 +162,7 @@ async def start( self.logger.info(f"Container public IP: {public_ip}") self._runtime = RemoteRuntime(host=public_ip, port=self._port, auth_token=token) t0 = time.time() - await self._wait_until_alive(timeout=timeout) + await self._wait_until_alive(timeout=self._runtime_timeout) self.logger.info(f"Runtime started in {time.time() - t0:.2f}s") async def stop(self): diff --git a/src/swerex/deployment/modal.py b/src/swerex/deployment/modal.py index 25d02ec..958abfc 100644 --- a/src/swerex/deployment/modal.py +++ b/src/swerex/deployment/modal.py @@ -103,7 +103,7 @@ class ModalDeployment(AbstractDeployment): def __init__( self, image: str | modal.Image | PurePath, - container_timeout: float = 1800, + startup_timeout: float = 1800, runtime_timeout: float = 0.4, modal_sandbox_kwargs: dict[str, Any] | None = None, ): @@ -116,13 +116,13 @@ def __init__( 2. Path to a Dockerfile 3. Dockerhub image name (e.g. `python:3.11-slim`) 4. ECR image name (e.g. `123456789012.dkr.ecr.us-east-1.amazonaws.com/my-image:tag`) - container_timeout: - runtime_timeout: + startup_timeout: The time to wait for the runtime to start. + runtime_timeout: The runtime timeout. modal_sandbox_kwargs: Additional arguments to pass to `modal.Sandbox.create` """ self._image = _ImageBuilder().auto(image) self._runtime: RemoteRuntime | None = None - self._container_timeout = container_timeout + self._startup_timeout = startup_timeout self._sandbox: modal.Sandbox | None = None self._port = 8880 self.logger = get_logger("deploy") @@ -161,7 +161,6 @@ def _start_swerex_cmd(self, token: str) -> str: """Start swerex-server on the remote. If swerex is not installed arelady, install pipx and then run swerex-server with pipx run """ - # todo: Change that to swe-rex after release rex_args = f"--port {self._port} --auth-token {token}" return f"{REMOTE_EXECUTABLE_NAME} {rex_args} || pipx run {PACKAGE_NAME} {rex_args}" @@ -175,8 +174,6 @@ def get_modal_log_url(self) -> str: async def start( self, - *, - timeout: float = 60, ): """Starts the runtime.""" self.logger.info("Starting modal sandbox") @@ -187,21 +184,23 @@ async def start( "-c", self._start_swerex_cmd(token), image=self._image, - timeout=int(self._container_timeout), + timeout=int(self._startup_timeout), unencrypted_ports=[self._port], app=self._app, **self._modal_kwargs, ) tunnel = self._sandbox.tunnels()[self._port] - self.logger.info(f"Sandbox ({self._sandbox.object_id}) created in {time.time() - t0:.2f}s") + elapsed_sandbox_creation = time.time() - t0 + self.logger.info(f"Sandbox ({self._sandbox.object_id}) created in {elapsed_sandbox_creation:.2f}s") self.logger.info(f"Check sandbox logs at {self.get_modal_log_url()}") self.logger.info(f"Sandbox created with id {self._sandbox.object_id}") await asyncio.sleep(1) self.logger.info(f"Starting runtime at {tunnel.url}") self._runtime = RemoteRuntime(host=tunnel.url, timeout=self._runtime_timeout, auth_token=token) - t0 = time.time() - await self._wait_until_alive(timeout=timeout) - self.logger.info(f"Runtime started in {time.time() - t0:.2f}s") + remaining_startup_timeout = max(0, self._startup_timeout - elapsed_sandbox_creation) + t1 = time.time() + await self._wait_until_alive(timeout=self._runtime_timeout + remaining_startup_timeout) + self.logger.info(f"Runtime started in {time.time() - t1:.2f}s") async def stop(self): """Stops the runtime.""" diff --git a/src/swerex/runtime/remote.py b/src/swerex/runtime/remote.py index fbbd0c8..295f66d 100644 --- a/src/swerex/runtime/remote.py +++ b/src/swerex/runtime/remote.py @@ -127,7 +127,7 @@ async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse: msg += traceback.format_exc() return IsAliveResponse(is_alive=False, message=msg) - async def wait_until_alive(self, *, timeout: float | None = None): + async def wait_until_alive(self, *, timeout: float = 60.0): return await _wait_until_alive(self.is_alive, timeout=timeout) def _request(self, endpoint: str, request: BaseModel | None, output_class: Any): diff --git a/tests/test_modal_deployment.py b/tests/test_modal_deployment.py index f44b76a..2ccee7d 100644 --- a/tests/test_modal_deployment.py +++ b/tests/test_modal_deployment.py @@ -9,7 +9,7 @@ async def test_modal_deployment_from_docker_with_swerex_installed(): dockerfile = Path(__file__).parent / "swe_rex_test.Dockerfile" image = _ImageBuilder().from_file(dockerfile, build_context=Path(__file__).resolve().parent.parent) - d = ModalDeployment(image=image, container_timeout=60) + d = ModalDeployment(image=image, startup_timeout=60) with pytest.raises(RuntimeError): await d.is_alive() await d.start() @@ -19,7 +19,7 @@ async def test_modal_deployment_from_docker_with_swerex_installed(): @pytest.mark.slow async def test_modal_deployment_from_docker_string(): - d = ModalDeployment(image="python:3.11-slim") + d = ModalDeployment(image="python:3.11-slim", startup_timeout=60) await d.start() assert await d.is_alive() await d.stop()