Skip to content

Commit

Permalink
Ref: Move parameters to start to __init__ method (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
klieret authored Nov 12, 2024
1 parent 40498e1 commit 608270d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 21 deletions.
15 changes: 12 additions & 3 deletions src/swerex/deployment/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@


class DockerDeployment(AbstractDeployment):
def __init__(self, image: str, *, port: int | None = None, docker_args: list[str] | None = None):
def __init__(
self,
image: str,
*,
port: int | None = None,
docker_args: list[str] | None = None,
startup_timeout: float = 60.0,
):
"""Deployment to local docker image.
Args:
image: The name of the docker image to use.
port: The port that the docker container connects to. If None, a free port is found.
docker_args: Additional arguments to pass to the docker run command.
startup_timeout: The time to wait for the runtime to start.
"""
self._image_name = image
self._runtime: RemoteRuntime | None = None
Expand All @@ -33,6 +41,7 @@ def __init__(self, image: str, *, port: int | None = None, docker_args: list[str
self._container_name = None
self.logger = get_logger("deploy")
self._runtime_timeout = 0.15
self._startup_timeout = startup_timeout

def _get_container_name(self) -> str:
"""Returns a unique container name based on the image name."""
Expand Down Expand Up @@ -89,7 +98,7 @@ def _get_swerex_start_cmd(self, token: str) -> list[str]:
f"{REMOTE_EXECUTABLE_NAME} {rex_args} || ({pipx_install} && pipx run {PACKAGE_NAME} {rex_args})",
]

async def start(self, *, timeout: float = 10.0):
async def start(self):
"""Starts the runtime."""
port = self._port or find_free_port()
assert self._container_name is None
Expand Down Expand Up @@ -117,7 +126,7 @@ async def start(self, *, timeout: float = 10.0):
self.logger.info(f"Starting runtime at {self._port}")
self._runtime = RemoteRuntime(port=port, timeout=self._runtime_timeout, auth_token=token)
t0 = time.time()
await self._wait_until_alive(timeout=timeout)
await self._wait_until_alive(timeout=self._startup_timeout)
self.logger.info(f"Runtime started in {time.time() - t0:.2f}s")

async def stop(self):
Expand Down
6 changes: 3 additions & 3 deletions src/swerex/deployment/fargate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
security_group_prefix: str = "swe-rex-deployment-sg",
fargate_args: dict | None = None,
container_timeout: float = 60 * 15,
runtime_timeout: float = 30,
):
self._image = image
self._runtime: RemoteRuntime | None = None
Expand All @@ -57,6 +58,7 @@ def __init__(
self._subnet_id = None
self._task_arn = None
self._security_group_id = None
self._runtime_timeout = runtime_timeout

def _init_aws(self):
self._cluster_arn = get_cluster_arn(self._cluster_name)
Expand Down Expand Up @@ -124,8 +126,6 @@ def _get_token(self) -> str:

async def start(
self,
*,
timeout: float = 120,
):
"""Starts the runtime."""
self._init_aws()
Expand Down Expand Up @@ -162,7 +162,7 @@ async def start(
self.logger.info(f"Container public IP: {public_ip}")
self._runtime = RemoteRuntime(host=public_ip, port=self._port, auth_token=token)
t0 = time.time()
await self._wait_until_alive(timeout=timeout)
await self._wait_until_alive(timeout=self._runtime_timeout)
self.logger.info(f"Runtime started in {time.time() - t0:.2f}s")

async def stop(self):
Expand Down
23 changes: 11 additions & 12 deletions src/swerex/deployment/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class ModalDeployment(AbstractDeployment):
def __init__(
self,
image: str | modal.Image | PurePath,
container_timeout: float = 1800,
startup_timeout: float = 1800,
runtime_timeout: float = 0.4,
modal_sandbox_kwargs: dict[str, Any] | None = None,
):
Expand All @@ -116,13 +116,13 @@ def __init__(
2. Path to a Dockerfile
3. Dockerhub image name (e.g. `python:3.11-slim`)
4. ECR image name (e.g. `123456789012.dkr.ecr.us-east-1.amazonaws.com/my-image:tag`)
container_timeout:
runtime_timeout:
startup_timeout: The time to wait for the runtime to start.
runtime_timeout: The runtime timeout.
modal_sandbox_kwargs: Additional arguments to pass to `modal.Sandbox.create`
"""
self._image = _ImageBuilder().auto(image)
self._runtime: RemoteRuntime | None = None
self._container_timeout = container_timeout
self._startup_timeout = startup_timeout
self._sandbox: modal.Sandbox | None = None
self._port = 8880
self.logger = get_logger("deploy")
Expand Down Expand Up @@ -161,7 +161,6 @@ def _start_swerex_cmd(self, token: str) -> str:
"""Start swerex-server on the remote. If swerex is not installed arelady,
install pipx and then run swerex-server with pipx run
"""
# todo: Change that to swe-rex after release
rex_args = f"--port {self._port} --auth-token {token}"
return f"{REMOTE_EXECUTABLE_NAME} {rex_args} || pipx run {PACKAGE_NAME} {rex_args}"

Expand All @@ -175,8 +174,6 @@ def get_modal_log_url(self) -> str:

async def start(
self,
*,
timeout: float = 60,
):
"""Starts the runtime."""
self.logger.info("Starting modal sandbox")
Expand All @@ -187,21 +184,23 @@ async def start(
"-c",
self._start_swerex_cmd(token),
image=self._image,
timeout=int(self._container_timeout),
timeout=int(self._startup_timeout),
unencrypted_ports=[self._port],
app=self._app,
**self._modal_kwargs,
)
tunnel = self._sandbox.tunnels()[self._port]
self.logger.info(f"Sandbox ({self._sandbox.object_id}) created in {time.time() - t0:.2f}s")
elapsed_sandbox_creation = time.time() - t0
self.logger.info(f"Sandbox ({self._sandbox.object_id}) created in {elapsed_sandbox_creation:.2f}s")
self.logger.info(f"Check sandbox logs at {self.get_modal_log_url()}")
self.logger.info(f"Sandbox created with id {self._sandbox.object_id}")
await asyncio.sleep(1)
self.logger.info(f"Starting runtime at {tunnel.url}")
self._runtime = RemoteRuntime(host=tunnel.url, timeout=self._runtime_timeout, auth_token=token)
t0 = time.time()
await self._wait_until_alive(timeout=timeout)
self.logger.info(f"Runtime started in {time.time() - t0:.2f}s")
remaining_startup_timeout = max(0, self._startup_timeout - elapsed_sandbox_creation)
t1 = time.time()
await self._wait_until_alive(timeout=self._runtime_timeout + remaining_startup_timeout)
self.logger.info(f"Runtime started in {time.time() - t1:.2f}s")

async def stop(self):
"""Stops the runtime."""
Expand Down
2 changes: 1 addition & 1 deletion src/swerex/runtime/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def is_alive(self, *, timeout: float | None = None) -> IsAliveResponse:
msg += traceback.format_exc()
return IsAliveResponse(is_alive=False, message=msg)

async def wait_until_alive(self, *, timeout: float | None = None):
async def wait_until_alive(self, *, timeout: float = 60.0):
return await _wait_until_alive(self.is_alive, timeout=timeout)

def _request(self, endpoint: str, request: BaseModel | None, output_class: Any):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_modal_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
async def test_modal_deployment_from_docker_with_swerex_installed():
dockerfile = Path(__file__).parent / "swe_rex_test.Dockerfile"
image = _ImageBuilder().from_file(dockerfile, build_context=Path(__file__).resolve().parent.parent)
d = ModalDeployment(image=image, container_timeout=60)
d = ModalDeployment(image=image, startup_timeout=60)
with pytest.raises(RuntimeError):
await d.is_alive()
await d.start()
Expand All @@ -19,7 +19,7 @@ async def test_modal_deployment_from_docker_with_swerex_installed():

@pytest.mark.slow
async def test_modal_deployment_from_docker_string():
d = ModalDeployment(image="python:3.11-slim")
d = ModalDeployment(image="python:3.11-slim", startup_timeout=60)
await d.start()
assert await d.is_alive()
await d.stop()

0 comments on commit 608270d

Please sign in to comment.