Skip to content

Commit 2959170

Browse files
committed
Address comments and fix mypy errors.
Signed-off-by: [email protected] <[email protected]>
1 parent fd7a25f commit 2959170

File tree

5 files changed

+31
-24
lines changed

5 files changed

+31
-24
lines changed

tests/fault_tolerance/deploy/client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import random
2020
import time
2121
from datetime import datetime
22+
from typing import Any, Dict
2223

2324
import requests
2425

@@ -154,12 +155,13 @@ def client(
154155
logger = logging.getLogger(f"CLIENT: {index}")
155156
logging.getLogger("httpx").setLevel(logging.WARNING)
156157

157-
managed_deployment = ManagedDeployment(None, deployment_spec, namespace)
158-
pod_ports = {}
158+
managed_deployment = ManagedDeployment(log_dir, deployment_spec, namespace)
159+
pod_ports: Dict[str, Any] = {}
159160

160161
min_elapsed_time = 1 / max_request_rate
161162

162163
try:
164+
os.makedirs(log_dir, exist_ok=True)
163165
log_path = os.path.join(log_dir, f"client_{index}.log.txt")
164166
with open(log_path, "w") as log:
165167
for i in range(requests_per_client):

tests/fault_tolerance/deploy/parse_results.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
import re
2020
from datetime import datetime
21-
from typing import Any
21+
from typing import Any, Dict, List, Tuple
2222

2323
import pandas as pd
2424
from tabulate import tabulate
@@ -38,7 +38,7 @@ def parse_test_log(file_path):
3838
start_time = datetime.fromisoformat(
3939
line.split(" ")[1].replace("T", " ")
4040
)
41-
start_cmd = []
41+
start_cmd: List[str] = []
4242
elif "Deployment fault-tolerance-test is ready" in line:
4343
ready_time = datetime.fromisoformat(
4444
line.split(" ")[1].replace("T", " ")
@@ -170,7 +170,7 @@ def parse_process_log(log_dir, process_name):
170170
}
171171
if not os.path.isdir(log_dir):
172172
return {}
173-
ready_times = {}
173+
ready_times: Dict[str, List[Tuple[datetime, str, float]]] = {}
174174

175175
for entry in os.listdir(log_dir):
176176
if entry.endswith(".log") and "metrics" not in entry:
@@ -317,7 +317,7 @@ def process_test_directory(test_dir, sla):
317317
}
318318

319319

320-
def main(logs_dir, tablefmt, log_paths=[], sla=None):
320+
def main(logs_dir, tablefmt, log_paths=None, sla=None):
321321
results = []
322322
if log_paths:
323323
for log_path in log_paths:

tests/fault_tolerance/deploy/scenarios.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
from dataclasses import dataclass
17+
from typing import Optional
1718

1819
from tests.utils.managed_deployment import DeploymentSpec
1920

@@ -26,7 +27,7 @@ class Load:
2627
output_token_length: int = 100
2728
max_retries: int = 1
2829
max_request_rate: float = 1
29-
sla: float = None
30+
sla: Optional[float] = None
3031

3132

3233
@dataclass
@@ -43,7 +44,7 @@ class Scenario:
4344
deployment: DeploymentSpec
4445
load: Load
4546
failures: list[Failure]
46-
model: str = None
47+
model: Optional[str] = None
4748

4849

4950
# Each Deployment Spec contains

tests/fault_tolerance/deploy/test_deployment.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import multiprocessing
66
import time
77
from contextlib import contextmanager
8-
from multiprocessing import Process
98

109
import pytest
1110

@@ -14,8 +13,6 @@
1413
from tests.fault_tolerance.deploy.scenarios import scenarios
1514
from tests.utils.managed_deployment import ManagedDeployment
1615

17-
multiprocessing.set_start_method("spawn")
18-
1916

2017
@pytest.fixture(params=scenarios.keys())
2118
def scenario(request):
@@ -37,9 +34,10 @@ def _clients(
3734
max_request_rate,
3835
):
3936
procs = []
37+
ctx = multiprocessing.get_context("spawn")
4038
for i in range(num_clients):
4139
procs.append(
42-
Process(
40+
ctx.Process(
4341
target=client,
4442
args=(
4543
deployment_spec,

tests/utils/managed_deployment.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import shlex
99
import time
1010
from dataclasses import dataclass
11-
from typing import Optional
11+
from typing import Any, Optional
1212

1313
import kr8s
1414
import kubernetes
@@ -179,6 +179,10 @@ def name(self) -> str:
179179
"""Deployment name"""
180180
return self._deployment_spec["metadata"]["name"]
181181

182+
@name.setter
183+
def name(self, value: str):
184+
self._deployment_spec["metadata"]["name"] = value
185+
182186
@property
183187
def port(self) -> int:
184188
"""Deployment port"""
@@ -193,10 +197,6 @@ def system_port(self) -> int:
193197
def endpoint(self) -> str:
194198
return self._endpoint
195199

196-
@name.setter
197-
def name(self, value: str):
198-
self._deployment_spec["metadata"]["name"] = value
199-
200200
@property
201201
def namespace(self) -> str:
202202
"""Deployment namespace"""
@@ -353,12 +353,13 @@ class ManagedDeployment:
353353
namespace: str
354354
frontend_service_name: Optional[str] = "Frontend"
355355

356-
_custom_api = None
357-
_core_api = None
358-
_in_cluster = False
359-
_logger = logging.getLogger()
360-
_port_forward = None
361-
_deployment_name = None
356+
_custom_api: Optional[Any] = None
357+
_core_api: Optional[Any] = None
358+
_in_cluster: bool = False
359+
_logger: logging.Logger = logging.getLogger()
360+
_port_forward: Optional[Any] = None
361+
_deployment_name: Optional[str] = None
362+
_apps_v1: Optional[Any] = None
362363

363364
def __post_init__(self):
364365
self._deployment_name = self.deployment_spec.name
@@ -379,6 +380,7 @@ async def _init_kubernetes(self):
379380

380381
async def _wait_for_pods(self, label, expected, timeout=300):
381382
for _ in range(timeout):
383+
assert self._core_api is not None, "Kubernetes API not initialized"
382384
pods = await self._core_api.list_namespaced_pod(
383385
self.namespace, label_selector=label
384386
)
@@ -397,6 +399,7 @@ async def _wait_for_pods(self, label, expected, timeout=300):
397399

398400
async def _scale_statfulset(self, name, label, replicas):
399401
body = {"spec": {"replicas": replicas}}
402+
assert self._apps_v1 is not None, "Kubernetes API not initialized"
400403
await self._apps_v1.patch_namespaced_stateful_set_scale(
401404
name, self.namespace, body
402405
)
@@ -406,6 +409,7 @@ async def _restart_stateful(self, name, label):
406409
self._logger.info(f"Restarting {name} {label}")
407410

408411
await self._scale_statfulset(name, label, 0)
412+
assert self._core_api is not None, "Kubernetes API not initialized"
409413
nats_pvc = await self._core_api.list_namespaced_persistent_volume_claim(
410414
self.namespace, label_selector=label
411415
)
@@ -434,6 +438,7 @@ async def _wait_for_ready(self, timeout: int = 1800, sleep=1, log_interval=60):
434438
while (time.time() - start_time) < timeout:
435439
try:
436440
attempt += 1
441+
assert self._custom_api is not None, "Kubernetes API not initialized"
437442
status = await self._custom_api.get_namespaced_custom_object(
438443
group="nvidia.com",
439444
version="v1alpha1",
@@ -520,6 +525,7 @@ async def _create_deployment(self):
520525
)
521526

522527
try:
528+
assert self._custom_api is not None, "Kubernetes API not initialized"
523529
await self._custom_api.create_namespaced_custom_object(
524530
group="nvidia.com",
525531
version="v1alpha1",
@@ -652,7 +658,7 @@ async def _delete_deployment(self):
652658
Delete the DynamoGraphDeployment CR.
653659
"""
654660
try:
655-
if self._deployment_name:
661+
if self._deployment_name and self._custom_api is not None:
656662
await self._custom_api.delete_namespaced_custom_object(
657663
group="nvidia.com",
658664
version="v1alpha1",

0 commit comments

Comments
 (0)