Skip to content

Commit

Permalink
Manifest: has DB objects refactor (#476)
Browse files Browse the repository at this point in the history
Refactor logic of `Manifest.has_db_objects` to remove excess branching
and improve readability/maintainability.

[ committed by @MattToast ]
[ reviewed by @ankona ]
  • Loading branch information
MattToast authored Feb 6, 2024
1 parent 3a4e828 commit b84b49f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 45 deletions.
51 changes: 6 additions & 45 deletions smartsim/_core/control/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import itertools
import pathlib
import typing as t
from dataclasses import dataclass, field
Expand Down Expand Up @@ -178,52 +179,12 @@ def __str__(self) -> str:
@property
def has_db_objects(self) -> bool:
"""Check if any entity has DBObjects to set"""

def has_db_models(
entity: t.Union[EntitySequence[SmartSimEntity], Model]
) -> bool:
return len(list(entity.db_models)) > 0

def has_db_scripts(
entity: t.Union[EntitySequence[SmartSimEntity], Model]
) -> bool:
return len(list(entity.db_scripts)) > 0

has_db_objects = False

# Check if any model has either a DBModel or a DBScript
# we update has_db_objects so that as soon as one check
# returns True, we can exit
has_db_objects |= any(
has_db_models(model) | has_db_scripts(model) for model in self.models
ents: t.Iterable[t.Union[Model, Ensemble]] = itertools.chain(
self.models,
self.ensembles,
(member for ens in self.ensembles for member in ens.entities),
)
if has_db_objects:
return True

# If there are no ensembles, there can be no outstanding model
# to check for DBObjects, return current value of DBObjects, which
# should be False
ensembles = self.ensembles
if not ensembles:
return has_db_objects

# First check if there is any ensemble DBObject, if so, return True
has_db_objects |= any(
has_db_models(ensemble) | has_db_scripts(ensemble) for ensemble in ensembles
)
if has_db_objects:
return True
for ensemble in ensembles:
# Last case, check if any model within an ensemble has DBObjects attached
has_db_objects |= any(
has_db_models(model) | has_db_scripts(model)
for model in ensemble.models
)
if has_db_objects:
return True

# `has_db_objects` should be False here
return has_db_objects
return any(any(ent.db_models) or any(ent.db_scripts) for ent in ents)


class _LaunchedManifestMetadata(t.NamedTuple):
Expand Down
36 changes: 36 additions & 0 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
_LaunchedManifestMetadata as LaunchedManifestMetadata,
)
from smartsim.database import Orchestrator
from smartsim.entity.dbobject import DBModel, DBScript
from smartsim.error import SmartSimError
from smartsim.settings import RunSettings

Expand All @@ -61,6 +62,9 @@
orc_1.name = "orc2"
model_no_name = exp.create_model(name=None, run_settings=rs)

db_script = DBScript("some-script", "def main():\n print('hello world')\n")
db_model = DBModel("some-model", "TORCH", b"some-model-bytes")


def test_separate():
manifest = Manifest(model, ensemble, orc)
Expand Down Expand Up @@ -106,6 +110,38 @@ class Person:
_ = Manifest(p)


@pytest.mark.parametrize(
"patch, has_db_objects",
[
pytest.param((), False, id="No DB Objects"),
pytest.param((model, "_db_models", [db_model]), True, id="Model w/ DB Model"),
pytest.param(
(model, "_db_scripts", [db_script]), True, id="Model w/ DB Script"
),
pytest.param(
(ensemble, "_db_models", [db_model]), True, id="Ensemble w/ DB Model"
),
pytest.param(
(ensemble, "_db_scripts", [db_script]), True, id="Ensemble w/ DB Script"
),
pytest.param(
(ensemble.entities[0], "_db_models", [db_model]),
True,
id="Ensemble Member w/ DB Model",
),
pytest.param(
(ensemble.entities[0], "_db_scripts", [db_script]),
True,
id="Ensemble Member w/ DB Script",
),
],
)
def test_manifest_detects_db_objects(monkeypatch, patch, has_db_objects):
if patch:
monkeypatch.setattr(*patch)
assert Manifest(model, ensemble).has_db_objects == has_db_objects


def test_launched_manifest_transform_data():
models = [(model, 1), (model_2, 2)]
ensembles = [(ensemble, [(m, i) for i, m in enumerate(ensemble.entities)])]
Expand Down

0 comments on commit b84b49f

Please sign in to comment.