Skip to content

Commit

Permalink
Fix restarting gateway connections (dstackai#1746)
Browse files Browse the repository at this point in the history
by deleting the remote unix socket before
reopening the connection
  • Loading branch information
jvstme authored Oct 1, 2024
1 parent f1f5fdc commit 55d09a7
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions src/dstack/_internal/server/services/gateways/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, ip_address: str, id_rsa: str, server_port: int):
],
# reverse_forwarded_sockets are added later in .open()
)
self.tunnel_id = uuid.uuid4()
self._client = GatewayClient(uds=self.gateway_socket_path)

@staticmethod
Expand All @@ -79,7 +80,7 @@ async def check_or_restart(self) -> bool:
async with self._lock.writer_lock:
if not await self.tunnel.acheck():
logger.info("Connection to gateway %s is down, restarting", self.ip_address)
await self.tunnel.aopen()
await self._open_tunnel()
return True
return False

Expand All @@ -89,18 +90,26 @@ async def open(self, close_existing_tunnel: bool = False) -> None:
# Close remaining tunnel if previous server process died w/o graceful shutdown
if await self.tunnel.acheck():
await self.tunnel.aclose()

self.connection_dir.mkdir(parents=True, exist_ok=True)
await self.tunnel.aopen()
await self.tunnel.aexec(f"mkdir -p {CONNECTIONS_DIR_ON_GATEWAY}")

self.tunnel.reverse_forwarded_sockets = [
SocketPair(
local=IPSocket(host="localhost", port=self.server_port),
remote=UnixSocket(path=f"{CONNECTIONS_DIR_ON_GATEWAY}/{uuid.uuid4()}.sock"),
),
]
await self.tunnel.aopen() # apply reverse forwarding
await self._open_tunnel()

async def _open_tunnel(self) -> None:
self.connection_dir.mkdir(parents=True, exist_ok=True)
remote_socket_path = f"{CONNECTIONS_DIR_ON_GATEWAY}/{self.tunnel_id}.sock"

# open w/o reverse forwarding and make sure reverse forwarding will be possible
self.tunnel.reverse_forwarded_sockets = []
await self.tunnel.aopen()
await self.tunnel.aexec(f"mkdir -p {CONNECTIONS_DIR_ON_GATEWAY}")
await self.tunnel.aexec(f"rm -f {remote_socket_path}")

# add reverse forwarding
self.tunnel.reverse_forwarded_sockets = [
SocketPair(
local=IPSocket(host="localhost", port=self.server_port),
remote=UnixSocket(path=remote_socket_path),
),
]
await self.tunnel.aopen()

async def close(self) -> None:
async with self._lock.writer_lock:
Expand Down

0 comments on commit 55d09a7

Please sign in to comment.