Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ssh sentry to golem service and node monitoring to ray service #217

Merged
merged 14 commits into from
Mar 29, 2024
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ python = "^3.8.1"
ray = {version="~2.9.3", extras=["default"]}
#golem-core = {path="../golem-core-python", develop=true}
#golem-core = {git="https://github.com/golemfactory/golem-core-python.git", branch="main"}
golem-core = "^0.6.0"
golem-core = "^0.6.2"
aiohttp = "^3"
requests = "^2"
click = "^8"
Expand All @@ -38,7 +38,7 @@ python = "^3.8.1"
ray = {version="==2.9.3", extras=["default"]}
#golem-core = {path="../golem-core-python", develop=true}
#golem-core = {git="https://github.com/golemfactory/golem-core-python.git", branch="main"}
golem-core = "^0.6.0"
golem-core = "^0.6.2"
aiohttp = "^3"
requests = "^2"
click = "^8"
Expand Down
184 changes: 133 additions & 51 deletions ray_on_golem/server/services/golem/golem.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from golem.node import GolemNode
from golem.payload import PaymentInfo
from golem.resources import Activity, Network, Proposal
from golem.utils.asyncio import create_task_with_logging, ensure_cancelled, ensure_cancelled_many
from yarl import URL

from ray_on_golem.reputation.plugins import ProviderBlacklistPlugin, ReputationScorer
Expand All @@ -41,10 +42,17 @@
DEFAULT_DEBIT_NOTE_INTERVAL = timedelta(minutes=3)
DEFAULT_DEBIT_NOTES_ACCEPT_TIMEOUT = timedelta(minutes=4)
DEFAULT_PROPOSAL_RESPONSE_TIMEOUT = timedelta(seconds=30)
DEFAULT_SSH_SENTRY_TIMEOUT = timedelta(minutes=2)
DEFAULT_MAX_SENTRY_FAILS_COUNT = 3


class GolemService:
def __init__(self, websocat_path: Path, registry_stats: bool):
def __init__(
self,
websocat_path: Path,
registry_stats: bool,
ssh_sentry_timeout: timedelta = DEFAULT_SSH_SENTRY_TIMEOUT,
):
self._websocat_path = websocat_path

self._demand_config_helper: DemandConfigHelper = DemandConfigHelper(registry_stats)
Expand All @@ -55,6 +63,9 @@ def __init__(self, websocat_path: Path, registry_stats: bool):
self._stacks_locks: DefaultDict[(str, bool), asyncio.Lock] = defaultdict(asyncio.Lock)
self._payment_manager: Optional[PaymentManager] = None

self._ssh_sentry_tasks: Dict[str, asyncio.Task] = {}
self._ssh_sentry_timeout: timedelta = ssh_sentry_timeout

async def init(self, yagna_appkey: str) -> None:
logger.info("Starting GolemService...")

Expand All @@ -81,6 +92,9 @@ async def shutdown(self) -> None:

# FIXME: Remove this method in case of multiple cluster support
async def clear(self) -> None:
await ensure_cancelled_many(self._ssh_sentry_tasks.values())
self._ssh_sentry_tasks.clear()

await asyncio.gather(
*[self._remove_stack(stack_hash) for stack_hash in self._stacks.keys()]
)
Expand Down Expand Up @@ -251,14 +265,15 @@ async def _get_proposal_expiration(self, proposal: Proposal) -> timedelta:
return await proposal.get_expiration_date() - datetime.now(timezone.utc)

@staticmethod
async def _get_provider_desc(context: WorkContext):
return f"{await context.get_provider_name()} ({await context.get_provider_id()})"
async def get_provider_desc(activity: Activity) -> str:
proposal = activity.agreement.proposal
return f"{await proposal.get_provider_name()} ({await proposal.get_provider_id()})"

async def _start_activity(
self, context: WorkContext, ip: str, *, add_state_log: Callable[[str], Awaitable[None]]
):
activity = context.activity
provider_desc = await self._get_provider_desc(context)
provider_desc = await self.get_provider_desc(activity)

logger.info(f"Deploying image on {provider_desc}, {ip=}, {activity=}")

Expand All @@ -275,12 +290,9 @@ async def _upload_node_configuration(
context: WorkContext,
ip: str,
ssh_public_key_data: str,
*,
add_state_log: Callable[[str], Awaitable[None]],
):
provider_desc = await self._get_provider_desc(context)
provider_desc = await self.get_provider_desc(context.activity)
logger.info(f"Running initial commands on {provider_desc}, {ip=}, {context.activity=}")
await add_state_log("[6/9] Running bootstrap commands...")
approxit marked this conversation as resolved.
Show resolved Hide resolved
hostname = ip.replace(".", "-")
await context.run("echo 'ON_GOLEM_NETWORK=1' >> /etc/environment")
await context.run(f"echo 'NODE_IP={ip}' >> /etc/environment")
Expand All @@ -290,60 +302,116 @@ async def _upload_node_configuration(
await context.run("mkdir -p /root/.ssh")
await context.run(f'echo "{ssh_public_key_data}" >> /root/.ssh/authorized_keys')

async def _start_ssh_server(
self, context: WorkContext, ip: str, *, add_state_log: Callable[[str], Awaitable[None]]
):
provider_desc = await self._get_provider_desc(context)
async def _start_ssh_server(self, context: WorkContext, ip: str):
provider_desc = await self.get_provider_desc(context.activity)
logger.info("Starting ssh service on " f"{provider_desc}, {ip=}, {context.activity=}")
await add_state_log("[7/9] Starting ssh service...")
await context.run("service ssh start")

async def _verify_ssh_connection(
self,
context: WorkContext,
async def _restart_ssh_server(self, context: WorkContext, ip: str):
provider_desc = await self.get_provider_desc(context.activity)
logger.debug(f"Restarting ssh service on {provider_desc}, {ip=}, {context.activity=}")
try:
await context.run("service ssh restart", timeout=120)
except Exception:
msg = f"Failed to restart SSH server {provider_desc}, {ip=}, {context.activity=}"
logger.warning(msg)
logger.debug(msg, exc_info=True)
else:
logger.debug(
f"Restarting ssh service on {provider_desc}, {ip=}, {context.activity=} done"
)

@staticmethod
async def _verify_ssh_connection_check(
activity_id: str,
provider_desc: str,
ip: str,
ssh_proxy_command: str,
ssh_user: str,
ssh_private_key_path: Path,
num_retries=3,
retry_interval=1,
*,
add_state_log: Callable[[str], Awaitable[None]],
) -> None:
activity = context.activity
):
ssh_command = (
f"{get_ssh_command(ip, ssh_proxy_command, ssh_user, ssh_private_key_path)} uptime"
)

logger.debug(
"SSH connection check started on "
f"{await self._get_provider_desc(context)}, {ip=}, {activity=}: cmd={ssh_command}."
f"{provider_desc}, {ip=}, {activity_id=}: cmd={ssh_command}."
)
process = await asyncio.create_subprocess_shell(
ssh_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await add_state_log("[8/9] Checking SSH connection...")

debug_data = ""
stdout, stderr = await process.communicate()

async def check():
nonlocal debug_data
debug_data = f"{provider_desc=}, exitcode={process.returncode}, {stdout=}, {stderr=}"
logger.debug(debug_data)

process = await asyncio.create_subprocess_shell(
ssh_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
if process.returncode != 0:
raise Exception(f"SSH connection check failed. {debug_data}")

stdout, stderr = await process.communicate()
async def _sentry_ssh_connection(
self,
context: WorkContext,
ip: str,
ssh_proxy_command: str,
ssh_user: str,
ssh_private_key_path: Path,
):
provider_desc = await self.get_provider_desc(context.activity)

debug_data = f"{activity=}, exitcode={process.returncode}, {stdout=}, {stderr=}"
fails_count = 0
while True:
try:
await self._verify_ssh_connection(
context,
ip,
ssh_proxy_command,
ssh_user,
ssh_private_key_path,
)
except Exception:
fails_count += 1
if fails_count >= DEFAULT_MAX_SENTRY_FAILS_COUNT:
msg = f"Destroying activity due to no SSH connection to {provider_desc}"
logger.warning(msg)
logger.debug(msg, exc_info=True)
create_task_with_logging(self.stop_activity(context.activity))
break

logger.debug(
f"SSH connection to {provider_desc} stopped working. Restarting SSH server",
exc_info=True,
)
await self._restart_ssh_server(context, ip)
lucekdudek marked this conversation as resolved.
Show resolved Hide resolved
await asyncio.sleep(self._ssh_sentry_timeout.total_seconds())

if process.returncode != 0:
raise Exception(f"SSH connection check failed. {debug_data}")
async def _verify_ssh_connection(
self,
context: WorkContext,
ip: str,
ssh_proxy_command: str,
ssh_user: str,
ssh_private_key_path: Path,
num_retries=3,
retry_interval=1,
) -> None:
activity = context.activity
provider_desc = await self.get_provider_desc(context.activity)

retry = num_retries

while retry > 0:
try:
await check()
await self._verify_ssh_connection_check(
activity.id,
provider_desc,
ip,
ssh_proxy_command,
ssh_user,
ssh_private_key_path,
)
break
except Exception as e:
retry -= 1
Expand All @@ -354,11 +422,7 @@ async def check():
else:
raise GolemException("SSH connection check failed!") from e

logger.info(
"SSH connection check successful on "
f"{await self._get_provider_desc(context)}, {ip=}, {activity=}."
)
logger.debug(debug_data)
logger.info(f"SSH connection check successful on {provider_desc}, {ip=}, {activity=}.")

async def create_activity(
self,
Expand Down Expand Up @@ -398,7 +462,18 @@ async def create_activity(
msg = "Failed to create activity, retrying."
error = f"{type(e).__module__}.{type(e).__name__}: {e}"
await add_state_log(f"{msg} {error=}")
logger.warning(msg, exc_info=True)
logger.warning(msg)
logger.debug(msg, exc_info=True)

async def stop_activity(self, activity: Activity):
if activity.id in self._ssh_sentry_tasks:
await ensure_cancelled(self._ssh_sentry_tasks[activity.id])

provider_desc = await self.get_provider_desc(activity)
try:
await activity.destroy()
except Exception:
logger.debug(f"Cannot destroy activity {provider_desc}", exc_info=True)

async def _create_activity(
self,
Expand Down Expand Up @@ -431,30 +506,37 @@ async def _create_activity(

work_context = WorkContext(activity)
await self._start_activity(work_context, ip, add_state_log=add_state_log)
await self._upload_node_configuration(
work_context, ip, public_ssh_key, add_state_log=add_state_log
)
await self._start_ssh_server(work_context, ip, add_state_log=add_state_log)

await add_state_log("[6/9] Running bootstrap commands...")
await self._upload_node_configuration(work_context, ip, public_ssh_key)
await add_state_log("[7/9] Starting ssh service...")
await self._start_ssh_server(work_context, ip)

await add_state_log("[8/9] Checking SSH connection...")
await self._verify_ssh_connection(
work_context,
ip,
ssh_proxy_command,
ssh_user,
ssh_private_key_path,
add_state_log=add_state_log,
)

self._ssh_sentry_tasks[activity.id] = create_task_with_logging(
self._sentry_ssh_connection(
work_context, ip, ssh_proxy_command, ssh_user, ssh_private_key_path
)
)

await self._network.refresh_nodes()
except Exception as e:
logger.error(f"Creating new activity failed with `{type(e).__name__}: {e}`")
await activity.destroy()
await self.stop_activity(activity)
raise

await add_state_log(f"[9/9] Activity ready on provider: {provider_desc}")
logger.info(
"Creating new activity done on "
f"{await self._get_provider_desc(work_context)}, {ip=}, {activity=}"
f"{await self.get_provider_desc(activity)}, {ip=}, {activity=}"
)

return activity, ip, ssh_proxy_command
Expand Down
Loading