Skip to content

Commit bf7fc29

Browse files
authored
fix: Resolve Then copy TypeError (#3553)
1 parent 49c8f9a commit bf7fc29

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

Diff for: altair/utils/schemapi.py

+3
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,9 @@ def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
854854
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
855855
copy = partial(_deep_copy, by_ref=by_ref)
856856
if isinstance(obj, SchemaBase):
857+
if copier := getattr(obj, "__deepcopy__", None):
858+
with debug_mode(False):
859+
return copier(obj)
857860
args = (copy(arg) for arg in obj._args)
858861
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
859862
with debug_mode(False):

Diff for: altair/vegalite/v5/api.py

+3
Original file line numberDiff line numberDiff line change
@@ -1061,6 +1061,9 @@ def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: # type: ignore[
10611061
m = super().to_dict(*args, **kwds)
10621062
return _Conditional(condition=m["condition"])
10631063

1064+
def __deepcopy__(self, memo: Any) -> Self:
1065+
return type(self)(_Conditional(condition=_deepcopy(self.condition)))
1066+
10641067

10651068
class ChainedWhen(_BaseWhen):
10661069
"""

Diff for: tests/vegalite/v5/test_api.py

+21
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,27 @@ def test_when_condition_parity(
698698
assert chart_condition == chart_when
699699

700700

701+
def test_when_then_interactive() -> None:
702+
"""Copy-related regression found in https://github.com/vega/altair/pull/3394#issuecomment-2302995453."""
703+
source = "https://cdn.jsdelivr.net/npm/[email protected]/data/movies.json"
704+
predicate = (alt.datum.IMDB_Rating == None) | ( # noqa: E711
705+
alt.datum.Rotten_Tomatoes_Rating == None # noqa: E711
706+
)
707+
708+
chart = (
709+
alt.Chart(source)
710+
.mark_point(invalid=None)
711+
.encode(
712+
x="IMDB_Rating:Q",
713+
y="Rotten_Tomatoes_Rating:Q",
714+
color=alt.when(predicate).then(alt.value("grey")), # type: ignore[arg-type]
715+
)
716+
)
717+
assert chart.interactive()
718+
assert chart.copy()
719+
assert chart.to_dict()
720+
721+
701722
def test_selection_to_dict():
702723
brush = alt.selection_interval()
703724

Diff for: tools/schemapi/schemapi.py

+3
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,9 @@ def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
852852
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
853853
copy = partial(_deep_copy, by_ref=by_ref)
854854
if isinstance(obj, SchemaBase):
855+
if copier := getattr(obj, "__deepcopy__", None):
856+
with debug_mode(False):
857+
return copier(obj)
855858
args = (copy(arg) for arg in obj._args)
856859
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
857860
with debug_mode(False):

0 commit comments

Comments
 (0)