Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 44 additions & 46 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -833,6 +834,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]:
return obj is Undefined


@overload
def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ...
@overload
def _shallow_copy(obj: Any) -> Any: ...
def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any:
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, (list, dict)):
return obj.copy()
else:
return obj


@overload
def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ...
@overload
def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
copy = partial(_deep_copy, by_ref=by_ref)
if isinstance(obj, SchemaBase):
args = (copy(arg) for arg in obj._args)
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [copy(v) for v in obj]
elif isinstance(obj, dict):
return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()}
else:
return obj


class SchemaBase:
"""
Base class for schema wrappers.
Expand Down Expand Up @@ -870,7 +903,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None:
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy( # noqa: C901
def copy(
self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None
) -> Self:
"""
Expand All @@ -887,53 +920,11 @@ def copy( # noqa: C901
A list of keys for which the contents should not be copied, but
only stored by reference.
"""

def _shallow_copy(obj):
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
return obj[:]
elif isinstance(obj, dict):
return obj.copy()
else:
return obj

def _deep_copy(obj, ignore: list[str] | None = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj._kwds.items()
}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [_deep_copy(v, ignore=ignore) for v in obj]
elif isinstance(obj, dict):
return {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj.items()
}
else:
return obj

try:
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
deep_is_list = True

if deep and not deep_is_list:
return _deep_copy(self, ignore=ignore)

if deep is True:
return cast("Self", _deep_copy(self, set(ignore) if ignore else set()))
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
if _is_iterable(deep):
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand Down Expand Up @@ -1240,6 +1231,13 @@ def __dir__(self) -> list[str]:

TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase)

_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any])
"""
Types which have an implementation in ``SchemaBase.copy()``.

All other types are returned **by reference**.
"""


def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]:
return isinstance(obj, dict)
Expand Down
90 changes: 44 additions & 46 deletions tools/schemapi/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -831,6 +832,38 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]:
return obj is Undefined


@overload
def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ...
@overload
def _shallow_copy(obj: Any) -> Any: ...
def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any:
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, (list, dict)):
return obj.copy()
else:
return obj


@overload
def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ...
@overload
def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
copy = partial(_deep_copy, by_ref=by_ref)
if isinstance(obj, SchemaBase):
args = (copy(arg) for arg in obj._args)
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [copy(v) for v in obj]
elif isinstance(obj, dict):
return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()}
else:
return obj


class SchemaBase:
"""
Base class for schema wrappers.
Expand Down Expand Up @@ -868,7 +901,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None:
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy( # noqa: C901
def copy(
self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None
) -> Self:
"""
Expand All @@ -885,53 +918,11 @@ def copy( # noqa: C901
A list of keys for which the contents should not be copied, but
only stored by reference.
"""

def _shallow_copy(obj):
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
return obj[:]
elif isinstance(obj, dict):
return obj.copy()
else:
return obj

def _deep_copy(obj, ignore: list[str] | None = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj._kwds.items()
}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [_deep_copy(v, ignore=ignore) for v in obj]
elif isinstance(obj, dict):
return {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj.items()
}
else:
return obj

try:
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
deep_is_list = True

if deep and not deep_is_list:
return _deep_copy(self, ignore=ignore)

if deep is True:
return cast("Self", _deep_copy(self, set(ignore) if ignore else set()))
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
if _is_iterable(deep):
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand Down Expand Up @@ -1238,6 +1229,13 @@ def __dir__(self) -> list[str]:

TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase)

_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any])
"""
Types which have an implementation in ``SchemaBase.copy()``.

All other types are returned **by reference**.
"""


def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]:
return isinstance(obj, dict)
Expand Down