Skip to content

Commit e8e23ed

Browse files
committed
Fix: Reboot endpoint was not implemented
1 parent 73f5147 commit e8e23ed

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

src/aleph/vm/orchestrator/supervisor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
status_public_config,
3333
update_allocations,
3434
)
35-
from .views.operator import operate_erase, operate_expire, operate_stop, stream_logs
35+
from .views.operator import (
36+
operate_erase,
37+
operate_expire,
38+
operate_reboot,
39+
operate_stop,
40+
stream_logs,
41+
)
3642

3743
logger = logging.getLogger(__name__)
3844

@@ -78,6 +84,7 @@ async def allow_cors_on_endpoint(request: web.Request):
7884
web.post("/control/machine/{ref}/expire", operate_expire),
7985
web.post("/control/machine/{ref}/stop", operate_stop),
8086
web.post("/control/machine/{ref}/erase", operate_erase),
87+
web.post("/control/machine/{ref}/reboot", operate_reboot),
8188
web.options(
8289
"/control/machine/{ref}/{view:.*}",
8390
allow_cors_on_endpoint,

src/aleph/vm/orchestrator/views/operator.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Callable, Coroutine, Literal, Union
88

99
import aiohttp.web_exceptions
10+
import pydantic
1011
from aiohttp import web
1112
from aiohttp.web_urldispatcher import UrlMappingMatchInfo
1213
from aleph_message.exceptions import UnknownHashError
@@ -18,6 +19,7 @@
1819
from pydantic.main import BaseModel
1920

2021
from aleph.vm.models import VmExecution
22+
from aleph.vm.orchestrator.run import create_vm_execution
2123
from aleph.vm.pool import VmPool
2224

2325
logger = logging.getLogger(__name__)
@@ -28,7 +30,7 @@ def is_token_still_valid(timestamp):
2830
Checks if a token has exprired based on its timestamp
2931
"""
3032
current_datetime = datetime.now(tz=timezone.utc)
31-
target_datetime = datetime.fromisoformat(timestamp, tz=timezone.utc)
33+
target_datetime = datetime.fromisoformat(timestamp)
3234

3335
return target_datetime > current_datetime
3436

@@ -51,7 +53,7 @@ class SignedPubKeyPayload(BaseModel):
5153
# alg: Literal["ECDSA"]
5254
domain: str
5355
address: str
54-
expires: str
56+
expires: str
5557

5658
@property
5759
def json_web_key(self) -> Jwk:
@@ -219,7 +221,6 @@ def get_execution_or_404(ref: ItemHash, pool: VmPool) -> VmExecution:
219221

220222
@require_jwk_authentication
221223
async def stream_logs(request: web.Request, authenticated_sender: str) -> web.StreamResponse:
222-
# TODO: Add user authentication
223224
vm_hash = get_itemhash_or_400(request.match_info)
224225
pool: VmPool = request.app["vm_pool"]
225226
execution = get_execution_or_404(vm_hash, pool=pool)
@@ -260,7 +261,6 @@ async def operate_expire(request: web.Request, authenticated_sender: str) -> web
260261
"""Stop the virtual machine, smoothly if possible.
261262
262263
A timeout may be specified to delay the action."""
263-
# TODO: Add user authentication
264264
vm_hash = get_itemhash_or_400(request.match_info)
265265
timeout = float(ItemHash(request.match_info["timeout"]))
266266
if not 0 < timeout < timedelta(days=10).total_seconds():
@@ -317,9 +317,14 @@ async def operate_reboot(request: web.Request, authenticated_sender: str) -> web
317317
logger.debug(f"Unauthorized sender {authenticated_sender} for {vm_hash}")
318318
return web.Response(status=401, body="Unauthorized sender")
319319

320-
# TODO: implement this endpoint
321-
logger.info(f"Rebooting {execution.vm_hash}")
322-
return web.Response(status=200, body=f"Rebooted {execution.vm_hash}")
320+
if execution.is_running:
321+
logger.info(f"Rebooting {execution.vm_hash}")
322+
await pool.stop_vm(vm_hash)
323+
pool.forget_vm(vm_hash)
324+
await create_vm_execution(vm_hash=vm_hash, pool=pool)
325+
return web.Response(status=200, body=f"Rebooted VM with ref {vm_hash}")
326+
else:
327+
return web.Response(status=200, body="Starting VM (was not running) with ref {vm_hash}")
323328

324329

325330
@require_jwk_authentication

src/aleph/vm/pool.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ async def get_running_vm(self, vm_hash: ItemHash) -> Optional[VmExecution]:
136136
else:
137137
return None
138138

139+
async def stop_vm(self, vm_hash: ItemHash) -> Optional[VmExecution]:
140+
"""Stop a VM."""
141+
execution = self.executions.get(vm_hash)
142+
if execution:
143+
await execution.stop()
144+
return execution
145+
else:
146+
return None
147+
139148
def forget_vm(self, vm_hash: ItemHash) -> None:
140149
"""Remove a VM from the executions pool.
141150

0 commit comments

Comments
 (0)