Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Typehint for RunSettings.colocated_db_settings #462

Merged
merged 19 commits into from
Jan 29, 2024
Merged
2 changes: 1 addition & 1 deletion smartsim/_core/_cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def check_py_torch_version(versions: Versioner, device_in: _TDeviceStr = "cpu")
"Torch version not found in python environment. "
"Attempting to install via `pip`"
)
wheel_device = device if device == "cpu" else device_suffix.replace("+","")
wheel_device = device if device == "cpu" else device_suffix.replace("+", "")
pip(
"install",
"--extra-index-url",
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def _prep_entity_client_env(self, entity: Model) -> None:
# Set address to local if it's a colocated model
if entity.colocated and entity.run_settings.colocated_db_settings is not None:
db_name_colo = entity.run_settings.colocated_db_settings["db_identifier"]

assert isinstance(db_name_colo, str)
for key in address_dict:
_, db_id = unpack_db_identifier(key, "_")
if db_name_colo == db_id:
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/step/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
)
makedirs(osp.dirname(script_path), exist_ok=True)

db_settings: t.Dict[str, str] = {}
db_settings = {}

Check warning on line 96 in smartsim/_core/launcher/step/step.py

View check run for this annotation

Codecov / codecov/patch

smartsim/_core/launcher/step/step.py#L96

Added line #L96 was not covered by tests
if isinstance(self.step_settings, RunSettings):
db_settings = self.step_settings.colocated_db_settings or {}

Expand Down
52 changes: 41 additions & 11 deletions smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def colocate_db_uds(
f"Invalid name for unix socket: {unix_socket}. Must only "
"contain alphanumeric characters or . : _ - /"
)

uds_options = {
uds_options: t.Dict[str, t.Union[int, str]] = {
"unix_socket": unix_socket,
"socket_permissions": socket_permissions,
"port": 0, # This is hardcoded to 0 as recommended by redis for UDS
# This is hardcoded to 0 as recommended by redis for UDS
"port": 0,
}

common_options = {
Expand Down Expand Up @@ -332,9 +332,18 @@ def colocate_db_tcp(

def _set_colocated_db_settings(
self,
connection_options: t.Dict[str, t.Any],
common_options: t.Dict[str, t.Any],
**kwargs: t.Any,
connection_options: t.Mapping[str, t.Union[int, t.List[str], str]],
common_options: t.Dict[
str,
t.Union[
t.Union[t.Iterable[t.Union[int, t.Iterable[int]]], None],
bool,
int,
str,
None,
],
],
**kwargs: t.Union[int, None],
) -> None:
"""
Ingest the connection-specific options (UDS/TCP) and set the final settings
Expand All @@ -357,21 +366,42 @@ def _set_colocated_db_settings(
)

# TODO list which db settings can be extras
custom_pinning_ = t.cast(
t.Optional[t.Iterable[t.Union[int, t.Iterable[int]]]],
common_options.get("custom_pinning"),
)
cpus_ = t.cast(int, common_options.get("cpus"))
common_options["custom_pinning"] = self._create_pinning_string(
common_options["custom_pinning"], common_options["cpus"]
custom_pinning_, cpus_
)

colo_db_config = {}
colo_db_config: t.Dict[
str,
t.Union[
bool,
int,
str,
None,
t.List[str],
t.Iterable[t.Union[int, t.Iterable[int]]],
t.List[DBModel],
t.List[DBScript],
t.Dict[str, t.Union[int, None]],
t.Dict[str, str],
],
] = {}
colo_db_config.update(connection_options)
colo_db_config.update(common_options)
# redisai arguments for inference settings
colo_db_config["rai_args"] = {

redis_ai_temp = {
"threads_per_queue": kwargs.get("threads_per_queue", None),
"inter_op_parallelism": kwargs.get("inter_op_parallelism", None),
"intra_op_parallelism": kwargs.get("intra_op_parallelism", None),
}
# redisai arguments for inference settings
colo_db_config["rai_args"] = redis_ai_temp
colo_db_config["extra_db_args"] = {
k: str(v) for k, v in kwargs.items() if k not in colo_db_config["rai_args"]
k: str(v) for k, v in kwargs.items() if k not in redis_ai_temp
}

self._check_db_objects_colo()
Expand Down
19 changes: 18 additions & 1 deletion smartsim/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from smartsim.settings.containers import Container

from .._core.utils.helpers import expand_exe_path, fmt_dict, is_valid_cmd
from ..entity.dbobject import DBModel, DBScript
from ..log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -96,7 +97,23 @@ def __init__(
self.container = container
self._run_command = run_command
self.in_batch = False
self.colocated_db_settings: t.Optional[t.Dict[str, str]] = None
self.colocated_db_settings: t.Optional[
t.Dict[
str,
t.Union[
bool,
int,
str,
None,
t.List[str],
t.Iterable[t.Union[int, t.Iterable[int]]],
t.List[DBModel],
t.List[DBScript],
t.Dict[str, t.Union[int, None]],
t.Dict[str, str],
],
]
] = None

@property
def exe_args(self) -> t.Union[str, t.List[str]]:
Expand Down
2 changes: 1 addition & 1 deletion smartsim/settings/lsfSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def set_cpus_per_rs(self, cpus_per_rs: int) -> None:
:type cpus_per_rs: int or str
"""
if self.colocated_db_settings:
db_cpus = int(self.colocated_db_settings.get("db_cpus", 0))
db_cpus = int(t.cast(int, self.colocated_db_settings.get("db_cpus", 0)))
if not db_cpus:
raise ValueError("db_cpus must be configured on colocated_db_settings")

Expand Down
4 changes: 2 additions & 2 deletions tests/install/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import pytest

import functools
import pathlib
import platform
import threading
import time

import pytest

import smartsim._core._install.builder as build

# The tests in this file belong to the group_a group
Expand Down
Loading