Skip to content

Commit

Permalink
Minor fixes (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
approxit authored Mar 15, 2024
1 parent bf6cadc commit 6a7511c
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 42 deletions.
12 changes: 8 additions & 4 deletions .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ name: Integration tests

on:
workflow_call:
schedule:
# run this workflow every day at 2:00 AM UTC
- cron: '0 2 * * *'
inputs:
BRANCH:
type: string
description: Git branch to be used in run
default: main

jobs:
examples:
Expand Down Expand Up @@ -38,6 +40,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ inputs.BRANCH }}

- name: Start Goth
env:
Expand Down Expand Up @@ -88,7 +92,7 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
name: logs-example-${{ matrix.example_name }}
name: logs-example-${{ inputs.BRANCH }}-${{ matrix.example_name }}
path: |
/root/.local/share/ray_on_golem/webserver_debug.log
/root/.local/share/ray_on_golem/yagna.log
Expand Down
17 changes: 17 additions & 0 deletions .github/workflows/on_schedule.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: On schedule

on:
schedule:
# run this workflow every day at 2:00 AM UTC
- cron: '0 2 * * *'

jobs:
nightly_tests:
name: Nightly tests
strategy:
fail-fast: false
matrix:
branch: [main, develop]
uses: ./.github/workflows/integration_tests.yml
with:
BRANCH: ${{ matrix.branch }}
24 changes: 15 additions & 9 deletions ray_on_golem/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ray_on_golem.client.exceptions import RayOnGolemClientError, RayOnGolemClientValidationError
from ray_on_golem.server import models, settings
from ray_on_golem.server.models import CreateClusterResponseData
from ray_on_golem.server.models import BootstrapClusterResponseData

TResponseModel = TypeVar("TResponseModel")

Expand All @@ -22,15 +22,19 @@ def __init__(self, port: int) -> None:
self.base_url = URL("http://127.0.0.1").with_port(self.port)
self._session = requests.Session()

def create_cluster(
def bootstrap_cluster(
self,
cluster_config: Dict[str, Any],
) -> CreateClusterResponseData:
provider_config: Dict[str, Any],
cluster_name: str,
) -> BootstrapClusterResponseData:
return self._make_request(
url=settings.URL_CREATE_CLUSTER,
request_data=models.CreateClusterRequestData(**cluster_config),
response_model=models.CreateClusterResponseData,
error_message="Couldn't create cluster",
url=settings.URL_BOOTSTRAP_CLUSTER,
request_data=models.BootstrapClusterRequestData(
provider_config=provider_config,
cluster_name=cluster_name,
),
response_model=models.BootstrapClusterResponseData,
error_message="Couldn't bootstrap cluster",
)

def request_nodes(
Expand Down Expand Up @@ -222,7 +226,9 @@ def _make_request(
data=request_data.json() if request_data else None,
)
except requests.ConnectionError as e:
raise RayOnGolemClientError(f"{error_message or f'Connection failed: {url}'}: {e}")
raise RayOnGolemClientError(
"{}: {}".format(error_message or f"Connection failed: {url}", e)
)

if response.status_code != 200:
raise RayOnGolemClientError(
Expand Down
14 changes: 9 additions & 5 deletions ray_on_golem/provider/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,20 @@ def __init__(self, provider_config: Dict[str, Any], cluster_name: str):
provider_parameters = self._map_ssh_config(provider_parameters)
self._payment_network = provider_parameters["payment_network"].lower().strip()

cluster_creation_response = self._ray_on_golem_client.create_cluster(provider_parameters)
cluster_bootstrap_response = self._ray_on_golem_client.bootstrap_cluster(
provider_parameters, cluster_name
)

self._wallet_address = cluster_creation_response.wallet_address
self._is_cluster_just_created = cluster_creation_response.is_cluster_just_created
self._wallet_address = cluster_bootstrap_response.wallet_address
self._is_cluster_just_created = cluster_bootstrap_response.is_cluster_just_created

self._print_mainnet_onboarding_message(
cluster_creation_response.yagna_payment_status_output
cluster_bootstrap_response.yagna_payment_status_output
)

wallet_glm_amount = float(cluster_creation_response.yagna_payment_status.get("amount", "0"))
wallet_glm_amount = float(
cluster_bootstrap_response.yagna_payment_status.get("amount", "0")
)
if not wallet_glm_amount:
cli_logger.abort("You don't seem to have any GLM tokens on your Golem wallet.")

Expand Down
7 changes: 4 additions & 3 deletions ray_on_golem/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray_on_golem.server.services import GolemService, RayService, YagnaService
from ray_on_golem.server.settings import (
DEFAULT_DATADIR,
RAY_ON_GOLEM_SHUTDOWN_TIMEOUT,
RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT,
WEBSOCAT_PATH,
YAGNA_PATH,
get_datadir,
Expand Down Expand Up @@ -65,7 +65,7 @@ def main(port: int, self_shutdown: bool, registry_stats: bool, datadir: Path):
app,
port=app["port"],
print=None,
shutdown_timeout=RAY_ON_GOLEM_SHUTDOWN_TIMEOUT.total_seconds(),
shutdown_timeout=RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT.total_seconds(),
)
except Exception:
logger.info("Server unexpectedly died, bye!")
Expand Down Expand Up @@ -124,7 +124,8 @@ async def startup_print(app: web.Application) -> None:
async def shutdown_print(app: web.Application) -> None:
print("") # explicit new line to console to visually better handle ^C
logger.info(
"Waiting up to `%s` for current connections to close...", RAY_ON_GOLEM_SHUTDOWN_TIMEOUT
"Waiting up to `%s` for current connections to close...",
RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT,
)


Expand Down
7 changes: 4 additions & 3 deletions ray_on_golem/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ class ProviderConfigData(BaseModel):
ssh_user: str


class CreateClusterRequestData(ProviderConfigData):
pass
class BootstrapClusterRequestData(BaseModel):
provider_config: ProviderConfigData
cluster_name: str


class CreateClusterResponseData(BaseModel):
class BootstrapClusterResponseData(BaseModel):
is_cluster_just_created: bool
wallet_address: str
yagna_payment_status_output: str
Expand Down
9 changes: 8 additions & 1 deletion ray_on_golem/server/services/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self._datadir = datadir

self._provider_config: Optional[ProviderConfigData] = None
self._cluster_name: Optional[str] = None
self._wallet_address: Optional[str] = None

self._nodes: Dict[NodeId, Node] = {}
Expand All @@ -78,11 +79,17 @@ async def shutdown(self) -> None:
logger.info("Stopping RayService done")

async def create_cluster(
self, provider_config: ProviderConfigData
self, provider_config: ProviderConfigData, cluster_name: str
) -> Tuple[bool, str, str, Dict]:
is_cluster_just_created = self._provider_config is None

if not is_cluster_just_created and self._cluster_name != cluster_name:
raise RayServiceError(
f"Webserver is running only for `{self._cluster_name}` cluster, not for `{cluster_name}`!"
)

self._provider_config = provider_config
self._cluster_name = cluster_name

self._ssh_private_key_path = Path(provider_config.ssh_private_key)
self._ssh_public_key_path = self._ssh_private_key_path.with_suffix(".pub")
Expand Down
12 changes: 10 additions & 2 deletions ray_on_golem/server/services/yagna.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from asyncio.subprocess import Process
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Optional

Expand Down Expand Up @@ -197,6 +197,7 @@ async def run_payment_fund(self, network: str, driver: str) -> Dict:
"--driver",
driver,
"--json",
timeout=timedelta(seconds=30),
)
)

Expand Down Expand Up @@ -224,7 +225,14 @@ async def run_payment_fund(self, network: str, driver: str) -> Dict:

async def fetch_payment_status(self, network: str, driver: str) -> str:
output = await run_subprocess_output(
self._yagna_path, "payment", "status", "--network", network, "--driver", driver
self._yagna_path,
"payment",
"status",
"--network",
network,
"--driver",
driver,
timeout=timedelta(seconds=30),
)
return output.decode()

Expand Down
5 changes: 4 additions & 1 deletion ray_on_golem/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
# how long a shutdown request will wait until the webserver shutdown is initiated
RAY_ON_GOLEM_SHUTDOWN_DELAY = timedelta(seconds=60)

# how long we wait for the webserver shutdown pending connection to complete
RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT = timedelta(seconds=5)

# how long we wait for the webserver shutdown to complete
RAY_ON_GOLEM_SHUTDOWN_TIMEOUT = timedelta(seconds=60)

Expand All @@ -31,7 +34,7 @@
RAY_ON_GOLEM_PID_FILENAME = "ray_on_golem.pid"

URL_STATUS = "/"
URL_CREATE_CLUSTER = "/create_cluster"
URL_BOOTSTRAP_CLUSTER = "/bootstrap_cluster"
URL_NON_TERMINATED_NODES = "/non_terminated_nodes"
URL_IS_RUNNING = "/is_running"
URL_IS_TERMINATED = "/is_terminated"
Expand Down
28 changes: 17 additions & 11 deletions ray_on_golem/server/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray_on_golem.server import models, settings
from ray_on_golem.server.models import ShutdownState
from ray_on_golem.server.services import RayService
from ray_on_golem.server.services.ray import RayServiceError
from ray_on_golem.utils import raise_graceful_exit
from ray_on_golem.version import get_version

Expand All @@ -18,7 +19,7 @@
def reject_if_shutting_down(func):
async def wrapper(request: web.Request) -> web.Response:
if request.app.get("shutting_down"):
return web.HTTPBadRequest(text="Action not allowed while server is shutting down!")
return web.HTTPBadRequest(reason="Action not allowed while server is shutting down!")

return await func(request)

Expand All @@ -43,21 +44,26 @@ async def status(request: web.Request) -> web.Response:
)


@routes.post(settings.URL_CREATE_CLUSTER)
async def create_cluster(request: web.Request) -> web.Response:
@routes.post(settings.URL_BOOTSTRAP_CLUSTER)
async def bootstrap_cluster(request: web.Request) -> web.Response:
ray_service: RayService = request.app["ray_service"]

request_data = models.CreateClusterRequestData.parse_raw(await request.text())
request_data = models.BootstrapClusterRequestData.parse_raw(await request.text())

(
is_cluster_just_created,
wallet_address,
yagna_payment_status_output,
yagna_payment_status,
) = await ray_service.create_cluster(provider_config=request_data)
try:
(
is_cluster_just_created,
wallet_address,
yagna_payment_status_output,
yagna_payment_status,
) = await ray_service.create_cluster(
request_data.provider_config, request_data.cluster_name
)
except RayServiceError as e:
raise web.HTTPBadRequest(reason=str(e))

return json_response(
models.CreateClusterResponseData(
models.BootstrapClusterResponseData(
is_cluster_just_created=is_cluster_just_created,
wallet_address=wallet_address,
yagna_payment_status_output=yagna_payment_status_output,
Expand Down
17 changes: 14 additions & 3 deletions ray_on_golem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import os
from asyncio.subprocess import Process
from collections import deque
from datetime import timedelta
from pathlib import Path
from typing import Dict
from typing import Dict, Optional

from aiohttp.web_runner import GracefulExit

Expand All @@ -27,14 +28,24 @@ async def run_subprocess(
return process


async def run_subprocess_output(*args) -> bytes:
async def run_subprocess_output(*args, timeout: Optional[timedelta] = None) -> bytes:
process = await run_subprocess(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

stdout, stderr = await process.communicate()
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout.total_seconds() if timeout else None,
)
except asyncio.TimeoutError as e:
if process.returncode is None:
process.kill()
await process.wait()

raise RayOnGolemError(f"Process could not finish in timeout of {timeout}!") from e

if process.returncode != 0:
raise RayOnGolemError(
Expand Down

0 comments on commit 6a7511c

Please sign in to comment.