Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing unitest warnings #820

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchx/apps/utils/booth_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
)
parser.add_argument(
"--trial_idx",
type=int,
type=str,
help="trial index (ignore if not running hpo)",
default=0,
default="0",
)
return parser.parse_args(argv)

Expand Down
2 changes: 1 addition & 1 deletion torchx/apps/utils/test/booth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def test_booth(self) -> None:
booth.main(["--x1", "1", "--x2", "3", "--tracker_base", self.test_dir])

tracker = FsspecResultTracker(self.test_dir)
self.assertEqual(0.0, tracker[0]["booth_eval"])
self.assertEqual(0.0, tracker["0"]["booth_eval"])
3 changes: 2 additions & 1 deletion torchx/cli/test/cmd_log_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Iterator, Optional
from unittest.mock import MagicMock, patch

import pytest
from torchx.cli.cmd_log import _prefix_line, ENDC, get_logs, GREEN, validate
from torchx.runner.api import Runner
from torchx.schedulers.api import Stream
Expand Down Expand Up @@ -191,7 +192,7 @@ def test_print_log_lines_throws(self, mock_runner: MagicMock) -> None:
# errors out; we raise the exception all the way through
with patch.object(mock_runner, "log_lines") as log_lines_mock:
log_lines_mock.side_effect = RuntimeError
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
get_logs(
sys.stdout,
"local://test-session/SparseNNAppDef/trainer/0,1",
Expand Down
3 changes: 2 additions & 1 deletion torchx/components/structured_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def parse_from(h: str, j: str) -> "StructuredJArgument":
f" This may lead to under-utilization or an error. "
f" If this was intentional, ignore this warning."
f" Otherwise set `-j {nnodes}` to auto-set nproc_per_node"
f" to the number of GPUs on the host."
f" to the number of GPUs on the host.",
ResourceWarning,
)
else:
raise ValueError(
Expand Down
28 changes: 16 additions & 12 deletions torchx/components/test/structured_arg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
from unittest import mock

import pytest
from torchx.components.structured_arg import StructuredJArgument, StructuredNameArgument

WARNINGS_WARN = "torchx.components.structured_arg.warnings.warn"
Expand Down Expand Up @@ -71,18 +72,21 @@ def test_create(self) -> None:
StructuredJArgument(nnodes=2, nproc_per_node=8),
StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2"),
)
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=4),
StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x4"),
)
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=16),
StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x16"),
)
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=8),
StructuredJArgument.parse_from(h="aws_trn1.2xlarge", j="2x8"),
)
with pytest.warns(ResourceWarning):
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=4),
StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x4"),
)
with pytest.warns(ResourceWarning):
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=16),
StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x16"),
)
with pytest.warns(ResourceWarning):
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=8),
StructuredJArgument.parse_from(h="aws_trn1.2xlarge", j="2x8"),
)

with self.assertRaisesRegex(
ValueError,
Expand Down
8 changes: 4 additions & 4 deletions torchx/examples/apps/lightning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import torch.jit
from torch.nn import functional as F
from torchmetrics import Accuracy
from torchmetrics.classification import MulticlassAccuracy
from torchvision.models.resnet import BasicBlock, ResNet


Expand All @@ -47,8 +47,8 @@ def __init__(
m.fc.out_features = 200
self.model: ResNet = m

self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.train_acc = MulticlassAccuracy(m.fc.out_features)
self.val_acc = MulticlassAccuracy(m.fc.out_features)

# pyre-fixme[14]
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -69,7 +69,7 @@ def validation_step(
def _step(
self,
step_name: str,
acc_metric: Accuracy,
acc_metric: MulticlassAccuracy,
batch: Tuple[torch.Tensor, torch.Tensor],
batch_idx: int,
) -> torch.Tensor:
Expand Down
12 changes: 8 additions & 4 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Generator, List, Mapping, Optional
from unittest.mock import MagicMock, patch

import pytest
from torchx.runner import get_runner, Runner
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler
from torchx.schedulers.local_scheduler import (
Expand Down Expand Up @@ -415,7 +416,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:
)
app_handle = runner.run(AppDef(app_id, roles=[role]), scheduler="local_dir")
status = none_throws(runner.status(app_handle))
self.assertEquals(resp.ui_url, status.ui_url)
self.assertEqual(resp.ui_url, status.ui_url)

@patch("json.dumps")
def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
Expand All @@ -439,7 +440,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:
)
app_handle = runner.run(AppDef(app_id, roles=[role]), scheduler="local_dir")
status = none_throws(runner.status(app_handle))
self.assertEquals(resp.structured_error_msg, status.structured_error_msg)
self.assertEqual(resp.structured_error_msg, status.structured_error_msg)

def test_wait_unknown_app(self, _) -> None:
with self.get_runner() as runner:
Expand All @@ -458,7 +459,10 @@ def test_cancel(self, _) -> None:

def test_stop(self, _) -> None:
with self.get_runner() as runner:
self.assertIsNone(runner.stop("local_dir://test_session/unknown_app_id"))
with pytest.warns(PendingDeprecationWarning):
self.assertIsNone(
runner.stop("local_dir://test_session/unknown_app_id")
)

def test_log_lines_unknown_app(self, _) -> None:
with self.get_runner() as runner:
Expand Down Expand Up @@ -550,7 +554,7 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None:
local_sched_mock.submit.called_once_with(app, {})

def test_run_from_module(self, _: str) -> None:
runner = get_runner(name="test_session")
runner = get_runner()
app_args = ["--image", "dummy_image", "--script", "test.py"]
with patch.object(runner, "schedule"), patch.object(
runner, "dryrun"
Expand Down
16 changes: 8 additions & 8 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torchx.test.fixtures import TestWithTmpDir


class TestScheduler(Scheduler):
class SchedulerTester(Scheduler):
def __init__(self, session_name: str) -> None:
super().__init__("test", session_name)

Expand Down Expand Up @@ -341,7 +341,7 @@ def test_no_override_load(self) -> None:

@patch(
TORCHX_GET_SCHEDULER_FACTORIES,
return_value={"test": TestScheduler},
return_value={"test": SchedulerTester},
)
def test_apply_default(self, _) -> None:
with patch(
Expand All @@ -357,7 +357,7 @@ def test_apply_default(self, _) -> None:

@patch(
TORCHX_GET_SCHEDULER_FACTORIES,
return_value={"test": TestScheduler},
return_value={"test": SchedulerTester},
)
def test_apply_dirs(self, _) -> None:
cfg: Dict[str, CfgVal] = {"s": "runtime_value"}
Expand All @@ -376,7 +376,7 @@ def test_dump_invalid_scheduler(self) -> None:

@patch(
TORCHX_GET_SCHEDULER_FACTORIES,
return_value={"test": TestScheduler},
return_value={"test": SchedulerTester},
)
def test_dump_only_required(self, _) -> None:
sfile = StringIO()
Expand All @@ -393,7 +393,7 @@ def test_dump_only_required(self, _) -> None:

@patch(
TORCHX_GET_SCHEDULER_FACTORIES,
return_value={"test": TestScheduler},
return_value={"test": SchedulerTester},
)
def test_load_invalid_runopt(self, _) -> None:
cfg = {}
Expand All @@ -407,7 +407,7 @@ def test_load_invalid_runopt(self, _) -> None:
# this makes things super hard to guarantee BC - stale config file will fail
# to run, we don't want that)

self.assertEquals("option_that_exists", cfg.get("s"))
self.assertEqual("option_that_exists", cfg.get("s"))

def test_load_no_section(self) -> None:
cfg = {}
Expand All @@ -429,7 +429,7 @@ def test_load_no_section(self) -> None:

@patch(
TORCHX_GET_SCHEDULER_FACTORIES,
return_value={"test": TestScheduler},
return_value={"test": SchedulerTester},
)
def test_dump_and_load_all_runopt_types(self, _) -> None:
sfile = StringIO()
Expand All @@ -441,7 +441,7 @@ def test_dump_and_load_all_runopt_types(self, _) -> None:
load(scheduler="test", f=sfile, cfg=cfg)

# all runopts in the TestScheduler have defaults, just check against those
for opt_name, opt in TestScheduler("test").run_opts():
for opt_name, opt in SchedulerTester("test").run_opts():
self.assertEqual(cfg.get(opt_name), opt.default)

def test_dump_and_load_all_registered_schedulers(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions torchx/runtime/tracking/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def test_put_get(self) -> None:

def test_get_missing_key(self) -> None:
tracker = FsspecResultTracker(self.test_dir)
res = tracker[1]
res = tracker["1"]
self.assertFalse(res)

def test_put_get_x2(self) -> None:
tracker = FsspecResultTracker(self.test_dir)
tracker[1] = {"l2norm": 1}
tracker[1] = {"l2norm": 2}
tracker["1"] = {"l2norm": 1}
tracker["1"] = {"l2norm": 2}

self.assertEqual(2, tracker["1"]["l2norm"])
self.assertEqual(2, tracker["1"]["l2norm"])
2 changes: 1 addition & 1 deletion torchx/schedulers/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_device_mounts(devices: Dict[str, int]) -> List[DeviceMount]:
device_mounts = []
for device_name, num_devices in devices.items():
if device_name not in DEVICES:
warnings.warn(f"Could not find named device: {device_name}")
warnings.warn(f"Could not find named device: {device_name}", RuntimeWarning)
continue
device_mounts += DEVICES[device_name](num_devices)
return device_mounts
6 changes: 3 additions & 3 deletions torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def get_unique_truncated_appid(app: AppDef) -> str:
if unique_id_size <= 3:
msg = "Name size has too many characters for some Kubernetes objects. Truncating \
application name."
warnings.warn(msg)
warnings.warn(msg, RuntimeWarning)
end = 63 - uid_chars - pg_chars
substring = app.name[0:end]
app.name = substring
Expand All @@ -501,7 +501,7 @@ def get_port_for_service(app: AppDef) -> str:
if not (0 < int(port) <= 65535):
msg = """Warning: port_map set to invalid port number. Value must be between 1-65535, with torchx default = 29500. Setting port to default = 29500"""
port = "29500"
warnings.warn(msg)
warnings.warn(msg, RuntimeWarning)

return port

Expand Down Expand Up @@ -1147,7 +1147,7 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
task_count = 0
for role in roles:
msg = "Warning - MCAD does not report individual replica statuses, but overall task status. Replica id may not match status"
warnings.warn(msg)
warnings.warn(msg, RuntimeWarning)

roles_statuses[role] = RoleStatus(role, [])
for idx in range(0, roles[role].num_replicas):
Expand Down
3 changes: 2 additions & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,8 @@ def log_iter(
if since or until:
warnings.warn(
"Since and/or until times specified for LocalScheduler.log_iter."
" These will be ignored and all log lines will be returned"
" These will be ignored and all log lines will be returned",
RuntimeWarning,
)

app = self._apps[app_id]
Expand Down
Loading
Loading