diff --git a/docs/api/docgen.py b/docs/api/docgen.py index 7a4408b..b3a9336 100644 --- a/docs/api/docgen.py +++ b/docs/api/docgen.py @@ -59,12 +59,11 @@ 'hyper', 'io', 'patching', - 'object_utils', + 'utils', 'symbolic', 'views', 'tuning', 'typing', - # Ext. 'early_stopping', 'evolution', @@ -73,7 +72,6 @@ 'evolution.recombinators', 'mutfun', 'scalars', - # generators. 'generators', }) diff --git a/docs/conf.py b/docs/conf.py index 3809d3a..ebcdbe6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -232,7 +232,7 @@ def noramlized_title_and_target(api, attr_name=None): exclude_patterns = [ # Temporarily disable sources for faster sphinx-build # 'api/core/symbolic', - # 'api/core/object_utils', + # 'api/core/utils', # 'api/core/typing', # 'api/core/detouring', # 'api/core/wrapping', diff --git a/docs/learn/soop/som/definition.rst b/docs/learn/soop/som/definition.rst index c9f0216..0c98ccf 100644 --- a/docs/learn/soop/som/definition.rst +++ b/docs/learn/soop/som/definition.rst @@ -73,7 +73,7 @@ Symbolic Tree A symbolic attribute is usually a value of simple type (e.g. :class:`int`, :class:`str``) or a reference to another symbolic object. When multiple symbolic objects are linked together in this way, it forms a tree-like structure of symbols. Each object in the tree is identified by a unique -combination of keys, which we call a *key path* (:class:`pg.KeyPath `). +combination of keys, which we call a *key path* (:class:`pg.KeyPath `). Key paths can be used to navigate and manipulate specific nodes in the symbolic tree:: @pg.symbolize diff --git a/docs/learn/soop/som/events.rst b/docs/learn/soop/som/events.rst index 6b11d0e..cb5672d 100644 --- a/docs/learn/soop/som/events.rst +++ b/docs/learn/soop/som/events.rst @@ -93,7 +93,7 @@ which is triggered when any of the symbolic arguments are updated via self._z = self.x ** 2 The ``_on_change`` event takes a ``updates`` argument, which is a dict of -:class:`pg.KeyPath ` to +:class:`pg.KeyPath ` to :class:`pg.FieldUpdate ` objects in case the user want to cherrypick the internal states to recompute based on the updates. For example:: diff --git a/docs/learn/soop/som/operations.rst b/docs/learn/soop/som/operations.rst index cf40333..0fd4cdd 100644 --- a/docs/learn/soop/som/operations.rst +++ b/docs/learn/soop/som/operations.rst @@ -195,7 +195,7 @@ Location ======== Each symbolic object has a unique location within a symbolic tree, represented a key path -(:class:`pg.KeyPath `), which is a path consists of the keys +(:class:`pg.KeyPath `), which is a path consists of the keys from the root node to the current node. For example, ``a.b[0].c`` is a path with height 4: @@ -331,7 +331,7 @@ human-readable format can be shown during debugging: * ``__str__`` formats a symbolic tree into a multi-line string representation, which is usually used in debugging purposes. -Both of these methods are based on :func:`pg.format `, which provides a +Both of these methods are based on :func:`pg.format `, which provides a rich set of features for formatting symbolic trees. For example, exclude the keys that have the default values from the string representation:: @@ -506,7 +506,7 @@ objects that are merely representations. Here is a summary of operations that de - :meth:`~pyglove.symbolic.Symbolic.sym_abstract` - Test whether an object is abstract or not. - * - :func:`pg.is_partial ` + * - :func:`pg.is_partial ` - :meth:`~pyglove.symbolic.Symbolic.sym_partial` - Test whether an object is partial or not. diff --git a/docs/learn/soop/som/validation.rst b/docs/learn/soop/som/validation.rst index 2c8f698..789f580 100644 --- a/docs/learn/soop/som/validation.rst +++ b/docs/learn/soop/som/validation.rst @@ -220,4 +220,4 @@ By default, PyGlove registered converters between the following pairs: - :class:`datetime.datetime` * - :class:`str` - - :class:`pg.KeyPath ` + - :class:`pg.KeyPath ` diff --git a/pyglove/core/__init__.py b/pyglove/core/__init__.py index 71132c8..89ce7bd 100644 --- a/pyglove/core/__init__.py +++ b/pyglove/core/__init__.py @@ -37,8 +37,7 @@ |__ tuning : Interface for program tuning with a local backend. |__ detouring : Detouring classes creation without symbolic types. |__ patching : Patching a program with URL-like strings. - |__ object_utils : Utility libary on operating with Python objects. - + |__ utils : Utility libary on operating with Python objects. """ # NOTE(daiyip): We disable bad-import-order to preserve the relation of @@ -273,31 +272,35 @@ # -# Symbols from 'object_utils' sub-module. +# Symbols from 'utils' sub-module. # -from pyglove.core import object_utils -KeyPath = object_utils.KeyPath -KeyPathSet = object_utils.KeyPathSet -MISSING_VALUE = object_utils.MISSING_VALUE +from pyglove.core import utils + +# For backward compatibility. +object_utils = utils + +KeyPath = utils.KeyPath +KeyPathSet = utils.KeyPathSet +MISSING_VALUE = utils.MISSING_VALUE -Formattable = object_utils.Formattable -repr_format = object_utils.repr_format -str_format = object_utils.str_format +Formattable = utils.Formattable +repr_format = utils.repr_format +str_format = utils.str_format -MaybePartial = object_utils.MaybePartial -JSONConvertible = object_utils.JSONConvertible -DocStr = object_utils.DocStr +MaybePartial = utils.MaybePartial +JSONConvertible = utils.JSONConvertible +DocStr = utils.DocStr -registered_types = object_utils.registered_types -explicit_method_override = object_utils.explicit_method_override +registered_types = utils.registered_types +explicit_method_override = utils.explicit_method_override -is_partial = object_utils.is_partial -format = object_utils.format # pylint: disable=redefined-builtin -print = object_utils.print # pylint: disable=redefined-builtin -docstr = object_utils.docstr -catch_errors = object_utils.catch_errors -timeit = object_utils.timeit +is_partial = utils.is_partial +format = utils.format # pylint: disable=redefined-builtin +print = utils.print # pylint: disable=redefined-builtin +docstr = utils.docstr +catch_errors = utils.catch_errors +timeit = utils.timeit # Symbols from 'views' sub-module. diff --git a/pyglove/core/geno/base.py b/pyglove/core/geno/base.py index 54004ef..f1f9c55 100644 --- a/pyglove/core/geno/base.py +++ b/pyglove/core/geno/base.py @@ -19,9 +19,9 @@ import types from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils class AttributeDict(dict): @@ -37,12 +37,20 @@ def __setattr__(self, key: Any, value: Any) -> None: @symbolic.members([ - ('location', - pg_typing.Object(object_utils.KeyPath, default=object_utils.KeyPath()), - ('KeyPath of associated genetic encoder relative to parent object ' - 'template. This allows DNA generator to apply rule based on locations.')), - ('hints', - pg_typing.Any(default=None), 'Hints for DNA generator to consume.') + ( + 'location', + pg_typing.Object(utils.KeyPath, default=utils.KeyPath()), + ( + 'KeyPath of associated genetic encoder relative to parent object' + ' template. This allows DNA generator to apply rule based on' + ' locations.' + ), + ), + ( + 'hints', + pg_typing.Any(default=None), + 'Hints for DNA generator to consume.', + ), ]) class DNASpec(symbolic.Object): """Base class for DNA specifications (genotypes). @@ -175,7 +183,7 @@ def decision_points(self) -> List['DecisionPoint']: """Returns all decision points in their declaration order.""" @property - def decision_ids(self) -> List[object_utils.KeyPath]: + def decision_ids(self) -> List[utils.KeyPath]: """Returns decision IDs.""" return list(self._decision_point_by_id.keys()) @@ -286,7 +294,7 @@ def parent_choice(self) -> Optional['DecisionPoint']: return self.parent_spec if self.is_space else self.parent_spec.parent_choice @property - def id(self) -> object_utils.KeyPath: + def id(self) -> utils.KeyPath: """Returns a path of locations from the root as the ID for current node.""" if self._id is None: parent = self.parent_spec @@ -295,18 +303,20 @@ def id(self) -> object_utils.KeyPath: elif self.is_space: assert parent.is_categorical, parent assert self.index is not None - self._id = object_utils.KeyPath( - ConditionalKey(self.index, len(parent.candidates)), - parent.id) + self.location + self._id = ( + utils.KeyPath( + ConditionalKey(self.index, len(parent.candidates)), parent.id + ) + + self.location + ) else: # Float() or a multi-choice spec of a parent Choice. self._id = parent.id + self.location return self._id - def get(self, - name_or_id: Union[object_utils.KeyPath, str], - default: Any = None - ) -> Union['DecisionPoint', List['DecisionPoint']]: + def get( + self, name_or_id: Union[utils.KeyPath, str], default: Any = None + ) -> Union['DecisionPoint', List['DecisionPoint']]: """Get decision point(s) by name or ID.""" try: return self[name_or_id] @@ -314,9 +324,8 @@ def get(self, return default def __getitem__( - self, - name_or_id: Union[object_utils.KeyPath, str] - ) -> Union['DecisionPoint', List['DecisionPoint']]: + self, name_or_id: Union[utils.KeyPath, str] + ) -> Union['DecisionPoint', List['DecisionPoint']]: """Get decision point(s) by name or ID .""" v = self._named_decision_points.get(name_or_id, None) if v is None: @@ -475,7 +484,7 @@ class DNA(symbolic.Object): # Allow assignment on symbolic attributes. allow_symbolic_assignment = True - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__( self, value: Union[None, int, float, str, List[Any], Tuple[Any]] = None, @@ -727,7 +736,7 @@ def _decision_by_id(self): return self._decision_by_id_cache @property - def decision_ids(self) -> List[object_utils.KeyPath]: + def decision_ids(self) -> List[utils.KeyPath]: """Returns decision IDs.""" self._ensure_dna_spec() return self._spec.decision_ids @@ -1249,9 +1258,11 @@ def _bind_decisions(dna_spec): return dna def to_numbers( - self, flatten: bool = True, - ) -> Union[List[Union[int, float, str]], - object_utils.Nestable[Union[int, float, str]]]: + self, + flatten: bool = True, + ) -> Union[ + List[Union[int, float, str]], utils.Nestable[Union[int, float, str]] + ]: """Returns a (maybe) nested structure of numbers as decisions. Args: @@ -1338,7 +1349,7 @@ def from_fn( f'Location: {dna_spec.location.path}.') children = [] for i, choice in enumerate(decision): - choice_location = object_utils.KeyPath(i, dna_spec.location) + choice_location = utils.KeyPath(i, dna_spec.location) if not isinstance(choice, int): raise ValueError( f'Choice value should be int. Encountered: {choice}, ' @@ -1410,7 +1421,7 @@ def sym_jsonify( if type_info: json_value = { - object_utils.JSONConvertible.TYPE_NAME_KEY: ( + utils.JSONConvertible.TYPE_NAME_KEY: ( self.__class__.__serialization_key__ ), 'format': 'compact', @@ -1435,7 +1446,8 @@ def from_json( json_value: Dict[str, Any], *, allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None) -> 'DNA': + root_path: Optional[utils.KeyPath] = None, + ) -> 'DNA': """Class method that load a DNA from a JSON value. Args: @@ -1472,8 +1484,8 @@ def is_leaf(self) -> bool: return not self.children def __getitem__( - self, key: Union[int, slice, str, object_utils.KeyPath, 'DecisionPoint'] - ) -> Union[None, 'DNA', List[Optional['DNA']]]: + self, key: Union[int, slice, str, utils.KeyPath, 'DecisionPoint'] + ) -> Union[None, 'DNA', List[Optional['DNA']]]: """Get an immediate child DNA or DNA in the sub-tree. Args: @@ -1504,10 +1516,11 @@ def __getitem__( v = self._decision_by_id[key] return v - def get(self, - key: Union[int, slice, str, object_utils.KeyPath, 'DecisionPoint'], - default: Any = None - ) -> Union[Any, None, 'DNA', List[Optional['DNA']]]: + def get( + self, + key: Union[int, slice, str, utils.KeyPath, 'DecisionPoint'], + default: Any = None, + ) -> Union[Any, None, 'DNA', List[Optional['DNA']]]: """Get an immediate child DNA or DNA in the sub-tree.""" try: return self[key] @@ -1529,8 +1542,9 @@ def __contains__(self, dna_or_value: Union[int, 'DNA']) -> bool: return True else: raise ValueError( - f'DNA.__contains__ does not accept ' - f'{object_utils.quote_if_str(dna_or_value)!r}.') + 'DNA.__contains__ does not accept ' + f'{utils.quote_if_str(dna_or_value)!r}.' + ) return False def __hash__(self): @@ -1605,12 +1619,13 @@ def format( ): """Customize format method for DNA for more compact representation.""" if as_dict and self.spec: - details = object_utils.format( + details = utils.format( self.to_dict(value_type='choice_and_literal'), False, verbose, root_indent, - **kwargs) + **kwargs, + ) s = f'DNA({details})' else: if 'list_wrap_threshold' not in kwargs: @@ -1621,7 +1636,7 @@ def format( elif self.is_leaf: s = f'DNA({self.value!r})' else: - rep = object_utils.format( + rep = utils.format( self.to_json(compact=True, type_info=False), compact, verbose, diff --git a/pyglove/core/geno/base_test.py b/pyglove/core/geno/base_test.py index 0e18253..13f1a97 100644 --- a/pyglove/core/geno/base_test.py +++ b/pyglove/core/geno/base_test.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.geno.DNA.""" - import unittest -from pyglove.core import object_utils from pyglove.core import symbolic +from pyglove.core import utils from pyglove.core.geno.base import ConditionalKey from pyglove.core.geno.base import DNA from pyglove.core.geno.categorical import manyof @@ -1144,7 +1142,7 @@ def test_basics(self): self.assertEqual(key.num_choices, 5) def test_to_str(self): - key = object_utils.KeyPath(['a', ConditionalKey(1, 5), 'b']) + key = utils.KeyPath(['a', ConditionalKey(1, 5), 'b']) self.assertEqual(str(key), 'a[=1/5].b') diff --git a/pyglove/core/geno/categorical.py b/pyglove/core/geno/categorical.py index 5363084..d2166b3 100644 --- a/pyglove/core/geno/categorical.py +++ b/pyglove/core/geno/categorical.py @@ -18,10 +18,9 @@ import types from typing import Any, List, Optional, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing - +from pyglove.core import utils from pyglove.core.geno.base import DecisionPoint from pyglove.core.geno.base import DNA from pyglove.core.geno.base import DNASpec @@ -139,14 +138,15 @@ def _on_bound(self): for i in range(self.num_choices): subchoice_spec = Choices( subchoice_index=i, - location=object_utils.KeyPath(i), + location=utils.KeyPath(i), num_choices=1, candidates=self.candidates, literal_values=self.literal_values, distinct=self.distinct, sorted=self.sorted, name=self.name, - hints=self.hints) + hints=self.hints, + ) self._decision_points.extend(subchoice_spec.decision_points) subchoice_specs.append(subchoice_spec) self._subchoice_specs = symbolic.List(subchoice_specs) @@ -158,7 +158,8 @@ def _on_bound(self): self._decision_points.extend(c.decision_points) def _update_children_paths( - self, old_path: object_utils.KeyPath, new_path: object_utils.KeyPath): + self, old_path: utils.KeyPath, new_path: utils.KeyPath + ): """Trigger path change for subchoices so their IDs can be invalidated.""" super()._update_children_paths(old_path, new_path) if self._subchoice_specs: @@ -337,7 +338,7 @@ def validate(self, dna: DNA) -> None: f'DNA child values should be sorted. ' f'Encountered: {sub_dna_values}, Location: {self.location.path}.') for i, sub_dna in enumerate(dna): - sub_location = object_utils.KeyPath(i, self.location) + sub_location = utils.KeyPath(i, self.location) if not isinstance(sub_dna.value, int): raise ValueError( f'Choice value should be int. Encountered: {sub_dna.value}, ' @@ -601,13 +602,18 @@ def _indent(text, indent): kvlist = [('id', str(self.id), '\'\'')] else: kvlist = [] - additionl_properties = object_utils.kvlist_str(kvlist + [ - ('name', self.name, None), - ('distinct', self.distinct, True), - ('sorted', self.sorted, False), - ('hints', self.hints, None), - ('subchoice_index', self.subchoice_index, None) - ], compact=False, root_indent=root_indent) + additionl_properties = utils.kvlist_str( + kvlist + + [ + ('name', self.name, None), + ('distinct', self.distinct, True), + ('sorted', self.sorted, False), + ('hints', self.hints, None), + ('subchoice_index', self.subchoice_index, None), + ], + compact=False, + root_indent=root_indent, + ) if additionl_properties: s.append(', ') s.append(additionl_properties) @@ -615,14 +621,16 @@ def _indent(text, indent): return ''.join(s) -def manyof(num_choices: int, - candidates: List[DNASpec], - distinct: bool = True, - sorted: bool = False, # pylint: disable=redefined-builtin - literal_values: Optional[List[Union[str, int, float]]] = None, - hints: Any = None, - location: Union[str, object_utils.KeyPath] = object_utils.KeyPath(), - name: Optional[str] = None) -> Choices: +def manyof( + num_choices: int, + candidates: List[DNASpec], + distinct: bool = True, + sorted: bool = False, # pylint: disable=redefined-builtin + literal_values: Optional[List[Union[str, int, float]]] = None, + hints: Any = None, + location: Union[str, utils.KeyPath] = utils.KeyPath(), + name: Optional[str] = None, +) -> Choices: """Returns a multi-choice specification. It creates the genotype for :func:`pyglove.manyof`. @@ -674,11 +682,13 @@ def manyof(num_choices: int, hints=hints, location=location, name=name) -def oneof(candidates: List[DNASpec], - literal_values: Optional[List[Union[str, int, float]]] = None, - hints: Any = None, - location: Union[str, object_utils.KeyPath] = object_utils.KeyPath(), - name: Optional[str] = None) -> Choices: +def oneof( + candidates: List[DNASpec], + literal_values: Optional[List[Union[str, int, float]]] = None, + hints: Any = None, + location: Union[str, utils.KeyPath] = utils.KeyPath(), + name: Optional[str] = None, +) -> Choices: """Returns a single choice specification. It creates the genotype for :func:`pyglove.oneof`. @@ -716,4 +726,3 @@ def oneof(candidates: List[DNASpec], """ return manyof(1, candidates, literal_values=literal_values, hints=hints, location=location, name=name) - diff --git a/pyglove/core/geno/custom.py b/pyglove/core/geno/custom.py index 238a692..6016031 100644 --- a/pyglove/core/geno/custom.py +++ b/pyglove/core/geno/custom.py @@ -17,10 +17,9 @@ import types from typing import Any, Callable, List, Optional, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing - +from pyglove.core import utils from pyglove.core.geno.base import DecisionPoint from pyglove.core.geno.base import DNA @@ -125,7 +124,7 @@ def validate(self, dna: DNA) -> None: f'CustomDecisionPoint expects string type DNA. ' f'Encountered: {dna!r}, Location: {self.location.path}.') - def sym_jsonify(self, **kwargs: Any) -> object_utils.JSONValueType: + def sym_jsonify(self, **kwargs: Any) -> utils.JSONValueType: """Overrides sym_jsonify to exclude non-serializable fields.""" exclude_keys = kwargs.pop('exclude_keys', []) exclude_keys.extend(['random_dna_fn', 'next_dna_fn']) @@ -153,21 +152,25 @@ def format(self, kvlist = [('id', str(self.id), '\'\'')] else: kvlist = [] - details = object_utils.kvlist_str(kvlist + [ - ('hyper_type', self.hyper_type, None), - ('name', self.name, None), - ('hints', self.hints, None), - ]) + details = utils.kvlist_str( + kvlist + + [ + ('hyper_type', self.hyper_type, None), + ('name', self.name, None), + ('hints', self.hints, None), + ] + ) return f'{self.__class__.__name__}({details})' -def custom(hyper_type: Optional[str] = None, - next_dna_fn: Optional[ - Callable[[Optional[DNA]], Optional[DNA]]] = None, - random_dna_fn: Optional[Callable[[Any], DNA]] = None, - hints: Any = None, - location: object_utils.KeyPath = object_utils.KeyPath(), - name: Optional[str] = None) -> CustomDecisionPoint: +def custom( + hyper_type: Optional[str] = None, + next_dna_fn: Optional[Callable[[Optional[DNA]], Optional[DNA]]] = None, + random_dna_fn: Optional[Callable[[Any], DNA]] = None, + hints: Any = None, + location: utils.KeyPath = utils.KeyPath(), + name: Optional[str] = None, +) -> CustomDecisionPoint: """Returns a custom decision point. It creates the genotype for subclasses of :func:`pyglove.hyper.CustomHyper`. diff --git a/pyglove/core/geno/numerical.py b/pyglove/core/geno/numerical.py index 270b7f3..eff345d 100644 --- a/pyglove/core/geno/numerical.py +++ b/pyglove/core/geno/numerical.py @@ -17,10 +17,9 @@ import types from typing import Any, List, Optional, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing - +from pyglove.core import utils from pyglove.core.geno.base import DecisionPoint from pyglove.core.geno.base import DNA @@ -167,22 +166,27 @@ def format(self, kvlist = [('id', str(self.id), '\'\'')] else: kvlist = [] - details = object_utils.kvlist_str(kvlist + [ - ('name', self.name, None), - ('min_value', self.min_value, None), - ('max_value', self.max_value, None), - ('scale', self.scale, None), - ('hints', self.hints, None), - ]) + details = utils.kvlist_str( + kvlist + + [ + ('name', self.name, None), + ('min_value', self.min_value, None), + ('max_value', self.max_value, None), + ('scale', self.scale, None), + ('hints', self.hints, None), + ] + ) return f'{self.__class__.__name__}({details})' -def floatv(min_value: float, - max_value: float, - scale: Optional[str] = None, - hints: Any = None, - location: object_utils.KeyPath = object_utils.KeyPath(), - name: Optional[str] = None) -> Float: +def floatv( + min_value: float, + max_value: float, + scale: Optional[str] = None, + hints: Any = None, + location: utils.KeyPath = utils.KeyPath(), + name: Optional[str] = None, +) -> Float: """Returns a Float specification. It creates the genotype for :func:`pyglove.floatv`. @@ -226,4 +230,3 @@ def floatv(min_value: float, """ return Float(min_value, max_value, scale, hints=hints, location=location, name=name) - diff --git a/pyglove/core/geno/space.py b/pyglove/core/geno/space.py index 559ff48..5bd8a03 100644 --- a/pyglove/core/geno/space.py +++ b/pyglove/core/geno/space.py @@ -18,10 +18,9 @@ import types from typing import List, Optional, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing - +from pyglove.core import utils from pyglove.core.geno.base import DecisionPoint from pyglove.core.geno.base import DNA from pyglove.core.geno.base import DNASpec @@ -231,8 +230,8 @@ def __len__(self) -> int: return sum([len(elem) for elem in self.elements]) def __getitem__( - self, index: Union[int, slice, str, object_utils.KeyPath] - ) -> Union[DecisionPoint, List[DecisionPoint]]: + self, index: Union[int, slice, str, utils.KeyPath] + ) -> Union[DecisionPoint, List[DecisionPoint]]: """Operator [] to return element by index or sub-DNASpec by name.""" if isinstance(index, (int, slice)): return self.elements[index] diff --git a/pyglove/core/hyper/base.py b/pyglove/core/hyper/base.py index 2522676..4a0981f 100644 --- a/pyglove/core/hyper/base.py +++ b/pyglove/core/hyper/base.py @@ -18,9 +18,9 @@ from typing import Any, Callable, Optional from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils class HyperValue(symbolic.NonDeterministic): # pytype: disable=ignored-metaclass @@ -110,8 +110,7 @@ def encode(self, value: Any) -> geno.DNA: """ @abc.abstractmethod - def dna_spec(self, - location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec: + def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.DNASpec: """Get DNA spec of DNA that is decodable/encodable by this hyper value.""" @@ -186,12 +185,13 @@ def set_dynamic_evaluate_fn( global _global_dynamic_evaluate_fn if per_thread: assert _global_dynamic_evaluate_fn is None, _global_dynamic_evaluate_fn - object_utils.thread_local_set(_TLS_KEY_DYNAMIC_EVALUATE_FN, fn) + utils.thread_local_set(_TLS_KEY_DYNAMIC_EVALUATE_FN, fn) else: _global_dynamic_evaluate_fn = fn def get_dynamic_evaluate_fn() -> Optional[Callable[[HyperValue], Any]]: """Gets current dynamic evaluate function.""" - return object_utils.thread_local_get( - _TLS_KEY_DYNAMIC_EVALUATE_FN, _global_dynamic_evaluate_fn) + return utils.thread_local_get( + _TLS_KEY_DYNAMIC_EVALUATE_FN, _global_dynamic_evaluate_fn + ) diff --git a/pyglove/core/hyper/categorical.py b/pyglove/core/hyper/categorical.py index 3deead9..3b2295e 100644 --- a/pyglove/core/hyper/categorical.py +++ b/pyglove/core/hyper/categorical.py @@ -18,9 +18,9 @@ from typing import Any, Callable, Iterable, List, Optional, Tuple, Union from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper import base from pyglove.core.hyper import object_template @@ -85,7 +85,8 @@ def _on_bound(self): self._value_spec = None def _update_children_paths( - self, old_path: object_utils.KeyPath, new_path: object_utils.KeyPath): + self, old_path: utils.KeyPath, new_path: utils.KeyPath + ): """Customized logic to update children paths.""" super()._update_children_paths(old_path, new_path) for t in self._candidate_templates: @@ -104,19 +105,20 @@ def is_leaf(self) -> bool: return False return True - def dna_spec(self, - location: Optional[object_utils.KeyPath] = None) -> geno.Choices: + def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.Choices: """Returns corresponding DNASpec.""" return geno.Choices( num_choices=self.num_choices, candidates=[ct.dna_spec() for ct in self._candidate_templates], - literal_values=[self._literal_value(c) - for i, c in enumerate(self.candidates)], + literal_values=[ + self._literal_value(c) for i, c in enumerate(self.candidates) + ], distinct=self.choices_distinct, sorted=self.choices_sorted, hints=self.hints, name=self.name, - location=location or object_utils.KeyPath()) + location=location or utils.KeyPath(), + ) def _literal_value( self, candidate: Any, max_len: int = 120) -> Union[int, float, str]: @@ -124,10 +126,13 @@ def _literal_value( if isinstance(candidate, numbers.Number): return candidate - literal = object_utils.format(candidate, compact=True, - hide_default_values=True, - hide_missing_values=True, - strip_object_id=True) + literal = utils.format( + candidate, + compact=True, + hide_default_values=True, + hide_missing_values=True, + strip_object_id=True, + ) if len(literal) > max_len: literal = literal[:max_len - 3] + '...' return literal @@ -139,52 +144,70 @@ def _decode(self) -> List[Any]: # Single choice. if not isinstance(dna.value, int): raise ValueError( - object_utils.message_on_path( - f'Did you forget to specify values for conditional choices?\n' + utils.message_on_path( + 'Did you forget to specify values for conditional choices?\n' f'Expect integer for {self.__class__.__name__}. ' - f'Encountered: {dna!r}.', self.sym_path)) + f'Encountered: {dna!r}.', + self.sym_path, + ) + ) if dna.value >= len(self.candidates): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Choice out of range. Value: {dna.value!r}, ' - f'Candidates: {len(self.candidates)}.', self.sym_path)) + f'Candidates: {len(self.candidates)}.', + self.sym_path, + ) + ) choices = [self._candidate_templates[dna.value].decode( geno.DNA(None, dna.children))] else: # Multi choices. if len(dna.children) != self.num_choices: raise ValueError( - object_utils.message_on_path( - f'Number of DNA child values does not match the number of ' + utils.message_on_path( + 'Number of DNA child values does not match the number of ' f'choices. Child values: {dna.children!r}, ' - f'Choices: {self.num_choices}.', self.sym_path)) + f'Choices: {self.num_choices}.', + self.sym_path, + ) + ) if self.choices_distinct or self.choices_sorted: sub_dna_values = [s.value for s in dna] if (self.choices_distinct and len(set(sub_dna_values)) != len(dna.children)): raise ValueError( - object_utils.message_on_path( - f'DNA child values should be distinct. ' - f'Encountered: {sub_dna_values}.', self.sym_path)) + utils.message_on_path( + 'DNA child values should be distinct. ' + f'Encountered: {sub_dna_values}.', + self.sym_path, + ) + ) if self.choices_sorted and sorted(sub_dna_values) != sub_dna_values: raise ValueError( - object_utils.message_on_path( - f'DNA child values should be sorted. ' - f'Encountered: {sub_dna_values}.', self.sym_path)) + utils.message_on_path( + 'DNA child values should be sorted. ' + f'Encountered: {sub_dna_values}.', + self.sym_path, + ) + ) choices = [] for i, sub_dna in enumerate(dna): if not isinstance(sub_dna.value, int): raise ValueError( - object_utils.message_on_path( - f'Choice value should be int. ' - f'Encountered: {sub_dna.value}.', - object_utils.KeyPath(i, self.sym_path))) + utils.message_on_path( + f'Choice value should be int. Encountered: {sub_dna.value}.', + utils.KeyPath(i, self.sym_path), + ) + ) if sub_dna.value >= len(self.candidates): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Choice out of range. Value: {sub_dna.value}, ' f'Candidates: {len(self.candidates)}.', - object_utils.KeyPath(i, self.sym_path))) + utils.KeyPath(i, self.sym_path), + ) + ) choices.append(self._candidate_templates[sub_dna.value].decode( geno.DNA(None, sub_dna.children))) return choices @@ -240,15 +263,21 @@ def encode(self, value: List[Any]) -> geno.DNA: """ if not isinstance(value, list): raise ValueError( - object_utils.message_on_path( - f'Cannot encode value: value should be a list type. ' - f'Encountered: {value!r}.', self.sym_path)) + utils.message_on_path( + 'Cannot encode value: value should be a list type. ' + f'Encountered: {value!r}.', + self.sym_path, + ) + ) choices = [] if self.num_choices is not None and len(value) != self.num_choices: raise ValueError( - object_utils.message_on_path( - f'Length of input list is different from the number of choices ' - f'({self.num_choices}). Encountered: {value}.', self.sym_path)) + utils.message_on_path( + 'Length of input list is different from the number of choices ' + f'({self.num_choices}). Encountered: {value}.', + self.sym_path, + ) + ) for v in value: choice_id = None child_dna = None @@ -259,10 +288,12 @@ def encode(self, value: List[Any]) -> geno.DNA: break if child_dna is None: raise ValueError( - object_utils.message_on_path( - f'Cannot encode value: no candidates matches with ' + utils.message_on_path( + 'Cannot encode value: no candidates matches with ' f'the value. Value: {v!r}, Candidates: {self.candidates}.', - self.sym_path)) + self.sym_path, + ) + ) choices.append(geno.DNA(choice_id, [child_dna])) return geno.DNA(None, choices) @@ -313,12 +344,13 @@ class ManyOf(Choices): def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, 'Choices']: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, 'Choices']: """Validate candidates during value_spec binding time.""" # Check if value_spec directly accepts `self`. if value_spec.value_type and isinstance(self, value_spec.value_type): @@ -329,10 +361,12 @@ def custom_apply( dest_spec = value_spec if not dest_spec.is_compatible(src_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Cannot bind an incompatible value spec {dest_spec!r} ' f'to {self.__class__.__name__} with bound spec {src_spec!r}.', - path)) + path, + ) + ) return (False, self) list_spec = typing.cast( @@ -399,12 +433,13 @@ def encode(self, value: Any) -> geno.DNA: def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, 'OneOf']: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, 'OneOf']: """Validate candidates during value_spec binding time.""" # Check if value_spec directly accepts `self`. if value_spec.value_type and isinstance(self, value_spec.value_type): @@ -413,10 +448,13 @@ def custom_apply( if self._value_spec: if not value_spec.is_compatible(self._value_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Cannot bind an incompatible value spec {value_spec!r} ' f'to {self.__class__.__name__} with bound ' - f'spec {self._value_spec!r}.', path)) + f'spec {self._value_spec!r}.', + path, + ) + ) return (False, self) for i, c in enumerate(self.candidates): @@ -427,6 +465,7 @@ def custom_apply( self._value_spec = value_spec return (False, self) + # # Helper methods for creating hyper values. # diff --git a/pyglove/core/hyper/custom.py b/pyglove/core/hyper/custom.py index acf1543..4301ed3 100644 --- a/pyglove/core/hyper/custom.py +++ b/pyglove/core/hyper/custom.py @@ -19,8 +19,8 @@ from typing import Any, Callable, Optional, Tuple, Union from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper import base @@ -111,8 +111,7 @@ def custom_encode(self, value: Any) -> geno.DNA: raise NotImplementedError( f'\'custom_encode\' is not supported by {self.__class__.__name__!r}.') - def dna_spec( - self, location: Optional[object_utils.KeyPath] = None) -> geno.DNASpec: + def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.DNASpec: """Always returns CustomDecisionPoint for CustomHyper.""" return geno.CustomDecisionPoint( hyper_type=self.__class__.__name__, @@ -147,12 +146,13 @@ def random_dna( def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, 'CustomHyper']: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, 'CustomHyper']: """Validate candidates during value_spec binding time.""" del path, value_spec, allow_partial, child_transform # Allow custom hyper to be assigned to any type. diff --git a/pyglove/core/hyper/custom_test.py b/pyglove/core/hyper/custom_test.py index fd32fbb..1a80e69 100644 --- a/pyglove/core/hyper/custom_test.py +++ b/pyglove/core/hyper/custom_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.hyper.CustomHyper.""" - import random import unittest from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic - +from pyglove.core import utils from pyglove.core.hyper.categorical import oneof from pyglove.core.hyper.custom import CustomHyper from pyglove.core.hyper.iter import iterate @@ -58,12 +55,14 @@ class CustomHyperTest(unittest.TestCase): """Test for CustomHyper.""" def test_dna_spec(self): - self.assertTrue(symbolic.eq( - IntSequence(hints='x').dna_spec('a'), - geno.CustomDecisionPoint( - hyper_type='IntSequence', - location=object_utils.KeyPath('a'), - hints='x'))) + self.assertTrue( + symbolic.eq( + IntSequence(hints='x').dna_spec('a'), + geno.CustomDecisionPoint( + hyper_type='IntSequence', location=utils.KeyPath('a'), hints='x' + ), + ) + ) def test_decode(self): self.assertEqual(IntSequence().decode(geno.DNA('0,1,2')), [0, 1, 2]) diff --git a/pyglove/core/hyper/derived.py b/pyglove/core/hyper/derived.py index 277dc32..4234eab 100644 --- a/pyglove/core/hyper/derived.py +++ b/pyglove/core/hyper/derived.py @@ -17,16 +17,19 @@ import copy from typing import Any, Callable, List, Optional, Tuple, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils -@symbolic.members([ - ('reference_paths', pg_typing.List(pg_typing.Object(object_utils.KeyPath)), - ('Paths of referenced values, which are relative paths searched from ' - 'current node to root.')) -]) +@symbolic.members([( + 'reference_paths', + pg_typing.List(pg_typing.Object(utils.KeyPath)), + ( + 'Paths of referenced values, which are relative paths searched from ' + 'current node to root.' + ), +)]) class DerivedValue(symbolic.Object, pg_typing.CustomTyping): """Base class of value that references to other values in object tree.""" @@ -36,8 +39,10 @@ def derive(self, *args: Any) -> Any: def resolve( self, reference_path_or_paths: Optional[Union[str, List[str]]] = None - ) -> Union[Tuple[symbolic.Symbolic, object_utils.KeyPath], - List[Tuple[symbolic.Symbolic, object_utils.KeyPath]]]: + ) -> Union[ + Tuple[symbolic.Symbolic, utils.KeyPath], + List[Tuple[symbolic.Symbolic, utils.KeyPath]], + ]: """Resolve reference paths based on the location of this node. Args: @@ -54,17 +59,17 @@ def resolve( if reference_path_or_paths is None: reference_paths = self.reference_paths elif isinstance(reference_path_or_paths, str): - reference_paths = [object_utils.KeyPath.parse(reference_path_or_paths)] + reference_paths = [utils.KeyPath.parse(reference_path_or_paths)] single_input = True - elif isinstance(reference_path_or_paths, object_utils.KeyPath): + elif isinstance(reference_path_or_paths, utils.KeyPath): reference_paths = [reference_path_or_paths] single_input = True elif isinstance(reference_path_or_paths, list): paths = [] for path in reference_path_or_paths: if isinstance(path, str): - path = object_utils.KeyPath.parse(path) - elif not isinstance(path, object_utils.KeyPath): + path = utils.KeyPath.parse(path) + elif not isinstance(path, utils.KeyPath): raise ValueError('Argument \'reference_path_or_paths\' must be None, ' 'a string, KeyPath object, a list of strings, or a ' 'list of KeyPath objects.') @@ -96,8 +101,7 @@ def __call__(self): # Make sure referenced value does not have referenced value. # NOTE(daiyip): We can support dependencies between derived values # in future if needed. - if not object_utils.traverse( - referenced_value, self._contains_not_derived_value): + if not utils.traverse(referenced_value, self._contains_not_derived_value): raise ValueError( f'Derived value (path={referenced_value.sym_path}) should not ' f'reference derived values. ' @@ -107,15 +111,18 @@ def __call__(self): return self.derive(*referenced_values) def _contains_not_derived_value( - self, path: object_utils.KeyPath, value: Any) -> bool: + self, path: utils.KeyPath, value: Any + ) -> bool: """Returns whether a value contains derived value.""" if isinstance(value, DerivedValue): return False elif isinstance(value, symbolic.Object): for k, v in value.sym_items(): - if not object_utils.traverse( - v, self._contains_not_derived_value, - root_path=object_utils.KeyPath(k, path)): + if not utils.traverse( + v, + self._contains_not_derived_value, + root_path=utils.KeyPath(k, path), + ): return False return True @@ -137,12 +144,13 @@ def derive(self, referenced_value: Any) -> Any: def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, 'DerivedValue']: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, 'DerivedValue']: """Implement pg_typing.CustomTyping interface.""" # TODO(daiyip): perform possible static analysis on referenced paths. del path, value_spec, allow_partial, child_transform diff --git a/pyglove/core/hyper/derived_test.py b/pyglove/core/hyper/derived_test.py index eedd233..90f8af2 100644 --- a/pyglove/core/hyper/derived_test.py +++ b/pyglove/core/hyper/derived_test.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.hyper.ValueReference.""" - import unittest -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper.derived import ValueReference @@ -47,7 +45,7 @@ def test_resolve(self): self.assertEqual(sd.c[0].y.resolve(), [(sd.c[0], 'c[0].x[0].z')]) self.assertEqual(sd.c[1].y.resolve(), [(sd.c[1], 'c[1].x[0].z')]) # Resolve references from this point. - self.assertEqual(sd.c[0].y.resolve(object_utils.KeyPath(0)), (sd.c, 'c[0]')) + self.assertEqual(sd.c[0].y.resolve(utils.KeyPath(0)), (sd.c, 'c[0]')) self.assertEqual(sd.c[0].y.resolve('[0]'), (sd.c, 'c[0]')) self.assertEqual( sd.c[0].y.resolve(['[0]', '[1]']), [(sd.c, 'c[0]'), (sd.c, 'c[1]')]) diff --git a/pyglove/core/hyper/dynamic_evaluation.py b/pyglove/core/hyper/dynamic_evaluation.py index 6cbf33a..25994f8 100644 --- a/pyglove/core/hyper/dynamic_evaluation.py +++ b/pyglove/core/hyper/dynamic_evaluation.py @@ -18,9 +18,9 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Union from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper import base from pyglove.core.hyper import categorical from pyglove.core.hyper import custom @@ -520,10 +520,10 @@ def ensure_thread_safety(self, context: DynamicEvaluationContext): @property def _local_stack(self): """Returns thread-local stack.""" - stack = object_utils.thread_local_get(self._TLS_KEY, None) + stack = utils.thread_local_get(self._TLS_KEY, None) if stack is None: stack = [] - object_utils.thread_local_set(self._TLS_KEY, stack) + utils.thread_local_set(self._TLS_KEY, stack) return stack def push(self, context: DynamicEvaluationContext): @@ -585,4 +585,3 @@ def trace( with context.collect(): fun() return context - diff --git a/pyglove/core/hyper/evolvable.py b/pyglove/core/hyper/evolvable.py index fca6aa1..2798ed2 100644 --- a/pyglove/core/hyper/evolvable.py +++ b/pyglove/core/hyper/evolvable.py @@ -20,9 +20,9 @@ from typing import Any, Callable, List, Optional, Tuple, Union from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper import custom @@ -44,7 +44,7 @@ class MutationPoint: parent: The parent node of the mutation point. """ mutation_type: 'MutationType' - location: object_utils.KeyPath + location: utils.KeyPath old_value: Any parent: Optional[symbolic.Symbolic] @@ -71,9 +71,9 @@ def mutation_points_and_weights( mutation_points: List[MutationPoint] = [] mutation_weights: List[float] = [] - def _choose_mutation_point(k: object_utils.KeyPath, - v: Any, - p: Optional[symbolic.Symbolic]): + def _choose_mutation_point( + k: utils.KeyPath, v: Any, p: Optional[symbolic.Symbolic] + ): """Visiting function for a symbolic node.""" def _add_point(mt: MutationType, k=k, v=v, p=p): mutation_points.append(MutationPoint(mt, k, v, p)) @@ -98,10 +98,9 @@ def _add_point(mt: MutationType, k=k, v=v, p=p): reached_min_size = False for i, cv in enumerate(v): - ck = object_utils.KeyPath(i, parent=k) + ck = utils.KeyPath(i, parent=k) if not reached_max_size: - _add_point(MutationType.INSERT, - k=ck, v=object_utils.MISSING_VALUE, p=v) + _add_point(MutationType.INSERT, k=ck, v=utils.MISSING_VALUE, p=v) if not reached_min_size: _add_point(MutationType.DELETE, k=ck, v=cv, p=v) @@ -109,10 +108,12 @@ def _add_point(mt: MutationType, k=k, v=v, p=p): # Replace type and value will be added in traverse. symbolic.traverse(cv, _choose_mutation_point, root_path=ck, parent=v) if not reached_max_size and i == len(v) - 1: - _add_point(MutationType.INSERT, - k=object_utils.KeyPath(i + 1, parent=k), - v=object_utils.MISSING_VALUE, - p=v) + _add_point( + MutationType.INSERT, + k=utils.KeyPath(i + 1, parent=k), + v=utils.MISSING_VALUE, + p=v, + ) return symbolic.TraverseAction.CONTINUE return symbolic.TraverseAction.ENTER @@ -157,7 +158,7 @@ def mutate( point.location, point.old_value, point.parent) elif point.mutation_type == MutationType.INSERT: assert isinstance(point.parent, symbolic.List), point - assert point.old_value == object_utils.MISSING_VALUE, point + assert point.old_value == utils.MISSING_VALUE, point assert isinstance(point.location.key, int), point with symbolic.allow_writable_accessors(): point.parent.insert( @@ -175,24 +176,31 @@ def mutate( # We defer members declaration for Evolvable since the weights will reference # the definition of MutationType. symbolic.members([ - ('initial_value', pg_typing.Object(symbolic.Symbolic), - 'Symbolic value to involve.'), - ('node_transform', pg_typing.Callable( - [], - returns=pg_typing.Any()), - ''), - ('weights', pg_typing.Callable( - [ - pg_typing.Object(MutationType), - pg_typing.Object(object_utils.KeyPath), - pg_typing.Any().noneable(), - pg_typing.Object(symbolic.Symbolic) - ], returns=pg_typing.Float(min_value=0.0)).noneable(), - ('An optional callable object that returns the unnormalized (e.g. ' - 'the sum of all probabilities do not have to sum to 1.0) mutation ' - 'probabilities for all the nodes in the symbolic tree, based on ' - '(mutation type, location, old value, parent node). If None, all the ' - 'locations and mutation types will be sampled uniformly.')), + ( + 'initial_value', + pg_typing.Object(symbolic.Symbolic), + 'Symbolic value to involve.', + ), + ('node_transform', pg_typing.Callable([], returns=pg_typing.Any()), ''), + ( + 'weights', + pg_typing.Callable( + [ + pg_typing.Object(MutationType), + pg_typing.Object(utils.KeyPath), + pg_typing.Any().noneable(), + pg_typing.Object(symbolic.Symbolic), + ], + returns=pg_typing.Float(min_value=0.0), + ).noneable(), + ( + 'An optional callable object that returns the unnormalized (e.g.' + ' the sum of all probabilities do not have to sum to 1.0) mutation' + ' probabilities for all the nodes in the symbolic tree, based on' + ' (mutation type, location, old value, parent node). If None, all' + ' the locations and mutation types will be sampled uniformly.' + ), + ), ])(Evolvable) @@ -200,25 +208,28 @@ def evolve( initial_value: symbolic.Symbolic, node_transform: Callable[ [ - object_utils.KeyPath, # Location. - Any, # Old value. - # pg.MISSING_VALUE for insertion. - symbolic.Symbolic, # Parent node. + utils.KeyPath, # Location. + Any, # Old value. + # pg.MISSING_VALUE for insertion. + symbolic.Symbolic, # Parent node. ], - Any # Replacement. + Any, # Replacement. ], *, - weights: Optional[Callable[ - [ - MutationType, # Mutation type. - object_utils.KeyPath, # Location. - Any, # Value. - symbolic.Symbolic, # Parent. - ], - float # Mutation weight. - ]] = None, # pylint: disable=bad-whitespace + weights: Optional[ + Callable[ + [ + MutationType, # Mutation type. + utils.KeyPath, # Location. + Any, # Value. + symbolic.Symbolic, # Parent. + ], + float, # Mutation weight. + ] + ] = None, # pylint: disable=bad-whitespace name: Optional[str] = None, - hints: Optional[Any] = None) -> Evolvable: + hints: Optional[Any] = None +) -> Evolvable: """An evolvable symbolic value. Example:: diff --git a/pyglove/core/hyper/numerical.py b/pyglove/core/hyper/numerical.py index 4c3c426..0877eae 100644 --- a/pyglove/core/hyper/numerical.py +++ b/pyglove/core/hyper/numerical.py @@ -17,9 +17,9 @@ from typing import Any, Callable, Optional, Tuple from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper import base @@ -62,8 +62,7 @@ def _on_bound(self): f'\'min_value\' must be positive when `scale` is {self.scale!r}. ' f'encountered: {self.min_value}.') - def dna_spec(self, - location: Optional[object_utils.KeyPath] = None) -> geno.Float: + def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.Float: """Returns corresponding DNASpec.""" return geno.Float( min_value=self.min_value, @@ -71,55 +70,74 @@ def dna_spec(self, scale=self.scale, hints=self.hints, name=self.name, - location=location or object_utils.KeyPath()) + location=location or utils.KeyPath(), + ) def _decode(self) -> float: """Decode a DNA into a float value.""" dna = self._dna if not isinstance(dna.value, float): raise ValueError( - object_utils.message_on_path( - f'Expect float value. Encountered: {dna.value}.', self.sym_path)) + utils.message_on_path( + f'Expect float value. Encountered: {dna.value}.', self.sym_path + ) + ) if dna.value < self.min_value: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'DNA value should be no less than {self.min_value}. ' - f'Encountered {dna.value}.', self.sym_path)) + f'Encountered {dna.value}.', + self.sym_path, + ) + ) if dna.value > self.max_value: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'DNA value should be no greater than {self.max_value}. ' - f'Encountered {dna.value}.', self.sym_path)) + f'Encountered {dna.value}.', + self.sym_path, + ) + ) return dna.value def encode(self, value: float) -> geno.DNA: """Encode a float value into a DNA.""" if not isinstance(value, float): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Value should be float to be encoded for {self!r}. ' - f'Encountered {value}.', self.sym_path)) + f'Encountered {value}.', + self.sym_path, + ) + ) if value < self.min_value: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Value should be no less than {self.min_value}. ' - f'Encountered {value}.', self.sym_path)) + f'Encountered {value}.', + self.sym_path, + ) + ) if value > self.max_value: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Value should be no greater than {self.max_value}. ' - f'Encountered {value}.', self.sym_path)) + f'Encountered {value}.', + self.sym_path, + ) + ) return geno.DNA(value) def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool = False, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, 'Float']: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, 'Float']: """Validate candidates during value_spec binding time.""" del allow_partial del child_transform @@ -134,17 +152,23 @@ def custom_apply( if (float_spec.min_value is not None and self.min_value < float_spec.min_value): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Float.min_value ({self.min_value}) should be no less than ' f'the min value ({float_spec.min_value}) of value spec: ' - f'{float_spec}.', path)) + f'{float_spec}.', + path, + ) + ) if (float_spec.max_value is not None and self.max_value > float_spec.max_value): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Float.max_value ({self.max_value}) should be no greater than ' f'the max value ({float_spec.max_value}) of value spec: ' - f'{float_spec}.', path)) + f'{float_spec}.', + path, + ) + ) return (False, self) def is_leaf(self) -> bool: diff --git a/pyglove/core/hyper/numerical_test.py b/pyglove/core/hyper/numerical_test.py index fc1c78d..aee7044 100644 --- a/pyglove/core/hyper/numerical_test.py +++ b/pyglove/core/hyper/numerical_test.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.hyper.Float.""" - import unittest from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper.numerical import Float from pyglove.core.hyper.numerical import floatv @@ -44,12 +42,14 @@ def test_scale(self): floatv(-1.0, 1.0, 'log') def test_dna_spec(self): - self.assertTrue(symbolic.eq( - floatv(0.0, 1.0).dna_spec('a'), - geno.Float( - location=object_utils.KeyPath('a'), - min_value=0.0, - max_value=1.0))) + self.assertTrue( + symbolic.eq( + floatv(0.0, 1.0).dna_spec('a'), + geno.Float( + location=utils.KeyPath('a'), min_value=0.0, max_value=1.0 + ), + ) + ) def test_decode(self): v = floatv(0.0, 1.0) diff --git a/pyglove/core/hyper/object_template.py b/pyglove/core/hyper/object_template.py index 9f837c4..2365500 100644 --- a/pyglove/core/hyper/object_template.py +++ b/pyglove/core/hyper/object_template.py @@ -16,14 +16,14 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from pyglove.core import geno -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.hyper import base from pyglove.core.hyper import derived -class ObjectTemplate(base.HyperValue, object_utils.Formattable): +class ObjectTemplate(base.HyperValue, utils.Formattable): """Object template that encodes and decodes symbolic values. An object template can be created from a hyper value, which is a symbolic @@ -131,18 +131,18 @@ def __init__(self, """ super().__init__() self._value = value - self._root_path = object_utils.KeyPath() + self._root_path = utils.KeyPath() self._compute_derived = compute_derived self._where = where self._parse_generators() @property - def root_path(self) -> object_utils.KeyPath: + def root_path(self) -> utils.KeyPath: """Returns root path.""" return self._root_path @root_path.setter - def root_path(self, path: object_utils.KeyPath): + def root_path(self, path: utils.KeyPath): """Set root path.""" self._root_path = path @@ -150,7 +150,8 @@ def _parse_generators(self) -> None: """Parse generators from its templated value.""" hyper_primitives = [] def _extract_immediate_child_hyper_primitives( - path: object_utils.KeyPath, value: Any) -> bool: + path: utils.KeyPath, value: Any + ) -> bool: """Extract top-level hyper primitives.""" if (isinstance(value, base.HyperValue) and (not self._where or self._where(value))): @@ -162,13 +163,14 @@ def _extract_immediate_child_hyper_primitives( hyper_primitives.append((path, value)) elif isinstance(value, symbolic.Object): for k, v in value.sym_items(): - object_utils.traverse( - v, _extract_immediate_child_hyper_primitives, - root_path=object_utils.KeyPath(k, path)) + utils.traverse( + v, + _extract_immediate_child_hyper_primitives, + root_path=utils.KeyPath(k, path), + ) return True - object_utils.traverse( - self._value, _extract_immediate_child_hyper_primitives) + utils.traverse(self._value, _extract_immediate_child_hyper_primitives) self._hyper_primitives = hyper_primitives @property @@ -186,15 +188,15 @@ def is_constant(self) -> bool: """Returns whether current template is constant value.""" return not self._hyper_primitives - def dna_spec( - self, location: Optional[object_utils.KeyPath] = None) -> geno.Space: + def dna_spec(self, location: Optional[utils.KeyPath] = None) -> geno.Space: """Return DNA spec (geno.Space) from this template.""" return geno.Space( elements=[ primitive.dna_spec(primitive_location) for primitive_location, primitive in self._hyper_primitives ], - location=location or object_utils.KeyPath()) + location=location or utils.KeyPath(), + ) def _decode(self) -> Any: """Decode DNA into a value.""" @@ -202,9 +204,10 @@ def _decode(self) -> Any: assert dna is not None if not self._hyper_primitives and (dna.value is not None or dna.children): raise ValueError( - object_utils.message_on_path( - f'Encountered extra DNA value to decode: {dna!r}', - self._root_path)) + utils.message_on_path( + f'Encountered extra DNA value to decode: {dna!r}', self._root_path + ) + ) # Compute hyper primitive values first. rebind_dict = {} @@ -214,11 +217,14 @@ def _decode(self) -> Any: else: if len(dna.children) != len(self._hyper_primitives): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'The length of child values ({len(dna.children)}) is ' - f'different from the number of hyper primitives ' + 'different from the number of hyper primitives ' f'({len(self._hyper_primitives)}) in ObjectTemplate. ' - f'DNA={dna!r}, ObjectTemplate={self!r}.', self._root_path)) + f'DNA={dna!r}, ObjectTemplate={self!r}.', + self._root_path, + ) + ) for i, (primitive_location, primitive) in enumerate( self._hyper_primitives): rebind_dict[primitive_location.path] = ( @@ -247,18 +253,18 @@ def _decode(self) -> Any: # TODO(daiyip): Currently derived value parsing is done at decode time, # which can be optimized by moving to template creation time. derived_values = [] - def _extract_derived_values( - path: object_utils.KeyPath, value: Any) -> bool: + def _extract_derived_values(path: utils.KeyPath, value: Any) -> bool: """Extract top-level primitives.""" if isinstance(value, derived.DerivedValue): derived_values.append((path, value)) elif isinstance(value, symbolic.Object): for k, v in value.sym_items(): - object_utils.traverse( - v, _extract_derived_values, - root_path=object_utils.KeyPath(k, path)) + utils.traverse( + v, _extract_derived_values, root_path=utils.KeyPath(k, path) + ) return True - object_utils.traverse(value, _extract_derived_values) + + utils.traverse(value, _extract_derived_values) if derived_values: if not copied: @@ -299,9 +305,9 @@ def encode(self, value: Any) -> geno.DNA: ValueError if value cannot be encoded by this template. """ children = [] - def _encode(path: object_utils.KeyPath, - template_value: Any, - input_value: Any) -> Any: + def _encode( + path: utils.KeyPath, template_value: Any, input_value: Any + ) -> Any: """Encode input value according to template value.""" if (pg_typing.MISSING_VALUE == input_value and pg_typing.MISSING_VALUE != template_value): @@ -339,10 +345,12 @@ def _encode(path: object_utils.KeyPath, f'TemplateOnlyKeys={template_keys - value_keys}, ' f'InputOnlyKeys={value_keys - template_keys})') for key in template_value.sym_keys(): - object_utils.merge_tree( + utils.merge_tree( template_value.sym_getattr(key), input_value.sym_getattr(key), - _encode, root_path=object_utils.KeyPath(key, path)) + _encode, + root_path=utils.KeyPath(key, path), + ) elif isinstance(template_value, symbolic.Dict): # Do nothing since merge will iterate all elements in dict and list. if not isinstance(input_value, dict): @@ -358,19 +366,23 @@ def _encode(path: object_utils.KeyPath, f'value. (Path=\'{path}\', Template={template_value!r}, ' f'Input={input_value!r})') for i, template_item in enumerate(template_value): - object_utils.merge_tree( - template_item, input_value[i], _encode, - root_path=object_utils.KeyPath(i, path)) + utils.merge_tree( + template_item, + input_value[i], + _encode, + root_path=utils.KeyPath(i, path), + ) else: if template_value != input_value: raise ValueError( - f'Unmatched value between template and input. ' - f'(Path=\'{path}\', ' - f'Template={object_utils.quote_if_str(template_value)}, ' - f'Input={object_utils.quote_if_str(input_value)})') + 'Unmatched value between template and input. ' + f"(Path='{path}', " + f'Template={utils.quote_if_str(template_value)}, ' + f'Input={utils.quote_if_str(input_value)})' + ) return template_value - object_utils.merge_tree( - self._value, value, _encode, root_path=self._root_path) + + utils.merge_tree(self._value, value, _encode, root_path=self._root_path) return geno.DNA(None, children) def try_encode(self, value: Any) -> Tuple[bool, geno.DNA]: @@ -399,18 +411,18 @@ def format(self, root_indent: int = 0, **kwargs) -> str: """Format this object.""" - details = object_utils.format( - self._value, compact, verbose, root_indent, **kwargs) + details = utils.format(self._value, compact, verbose, root_indent, **kwargs) return f'{self.__class__.__name__}(value={details})' def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, 'ObjectTemplate']: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, 'ObjectTemplate']: """Validate candidates during value_spec binding time.""" # Check if value_spec directly accepts `self`. if not value_spec.value_type or not isinstance(self, value_spec.value_type): diff --git a/pyglove/core/logging_test.py b/pyglove/core/logging_test.py index 932c126..d56695d 100644 --- a/pyglove/core/logging_test.py +++ b/pyglove/core/logging_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.""" - import io import logging import unittest diff --git a/pyglove/core/object_utils/__init__.py b/pyglove/core/object_utils/__init__.py deleted file mode 100644 index fdc8bec..0000000 --- a/pyglove/core/object_utils/__init__.py +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright 2022 The PyGlove Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# pylint: disable=line-too-long -"""Utility library that provides common traits for objects in Python. - -Overview --------- - -``pg.object_utils`` sits at the bottom of all PyGlove modules and empowers other -modules with the following features: - - +---------------------+--------------------------------------------+ - | Functionality | API | - +=====================+============================================+ - | Formatting | :class:`pg.Formattable`, | - | | | - | | :func:`pg.format`, | - | | | - | | :func:`pg.print`, | - | | | - | | :func:`pg.object_utils.kvlist_str`, | - | | | - | | :func:`pg.object_utils.quote_if_str`, | - | | | - | | :func:`pg.object_utils.message_on_path` | - +---------------------+--------------------------------------------+ - | Serialization | :class:`pg.JSONConvertible`, | - | | | - | | :func:`pg.registered_types`, | - | | | - | | :func:`pg.object_utils.to_json`, | - | | | - | | :func:`pg.object_utils.from_json`, | - +---------------------+--------------------------------------------+ - | Partial construction| :class:`pg.MaybePartial`, | - | | | - | | :const:`pg.MISSING_VALUE`. | - +---------------------+--------------------------------------------+ - | Hierarchical key | :class:`pg.KeyPath` | - | representation | | - +---------------------+--------------------------------------------+ - | Hierarchical object | :func:`pg.object_utils.traverse` | - | traversal | | - +---------------------+--------------------------------------------+ - | Hierarchical object | :func:`pg.object_utils.transform`, | - | transformation | | - | | :func:`pg.object_utils.merge`, | - | | | - | | :func:`pg.object_utils.canonicalize`, | - | | | - | | :func:`pg.object_utils.flatten` | - +---------------------+--------------------------------------------+ - | Code generation | :class:`pg.object_utils.make_function` | - +---------------------+--------------------------------------------+ - | Docstr handling | :class:`pg.docstr`, | - +---------------------+--------------------------------------------+ - | Error handling | :class:`pg.catch_errors`, | - +---------------------+--------------------------------------------+ -""" -# pylint: enable=line-too-long -# pylint: disable=g-bad-import-order -# pylint: disable=g-importing-member - -# Handling JSON conversion. -from pyglove.core.object_utils.json_conversion import Nestable -from pyglove.core.object_utils.json_conversion import JSONValueType - -from pyglove.core.object_utils.json_conversion import JSONConvertible -from pyglove.core.object_utils.json_conversion import from_json -from pyglove.core.object_utils.json_conversion import to_json -from pyglove.core.object_utils.json_conversion import registered_types - -# Handling formatting. -from pyglove.core.object_utils.formatting import Formattable -from pyglove.core.object_utils.formatting import format # pylint: disable=redefined-builtin -from pyglove.core.object_utils.formatting import printv as print # pylint: disable=redefined-builtin -from pyglove.core.object_utils.formatting import kvlist_str -from pyglove.core.object_utils.formatting import quote_if_str -from pyglove.core.object_utils.formatting import maybe_markdown_quote -from pyglove.core.object_utils.formatting import comma_delimited_str -from pyglove.core.object_utils.formatting import camel_to_snake -from pyglove.core.object_utils.formatting import auto_plural -from pyglove.core.object_utils.formatting import BracketType -from pyglove.core.object_utils.formatting import bracket_chars -from pyglove.core.object_utils.formatting import RawText - -# Context managers for defining the default format for __str__ and __repr__. -from pyglove.core.object_utils.formatting import str_format -from pyglove.core.object_utils.formatting import repr_format - -# Value location. -from pyglove.core.object_utils.value_location import KeyPath -from pyglove.core.object_utils.value_location import KeyPathSet -from pyglove.core.object_utils.value_location import StrKey -from pyglove.core.object_utils.value_location import message_on_path - -# Value markers. -from pyglove.core.object_utils.missing import MissingValue -from pyglove.core.object_utils.missing import MISSING_VALUE - -# Handling hierarchical. -from pyglove.core.object_utils.hierarchical import traverse -from pyglove.core.object_utils.hierarchical import transform -from pyglove.core.object_utils.hierarchical import flatten -from pyglove.core.object_utils.hierarchical import canonicalize -from pyglove.core.object_utils.hierarchical import merge -from pyglove.core.object_utils.hierarchical import merge_tree -from pyglove.core.object_utils.hierarchical import is_partial -from pyglove.core.object_utils.hierarchical import try_listify_dict_with_int_keys - -# Common traits. -from pyglove.core.object_utils.common_traits import MaybePartial -from pyglove.core.object_utils.common_traits import Functor - -from pyglove.core.object_utils.common_traits import explicit_method_override -from pyglove.core.object_utils.common_traits import ensure_explicit_method_override - -# Handling thread local values. -from pyglove.core.object_utils.thread_local import thread_local_value_scope -from pyglove.core.object_utils.thread_local import thread_local_has -from pyglove.core.object_utils.thread_local import thread_local_set -from pyglove.core.object_utils.thread_local import thread_local_get -from pyglove.core.object_utils.thread_local import thread_local_del -from pyglove.core.object_utils.thread_local import thread_local_increment -from pyglove.core.object_utils.thread_local import thread_local_decrement -from pyglove.core.object_utils.thread_local import thread_local_push -from pyglove.core.object_utils.thread_local import thread_local_pop -from pyglove.core.object_utils.thread_local import thread_local_peek - -# Handling docstrings. -from pyglove.core.object_utils.docstr_utils import DocStr -from pyglove.core.object_utils.docstr_utils import DocStrStyle -from pyglove.core.object_utils.docstr_utils import DocStrEntry -from pyglove.core.object_utils.docstr_utils import DocStrExample -from pyglove.core.object_utils.docstr_utils import DocStrArgument -from pyglove.core.object_utils.docstr_utils import DocStrReturns -from pyglove.core.object_utils.docstr_utils import DocStrRaises -from pyglove.core.object_utils.docstr_utils import docstr - -# Handling exceptions. -from pyglove.core.object_utils.error_utils import catch_errors -from pyglove.core.object_utils.error_utils import CatchErrorsContext -from pyglove.core.object_utils.error_utils import ErrorInfo - -# Timing. -from pyglove.core.object_utils.timing import timeit -from pyglove.core.object_utils.timing import TimeIt - -# pylint: enable=g-importing-member -# pylint: enable=g-bad-import-order diff --git a/pyglove/core/patching/object_factory.py b/pyglove/core/patching/object_factory.py index 4102615..9d7fed6 100644 --- a/pyglove/core/patching/object_factory.py +++ b/pyglove/core/patching/object_factory.py @@ -14,8 +14,8 @@ """Object factory based on patchers.""" from typing import Any, Callable, Dict, Optional, Type, Union -from pyglove.core import object_utils from pyglove.core import symbolic +from pyglove.core import utils from pyglove.core.patching import rule_based @@ -88,7 +88,7 @@ def ObjectFactory( # pylint: disable=invalid-name # Step 3: Patch with additional parameter override dict if available. if params_override: value = value.rebind( - object_utils.flatten(from_maybe_serialized(params_override, dict)), - raise_on_no_change=False) + utils.flatten(from_maybe_serialized(params_override, dict)), + raise_on_no_change=False, + ) return value - diff --git a/pyglove/core/patching/pattern_based.py b/pyglove/core/patching/pattern_based.py index 22c6810..7be0e49 100644 --- a/pyglove/core/patching/pattern_based.py +++ b/pyglove/core/patching/pattern_based.py @@ -15,8 +15,8 @@ import re from typing import Any, Callable, Optional, Tuple, Type, Union -from pyglove.core import object_utils from pyglove.core import symbolic +from pyglove.core import utils def patch_on_key( @@ -214,11 +214,11 @@ def patch_on_member( def _conditional_patch( src: symbolic.Symbolic, - condition: Callable[ - [object_utils.KeyPath, Any, symbolic.Symbolic], bool], + condition: Callable[[utils.KeyPath, Any, symbolic.Symbolic], bool], value: Any = None, value_fn: Optional[Callable[[Any], Any]] = None, - skip_notification: Optional[bool] = None) -> Any: + skip_notification: Optional[bool] = None, +) -> Any: """Recursive patch values on condition. Args: diff --git a/pyglove/core/patching/rule_based.py b/pyglove/core/patching/rule_based.py index 56f9fbc..1fdbd47 100644 --- a/pyglove/core/patching/rule_based.py +++ b/pyglove/core/patching/rule_based.py @@ -16,9 +16,9 @@ import re import typing from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils class Patcher(symbolic.Functor): @@ -350,7 +350,7 @@ def from_uri(uri: str) -> Patcher: name, args, kwargs = parse_uri(uri) patcher_cls = typing.cast(Type[Any], _PATCHER_REGISTRY.get(name)) args, kwargs = parse_args(patcher_cls.__signature__, args, kwargs) - return patcher_cls(object_utils.MISSING_VALUE, *args, **kwargs) + return patcher_cls(utils.MISSING_VALUE, *args, **kwargs) def parse_uri(uri: str) -> Tuple[str, List[str], Dict[str, str]]: @@ -467,7 +467,8 @@ def _value_error(msg): f'{value_spec!r} cannot be used for Patcher argument.\n' f'Consider to treat this argument as string and parse it yourself.') return value_spec.apply( - arg, root_path=object_utils.KeyPath.parse(f'{patcher_id}.{arg_name}')) + arg, root_path=utils.KeyPath.parse(f'{patcher_id}.{arg_name}') + ) def parse_list(string: str, diff --git a/pyglove/core/symbolic/base.py b/pyglove/core/symbolic/base.py index 3bc5365..3374fe8 100644 --- a/pyglove/core/symbolic/base.py +++ b/pyglove/core/symbolic/base.py @@ -25,8 +25,8 @@ from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union from pyglove.core import io as pg_io -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import flags from pyglove.core.symbolic.origin import Origin from pyglove.core.symbolic.pure_symbolic import NonDeterministic @@ -38,15 +38,17 @@ class WritePermissionError(Exception): """Exception raisen when write access to object fields is not allowed.""" -class FieldUpdate(object_utils.Formattable): +class FieldUpdate(utils.Formattable): """Class that describes an update to a field in an object tree.""" - def __init__(self, - path: object_utils.KeyPath, - target: 'Symbolic', - field: Optional[pg_typing.Field], - old_value: Any, - new_value: Any): + def __init__( + self, + path: utils.KeyPath, + target: 'Symbolic', + field: Optional[pg_typing.Field], + old_value: Any, + new_value: Any, + ): """Constructor. Args: @@ -70,18 +72,18 @@ def format( **kwargs, ) -> str: """Formats this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('parent_path', self.target.sym_path, None), ('path', self.path, None), - ('old_value', self.old_value, object_utils.MISSING_VALUE), - ('new_value', self.new_value, object_utils.MISSING_VALUE), + ('old_value', self.old_value, utils.MISSING_VALUE), + ('new_value', self.new_value, utils.MISSING_VALUE), ], label=self.__class__.__name__, compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def __eq__(self, other: Any) -> bool: @@ -124,11 +126,11 @@ def sym_setparent(self, parent: Optional['TopologyAware']) -> None: @property @abc.abstractmethod - def sym_path(self) -> object_utils.KeyPath: + def sym_path(self) -> utils.KeyPath: """Returns the path of this object under its topology.""" @abc.abstractmethod - def sym_setpath(self, path: object_utils.KeyPath) -> None: + def sym_setpath(self, path: utils.KeyPath) -> None: """Sets the path of this object under its topology.""" @@ -172,9 +174,9 @@ def infer(self, **kwargs) -> Any: class Symbolic( TopologyAware, - object_utils.Formattable, - object_utils.JSONConvertible, - object_utils.MaybePartial, + utils.Formattable, + utils.JSONConvertible, + utils.MaybePartial, HtmlConvertible, ): """Base for all symbolic types. @@ -203,13 +205,15 @@ class Symbolic( # pylint: enable=invalid-name - def __init__(self, - *, - allow_partial: bool, - accessor_writable: bool, - sealed: bool, - root_path: Optional[object_utils.KeyPath], - init_super: bool = True): + def __init__( + self, + *, + allow_partial: bool, + accessor_writable: bool, + sealed: bool, + root_path: Optional[utils.KeyPath], + init_super: bool = True, + ): """Constructor. Args: @@ -237,7 +241,7 @@ def __init__(self, # NOTE(daiyip): parent is used for rebind call to notify their ancestors # for updates, not for external usage. self._set_raw_attr('_sym_parent', None) - self._set_raw_attr('_sym_path', root_path or object_utils.KeyPath()) + self._set_raw_attr('_sym_path', root_path or utils.KeyPath()) self._set_raw_attr('_sym_puresymbolic', None) self._set_raw_attr('_sym_missing_values', None) self._set_raw_attr('_sym_nondefault_values', None) @@ -317,7 +321,7 @@ def sym_missing(self, flatten: bool = True) -> Dict[Union[str, int], Any]: missing = self._sym_missing() self._set_raw_attr('_sym_missing_values', missing) if flatten: - missing = object_utils.flatten(missing) + missing = utils.flatten(missing) return missing def sym_nondefault(self, flatten: bool = True) -> Dict[Union[int, str], Any]: @@ -327,7 +331,7 @@ def sym_nondefault(self, flatten: bool = True) -> Dict[Union[int, str], Any]: nondefault = self._sym_nondefault() self._set_raw_attr('_sym_nondefault_values', nondefault) if flatten: - nondefault = object_utils.flatten(nondefault) + nondefault = utils.flatten(nondefault) return nondefault @property @@ -405,7 +409,7 @@ def visit(k, v, p): def sym_attr_field(self, key: Union[str, int]) -> Optional[pg_typing.Field]: """Returns the field definition for a symbolic attribute.""" - def sym_has(self, path: Union[object_utils.KeyPath, str, int]) -> bool: + def sym_has(self, path: Union[utils.KeyPath, str, int]) -> bool: """Returns True if a path exists in the sub-tree. Args: @@ -414,13 +418,14 @@ def sym_has(self, path: Union[object_utils.KeyPath, str, int]) -> bool: Returns: True if the path exists in current sub-tree, otherwise False. """ - return object_utils.KeyPath.from_value(path).exists(self) + return utils.KeyPath.from_value(path).exists(self) def sym_get( self, - path: Union[object_utils.KeyPath, str, int], + path: Union[utils.KeyPath, str, int], default: Any = RAISE_IF_NOT_FOUND, - use_inferred: bool = False) -> Any: + use_inferred: bool = False, + ) -> Any: """Returns a sub-node by path. NOTE: there is no `sym_set`, use `sym_rebind`. @@ -439,7 +444,7 @@ def sym_get( Raises: KeyError if `path` does not exist and `default` is not specified. """ - path = object_utils.KeyPath.from_value(path) + path = utils.KeyPath.from_value(path) if default is RAISE_IF_NOT_FOUND: return path.query(self, use_inferred=use_inferred) else: @@ -533,12 +538,11 @@ def sym_contains( return contains(self, value, type) @property - def sym_path(self) -> object_utils.KeyPath: + def sym_path(self) -> utils.KeyPath: """Returns the path of current object from the root of its symbolic tree.""" return getattr(self, '_sym_path') - def sym_setpath( - self, path: Optional[Union[str, object_utils.KeyPath]]) -> None: + def sym_setpath(self, path: Optional[Union[str, utils.KeyPath]]) -> None: """Sets the path of current node in its symbolic tree.""" if self.sym_path != path: old_path = self.sym_path @@ -547,11 +551,9 @@ def sym_setpath( def sym_rebind( self, - path_value_pairs: Optional[Union[ - Dict[ - Union[object_utils.KeyPath, str, int], - Any], - Callable]] = None, # pylint: disable=g-bare-generic + path_value_pairs: Optional[ + Union[Dict[Union[utils.KeyPath, str, int], Any], Callable[..., Any]] + ] = None, # pylint: disable=g-bare-generic *, raise_on_no_change: bool = True, notify_parents: bool = True, @@ -575,8 +577,9 @@ def sym_rebind( f'Argument \'path_value_pairs\' should be a dict. ' f'Encountered {path_value_pairs}')) path_value_pairs.update(kwargs) - path_value_pairs = {object_utils.KeyPath.from_value(k): v - for k, v in path_value_pairs.items()} + path_value_pairs = { + utils.KeyPath.from_value(k): v for k, v in path_value_pairs.items() + } if not path_value_pairs and raise_on_no_change: raise ValueError(self._error_message('There are no values to rebind.')) @@ -601,10 +604,9 @@ def sym_clone(self, return new_value @abc.abstractmethod - def sym_jsonify(self, - *, - hide_default_values: bool = False, - **kwargs) -> object_utils.JSONValueType: + def sym_jsonify( + self, *, hide_default_values: bool = False, **kwargs + ) -> utils.JSONValueType: """Converts representation of current object to a plain Python object.""" def sym_ne(self, other: Any) -> bool: @@ -749,16 +751,15 @@ def is_sealed(self) -> bool: def rebind( self, - path_value_pairs: Optional[Union[ - Dict[ - Union[object_utils.KeyPath, str, int], - Any], - Callable]] = None, # pylint: disable=g-bare-generic + path_value_pairs: Optional[ + Union[Dict[Union[utils.KeyPath, str, int], Any], Callable[..., Any]] + ] = None, # pylint: disable=g-bare-generic *, raise_on_no_change: bool = True, notify_parents: bool = True, skip_notification: Optional[bool] = None, - **kwargs) -> 'Symbolic': + **kwargs, + ) -> 'Symbolic': """Alias for `sym_rebind`. Alias for `sym_rebind`. `rebind` is the recommended way for mutating @@ -941,7 +942,7 @@ def clone( """ return self.sym_clone(deep, memo, override) - def to_json(self, **kwargs) -> object_utils.JSONValueType: + def to_json(self, **kwargs) -> utils.JSONValueType: """Alias for `sym_jsonify`.""" return to_json(self, **kwargs) @@ -964,13 +965,18 @@ def save(self, *args, **kwargs) -> Any: def inspect( self, path_regex: Optional[str] = None, - where: Optional[Union[Callable[[Any], bool], - Callable[[Any, Any], bool]]] = None, - custom_selector: Optional[Union[ - Callable[[object_utils.KeyPath, Any], bool], - Callable[[object_utils.KeyPath, Any, Any], bool]]] = None, + where: Optional[ + Union[Callable[[Any], bool], Callable[[Any, Any], bool]] + ] = None, + custom_selector: Optional[ + Union[ + Callable[[utils.KeyPath, Any], bool], + Callable[[utils.KeyPath, Any, Any], bool], + ] + ] = None, file=sys.stdout, # pylint: disable=redefined-builtin - **kwargs) -> None: + **kwargs, + ) -> None: """Inspects current object by printing out selected values. Example:: @@ -1058,7 +1064,7 @@ class A(pg.Object): v = self else: v = query(self, path_regex, where, False, custom_selector) - object_utils.print(v, file=file, **kwargs) + utils.print(v, file=file, **kwargs) def __copy__(self) -> 'Symbolic': """Overridden shallow copy.""" @@ -1074,8 +1080,8 @@ def __deepcopy__(self, memo) -> 'Symbolic': @abc.abstractmethod def _sym_rebind( - self, path_value_pairs: Dict[object_utils.KeyPath, Any] - ) -> List[FieldUpdate]: + self, path_value_pairs: Dict[utils.KeyPath, Any] + ) -> List[FieldUpdate]: """Subclass specific rebind implementation. Args: @@ -1111,9 +1117,8 @@ def _sym_clone(self, deep: bool, memo=None) -> 'Symbolic': @abc.abstractmethod def _update_children_paths( - self, - old_path: object_utils.KeyPath, - new_path: object_utils.KeyPath) -> None: + self, old_path: utils.KeyPath, new_path: utils.KeyPath + ) -> None: """Update children paths according to root_path of current node.""" @abc.abstractmethod @@ -1122,7 +1127,7 @@ def _set_item_without_permission_check( """Child should implement: set an item without permission check.""" @abc.abstractmethod - def _on_change(self, field_updates: Dict[object_utils.KeyPath, FieldUpdate]): + def _on_change(self, field_updates: Dict[utils.KeyPath, FieldUpdate]): """Event that is triggered when field values in the subtree are updated. This event will be called @@ -1175,14 +1180,14 @@ def _relocate_if_symbolic(self, key: Union[str, int], value: Any) -> Any: # NOTE(daiyip): make a copy of symbolic object if it belongs to another # object tree, this prevents it from having multiple parents. See # List._formalized_value for similar logic. - root_path = object_utils.KeyPath(key, self.sym_path) + root_path = utils.KeyPath(key, self.sym_path) if (value.sym_parent is not None and (value.sym_parent is not self or root_path != value.sym_path)): value = value.clone() if isinstance(value, TopologyAware): - value.sym_setpath(object_utils.KeyPath(key, self.sym_path)) + value.sym_setpath(utils.KeyPath(key, self.sym_path)) value.sym_setparent(self._sym_parent_for_children()) return value @@ -1191,9 +1196,10 @@ def _sym_parent_for_children(self) -> Optional['Symbolic']: return self def _set_item_of_current_tree( - self, path: object_utils.KeyPath, value: Any) -> Optional[FieldUpdate]: + self, path: utils.KeyPath, value: Any + ) -> Optional[FieldUpdate]: """Set a field of current tree by key path and return its parent.""" - assert isinstance(path, object_utils.KeyPath), path + assert isinstance(path, utils.KeyPath), path if not path: raise KeyError( self._error_message( @@ -1222,8 +1228,8 @@ def _notify_field_updates( per_target_updates = dict() def _get_target_updates( - target: 'Symbolic' - ) -> Dict[object_utils.KeyPath, FieldUpdate]: + target: 'Symbolic', + ) -> Dict[utils.KeyPath, FieldUpdate]: target_id = id(target) if target_id not in per_target_updates: per_target_updates[target_id] = (target, dict()) @@ -1256,7 +1262,7 @@ def _get_target_updates( def _error_message(self, message: str) -> str: """Create error message to include path information.""" - return object_utils.message_on_path(message, self.sym_path) + return utils.message_on_path(message, self.sym_path) # @@ -1271,12 +1277,11 @@ def get_rebind_dict( """Generate rebind dict using rebinder on target value. Args: - rebinder: A callable object with signature: - (key_path: object_utils.KeyPath, value: Any) -> Any or - (key_path: object_utils.KeyPath, value: Any, parent: Any) -> Any. If - rebinder returns the same value from input, the value is considered - unchanged. Otherwise it will be put into the returning rebind dict. See - `Symbolic.rebind` for more details. + rebinder: A callable object with signature: (key_path: utils.KeyPath, value: + Any) -> Any or (key_path: utils.KeyPath, value: Any, parent: Any) -> Any. + If rebinder returns the same value from input, the value is considered + unchanged. Otherwise it will be put into the returning rebind dict. See + `Symbolic.rebind` for more details. target: Upon which value the rebind dict is computed. Returns: @@ -1329,15 +1334,17 @@ class TraverseAction(enum.Enum): CONTINUE = 2 -def traverse(x: Any, - preorder_visitor_fn: Optional[ - Callable[[object_utils.KeyPath, Any, Any], - Optional[TraverseAction]]] = None, - postorder_visitor_fn: Optional[ - Callable[[object_utils.KeyPath, Any, Any], - Optional[TraverseAction]]] = None, - root_path: Optional[object_utils.KeyPath] = None, - parent: Optional[Any] = None) -> bool: +def traverse( + x: Any, + preorder_visitor_fn: Optional[ + Callable[[utils.KeyPath, Any, Any], Optional[TraverseAction]] + ] = None, + postorder_visitor_fn: Optional[ + Callable[[utils.KeyPath, Any, Any], Optional[TraverseAction]] + ] = None, + root_path: Optional[utils.KeyPath] = None, + parent: Optional[Any] = None, +) -> bool: """Traverse a (maybe) symbolic value using visitor functions. Example:: @@ -1372,7 +1379,7 @@ def track_integers(k, v, p): either `TraverseAction.ENTER` or `TraverseAction.CONTINUE` for all nodes. Otherwise False. """ - root_path = root_path or object_utils.KeyPath() + root_path = root_path or utils.KeyPath() def no_op_visitor(path, value, parent): del path, value, parent @@ -1387,20 +1394,35 @@ def no_op_visitor(path, value, parent): if preorder_action is None or preorder_action == TraverseAction.ENTER: if isinstance(x, dict): for k, v in x.items(): - if not traverse(v, preorder_visitor_fn, postorder_visitor_fn, - object_utils.KeyPath(k, root_path), x): + if not traverse( + v, + preorder_visitor_fn, + postorder_visitor_fn, + utils.KeyPath(k, root_path), + x, + ): preorder_action = TraverseAction.STOP break elif isinstance(x, list): for i, v in enumerate(x): - if not traverse(v, preorder_visitor_fn, postorder_visitor_fn, - object_utils.KeyPath(i, root_path), x): + if not traverse( + v, + preorder_visitor_fn, + postorder_visitor_fn, + utils.KeyPath(i, root_path), + x, + ): preorder_action = TraverseAction.STOP break elif isinstance(x, Symbolic.ObjectType): # pytype: disable=wrong-arg-types for k, v in x.sym_items(): - if not traverse(v, preorder_visitor_fn, postorder_visitor_fn, - object_utils.KeyPath(k, root_path), x): + if not traverse( + v, + preorder_visitor_fn, + postorder_visitor_fn, + utils.KeyPath(k, root_path), + x, + ): preorder_action = TraverseAction.STOP break postorder_action = postorder_visitor_fn(root_path, x, parent) @@ -1413,12 +1435,16 @@ def no_op_visitor(path, value, parent): def query( x: Any, path_regex: Optional[str] = None, - where: Optional[Union[Callable[[Any], bool], - Callable[[Any, Any], bool]]] = None, + where: Optional[ + Union[Callable[[Any], bool], Callable[[Any, Any], bool]] + ] = None, enter_selected: bool = False, - custom_selector: Optional[Union[ - Callable[[object_utils.KeyPath, Any], bool], - Callable[[object_utils.KeyPath, Any, Any], bool]]] = None + custom_selector: Optional[ + Union[ + Callable[[utils.KeyPath, Any], bool], + Callable[[utils.KeyPath, Any, Any], bool], + ] + ] = None, ) -> Dict[str, Any]: """Queries a (maybe) symbolic value. @@ -1521,8 +1547,9 @@ def select_fn(k, v, p): results = {} - def _preorder_visitor(path: object_utils.KeyPath, v: Any, - parent: Any) -> TraverseAction: + def _preorder_visitor( + path: utils.KeyPath, v: Any, parent: Any + ) -> TraverseAction: if select_fn(path, v, parent): # pytype: disable=wrong-arg-count results[str(path)] = v return TraverseAction.ENTER if enter_selected else TraverseAction.CONTINUE @@ -1752,7 +1779,7 @@ def gt(left: Any, right: Any) -> bool: def _type_order(value: Any) -> str: """Returns the ordering string of value's type.""" - if isinstance(value, object_utils.MissingValue): + if isinstance(value, utils.MissingValue): type_order = 0 elif value is None: type_order = 1 @@ -1950,7 +1977,7 @@ class Bar(pg.PureSymbolic): True if value itself is partial/PureSymbolic or its child and nested child fields contain partial/PureSymbolic values. """ - return object_utils.is_partial(x) or is_pure_symbolic(x) + return utils.is_partial(x) or is_pure_symbolic(x) def contains( @@ -2009,11 +2036,11 @@ def from_json( json_value: Any, *, allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, auto_import: bool = True, auto_dict: bool = False, value_spec: Optional[pg_typing.ValueSpec] = None, - **kwargs + **kwargs, ) -> Any: """Deserializes a (maybe) symbolic value from JSON value. @@ -2057,28 +2084,30 @@ class A(pg.Object): typename_resolved = kwargs.pop('_typename_resolved', False) if not typename_resolved: - json_value = object_utils.json_conversion.resolve_typenames( + json_value = utils.json_conversion.resolve_typenames( json_value, auto_import=auto_import, auto_dict=auto_dict ) def _load_child(k, v): return from_json( v, - root_path=object_utils.KeyPath(k, root_path), + root_path=utils.KeyPath(k, root_path), _typename_resolved=True, allow_partial=allow_partial, - **kwargs + **kwargs, ) if isinstance(json_value, list): - if (json_value - and json_value[0] == object_utils.JSONConvertible.TUPLE_MARKER): + if json_value and json_value[0] == utils.JSONConvertible.TUPLE_MARKER: if len(json_value) < 2: raise ValueError( - object_utils.message_on_path( - f'Tuple should have at least one element ' - f'besides \'{object_utils.JSONConvertible.TUPLE_MARKER}\'. ' - f'Encountered: {json_value}', root_path)) + utils.message_on_path( + 'Tuple should have at least one element ' + f"besides '{utils.JSONConvertible.TUPLE_MARKER}'. " + f'Encountered: {json_value}', + root_path, + ) + ) return tuple(_load_child(i, v) for i, v in enumerate(json_value[1:])) return Symbolic.ListType.from_json( # pytype: disable=attribute-error json_value, @@ -2088,7 +2117,7 @@ def _load_child(k, v): **kwargs, ) elif isinstance(json_value, dict): - if object_utils.JSONConvertible.TYPE_NAME_KEY not in json_value: + if utils.JSONConvertible.TYPE_NAME_KEY not in json_value: return Symbolic.DictType.from_json( # pytype: disable=attribute-error json_value, value_spec=value_spec, @@ -2096,20 +2125,25 @@ def _load_child(k, v): allow_partial=allow_partial, **kwargs, ) - return object_utils.from_json( - json_value, _typename_resolved=True, - root_path=root_path, allow_partial=allow_partial, **kwargs + return utils.from_json( + json_value, + _typename_resolved=True, + root_path=root_path, + allow_partial=allow_partial, + **kwargs, ) return json_value -def from_json_str(json_str: str, - *, - allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None, - auto_import: bool = True, - auto_dict: bool = False, - **kwargs) -> Any: +def from_json_str( + json_str: str, + *, + allow_partial: bool = False, + root_path: Optional[utils.KeyPath] = None, + auto_import: bool = True, + auto_dict: bool = False, + **kwargs, +) -> Any: """Deserialize (maybe) symbolic object from JSON string. Example:: @@ -2202,7 +2236,7 @@ class A(pg.Object): # classes may have conflicting `to_json` method in their existing classes. if isinstance(value, Symbolic): return value.sym_jsonify(**kwargs) - return object_utils.to_json(value, **kwargs) + return utils.to_json(value, **kwargs) def to_json_str(value: Any, @@ -2378,8 +2412,11 @@ def default_save_handler( if file_format == 'json': content = to_json_str(value, json_indent=indent, **kwargs) elif file_format == 'txt': - content = value if isinstance(value, str) else object_utils.format( - value, compact=False, verbose=True) + content = ( + value + if isinstance(value, str) + else utils.format(value, compact=False, verbose=True) + ) else: raise ValueError(f'Unsupported `file_format`: {file_format!r}.') @@ -2415,8 +2452,7 @@ def treats_as_sealed(value: Symbolic) -> bool: def symbolic_transform_fn(allow_partial: bool): """Symbolic object transform function builder.""" - def _fn( - path: object_utils.KeyPath, field: pg_typing.Field, value: Any) -> Any: + def _fn(path: utils.KeyPath, field: pg_typing.Field, value: Any) -> Any: """Transform schema-less List and Dict to symbolic.""" if isinstance(value, Symbolic): return value diff --git a/pyglove/core/symbolic/base_test.py b/pyglove/core/symbolic/base_test.py index 5b936f8..7c4f6c6 100644 --- a/pyglove/core/symbolic/base_test.py +++ b/pyglove/core/symbolic/base_test.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.symbolic.base.""" - import copy import inspect from typing import Any import unittest -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core import views from pyglove.core.symbolic import base from pyglove.core.symbolic.dict import Dict @@ -33,7 +31,7 @@ class FieldUpdateTest(unittest.TestCase): def test_basics(self): x = Dict(x=1) f = pg_typing.Field('x', pg_typing.Int()) - update = base.FieldUpdate(object_utils.KeyPath('x'), x, f, 1, 2) + update = base.FieldUpdate(utils.KeyPath('x'), x, f, 1, 2) self.assertEqual(update.path, 'x') self.assertIs(update.target, x) self.assertIs(update.field, f) @@ -42,15 +40,15 @@ def test_basics(self): def test_format(self): self.assertEqual( - base.FieldUpdate( - object_utils.KeyPath('x'), Dict(x=1), None, 1, 2 - ).format(compact=True), + base.FieldUpdate(utils.KeyPath('x'), Dict(x=1), None, 1, 2).format( + compact=True + ), 'FieldUpdate(parent_path=, path=x, old_value=1, new_value=2)', ) self.assertEqual( base.FieldUpdate( - object_utils.KeyPath('a'), Dict(x=Dict(a=1)).x, None, 1, 2 + utils.KeyPath('a'), Dict(x=Dict(a=1)).x, None, 1, 2 ).format(compact=True), 'FieldUpdate(parent_path=x, path=a, old_value=1, new_value=2)', ) @@ -59,34 +57,34 @@ def test_eq_ne(self): x = Dict() f = pg_typing.Field('x', pg_typing.Int()) self.assertEqual( - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), ) # Targets are not the same instance. self.assertNotEqual( - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - base.FieldUpdate(object_utils.KeyPath('a'), Dict(), f, 1, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(utils.KeyPath('a'), Dict(), f, 1, 2), ) # Fields are not the same instance. self.assertNotEqual( - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - base.FieldUpdate(object_utils.KeyPath('b'), x, copy.copy(f), 1, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(utils.KeyPath('b'), x, copy.copy(f), 1, 2), ) self.assertNotEqual( - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 0, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 0, 2), ) self.assertNotEqual( - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 1), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 1), ) self.assertNotEqual( - base.FieldUpdate(object_utils.KeyPath('a'), x, f, 1, 2), Dict() + base.FieldUpdate(utils.KeyPath('a'), x, f, 1, 2), Dict() ) diff --git a/pyglove/core/symbolic/boilerplate.py b/pyglove/core/symbolic/boilerplate.py index f083c20..273a6a8 100644 --- a/pyglove/core/symbolic/boilerplate.py +++ b/pyglove/core/symbolic/boilerplate.py @@ -15,11 +15,10 @@ import copy import inspect - from typing import Any, List, Optional, Type -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import flags from pyglove.core.symbolic import object as pg_object @@ -129,9 +128,9 @@ class _BoilerplateClass(base_cls): cls.auto_register = True allow_partial = value.allow_partial - def _freeze_field(path: object_utils.KeyPath, - field: pg_typing.Field, - value: Any) -> Any: + def _freeze_field( + path: utils.KeyPath, field: pg_typing.Field, value: Any + ) -> Any: # We do not do validation since Object is already in valid form. del path if not isinstance(field.key, pg_typing.ListKey): diff --git a/pyglove/core/symbolic/class_wrapper.py b/pyglove/core/symbolic/class_wrapper.py index b4d6694..41bf8d3 100644 --- a/pyglove/core/symbolic/class_wrapper.py +++ b/pyglove/core/symbolic/class_wrapper.py @@ -26,8 +26,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union from pyglove.core import detouring -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import dict as pg_dict # pylint: disable=unused-import from pyglove.core.symbolic import list as pg_list # pylint: disable=unused-import @@ -71,7 +71,7 @@ class _SubclassedWrapperBase(ClassWrapper): # the `__init__` method. auto_typing = False - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, *args, **kwargs): """Overridden __init__ to construct symbolic wrapper only.""" # NOTE(daiyip): We avoid `__init__` to be called multiple times. @@ -100,7 +100,7 @@ def wrapped_cls_initialized(self): def __init_subclass__(cls): # Class wrappers inherit `__init__` from the user class. Therefore, we mark # all of them as explicitly overridden. - object_utils.explicit_method_override(cls.__init__) + utils.explicit_method_override(cls.__init__) super().__init_subclass__() if cls.__init__ is _SubclassedWrapperBase.__init__: @@ -129,7 +129,7 @@ def __init_subclass__(cls): init_arg_list, arg_fields = _extract_init_signature( cls, auto_doc=cls.auto_doc, auto_typing=cls.auto_typing) - @object_utils.explicit_method_override + @utils.explicit_method_override @functools.wraps(cls.__init__) def _sym_init(self, *args, **kwargs): _SubclassedWrapperBase.__init__(self, *args, **kwargs) @@ -522,7 +522,7 @@ def foo(): """ if not wrapper_classes: wrapper_classes = [] - for _, c in object_utils.JSONConvertible.registered_types(): + for _, c in utils.JSONConvertible.registered_types(): if (issubclass(c, ClassWrapper) and c not in (ClassWrapper, _SubclassedWrapperBase) and (not where or where(c)) @@ -544,13 +544,13 @@ def _extract_init_signature( # Read args docstr from both class doc string and __init__ doc string. args_docstr = dict() if cls.__doc__: - cls_docstr = object_utils.DocStr.parse(cls.__doc__) + cls_docstr = utils.DocStr.parse(cls.__doc__) args_docstr = cls_docstr.args if init_method.__doc__: - init_docstr = object_utils.DocStr.parse(init_method.__doc__) + init_docstr = utils.DocStr.parse(init_method.__doc__) args_docstr.update(init_docstr.args) - docstr = object_utils.DocStr( - object_utils.DocStrStyle.GOOGLE, + docstr = utils.DocStr( + utils.DocStrStyle.GOOGLE, short_description=None, long_description=None, examples=[], diff --git a/pyglove/core/symbolic/compounding.py b/pyglove/core/symbolic/compounding.py index 48afe8d..e4f1121 100644 --- a/pyglove/core/symbolic/compounding.py +++ b/pyglove/core/symbolic/compounding.py @@ -19,7 +19,7 @@ import types from typing import Any, Dict, List, Optional, Tuple, Type, Union -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.symbolic.base import Symbolic from pyglove.core.symbolic.object import Object import pyglove.core.typing as pg_typing @@ -39,7 +39,7 @@ def __init_subclass__(cls): # from the user class to compound with. Object.__init_subclass__(cls) - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, *args, **kwargs): # `explicit_init` allows the `__init__` of the other classes that sit after # `Compound` to be bypassed. diff --git a/pyglove/core/symbolic/compounding_test.py b/pyglove/core/symbolic/compounding_test.py index c9e0108..812f3de 100644 --- a/pyglove/core/symbolic/compounding_test.py +++ b/pyglove/core/symbolic/compounding_test.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.compounding.""" - import abc import dataclasses import sys import unittest -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic.compounding import compound as pg_compound from pyglove.core.symbolic.compounding import compound_class as pg_compound_class from pyglove.core.symbolic.dict import Dict @@ -145,7 +143,7 @@ def test_user_class_with_side_effect_init(self): class A(Object): x: int - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, x): super().__init__(x=x) assert type(self) is A # pylint: disable=unidiomatic-typecheck diff --git a/pyglove/core/symbolic/dict.py b/pyglove/core/symbolic/dict.py index ec670ec..e0c14cf 100644 --- a/pyglove/core/symbolic/dict.py +++ b/pyglove/core/symbolic/dict.py @@ -16,8 +16,8 @@ import typing from typing import Any, Callable, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import flags @@ -96,14 +96,16 @@ def on_change(updates): """ @classmethod - def partial(cls, - dict_obj: Optional[typing.Dict[Union[str, int], Any]] = None, - value_spec: Optional[pg_typing.Dict] = None, - *, - onchange_callback: Optional[Callable[ - [typing.Dict[object_utils.KeyPath, base.FieldUpdate]], None] - ] = None, # pylint: disable=bad-continuation - **kwargs) -> 'Dict': + def partial( + cls, + dict_obj: Optional[typing.Dict[Union[str, int], Any]] = None, + value_spec: Optional[pg_typing.Dict] = None, + *, + onchange_callback: Optional[ + Callable[[typing.Dict[utils.KeyPath, base.FieldUpdate]], None] + ] = None, # pylint: disable=bad-continuation + **kwargs, + ) -> 'Dict': """Class method that creates a partial Dict object.""" return cls(dict_obj, value_spec=value_spec, @@ -112,13 +114,15 @@ def partial(cls, **kwargs) @classmethod - def from_json(cls, - json_value: Any, - *, - value_spec: Optional[pg_typing.Dict] = None, - allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None, - **kwargs) -> 'Dict': + def from_json( + cls, + json_value: Any, + *, + value_spec: Optional[pg_typing.Dict] = None, + allow_partial: bool = False, + root_path: Optional[utils.KeyPath] = None, + **kwargs, + ) -> 'Dict': """Class method that load an symbolic Dict from a JSON value. Args: @@ -156,27 +160,31 @@ def from_json(cls, { k: base.from_json( v, - root_path=object_utils.KeyPath(k, root_path), + root_path=utils.KeyPath(k, root_path), allow_partial=allow_partial, - **kwargs - ) for k, v in json_value.items() + **kwargs, + ) + for k, v in json_value.items() }, value_spec=value_spec, root_path=root_path, allow_partial=allow_partial, ) - def __init__(self, - dict_obj: Union[ - None, - Iterable[Tuple[Union[str, int], Any]], - typing.Dict[Union[str, int], Any]] = None, - *, - value_spec: Optional[pg_typing.Dict] = None, - onchange_callback: Optional[Callable[ - [typing.Dict[object_utils.KeyPath, base.FieldUpdate]], None] - ] = None, # pylint: disable=bad-continuation - **kwargs): + def __init__( + self, + dict_obj: Union[ + None, + Iterable[Tuple[Union[str, int], Any]], + typing.Dict[Union[str, int], Any], + ] = None, + *, + value_spec: Optional[pg_typing.Dict] = None, + onchange_callback: Optional[ + Callable[[typing.Dict[utils.KeyPath, base.FieldUpdate]], None] + ] = None, # pylint: disable=bad-continuation + **kwargs, + ): """Constructor. Args: @@ -335,8 +343,8 @@ def _sym_parent_for_children(self) -> Optional[base.Symbolic]: return self def _sym_rebind( - self, path_value_pairs: typing.Dict[object_utils.KeyPath, Any] - ) -> List[base.FieldUpdate]: + self, path_value_pairs: typing.Dict[utils.KeyPath, Any] + ) -> List[base.FieldUpdate]: """Subclass specific rebind implementation.""" updates = [] for k, v in path_value_pairs.items(): @@ -360,7 +368,7 @@ def _sym_missing(self) -> typing.Dict[Union[str, int], Any]: if keys: for key in keys: v = self.sym_getattr(key) - if object_utils.MISSING_VALUE == v: + if utils.MISSING_VALUE == v: missing[key] = field.value.default else: if isinstance(v, base.Symbolic): @@ -514,14 +522,13 @@ def _sym_clone(self, deep: bool, memo=None) -> 'Dict': pass_through=True) def _update_children_paths( - self, - old_path: object_utils.KeyPath, - new_path: object_utils.KeyPath) -> None: + self, old_path: utils.KeyPath, new_path: utils.KeyPath + ) -> None: """Update children paths according to root_path of current node.""" del old_path for k, v in self.sym_items(): if isinstance(v, base.TopologyAware): - v.sym_setpath(object_utils.KeyPath(k, new_path)) + v.sym_setpath(utils.KeyPath(k, new_path)) def _set_item_without_permission_check( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, key: Union[str, int], value: Any) -> Optional[base.FieldUpdate]: @@ -550,7 +557,7 @@ def _set_item_without_permission_check( # pytype: disable=signature-mismatch # # Detach old value from object tree. if isinstance(old_value, base.TopologyAware): old_value.sym_setparent(None) - old_value.sym_setpath(object_utils.KeyPath()) + old_value.sym_setpath(utils.KeyPath()) if (pg_typing.MISSING_VALUE == value and (not field or isinstance(field.key, pg_typing.NonConstKey))): @@ -589,13 +596,15 @@ def _formalized_value( value = base.from_json( value, allow_partial=allow_partial, - root_path=object_utils.KeyPath(name, self.sym_path)) + root_path=utils.KeyPath(name, self.sym_path), + ) if field and flags.is_type_check_enabled(): value = field.apply( value, allow_partial=allow_partial, transform_fn=base.symbolic_transform_fn(self._allow_partial), - root_path=object_utils.KeyPath(name, self.sym_path)) + root_path=utils.KeyPath(name, self.sym_path), + ) return self._relocate_if_symbolic(name, value) @property @@ -603,8 +612,9 @@ def _subscribes_field_updates(self) -> bool: """Returns True if current dict subscribes field updates.""" return self._onchange_callback is not None - def _on_change(self, field_updates: typing.Dict[object_utils.KeyPath, - base.FieldUpdate]): + def _on_change( + self, field_updates: typing.Dict[utils.KeyPath, base.FieldUpdate] + ): """On change event of Dict.""" if self._onchange_callback: self._onchange_callback(field_updates) @@ -814,8 +824,8 @@ def sym_jsonify( hide_default_values: bool = False, exclude_keys: Optional[Sequence[Union[str, int]]] = None, use_inferred: bool = False, - **kwargs - ) -> object_utils.JSONValueType: + **kwargs, + ) -> utils.JSONValueType: """Converts current object to a dict with plain Python objects.""" exclude_keys = set(exclude_keys or []) if self._value_spec and self._value_spec.schema: @@ -858,11 +868,12 @@ def sym_jsonify( def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, child_transform: Optional[ - Callable[[object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, ) -> Tuple[bool, 'Dict']: """Implement pg.typing.CustomTyping interface. @@ -881,9 +892,12 @@ def custom_apply( if self._value_spec: if value_spec and not value_spec.is_compatible(self._value_spec): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Dict (spec={self._value_spec!r}) cannot be assigned to an ' - f'incompatible field (spec={value_spec!r}).', path)) + f'incompatible field (spec={value_spec!r}).', + path, + ) + ) if self._allow_partial == allow_partial: proceed_with_standard_apply = False else: @@ -906,7 +920,7 @@ def format( exclude_keys: Optional[Set[Union[str, int]]] = None, use_inferred: bool = False, cls_name: Optional[str] = None, - bracket_type: object_utils.BracketType = object_utils.BracketType.CURLY, + bracket_type: utils.BracketType = utils.BracketType.CURLY, key_as_attribute: bool = False, extra_blankline_for_field_docstr: bool = False, **kwargs, @@ -948,7 +962,7 @@ def _should_include_key(key): v = self.sym_inferred(k, default=v) field_list.append((None, k, v)) - open_bracket, close_bracket = object_utils.bracket_chars(bracket_type) + open_bracket, close_bracket = utils.bracket_chars(bracket_type) if not field_list: return f'{cls_name}{open_bracket}{close_bracket}' @@ -956,7 +970,7 @@ def _should_include_key(key): s = [f'{cls_name}{open_bracket}'] kv_strs = [] for _, k, v in field_list: - v_str = object_utils.format( + v_str = utils.format( v, compact, verbose, @@ -967,7 +981,8 @@ def _should_include_key(key): python_format=python_format, use_inferred=use_inferred, extra_blankline_for_field_docstr=extra_blankline_for_field_docstr, - **kwargs) + **kwargs, + ) if not python_format or key_as_attribute: if isinstance(k, int): k = f'[{k}]' @@ -989,7 +1004,7 @@ def _should_include_key(key): description = typing.cast(pg_typing.Field, f).description for line in description.split('\n'): s.append(_indent(f'# {line}\n', root_indent + 1)) - v_str = object_utils.format( + v_str = utils.format( v, compact, verbose, @@ -1000,7 +1015,8 @@ def _should_include_key(key): python_format=python_format, use_inferred=use_inferred, extra_blankline_for_field_docstr=extra_blankline_for_field_docstr, - **kwargs) + **kwargs, + ) if not python_format: # Format in PyGlove's format (default). diff --git a/pyglove/core/symbolic/dict_test.py b/pyglove/core/symbolic/dict_test.py index 828bf79..3ccd4ce 100644 --- a/pyglove/core/symbolic/dict_test.py +++ b/pyglove/core/symbolic/dict_test.py @@ -11,16 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.Dict.""" - import copy import inspect import io import pickle import unittest -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import flags from pyglove.core.symbolic import inferred @@ -31,7 +29,7 @@ from pyglove.core.symbolic.pure_symbolic import PureSymbolic -MISSING_VALUE = object_utils.MISSING_VALUE +MISSING_VALUE = utils.MISSING_VALUE class DictTest(unittest.TestCase): @@ -799,7 +797,7 @@ def test_sym_has(self): self.assertTrue(sd.sym_has('y.z')) self.assertTrue(sd.sym_has('[1].a')) self.assertTrue(sd.sym_has('y[2]')) - self.assertTrue(sd.sym_has(object_utils.KeyPath.parse('y.z'))) + self.assertTrue(sd.sym_has(utils.KeyPath.parse('y.z'))) self.assertFalse(sd.sym_has('x.z')) def test_sym_get(self): @@ -1267,7 +1265,7 @@ def test_sym_path(self): self.assertEqual(sd.x.a.sym_path, 'x.a') self.assertEqual(sd.y[0].b.sym_path, 'y[0].b') - sd.sym_setpath(object_utils.KeyPath('a')) + sd.sym_setpath(utils.KeyPath('a')) self.assertEqual(sd.sym_path, 'a') self.assertEqual(sd.x.sym_path, 'a.x') self.assertEqual(sd.x.a.sym_path, 'a.x.a') @@ -1724,96 +1722,112 @@ def on_dict_change(field_updates): 'd': 'foo', # Unchanged. 'e': 'bar' }) - self.assertEqual(updates, [ - { # Notification to `sd.c[0]`. - 'p': base.FieldUpdate( - object_utils.KeyPath.parse('c[0].p'), - target=sd.c[0], - field=None, - old_value=1, - new_value=MISSING_VALUE), - 'q': base.FieldUpdate( - object_utils.KeyPath.parse('c[0].q'), - target=sd.c[0], - field=None, - old_value=MISSING_VALUE, - new_value=2), - }, - { # Notification to `sd.c`. - '[0].p': base.FieldUpdate( - object_utils.KeyPath.parse('c[0].p'), - target=sd.c[0], - field=None, - old_value=1, - new_value=MISSING_VALUE), - '[0].q': base.FieldUpdate( - object_utils.KeyPath.parse('c[0].q'), - target=sd.c[0], - field=None, - old_value=MISSING_VALUE, - new_value=2), - }, - { # Notification to `sd.b.y`. - 'z': base.FieldUpdate( - object_utils.KeyPath.parse('b.y.z'), - target=sd.b.y, - field=None, - old_value=MISSING_VALUE, - new_value=1), - }, - { # Notification to `sd.b`. - 'x': base.FieldUpdate( - object_utils.KeyPath.parse('b.x'), - target=sd.b, - field=None, - old_value=1, - new_value=2), - 'y.z': base.FieldUpdate( - object_utils.KeyPath.parse('b.y.z'), - target=sd.b.y, - field=None, - old_value=MISSING_VALUE, - new_value=1), - }, - { # Notification to `sd`. - 'a': base.FieldUpdate( - object_utils.KeyPath.parse('a'), - target=sd, - field=None, - old_value=1, - new_value=2), - 'b.x': base.FieldUpdate( - object_utils.KeyPath.parse('b.x'), - target=sd.b, - field=None, - old_value=1, - new_value=2), - 'b.y.z': base.FieldUpdate( - object_utils.KeyPath.parse('b.y.z'), - target=sd.b.y, - field=None, - old_value=MISSING_VALUE, - new_value=1), - 'c[0].p': base.FieldUpdate( - object_utils.KeyPath.parse('c[0].p'), - target=sd.c[0], - field=None, - old_value=1, - new_value=MISSING_VALUE), - 'c[0].q': base.FieldUpdate( - object_utils.KeyPath.parse('c[0].q'), - target=sd.c[0], - field=None, - old_value=MISSING_VALUE, - new_value=2), - 'e': base.FieldUpdate( - object_utils.KeyPath.parse('e'), - target=sd, - field=None, - old_value=MISSING_VALUE, - new_value='bar') - } - ]) + self.assertEqual( + updates, + [ + { # Notification to `sd.c[0]`. + 'p': base.FieldUpdate( + utils.KeyPath.parse('c[0].p'), + target=sd.c[0], + field=None, + old_value=1, + new_value=MISSING_VALUE, + ), + 'q': base.FieldUpdate( + utils.KeyPath.parse('c[0].q'), + target=sd.c[0], + field=None, + old_value=MISSING_VALUE, + new_value=2, + ), + }, + { # Notification to `sd.c`. + '[0].p': base.FieldUpdate( + utils.KeyPath.parse('c[0].p'), + target=sd.c[0], + field=None, + old_value=1, + new_value=MISSING_VALUE, + ), + '[0].q': base.FieldUpdate( + utils.KeyPath.parse('c[0].q'), + target=sd.c[0], + field=None, + old_value=MISSING_VALUE, + new_value=2, + ), + }, + { # Notification to `sd.b.y`. + 'z': base.FieldUpdate( + utils.KeyPath.parse('b.y.z'), + target=sd.b.y, + field=None, + old_value=MISSING_VALUE, + new_value=1, + ), + }, + { # Notification to `sd.b`. + 'x': base.FieldUpdate( + utils.KeyPath.parse('b.x'), + target=sd.b, + field=None, + old_value=1, + new_value=2, + ), + 'y.z': base.FieldUpdate( + utils.KeyPath.parse('b.y.z'), + target=sd.b.y, + field=None, + old_value=MISSING_VALUE, + new_value=1, + ), + }, + { # Notification to `sd`. + 'a': base.FieldUpdate( + utils.KeyPath.parse('a'), + target=sd, + field=None, + old_value=1, + new_value=2, + ), + 'b.x': base.FieldUpdate( + utils.KeyPath.parse('b.x'), + target=sd.b, + field=None, + old_value=1, + new_value=2, + ), + 'b.y.z': base.FieldUpdate( + utils.KeyPath.parse('b.y.z'), + target=sd.b.y, + field=None, + old_value=MISSING_VALUE, + new_value=1, + ), + 'c[0].p': base.FieldUpdate( + utils.KeyPath.parse('c[0].p'), + target=sd.c[0], + field=None, + old_value=1, + new_value=MISSING_VALUE, + ), + 'c[0].q': base.FieldUpdate( + utils.KeyPath.parse('c[0].q'), + target=sd.c[0], + field=None, + old_value=MISSING_VALUE, + new_value=2, + ), + 'e': base.FieldUpdate( + utils.KeyPath.parse('e'), + target=sd, + field=None, + old_value=MISSING_VALUE, + new_value='bar', + ), + }, + ], + ) def test_rebind_with_fn(self): sd = Dict(a=1, b=dict(x=2, y='foo', z=[0, 1, 2])) @@ -2096,7 +2110,7 @@ def test_compact_exclude_keys(self): def test_compact_python_format(self): self.assertEqual( - object_utils.format( + utils.format( self._dict, compact=True, python_format=True, markdown=True ), "`{'a1': 1, 'a2': {'b1': {'c1': [{'d1': MISSING_VALUE, " @@ -2106,9 +2120,12 @@ def test_compact_python_format(self): def test_noncompact_python_format(self): self.assertEqual( - object_utils.format( - self._dict, compact=False, verbose=False, - python_format=True, markdown=True + utils.format( + self._dict, + compact=False, + verbose=False, + python_format=True, + markdown=True, ), inspect.cleandoc(""" ``` diff --git a/pyglove/core/symbolic/diff.py b/pyglove/core/symbolic/diff.py index 1caf7f6..550a0fc 100644 --- a/pyglove/core/symbolic/diff.py +++ b/pyglove/core/symbolic/diff.py @@ -15,8 +15,8 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import list as pg_list from pyglove.core.symbolic import object as pg_object @@ -112,8 +112,7 @@ def format( return 'No diff' # When there is no diff, but the same value needs to be displayed # we simply return the value. - return object_utils.format( - self.value, compact, verbose, root_indent, **kwargs) + return utils.format(self.value, compact, verbose, root_indent, **kwargs) if self.is_leaf: exclude_keys = kwargs.pop('exclude_keys', None) exclude_keys = exclude_keys or set() @@ -129,7 +128,7 @@ def format( verbose=verbose, root_indent=root_indent, cls_name='', - bracket_type=object_utils.BracketType.SQUARE, + bracket_type=utils.BracketType.SQUARE, **kwargs, ) if self.left is self.right: @@ -141,7 +140,7 @@ def format( verbose=verbose, root_indent=root_indent, cls_name=cls_name, - bracket_type=object_utils.BracketType.ROUND, + bracket_type=utils.BracketType.ROUND, **kwargs, ) @@ -155,7 +154,7 @@ def _html_tree_view_summary( max_summary_len_for_str: int = 80, **kwargs, ) -> Optional[tree_view.Html]: - # pytype: enable=annotation-type-mismatch + # pytype: enable=annotation-type-mismatch if not bool(self): v = self.value if (isinstance(v, (int, float, bool, type(None))) @@ -199,11 +198,11 @@ def _html_tree_view_content( *, view: tree_view.HtmlTreeView, parent: Any = None, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, css_classes: Optional[Sequence[str]] = None, - **kwargs + **kwargs, ) -> tree_view.Html: - root_path = root_path or object_utils.KeyPath() + root_path = root_path or utils.KeyPath() if not bool(self): if self.value == Diff.MISSING: root = tree_view.Html.element( @@ -353,7 +352,8 @@ def diff( right: Any, flatten: bool = False, collapse: Union[bool, str, Callable[[Any, Any], bool]] = 'same_type', - mode: str = 'diff') -> object_utils.Nestable[Diff]: + mode: str = 'diff', +) -> utils.Nestable[Diff]: """Inspect the symbolic diff between two objects. For example:: @@ -479,7 +479,7 @@ def _get_container_ops(container): assert isinstance(container, base.Symbolic) return container.sym_hasattr, container.sym_getattr, container.sym_items - def _diff(x, y) -> Tuple[object_utils.Nestable[Diff], bool]: + def _diff(x, y) -> Tuple[utils.Nestable[Diff], bool]: if x is y or x == y: return (Diff(x, y), False) if not _should_collapse(x, y): @@ -533,5 +533,5 @@ def _child(l, index): if not has_diff and mode == 'diff': diff_value = Diff() if flatten: - diff_value = object_utils.flatten(diff_value) + diff_value = utils.flatten(diff_value) return diff_value diff --git a/pyglove/core/symbolic/flags.py b/pyglove/core/symbolic/flags.py index d08ffe6..5645f29 100644 --- a/pyglove/core/symbolic/flags.py +++ b/pyglove/core/symbolic/flags.py @@ -14,7 +14,7 @@ """Global, thread-local and scoped flags for handling symbolic objects.""" from typing import Any, Callable, ContextManager, Optional -from pyglove.core.object_utils import thread_local +from pyglove.core.utils import thread_local # diff --git a/pyglove/core/symbolic/functor.py b/pyglove/core/symbolic/functor.py index 6fb8a1f..280c40c 100644 --- a/pyglove/core/symbolic/functor.py +++ b/pyglove/core/symbolic/functor.py @@ -22,14 +22,14 @@ import typing from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import flags from pyglove.core.symbolic import object as pg_object -class Functor(pg_object.Object, object_utils.Functor): +class Functor(pg_object.Object, utils.Functor): """Symbolic functions (Functors). A symbolic function is a symbolic class with a ``__call__`` method, whose @@ -124,7 +124,7 @@ def _update_signatures_based_on_schema(cls): if not hasattr(cls, '__orig_init__'): setattr(cls, '__orig_init__', cls.__init__) - @object_utils.explicit_method_override + @utils.explicit_method_override @functools.wraps(pseudo_init) def _init(self, *args, **kwargs): self.__class__.__orig_init__(self, *args, **kwargs) @@ -148,14 +148,15 @@ def __new__(cls, *args, **kwargs): return instance() return instance - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__( self, *args, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, override_args: bool = False, ignore_extra_args: bool = False, - **kwargs): + **kwargs, + ): """Constructor. Args: @@ -182,8 +183,8 @@ def __init__( varargs = list(args[len(signature.args) :]) args = args[: len(signature.args)] else: - arg_phrase = object_utils.auto_plural(len(signature.args), 'argument') - was_phrase = object_utils.auto_plural(len(args), 'was', 'were') + arg_phrase = utils.auto_plural(len(signature.args), 'argument') + was_phrase = utils.auto_plural(len(args), 'was', 'were') raise TypeError( f'{signature.id}() takes {len(signature.args)} ' f'positional {arg_phrase} but {len(args)} {was_phrase} given.' @@ -257,8 +258,7 @@ def _sym_clone(self, deep: bool, memo: Any = None) -> 'Functor': # pylint: enable=protected-access return typing.cast(Functor, other) - def _on_change( - self, field_updates: Dict[object_utils.KeyPath, base.FieldUpdate]): + def _on_change(self, field_updates: Dict[utils.KeyPath, base.FieldUpdate]): """Custom handling field change to update bound args.""" for relative_path, update in field_updates.items(): assert relative_path @@ -406,8 +406,8 @@ def _parse_call_time_overrides( if ignore_extra_args: args = args[: len(signature.args)] else: - arg_phrase = object_utils.auto_plural(len(signature.args), 'argument') - was_phrase = object_utils.auto_plural(len(args), 'was', 'were') + arg_phrase = utils.auto_plural(len(signature.args), 'argument') + was_phrase = utils.auto_plural(len(args), 'was', 'were') raise TypeError( f'{signature.id}() takes {len(signature.args)} ' f'positional {arg_phrase} but {len(args)} {was_phrase} given.' @@ -483,9 +483,10 @@ def _parse_call_time_overrides( missing_required_arg_names.append(arg.name) if missing_required_arg_names: - arg_phrase = object_utils.auto_plural( - len(missing_required_arg_names), 'argument') - args_str = object_utils.comma_delimited_str(missing_required_arg_names) + arg_phrase = utils.auto_plural( + len(missing_required_arg_names), 'argument' + ) + args_str = utils.comma_delimited_str(missing_required_arg_names) raise TypeError( f'{signature.id}() missing {len(missing_required_arg_names)} ' f'required positional {arg_phrase}: {args_str}.' diff --git a/pyglove/core/symbolic/functor_test.py b/pyglove/core/symbolic/functor_test.py index 9096a96..45d43e2 100644 --- a/pyglove/core/symbolic/functor_test.py +++ b/pyglove/core/symbolic/functor_test.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.Functor.""" - import inspect import io import typing import unittest -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import flags from pyglove.core.symbolic.base import from_json_str as pg_from_json_str from pyglove.core.symbolic.dict import Dict @@ -31,7 +29,7 @@ from pyglove.core.symbolic.object import Object -MISSING_VALUE = object_utils.MISSING_VALUE +MISSING_VALUE = utils.MISSING_VALUE class FunctorTest(unittest.TestCase): diff --git a/pyglove/core/symbolic/inferred.py b/pyglove/core/symbolic/inferred.py index eff3b59..13a7b2b 100644 --- a/pyglove/core/symbolic/inferred.py +++ b/pyglove/core/symbolic/inferred.py @@ -14,8 +14,8 @@ """Common inferential values.""" from typing import Any, Tuple -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic.object import Object @@ -65,7 +65,7 @@ def infer(self, **kwargs) -> Any: if v == pg_typing.MISSING_VALUE: if parent is None: raise AttributeError( - object_utils.message_on_path( + utils.message_on_path( ( f'`{self.inference_key}` is not found under its context ' '(along its symbolic parent chain).' diff --git a/pyglove/core/symbolic/list.py b/pyglove/core/symbolic/list.py index db9a9a6..7bfe3da 100644 --- a/pyglove/core/symbolic/list.py +++ b/pyglove/core/symbolic/list.py @@ -18,8 +18,8 @@ import numbers import typing from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Tuple, Union -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import flags @@ -74,13 +74,16 @@ def on_change(updates): """ @classmethod - def partial(cls, - items: Optional[Iterable[Any]] = None, - *, - value_spec: Optional[pg_typing.List] = None, - onchange_callback: Optional[Callable[ - [Dict[object_utils.KeyPath, base.FieldUpdate]], None]] = None, - **kwargs) -> 'List': + def partial( + cls, + items: Optional[Iterable[Any]] = None, + *, + value_spec: Optional[pg_typing.List] = None, + onchange_callback: Optional[ + Callable[[Dict[utils.KeyPath, base.FieldUpdate]], None] + ] = None, + **kwargs, + ) -> 'List': """Class method that creates a partial List object.""" return cls(items, value_spec=value_spec, @@ -89,13 +92,15 @@ def partial(cls, **kwargs) @classmethod - def from_json(cls, - json_value: Any, - *, - value_spec: Optional[pg_typing.List] = None, - allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None, - **kwargs) -> 'List': + def from_json( + cls, + json_value: Any, + *, + value_spec: Optional[pg_typing.List] = None, + allow_partial: bool = False, + root_path: Optional[utils.KeyPath] = None, + **kwargs, + ) -> 'List': """Class method that load an symbolic List from a JSON value. Example:: @@ -136,10 +141,11 @@ def from_json(cls, [ base.from_json( v, - root_path=object_utils.KeyPath(i, root_path), + root_path=utils.KeyPath(i, root_path), allow_partial=allow_partial, - **kwargs - ) for i, v in enumerate(json_value) + **kwargs, + ) + for i, v in enumerate(json_value) ], value_spec=value_spec, root_path=root_path, @@ -151,12 +157,14 @@ def __init__( items: Optional[Iterable[Any]] = None, *, value_spec: Optional[pg_typing.List] = None, - onchange_callback: Optional[Callable[ - [Dict[object_utils.KeyPath, base.FieldUpdate]], None]] = None, + onchange_callback: Optional[ + Callable[[Dict[utils.KeyPath, base.FieldUpdate]], None] + ] = None, allow_partial: bool = False, accessor_writable: bool = True, sealed: bool = False, - root_path: Optional[object_utils.KeyPath] = None): + root_path: Optional[utils.KeyPath] = None, + ): """Constructor. Args: @@ -337,8 +345,8 @@ def _sym_missing(self) -> Dict[Any, Any]: return missing def _sym_rebind( - self, path_value_pairs: typing.Dict[object_utils.KeyPath, Any] - ) -> typing.List[base.FieldUpdate]: + self, path_value_pairs: typing.Dict[utils.KeyPath, Any] + ) -> typing.List[base.FieldUpdate]: """Subclass specific rebind implementation.""" updates = [] @@ -378,14 +386,13 @@ def seal(self, sealed: bool = True) -> 'List': return self def _update_children_paths( - self, - old_path: object_utils.KeyPath, - new_path: object_utils.KeyPath) -> None: + self, old_path: utils.KeyPath, new_path: utils.KeyPath + ) -> None: """Update children paths according to root_path of current node.""" del old_path for idx, item in self.sym_items(): if isinstance(item, base.TopologyAware): - item.sym_setpath(object_utils.KeyPath(idx, new_path)) + item.sym_setpath(utils.KeyPath(idx, new_path)) def _set_item_without_permission_check( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, key: int, value: Any) -> Optional[base.FieldUpdate]: @@ -432,13 +439,15 @@ def _formalized_value(self, idx: int, value: Any): value = base.from_json( value, allow_partial=allow_partial, - root_path=object_utils.KeyPath(idx, self.sym_path)) + root_path=utils.KeyPath(idx, self.sym_path), + ) if self._value_spec and flags.is_type_check_enabled(): value = self._value_spec.element.apply( value, allow_partial=allow_partial, transform_fn=base.symbolic_transform_fn(self._allow_partial), - root_path=object_utils.KeyPath(idx, self.sym_path)) + root_path=utils.KeyPath(idx, self.sym_path), + ) return self._relocate_if_symbolic(idx, value) @property @@ -446,8 +455,7 @@ def _subscribes_field_updates(self) -> bool: """Returns True if current list subscribes field updates.""" return self._onchange_callback is not None - def _on_change(self, - field_updates: Dict[object_utils.KeyPath, base.FieldUpdate]): + def _on_change(self, field_updates: Dict[utils.KeyPath, base.FieldUpdate]): """On change event of List.""" # Do nothing for now to handle changes of List. @@ -463,7 +471,7 @@ def _on_change(self, # Update paths for children. for idx, item in self.sym_items(): if isinstance(item, base.TopologyAware) and item.sym_path.key != idx: - item.sym_setpath(object_utils.KeyPath(idx, self.sym_path)) + item.sym_setpath(utils.KeyPath(idx, self.sym_path)) if self._onchange_callback is not None: self._onchange_callback(field_updates) @@ -723,11 +731,12 @@ def reverse(self) -> None: def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, child_transform: Optional[ - Callable[[object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, ) -> Tuple[bool, 'List']: """Implement pg.typing.CustomTyping interface. @@ -746,9 +755,12 @@ def custom_apply( if self._value_spec: if value_spec and not value_spec.is_compatible(self._value_spec): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'List (spec={self._value_spec!r}) cannot be assigned to an ' - f'incompatible field (spec={value_spec!r}).', path)) + f'incompatible field (spec={value_spec!r}).', + path, + ) + ) if self._allow_partial == allow_partial: proceed_with_standard_apply = False else: @@ -758,9 +770,8 @@ def custom_apply( return (proceed_with_standard_apply, self) def sym_jsonify( - self, - use_inferred: bool = False, - **kwargs) -> object_utils.JSONValueType: + self, use_inferred: bool = False, **kwargs + ) -> utils.JSONValueType: """Converts current list to a list of plain Python objects.""" def json_item(idx): v = self.sym_getattr(idx) @@ -778,7 +789,7 @@ def format( python_format: bool = False, use_inferred: bool = False, cls_name: Optional[str] = None, - bracket_type: object_utils.BracketType = object_utils.BracketType.SQUARE, + bracket_type: utils.BracketType = utils.BracketType.SQUARE, **kwargs, ) -> str: """Formats this List.""" @@ -787,16 +798,22 @@ def _indent(text, indent): return ' ' * 2 * indent + text cls_name = cls_name or '' - open_bracket, close_bracket = object_utils.bracket_chars(bracket_type) + open_bracket, close_bracket = utils.bracket_chars(bracket_type) s = [f'{cls_name}{open_bracket}'] if compact: kv_strs = [] for idx, elem in self.sym_items(): if use_inferred and isinstance(elem, base.Inferential): elem = self.sym_inferred(idx, default=elem) - v_str = object_utils.format( - elem, compact, verbose, root_indent + 1, - python_format=python_format, use_inferred=use_inferred, **kwargs) + v_str = utils.format( + elem, + compact, + verbose, + root_indent + 1, + python_format=python_format, + use_inferred=use_inferred, + **kwargs, + ) if python_format: kv_strs.append(v_str) else: @@ -812,9 +829,15 @@ def _indent(text, indent): s.append('\n') else: s.append(',\n') - v_str = object_utils.format( - elem, compact, verbose, root_indent + 1, - python_format=python_format, use_inferred=use_inferred, **kwargs) + v_str = utils.format( + elem, + compact, + verbose, + root_indent + 1, + python_format=python_format, + use_inferred=use_inferred, + **kwargs, + ) if python_format: s.append(_indent(v_str, root_indent + 1)) else: diff --git a/pyglove/core/symbolic/list_test.py b/pyglove/core/symbolic/list_test.py index 1e3112a..f44fda7 100644 --- a/pyglove/core/symbolic/list_test.py +++ b/pyglove/core/symbolic/list_test.py @@ -20,8 +20,8 @@ from typing import Any import unittest -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import flags from pyglove.core.symbolic import inferred @@ -34,7 +34,7 @@ from pyglove.core.symbolic.pure_symbolic import PureSymbolic -MISSING_VALUE = object_utils.MISSING_VALUE +MISSING_VALUE = utils.MISSING_VALUE class ListTest(unittest.TestCase): @@ -685,7 +685,7 @@ def test_sym_has(self): self.assertTrue(sl.sym_has('[0].x')) self.assertTrue(sl.sym_has('[0].x[0]')) self.assertTrue(sl.sym_has('[0].x[0].y')) - self.assertTrue(sl.sym_has(object_utils.KeyPath.parse('[0].x[0].y'))) + self.assertTrue(sl.sym_has(utils.KeyPath.parse('[0].x[0].y'))) def test_sym_get(self): sl = List([dict(x=[dict(y=1)])]) @@ -1100,7 +1100,7 @@ def test_sym_path(self): self.assertEqual(sl[1].sym_path, '[1]') self.assertEqual(sl[1][0].b.sym_path, '[1][0].b') - sl.sym_setpath(object_utils.KeyPath('a')) + sl.sym_setpath(utils.KeyPath('a')) self.assertEqual(sl.sym_path, 'a') self.assertEqual(sl[0].sym_path, 'a[0]') self.assertEqual(sl[0].a.sym_path, 'a[0].a') @@ -1412,96 +1412,112 @@ def on_dict_change(field_updates): '[3]': 'foo', # Unchanged. '[4]': Insertion('bar') }) - self.assertEqual(updates, [ - { # Notification to `sl[2][0]`. - 'p': base.FieldUpdate( - object_utils.KeyPath.parse('[2][0].p'), - target=sl[2][0], - field=None, - old_value=1, - new_value=MISSING_VALUE), - 'q': base.FieldUpdate( - object_utils.KeyPath.parse('[2][0].q'), - target=sl[2][0], - field=None, - old_value=MISSING_VALUE, - new_value=2), - }, - { # Notification to `sl.c`. - '[0].p': base.FieldUpdate( - object_utils.KeyPath.parse('[2][0].p'), - target=sl[2][0], - field=None, - old_value=1, - new_value=MISSING_VALUE), - '[0].q': base.FieldUpdate( - object_utils.KeyPath.parse('[2][0].q'), - target=sl[2][0], - field=None, - old_value=MISSING_VALUE, - new_value=2), - }, - { # Notification to `sl[1].y`. - 'z': base.FieldUpdate( - object_utils.KeyPath.parse('[1].y.z'), - target=sl[1].y, - field=None, - old_value=MISSING_VALUE, - new_value=1), - }, - { # Notification to `sl.b`. - 'x': base.FieldUpdate( - object_utils.KeyPath.parse('[1].x'), - target=sl[1], - field=None, - old_value=1, - new_value=2), - 'y.z': base.FieldUpdate( - object_utils.KeyPath.parse('[1].y.z'), - target=sl[1].y, - field=None, - old_value=MISSING_VALUE, - new_value=1), - }, - { # Notification to `sl`. - '[0]': base.FieldUpdate( - object_utils.KeyPath.parse('[0]'), - target=sl, - field=None, - old_value=1, - new_value=2), - '[1].x': base.FieldUpdate( - object_utils.KeyPath.parse('[1].x'), - target=sl[1], - field=None, - old_value=1, - new_value=2), - '[1].y.z': base.FieldUpdate( - object_utils.KeyPath.parse('[1].y.z'), - target=sl[1].y, - field=None, - old_value=MISSING_VALUE, - new_value=1), - '[2][0].p': base.FieldUpdate( - object_utils.KeyPath.parse('[2][0].p'), - target=sl[2][0], - field=None, - old_value=1, - new_value=MISSING_VALUE), - '[2][0].q': base.FieldUpdate( - object_utils.KeyPath.parse('[2][0].q'), - target=sl[2][0], - field=None, - old_value=MISSING_VALUE, - new_value=2), - '[4]': base.FieldUpdate( - object_utils.KeyPath.parse('[4]'), - target=sl, - field=None, - old_value=MISSING_VALUE, - new_value='bar') - } - ]) + self.assertEqual( + updates, + [ + { # Notification to `sl[2][0]`. + 'p': base.FieldUpdate( + utils.KeyPath.parse('[2][0].p'), + target=sl[2][0], + field=None, + old_value=1, + new_value=MISSING_VALUE, + ), + 'q': base.FieldUpdate( + utils.KeyPath.parse('[2][0].q'), + target=sl[2][0], + field=None, + old_value=MISSING_VALUE, + new_value=2, + ), + }, + { # Notification to `sl.c`. + '[0].p': base.FieldUpdate( + utils.KeyPath.parse('[2][0].p'), + target=sl[2][0], + field=None, + old_value=1, + new_value=MISSING_VALUE, + ), + '[0].q': base.FieldUpdate( + utils.KeyPath.parse('[2][0].q'), + target=sl[2][0], + field=None, + old_value=MISSING_VALUE, + new_value=2, + ), + }, + { # Notification to `sl[1].y`. + 'z': base.FieldUpdate( + utils.KeyPath.parse('[1].y.z'), + target=sl[1].y, + field=None, + old_value=MISSING_VALUE, + new_value=1, + ), + }, + { # Notification to `sl.b`. + 'x': base.FieldUpdate( + utils.KeyPath.parse('[1].x'), + target=sl[1], + field=None, + old_value=1, + new_value=2, + ), + 'y.z': base.FieldUpdate( + utils.KeyPath.parse('[1].y.z'), + target=sl[1].y, + field=None, + old_value=MISSING_VALUE, + new_value=1, + ), + }, + { # Notification to `sl`. + '[0]': base.FieldUpdate( + utils.KeyPath.parse('[0]'), + target=sl, + field=None, + old_value=1, + new_value=2, + ), + '[1].x': base.FieldUpdate( + utils.KeyPath.parse('[1].x'), + target=sl[1], + field=None, + old_value=1, + new_value=2, + ), + '[1].y.z': base.FieldUpdate( + utils.KeyPath.parse('[1].y.z'), + target=sl[1].y, + field=None, + old_value=MISSING_VALUE, + new_value=1, + ), + '[2][0].p': base.FieldUpdate( + utils.KeyPath.parse('[2][0].p'), + target=sl[2][0], + field=None, + old_value=1, + new_value=MISSING_VALUE, + ), + '[2][0].q': base.FieldUpdate( + utils.KeyPath.parse('[2][0].q'), + target=sl[2][0], + field=None, + old_value=MISSING_VALUE, + new_value=2, + ), + '[4]': base.FieldUpdate( + utils.KeyPath.parse('[4]'), + target=sl, + field=None, + old_value=MISSING_VALUE, + new_value='bar', + ), + }, + ], + ) def test_rebind_with_fn(self): sl = List([0, dict(x=1, y='foo', z=[2, 3, 4])]) @@ -1716,7 +1732,7 @@ def test_compact(self): def test_compact_python_format(self): self.assertEqual( - object_utils.format( + utils.format( self._list, compact=True, python_format=True, markdown=True ), "`[{'a1': 1, 'a2': {'b1': {'c1': [{'d1': MISSING_VALUE, " @@ -1726,9 +1742,12 @@ def test_compact_python_format(self): def test_noncompact_python_format(self): self.assertEqual( - object_utils.format( - self._list, compact=False, verbose=False, - python_format=True, markdown=True + utils.format( + self._list, + compact=False, + verbose=False, + python_format=True, + markdown=True, ), inspect.cleandoc(""" ``` diff --git a/pyglove/core/symbolic/object.py b/pyglove/core/symbolic/object.py index 949d006..4bad3f0 100644 --- a/pyglove/core/symbolic/object.py +++ b/pyglove/core/symbolic/object.py @@ -20,8 +20,8 @@ from typing import Any, Dict, Iterator, List, Optional, Sequence, Union from pyglove.core import coding -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import dict as pg_dict from pyglove.core.symbolic import flags @@ -116,7 +116,7 @@ def register_for_deserialization( # Register class with 'type' property. for key in serialization_keys: - object_utils.JSONConvertible.register( + utils.JSONConvertible.register( key, cls, flags.is_repeated_class_registration_allowed() ) @@ -309,7 +309,7 @@ def __init_subclass__(cls): Args: user_cls: The source class that calls this class method. """ - object_utils.ensure_explicit_method_override( + utils.ensure_explicit_method_override( cls.__init__, ( '`pg.Object.__init__` is a PyGlove managed method. For setting up ' @@ -317,7 +317,8 @@ def __init_subclass__(cls): '`_on_init()`. If you do have a need to override `__init__` and ' 'know the implications, please decorate your overridden method ' 'with `@pg.explicit_method_override`.' - )) + ), + ) # Set `__serialization_key__` before JSONConvertible.__init_subclass__ # is called. @@ -363,11 +364,11 @@ def _normalize_schema(cls, schema: pg_typing.Schema) -> pg_typing.Schema: """Normalizes the schema before applying it.""" schema.set_name(cls.__type_name__) - docstr = object_utils.docstr(cls) + docstr = utils.docstr(cls) if docstr: schema.set_description(docstr.description) - def _formalize_field(path: object_utils.KeyPath, node: Any) -> bool: + def _formalize_field(path: utils.KeyPath, node: Any) -> bool: """Formalize field.""" if isinstance(node, pg_typing.Field): field = node @@ -385,27 +386,29 @@ def _formalize_field(path: object_utils.KeyPath, node: Any) -> bool: if isinstance(field.value, pg_typing.Dict): if field.value.schema is not None: field.value.schema.set_name(f'{schema.name}.{path.path}') - object_utils.traverse(field.value.schema.fields, _formalize_field, - None, path) + utils.traverse( + field.value.schema.fields, _formalize_field, None, path + ) elif isinstance(field.value, pg_typing.List): - _formalize_field(object_utils.KeyPath(0, path), field.value.element) + _formalize_field(utils.KeyPath(0, path), field.value.element) elif isinstance(field.value, pg_typing.Tuple): for i, elem in enumerate(field.value.elements): - _formalize_field(object_utils.KeyPath(i, path), elem) + _formalize_field(utils.KeyPath(i, path), elem) elif isinstance(field.value, pg_typing.Union): for i, c in enumerate(field.value.candidates): _formalize_field( - object_utils.KeyPath(i, path), - pg_typing.Field(field.key, c, 'Union sub-type.')) + utils.KeyPath(i, path), + pg_typing.Field(field.key, c, 'Union sub-type.'), + ) return True - object_utils.traverse(schema.fields, _formalize_field) + utils.traverse(schema.fields, _formalize_field) return schema @classmethod def _finalize_init_arg_list(cls) -> List[str]: """Finalizes init_arg_list based on schema.""" - # Update `init_arg_list`` based on the updated schema. + # Update `init_arg_list`` based on the updated schema. init_arg_list = cls.__schema__.metadata.get('init_arg_list', None) if init_arg_list is None: # Inherit from the first non-empty base if they have the same signature. @@ -476,7 +479,7 @@ def _update_signatures_based_on_schema(cls): # Create a new `__init__` that passes through all the arguments to # in `pg.Object.__init__`. This is needed for each class to use different # signature. - @object_utils.explicit_method_override + @utils.explicit_method_override @functools.wraps(pseudo_init) def _init(self, *args, **kwargs): # We pass through the arguments to `Object.__init__` instead of @@ -539,8 +542,8 @@ def from_json( json_value: Any, *, allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None, - **kwargs + root_path: Optional[utils.KeyPath] = None, + **kwargs, ) -> 'Object': """Class method that load an symbolic Object from a JSON value. @@ -588,15 +591,16 @@ class Foo(pg.Object): for k, v in json_value.items() }) - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__( self, *args, allow_partial: bool = False, sealed: Optional[bool] = None, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, explicit_init: bool = False, - **kwargs): + **kwargs, + ): """Create an Object instance. Args: @@ -638,8 +642,8 @@ def __init__( # Fill field_args and init_args from **kwargs. _, unmatched_keys = self.__class__.__schema__.resolve(list(kwargs.keys())) if unmatched_keys: - arg_phrase = object_utils.auto_plural(len(unmatched_keys), 'argument') - keys_str = object_utils.comma_delimited_str(unmatched_keys) + arg_phrase = utils.auto_plural(len(unmatched_keys), 'argument') + keys_str = utils.comma_delimited_str(unmatched_keys) raise TypeError( f'{self.__class__.__name__}.__init__() got unexpected ' f'keyword {arg_phrase}: {keys_str}') @@ -659,8 +663,8 @@ def __init__( field_args[vararg_name] = list(args[num_named_args:]) args = args[:num_named_args] elif len(args) > len(init_arg_names): - arg_phrase = object_utils.auto_plural(len(init_arg_names), 'argument') - was_phrase = object_utils.auto_plural(len(args), 'was', 'were') + arg_phrase = utils.auto_plural(len(init_arg_names), 'argument') + was_phrase = utils.auto_plural(len(args), 'was', 'were') raise TypeError( f'{self.__class__.__name__}.__init__() takes ' f'{len(init_arg_names)} positional {arg_phrase} but {len(args)} ' @@ -672,7 +676,7 @@ def __init__( for k, v in kwargs.items(): if k in field_args: - values_str = object_utils.comma_delimited_str([field_args[k], v]) + values_str = utils.comma_delimited_str([field_args[k], v]) raise TypeError( f'{self.__class__.__name__}.__init__() got multiple values for ' f'argument \'{k}\': {values_str}.') @@ -687,8 +691,8 @@ def __init__( and field.key not in field_args): missing_args.append(str(field.key)) if missing_args: - arg_phrase = object_utils.auto_plural(len(missing_args), 'argument') - keys_str = object_utils.comma_delimited_str(missing_args) + arg_phrase = utils.auto_plural(len(missing_args), 'argument') + keys_str = utils.comma_delimited_str(missing_args) raise TypeError( f'{self.__class__.__name__}.__init__() missing {len(missing_args)} ' f'required {arg_phrase}: {keys_str}.') @@ -738,8 +742,7 @@ def _on_bound(self) -> None: and during __init__. """ - def _on_change(self, - field_updates: Dict[object_utils.KeyPath, base.FieldUpdate]): + def _on_change(self, field_updates: Dict[utils.KeyPath, base.FieldUpdate]): """Event that is triggered when field values in the subtree are updated. This event will be called @@ -759,8 +762,7 @@ def _on_change(self, del field_updates return self._on_bound() - def _on_path_change( - self, old_path: object_utils.KeyPath, new_path: object_utils.KeyPath): + def _on_path_change(self, old_path: utils.KeyPath, new_path: utils.KeyPath): """Event that is triggered after the symbolic path changes.""" del old_path, new_path @@ -839,8 +841,8 @@ def _sym_getattr( # pytype: disable=signature-mismatch # overriding-parameter- return self._sym_attributes.sym_getattr(key) def _sym_rebind( - self, path_value_pairs: Dict[object_utils.KeyPath, Any] - ) -> List[base.FieldUpdate]: + self, path_value_pairs: Dict[utils.KeyPath, Any] + ) -> List[base.FieldUpdate]: """Rebind current object using object-form members.""" if base.treats_as_sealed(self): raise base.WritePermissionError( @@ -879,9 +881,8 @@ def seal(self, sealed: bool = True) -> 'Object': return self def _update_children_paths( - self, - old_path: object_utils.KeyPath, - new_path: object_utils.KeyPath) -> None: + self, old_path: utils.KeyPath, new_path: utils.KeyPath + ) -> None: """Update children paths according to root_path of current node.""" self._sym_attributes.sym_setpath(new_path) self._on_path_change(old_path, new_path) @@ -965,10 +966,10 @@ def __hash__(self) -> int: return self.sym_hash() return super().__hash__() - def sym_jsonify(self, **kwargs) -> object_utils.JSONValueType: + def sym_jsonify(self, **kwargs) -> utils.JSONValueType: """Converts current object to a dict of plain Python objects.""" json_dict = { - object_utils.JSONConvertible.TYPE_NAME_KEY: ( + utils.JSONConvertible.TYPE_NAME_KEY: ( self.__class__.__serialization_key__ ) } @@ -987,8 +988,9 @@ def format(self, root_indent, cls_name=self.__class__.__name__, key_as_attribute=True, - bracket_type=object_utils.BracketType.ROUND, - **kwargs) + bracket_type=utils.BracketType.ROUND, + **kwargs, + ) base.Symbolic.ObjectType = Object diff --git a/pyglove/core/symbolic/object_test.py b/pyglove/core/symbolic/object_test.py index c85a4ee..066f943 100644 --- a/pyglove/core/symbolic/object_test.py +++ b/pyglove/core/symbolic/object_test.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.Object.""" - import copy import inspect import io @@ -23,8 +21,8 @@ from typing import Any import unittest -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import flags from pyglove.core.symbolic import inferred @@ -42,7 +40,7 @@ from pyglove.core.views.html import tree_view # pylint: disable=unused-import -MISSING_VALUE = object_utils.MISSING_VALUE +MISSING_VALUE = utils.MISSING_VALUE class ObjectMetaTest(unittest.TestCase): @@ -205,7 +203,7 @@ def test_override_init(self): ]) class A(Object): - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, x): super().__init__(int(x)) @@ -214,7 +212,7 @@ def __init__(self, x): class B(A): - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, x): # pylint: disable=super-init-not-called # Forgot to call super().__init__ will trigger error. self.x = x @@ -802,8 +800,8 @@ def _on_bound(self): a = A(A(dict(y=A(1)))) self.assertTrue(a.sym_has('x')) self.assertTrue(a.sym_has('x.x')) - self.assertTrue(a.sym_has(object_utils.KeyPath.parse('x.x.y'))) - self.assertTrue(a.sym_has(object_utils.KeyPath.parse('x.x.y.x'))) + self.assertTrue(a.sym_has(utils.KeyPath.parse('x.x.y'))) + self.assertTrue(a.sym_has(utils.KeyPath.parse('x.x.y.x'))) self.assertFalse(a.sym_has('y')) # `y` is not a symbolic field. def test_sym_get(self): @@ -828,10 +826,10 @@ def _on_bound(self): self.assertIs(a.sym_get('x'), a.x) self.assertIs(a.sym_get('p'), a.sym_getattr('p')) self.assertIs(a.sym_get('x.x'), a.x.x) - self.assertIs(a.sym_get(object_utils.KeyPath.parse('x.x.y')), a.x.x.y) - self.assertIs(a.sym_get(object_utils.KeyPath.parse('x.x.y.x')), a.x.x.y.x) + self.assertIs(a.sym_get(utils.KeyPath.parse('x.x.y')), a.x.x.y) + self.assertIs(a.sym_get(utils.KeyPath.parse('x.x.y.x')), a.x.x.y.x) self.assertIs( - a.sym_get(object_utils.KeyPath.parse('x.x.y.p')), + a.sym_get(utils.KeyPath.parse('x.x.y.p')), a.x.x.y.sym_getattr('p'), ) self.assertIsNone(a.sym_get('x.x.y.q', use_inferred=True)) @@ -1595,7 +1593,7 @@ class A(Object): self.assertEqual(a.x.x.x.sym_path, 'x.x.x') self.assertEqual(a.x.x.x[0].sym_path, 'x.x.x[0]') - a.sym_setpath(object_utils.KeyPath('a')) + a.sym_setpath(utils.KeyPath('a')) self.assertEqual(a.sym_path, 'a') self.assertEqual(a.x.sym_path, 'a.x') self.assertEqual(a.x.x.sym_path, 'a.x.x') @@ -2075,7 +2073,7 @@ class B(Object): class C(B): """Custom __init__.""" - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, a, b): super().__init__(b, x=a) @@ -2450,44 +2448,51 @@ def _onchange_child(field_updates): [ # Set default value from outer space (parent List) for field d1. { - 'd1': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d1'), - target=sd.a2.b1.c1[0], - field=sd.a2.b1.c1[0].value_spec.schema['d1'], - old_value=MISSING_VALUE, - new_value='foo') + 'd1': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d1'), + target=sd.a2.b1.c1[0], + field=sd.a2.b1.c1[0].value_spec.schema['d1'], + old_value=MISSING_VALUE, + new_value='foo', + ) }, # Set default value from outer space (parent List) for field d2. { - 'd2': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d2'), - target=sd.a2.b1.c1[0], - field=sd.a2.b1.c1[0].value_spec.schema['d2'], - old_value=MISSING_VALUE, - new_value=True) - } - ]) + 'd2': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d2'), + target=sd.a2.b1.c1[0], + field=sd.a2.b1.c1[0].value_spec.schema['d2'], + old_value=MISSING_VALUE, + new_value=True, + ) + }, + ], + ) # list get updated after bind with parent structures. - self.assertEqual(list_updates, [{ - '[0].d1': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d1'), - target=sd.a2.b1.c1[0], - field=sd.a2.b1.c1[0].value_spec.schema['d1'], - old_value=MISSING_VALUE, - new_value='foo') - }, { - '[0].d2': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d2'), - target=sd.a2.b1.c1[0], - field=sd.a2.b1.c1[0].value_spec.schema['d2'], - old_value=MISSING_VALUE, - new_value=True) - }]) + self.assertEqual( + list_updates, + [ + { + '[0].d1': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d1'), + target=sd.a2.b1.c1[0], + field=sd.a2.b1.c1[0].value_spec.schema['d1'], + old_value=MISSING_VALUE, + new_value='foo', + ) + }, + { + '[0].d2': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d2'), + target=sd.a2.b1.c1[0], + field=sd.a2.b1.c1[0].value_spec.schema['d2'], + old_value=MISSING_VALUE, + new_value=True, + ) + }, + ], + ) # There are no updates in root. self.assertEqual(root_updates, []) @@ -2510,28 +2515,28 @@ def _onchange_child(field_updates): root_updates[0], { 'a1': base.FieldUpdate( - path=object_utils.KeyPath.parse('a1'), + path=utils.KeyPath.parse('a1'), target=sd, field=sd.value_spec.schema['a1'], old_value=MISSING_VALUE, new_value=1, ), 'a2.b1.c1[0].d1': base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d1'), + path=utils.KeyPath.parse('a2.b1.c1[0].d1'), target=sd.a2.b1.c1[0], field=sd.a2.b1.c1[0].value_spec.schema['d1'], old_value='foo', new_value='bar', ), 'a2.b1.c1[0].d2': base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d2'), + path=utils.KeyPath.parse('a2.b1.c1[0].d2'), target=sd.a2.b1.c1[0], field=sd.a2.b1.c1[0].value_spec.schema['d2'], old_value=True, new_value=False, ), 'a2.b1.c1[0].d3.z': base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d3.z'), + path=utils.KeyPath.parse('a2.b1.c1[0].d3.z'), target=sd.a2.b1.c1[0].d3, field=sd.a2.b1.c1[0].d3.__class__.__schema__['z'], old_value=MISSING_VALUE, @@ -2547,21 +2552,21 @@ def _onchange_child(field_updates): # Root object rebind. { '[0].d1': base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d1'), + path=utils.KeyPath.parse('a2.b1.c1[0].d1'), target=sd.a2.b1.c1[0], field=sd.a2.b1.c1[0].value_spec.schema['d1'], old_value='foo', new_value='bar', ), '[0].d2': base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d2'), + path=utils.KeyPath.parse('a2.b1.c1[0].d2'), target=sd.a2.b1.c1[0], field=sd.a2.b1.c1[0].value_spec.schema['d2'], old_value=True, new_value=False, ), '[0].d3.z': base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d3.z'), + path=utils.KeyPath.parse('a2.b1.c1[0].d3.z'), target=sd.a2.b1.c1[0].d3, field=sd.a2.b1.c1[0].d3.__class__.__schema__['z'], old_value=MISSING_VALUE, @@ -2577,29 +2582,30 @@ def _onchange_child(field_updates): [ # Root object rebind. { - 'd1': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d1'), - target=sd.a2.b1.c1[0], - field=sd.a2.b1.c1[0].value_spec.schema['d1'], - old_value='foo', - new_value='bar'), - 'd2': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d2'), - target=sd.a2.b1.c1[0], - field=sd.a2.b1.c1[0].value_spec.schema['d2'], - old_value=True, - new_value=False), - 'd3.z': - base.FieldUpdate( - path=object_utils.KeyPath.parse('a2.b1.c1[0].d3.z'), - target=sd.a2.b1.c1[0].d3, - field=sd.a2.b1.c1[0].d3.__class__.__schema__['z'], - old_value=MISSING_VALUE, - new_value='foo') + 'd1': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d1'), + target=sd.a2.b1.c1[0], + field=sd.a2.b1.c1[0].value_spec.schema['d1'], + old_value='foo', + new_value='bar', + ), + 'd2': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d2'), + target=sd.a2.b1.c1[0], + field=sd.a2.b1.c1[0].value_spec.schema['d2'], + old_value=True, + new_value=False, + ), + 'd3.z': base.FieldUpdate( + path=utils.KeyPath.parse('a2.b1.c1[0].d3.z'), + target=sd.a2.b1.c1[0].d3, + field=sd.a2.b1.c1[0].d3.__class__.__schema__['z'], + old_value=MISSING_VALUE, + new_value='foo', + ), } - ]) + ], + ) def test_on_change_notification_order(self): change_order = [] @@ -2645,7 +2651,7 @@ def _on_parent_change(self, old_parent, new_parent): y.x = A() self.assertIs(x.old_parent, y) self.assertIsNone(x.new_parent) - self.assertEqual(x.sym_path, object_utils.KeyPath()) + self.assertEqual(x.sym_path, utils.KeyPath()) def test_on_path_change(self): @@ -2656,8 +2662,8 @@ def _on_path_change(self, old_path, new_path): self.new_path = new_path x = A() - x.sym_setpath(object_utils.KeyPath('a')) - self.assertEqual(x.old_path, object_utils.KeyPath()) + x.sym_setpath(utils.KeyPath('a')) + self.assertEqual(x.old_path, utils.KeyPath()) self.assertEqual(x.new_path, 'a') y = Dict(x=x) @@ -3059,7 +3065,7 @@ def test_standard_serialization(self): def test_serialization_with_json_convertible(self): - class Y(object_utils.JSONConvertible): + class Y(utils.JSONConvertible): TYPE_NAME = 'Y' @@ -3079,7 +3085,7 @@ def to_json(self, *args, **kwargs): def from_json(cls, json_dict, *args, **kwargs): return cls(json_dict.pop('value')) - object_utils.JSONConvertible.register(Y.TYPE_NAME, Y) + utils.JSONConvertible.register(Y.TYPE_NAME, Y) a = self._A(Y(1), y=True) self.assertEqual(base.from_json_str(a.to_json_str()), a) @@ -3286,9 +3292,7 @@ def test_compact(self): def test_compact_python_format(self): self.assertEqual( - object_utils.format( - self._a, compact=True, python_format=True, markdown=True - ), + utils.format(self._a, compact=True, python_format=True, markdown=True), "`A(x=[A(x=1, y=None), A(x='foo', y={'a': A(x=True, y=1.0)})], " 'y=MISSING_VALUE)`', ) @@ -3333,9 +3337,12 @@ class A(Object): def test_noncompact_python_format(self): self.assertEqual( - object_utils.format( - self._a, compact=False, verbose=False, python_format=True, - markdown=True + utils.format( + self._a, + compact=False, + verbose=False, + python_format=True, + markdown=True, ), inspect.cleandoc(""" ``` @@ -3502,10 +3509,10 @@ def fn(v, root_indent): return f() if f is not None else None return fn - with object_utils.str_format(custom_format=_method('_repr_xml_')): + with utils.str_format(custom_format=_method('_repr_xml_')): self.assertEqual(str(Bar(Foo())), 'Bar(\n foo = Foo()\n)') - with object_utils.str_format(custom_format=_method('_repr_html_')): + with utils.str_format(custom_format=_method('_repr_html_')): self.assertIn('', str(Bar(Foo()))) diff --git a/pyglove/core/symbolic/origin.py b/pyglove/core/symbolic/origin.py index 654e698..afbc3dd 100644 --- a/pyglove/core/symbolic/origin.py +++ b/pyglove/core/symbolic/origin.py @@ -16,11 +16,11 @@ import traceback from typing import Any, Callable, List, Optional -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.symbolic import flags -class Origin(object_utils.Formattable): +class Origin(utils.Formattable): """Class that represents the origin of a symbolic value. Origin is used for debugging the creation chain of a symbolic value, as @@ -158,14 +158,12 @@ def format( if isinstance(self._source, (str, type(None))): source_str = self._source else: - source_info = object_utils.format( + source_info = utils.format( self._source, compact, verbose, root_indent + 1, **kwargs ) - source_str = object_utils.RawText( - f'{source_info} at 0x{id(self._source):8x}' - ) + source_str = utils.RawText(f'{source_info} at 0x{id(self._source):8x}') - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('tag', self._tag, None), ('source', source_str, None), diff --git a/pyglove/core/symbolic/pure_symbolic.py b/pyglove/core/symbolic/pure_symbolic.py index 659e30d..2fd0ace 100644 --- a/pyglove/core/symbolic/pure_symbolic.py +++ b/pyglove/core/symbolic/pure_symbolic.py @@ -14,8 +14,8 @@ """Interfaces for pure symbolic objects.""" from typing import Any, Callable, Optional, Tuple -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils class PureSymbolic(pg_typing.CustomTyping): @@ -37,11 +37,12 @@ class PureSymbolic(pg_typing.CustomTyping): def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool, child_transform: Optional[ - Callable[[object_utils.KeyPath, pg_typing.Field, Any], Any]] = None + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, ) -> Tuple[bool, Any]: """Custom apply on a value based on its original value spec. diff --git a/pyglove/core/symbolic/ref.py b/pyglove/core/symbolic/ref.py index 4e6f0f6..f343900 100644 --- a/pyglove/core/symbolic/ref.py +++ b/pyglove/core/symbolic/ref.py @@ -16,8 +16,8 @@ import functools import typing from typing import Any, Callable, List, Optional, Tuple, Type -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.symbolic import base from pyglove.core.symbolic import object as pg_object from pyglove.core.views.html import tree_view @@ -100,7 +100,7 @@ def __new__(cls, value: Any, **kwargs): return object.__new__(cls) return value - @object_utils.explicit_method_override + @utils.explicit_method_override def __init__(self, value: Any, **kwargs) -> None: super().__init__(**kwargs) if isinstance(value, Ref): @@ -127,12 +127,13 @@ def infer(self, **kwargs) -> Any: def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: pg_typing.ValueSpec, allow_partial: bool = False, - child_transform: Optional[Callable[ - [object_utils.KeyPath, pg_typing.Field, Any], Any]] = None - ) -> Tuple[bool, Any]: + child_transform: Optional[ + Callable[[utils.KeyPath, pg_typing.Field, Any], Any] + ] = None, + ) -> Tuple[bool, Any]: """Validate candidates during value_spec binding time.""" del child_transform # Check if the field being assigned could accept the referenced value. @@ -166,9 +167,12 @@ def format( root_indent: int = 0, **kwargs: Any, ) -> str: - value_str = object_utils.format( + value_str = utils.format( self._value, - compact=compact, verbose=verbose, root_indent=root_indent + 1) + compact=compact, + verbose=verbose, + root_indent=root_indent + 1, + ) if compact: return f'{self.__class__.__name__}({value_str})' else: diff --git a/pyglove/core/tuning/local_backend.py b/pyglove/core/tuning/local_backend.py index 7a39046..9ea9c18 100644 --- a/pyglove/core/tuning/local_backend.py +++ b/pyglove/core/tuning/local_backend.py @@ -21,8 +21,8 @@ from pyglove.core import geno from pyglove.core import logging -from pyglove.core import object_utils from pyglove.core import symbolic +from pyglove.core import utils from pyglove.core.tuning import backend from pyglove.core.tuning.early_stopping import EarlyStoppingPolicy from pyglove.core.tuning.protocols import Feedback @@ -278,7 +278,7 @@ def format(self, ('step', self._best_trial.final_measurement.step), ('dna', self._best_trial.dna.format(compact=True)) ]) - return object_utils.format(json_repr, compact, False, root_indent, **kwargs) + return utils.format(json_repr, compact, False, root_indent, **kwargs) @backend.add_backend('in-memory') diff --git a/pyglove/core/tuning/protocols.py b/pyglove/core/tuning/protocols.py index 31f3f61..3e725fa 100644 --- a/pyglove/core/tuning/protocols.py +++ b/pyglove/core/tuning/protocols.py @@ -22,9 +22,9 @@ from pyglove.core import geno from pyglove.core import logging -from pyglove.core import object_utils from pyglove.core import symbolic from pyglove.core import typing as pg_typing +from pyglove.core import utils class _DataEntity(symbolic.Object): @@ -116,7 +116,7 @@ def get_reward_for_feedback( return tuple(metric_values) if len(metric_values) > 1 else metric_values[0] -class Result(object_utils.Formattable): +class Result(utils.Formattable): """Interface for tuning result.""" @property @@ -416,7 +416,7 @@ def skip_on_exception(unused_error): error_stack = traceback.format_exc() logging.warning('Skipping trial on unhandled exception: %s', error_stack) self.skip(error_stack) - return object_utils.catch_errors(exceptions, skip_on_exception) + return utils.catch_errors(exceptions, skip_on_exception) @contextlib.contextmanager def ignore_race_condition(self): diff --git a/pyglove/core/typing/annotation_conversion.py b/pyglove/core/typing/annotation_conversion.py index 261244c..abde706 100644 --- a/pyglove/core/typing/annotation_conversion.py +++ b/pyglove/core/typing/annotation_conversion.py @@ -18,7 +18,7 @@ import types import typing -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import annotated from pyglove.core.typing import class_schema from pyglove.core.typing import inspect as pg_inspect @@ -73,7 +73,7 @@ def _value_spec_from_default_value( elif isinstance(value, tuple): value_spec = vs.Tuple( [_value_spec_from_default_value(elem, False) for elem in value]) - elif inspect.isfunction(value) or isinstance(value, object_utils.Functor): + elif inspect.isfunction(value) or isinstance(value, utils.Functor): value_spec = vs.Callable() elif not isinstance(value, type): value_spec = vs.Object(type(value)) @@ -132,7 +132,7 @@ def _sub_value_spec_from_annotation( return vs.Union([vs.List(elem), vs.Tuple(elem)]) # Handling literals. elif origin is typing.Literal: - return vs.Enum(object_utils.MISSING_VALUE, args) + return vs.Enum(utils.MISSING_VALUE, args) # Handling dict. elif origin in (dict, typing.Dict, collections.abc.Mapping): if not args: diff --git a/pyglove/core/typing/callable_ext.py b/pyglove/core/typing/callable_ext.py index 75671ab..eff02ca 100644 --- a/pyglove/core/typing/callable_ext.py +++ b/pyglove/core/typing/callable_ext.py @@ -19,14 +19,14 @@ import types from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import callable_signature _TLS_KEY_PRESET_KWARGS = '__preset_kwargs__' -class PresetArgValue(object_utils.Formattable): +class PresetArgValue(utils.Formattable): """Value placeholder for arguments whose value will be provided by presets. Example: @@ -39,12 +39,12 @@ def foo(x, y=pg.PresetArgValue(default=1)) print(foo(x=1)) # 2: y=1 """ - def __init__(self, default: Any = object_utils.MISSING_VALUE): + def __init__(self, default: Any = utils.MISSING_VALUE): self.default = default @property def has_default(self) -> bool: - return self.default != object_utils.MISSING_VALUE + return self.default != utils.MISSING_VALUE def __eq__(self, other: Any) -> bool: return isinstance(other, PresetArgValue) and ( @@ -55,13 +55,13 @@ def __ne__(self, other: Any) -> bool: return not self.__eq__(other) def format(self, *args, **kwargs): - return object_utils.kvlist_str( + return utils.kvlist_str( [ - ('default', self.default, object_utils.MISSING_VALUE), + ('default', self.default, utils.MISSING_VALUE), ], label='PresetArgValue', *args, - **kwargs + **kwargs, ) @classmethod @@ -182,15 +182,15 @@ def preset_args( Current preset kwargs. """ - parent_presets = object_utils.thread_local_peek( + parent_presets = utils.thread_local_peek( _TLS_KEY_PRESET_KWARGS, _ArgPresets() ) current_preset = parent_presets.derive(kwargs, preset_name, inherit_preset) - object_utils.thread_local_push(_TLS_KEY_PRESET_KWARGS, current_preset) + utils.thread_local_push(_TLS_KEY_PRESET_KWARGS, current_preset) try: yield current_preset finally: - object_utils.thread_local_pop(_TLS_KEY_PRESET_KWARGS, None) + utils.thread_local_pop(_TLS_KEY_PRESET_KWARGS, None) def enable_preset_args( @@ -243,9 +243,7 @@ def decorator(func): @functools.wraps(func) def _func(*args, **kwargs): # Map positional arguments to keyword arguments. - presets = object_utils.thread_local_peek( - _TLS_KEY_PRESET_KWARGS, None - ) + presets = utils.thread_local_peek(_TLS_KEY_PRESET_KWARGS, None) preset_kwargs = presets.get_preset(preset_name) if presets else {} args, kwargs = PresetArgValue.resolve_args( args, kwargs, positional_arg_names, arg_defaults, preset_kwargs, diff --git a/pyglove/core/typing/callable_signature.py b/pyglove/core/typing/callable_signature.py index db474e5..6ed09ec 100644 --- a/pyglove/core/typing/callable_signature.py +++ b/pyglove/core/typing/callable_signature.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, List, Optional, Union from pyglove.core import coding -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import class_schema from pyglove.core.typing import key_specs as ks @@ -141,7 +141,7 @@ class CallableType(enum.Enum): METHOD = 2 -class Signature(object_utils.Formattable): +class Signature(utils.Formattable): """PY3 function signature.""" def __init__(self, @@ -257,7 +257,7 @@ def format( **kwargs, ) -> str: """Format current object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('', self.id, ''), ('args', self.args, []), @@ -271,7 +271,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def annotate( @@ -288,7 +288,7 @@ def annotate( return_value = class_schema.ValueSpec.from_annotation( return_value, auto_typing=True ) - if object_utils.MISSING_VALUE != return_value.default: + if utils.MISSING_VALUE != return_value.default: raise ValueError('return value spec should not have default value.') self.return_value = return_value @@ -337,12 +337,12 @@ def update_arg(arg: Argument, field: class_schema.Field): or field.value.default is None ): field.value.set_default( - arg.value_spec.default, root_path=object_utils.KeyPath(arg.name) + arg.value_spec.default, root_path=utils.KeyPath(arg.name) ) if arg.value_spec.default != field.value.default: if field.value.is_noneable and not arg.value_spec.has_default: - # Special handling noneable which always comes with a default. - field.value.set_default(object_utils.MISSING_VALUE) + # Special handling noneable which always comes with a default. + field.value.set_default(utils.MISSING_VALUE) elif not ( # Special handling Dict type which always has default. isinstance(field.value, class_schema.ValueSpec.DictType) @@ -549,7 +549,7 @@ def from_callable( if not callable(callable_object): raise TypeError(f'{callable_object!r} is not callable.') - if isinstance(callable_object, object_utils.Functor): + if isinstance(callable_object, utils.Functor): assert callable_object.__signature__ is not None return callable_object.__signature__ @@ -566,15 +566,15 @@ def from_callable( description = None args_doc = {} if func.__doc__: - cls_doc = object_utils.DocStr.parse(func.__doc__) + cls_doc = utils.DocStr.parse(func.__doc__) description = cls_doc.short_description args_doc.update(cls_doc.args) if func.__init__.__doc__: - init_doc = object_utils.DocStr.parse(func.__init__.__doc__) + init_doc = utils.DocStr.parse(func.__init__.__doc__) args_doc.update(init_doc.args) - docstr = object_utils.DocStr( - object_utils.DocStrStyle.GOOGLE, + docstr = utils.DocStr( + utils.DocStrStyle.GOOGLE, short_description=description, long_description=None, examples=[], @@ -593,7 +593,7 @@ def from_callable( else CallableType.FUNCTION ) if auto_doc: - docstr = object_utils.docstr(func) + docstr = utils.docstr(func) sig = inspect.signature(func) module_name = getattr(func, '__module__', None) @@ -617,7 +617,7 @@ def from_signature( module_name: Optional[str] = None, qualname: Optional[str] = None, auto_typing: bool = False, - docstr: Union[str, object_utils.DocStr, None] = None, + docstr: Union[str, utils.DocStr, None] = None, parent_module: Optional[types.ModuleType] = None, ) -> 'Signature': """Returns PyGlove signature from Python signature. @@ -644,7 +644,7 @@ def from_signature( varkw = None if isinstance(docstr, str): - docstr = object_utils.DocStr.parse(docstr) + docstr = utils.DocStr.parse(docstr) def make_arg_spec(param: inspect.Parameter) -> Argument: """Makes argument spec from inspect.Parameter.""" @@ -708,7 +708,7 @@ def _append_arg( force_missing_as_default: bool = False, arg_prefix: str = ''): s = [f'{arg_prefix}{arg_name}'] - if arg_spec.annotation != object_utils.MISSING_VALUE: + if arg_spec.annotation != utils.MISSING_VALUE: s.append(f': _annotation_{arg_name}') exec_locals[f'_annotation_{arg_name}'] = arg_spec.annotation if not arg_prefix and (force_missing_as_default or arg_spec.has_default): @@ -761,7 +761,8 @@ def _append_arg( exec_globals=exec_globals, exec_locals=exec_locals, return_type=getattr( - self.return_value, 'annotation', coding.NO_TYPE_ANNOTATION) + self.return_value, 'annotation', coding.NO_TYPE_ANNOTATION + ), ) fn.__module__ = self.module_name fn.__name__ = self.name diff --git a/pyglove/core/typing/callable_signature_test.py b/pyglove/core/typing/callable_signature_test.py index 887973b..b14d853 100644 --- a/pyglove/core/typing/callable_signature_test.py +++ b/pyglove/core/typing/callable_signature_test.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.core.typing.callable_signature.""" - import copy import dataclasses import inspect from typing import List import unittest -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import annotation_conversion # pylint: disable=unused-import from pyglove.core.typing import callable_signature from pyglove.core.typing import class_schema @@ -628,8 +626,8 @@ class B: self.assertIsNotNone(signature.varkw) def test_signature_with_forward_declarations(self): - signature = callable_signature.signature(object_utils.KeyPath) - self.assertIs(signature.get_value_spec('parent').cls, object_utils.KeyPath) + signature = callable_signature.signature(utils.KeyPath) + self.assertIs(signature.get_value_spec('parent').cls, utils.KeyPath) class FromSchemaTest(unittest.TestCase): diff --git a/pyglove/core/typing/class_schema.py b/pyglove/core/typing/class_schema.py index 5a47676..f89313d 100644 --- a/pyglove/core/typing/class_schema.py +++ b/pyglove/core/typing/class_schema.py @@ -20,10 +20,10 @@ import types from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Type, Union -from pyglove.core import object_utils +from pyglove.core import utils -class KeySpec(object_utils.Formattable, object_utils.JSONConvertible): +class KeySpec(utils.Formattable, utils.JSONConvertible): """Interface for key specifications. A key specification determines what keys are acceptable for a symbolic @@ -94,7 +94,7 @@ def from_str(cls, key: str) -> 'KeySpec': assert False, 'Overridden in `key_specs.py`.' -class ForwardRef(object_utils.Formattable): +class ForwardRef(utils.Formattable): """Forward type reference.""" def __init__(self, module: types.ModuleType, name: str): @@ -147,7 +147,7 @@ def format( **kwargs ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('module', self.module.__name__, None), ('name', self.name, None), @@ -180,7 +180,7 @@ def __deepcopy__(self, memo) -> 'ForwardRef': return ForwardRef(self.module, self.name) -class ValueSpec(object_utils.Formattable, object_utils.JSONConvertible): +class ValueSpec(utils.Formattable, utils.JSONConvertible): """Interface for value specifications. A value specification defines what values are acceptable for a symbolic @@ -367,7 +367,7 @@ def set_default( self, default: Any, use_default_apply: bool = True, - root_path: Optional[object_utils.KeyPath] = None + root_path: Optional[utils.KeyPath] = None, ) -> 'ValueSpec': """Sets the default value and returns `self`. @@ -398,13 +398,14 @@ def default(self) -> Any: @property def has_default(self) -> bool: """Returns True if the default value is provided.""" - return self.default != object_utils.MISSING_VALUE + return self.default != utils.MISSING_VALUE @abc.abstractmethod def freeze( self, - permanent_value: Any = object_utils.MISSING_VALUE, - apply_before_use: bool = True) -> 'ValueSpec': + permanent_value: Any = utils.MISSING_VALUE, + apply_before_use: bool = True, + ) -> 'ValueSpec': """Sets the default value using a permanent value and freezes current spec. A frozen value spec will not accept any value that is not the default @@ -471,10 +472,11 @@ def apply( self, value: Any, allow_partial: bool = False, - child_transform: Optional[Callable[ - [object_utils.KeyPath, 'Field', Any], Any]] = None, - root_path: Optional[object_utils.KeyPath] = None, - ) -> Any: + child_transform: Optional[ + Callable[[utils.KeyPath, 'Field', Any], Any] + ] = None, + root_path: Optional[utils.KeyPath] = None, + ) -> Any: """Validates, completes and transforms the input value. Here is the procedure of ``apply``:: @@ -551,7 +553,7 @@ def from_annotation( assert False, 'Overridden in `annotation_conversion.py`.' -class Field(object_utils.Formattable, object_utils.JSONConvertible): +class Field(utils.Formattable, utils.JSONConvertible): """Class that represents the definition of one or a group of attributes. ``Field`` is held by a :class:`pyglove.Schema` object for defining the @@ -681,9 +683,11 @@ def apply( self, value: Any, allow_partial: bool = False, - transform_fn: Optional[Callable[ - [object_utils.KeyPath, 'Field', Any], Any]] = None, - root_path: Optional[object_utils.KeyPath] = None) -> Any: + transform_fn: Optional[ + Callable[[utils.KeyPath, 'Field', Any], Any] + ] = None, + root_path: Optional[utils.KeyPath] = None, + ) -> Any: """Apply current field to a value, which validate and complete the value. Args: @@ -735,7 +739,7 @@ def format( **kwargs, ) -> str: """Format this field into a string.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('key', self._key, None), ('value', self._value, None), @@ -746,7 +750,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: Any) -> Dict[str, Any]: @@ -775,7 +779,7 @@ def __ne__(self, other: Any) -> bool: return not self.__eq__(other) -class Schema(object_utils.Formattable, object_utils.JSONConvertible): +class Schema(utils.Formattable, utils.JSONConvertible): """Class that represents a schema. PyGlove's runtime type system is based on the concept of ``Schema`` ( @@ -959,13 +963,12 @@ def _merge_field( parent_field: Field, child_field: Field) -> Field: """Merge function on field with the same key.""" - if parent_field != object_utils.MISSING_VALUE: - if object_utils.MISSING_VALUE == child_field: + if parent_field != utils.MISSING_VALUE: + if utils.MISSING_VALUE == child_field: if (not self._allow_nonconst_keys and not parent_field.key.is_const): - hints = object_utils.kvlist_str([ - ('base', base.name, None), - ('path', path, None) - ]) + hints = utils.kvlist_str( + [('base', base.name, None), ('path', path, None)] + ) raise ValueError( f'Non-const key {parent_field.key} is not allowed to be ' f'added to the schema. ({hints})') @@ -974,16 +977,15 @@ def _merge_field( try: child_field.extend(parent_field) except Exception as e: # pylint: disable=broad-except - hints = object_utils.kvlist_str([ - ('base', base.name, None), - ('path', path, None) - ]) + hints = utils.kvlist_str( + [('base', base.name, None), ('path', path, None)] + ) raise e.__class__(f'{e} ({hints})').with_traceback( sys.exc_info()[2]) return child_field - self._fields = object_utils.merge([base.fields, self.fields], _merge_field) - self._metadata = object_utils.merge([base.metadata, self.metadata]) + self._fields = utils.merge([base.fields, self.fields], _merge_field) + self._metadata = utils.merge([base.metadata, self.metadata]) # Inherit dynamic field from base if it's not present in the child. if self._dynamic_field is None: @@ -1106,8 +1108,8 @@ def apply( dict_obj: Dict[str, Any], allow_partial: bool = False, child_transform: Optional[Callable[ - [object_utils.KeyPath, Field, Any], Any]] = None, - root_path: Optional[object_utils.KeyPath] = None, + [utils.KeyPath, Field, Any], Any]] = None, + root_path: Optional[utils.KeyPath] = None, ) -> Dict[str, Any]: # pyformat: disable # pyformat: disable """Apply this schema to a dict object, validate and transform it. @@ -1164,18 +1166,18 @@ def apply( keys.append(str(key_spec)) for key in keys: if dict_obj: - value = dict_obj.get(key, object_utils.MISSING_VALUE) + value = dict_obj.get(key, utils.MISSING_VALUE) else: - value = object_utils.MISSING_VALUE + value = utils.MISSING_VALUE # NOTE(daiyip): field.default_value may be MISSING_VALUE too # or partial. - if object_utils.MISSING_VALUE == value: + if utils.MISSING_VALUE == value: value = copy.deepcopy(field.default_value) new_value = field.apply( value, allow_partial=allow_partial, transform_fn=child_transform, - root_path=object_utils.KeyPath(key, root_path) + root_path=utils.KeyPath(key, root_path), ) # NOTE(daiyip): `pg.Dict.__getitem__`` has special logics in handling @@ -1189,10 +1191,12 @@ def apply( dict_obj[key] = new_value return dict_obj - def validate(self, - dict_obj: Dict[str, Any], - allow_partial: bool = False, - root_path: Optional[object_utils.KeyPath] = None) -> None: + def validate( + self, + dict_obj: Dict[str, Any], + allow_partial: bool = False, + root_path: Optional[utils.KeyPath] = None, + ) -> None: """Validates whether dict object is conformed with the schema.""" self.apply( copy.deepcopy(dict_obj), @@ -1257,12 +1261,12 @@ def format( root_indent: int = 0, *, cls_name: Optional[str] = None, - bracket_type: object_utils.BracketType = object_utils.BracketType.ROUND, + bracket_type: utils.BracketType = utils.BracketType.ROUND, fields_only: bool = False, **kwargs, ) -> str: """Format current Schema into nicely printed string.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('name', self.name, None), ('description', self.description, None), diff --git a/pyglove/core/typing/class_schema_test.py b/pyglove/core/typing/class_schema_test.py index fa577bd..0bf5cf4 100644 --- a/pyglove/core/typing/class_schema_test.py +++ b/pyglove/core/typing/class_schema_test.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.core.typing.class_schema.""" - import copy import inspect import sys from typing import Optional, Union, List import unittest -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import annotation_conversion # pylint: disable=unused-import from pyglove.core.typing import class_schema from pyglove.core.typing import custom_typing @@ -204,7 +202,7 @@ def test_repr(self): def test_json_conversion(self): def assert_json_conversion(f): - self.assertEqual(object_utils.from_json(f.to_json()), f) + self.assertEqual(utils.from_json(f.to_json()), f) assert_json_conversion(Field('a', vs.Int())) assert_json_conversion(Field('a', vs.Int(), 'description')) @@ -822,7 +820,7 @@ def test_json_conversion(self): schema = self._create_test_schema() schema.set_description('Foo') schema.set_name('Bar') - schema_copy = object_utils.from_json(schema.to_json()) + schema_copy = utils.from_json(schema.to_json()) # This compares fields only self.assertEqual(schema_copy, schema) diff --git a/pyglove/core/typing/custom_typing.py b/pyglove/core/typing/custom_typing.py index db51bd1..3ea24a8 100644 --- a/pyglove/core/typing/custom_typing.py +++ b/pyglove/core/typing/custom_typing.py @@ -16,7 +16,7 @@ import abc from typing import Any, Callable, Optional, Tuple -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import class_schema @@ -34,11 +34,12 @@ class CustomTyping(metaclass=abc.ABCMeta): @abc.abstractmethod def custom_apply( self, - path: object_utils.KeyPath, + path: utils.KeyPath, value_spec: class_schema.ValueSpec, allow_partial: bool, - child_transform: Optional[Callable[ - [object_utils.KeyPath, class_schema.Field, Any], Any]] = None + child_transform: Optional[ + Callable[[utils.KeyPath, class_schema.Field, Any], Any] + ] = None, ) -> Tuple[bool, Any]: """Custom apply on a value based on its original value spec. diff --git a/pyglove/core/typing/key_specs.py b/pyglove/core/typing/key_specs.py index 07d2472..b2dd589 100644 --- a/pyglove/core/typing/key_specs.py +++ b/pyglove/core/typing/key_specs.py @@ -16,7 +16,7 @@ import re from typing import Any, Dict, Optional -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing.class_schema import KeySpec @@ -38,7 +38,7 @@ def __str_kwargs__(self) -> Dict[str, Any]: return {} -class ConstStrKey(KeySpecBase, object_utils.StrKey): +class ConstStrKey(KeySpecBase, utils.StrKey): """Class that represents a constant string key. Example:: @@ -159,11 +159,9 @@ def regex(self): def format(self, **kwargs): """Format this object.""" - return object_utils.kvlist_str( - [ - ('regex', getattr(self._regex, 'pattern', None), None) - ], - label=self.__class__.__name__ + return utils.kvlist_str( + [('regex', getattr(self._regex, 'pattern', None), None)], + label=self.__class__.__name__, ) def to_json(self, **kwargs: Any) -> Dict[str, Any]: diff --git a/pyglove/core/typing/key_specs_test.py b/pyglove/core/typing/key_specs_test.py index d404465..7190a10 100644 --- a/pyglove/core/typing/key_specs_test.py +++ b/pyglove/core/typing/key_specs_test.py @@ -14,7 +14,7 @@ """Tests for pyglove.core.typing.key_specs.""" import unittest -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import key_specs as ks @@ -22,7 +22,7 @@ class KeySpecTest(unittest.TestCase): """Base class for KeySpec tests.""" def assert_json_conversion(self, spec: ks.KeySpec): - self.assertEqual(object_utils.from_json(object_utils.to_json(spec)), spec) + self.assertEqual(utils.from_json(utils.to_json(spec)), spec) class ConstStrKeyTest(KeySpecTest): @@ -35,9 +35,9 @@ def test_basics(self): self.assertEqual(key.text, 'a') self.assertNotEqual(key, 'b') self.assertIn(key, {'a': 1}) - with object_utils.str_format(markdown=True): + with utils.str_format(markdown=True): self.assertEqual(str(key), 'a') - with object_utils.str_format(markdown=True): + with utils.str_format(markdown=True): self.assertEqual(repr(key), 'a') self.assertTrue(key.match('a')) self.assertFalse(key.match('b')) diff --git a/pyglove/core/typing/type_conversion.py b/pyglove/core/typing/type_conversion.py index 8527cc7..69320e7 100644 --- a/pyglove/core/typing/type_conversion.py +++ b/pyglove/core/typing/type_conversion.py @@ -17,7 +17,7 @@ import datetime from typing import Any, Callable, Optional, Tuple, Type, Union -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import inspect as pg_inspect @@ -135,10 +135,9 @@ def _register_builtin_converters(): lambda x: calendar.timegm(x.timetuple())) # string <=> KeyPath. - register_converter(str, object_utils.KeyPath, - object_utils.KeyPath.parse) - register_converter(object_utils.KeyPath, str, lambda x: x.path) + register_converter(str, utils.KeyPath, utils.KeyPath.parse) + register_converter(utils.KeyPath, str, lambda x: x.path) _register_builtin_converters() -object_utils.JSONConvertible.TYPE_CONVERTER = get_json_value_converter +utils.JSONConvertible.TYPE_CONVERTER = get_json_value_converter diff --git a/pyglove/core/typing/type_conversion_test.py b/pyglove/core/typing/type_conversion_test.py index 697e39b..e5365f3 100644 --- a/pyglove/core/typing/type_conversion_test.py +++ b/pyglove/core/typing/type_conversion_test.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.core.typing.type_conversion.""" - import calendar import datetime import typing import unittest -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import annotation_conversion # pylint: disable=unused-import from pyglove.core.typing import type_conversion from pyglove.core.typing import value_specs as vs @@ -137,17 +135,19 @@ def test_datetime_to_int(self): def test_keypath_to_str(self): """Test built-in converter between string and KeyPath.""" self.assertEqual( - vs.Object(object_utils.KeyPath).apply('a.b.c').keys, - ['a', 'b', 'c']) - self.assertEqual( - vs.Union([vs.Object(object_utils.KeyPath), vs.Int()]).apply( - 'a.b.c').keys, - ['a', 'b', 'c']) + vs.Object(utils.KeyPath).apply('a.b.c').keys, ['a', 'b', 'c'] + ) self.assertEqual( - vs.Str().apply(object_utils.KeyPath.parse('a.b.c')), 'a.b.c') + vs.Union([vs.Object(utils.KeyPath), vs.Int()]).apply('a.b.c').keys, + ['a', 'b', 'c'], + ) + self.assertEqual(vs.Str().apply(utils.KeyPath.parse('a.b.c')), 'a.b.c') self.assertEqual( - type_conversion.get_json_value_converter(object_utils.KeyPath)( - object_utils.KeyPath.parse('a.b.c')), 'a.b.c') + type_conversion.get_json_value_converter(utils.KeyPath)( + utils.KeyPath.parse('a.b.c') + ), + 'a.b.c', + ) if __name__ == '__main__': diff --git a/pyglove/core/typing/typed_missing.py b/pyglove/core/typing/typed_missing.py index 07bd050..e641bdb 100644 --- a/pyglove/core/typing/typed_missing.py +++ b/pyglove/core/typing/typed_missing.py @@ -14,15 +14,15 @@ """Typed value placeholders.""" from typing import Any -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import class_schema # Non-typed missing value. -MISSING_VALUE = object_utils.MISSING_VALUE +MISSING_VALUE = utils.MISSING_VALUE -class MissingValue(object_utils.MissingValue, object_utils.Formattable): +class MissingValue(utils.MissingValue, utils.Formattable): """Class represents missing value **for a specific value spec**.""" def __init__(self, value_spec: class_schema.ValueSpec): @@ -37,15 +37,15 @@ def value_spec(self) -> class_schema.ValueSpec: def __eq__(self, other: Any) -> bool: """Operator ==. - NOTE: `MissingValue(value_spec) and `object_utils.MissingValue` are + NOTE: `MissingValue(value_spec) and `utils.MissingValue` are considered equal, but `MissingValue(value_spec1)` and `MissingValue(value_spec2)` are considered different. That being said, the 'eq' operation is not transitive. However in practice this is not a problem, since user always compare - against `schema.MISSING_VALUE` which is `object_utils.MissingValue`. + against `schema.MISSING_VALUE` which is `utils.MissingValue`. Therefore the `__hash__` function returns the same value with - `object_utils.MissingValue`. + `utils.MissingValue`. Args: other: the value to compare against. @@ -80,4 +80,3 @@ def format(self, def __deepcopy__(self, memo): """Avoid deep copy by copying value_spec by reference.""" return MissingValue(self.value_spec) - diff --git a/pyglove/core/typing/typed_missing_test.py b/pyglove/core/typing/typed_missing_test.py index 7b3a73a..f7ccd45 100644 --- a/pyglove/core/typing/typed_missing_test.py +++ b/pyglove/core/typing/typed_missing_test.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.core.typing.typed_missing.""" - import unittest -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import typed_missing from pyglove.core.typing import value_specs @@ -25,12 +23,12 @@ class MissingValueTest(unittest.TestCase): def test_eq(self): self.assertEqual( - typed_missing.MissingValue(value_specs.Int()), - object_utils.MISSING_VALUE) + typed_missing.MissingValue(value_specs.Int()), utils.MISSING_VALUE + ) self.assertEqual( - object_utils.MISSING_VALUE, - typed_missing.MissingValue(value_specs.Int())) + utils.MISSING_VALUE, typed_missing.MissingValue(value_specs.Int()) + ) self.assertEqual( typed_missing.MissingValue(value_specs.Int()), @@ -54,7 +52,8 @@ def test_hash(self): self.assertEqual( hash(typed_missing.MissingValue(value_specs.Int())), - hash(object_utils.MISSING_VALUE)) + hash(utils.MISSING_VALUE), + ) self.assertNotEqual( hash(typed_missing.MissingValue(value_specs.Int())), diff --git a/pyglove/core/typing/value_specs.py b/pyglove/core/typing/value_specs.py index a77489d..d0b0884 100644 --- a/pyglove/core/typing/value_specs.py +++ b/pyglove/core/typing/value_specs.py @@ -22,7 +22,7 @@ import sys import typing import __main__ -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import callable_signature from pyglove.core.typing import class_schema from pyglove.core.typing import inspect as pg_inspect @@ -35,7 +35,7 @@ from pyglove.core.typing.custom_typing import CustomTyping -MISSING_VALUE = object_utils.MISSING_VALUE +MISSING_VALUE = utils.MISSING_VALUE class _FrozenValuePlaceholder(CustomTyping): @@ -169,7 +169,7 @@ def set_default( self, default: typing.Any, use_default_apply: bool = True, - root_path: typing.Optional[object_utils.KeyPath] = None + root_path: typing.Optional[utils.KeyPath] = None, ) -> ValueSpec: """Set default value and returns `self`.""" # NOTE(daiyip): Default can be schema.MissingValue types, all are @@ -246,13 +246,13 @@ def apply( self, value: typing.Any, allow_partial: bool = False, - child_transform: typing.Optional[typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], - typing.Any - ]] = None, - root_path: typing.Optional[object_utils.KeyPath] = None) -> typing.Any: # pyformat: disable pylint: disable=line-too-long + child_transform: typing.Optional[ + typing.Callable[[utils.KeyPath, Field, typing.Any], typing.Any] + ] = None, + root_path: typing.Optional[utils.KeyPath] = None, + ) -> typing.Any: # pyformat: disable pylint: disable=line-too-long """Apply spec to validate and complete value.""" - root_path = root_path or object_utils.KeyPath() + root_path = root_path or utils.KeyPath() if self.frozen and self.default is not _FROZEN_VALUE_PLACEHOLDER: # Always return the default value if a field is frozen. if MISSING_VALUE != value and self.default != value: @@ -291,9 +291,8 @@ def apply( value = self._transform(value) except Exception as e: # pylint: disable=broad-except raise e.__class__( - object_utils.message_on_path( - str(e), root_path) - ).with_traceback(sys.exc_info()[2]) + utils.message_on_path(str(e), root_path) + ).with_traceback(sys.exc_info()[2]) return self.skip_user_transform.apply( value, @@ -309,9 +308,12 @@ def apply( converter = type_conversion.get_converter(type(value), self.value_type) if converter is None: raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Expect {self.value_type} ' - f'but encountered {type(value)!r}: {value}.', root_path)) + f'but encountered {type(value)!r}: {value}.', + root_path, + ) + ) value = converter(value) # NOTE(daiyip): child nodes validation and transformation is done before @@ -325,15 +327,18 @@ def apply( self._validate(root_path, value) return value - def _validate(self, path: object_utils.KeyPath, value: typing.Any): + def _validate(self, path: utils.KeyPath, value: typing.Any): """Validation on applied value. Child class can override.""" - def _apply(self, - value: typing.Any, - allow_partial: bool, - child_transform: typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], typing.Any], - root_path: object_utils.KeyPath) -> typing.Any: + def _apply( + self, + value: typing.Any, + allow_partial: bool, + child_transform: typing.Callable[ + [utils.KeyPath, Field, typing.Any], typing.Any + ], + root_path: utils.KeyPath, + ) -> typing.Any: """Customized apply so each subclass can override.""" del allow_partial del child_transform @@ -401,17 +406,17 @@ def format( **kwargs ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('default', self._default, MISSING_VALUE), ('noneable', self._is_noneable, False), - ('frozen', self._frozen, False) + ('frozen', self._frozen, False), ], label=self.__class__.__name__, compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) @@ -552,15 +557,18 @@ def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: **kwargs, ) - def _validate(self, path: object_utils.KeyPath, value: str) -> None: + def _validate(self, path: utils.KeyPath, value: str) -> None: """Validates applied value.""" if not self._regex: return if not self._regex.match(value): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'String {value!r} does not match ' - f'regular expression {self._regex.pattern!r}.', path)) + f'regular expression {self._regex.pattern!r}.', + path, + ) + ) @property def regex(self): @@ -595,7 +603,7 @@ def format( ) -> str: """Format this object.""" regex_pattern = self._regex.pattern if self._regex else None - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('default', self._default, MISSING_VALUE), ('regex', regex_pattern, None), @@ -606,7 +614,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def _eq(self, other: 'Str') -> bool: @@ -665,15 +673,17 @@ def max_value(self) -> typing.Optional[numbers.Number]: """Returns maximum value of acceptable values.""" return self._max_value - def _validate(self, path: object_utils.KeyPath, - value: numbers.Number) -> None: + def _validate(self, path: utils.KeyPath, value: numbers.Number) -> None: """Validates applied value.""" if ((self._min_value is not None and value < self._min_value) or (self._max_value is not None and value > self._max_value)): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Value {value} is out of range ' - f'(min={self._min_value}, max={self._max_value}).', path)) + f'(min={self._min_value}, max={self._max_value}).', + path, + ) + ) def _extend(self, base: 'Number') -> None: """Number specific extend.""" @@ -726,19 +736,19 @@ def format( **kwargs ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('default', self._default, MISSING_VALUE), ('min', self._min_value, None), ('max', self._max_value, None), ('noneable', self._is_noneable, False), - ('frozen', self._frozen, False) + ('frozen', self._frozen, False), ], label=self.__class__.__name__, compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -944,13 +954,14 @@ def values(self) -> typing.List[typing.Any]: """Returns all acceptable values of this spec.""" return self._values - def _validate(self, path: object_utils.KeyPath, value: typing.Any) -> None: + def _validate(self, path: utils.KeyPath, value: typing.Any) -> None: """Validates applied value.""" if value not in self._values: raise ValueError( - object_utils.message_on_path( - f'Value {value!r} is not in candidate list {self._values}.', - path)) + utils.message_on_path( + f'Value {value!r} is not in candidate list {self._values}.', path + ) + ) def _extend(self, base: 'Enum') -> None: """Enum specific extend.""" @@ -995,7 +1006,7 @@ def format( **kwargs ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('default', self._default, MISSING_VALUE), ('values', self._values, None), @@ -1005,7 +1016,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -1133,12 +1144,15 @@ def max_size(self) -> typing.Optional[int]: """Returns max size of the list.""" return self._element.key.max_value # pytype: disable=attribute-error # bind-properties - def _apply(self, - value: typing.List[typing.Any], - allow_partial: bool, - child_transform: typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], typing.Any], - root_path: object_utils.KeyPath) -> typing.Any: + def _apply( + self, + value: typing.List[typing.Any], + allow_partial: bool, + child_transform: typing.Callable[ + [utils.KeyPath, Field, typing.Any], typing.Any + ], + root_path: utils.KeyPath, + ) -> typing.Any: """List specific apply.""" # NOTE(daiyip): for symbolic List, write access using `__setitem__` will # trigger permission error when `accessor_writable` is set to False. @@ -1155,27 +1169,35 @@ def _fn(i, v): getitem = getattr(value, 'sym_getattr', value.__getitem__) for i in range(len(value)): v = self._element.apply( - getitem(i), allow_partial=allow_partial, transform_fn=child_transform, - root_path=object_utils.KeyPath(i, root_path)) + getitem(i), + allow_partial=allow_partial, + transform_fn=child_transform, + root_path=utils.KeyPath(i, root_path), + ) if getitem(i) is not v: set_item(i, v) return value - def _validate( - self, path: object_utils.KeyPath, value: typing.List[typing.Any]): + def _validate(self, path: utils.KeyPath, value: typing.List[typing.Any]): """Validates applied value.""" if len(value) < self.min_size: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Length of list {value!r} is less than ' - f'min size ({self.min_size}).', path)) + f'min size ({self.min_size}).', + path, + ) + ) if self.max_size is not None: if len(value) > self.max_size: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Length of list {value!r} is greater than ' - f'max size ({self.max_size}).', path)) + f'max size ({self.max_size}).', + path, + ) + ) def _extend(self, base: 'List') -> None: """List specific extend.""" @@ -1203,7 +1225,7 @@ def format( **kwargs, ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('', self._element.value, None), ('min_size', self.min_size, 0), @@ -1216,7 +1238,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -1409,35 +1431,50 @@ def __len__(self) -> int: """Returns length of this tuple.""" return len(self._elements) if self.fixed_length else 0 - def _apply(self, - value: typing.Tuple[typing.Any, ...], - allow_partial: bool, - child_transform: typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], typing.Any], - root_path: object_utils.KeyPath) -> typing.Any: + def _apply( + self, + value: typing.Tuple[typing.Any, ...], + allow_partial: bool, + child_transform: typing.Callable[ + [utils.KeyPath, Field, typing.Any], typing.Any + ], + root_path: utils.KeyPath, + ) -> typing.Any: """Tuple specific apply.""" if self.fixed_length: if len(value) != len(self.elements): raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Length of input tuple ({len(value)}) does not match the ' f'length of spec ({len(self.elements)}). ' - f'Input: {value}, Spec: {self!r}', root_path)) + f'Input: {value}, Spec: {self!r}', + root_path, + ) + ) else: if len(value) < self.min_size: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Length of tuple {value} is less than ' - f'min size ({self.min_size}).', root_path)) + f'min size ({self.min_size}).', + root_path, + ) + ) if self.max_size is not None and len(value) > self.max_size: raise ValueError( - object_utils.message_on_path( + utils.message_on_path( f'Length of tuple {value} is greater than ' - f'max size ({self.max_size}).', root_path)) + f'max size ({self.max_size}).', + root_path, + ) + ) return tuple([ self._elements[i if self.fixed_length else 0].apply( # pylint: disable=g-complex-comprehension - v, allow_partial=allow_partial, transform_fn=child_transform, - root_path=object_utils.KeyPath(i, root_path)) + v, + allow_partial=allow_partial, + transform_fn=child_transform, + root_path=utils.KeyPath(i, root_path), + ) for i, v in enumerate(value) ]) @@ -1527,7 +1564,7 @@ def format( else: value = self._elements[0].value default_min, default_max = 0, None - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('', value, None), ('default', self._default, MISSING_VALUE), @@ -1540,7 +1577,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -1687,7 +1724,7 @@ def set_default( self, default: typing.Any, use_default_apply: bool = True, - root_path: typing.Optional[object_utils.KeyPath] = None, + root_path: typing.Optional[utils.KeyPath] = None, ) -> ValueSpec: if MISSING_VALUE == default and self._schema: self._use_generated_default = True @@ -1707,12 +1744,15 @@ def forward_refs(self) -> typing.Set[class_schema.ForwardRef]: forward_refs.update(field.value.forward_refs) return forward_refs - def _apply(self, - value: typing.Dict[typing.Any, typing.Any], - allow_partial: bool, - child_transform: typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], typing.Any], - root_path: object_utils.KeyPath) -> typing.Any: + def _apply( + self, + value: typing.Dict[typing.Any, typing.Any], + allow_partial: bool, + child_transform: typing.Callable[ + [utils.KeyPath, Field, typing.Any], typing.Any + ], + root_path: utils.KeyPath, + ) -> typing.Any: """Dict specific apply.""" if not self._schema: return value @@ -1756,11 +1796,13 @@ def format( **kwargs, ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ - ('fields', - list(self._schema.values()) if self._schema else None, - None), + ( + 'fields', + list(self._schema.values()) if self._schema else None, + None, + ), ('noneable', self._is_noneable, False), ('frozen', self._frozen, False), ], @@ -1892,19 +1934,24 @@ def cls(self) -> typing.Type[typing.Any]: def value_type(self) -> typing.Type[typing.Any]: return self.cls - def _apply(self, - value: typing.Any, - allow_partial: bool, - child_transform: typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], typing.Any], - root_path: object_utils.KeyPath) -> typing.Any: + def _apply( + self, + value: typing.Any, + allow_partial: bool, + child_transform: typing.Callable[ + [utils.KeyPath, Field, typing.Any], typing.Any + ], + root_path: utils.KeyPath, + ) -> typing.Any: """Object specific apply.""" del child_transform - if isinstance(value, object_utils.MaybePartial): + if isinstance(value, utils.MaybePartial): if not allow_partial and value.is_partial: raise ValueError( - object_utils.message_on_path( - f'Object {value} is not fully bound.', root_path)) + utils.message_on_path( + f'Object {value} is not fully bound.', root_path + ) + ) return value def extend(self, base: ValueSpec) -> ValueSpec: @@ -1955,9 +2002,9 @@ def format( name = self._forward_ref.name else: name = self._value_type.__name__ - return object_utils.kvlist_str( + return utils.kvlist_str( [ - ('', object_utils.RawText(name), None), + ('', utils.RawText(name), None), ('default', self._default, MISSING_VALUE), ('noneable', self._is_noneable, False), ('frozen', self._frozen, False), @@ -1966,7 +2013,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -2103,12 +2150,12 @@ def return_value(self) -> typing.Optional[ValueSpec]: """Value spec for return value.""" return self._return_value - def _validate(self, path: object_utils.KeyPath, value: typing.Any) -> None: + def _validate(self, path: utils.KeyPath, value: typing.Any) -> None: """Validate applied value.""" if not callable(value): raise TypeError( - object_utils.message_on_path( - f'Value is not callable: {value!r}.', path)) + utils.message_on_path(f'Value is not callable: {value!r}.', path) + ) # Shortcircuit if there is no signature to check. if not (self._args or self._kw or self._return_value): @@ -2120,10 +2167,12 @@ def _validate(self, path: object_utils.KeyPath, value: typing.Any) -> None: if len(self._args) > len(signature.args) and not signature.has_varargs: raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'{signature.id} only take {len(signature.args)} positional ' f'arguments, while {len(self._args)} is required by {self!r}.', - path)) + path, + ) + ) # Check positional arguments. for i in range(min(len(self._args), len(signature.args))): @@ -2131,10 +2180,12 @@ def _validate(self, path: object_utils.KeyPath, value: typing.Any) -> None: dest_spec = signature.args[i].value_spec if not dest_spec.is_compatible(src_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Value spec of positional argument {i} is not compatible. ' f'Expected: {dest_spec!r}, Actual: {src_spec!r}.', - path)) + path, + ) + ) if len(self._args) > len(signature.args): assert signature.varargs assert isinstance(signature.varargs.value_spec, List), signature.varargs @@ -2143,10 +2194,13 @@ def _validate(self, path: object_utils.KeyPath, value: typing.Any) -> None: src_spec = self._args[i] if not dest_spec.is_compatible(src_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Value spec of positional argument {i} is not compatible ' f'with the value spec of *{signature.varargs.name}. ' - f'Expected: {dest_spec!r}, Actual: {src_spec!r}.', path)) + f'Expected: {dest_spec!r}, Actual: {src_spec!r}.', + path, + ) + ) # Check keyword arguments. dest_args = signature.args + signature.kwonlyargs @@ -2159,37 +2213,46 @@ def _validate(self, path: object_utils.KeyPath, value: typing.Any) -> None: if dest_spec is not None: if not dest_spec.is_compatible(src_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Value spec of keyword argument {arg_name!r} is not ' f'compatible. Expected: {src_spec!r}, Actual: {dest_spec!r}.', - path)) + path, + ) + ) elif signature.varkw: assert isinstance(signature.varkw.value_spec, Dict), signature.varkw varkw_value_spec = signature.varkw.value_spec.schema.dynamic_field.value # pytype: disable=attribute-error if not varkw_value_spec.is_compatible(src_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Value spec of keyword argument {arg_name!r} is not ' - f'compatible with the value spec of ' + 'compatible with the value spec of ' f'**{signature.varkw.name}. ' f'Expected: {varkw_value_spec!r}, ' - f'Actual: {src_spec!r}.', path)) + f'Actual: {src_spec!r}.', + path, + ) + ) else: raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Keyword argument {arg_name!r} does not exist in {value!r}.', - path)) + path, + ) + ) # Check return value if (self._return_value and signature.return_value and not isinstance(signature.return_value, Any) and not self._return_value.is_compatible(signature.return_value)): raise TypeError( - object_utils.message_on_path( - f'Value spec for return value is not compatible. ' + utils.message_on_path( + 'Value spec for return value is not compatible. ' f'Expected: {self._return_value!r}, ' f'Actual: {signature.return_value!r} ({value!r}).', - path)) + path, + ) + ) def _extend(self, base: 'Callable') -> None: """Callable specific extension.""" @@ -2260,14 +2323,14 @@ def format( **kwargs, ) -> str: """Format this spec.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('args', self._args, []), ('kw', self._kw, []), ('returns', self._return_value, None), ('default', self._default, MISSING_VALUE), ('noneable', self._is_noneable, False), - ('frozen', self._frozen, False) + ('frozen', self._frozen, False), ], label=self.__class__.__name__, compact=compact, @@ -2359,14 +2422,14 @@ def __init__( returns=returns, default=default, transform=transform, - callable_type=object_utils.Functor, + callable_type=utils.Functor, is_noneable=is_noneable, frozen=frozen, ) def _annotate(self) -> typing.Any: """Annotate with PyType annotation.""" - return object_utils.Functor + return utils.Functor def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: exclude_keys = kwargs.pop('exclude_keys', set()) @@ -2431,12 +2494,14 @@ def forward_refs(self) -> typing.Set[class_schema.ForwardRef]: return set() return set([self._forward_ref]) - def _validate(self, path: object_utils.KeyPath, value: typing.Type) -> None: # pylint: disable=g-bare-generic + def _validate(self, path: utils.KeyPath, value: typing.Type) -> None: # pylint: disable=g-bare-generic """Validate applied value.""" if self.type_resolved and not pg_inspect.is_subclass(value, self.type): raise ValueError( - object_utils.message_on_path( - f'{value!r} is not a subclass of {self.type!r}', path)) + utils.message_on_path( + f'{value!r} is not a subclass of {self.type!r}', path + ) + ) def _is_compatible(self, other: 'Type') -> bool: """Type specific compatiblity check.""" @@ -2472,7 +2537,7 @@ def format( **kwargs, ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('', self._expected_type, None), ('default', self._default, MISSING_VALUE), @@ -2483,7 +2548,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -2673,14 +2738,15 @@ def get_candidate( return c return None - def _apply(self, - value: typing.Any, - allow_partial: bool, - child_transform: typing.Callable[ - [object_utils.KeyPath, Field, typing.Any], - typing.Any - ], - root_path: object_utils.KeyPath) -> typing.Any: + def _apply( + self, + value: typing.Any, + allow_partial: bool, + child_transform: typing.Callable[ + [utils.KeyPath, Field, typing.Any], typing.Any + ], + root_path: utils.KeyPath, + ) -> typing.Any: """Union specific apply.""" # Match strong-typed candidates first. if not self.type_resolved: @@ -2782,7 +2848,7 @@ def format( **kwargs, ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('', self._candidates, None), ('default', self._default, MISSING_VALUE), @@ -2794,7 +2860,7 @@ def format( verbose=verbose, root_indent=root_indent, list_wrap_threshold=kwargs.pop('list_wrap_threshold', 20), - **kwargs + **kwargs, ) def to_json(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]: @@ -2942,17 +3008,17 @@ def format( **kwargs, ) -> str: """Format this object.""" - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('default', self._default, MISSING_VALUE), ('frozen', self._frozen, False), - ('annotation', self._annotation, MISSING_VALUE) + ('annotation', self._annotation, MISSING_VALUE), ], label=self.__class__.__name__, compact=compact, verbose=verbose, root_indent=root_indent, - **kwargs + **kwargs, ) def annotate(self, annotation: typing.Any) -> 'Any': @@ -3022,7 +3088,7 @@ def _get_spec_callsite_module(): def ensure_value_spec( value_spec: class_schema.ValueSpec, src_spec: class_schema.ValueSpec, - root_path: typing.Optional[object_utils.KeyPath] = None + root_path: typing.Optional[utils.KeyPath] = None, ) -> typing.Optional[class_schema.ValueSpec]: """Extract counter part from value spec that matches dest spec type. @@ -3043,7 +3109,10 @@ def ensure_value_spec( return None if not src_spec.is_compatible(value_spec): raise TypeError( - object_utils.message_on_path( + utils.message_on_path( f'Source spec {src_spec} is not compatible with destination ' - f'spec {value_spec}.', root_path)) + f'spec {value_spec}.', + root_path, + ) + ) return value_spec diff --git a/pyglove/core/typing/value_specs_test.py b/pyglove/core/typing/value_specs_test.py index b2d6614..1085449 100644 --- a/pyglove/core/typing/value_specs_test.py +++ b/pyglove/core/typing/value_specs_test.py @@ -18,7 +18,7 @@ import typing import unittest -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.typing import annotation_conversion # pylint: disable=unused-import from pyglove.core.typing import callable_signature from pyglove.core.typing import class_schema @@ -36,11 +36,10 @@ class ValueSpecTest(unittest.TestCase): """Base class for value spec test.""" def assert_json_conversion(self, v): - self.assertEqual(object_utils.from_json(v.to_json()), v) + self.assertEqual(utils.from_json(v.to_json()), v) def assert_json_conversion_key(self, v, key): - self.assertEqual( - v.to_json()[object_utils.JSONConvertible.TYPE_NAME_KEY], key) + self.assertEqual(v.to_json()[utils.JSONConvertible.TYPE_NAME_KEY], key) class BoolTest(ValueSpecTest): @@ -1039,9 +1038,9 @@ def test_apply(self): self.assertEqual(vs.List(vs.Int().noneable()).apply([1, None]), [1, None]) # Automatic conversion: str -> KeyPath is a registered conversion. # See 'type_conversion.py'. - l = vs.List(vs.Object(object_utils.KeyPath)).apply(['a.b.c']) - self.assertIsInstance(l[0], object_utils.KeyPath) - self.assertEqual(l, [object_utils.KeyPath.parse('a.b.c')]) + l = vs.List(vs.Object(utils.KeyPath)).apply(['a.b.c']) + self.assertIsInstance(l[0], utils.KeyPath) + self.assertEqual(l, [utils.KeyPath.parse('a.b.c')]) self.assertEqual( vs.List(vs.Int()).apply( typed_missing.MISSING_VALUE, allow_partial=True), @@ -2046,7 +2045,7 @@ def test_json_conversion(self): x = vs.Dict([ ('a', int, 'field 1', dict(x=1)), ]).freeze(dict(a=1)) - y = object_utils.from_json(x.to_json()) + y = utils.from_json(x.to_json()) self.assert_json_conversion( vs.Dict([ ('a', int, 'field 1', dict(x=1)), @@ -2089,7 +2088,7 @@ def __init__(self, value=0): class C(A): pass - class D(C, object_utils.MaybePartial): + class D(C, utils.MaybePartial): def missing_values(self): return {'SOME_KEY': 'SOME_VALUE'} @@ -2378,7 +2377,7 @@ def test_generic(self): def test_value_type(self): self.assertIsNone(vs.Callable().value_type) - self.assertEqual(vs.Functor().annotation, object_utils.Functor) + self.assertEqual(vs.Functor().annotation, utils.Functor) def test_forward_refs(self): self.assertEqual(vs.Callable().forward_refs, set()) @@ -2584,7 +2583,7 @@ def _value_is_one(func): def test_apply_on_functor(self): - class FunctorWithRegularArgs(object_utils.Functor): + class FunctorWithRegularArgs(utils.Functor): __signature__ = Signature( callable_type=callable_signature.CallableType.FUNCTION, @@ -2630,7 +2629,7 @@ def __call__(self, a, b): def test_apply_on_functor_with_varargs(self): - class FunctorWithVarArgs(object_utils.Functor): + class FunctorWithVarArgs(utils.Functor): __signature__ = Signature( callable_type=callable_signature.CallableType.FUNCTION, @@ -2781,7 +2780,7 @@ def test_json_conversion(self): ) ) x = vs.Callable([vs.Int()], default=lambda x: x + 1).noneable() - y = object_utils.from_json(x.to_json()) + y = utils.from_json(x.to_json()) self.assert_json_conversion( vs.Callable([vs.Int()], default=lambda x: x + 1).noneable() ) diff --git a/pyglove/core/utils/__init__.py b/pyglove/core/utils/__init__.py new file mode 100644 index 0000000..d3c42d3 --- /dev/null +++ b/pyglove/core/utils/__init__.py @@ -0,0 +1,159 @@ +# Copyright 2022 The PyGlove Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=line-too-long +"""Utility library that provides common traits for objects in Python. + +Overview +-------- + +``pg.utils`` sits at the bottom of all PyGlove modules and empowers other +modules with the following features: + + +---------------------+--------------------------------------------+ + | Functionality | API | + +=====================+============================================+ + | Formatting | :class:`pg.Formattable`, | + | | | + | | :func:`pg.format`, | + | | | + | | :func:`pg.print`, | + | | | + | | :func:`pg.utils.kvlist_str`, | + | | | + | | :func:`pg.utils.quote_if_str`, | + | | | + | | :func:`pg.utils.message_on_path` | + +---------------------+--------------------------------------------+ + | Serialization | :class:`pg.JSONConvertible`, | + | | | + | | :func:`pg.registered_types`, | + | | | + | | :func:`pg.utils.to_json`, | + | | | + | | :func:`pg.utils.from_json`, | + +---------------------+--------------------------------------------+ + | Partial construction| :class:`pg.MaybePartial`, | + | | | + | | :const:`pg.MISSING_VALUE`. | + +---------------------+--------------------------------------------+ + | Hierarchical key | :class:`pg.KeyPath` | + | representation | | + +---------------------+--------------------------------------------+ + | Hierarchical object | :func:`pg.utils.traverse` | + | traversal | | + +---------------------+--------------------------------------------+ + | Hierarchical object | :func:`pg.utils.transform`, | + | transformation | | + | | :func:`pg.utils.merge`, | + | | | + | | :func:`pg.utils.canonicalize`, | + | | | + | | :func:`pg.utils.flatten` | + +---------------------+--------------------------------------------+ + | Docstr handling | :class:`pg.docstr`, | + +---------------------+--------------------------------------------+ + | Error handling | :class:`pg.catch_errors`, | + +---------------------+--------------------------------------------+ +""" +# pylint: enable=line-too-long +# pylint: disable=g-bad-import-order +# pylint: disable=g-importing-member + +# Handling JSON conversion. +from pyglove.core.utils.json_conversion import Nestable +from pyglove.core.utils.json_conversion import JSONValueType + +from pyglove.core.utils.json_conversion import JSONConvertible +from pyglove.core.utils.json_conversion import from_json +from pyglove.core.utils.json_conversion import to_json +from pyglove.core.utils.json_conversion import registered_types + +# Handling formatting. +from pyglove.core.utils.formatting import Formattable +from pyglove.core.utils.formatting import format # pylint: disable=redefined-builtin +from pyglove.core.utils.formatting import printv as print # pylint: disable=redefined-builtin +from pyglove.core.utils.formatting import kvlist_str +from pyglove.core.utils.formatting import quote_if_str +from pyglove.core.utils.formatting import maybe_markdown_quote +from pyglove.core.utils.formatting import comma_delimited_str +from pyglove.core.utils.formatting import camel_to_snake +from pyglove.core.utils.formatting import auto_plural +from pyglove.core.utils.formatting import BracketType +from pyglove.core.utils.formatting import bracket_chars +from pyglove.core.utils.formatting import RawText + +# Context managers for defining the default format for __str__ and __repr__. +from pyglove.core.utils.formatting import str_format +from pyglove.core.utils.formatting import repr_format + +# Value location. +from pyglove.core.utils.value_location import KeyPath +from pyglove.core.utils.value_location import KeyPathSet +from pyglove.core.utils.value_location import StrKey +from pyglove.core.utils.value_location import message_on_path + +# Value markers. +from pyglove.core.utils.missing import MissingValue +from pyglove.core.utils.missing import MISSING_VALUE + +# Handling hierarchical. +from pyglove.core.utils.hierarchical import traverse +from pyglove.core.utils.hierarchical import transform +from pyglove.core.utils.hierarchical import flatten +from pyglove.core.utils.hierarchical import canonicalize +from pyglove.core.utils.hierarchical import merge +from pyglove.core.utils.hierarchical import merge_tree +from pyglove.core.utils.hierarchical import is_partial +from pyglove.core.utils.hierarchical import try_listify_dict_with_int_keys + +# Common traits. +from pyglove.core.utils.common_traits import MaybePartial +from pyglove.core.utils.common_traits import Functor + +from pyglove.core.utils.common_traits import explicit_method_override +from pyglove.core.utils.common_traits import ensure_explicit_method_override + +# Handling thread local values. +from pyglove.core.utils.thread_local import thread_local_value_scope +from pyglove.core.utils.thread_local import thread_local_has +from pyglove.core.utils.thread_local import thread_local_set +from pyglove.core.utils.thread_local import thread_local_get +from pyglove.core.utils.thread_local import thread_local_del +from pyglove.core.utils.thread_local import thread_local_increment +from pyglove.core.utils.thread_local import thread_local_decrement +from pyglove.core.utils.thread_local import thread_local_push +from pyglove.core.utils.thread_local import thread_local_pop +from pyglove.core.utils.thread_local import thread_local_peek + +# Handling docstrings. +from pyglove.core.utils.docstr_utils import DocStr +from pyglove.core.utils.docstr_utils import DocStrStyle +from pyglove.core.utils.docstr_utils import DocStrEntry +from pyglove.core.utils.docstr_utils import DocStrExample +from pyglove.core.utils.docstr_utils import DocStrArgument +from pyglove.core.utils.docstr_utils import DocStrReturns +from pyglove.core.utils.docstr_utils import DocStrRaises +from pyglove.core.utils.docstr_utils import docstr + +# Handling exceptions. +from pyglove.core.utils.error_utils import catch_errors +from pyglove.core.utils.error_utils import CatchErrorsContext +from pyglove.core.utils.error_utils import ErrorInfo + +# Timing. +from pyglove.core.utils.timing import timeit +from pyglove.core.utils.timing import TimeIt + +# pylint: enable=g-importing-member +# pylint: enable=g-bad-import-order diff --git a/pyglove/core/object_utils/common_traits.py b/pyglove/core/utils/common_traits.py similarity index 100% rename from pyglove/core/object_utils/common_traits.py rename to pyglove/core/utils/common_traits.py diff --git a/pyglove/core/object_utils/common_traits_test.py b/pyglove/core/utils/common_traits_test.py similarity index 91% rename from pyglove/core/object_utils/common_traits_test.py rename to pyglove/core/utils/common_traits_test.py index 9227c20..6305f70 100644 --- a/pyglove/core/object_utils/common_traits_test.py +++ b/pyglove/core/utils/common_traits_test.py @@ -11,10 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.common_traits.""" - import unittest -from pyglove.core.object_utils import common_traits +from pyglove.core.utils import common_traits class ExplicitlyOverrideTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/docstr_utils.py b/pyglove/core/utils/docstr_utils.py similarity index 100% rename from pyglove/core/object_utils/docstr_utils.py rename to pyglove/core/utils/docstr_utils.py diff --git a/pyglove/core/object_utils/docstr_utils_test.py b/pyglove/core/utils/docstr_utils_test.py similarity index 97% rename from pyglove/core/object_utils/docstr_utils_test.py rename to pyglove/core/utils/docstr_utils_test.py index 19b2eb0..6562336 100644 --- a/pyglove/core/object_utils/docstr_utils_test.py +++ b/pyglove/core/utils/docstr_utils_test.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.docstr_utils.""" - import inspect import unittest -from pyglove.core.object_utils import docstr_utils +from pyglove.core.utils import docstr_utils class DocStrTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/error_utils.py b/pyglove/core/utils/error_utils.py similarity index 97% rename from pyglove/core/object_utils/error_utils.py rename to pyglove/core/utils/error_utils.py index 5fdd7c7..85584e1 100644 --- a/pyglove/core/object_utils/error_utils.py +++ b/pyglove/core/utils/error_utils.py @@ -21,8 +21,8 @@ import traceback from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union -from pyglove.core.object_utils import formatting -from pyglove.core.object_utils import json_conversion +from pyglove.core.utils import formatting +from pyglove.core.utils import json_conversion @dataclasses.dataclass(frozen=True) @@ -102,7 +102,7 @@ def catch_errors( Examples:: - with pg.object_utils.catch_errors( + with pg.utils.catch_errors( [ RuntimeErrror, (ValueError, 'Input is wrong.') diff --git a/pyglove/core/object_utils/error_utils_test.py b/pyglove/core/utils/error_utils_test.py similarity index 98% rename from pyglove/core/object_utils/error_utils_test.py rename to pyglove/core/utils/error_utils_test.py index 509b0ca..7d1a191 100644 --- a/pyglove/core/object_utils/error_utils_test.py +++ b/pyglove/core/utils/error_utils_test.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect import unittest -from pyglove.core.object_utils import error_utils +from pyglove.core.utils import error_utils class ErrorInfoTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/formatting.py b/pyglove/core/utils/formatting.py similarity index 99% rename from pyglove/core/object_utils/formatting.py rename to pyglove/core/utils/formatting.py index 1f4e76f..2d7fb94 100644 --- a/pyglove/core/object_utils/formatting.py +++ b/pyglove/core/utils/formatting.py @@ -18,7 +18,7 @@ import io import sys from typing import Any, Callable, ContextManager, Dict, List, Optional, Sequence, Set, Tuple -from pyglove.core.object_utils import thread_local +from pyglove.core.utils import thread_local _TLS_STR_FORMAT_KWARGS = '_str_format_kwargs' diff --git a/pyglove/core/object_utils/formatting_test.py b/pyglove/core/utils/formatting_test.py similarity index 99% rename from pyglove/core/object_utils/formatting_test.py rename to pyglove/core/utils/formatting_test.py index 799e736..f1c427c 100644 --- a/pyglove/core/object_utils/formatting_test.py +++ b/pyglove/core/utils/formatting_test.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.formatting.""" import inspect import unittest -from pyglove.core.object_utils import formatting +from pyglove.core.utils import formatting class Foo(formatting.Formattable): diff --git a/pyglove/core/object_utils/hierarchical.py b/pyglove/core/utils/hierarchical.py similarity index 93% rename from pyglove/core/object_utils/hierarchical.py rename to pyglove/core/utils/hierarchical.py index b6af126..aa2ac2e 100644 --- a/pyglove/core/object_utils/hierarchical.py +++ b/pyglove/core/utils/hierarchical.py @@ -14,9 +14,9 @@ """Operating hierarchical object.""" from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from pyglove.core.object_utils import common_traits -from pyglove.core.object_utils.missing import MISSING_VALUE -from pyglove.core.object_utils.value_location import KeyPath +from pyglove.core.utils import common_traits +from pyglove.core.utils.missing import MISSING_VALUE +from pyglove.core.utils.value_location import KeyPath def traverse(value: Any, @@ -33,7 +33,7 @@ def preorder_visit(path, value): print(path) tree = {'a': [{'c': [1, 2]}, {'d': {'g': (3, 4)}}], 'b': 'foo'} - pg.object_utils.traverse(tree, preorder_visit) + pg.utils.traverse(tree, preorder_visit) # Should print: # 'a' @@ -48,10 +48,10 @@ def preorder_visit(path, value): Args: value: A maybe hierarchical value to traverse. - preorder_visitor_fn: Preorder visitor function. - Function signature is (path, value) -> should_continue. - postorder_visitor_fn: Postorder visitor function. - Function signature is (path, value) -> should_continue. + preorder_visitor_fn: Preorder visitor function. Function signature is (path, + value) -> should_continue. + postorder_visitor_fn: Postorder visitor function. Function signature is + (path, value) -> should_continue. root_path: The key path of the root value. Returns: @@ -111,7 +111,7 @@ def _remove_int(path, value): 'e': 'bar', 'f': 4 } - output = pg.object_utils.transform(inputs, _remove_int) + output = pg.utils.transform(inputs, _remove_int) assert output == { 'a': { 'c': ['bar'], @@ -123,11 +123,11 @@ def _remove_int(path, value): Args: value: Any python value type. If value is a list of dict, transformation will occur recursively. - transform_fn: Transform function in signature - (path, value) -> new value - If new value is MISSING_VALUE, key will be deleted. + transform_fn: Transform function in signature (path, value) -> new value If + new value is MISSING_VALUE, key will be deleted. root_path: KeyPath of the root. inplace: If True, perform transformation in place. + Returns: Transformed value. """ @@ -186,7 +186,7 @@ def flatten(src: Any, flatten_complex_keys: bool = True) -> Any: 'b': 'hi', 'c': None } - output = pg.object_utils.flatten(inputs) + output = pg.utils.flatten(inputs) assert output == { 'a.e': 1, 'a.f[0].g': 2, @@ -200,9 +200,9 @@ def flatten(src: Any, flatten_complex_keys: bool = True) -> Any: Args: src: source value to flatten. flatten_complex_keys: if True, complex keys such as 'x.y' will be flattened - as 'x'.'y'. For example: - {'a': {'b.c': 1}} will be flattened into {'a.b.c': 1} if this flag is on, - otherwise it will be flattened as {'a[b.c]': 1}. + as 'x'.'y'. For example: {'a': {'b.c': 1}} will be flattened into + {'a.b.c': 1} if this flag is on, otherwise it will be flattened as + {'a[b.c]': 1}. Returns: For primitive value types, `src` itself will be returned. @@ -464,7 +464,7 @@ def merge(value_list: List[Any], 'f': 10 } } - output = pg.object_utils.merge([original, patch]) + output = pg.utils.merge([original, patch]) assert output == { 'a': 1, # b is updated. @@ -486,14 +486,12 @@ def merge(value_list: List[Any], value. The merge process will keep input values intact. merge_fn: A function to handle value merge that will be called for updated or added keys. If a branch is added/updated, the root of branch will be - passed to merge_fn. - the signature of function is: - `(path, left_value, right_value) -> final_value` - If a key is only present in src dict, old_value is MISSING_VALUE; - If a key is only present in dest dict, new_value is MISSING_VALUE; - otherwise both new_value and old_value are filled. - If final_value is MISSING_VALUE for a path, it will be removed from its - parent collection. + passed to merge_fn. the signature of function is: `(path, left_value, + right_value) -> final_value` If a key is only present in src dict, + old_value is MISSING_VALUE; If a key is only present in dest dict, + new_value is MISSING_VALUE; otherwise both new_value and old_value are + filled. If final_value is MISSING_VALUE for a path, it will be removed + from its parent collection. Returns: A merged value. diff --git a/pyglove/core/object_utils/hierarchical_test.py b/pyglove/core/utils/hierarchical_test.py similarity index 99% rename from pyglove/core/object_utils/hierarchical_test.py rename to pyglove/core/utils/hierarchical_test.py index 911ec8d..1bbcb12 100644 --- a/pyglove/core/object_utils/hierarchical_test.py +++ b/pyglove/core/utils/hierarchical_test.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.hierarchical.""" - import unittest -from pyglove.core.object_utils import common_traits -from pyglove.core.object_utils import hierarchical -from pyglove.core.object_utils import value_location +from pyglove.core.utils import common_traits +from pyglove.core.utils import hierarchical +from pyglove.core.utils import value_location class TraverseTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/json_conversion.py b/pyglove/core/utils/json_conversion.py similarity index 100% rename from pyglove/core/object_utils/json_conversion.py rename to pyglove/core/utils/json_conversion.py diff --git a/pyglove/core/object_utils/json_conversion_test.py b/pyglove/core/utils/json_conversion_test.py similarity index 99% rename from pyglove/core/object_utils/json_conversion_test.py rename to pyglove/core/utils/json_conversion_test.py index 2c253b8..4021527 100644 --- a/pyglove/core/object_utils/json_conversion_test.py +++ b/pyglove/core/utils/json_conversion_test.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.json_conversion.""" - import abc import typing import unittest -from pyglove.core.object_utils import json_conversion from pyglove.core.typing import inspect as pg_inspect +from pyglove.core.utils import json_conversion class X: diff --git a/pyglove/core/object_utils/missing.py b/pyglove/core/utils/missing.py similarity index 92% rename from pyglove/core/object_utils/missing.py rename to pyglove/core/utils/missing.py index 96952b4..218f13e 100644 --- a/pyglove/core/object_utils/missing.py +++ b/pyglove/core/utils/missing.py @@ -14,8 +14,8 @@ """Representing missing value for a field.""" from typing import Any, Dict -from pyglove.core.object_utils import formatting -from pyglove.core.object_utils import json_conversion +from pyglove.core.utils import formatting +from pyglove.core.utils import json_conversion class MissingValue(formatting.Formattable, json_conversion.JSONConvertible): diff --git a/pyglove/core/object_utils/missing_test.py b/pyglove/core/utils/missing_test.py similarity index 89% rename from pyglove/core/object_utils/missing_test.py rename to pyglove/core/utils/missing_test.py index 558e54f..c3f1d11 100644 --- a/pyglove/core/object_utils/missing_test.py +++ b/pyglove/core/utils/missing_test.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.missing.""" - import unittest -from pyglove.core.object_utils import json_conversion -from pyglove.core.object_utils import missing +from pyglove.core.utils import json_conversion +from pyglove.core.utils import missing class MissingValueTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/thread_local.py b/pyglove/core/utils/thread_local.py similarity index 100% rename from pyglove/core/object_utils/thread_local.py rename to pyglove/core/utils/thread_local.py diff --git a/pyglove/core/object_utils/thread_local_test.py b/pyglove/core/utils/thread_local_test.py similarity index 98% rename from pyglove/core/object_utils/thread_local_test.py rename to pyglove/core/utils/thread_local_test.py index a6b3355..18ff4b2 100644 --- a/pyglove/core/object_utils/thread_local_test.py +++ b/pyglove/core/utils/thread_local_test.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.thread_local.""" - import threading import time import unittest -from pyglove.core.object_utils import thread_local +from pyglove.core.utils import thread_local class ThreadLocalTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/timing.py b/pyglove/core/utils/timing.py similarity index 97% rename from pyglove/core/object_utils/timing.py rename to pyglove/core/utils/timing.py index 2372ac7..c332608 100644 --- a/pyglove/core/object_utils/timing.py +++ b/pyglove/core/utils/timing.py @@ -18,9 +18,9 @@ import time from typing import Any, Dict, List, Optional -from pyglove.core.object_utils import error_utils -from pyglove.core.object_utils import json_conversion -from pyglove.core.object_utils import thread_local +from pyglove.core.utils import error_utils +from pyglove.core.utils import json_conversion +from pyglove.core.utils import thread_local class TimeIt: diff --git a/pyglove/core/object_utils/timing_test.py b/pyglove/core/utils/timing_test.py similarity index 97% rename from pyglove/core/object_utils/timing_test.py rename to pyglove/core/utils/timing_test.py index cac5399..22a07c9 100644 --- a/pyglove/core/object_utils/timing_test.py +++ b/pyglove/core/utils/timing_test.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import time import unittest -from pyglove.core.object_utils import json_conversion -from pyglove.core.object_utils import timing +from pyglove.core.utils import json_conversion +from pyglove.core.utils import timing class TimeItTest(unittest.TestCase): diff --git a/pyglove/core/object_utils/value_location.py b/pyglove/core/utils/value_location.py similarity index 99% rename from pyglove/core/object_utils/value_location.py rename to pyglove/core/utils/value_location.py index 5857b95..ad67a54 100644 --- a/pyglove/core/object_utils/value_location.py +++ b/pyglove/core/utils/value_location.py @@ -17,7 +17,7 @@ import copy as copy_lib import operator from typing import Any, Callable, Iterable, Iterator, List, Optional, Union -from pyglove.core.object_utils import formatting +from pyglove.core.utils import formatting class KeyPath(formatting.Formattable): @@ -822,7 +822,7 @@ class StrKey(metaclass=abc.ABCMeta): Example:: - class MyKey(pg.object_utils.StrKey): + class MyKey(pg.utils.StrKey): def __init__(self, name): self.name = name diff --git a/pyglove/core/object_utils/value_location_test.py b/pyglove/core/utils/value_location_test.py similarity index 99% rename from pyglove/core/object_utils/value_location_test.py rename to pyglove/core/utils/value_location_test.py index 99b553c..389f357 100644 --- a/pyglove/core/object_utils/value_location_test.py +++ b/pyglove/core/utils/value_location_test.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for pyglove.object_utils.value_location.""" - import unittest -from pyglove.core.object_utils import formatting -from pyglove.core.object_utils import value_location +from pyglove.core.utils import formatting +from pyglove.core.utils import value_location KeyPath = value_location.KeyPath diff --git a/pyglove/core/views/base.py b/pyglove/core/views/base.py index 2d55155..2464157 100644 --- a/pyglove/core/views/base.py +++ b/pyglove/core/views/base.py @@ -136,18 +136,18 @@ def _myview2_render(self, value, **kwargs): from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Set, Type, Union from pyglove.core import io as pg_io -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils # Type definition for the value filter function. NodeFilter = Callable[ [ - object_utils.KeyPath, # The path to the value. - Any, # Current value. - Any, # Parent value + utils.KeyPath, # The path to the value. + Any, # Current value. + Any, # Parent value ], - bool # Whether to include the value. + bool, # Whether to include the value. ] @@ -157,7 +157,7 @@ def _myview2_render(self, value, **kwargs): _TLS_KEY_VIEW_OPTIONS = '__view_options__' -class Content(object_utils.Formattable, metaclass=abc.ABCMeta): +class Content(utils.Formattable, metaclass=abc.ABCMeta): """Content: A type of media to be displayed in a view. For example, `pg.Html` is a `Content` type that represents HTML to be @@ -171,7 +171,7 @@ class Content(object_utils.Formattable, metaclass=abc.ABCMeta): None ] - class SharedParts(object_utils.Formattable): + class SharedParts(utils.Formattable): """A part of the content that should appear just once. For example, `pg.Html.Styles` is a `SharedParts` type that represents @@ -244,7 +244,7 @@ def format( **kwargs ) -> str: if compact: - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('parts', self._parts, {}), ], @@ -252,7 +252,7 @@ def format( compact=compact, verbose=verbose, root_indent=root_indent, - bracket_type=object_utils.BracketType.ROUND, + bracket_type=utils.BracketType.ROUND, ) return self.content @@ -363,17 +363,16 @@ def format(self, """Formats the Content object.""" del kwargs if compact: - return object_utils.kvlist_str( + return utils.kvlist_str( [ ('content', self.content, ''), - ] + [ - (k, v, None) for k, v in self._shared_parts.items() - ], + ] + + [(k, v, None) for k, v in self._shared_parts.items()], label=self.__class__.__name__, compact=compact, verbose=verbose, root_indent=root_indent, - bracket_type=object_utils.BracketType.ROUND, + bracket_type=utils.BracketType.ROUND, ) return self.to_str(content_only=content_only) @@ -427,9 +426,9 @@ def view( value: Any, *, name: Optional[str] = None, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, view_id: str = 'html-tree-view', - **kwargs + **kwargs, ) -> Content: """Views an object through generating content based on a specific view. @@ -451,8 +450,7 @@ def view( with view_options(**kwargs) as options: view_object = View.create(view_id) return view_object.render( - value, name=name, root_path=root_path or object_utils.KeyPath(), - **options + value, name=name, root_path=root_path or utils.KeyPath(), **options ) @@ -471,14 +469,14 @@ def view_options(**kwargs) -> Iterator[Dict[str, Any]]: Yields: The merged keyword arguments. """ - parent_options = object_utils.thread_local_peek(_TLS_KEY_VIEW_OPTIONS, {}) + parent_options = utils.thread_local_peek(_TLS_KEY_VIEW_OPTIONS, {}) # Deep merge the two dict. - options = object_utils.merge([parent_options, kwargs]) - object_utils.thread_local_push(_TLS_KEY_VIEW_OPTIONS, options) + options = utils.merge([parent_options, kwargs]) + utils.thread_local_push(_TLS_KEY_VIEW_OPTIONS, options) try: yield options finally: - object_utils.thread_local_pop(_TLS_KEY_VIEW_OPTIONS) + utils.thread_local_pop(_TLS_KEY_VIEW_OPTIONS) class View(metaclass=abc.ABCMeta): @@ -697,8 +695,8 @@ def render( value: Any, *, name: Optional[str] = None, - root_path: Optional[object_utils.KeyPath] = None, - **kwargs + root_path: Optional[utils.KeyPath] = None, + **kwargs, ) -> Content: """Renders the input value. @@ -789,20 +787,18 @@ def _track_rendering( ) -> Iterator[Any]: """Context manager for tracking the value being rendered.""" del self - rendering_stack = object_utils.thread_local_get( + rendering_stack = utils.thread_local_get( _TLS_KEY_OPERAND_STACK_BY_METHOD, {} ) callsite_value = rendering_stack.get(view_method, None) rendering_stack[view_method] = value - object_utils.thread_local_set( - _TLS_KEY_OPERAND_STACK_BY_METHOD, rendering_stack - ) + utils.thread_local_set(_TLS_KEY_OPERAND_STACK_BY_METHOD, rendering_stack) try: yield callsite_value finally: if callsite_value is None: rendering_stack.pop(view_method) if not rendering_stack: - object_utils.thread_local_del(_TLS_KEY_OPERAND_STACK_BY_METHOD) + utils.thread_local_del(_TLS_KEY_OPERAND_STACK_BY_METHOD) else: rendering_stack[view_method] = callsite_value diff --git a/pyglove/core/views/html/base.py b/pyglove/core/views/html/base.py index 26a0eec..1079e8a 100644 --- a/pyglove/core/views/html/base.py +++ b/pyglove/core/views/html/base.py @@ -20,8 +20,8 @@ import typing from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union -from pyglove.core import object_utils from pyglove.core import typing as pg_typing +from pyglove.core import utils from pyglove.core.views import base NestableStr = Union[ @@ -33,11 +33,11 @@ NodeFilter = base.NodeFilter NodeColor = Callable[ [ - object_utils.KeyPath, # The path to the value. - Any, # Current value. - Any, # Parent value + utils.KeyPath, # The path to the value. + Any, # Current value. + Any, # Parent value ], - Optional[str] # The color of the node. + Optional[str], # The color of the node. ] @@ -386,7 +386,7 @@ def concate( cls, nestable_str: NestableStr, separator: str = ' ', dedup: bool = True ) -> Optional[str]: """Concates the string nodes in a nestable object.""" - flattened = object_utils.flatten(nestable_str) + flattened = utils.flatten(nestable_str) if isinstance(flattened, str): return flattened elif isinstance(flattened, dict): @@ -456,8 +456,8 @@ def render( value: Any, *, name: Optional[str] = None, - root_path: Optional[object_utils.KeyPath] = None, - **kwargs + root_path: Optional[utils.KeyPath] = None, + **kwargs, ) -> Html: """Renders the input value into an HTML object.""" # For customized HtmlConvertible objects, call their `to_html()` method. @@ -473,8 +473,8 @@ def _render( value: Any, *, name: Optional[str] = None, - root_path: Optional[object_utils.KeyPath] = None, - **kwargs + root_path: Optional[utils.KeyPath] = None, + **kwargs, ) -> Html: """View's implementation of HTML rendering.""" @@ -483,9 +483,9 @@ def to_html( value: Any, *, name: Optional[str] = None, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, view_id: str = 'html-tree-view', - **kwargs + **kwargs, ) -> Html: """Returns the HTML representation of a value. @@ -517,10 +517,10 @@ def to_html_str( value: Any, *, name: Optional[str] = None, - root_path: Optional[object_utils.KeyPath] = None, + root_path: Optional[utils.KeyPath] = None, view_id: str = 'html-tree-view', content_only: bool = False, - **kwargs + **kwargs, ) -> str: """Returns a HTML str for a value. @@ -545,4 +545,3 @@ def to_html_str( view_id=view_id, **kwargs ).to_str(content_only=content_only) - diff --git a/pyglove/core/views/html/controls/base.py b/pyglove/core/views/html/controls/base.py index 116e463..3deae39 100644 --- a/pyglove/core/views/html/controls/base.py +++ b/pyglove/core/views/html/controls/base.py @@ -19,7 +19,7 @@ import sys from typing import Annotated, Any, Dict, Iterator, List, Optional, Union -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.symbolic import object as pg_object from pyglove.core.views.html import base @@ -89,16 +89,16 @@ def _to_html(self, **kwargs) -> Html: @contextlib.contextmanager def track_scripts(cls) -> Iterator[List[str]]: del cls - all_tracked = object_utils.thread_local_get(_TLS_TRACKED_SCRIPTS, []) + all_tracked = utils.thread_local_get(_TLS_TRACKED_SCRIPTS, []) current = [] all_tracked.append(current) - object_utils.thread_local_set(_TLS_TRACKED_SCRIPTS, all_tracked) + utils.thread_local_set(_TLS_TRACKED_SCRIPTS, all_tracked) try: yield current finally: all_tracked.pop(-1) if not all_tracked: - object_utils.thread_local_del(_TLS_TRACKED_SCRIPTS) + utils.thread_local_del(_TLS_TRACKED_SCRIPTS) def _sync_members(self, **fields) -> None: """Synchronizes displayed values to members.""" @@ -121,7 +121,7 @@ def _run_javascript(self, code: str, debug: bool = False) -> None: _notebook.display(_notebook.Javascript(code)) # Track script execution. - all_tracked = object_utils.thread_local_get(_TLS_TRACKED_SCRIPTS, []) + all_tracked = utils.thread_local_get(_TLS_TRACKED_SCRIPTS, []) for tracked in all_tracked: tracked.append(code) diff --git a/pyglove/core/views/html/controls/progress_bar.py b/pyglove/core/views/html/controls/progress_bar.py index 2fc0bed..f9c05ae 100644 --- a/pyglove/core/views/html/controls/progress_bar.py +++ b/pyglove/core/views/html/controls/progress_bar.py @@ -16,7 +16,7 @@ import functools from typing import Annotated, List, Optional, Union -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.symbolic import object as pg_object # pylint: disable=g-importing-member from pyglove.core.views.html.base import Html @@ -73,10 +73,8 @@ def _to_html(self, **kwargs) -> Html: [], id=self.element_id(), styles=styles, - css_classes=[ - 'sub-progress', - object_utils.camel_to_snake(self.name, '-') - ] + self.css_classes, + css_classes=['sub-progress', utils.camel_to_snake(self.name, '-')] + + self.css_classes, ) def increment(self, delta: int = 1): diff --git a/pyglove/core/views/html/tree_view.py b/pyglove/core/views/html/tree_view.py index e19033b..413a8d2 100644 --- a/pyglove/core/views/html/tree_view.py +++ b/pyglove/core/views/html/tree_view.py @@ -16,13 +16,13 @@ import inspect from typing import Any, Callable, Dict, Iterable, Literal, Optional, Sequence, Tuple, Union -from pyglove.core import object_utils +from pyglove.core import utils from pyglove.core.symbolic import base as pg_symbolic from pyglove.core.views.html import base -KeyPath = object_utils.KeyPath -KeyPathSet = object_utils.KeyPathSet +KeyPath = utils.KeyPath +KeyPathSet = utils.KeyPathSet Html = base.Html HtmlView = base.HtmlView @@ -940,10 +940,13 @@ def value_repr() -> str: return repr(value) else: return value - return object_utils.format( + return utils.format( value, - compact=False, verbose=False, hide_default_values=True, - python_format=True, use_inferred=True, + compact=False, + verbose=False, + hide_default_values=True, + python_format=True, + use_inferred=True, max_bytes_len=64, ) return Html.element( @@ -1257,7 +1260,7 @@ def tooltip( del parent, kwargs if content is None: content = Html.escape( - object_utils.format( + utils.format( value, root_path=root_path, compact=False, @@ -1302,7 +1305,7 @@ def css_class_name(value: Any) -> Optional[str]: class_name = f'{value.__name__}-class' else: class_name = type(value).__name__ - return object_utils.camel_to_snake(class_name, '-') + return utils.camel_to_snake(class_name, '-') @staticmethod def init_uncollapse( @@ -1341,40 +1344,40 @@ def get_child_kwargs( @staticmethod def get_passthrough_kwargs( *, - enable_summary: Optional[bool] = object_utils.MISSING_VALUE, - enable_summary_for_str: bool = object_utils.MISSING_VALUE, - max_summary_len_for_str: int = object_utils.MISSING_VALUE, - enable_summary_tooltip: bool = object_utils.MISSING_VALUE, + enable_summary: Optional[bool] = utils.MISSING_VALUE, + enable_summary_for_str: bool = utils.MISSING_VALUE, + max_summary_len_for_str: int = utils.MISSING_VALUE, + enable_summary_tooltip: bool = utils.MISSING_VALUE, key_style: Union[ Literal['label', 'summary'], - Callable[[KeyPath, Any, Any], Literal['label', 'summary']] - ] = object_utils.MISSING_VALUE, + Callable[[KeyPath, Any, Any], Literal['label', 'summary']], + ] = utils.MISSING_VALUE, key_color: Union[ Tuple[Optional[str], Optional[str]], - Callable[[KeyPath, Any, Any], Tuple[Optional[str], Optional[str]]] - ] = object_utils.MISSING_VALUE, + Callable[[KeyPath, Any, Any], Tuple[Optional[str], Optional[str]]], + ] = utils.MISSING_VALUE, include_keys: Union[ Iterable[Union[int, str]], Callable[[KeyPath, Any, Any], Iterable[Union[int, str]]], - None - ] = object_utils.MISSING_VALUE, + None, + ] = utils.MISSING_VALUE, exclude_keys: Union[ Iterable[Union[int, str]], Callable[[KeyPath, Any, Any], Iterable[Union[int, str]]], - None - ] = object_utils.MISSING_VALUE, - enable_key_tooltip: bool = object_utils.MISSING_VALUE, + None, + ] = utils.MISSING_VALUE, + enable_key_tooltip: bool = utils.MISSING_VALUE, uncollapse: Union[ KeyPathSet, base.NodeFilter, None - ] = object_utils.MISSING_VALUE, - extra_flags: Optional[Dict[str, Any]] = object_utils.MISSING_VALUE, - highlight: Optional[base.NodeFilter] = object_utils.MISSING_VALUE, - lowlight: Optional[base.NodeFilter] = object_utils.MISSING_VALUE, - debug: bool = object_utils.MISSING_VALUE, + ] = utils.MISSING_VALUE, + extra_flags: Optional[Dict[str, Any]] = utils.MISSING_VALUE, + highlight: Optional[base.NodeFilter] = utils.MISSING_VALUE, + lowlight: Optional[base.NodeFilter] = utils.MISSING_VALUE, + debug: bool = utils.MISSING_VALUE, remove: Optional[Iterable[str]] = None, **kwargs, ): - # pytype: enable=annotation-type-mismatch + # pytype: enable=annotation-type-mismatch """Gets the rendering arguments to pass through to the child nodes.""" del kwargs passthrough_kwargs = dict( @@ -1386,23 +1389,22 @@ def get_passthrough_kwargs( key_style=key_style, key_color=key_color, include_keys=( - include_keys if callable(include_keys) - else object_utils.MISSING_VALUE + include_keys if callable(include_keys) else utils.MISSING_VALUE ), exclude_keys=( - exclude_keys if callable(exclude_keys) - else object_utils.MISSING_VALUE + exclude_keys if callable(exclude_keys) else utils.MISSING_VALUE ), uncollapse=uncollapse, highlight=highlight, lowlight=lowlight, extra_flags=extra_flags, - debug=debug + debug=debug, ) # Filter out missing values. passthrough_kwargs = { - k: v for k, v in passthrough_kwargs.items() - if v is not object_utils.MISSING_VALUE + k: v + for k, v in passthrough_kwargs.items() + if v is not utils.MISSING_VALUE } if remove: return { @@ -1467,7 +1469,7 @@ def get_kwargs( ) # Deep hierarchy merge. - return object_utils.merge_tree(call_kwargs, overriden_kwargs) + return utils.merge_tree(call_kwargs, overriden_kwargs) @staticmethod def merge_uncollapse(