From 31dedc841ed452dfc767383d7fedd48476b0b171 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Jun 2024 20:48:40 +0100 Subject: [PATCH 01/11] fix, doc, perf: Fix issues with `Chart|LayerChart.encode`, 1.33x speedup to `infer_encoding_types` Fixes: - [Sphinx warning](https://altair-viz.github.io/user_guide/generated/toplevel/altair.Chart.html#altair.Chart) on `Chart.encode`. Also incorrectly under `Attributes` section - Preserve static typing previously found in `_encode_signature` but lost after `_EncodingMixin.encode` - Re-running `mypy` output 'Found 63 errors in 47 files (checked 360 source files)', tests/examples Perf: - This was a response to the `TODO` left at the top of `infer_encoding_types` - Will be adding the benchmark to the PR description --- altair/utils/core.py | 203 ++++++--- altair/vegalite/v5/api.py | 25 +- altair/vegalite/v5/schema/channels.py | 593 ++++++++++++++------------ tools/generate_schema_wrapper.py | 42 +- 4 files changed, 496 insertions(+), 367 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 88d03b6e7..5ccfdace8 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -18,18 +18,20 @@ Dict, Optional, Tuple, - Sequence, + Iterator, Type, cast, ) from types import ModuleType +from itertools import groupby +from operator import itemgetter import jsonschema import pandas as pd import numpy as np from pandas.api.types import infer_dtype -from altair.utils.schemapi import SchemaBase +from altair.utils.schemapi import SchemaBase, Undefined from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame if sys.version_info >= (3, 10): @@ -767,7 +769,132 @@ def display_traceback(in_ipython: bool = True): traceback.print_exception(*exc_info) -def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): +_ChannelType = Literal["field", "datum", "value"] +_CHANNEL_CACHE: "_ChannelCache" +"""Singleton `_ChannelCache` instance. + +Initialized on first use. +""" + + +class _ChannelCache: + channel_to_name: Dict[Type[SchemaBase], str] + name_to_channel: Dict[str, Dict[_ChannelType, Type[SchemaBase]]] + + @classmethod + def from_channels(cls, channels: ModuleType, /) -> "_ChannelCache": + # - This branch is only kept for tests that depend on mocking `channels`. + # - No longer needs to pass around `channels` reference and rebuild every call. + c_to_n = { + c: c._encoding_name + for c in channels.__dict__.values() + if isinstance(c, type) + and issubclass(c, SchemaBase) + and hasattr(c, "_encoding_name") + } + self = cls.__new__(cls) + self.channel_to_name = c_to_n + self.name_to_channel = _invert_group_channels(c_to_n) + return self + + @classmethod + def from_cache(cls) -> "_ChannelCache": + global _CHANNEL_CACHE + try: + cached = _CHANNEL_CACHE + except NameError: + cached = cls.__new__(cls) + cached.channel_to_name = _init_channel_to_name() + cached.name_to_channel = _invert_group_channels(cached.channel_to_name) + _CHANNEL_CACHE = cached + return _CHANNEL_CACHE + + def get_encoding(self, tp: Type[Any], /) -> str: + if encoding := self.channel_to_name.get(tp): + return encoding + raise NotImplementedError(f"positional of type {type(tp).__name__!r}") + + def _wrap_in_channel(self, obj: Any, encoding: str, /): + if isinstance(obj, SchemaBase): + return obj + elif isinstance(obj, str): + obj = {"shorthand": obj} + elif isinstance(obj, (list, tuple)): + return [self._wrap_in_channel(el, encoding) for el in obj] + if channel := self.name_to_channel.get(encoding): + tp = channel["value" if "value" in obj else "field"] + try: + # Don't force validation here; some objects won't be valid until + # they're created in the context of a chart. + return tp.from_dict(obj, validate=False) + except jsonschema.ValidationError: + # our attempts at finding the correct class have failed + return obj + else: + warnings.warn(f"Unrecognized encoding channel {encoding!r}", stacklevel=1) + return obj + + def infer_encoding_types(self, kwargs: Dict[str, Any], /): + return { + encoding: self._wrap_in_channel(obj, encoding) + for encoding, obj in kwargs.items() + if obj is not Undefined + } + + +def _init_channel_to_name(): + """ + Construct a dictionary of channel type to encoding name. + + Note + ---- + The return type is not expressible using annotations, but is used + internally by `mypy`/`pyright` and avoid the need for type ignores. + + Returns + ------- + mapping: dict[type[``] | type[``] | type[``], str] + """ + from altair.vegalite.v5.schema import channels as ch + + mixins = ch.FieldChannelMixin, ch.ValueChannelMixin, ch.DatumChannelMixin + + return { + c: c._encoding_name + for c in ch.__dict__.values() + if isinstance(c, type) and issubclass(c, mixins) and issubclass(c, SchemaBase) + } + + +def _invert_group_channels( + m: Dict[Type[SchemaBase], str], / +) -> Dict[str, Dict[_ChannelType, Type[SchemaBase]]]: + """Grouped inverted index for `_ChannelCache.channel_to_name`.""" + + def _reduce(it: Iterator[Tuple[Type[Any], str]]) -> Any: + """Returns a 1-2 item dict, per channel. + + Never includes `datum`, as it is never utilized in `wrap_in_channel`. + """ + item: Dict[Any, Type[SchemaBase]] = {} + for tp, _ in it: + name = tp.__name__ + if name.endswith("Datum"): + continue + elif name.endswith("Value"): + sub_key = "value" + else: + sub_key = "field" + item[sub_key] = tp + return item + + grouper = groupby(m.items(), itemgetter(1)) + return {k: _reduce(chans) for k, chans in grouper} + + +def infer_encoding_types( + args: Tuple[Any, ...], kwargs: Dict[str, Any], channels: Optional[ModuleType] = None +): """Infer typed keyword arguments for args and kwargs Parameters @@ -785,68 +912,18 @@ def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: Modul All args and kwargs in a single dict, with keys and types based on the channels mapping. """ - # Construct a dictionary of channel type to encoding name - # TODO: cache this somehow? - channel_objs = (getattr(channels, name) for name in dir(channels)) - channel_objs = ( - c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) + cache = ( + _ChannelCache.from_channels(channels) + if channels + else _ChannelCache.from_cache() ) - channel_to_name: Dict[Type[SchemaBase], str] = { - c: c._encoding_name for c in channel_objs - } - name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} - for chan, name in channel_to_name.items(): - chans = name_to_channel.setdefault(name, {}) - if chan.__name__.endswith("Datum"): - key = "datum" - elif chan.__name__.endswith("Value"): - key = "value" - else: - key = "field" - chans[key] = chan - # First use the mapping to convert args to kwargs based on their types. for arg in args: - if isinstance(arg, (list, tuple)) and len(arg) > 0: - type_ = type(arg[0]) + el = next(iter(arg), None) if isinstance(arg, (list, tuple)) else arg + encoding = cache.get_encoding(type(el)) + if encoding not in kwargs: + kwargs[encoding] = arg else: - type_ = type(arg) - - encoding = channel_to_name.get(type_, None) - if encoding is None: - raise NotImplementedError("positional of type {}" "".format(type_)) - if encoding in kwargs: - raise ValueError("encoding {} specified twice.".format(encoding)) - kwargs[encoding] = arg - - def _wrap_in_channel_class(obj, encoding): - if isinstance(obj, SchemaBase): - return obj + raise ValueError(f"encoding {encoding!r} specified twice.") - if isinstance(obj, str): - obj = {"shorthand": obj} - - if isinstance(obj, (list, tuple)): - return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] - - if encoding not in name_to_channel: - warnings.warn( - "Unrecognized encoding channel '{}'".format(encoding), stacklevel=1 - ) - return obj - - classes = name_to_channel[encoding] - cls = classes["value"] if "value" in obj else classes["field"] - - try: - # Don't force validation here; some objects won't be valid until - # they're created in the context of a chart. - return cls.from_dict(obj, validate=False) - except jsonschema.ValidationError: - # our attempts at finding the correct class have failed - return obj - - return { - encoding: _wrap_in_channel_class(obj, encoding) - for encoding, obj in kwargs.items() - } + return cache.infer_encoding_types(kwargs) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 6425c77fe..7b5e648d6 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -2727,24 +2727,7 @@ def resolve_scale(self, *args, **kwargs) -> Self: return self._set_resolve(scale=core.ScaleResolveMap(*args, **kwargs)) -class _EncodingMixin: - @utils.use_signature(channels._encode_signature) - def encode(self, *args, **kwargs) -> Self: - # Convert args to kwargs based on their types. - kwargs = utils.infer_encoding_types(args, kwargs, channels) - - # get a copy of the dict representation of the previous encoding - # ignore type as copy method comes from SchemaBase - copy = self.copy(deep=["encoding"]) # type: ignore[attr-defined] - encoding = copy._get("encoding", {}) - if isinstance(encoding, core.VegaLiteSchema): - encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined} - - # update with the new encodings, and apply them to the copy - encoding.update(kwargs) - copy.encoding = core.FacetedEncoding(**encoding) - return copy - +class _EncodingMixin(channels._EncodingMixin): def facet( self, facet: Union[str, channels.Facet, UndefinedType] = Undefined, @@ -3629,7 +3612,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def __iadd__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: + def __iadd__(self, other: Union["LayerChart", Chart]) -> Self: _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) self.layer.append(other) @@ -3637,12 +3620,12 @@ def __iadd__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: self.params, self.layer = _combine_subchart_params(self.params, self.layer) return self - def __add__(self, other: Union[core.LayerSpec, core.UnitSpec]) -> Self: + def __add__(self, other: Union["LayerChart", Chart]) -> Self: copy = self.copy(deep=["layer"]) copy += other return copy - def add_layers(self, *layers: Union[core.LayerSpec, core.UnitSpec]) -> Self: + def add_layers(self, *layers: Union["LayerChart", Chart]) -> Self: copy = self.copy(deep=["layer"]) for layer in layers: copy += layer diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index 089a534a6..73bdd231a 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -13,12 +13,15 @@ from . import core import pandas as pd from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters -from altair.utils import parse_shorthand +from altair.utils import parse_shorthand, infer_encoding_types as _infer_encoding_types from typing import Any, overload, Sequence, List, Literal, Union, Optional from typing import Dict as TypingDict +from typing_extensions import Self class FieldChannelMixin: + _encoding_name: str + def to_dict( self, validate: bool = True, @@ -88,6 +91,8 @@ def to_dict( class ValueChannelMixin: + _encoding_name: str + def to_dict( self, validate: bool = True, @@ -111,6 +116,8 @@ def to_dict( class DatumChannelMixin: + _encoding_name: str + def to_dict( self, validate: bool = True, @@ -78427,276 +78434,314 @@ def __init__(self, value, **kwds): super(YOffsetValue, self).__init__(value=value, **kwds) -def _encode_signature( - self, - angle: Union[str, Angle, dict, AngleDatum, AngleValue, UndefinedType] = Undefined, - color: Union[str, Color, dict, ColorDatum, ColorValue, UndefinedType] = Undefined, - column: Union[str, Column, dict, UndefinedType] = Undefined, - description: Union[ - str, Description, dict, DescriptionValue, UndefinedType - ] = Undefined, - detail: Union[str, Detail, dict, list, UndefinedType] = Undefined, - facet: Union[str, Facet, dict, UndefinedType] = Undefined, - fill: Union[str, Fill, dict, FillDatum, FillValue, UndefinedType] = Undefined, - fillOpacity: Union[ - str, FillOpacity, dict, FillOpacityDatum, FillOpacityValue, UndefinedType - ] = Undefined, - href: Union[str, Href, dict, HrefValue, UndefinedType] = Undefined, - key: Union[str, Key, dict, UndefinedType] = Undefined, - latitude: Union[str, Latitude, dict, LatitudeDatum, UndefinedType] = Undefined, - latitude2: Union[ - str, Latitude2, dict, Latitude2Datum, Latitude2Value, UndefinedType - ] = Undefined, - longitude: Union[str, Longitude, dict, LongitudeDatum, UndefinedType] = Undefined, - longitude2: Union[ - str, Longitude2, dict, Longitude2Datum, Longitude2Value, UndefinedType - ] = Undefined, - opacity: Union[ - str, Opacity, dict, OpacityDatum, OpacityValue, UndefinedType - ] = Undefined, - order: Union[str, Order, dict, list, OrderValue, UndefinedType] = Undefined, - radius: Union[ - str, Radius, dict, RadiusDatum, RadiusValue, UndefinedType - ] = Undefined, - radius2: Union[ - str, Radius2, dict, Radius2Datum, Radius2Value, UndefinedType - ] = Undefined, - row: Union[str, Row, dict, UndefinedType] = Undefined, - shape: Union[str, Shape, dict, ShapeDatum, ShapeValue, UndefinedType] = Undefined, - size: Union[str, Size, dict, SizeDatum, SizeValue, UndefinedType] = Undefined, - stroke: Union[ - str, Stroke, dict, StrokeDatum, StrokeValue, UndefinedType - ] = Undefined, - strokeDash: Union[ - str, StrokeDash, dict, StrokeDashDatum, StrokeDashValue, UndefinedType - ] = Undefined, - strokeOpacity: Union[ - str, StrokeOpacity, dict, StrokeOpacityDatum, StrokeOpacityValue, UndefinedType - ] = Undefined, - strokeWidth: Union[ - str, StrokeWidth, dict, StrokeWidthDatum, StrokeWidthValue, UndefinedType - ] = Undefined, - text: Union[str, Text, dict, TextDatum, TextValue, UndefinedType] = Undefined, - theta: Union[str, Theta, dict, ThetaDatum, ThetaValue, UndefinedType] = Undefined, - theta2: Union[ - str, Theta2, dict, Theta2Datum, Theta2Value, UndefinedType - ] = Undefined, - tooltip: Union[str, Tooltip, dict, list, TooltipValue, UndefinedType] = Undefined, - url: Union[str, Url, dict, UrlValue, UndefinedType] = Undefined, - x: Union[str, X, dict, XDatum, XValue, UndefinedType] = Undefined, - x2: Union[str, X2, dict, X2Datum, X2Value, UndefinedType] = Undefined, - xError: Union[str, XError, dict, XErrorValue, UndefinedType] = Undefined, - xError2: Union[str, XError2, dict, XError2Value, UndefinedType] = Undefined, - xOffset: Union[ - str, XOffset, dict, XOffsetDatum, XOffsetValue, UndefinedType - ] = Undefined, - y: Union[str, Y, dict, YDatum, YValue, UndefinedType] = Undefined, - y2: Union[str, Y2, dict, Y2Datum, Y2Value, UndefinedType] = Undefined, - yError: Union[str, YError, dict, YErrorValue, UndefinedType] = Undefined, - yError2: Union[str, YError2, dict, YError2Value, UndefinedType] = Undefined, - yOffset: Union[ - str, YOffset, dict, YOffsetDatum, YOffsetValue, UndefinedType - ] = Undefined, -): - """Parameters - ---------- - - angle : str, :class:`Angle`, Dict, :class:`AngleDatum`, :class:`AngleValue` - Rotation angle of point and text marks. - color : str, :class:`Color`, Dict, :class:`ColorDatum`, :class:`ColorValue` - Color of the marks – either fill or stroke color based on the ``filled`` property - of mark definition. By default, ``color`` represents fill color for ``"area"``, - ``"bar"``, ``"tick"``, ``"text"``, ``"trail"``, ``"circle"``, and ``"square"`` / - stroke color for ``"line"`` and ``"point"``. - - **Default value:** If undefined, the default color depends on `mark config - `__ 's ``color`` - property. - - *Note:* 1) For fine-grained control over both fill and stroke colors of the marks, - please use the ``fill`` and ``stroke`` channels. The ``fill`` or ``stroke`` - encodings have higher precedence than ``color``, thus may override the ``color`` - encoding if conflicting encodings are specified. 2) See the scale documentation for - more information about customizing `color scheme - `__. - column : str, :class:`Column`, Dict - A field definition for the horizontal facet of trellis plots. - description : str, :class:`Description`, Dict, :class:`DescriptionValue` - A text description of this mark for ARIA accessibility (SVG output only). For SVG - output the ``"aria-label"`` attribute will be set to this description. - detail : str, :class:`Detail`, Dict, List - Additional levels of detail for grouping data in aggregate views and in line, trail, - and area marks without mapping data to a specific visual channel. - facet : str, :class:`Facet`, Dict - A field definition for the (flexible) facet of trellis plots. - - If either ``row`` or ``column`` is specified, this channel will be ignored. - fill : str, :class:`Fill`, Dict, :class:`FillDatum`, :class:`FillValue` - Fill color of the marks. **Default value:** If undefined, the default color depends - on `mark config `__ - 's ``color`` property. - - *Note:* The ``fill`` encoding has higher precedence than ``color``, thus may - override the ``color`` encoding if conflicting encodings are specified. - fillOpacity : str, :class:`FillOpacity`, Dict, :class:`FillOpacityDatum`, :class:`FillOpacityValue` - Fill opacity of the marks. - - **Default value:** If undefined, the default opacity depends on `mark config - `__ 's - ``fillOpacity`` property. - href : str, :class:`Href`, Dict, :class:`HrefValue` - A URL to load upon mouse click. - key : str, :class:`Key`, Dict - A data field to use as a unique key for data binding. When a visualization’s data is - updated, the key value will be used to match data elements to existing mark - instances. Use a key channel to enable object constancy for transitions over dynamic - data. - latitude : str, :class:`Latitude`, Dict, :class:`LatitudeDatum` - Latitude position of geographically projected marks. - latitude2 : str, :class:`Latitude2`, Dict, :class:`Latitude2Datum`, :class:`Latitude2Value` - Latitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, - ``"rect"``, and ``"rule"``. - longitude : str, :class:`Longitude`, Dict, :class:`LongitudeDatum` - Longitude position of geographically projected marks. - longitude2 : str, :class:`Longitude2`, Dict, :class:`Longitude2Datum`, :class:`Longitude2Value` - Longitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, - ``"rect"``, and ``"rule"``. - opacity : str, :class:`Opacity`, Dict, :class:`OpacityDatum`, :class:`OpacityValue` - Opacity of the marks. - - **Default value:** If undefined, the default opacity depends on `mark config - `__ 's ``opacity`` - property. - order : str, :class:`Order`, Dict, List, :class:`OrderValue` - Order of the marks. - - - * For stacked marks, this ``order`` channel encodes `stack order - `__. - * For line and trail marks, this ``order`` channel encodes order of data points in - the lines. This can be useful for creating `a connected scatterplot - `__. Setting - ``order`` to ``{"value": null}`` makes the line marks use the original order in - the data sources. - * Otherwise, this ``order`` channel encodes layer order of the marks. - - **Note** : In aggregate plots, ``order`` field should be ``aggregate`` d to avoid - creating additional aggregation grouping. - radius : str, :class:`Radius`, Dict, :class:`RadiusDatum`, :class:`RadiusValue` - The outer radius in pixels of arc marks. - radius2 : str, :class:`Radius2`, Dict, :class:`Radius2Datum`, :class:`Radius2Value` - The inner radius in pixels of arc marks. - row : str, :class:`Row`, Dict - A field definition for the vertical facet of trellis plots. - shape : str, :class:`Shape`, Dict, :class:`ShapeDatum`, :class:`ShapeValue` - Shape of the mark. - - - #. - For ``point`` marks the supported values include: - plotting shapes: ``"circle"``, - ``"square"``, ``"cross"``, ``"diamond"``, ``"triangle-up"``, ``"triangle-down"``, - ``"triangle-right"``, or ``"triangle-left"``. - the line symbol ``"stroke"`` - - centered directional shapes ``"arrow"``, ``"wedge"``, or ``"triangle"`` - a custom - `SVG path string - `__ (For correct - sizing, custom shape paths should be defined within a square bounding box with - coordinates ranging from -1 to 1 along both the x and y dimensions.) - - #. - For ``geoshape`` marks it should be a field definition of the geojson data - - **Default value:** If undefined, the default shape depends on `mark config - `__ 's ``shape`` - property. ( ``"circle"`` if unset.) - size : str, :class:`Size`, Dict, :class:`SizeDatum`, :class:`SizeValue` - Size of the mark. - - - * For ``"point"``, ``"square"`` and ``"circle"``, – the symbol size, or pixel area - of the mark. - * For ``"bar"`` and ``"tick"`` – the bar and tick's size. - * For ``"text"`` – the text's font size. - * Size is unsupported for ``"line"``, ``"area"``, and ``"rect"``. (Use ``"trail"`` - instead of line with varying size) - stroke : str, :class:`Stroke`, Dict, :class:`StrokeDatum`, :class:`StrokeValue` - Stroke color of the marks. **Default value:** If undefined, the default color - depends on `mark config - `__ 's ``color`` - property. - - *Note:* The ``stroke`` encoding has higher precedence than ``color``, thus may - override the ``color`` encoding if conflicting encodings are specified. - strokeDash : str, :class:`StrokeDash`, Dict, :class:`StrokeDashDatum`, :class:`StrokeDashValue` - Stroke dash of the marks. - - **Default value:** ``[1,0]`` (No dash). - strokeOpacity : str, :class:`StrokeOpacity`, Dict, :class:`StrokeOpacityDatum`, :class:`StrokeOpacityValue` - Stroke opacity of the marks. - - **Default value:** If undefined, the default opacity depends on `mark config - `__ 's - ``strokeOpacity`` property. - strokeWidth : str, :class:`StrokeWidth`, Dict, :class:`StrokeWidthDatum`, :class:`StrokeWidthValue` - Stroke width of the marks. - - **Default value:** If undefined, the default stroke width depends on `mark config - `__ 's - ``strokeWidth`` property. - text : str, :class:`Text`, Dict, :class:`TextDatum`, :class:`TextValue` - Text of the ``text`` mark. - theta : str, :class:`Theta`, Dict, :class:`ThetaDatum`, :class:`ThetaValue` - For arc marks, the arc length in radians if theta2 is not specified, otherwise the - start arc angle. (A value of 0 indicates up or “north”, increasing values proceed - clockwise.) - - For text marks, polar coordinate angle in radians. - theta2 : str, :class:`Theta2`, Dict, :class:`Theta2Datum`, :class:`Theta2Value` - The end angle of arc marks in radians. A value of 0 indicates up or “north”, - increasing values proceed clockwise. - tooltip : str, :class:`Tooltip`, Dict, List, :class:`TooltipValue` - The tooltip text to show upon mouse hover. Specifying ``tooltip`` encoding overrides - `the tooltip property in the mark definition - `__. - - See the `tooltip `__ - documentation for a detailed discussion about tooltip in Vega-Lite. - url : str, :class:`Url`, Dict, :class:`UrlValue` - The URL of an image mark. - x : str, :class:`X`, Dict, :class:`XDatum`, :class:`XValue` - X coordinates of the marks, or width of horizontal ``"bar"`` and ``"area"`` without - specified ``x2`` or ``width``. - - The ``value`` of this channel can be a number or a string ``"width"`` for the width - of the plot. - x2 : str, :class:`X2`, Dict, :class:`X2Datum`, :class:`X2Value` - X2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. - - The ``value`` of this channel can be a number or a string ``"width"`` for the width - of the plot. - xError : str, :class:`XError`, Dict, :class:`XErrorValue` - Error value of x coordinates for error specified ``"errorbar"`` and ``"errorband"``. - xError2 : str, :class:`XError2`, Dict, :class:`XError2Value` - Secondary error value of x coordinates for error specified ``"errorbar"`` and - ``"errorband"``. - xOffset : str, :class:`XOffset`, Dict, :class:`XOffsetDatum`, :class:`XOffsetValue` - Offset of x-position of the marks - y : str, :class:`Y`, Dict, :class:`YDatum`, :class:`YValue` - Y coordinates of the marks, or height of vertical ``"bar"`` and ``"area"`` without - specified ``y2`` or ``height``. - - The ``value`` of this channel can be a number or a string ``"height"`` for the - height of the plot. - y2 : str, :class:`Y2`, Dict, :class:`Y2Datum`, :class:`Y2Value` - Y2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. - - The ``value`` of this channel can be a number or a string ``"height"`` for the - height of the plot. - yError : str, :class:`YError`, Dict, :class:`YErrorValue` - Error value of y coordinates for error specified ``"errorbar"`` and ``"errorband"``. - yError2 : str, :class:`YError2`, Dict, :class:`YError2Value` - Secondary error value of y coordinates for error specified ``"errorbar"`` and - ``"errorband"``. - yOffset : str, :class:`YOffset`, Dict, :class:`YOffsetDatum`, :class:`YOffsetValue` - Offset of y-position of the marks - """ - ... +class _EncodingMixin: + def encode( + self, + *args: Any, + angle: Union[ + str, Angle, dict, AngleDatum, AngleValue, UndefinedType + ] = Undefined, + color: Union[ + str, Color, dict, ColorDatum, ColorValue, UndefinedType + ] = Undefined, + column: Union[str, Column, dict, UndefinedType] = Undefined, + description: Union[ + str, Description, dict, DescriptionValue, UndefinedType + ] = Undefined, + detail: Union[str, Detail, dict, list, UndefinedType] = Undefined, + facet: Union[str, Facet, dict, UndefinedType] = Undefined, + fill: Union[str, Fill, dict, FillDatum, FillValue, UndefinedType] = Undefined, + fillOpacity: Union[ + str, FillOpacity, dict, FillOpacityDatum, FillOpacityValue, UndefinedType + ] = Undefined, + href: Union[str, Href, dict, HrefValue, UndefinedType] = Undefined, + key: Union[str, Key, dict, UndefinedType] = Undefined, + latitude: Union[str, Latitude, dict, LatitudeDatum, UndefinedType] = Undefined, + latitude2: Union[ + str, Latitude2, dict, Latitude2Datum, Latitude2Value, UndefinedType + ] = Undefined, + longitude: Union[ + str, Longitude, dict, LongitudeDatum, UndefinedType + ] = Undefined, + longitude2: Union[ + str, Longitude2, dict, Longitude2Datum, Longitude2Value, UndefinedType + ] = Undefined, + opacity: Union[ + str, Opacity, dict, OpacityDatum, OpacityValue, UndefinedType + ] = Undefined, + order: Union[str, Order, dict, list, OrderValue, UndefinedType] = Undefined, + radius: Union[ + str, Radius, dict, RadiusDatum, RadiusValue, UndefinedType + ] = Undefined, + radius2: Union[ + str, Radius2, dict, Radius2Datum, Radius2Value, UndefinedType + ] = Undefined, + row: Union[str, Row, dict, UndefinedType] = Undefined, + shape: Union[ + str, Shape, dict, ShapeDatum, ShapeValue, UndefinedType + ] = Undefined, + size: Union[str, Size, dict, SizeDatum, SizeValue, UndefinedType] = Undefined, + stroke: Union[ + str, Stroke, dict, StrokeDatum, StrokeValue, UndefinedType + ] = Undefined, + strokeDash: Union[ + str, StrokeDash, dict, StrokeDashDatum, StrokeDashValue, UndefinedType + ] = Undefined, + strokeOpacity: Union[ + str, + StrokeOpacity, + dict, + StrokeOpacityDatum, + StrokeOpacityValue, + UndefinedType, + ] = Undefined, + strokeWidth: Union[ + str, StrokeWidth, dict, StrokeWidthDatum, StrokeWidthValue, UndefinedType + ] = Undefined, + text: Union[str, Text, dict, TextDatum, TextValue, UndefinedType] = Undefined, + theta: Union[ + str, Theta, dict, ThetaDatum, ThetaValue, UndefinedType + ] = Undefined, + theta2: Union[ + str, Theta2, dict, Theta2Datum, Theta2Value, UndefinedType + ] = Undefined, + tooltip: Union[ + str, Tooltip, dict, list, TooltipValue, UndefinedType + ] = Undefined, + url: Union[str, Url, dict, UrlValue, UndefinedType] = Undefined, + x: Union[str, X, dict, XDatum, XValue, UndefinedType] = Undefined, + x2: Union[str, X2, dict, X2Datum, X2Value, UndefinedType] = Undefined, + xError: Union[str, XError, dict, XErrorValue, UndefinedType] = Undefined, + xError2: Union[str, XError2, dict, XError2Value, UndefinedType] = Undefined, + xOffset: Union[ + str, XOffset, dict, XOffsetDatum, XOffsetValue, UndefinedType + ] = Undefined, + y: Union[str, Y, dict, YDatum, YValue, UndefinedType] = Undefined, + y2: Union[str, Y2, dict, Y2Datum, Y2Value, UndefinedType] = Undefined, + yError: Union[str, YError, dict, YErrorValue, UndefinedType] = Undefined, + yError2: Union[str, YError2, dict, YError2Value, UndefinedType] = Undefined, + yOffset: Union[ + str, YOffset, dict, YOffsetDatum, YOffsetValue, UndefinedType + ] = Undefined, + ) -> Self: + """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) + + Parameters + ---------- + angle : str, :class:`Angle`, Dict, :class:`AngleDatum`, :class:`AngleValue` + Rotation angle of point and text marks. + color : str, :class:`Color`, Dict, :class:`ColorDatum`, :class:`ColorValue` + Color of the marks – either fill or stroke color based on the ``filled`` property + of mark definition. By default, ``color`` represents fill color for ``"area"``, + ``"bar"``, ``"tick"``, ``"text"``, ``"trail"``, ``"circle"``, and ``"square"`` / + stroke color for ``"line"`` and ``"point"``. + + **Default value:** If undefined, the default color depends on `mark config + `__ 's ``color`` + property. + + *Note:* 1) For fine-grained control over both fill and stroke colors of the marks, + please use the ``fill`` and ``stroke`` channels. The ``fill`` or ``stroke`` + encodings have higher precedence than ``color``, thus may override the ``color`` + encoding if conflicting encodings are specified. 2) See the scale documentation for + more information about customizing `color scheme + `__. + column : str, :class:`Column`, Dict + A field definition for the horizontal facet of trellis plots. + description : str, :class:`Description`, Dict, :class:`DescriptionValue` + A text description of this mark for ARIA accessibility (SVG output only). For SVG + output the ``"aria-label"`` attribute will be set to this description. + detail : str, :class:`Detail`, Dict, List + Additional levels of detail for grouping data in aggregate views and in line, trail, + and area marks without mapping data to a specific visual channel. + facet : str, :class:`Facet`, Dict + A field definition for the (flexible) facet of trellis plots. + + If either ``row`` or ``column`` is specified, this channel will be ignored. + fill : str, :class:`Fill`, Dict, :class:`FillDatum`, :class:`FillValue` + Fill color of the marks. **Default value:** If undefined, the default color depends + on `mark config `__ + 's ``color`` property. + + *Note:* The ``fill`` encoding has higher precedence than ``color``, thus may + override the ``color`` encoding if conflicting encodings are specified. + fillOpacity : str, :class:`FillOpacity`, Dict, :class:`FillOpacityDatum`, :class:`FillOpacityValue` + Fill opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__ 's + ``fillOpacity`` property. + href : str, :class:`Href`, Dict, :class:`HrefValue` + A URL to load upon mouse click. + key : str, :class:`Key`, Dict + A data field to use as a unique key for data binding. When a visualization’s data is + updated, the key value will be used to match data elements to existing mark + instances. Use a key channel to enable object constancy for transitions over dynamic + data. + latitude : str, :class:`Latitude`, Dict, :class:`LatitudeDatum` + Latitude position of geographically projected marks. + latitude2 : str, :class:`Latitude2`, Dict, :class:`Latitude2Datum`, :class:`Latitude2Value` + Latitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, + ``"rect"``, and ``"rule"``. + longitude : str, :class:`Longitude`, Dict, :class:`LongitudeDatum` + Longitude position of geographically projected marks. + longitude2 : str, :class:`Longitude2`, Dict, :class:`Longitude2Datum`, :class:`Longitude2Value` + Longitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, + ``"rect"``, and ``"rule"``. + opacity : str, :class:`Opacity`, Dict, :class:`OpacityDatum`, :class:`OpacityValue` + Opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__ 's ``opacity`` + property. + order : str, :class:`Order`, Dict, List, :class:`OrderValue` + Order of the marks. + + + * For stacked marks, this ``order`` channel encodes `stack order + `__. + * For line and trail marks, this ``order`` channel encodes order of data points in + the lines. This can be useful for creating `a connected scatterplot + `__. Setting + ``order`` to ``{"value": null}`` makes the line marks use the original order in + the data sources. + * Otherwise, this ``order`` channel encodes layer order of the marks. + + **Note** : In aggregate plots, ``order`` field should be ``aggregate`` d to avoid + creating additional aggregation grouping. + radius : str, :class:`Radius`, Dict, :class:`RadiusDatum`, :class:`RadiusValue` + The outer radius in pixels of arc marks. + radius2 : str, :class:`Radius2`, Dict, :class:`Radius2Datum`, :class:`Radius2Value` + The inner radius in pixels of arc marks. + row : str, :class:`Row`, Dict + A field definition for the vertical facet of trellis plots. + shape : str, :class:`Shape`, Dict, :class:`ShapeDatum`, :class:`ShapeValue` + Shape of the mark. + + + #. + For ``point`` marks the supported values include: - plotting shapes: ``"circle"``, + ``"square"``, ``"cross"``, ``"diamond"``, ``"triangle-up"``, ``"triangle-down"``, + ``"triangle-right"``, or ``"triangle-left"``. - the line symbol ``"stroke"`` - + centered directional shapes ``"arrow"``, ``"wedge"``, or ``"triangle"`` - a custom + `SVG path string + `__ (For correct + sizing, custom shape paths should be defined within a square bounding box with + coordinates ranging from -1 to 1 along both the x and y dimensions.) + + #. + For ``geoshape`` marks it should be a field definition of the geojson data + + **Default value:** If undefined, the default shape depends on `mark config + `__ 's ``shape`` + property. ( ``"circle"`` if unset.) + size : str, :class:`Size`, Dict, :class:`SizeDatum`, :class:`SizeValue` + Size of the mark. + + + * For ``"point"``, ``"square"`` and ``"circle"``, – the symbol size, or pixel area + of the mark. + * For ``"bar"`` and ``"tick"`` – the bar and tick's size. + * For ``"text"`` – the text's font size. + * Size is unsupported for ``"line"``, ``"area"``, and ``"rect"``. (Use ``"trail"`` + instead of line with varying size) + stroke : str, :class:`Stroke`, Dict, :class:`StrokeDatum`, :class:`StrokeValue` + Stroke color of the marks. **Default value:** If undefined, the default color + depends on `mark config + `__ 's ``color`` + property. + + *Note:* The ``stroke`` encoding has higher precedence than ``color``, thus may + override the ``color`` encoding if conflicting encodings are specified. + strokeDash : str, :class:`StrokeDash`, Dict, :class:`StrokeDashDatum`, :class:`StrokeDashValue` + Stroke dash of the marks. + + **Default value:** ``[1,0]`` (No dash). + strokeOpacity : str, :class:`StrokeOpacity`, Dict, :class:`StrokeOpacityDatum`, :class:`StrokeOpacityValue` + Stroke opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__ 's + ``strokeOpacity`` property. + strokeWidth : str, :class:`StrokeWidth`, Dict, :class:`StrokeWidthDatum`, :class:`StrokeWidthValue` + Stroke width of the marks. + + **Default value:** If undefined, the default stroke width depends on `mark config + `__ 's + ``strokeWidth`` property. + text : str, :class:`Text`, Dict, :class:`TextDatum`, :class:`TextValue` + Text of the ``text`` mark. + theta : str, :class:`Theta`, Dict, :class:`ThetaDatum`, :class:`ThetaValue` + For arc marks, the arc length in radians if theta2 is not specified, otherwise the + start arc angle. (A value of 0 indicates up or “north”, increasing values proceed + clockwise.) + + For text marks, polar coordinate angle in radians. + theta2 : str, :class:`Theta2`, Dict, :class:`Theta2Datum`, :class:`Theta2Value` + The end angle of arc marks in radians. A value of 0 indicates up or “north”, + increasing values proceed clockwise. + tooltip : str, :class:`Tooltip`, Dict, List, :class:`TooltipValue` + The tooltip text to show upon mouse hover. Specifying ``tooltip`` encoding overrides + `the tooltip property in the mark definition + `__. + + See the `tooltip `__ + documentation for a detailed discussion about tooltip in Vega-Lite. + url : str, :class:`Url`, Dict, :class:`UrlValue` + The URL of an image mark. + x : str, :class:`X`, Dict, :class:`XDatum`, :class:`XValue` + X coordinates of the marks, or width of horizontal ``"bar"`` and ``"area"`` without + specified ``x2`` or ``width``. + + The ``value`` of this channel can be a number or a string ``"width"`` for the width + of the plot. + x2 : str, :class:`X2`, Dict, :class:`X2Datum`, :class:`X2Value` + X2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. + + The ``value`` of this channel can be a number or a string ``"width"`` for the width + of the plot. + xError : str, :class:`XError`, Dict, :class:`XErrorValue` + Error value of x coordinates for error specified ``"errorbar"`` and ``"errorband"``. + xError2 : str, :class:`XError2`, Dict, :class:`XError2Value` + Secondary error value of x coordinates for error specified ``"errorbar"`` and + ``"errorband"``. + xOffset : str, :class:`XOffset`, Dict, :class:`XOffsetDatum`, :class:`XOffsetValue` + Offset of x-position of the marks + y : str, :class:`Y`, Dict, :class:`YDatum`, :class:`YValue` + Y coordinates of the marks, or height of vertical ``"bar"`` and ``"area"`` without + specified ``y2`` or ``height``. + + The ``value`` of this channel can be a number or a string ``"height"`` for the + height of the plot. + y2 : str, :class:`Y2`, Dict, :class:`Y2Datum`, :class:`Y2Value` + Y2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. + + The ``value`` of this channel can be a number or a string ``"height"`` for the + height of the plot. + yError : str, :class:`YError`, Dict, :class:`YErrorValue` + Error value of y coordinates for error specified ``"errorbar"`` and ``"errorband"``. + yError2 : str, :class:`YError2`, Dict, :class:`YError2Value` + Secondary error value of y coordinates for error specified ``"errorbar"`` and + ``"errorband"``. + yOffset : str, :class:`YOffset`, Dict, :class:`YOffsetDatum`, :class:`YOffsetValue` + Offset of y-position of the marks + """ + # Compat prep for `infer_encoding_types` signature + kwargs = locals() + kwargs.pop("self") + args = kwargs.pop("args") + if args: + kwargs = {k: v for k, v in kwargs.items() if v is not Undefined} + + # Convert args to kwargs based on their types. + kwargs = _infer_encoding_types(args, kwargs) + # get a copy of the dict representation of the previous encoding + # ignore type as copy method comes from SchemaBase + copy = self.copy(deep=["encoding"]) # type: ignore[attr-defined] + encoding = copy._get("encoding", {}) + if isinstance(encoding, core.VegaLiteSchema): + encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined} + # update with the new encodings, and apply them to the copy + encoding.update(kwargs) + copy.encoding = core.FacetedEncoding(**encoding) + return copy diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 40958c202..521a898a2 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -99,6 +99,7 @@ def load_schema() -> dict: CHANNEL_MIXINS: Final = """ class FieldChannelMixin: + _encoding_name: str def to_dict( self, validate: bool = True, @@ -167,6 +168,7 @@ def to_dict( class ValueChannelMixin: + _encoding_name: str def to_dict( self, validate: bool = True, @@ -190,6 +192,7 @@ def to_dict( class DatumChannelMixin: + _encoding_name: str def to_dict( self, validate: bool = True, @@ -239,10 +242,30 @@ def configure_{prop}(self, *args, **kwargs) -> Self: return copy """ -ENCODE_SIGNATURE: Final = ''' -def _encode_signature({encode_method_args}): - """{docstring}""" - ... +ENCODE_METHOD: Final = ''' +class _EncodingMixin: + def encode({encode_method_args}) -> Self: + """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) + {docstring}""" + # Compat prep for `infer_encoding_types` signature + kwargs = locals() + kwargs.pop("self") + args = kwargs.pop("args") + if args: + kwargs = {{k: v for k, v in kwargs.items() if v is not Undefined}} + + # Convert args to kwargs based on their types. + kwargs = _infer_encoding_types(args, kwargs) + # get a copy of the dict representation of the previous encoding + # ignore type as copy method comes from SchemaBase + copy = self.copy(deep=['encoding']) # type: ignore[attr-defined] + encoding = copy._get('encoding', {{}}) + if isinstance(encoding, core.VegaLiteSchema): + encoding = {{k: v for k, v in encoding._kwds.items() if v is not Undefined}} + # update with the new encodings, and apply them to the copy + encoding.update(kwargs) + copy.encoding = core.FacetedEncoding(**encoding) + return copy ''' @@ -563,9 +586,10 @@ def generate_vegalite_channel_wrappers( "from . import core", "import pandas as pd", "from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters", - "from altair.utils import parse_shorthand", + "from altair.utils import parse_shorthand, infer_encoding_types as _infer_encoding_types", "from typing import Any, overload, Sequence, List, Literal, Union, Optional", "from typing import Dict as TypingDict", + "from typing_extensions import Self", ] contents = [HEADER] contents.append(CHANNEL_MYPY_IGNORE_STATEMENTS) @@ -801,8 +825,8 @@ def vegalite_main(skip_download: bool = False) -> None: def _create_encode_signature( channel_infos: Dict[str, ChannelInfo], ) -> str: - signature_args: list[str] = ["self"] - docstring_parameters: list[str] = ["", "Parameters", "----------", ""] + signature_args: list[str] = ["self", "*args: Any"] + docstring_parameters: list[str] = ["", "Parameters", "----------"] for channel, info in channel_infos.items(): field_class_name = info.field_class_name assert ( @@ -842,9 +866,9 @@ def _create_encode_signature( if len(docstring_parameters) > 1: docstring_parameters += [""] docstring = indent_docstring( - docstring_parameters, indent_level=4, width=100, lstrip=True + docstring_parameters, indent_level=4, width=100, lstrip=False ) - return ENCODE_SIGNATURE.format( + return ENCODE_METHOD.format( encode_method_args=", ".join(signature_args), docstring=docstring ) From 0b19a342b4005183907f3be7b6d6e2461488abd5 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Jun 2024 20:54:05 +0100 Subject: [PATCH 02/11] fix(typing): Resolve assignment type errors revealed Incompatible types in assignment (expression has type "Chart", variable has type "DataFrame") --- tests/examples_arguments_syntax/select_detail.py | 4 ++-- tests/examples_methods_syntax/select_detail.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/examples_arguments_syntax/select_detail.py b/tests/examples_arguments_syntax/select_detail.py index a6b558df7..500668c4f 100644 --- a/tests/examples_arguments_syntax/select_detail.py +++ b/tests/examples_arguments_syntax/select_detail.py @@ -54,7 +54,7 @@ color=alt.condition(selector, 'id:O', alt.value('lightgray'), legend=None), ) -timeseries = base.mark_line().encode( +line = base.mark_line().encode( x='time', y=alt.Y('value', scale=alt.Scale(domain=(-15, 15))), color=alt.Color('id:O', legend=None) @@ -62,4 +62,4 @@ selector ) -points | timeseries +points | line diff --git a/tests/examples_methods_syntax/select_detail.py b/tests/examples_methods_syntax/select_detail.py index 4dead24c5..58bdb9dfd 100644 --- a/tests/examples_methods_syntax/select_detail.py +++ b/tests/examples_methods_syntax/select_detail.py @@ -54,7 +54,7 @@ color=alt.condition(selector, 'id:O', alt.value('lightgray'), legend=None), ) -timeseries = base.mark_line().encode( +line = base.mark_line().encode( x='time', y=alt.Y('value').scale(domain=(-15, 15)), color=alt.Color('id:O').legend(None) @@ -62,4 +62,4 @@ selector ) -points | timeseries +points | line From 0eb139d0fe9faa73cdf823e8ec3f26ce9dc09f8a Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Jun 2024 20:58:16 +0100 Subject: [PATCH 03/11] fix(typing): Resolve direct arg-type errors revealed `Color` -> `Fill` when passed to `fill` channel --- tests/examples_arguments_syntax/hexbins.py | 2 +- tests/examples_methods_syntax/hexbins.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/examples_arguments_syntax/hexbins.py b/tests/examples_arguments_syntax/hexbins.py index a8d91ac74..3005fcf00 100644 --- a/tests/examples_arguments_syntax/hexbins.py +++ b/tests/examples_arguments_syntax/hexbins.py @@ -30,7 +30,7 @@ labelPadding=20, tickOpacity=0, domainOpacity=0)), stroke=alt.value('black'), strokeWidth=alt.value(0.2), - fill=alt.Color('mean(temp_max):Q', scale=alt.Scale(scheme='darkblue')), + fill=alt.Fill('mean(temp_max):Q', scale=alt.Scale(scheme='darkblue')), tooltip=['month(' + xField + '):O', 'day(' + yField + '):O', 'mean(temp_max):Q'] ).transform_calculate( # This field is required for the hexagonal X-Offset diff --git a/tests/examples_methods_syntax/hexbins.py b/tests/examples_methods_syntax/hexbins.py index 26f3890a0..4e23e7bb0 100644 --- a/tests/examples_methods_syntax/hexbins.py +++ b/tests/examples_methods_syntax/hexbins.py @@ -32,7 +32,7 @@ .axis(labelPadding=20, tickOpacity=0, domainOpacity=0), stroke=alt.value('black'), strokeWidth=alt.value(0.2), - fill=alt.Color('mean(temp_max):Q').scale(scheme='darkblue'), + fill=alt.Fill('mean(temp_max):Q').scale(scheme='darkblue'), tooltip=['month(' + xField + '):O', 'day(' + yField + '):O', 'mean(temp_max):Q'] ).transform_calculate( # This field is required for the hexagonal X-Offset From 92d4ac3077a51bb670e5920be065c3f90105360c Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Jun 2024 21:18:42 +0100 Subject: [PATCH 04/11] fix(typing): Resolve `alt.condition` overload-related arg-type errors revealed 'error: Argument "color" to "encode" of "_EncodingMixin" has incompatible type "dict[Any, Any] | SchemaBase"; expected "str | Color | dict[Any, Any] | ColorDatum | ColorValue | UndefinedType" [arg-type]' --- altair/vegalite/v5/api.py | 45 +++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 7b5e648d6..8e426242e 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -341,6 +341,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: return False +_TestPredicateType = Union[str, _expr_core.Expression, core.PredicateComposition] +_PredicateType = Union[ + Parameter, + core.Expr, + TypingDict[str, Any], + _TestPredicateType, + _expr_core.OperatorMixin, +] +_ConditionType = TypingDict[str, Union[_TestPredicateType, Any]] +_DictOrStr = Union[TypingDict[str, Any], str] +_DictOrSchema = Union[core.SchemaBase, TypingDict[str, Any]] +_StatementType = Union[core.SchemaBase, _DictOrStr] + # ------------------------------------------------------------------------ # Top-Level Functions @@ -786,15 +799,33 @@ def binding_range(**kwargs): return core.BindRange(input="range", **kwargs) +_TSchemaBase = typing.TypeVar("_TSchemaBase", bound=core.SchemaBase) + + +@typing.overload +def condition( + predicate: _PredicateType, if_true: _StatementType, if_false: _TSchemaBase, **kwargs +) -> _TSchemaBase: ... +@typing.overload +def condition( + predicate: _PredicateType, if_true: str, if_false: str, **kwargs +) -> typing.NoReturn: ... +@typing.overload +def condition( + predicate: _PredicateType, if_true: _DictOrSchema, if_false: _DictOrStr, **kwargs +) -> TypingDict[str, Union[_ConditionType, Any]]: ... +@typing.overload +def condition( + predicate: _PredicateType, + if_true: _DictOrStr, + if_false: TypingDict[str, Any], + **kwargs, +) -> TypingDict[str, Union[_ConditionType, Any]]: ... # TODO: update the docstring def condition( - predicate: Union[ - Parameter, str, expr.Expression, core.Expr, core.PredicateComposition, dict - ], - # Types of these depends on where the condition is used so we probably - # can't be more specific here. - if_true: Any, - if_false: Any, + predicate: _PredicateType, + if_true: _StatementType, + if_false: _StatementType, **kwargs, ) -> Union[dict, core.SchemaBase]: """A conditional attribute or encoding From e4ab7052e9a1e62ff1fd80379864489c69a1e020 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 20 Jun 2024 21:54:38 +0100 Subject: [PATCH 05/11] test: update `infer_encoding_types` tests - New implementation does not use `**kwargs`, which eliminates an entire class of tests based on `.encode(invalidChannel=...)` as these now trigger a runtime error --- tests/utils/test_core.py | 11 ++--- tests/utils/test_schemapi.py | 95 +++++++++++------------------------- 2 files changed, 33 insertions(+), 73 deletions(-) diff --git a/tests/utils/test_core.py b/tests/utils/test_core.py index 27cd3b7ee..fd61518ee 100644 --- a/tests/utils/test_core.py +++ b/tests/utils/test_core.py @@ -260,6 +260,7 @@ def _getargs(*args, **kwargs): return args, kwargs +# NOTE: Dependent on a no longer needed implementation detail def test_infer_encoding_types(channels): expected = { "x": channels.X("xval"), @@ -285,8 +286,6 @@ def test_infer_encoding_types(channels): def test_infer_encoding_types_with_condition(): - channels = alt.channels - args, kwds = _getargs( size=alt.condition("pred1", alt.value(1), alt.value(2)), color=alt.condition("pred2", alt.value("red"), "cfield:N"), @@ -294,19 +293,19 @@ def test_infer_encoding_types_with_condition(): ) expected = { - "size": channels.SizeValue( + "size": alt.SizeValue( 2, condition=alt.ConditionalPredicateValueDefnumberExprRef( value=1, test=alt.Predicate("pred1") ), ), - "color": channels.Color( + "color": alt.Color( "cfield:N", condition=alt.ConditionalPredicateValueDefGradientstringnullExprRef( value="red", test=alt.Predicate("pred2") ), ), - "opacity": channels.OpacityValue( + "opacity": alt.OpacityValue( 0.2, condition=alt.ConditionalPredicateMarkPropFieldOrDatumDef( field=alt.FieldName("ofield"), @@ -315,7 +314,7 @@ def test_infer_encoding_types_with_condition(): ), ), } - assert infer_encoding_types(args, kwds, channels) == expected + assert infer_encoding_types(args, kwds) == expected def test_invalid_data_type(): diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index e1d0e5cc3..948458080 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -1,3 +1,4 @@ +# ruff: noqa: W291 import copy import io import inspect @@ -440,22 +441,6 @@ def chart_error_example__hconcat(): return points | text -def chart_error_example__invalid_channel(): - # Error: invalidChannel is an invalid encoding channel. Condition is correct - # but is added below as in previous implementations of Altair this interfered - # with finding the invalidChannel error - selection = alt.selection_point() - return ( - alt.Chart(data.barley()) - .mark_circle() - .add_params(selection) - .encode( - color=alt.condition(selection, alt.value("red"), alt.value("green")), - invalidChannel=None, - ) - ) - - def chart_error_example__invalid_y_option_value_unknown_x_option(): # Error 1: unknown is an invalid channel option for X # Error 2: Invalid Y option value "asdf" and unknown option "unknown" for X @@ -557,10 +542,10 @@ def chart_error_example__wrong_tooltip_type_in_layered_chart(): def chart_error_example__two_errors_in_layered_chart(): # Error 1: Wrong data type to pass to tooltip - # Error 2: invalidChannel is not a valid encoding channel + # Error 2: `Color` has no parameter named 'invalidChannel' return alt.layer( alt.Chart().mark_point().encode(tooltip=[{"wrong"}]), - alt.Chart().mark_line().encode(invalidChannel="unknown"), + alt.Chart().mark_line().encode(alt.Color(invalidChannel="unknown")), ) @@ -575,7 +560,7 @@ def chart_error_example__two_errors_in_complex_concat_layered_chart(): def chart_error_example__three_errors_in_complex_concat_layered_chart(): # Error 1: Wrong data type to pass to tooltip - # Error 2: invalidChannel is not a valid encoding channel + # Error 2: `Color` has no parameter named 'invalidChannel' # Error 3: Invalid value for bandPosition return ( chart_error_example__two_errors_in_layered_chart() @@ -585,7 +570,7 @@ def chart_error_example__three_errors_in_complex_concat_layered_chart(): def chart_error_example__two_errors_with_one_in_nested_layered_chart(): # Error 1: invalidOption is not a valid option for Scale - # Error 2: invalidChannel is not a valid encoding channel + # Error 2: `Color` has no parameter named 'invalidChannel' # In the final chart object, the `layer` attribute will look like this: # [alt.Chart(...), alt.Chart(...), alt.LayerChart(...)] @@ -621,7 +606,7 @@ def chart_error_example__two_errors_with_one_in_nested_layered_chart(): base = alt.Chart().encode(y=alt.datum(300)) - rule = base.mark_rule().encode(invalidChannel=2) + rule = base.mark_rule().encode(alt.Color(invalidChannel="unknown")) text = base.mark_text(text="hazardous") rule_text = rule + text @@ -648,6 +633,11 @@ def chart_error_example__four_errors(): ) +# NOTE: These 6 `TypeError`s are due to the removal of `kwargs` in Chart.encode +# That is, only declared keywords can now be used in Chart.encode, +# so these are not conditions worth checking against. + + @pytest.mark.parametrize( "chart_func, expected_error_message", [ @@ -691,18 +681,14 @@ def chart_error_example__four_errors(): Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: `Encoding` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidChannel' Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href - - See the help for `Encoding` to read the full description of these parameters$""" # noqa: W291 + shorthand bin legend timeUnit + aggregate condition scale title + bandPosition field sort type + + See the help for `Color` to read the full description of these parameters$""" # noqa: W291 ), ), ( @@ -722,18 +708,14 @@ def chart_error_example__four_errors(): Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: `Encoding` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidChannel' Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href + shorthand bin legend timeUnit + aggregate condition scale title + bandPosition field sort type - See the help for `Encoding` to read the full description of these parameters + See the help for `Color` to read the full description of these parameters Error 3: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""" # noqa: W291 ), @@ -754,18 +736,14 @@ def chart_error_example__four_errors(): See the help for `Scale` to read the full description of these parameters - Error 2: `Encoding` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidChannel' Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href - - See the help for `Encoding` to read the full description of these parameters$""" # noqa: W291 + shorthand bin legend timeUnit + aggregate condition scale title + bandPosition field sort type + + See the help for `Color` to read the full description of these parameters$""" # noqa: W291 ), ), ( @@ -806,23 +784,6 @@ def chart_error_example__four_errors(): r"""'{'text': 'Horsepower', 'align': 'right'}' is an invalid value for `title`. Valid values are of type 'string', 'array', or 'null'.$""" ), ), - ( - chart_error_example__invalid_channel, - inspect.cleandoc( - r"""`Encoding` has no parameter named 'invalidChannel' - - Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href - - See the help for `Encoding` to read the full description of these parameters$""" # noqa: W291 - ), - ), ( chart_error_example__invalid_timeunit_value, inspect.cleandoc( From fe736f98dba106da717073f541893c748d8136b7 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Jun 2024 13:33:15 +0100 Subject: [PATCH 06/11] test: Rename `invalidChannel` to `invalidArgument` Fixes https://github.com/vega/altair/pull/3444/files/e4ab7052e9a1e62ff1fd80379864489c69a1e020#r1657008627 --- tests/utils/test_schemapi.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index 948458080..0445a195d 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -542,10 +542,10 @@ def chart_error_example__wrong_tooltip_type_in_layered_chart(): def chart_error_example__two_errors_in_layered_chart(): # Error 1: Wrong data type to pass to tooltip - # Error 2: `Color` has no parameter named 'invalidChannel' + # Error 2: `Color` has no parameter named 'invalidArgument' return alt.layer( alt.Chart().mark_point().encode(tooltip=[{"wrong"}]), - alt.Chart().mark_line().encode(alt.Color(invalidChannel="unknown")), + alt.Chart().mark_line().encode(alt.Color(invalidArgument="unknown")), ) @@ -560,7 +560,7 @@ def chart_error_example__two_errors_in_complex_concat_layered_chart(): def chart_error_example__three_errors_in_complex_concat_layered_chart(): # Error 1: Wrong data type to pass to tooltip - # Error 2: `Color` has no parameter named 'invalidChannel' + # Error 2: `Color` has no parameter named 'invalidArgument' # Error 3: Invalid value for bandPosition return ( chart_error_example__two_errors_in_layered_chart() @@ -570,7 +570,7 @@ def chart_error_example__three_errors_in_complex_concat_layered_chart(): def chart_error_example__two_errors_with_one_in_nested_layered_chart(): # Error 1: invalidOption is not a valid option for Scale - # Error 2: `Color` has no parameter named 'invalidChannel' + # Error 2: `Color` has no parameter named 'invalidArgument' # In the final chart object, the `layer` attribute will look like this: # [alt.Chart(...), alt.Chart(...), alt.LayerChart(...)] @@ -606,7 +606,7 @@ def chart_error_example__two_errors_with_one_in_nested_layered_chart(): base = alt.Chart().encode(y=alt.datum(300)) - rule = base.mark_rule().encode(alt.Color(invalidChannel="unknown")) + rule = base.mark_rule().encode(alt.Color(invalidArgument="unknown")) text = base.mark_text(text="hazardous") rule_text = rule + text @@ -681,7 +681,7 @@ def chart_error_example__four_errors(): Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: `Color` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidArgument' Existing parameter names are: shorthand bin legend timeUnit @@ -708,7 +708,7 @@ def chart_error_example__four_errors(): Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: `Color` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidArgument' Existing parameter names are: shorthand bin legend timeUnit @@ -736,7 +736,7 @@ def chart_error_example__four_errors(): See the help for `Scale` to read the full description of these parameters - Error 2: `Color` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidArgument' Existing parameter names are: shorthand bin legend timeUnit From 0e0e167e11b298dd971707dd2f41d2ad5e54ad16 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:39:32 +0100 Subject: [PATCH 07/11] chore: remove PR note comment Fixes https://github.com/vega/altair/pull/3444#discussion_r1657014594 --- tests/utils/test_schemapi.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index 0445a195d..d2bdca0d1 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -633,11 +633,6 @@ def chart_error_example__four_errors(): ) -# NOTE: These 6 `TypeError`s are due to the removal of `kwargs` in Chart.encode -# That is, only declared keywords can now be used in Chart.encode, -# so these are not conditions worth checking against. - - @pytest.mark.parametrize( "chart_func, expected_error_message", [ From 064039dfb790e6fb4ec40f463baaea6a8eff1ab8 Mon Sep 17 00:00:00 2001 From: Dan Redding <125183946+dangotbanned@users.noreply.github.com> Date: Thu, 27 Jun 2024 14:41:16 +0100 Subject: [PATCH 08/11] docs: fix typo --- altair/utils/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 5ccfdace8..222cc46fc 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -849,7 +849,7 @@ def _init_channel_to_name(): Note ---- The return type is not expressible using annotations, but is used - internally by `mypy`/`pyright` and avoid the need for type ignores. + internally by `mypy`/`pyright` and avoids the need for type ignores. Returns ------- From dbe8a74670cb3a54c03d98f4af3fc3c5a0de3e1b Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Thu, 27 Jun 2024 21:04:36 +0200 Subject: [PATCH 09/11] Exclude LookupData export from core.py to fix issue with mypy where it assumes that altair.LookupData comes from core.py instead of api.py --- altair/vegalite/v5/schema/core.py | 1 - tools/generate_schema_wrapper.py | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 38bc0868a..e232d0206 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -246,7 +246,6 @@ "LogicalAndPredicate", "LogicalNotPredicate", "LogicalOrPredicate", - "LookupData", "LookupSelection", "LookupTransform", "Mark", diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index c632f1a65..77839ba6a 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -498,10 +498,10 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str: child.basename.append(name) # Specify __all__ explicitly so that we can exclude the ones from the list - # of exported classes which are also defined in the channels module which takes - # precedent in the generated __init__.py file one level up where core.py - # and channels.py are imported. Importing both confuses type checkers. - it = (c for c in definitions.keys() - {"Color", "Text"} if not c.startswith("_")) + # of exported classes which are also defined in the channels or api modules which takes + # precedent in the generated __init__.py files one and two levels up. + # Importing these classes from multiple modules confuses type checkers. + it = (c for c in definitions.keys() - {"Color", "Text", "LookupData"} if not c.startswith("_")) all_ = [*sorted(it), "Root", "VegaLiteSchema", "SchemaBase", "load_schema"] contents = [ From be472e61d53ca2327e017b77dca5544fa4a154f0 Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Thu, 27 Jun 2024 21:04:58 +0200 Subject: [PATCH 10/11] Remove 'pd' and 'jsonschema' from __init__.py __all__. Unclear why they show up only now... --- tools/update_init_file.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/update_init_file.py b/tools/update_init_file.py index fc694b7d8..5e8bbc6ab 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -92,7 +92,7 @@ def relevant_attributes(namespace: dict[str, Any], /) -> list[str]: it = ( name for name, attr in namespace.items() - if (not name.startswith("_")) and _is_relevant(attr) + if (not name.startswith("_")) and _is_relevant(attr, name) ) return sorted(it) @@ -105,12 +105,13 @@ def _is_hashable(obj: Any) -> bool: return False -def _is_relevant(attr: Any, /) -> bool: +def _is_relevant(attr: Any, name: str, /) -> bool: """Predicate logic for filtering attributes.""" if ( getattr_static(attr, "_deprecated", False) or attr is TYPE_CHECKING or (_is_hashable(attr) and attr in _TYPING_CONSTRUCTS) + or name in {"pd", "jsonschema"} ): return False elif ismodule(attr): From f8dd7482696578133dbcf6b1cb2a792c6f3f4d2b Mon Sep 17 00:00:00 2001 From: Stefan Binder Date: Thu, 27 Jun 2024 21:07:05 +0200 Subject: [PATCH 11/11] Format code --- tools/generate_schema_wrapper.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index 77839ba6a..f1a73670f 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -501,7 +501,11 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str: # of exported classes which are also defined in the channels or api modules which takes # precedent in the generated __init__.py files one and two levels up. # Importing these classes from multiple modules confuses type checkers. - it = (c for c in definitions.keys() - {"Color", "Text", "LookupData"} if not c.startswith("_")) + it = ( + c + for c in definitions.keys() - {"Color", "Text", "LookupData"} + if not c.startswith("_") + ) all_ = [*sorted(it), "Root", "VegaLiteSchema", "SchemaBase", "load_schema"] contents = [