Skip to content

Commit

Permalink
Merge pull request #273 from ncilfone/bugs_v3.0.0
Browse files Browse the repository at this point in the history
Bug Fixes for v3.0.0
  • Loading branch information
mmalouane authored Jan 19, 2023
2 parents 8bd55a4 + f11f3ce commit af305f4
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 60 deletions.
2 changes: 2 additions & 0 deletions spock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
"config",
"directory",
"file",
"helpers",
"SavePath",
"spock",
"SpockBuilder",
"utils",
]

__version__ = get_versions()["version"]
Expand Down
2 changes: 2 additions & 0 deletions spock/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ __all__ = [
"config",
"directory",
"file",
"helpers",
"SavePath",
"spock",
"SpockBuilder",
"utils",
]

_T = TypeVar("_T")
Expand Down
36 changes: 31 additions & 5 deletions spock/backend/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import argparse
from abc import ABC, abstractmethod
from enum import EnumMeta
from typing import ByteString, Dict, List, Optional, Tuple
from typing import ByteString, Dict, List, Optional, Set, Tuple

import attr

from spock.args import SpockArguments
from spock.backend.field_handlers import RegisterSpockCls
from spock.backend.help import attrs_help
from spock.backend.resolvers import VarResolver
from spock.backend.spaces import BuilderSpace
from spock.backend.wrappers import Spockspace
from spock.exceptions import _SpockInstantiationError
Expand Down Expand Up @@ -196,17 +197,23 @@ def resolve_spock_space_kwargs(self, graph: Graph, dict_args: Dict) -> Dict:
# dependencies in the correct order
for spock_name in merged_graph.topological_order:
# First we check for any needed cls dependent variable resolution
cls_fields = var_graph.resolve(spock_name, builder_space.spock_space)
cls_fields, cls_changed_vars = var_graph.resolve(
spock_name, builder_space.spock_space
)
# Then we map cls references to their instantiated version
cls_fields = self._clean_up_cls_refs(cls_fields, builder_space.spock_space)
# Lastly we have to check for self-resolution -- we do this w/ yet another
# graph -- graphs FTW! -- this maps back to the fields dict in the tuple
cls_fields = SelfGraph(
cls_fields, var_changed_vars = SelfGraph(
cls_fields_dict[spock_name]["cls"], cls_fields
).resolve()

# Once all resolution occurs we attempt to instantiate the cls
# Get the actual underlying class
spock_cls = merged_graph.node_map[spock_name]
# Merge the changed sets -- then attempt to cast them all post resolution
self._cast_all_maps(
spock_cls, cls_fields, cls_changed_vars | var_changed_vars
)
# Once all resolution occurs we attempt to instantiate the cls
try:
spock_instance = spock_cls(**cls_fields)
except Exception as e:
Expand All @@ -218,6 +225,25 @@ def resolve_spock_space_kwargs(self, graph: Graph, dict_args: Dict) -> Dict:
builder_space.spock_space[spock_cls.__name__] = spock_instance
return builder_space.spock_space

@staticmethod
def _cast_all_maps(cls, cls_fields: Dict, changed_vars: Set) -> None:
"""Casts all the resolved references to the requested type
Args:
cls: current spock class
cls_fields: current fields dictionary to attempt cast within
changed_vars: set of resolved variables that need to be cast
Returns:
"""
for val in changed_vars:
cls_fields[val] = VarResolver._attempt_cast(
maybe_env=cls_fields[val],
value_type=getattr(cls.__attrs_attrs__, val).type,
ref_value=val,
)

@staticmethod
def _clean_up_cls_refs(fields: Dict, spock_space: Dict) -> Dict:
"""Swaps in the newly created cls if it hasn't been instantiated yet
Expand Down
6 changes: 3 additions & 3 deletions spock/backend/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _is_file(type: _T, check_access: bool, attr: attr.Attribute, value: str) ->
_check_instance(value, attr.name, str)
# # If so then cast to underlying type
# value = file(value)
if not Path(value).is_file():
if not Path(value).resolve().is_file():
raise ValueError(f"{attr.name} must be a file: {value} is not a valid file")
r = os.access(value, os.R_OK)
w = os.access(value, os.W_OK)
Expand Down Expand Up @@ -141,13 +141,13 @@ def _is_directory(
# Check the instance type first
_check_instance(value, attr.name, str)
# If it's not a path and not flagged to create then raise exception
if not Path(value).is_dir() and not create:
if not Path(value).resolve().is_dir() and not create:
raise ValueError(
f"{attr.name} must be a directory: {value} is not a " f"valid directory"
)
# Else just try and create the path -- exist_ok means if the path already exists
# it won't throw an exception
elif not Path(value).is_dir() and create:
elif not Path(value).resolve().is_dir() and create:
try:
os.makedirs(value, exist_ok=True)
print(
Expand Down
35 changes: 20 additions & 15 deletions spock/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,21 +640,26 @@ def obj_2_dict(self, obj: Union[_C, List[_C], Tuple[_C, ...]]) -> Dict[str, Dict
Returns:
dictionary where the class names are keys and the values are the dictionary representations
"""
if isinstance(obj, (List, Tuple)):
obj_dict = {}
for val in obj:
if not _is_spock_instance(val):
raise _SpockValueError(
f"Object is not a @spock decorated class object -- currently `{type(val)}`"
)
obj_dict.update({type(val).__name__: val})
elif _is_spock_instance(obj):
obj_dict = {type(obj).__name__: obj}
else:
raise _SpockValueError(
f"Object is not a @spock decorated class object -- currently `{type(obj)}`"
)
return self.spockspace_2_dict(Spockspace(**obj_dict))

from spock.helpers import to_dict

return to_dict(obj, self._saver_obj)

# if isinstance(obj, (List, Tuple)):
# obj_dict = {}
# for val in obj:
# if not _is_spock_instance(val):
# raise _SpockValueError(
# f"Object is not a @spock decorated class object -- currently `{type(val)}`"
# )
# obj_dict.update({type(val).__name__: val})
# elif _is_spock_instance(obj):
# obj_dict = {type(obj).__name__: obj}
# else:
# raise _SpockValueError(
# f"Object is not a @spock decorated class object -- currently `{type(obj)}`"
# )
# return self.spockspace_2_dict(Spockspace(**obj_dict))

def evolve(self, *args: _C) -> Spockspace:
"""Function that allows a user to evolve the underlying spock classes with
Expand Down
48 changes: 17 additions & 31 deletions spock/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,16 @@ def reverse_map(self):
def nodes(self):
return [k.name for k in self._cls.__attrs_attrs__]

def _cast_all_maps(self, changed_vars: Set):
for val in changed_vars:
self._fields[val] = self.var_resolver.attempt_cast(
self._fields[val], getattr(self._cls.__attrs_attrs__, val).type, val
)
def resolve(self) -> Tuple[Dict, Set]:
"""Resolves variable references by searching thorough the current spock_space
Args:
def resolve(self) -> Dict:
Returns:
field dictionary containing the resolved values and a set containing all
changed variables to delay casting post resolution
"""
# Iterate in topo order
for k in self.topological_order:
# get the self dependent values and swap within the fields dict
Expand All @@ -275,9 +278,8 @@ def resolve(self) -> Dict:
name=v,
)
self._fields[v] = typed_val
# Get a set of all changed variables and attempt to cast them
self._cast_all_maps(set(self._ref_map.keys()))
return self._fields
# Get a set of all changed variables
return self._fields, set(self._ref_map.keys())

def _build(self) -> Tuple[Dict, Dict]:
"""Builds a dictionary of nodes and their edges (essentially builds the DAG)
Expand Down Expand Up @@ -363,35 +365,20 @@ def ref_2_resolve(self) -> Set:
"""Returns the values that need to be resolved"""
return set(self.ref_map.keys())

def _cast_all_maps(self, cls_name: str, changed_vars: Set) -> None:
"""Casts all the resolved references to the requested type
Args:
cls_name: name of the underlying class
changed_vars: set of resolved variables that need to be cast
Returns:
"""
for val in changed_vars:
self.cls_map[cls_name][val] = self.var_resolver.attempt_cast(
self.cls_map[cls_name][val],
getattr(self.node_map[cls_name].__attrs_attrs__, val).type,
val,
)

def resolve(self, spock_cls: str, spock_space: Dict) -> Dict:
def resolve(self, spock_cls: str, spock_space: Dict) -> Tuple[Dict, Set]:
"""Resolves variable references by searching thorough the current spock_space
Args:
spock_cls: name of the spock class
spock_space: current spock_space to look for the underlying value
Returns:
field dictionary containing the resolved values
field dictionary containing the resolved values and a set containing all
changed variables to delay casting post resolution
"""
# First we check for any needed variable resolution
changed_vars = set()
if spock_cls in self.ref_2_resolve:
# iterate over the mapped refs to swap values -- using the var resolver
# to get the correct values
Expand All @@ -406,11 +393,10 @@ def resolve(self, spock_cls: str, spock_space: Dict) -> Dict:
)
# Swap the value to the replaced version
self.cls_map[spock_cls][ref["val"]] = typed_val
# Get a set of all changed variables and attempt to cast them
# Get a set of all changed variables
changed_vars = {n["val"] for n in self.ref_map[spock_cls]}
self._cast_all_maps(spock_cls, changed_vars)
# Return the field dict
return self.cls_map[spock_cls]
return self.cls_map[spock_cls], changed_vars

def _build(self) -> Tuple[Dict, Dict]:
"""Builds a dictionary of nodes and their edges (essentially builds the DAG)
Expand Down
44 changes: 44 additions & 0 deletions spock/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-

# SPDX-License-Identifier: Apache-2.0

"""Helper functions for Spock"""

from typing import Dict, List, Optional, Tuple, Union

from spock.backend.saver import AttrSaver
from spock.backend.wrappers import Spockspace
from spock.exceptions import _SpockValueError
from spock.utils import _C, _is_spock_instance


def to_dict(
objs: Union[_C, List[_C], Tuple[_C, ...]], saver: Optional[AttrSaver] = AttrSaver()
) -> Dict[str, Dict]:
"""Converts spock classes from a Spockspace into their dictionary representations
Args:
objs: single spock class or an iterable of spock classes
saver: optional saver class object
Returns:
dictionary where the class names are keys and the values are the dictionary
representations
"""
if isinstance(objs, (List, Tuple)):
obj_dict = {}
for val in objs:
if not _is_spock_instance(val):
raise _SpockValueError(
f"Object is not a @spock decorated class object -- "
f"currently `{type(val)}`"
)
obj_dict.update({type(val).__name__: val})
elif _is_spock_instance(objs):
obj_dict = {type(objs).__name__: objs}
else:
raise _SpockValueError(
f"Object is not a @spock decorated class object -- "
f"currently `{type(objs)}`"
)
return saver.dict_payload(Spockspace(**obj_dict))
5 changes: 0 additions & 5 deletions spock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ def _get_callable_type():
return _VariadicGenericAlias


def _get_new_type():

pass


_SpockGenericAlias = _get_alias_type()
_SpockVariadicGenericAlias = _get_callable_type()
_T = TypeVar("_T")
Expand Down
11 changes: 10 additions & 1 deletion tests/base/test_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class FooBar:
val: int = 12


@spock
class OtherRef:
other_str: str = "yes"


@spock
class RefClass:
a_float: float = 12.1
Expand Down Expand Up @@ -113,6 +118,9 @@ class RefClassDefault:
ref_nested_to_str: str = "${spock.var:FooBar.val}.${spock.var:Lastly.tester}"
ref_nested_to_float: float = "${spock.var:FooBar.val}.${spock.var:Lastly.tester}"
ref_self: float = "${spock.var:RefClassDefault.ref_float}"
ref_self_nested: str = (
"${spock.var:RefClassDefault.ref_string}-${spock.var:OtherRef.other_str}"
)


class TestRefResolver:
Expand Down Expand Up @@ -162,7 +170,7 @@ def test_from_def(self, monkeypatch):
with monkeypatch.context() as m:
m.setattr(sys, "argv", [""])
config = SpockBuilder(
RefClassDefault, RefClass, Lastly, BarFoo, FooBar
RefClassDefault, RefClass, Lastly, BarFoo, FooBar, OtherRef
).generate()

assert config.RefClassDefault.ref_float == 12.1
Expand All @@ -172,6 +180,7 @@ def test_from_def(self, monkeypatch):
assert config.RefClassDefault.ref_nested_to_str == "12.1"
assert config.RefClassDefault.ref_nested_to_float == 12.1
assert config.RefClassDefault.ref_self == config.RefClassDefault.ref_float
assert config.RefClassDefault.ref_self_nested == "helloo-yes"


@spock
Expand Down

0 comments on commit af305f4

Please sign in to comment.