diff --git a/altair/utils/core.py b/altair/utils/core.py index 71db80861..15453028e 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -16,20 +16,22 @@ Callable, TypeVar, Any, - Sequence, + Iterator, cast, Literal, Protocol, TYPE_CHECKING, runtime_checkable, ) +from itertools import groupby +from operator import itemgetter import jsonschema import pandas as pd import numpy as np from pandas.api.types import infer_dtype -from altair.utils.schemapi import SchemaBase +from altair.utils.schemapi import SchemaBase, Undefined from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame if sys.version_info >= (3, 10): @@ -773,9 +775,133 @@ def display_traceback(in_ipython: bool = True): traceback.print_exception(*exc_info) +_ChannelType = Literal["field", "datum", "value"] +_CHANNEL_CACHE: _ChannelCache +"""Singleton `_ChannelCache` instance. + +Initialized on first use. +""" + + +class _ChannelCache: + channel_to_name: dict[type[SchemaBase], str] + name_to_channel: dict[str, dict[_ChannelType, type[SchemaBase]]] + + @classmethod + def from_channels(cls, channels: ModuleType, /) -> _ChannelCache: + # - This branch is only kept for tests that depend on mocking `channels`. + # - No longer needs to pass around `channels` reference and rebuild every call. + c_to_n = { + c: c._encoding_name + for c in channels.__dict__.values() + if isinstance(c, type) + and issubclass(c, SchemaBase) + and hasattr(c, "_encoding_name") + } + self = cls.__new__(cls) + self.channel_to_name = c_to_n + self.name_to_channel = _invert_group_channels(c_to_n) + return self + + @classmethod + def from_cache(cls) -> _ChannelCache: + global _CHANNEL_CACHE + try: + cached = _CHANNEL_CACHE + except NameError: + cached = cls.__new__(cls) + cached.channel_to_name = _init_channel_to_name() + cached.name_to_channel = _invert_group_channels(cached.channel_to_name) + _CHANNEL_CACHE = cached + return _CHANNEL_CACHE + + def get_encoding(self, tp: type[Any], /) -> str: + if encoding := self.channel_to_name.get(tp): + return encoding + msg = f"positional of type {type(tp).__name__!r}" + raise NotImplementedError(msg) + + def _wrap_in_channel(self, obj: Any, encoding: str, /): + if isinstance(obj, SchemaBase): + return obj + elif isinstance(obj, str): + obj = {"shorthand": obj} + elif isinstance(obj, (list, tuple)): + return [self._wrap_in_channel(el, encoding) for el in obj] + if channel := self.name_to_channel.get(encoding): + tp = channel["value" if "value" in obj else "field"] + try: + # Don't force validation here; some objects won't be valid until + # they're created in the context of a chart. + return tp.from_dict(obj, validate=False) + except jsonschema.ValidationError: + # our attempts at finding the correct class have failed + return obj + else: + warnings.warn(f"Unrecognized encoding channel {encoding!r}", stacklevel=1) + return obj + + def infer_encoding_types(self, kwargs: dict[str, Any], /): + return { + encoding: self._wrap_in_channel(obj, encoding) + for encoding, obj in kwargs.items() + if obj is not Undefined + } + + +def _init_channel_to_name(): + """ + Construct a dictionary of channel type to encoding name. + + Note + ---- + The return type is not expressible using annotations, but is used + internally by `mypy`/`pyright` and avoids the need for type ignores. + + Returns + ------- + mapping: dict[type[``] | type[``] | type[``], str] + """ + from altair.vegalite.v5.schema import channels as ch + + mixins = ch.FieldChannelMixin, ch.ValueChannelMixin, ch.DatumChannelMixin + + return { + c: c._encoding_name + for c in ch.__dict__.values() + if isinstance(c, type) and issubclass(c, mixins) and issubclass(c, SchemaBase) + } + + +def _invert_group_channels( + m: dict[type[SchemaBase], str], / +) -> dict[str, dict[_ChannelType, type[SchemaBase]]]: + """Grouped inverted index for `_ChannelCache.channel_to_name`.""" + + def _reduce(it: Iterator[tuple[type[Any], str]]) -> Any: + """Returns a 1-2 item dict, per channel. + + Never includes `datum`, as it is never utilized in `wrap_in_channel`. + """ + item: dict[Any, type[SchemaBase]] = {} + for tp, _ in it: + name = tp.__name__ + if name.endswith("Datum"): + continue + elif name.endswith("Value"): + sub_key = "value" + else: + sub_key = "field" + item[sub_key] = tp + return item + + grouper = groupby(m.items(), itemgetter(1)) + return {k: _reduce(chans) for k, chans in grouper} + + def infer_encoding_types( - args: Sequence[Any], kwargs: t.MutableMapping[str, Any], channels: ModuleType -) -> dict[str, SchemaBase | list | dict[str, str] | Any]: + args: tuple[Any, ...], kwargs: dict[str, Any], channels: ModuleType | None = None +): """Infer typed keyword arguments for args and kwargs Parameters @@ -793,68 +919,19 @@ def infer_encoding_types( All args and kwargs in a single dict, with keys and types based on the channels mapping. """ - # Construct a dictionary of channel type to encoding name - # TODO: cache this somehow? - channel_objs = (getattr(channels, name) for name in dir(channels)) - channel_objs = ( - c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) + cache = ( + _ChannelCache.from_channels(channels) + if channels + else _ChannelCache.from_cache() ) - channel_to_name: dict[type[SchemaBase], str] = { - c: c._encoding_name for c in channel_objs - } - name_to_channel: dict[str, dict[str, type[SchemaBase]]] = {} - for chan, name in channel_to_name.items(): - chans = name_to_channel.setdefault(name, {}) - if chan.__name__.endswith("Datum"): - key = "datum" - elif chan.__name__.endswith("Value"): - key = "value" - else: - key = "field" - chans[key] = chan - # First use the mapping to convert args to kwargs based on their types. for arg in args: - if isinstance(arg, (list, tuple)) and len(arg) > 0: - type_ = type(arg[0]) + el = next(iter(arg), None) if isinstance(arg, (list, tuple)) else arg + encoding = cache.get_encoding(type(el)) + if encoding not in kwargs: + kwargs[encoding] = arg else: - type_ = type(arg) - - encoding = channel_to_name.get(type_) - if encoding is None: - msg = f"positional of type {type_}" "" - raise NotImplementedError(msg) - if encoding in kwargs: - msg = f"encoding {encoding} specified twice." + msg = f"encoding {encoding!r} specified twice." raise ValueError(msg) - kwargs[encoding] = arg - - def _wrap_in_channel_class(obj, encoding): - if isinstance(obj, SchemaBase): - return obj - - if isinstance(obj, str): - obj = {"shorthand": obj} - - if isinstance(obj, (list, tuple)): - return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] - - if encoding not in name_to_channel: - warnings.warn(f"Unrecognized encoding channel '{encoding}'", stacklevel=1) - return obj - classes = name_to_channel[encoding] - cls = classes["value"] if "value" in obj else classes["field"] - - try: - # Don't force validation here; some objects won't be valid until - # they're created in the context of a chart. - return cls.from_dict(obj, validate=False) - except jsonschema.ValidationError: - # our attempts at finding the correct class have failed - return obj - - return { - encoding: _wrap_in_channel_class(obj, encoding) - for encoding, obj in kwargs.items() - } + return cache.infer_encoding_types(kwargs) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 715b132e0..f25491865 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -8,6 +8,7 @@ import itertools from typing import Union, cast, Any, Iterable, Literal, IO, TYPE_CHECKING from typing_extensions import TypeAlias +import typing from .schema import core, channels, mixins, Undefined, SCHEMA_URL @@ -74,8 +75,6 @@ Step, RepeatRef, NonNormalizedSpec, - LayerSpec, - UnitSpec, UrlData, SequenceGenerator, GraticuleGenerator, @@ -381,6 +380,19 @@ def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: return False +_TestPredicateType = Union[str, _expr_core.Expression, core.PredicateComposition] +_PredicateType = Union[ + Parameter, + core.Expr, + typing.Dict[str, Any], + _TestPredicateType, + _expr_core.OperatorMixin, +] +_ConditionType = typing.Dict[str, Union[_TestPredicateType, Any]] +_DictOrStr = Union[typing.Dict[str, Any], str] +_DictOrSchema = Union[core.SchemaBase, typing.Dict[str, Any]] +_StatementType = Union[core.SchemaBase, _DictOrStr] + # ------------------------------------------------------------------------ # Top-Level Functions @@ -826,18 +838,33 @@ def binding_range(**kwargs): return core.BindRange(input="range", **kwargs) +_TSchemaBase = typing.TypeVar("_TSchemaBase", bound=core.SchemaBase) + + +@typing.overload +def condition( + predicate: _PredicateType, if_true: _StatementType, if_false: _TSchemaBase, **kwargs +) -> _TSchemaBase: ... +@typing.overload +def condition( + predicate: _PredicateType, if_true: str, if_false: str, **kwargs +) -> typing.NoReturn: ... +@typing.overload +def condition( + predicate: _PredicateType, if_true: _DictOrSchema, if_false: _DictOrStr, **kwargs +) -> dict[str, _ConditionType | Any]: ... +@typing.overload +def condition( + predicate: _PredicateType, + if_true: _DictOrStr, + if_false: dict[str, Any], + **kwargs, +) -> dict[str, _ConditionType | Any]: ... # TODO: update the docstring def condition( - predicate: Parameter - | str - | Expression - | Expr - | PredicateComposition - | dict[str, Any], - # Types of these depends on where the condition is used so we probably - # can't be more specific here. - if_true: Any, - if_false: Any, + predicate: _PredicateType, + if_true: _StatementType, + if_false: _StatementType, **kwargs, ) -> dict[str, Any] | SchemaBase: """A conditional attribute or encoding @@ -2729,24 +2756,7 @@ def resolve_scale(self, *args, **kwargs) -> Self: return self._set_resolve(scale=core.ScaleResolveMap(*args, **kwargs)) -class _EncodingMixin: - @utils.use_signature(channels._encode_signature) - def encode(self, *args, **kwargs) -> Self: - # Convert args to kwargs based on their types. - kwargs = utils.infer_encoding_types(args, kwargs, channels) - - # get a copy of the dict representation of the previous encoding - # ignore type as copy method comes from SchemaBase - copy = self.copy(deep=["encoding"]) # type: ignore[attr-defined] - encoding = copy._get("encoding", {}) - if isinstance(encoding, core.VegaLiteSchema): - encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined} - - # update with the new encodings, and apply them to the copy - encoding.update(kwargs) - copy.encoding = core.FacetedEncoding(**encoding) - return copy - +class _EncodingMixin(channels._EncodingMixin): def facet( self, facet: Optional[str | Facet] = Undefined, @@ -3614,7 +3624,7 @@ def transformed_data( return transformed_data(self, row_limit=row_limit, exclude=exclude) - def __iadd__(self, other: LayerSpec | UnitSpec) -> Self: + def __iadd__(self, other: LayerChart | Chart) -> Self: _check_if_valid_subspec(other, "LayerChart") _check_if_can_be_layered(other) self.layer.append(other) @@ -3622,12 +3632,12 @@ def __iadd__(self, other: LayerSpec | UnitSpec) -> Self: self.params, self.layer = _combine_subchart_params(self.params, self.layer) return self - def __add__(self, other: LayerSpec | UnitSpec) -> Self: + def __add__(self, other: LayerChart | Chart) -> Self: copy = self.copy(deep=["layer"]) copy += other return copy - def add_layers(self, *layers: LayerSpec | UnitSpec) -> Self: + def add_layers(self, *layers: LayerChart | Chart) -> Self: copy = self.copy(deep=["layer"]) for layer in layers: copy += layer diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index 60d71f1a8..8a92ce045 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -15,6 +15,7 @@ import pandas as pd +from altair.utils import infer_encoding_types as _infer_encoding_types from altair.utils import parse_shorthand from altair.utils.schemapi import Undefined, with_property_setters @@ -22,6 +23,8 @@ # ruff: noqa: F405 if TYPE_CHECKING: + from typing_extensions import Self + from altair import Parameter, SchemaBase from altair.utils.schemapi import Optional @@ -29,6 +32,8 @@ class FieldChannelMixin: + _encoding_name: str + def to_dict( self, validate: bool = True, @@ -92,6 +97,8 @@ def to_dict( class ValueChannelMixin: + _encoding_name: str + def to_dict( self, validate: bool = True, @@ -115,6 +122,8 @@ def to_dict( class DatumChannelMixin: + _encoding_name: str + def to_dict( self, validate: bool = True, @@ -34194,259 +34203,289 @@ def __init__(self, value, **kwds): super().__init__(value=value, **kwds) -def _encode_signature( - self, - angle: Optional[str | Angle | dict | AngleDatum | AngleValue] = Undefined, - color: Optional[str | Color | dict | ColorDatum | ColorValue] = Undefined, - column: Optional[str | Column | dict] = Undefined, - description: Optional[str | Description | dict | DescriptionValue] = Undefined, - detail: Optional[str | Detail | dict | list] = Undefined, - facet: Optional[str | Facet | dict] = Undefined, - fill: Optional[str | Fill | dict | FillDatum | FillValue] = Undefined, - fillOpacity: Optional[ - str | FillOpacity | dict | FillOpacityDatum | FillOpacityValue - ] = Undefined, - href: Optional[str | Href | dict | HrefValue] = Undefined, - key: Optional[str | Key | dict] = Undefined, - latitude: Optional[str | Latitude | dict | LatitudeDatum] = Undefined, - latitude2: Optional[ - str | Latitude2 | dict | Latitude2Datum | Latitude2Value - ] = Undefined, - longitude: Optional[str | Longitude | dict | LongitudeDatum] = Undefined, - longitude2: Optional[ - str | Longitude2 | dict | Longitude2Datum | Longitude2Value - ] = Undefined, - opacity: Optional[str | Opacity | dict | OpacityDatum | OpacityValue] = Undefined, - order: Optional[str | Order | dict | list | OrderValue] = Undefined, - radius: Optional[str | Radius | dict | RadiusDatum | RadiusValue] = Undefined, - radius2: Optional[str | Radius2 | dict | Radius2Datum | Radius2Value] = Undefined, - row: Optional[str | Row | dict] = Undefined, - shape: Optional[str | Shape | dict | ShapeDatum | ShapeValue] = Undefined, - size: Optional[str | Size | dict | SizeDatum | SizeValue] = Undefined, - stroke: Optional[str | Stroke | dict | StrokeDatum | StrokeValue] = Undefined, - strokeDash: Optional[ - str | StrokeDash | dict | StrokeDashDatum | StrokeDashValue - ] = Undefined, - strokeOpacity: Optional[ - str | StrokeOpacity | dict | StrokeOpacityDatum | StrokeOpacityValue - ] = Undefined, - strokeWidth: Optional[ - str | StrokeWidth | dict | StrokeWidthDatum | StrokeWidthValue - ] = Undefined, - text: Optional[str | Text | dict | TextDatum | TextValue] = Undefined, - theta: Optional[str | Theta | dict | ThetaDatum | ThetaValue] = Undefined, - theta2: Optional[str | Theta2 | dict | Theta2Datum | Theta2Value] = Undefined, - tooltip: Optional[str | Tooltip | dict | list | TooltipValue] = Undefined, - url: Optional[str | Url | dict | UrlValue] = Undefined, - x: Optional[str | X | dict | XDatum | XValue] = Undefined, - x2: Optional[str | X2 | dict | X2Datum | X2Value] = Undefined, - xError: Optional[str | XError | dict | XErrorValue] = Undefined, - xError2: Optional[str | XError2 | dict | XError2Value] = Undefined, - xOffset: Optional[str | XOffset | dict | XOffsetDatum | XOffsetValue] = Undefined, - y: Optional[str | Y | dict | YDatum | YValue] = Undefined, - y2: Optional[str | Y2 | dict | Y2Datum | Y2Value] = Undefined, - yError: Optional[str | YError | dict | YErrorValue] = Undefined, - yError2: Optional[str | YError2 | dict | YError2Value] = Undefined, - yOffset: Optional[str | YOffset | dict | YOffsetDatum | YOffsetValue] = Undefined, -): - """Parameters - ---------- - - angle : str, :class:`Angle`, Dict, :class:`AngleDatum`, :class:`AngleValue` - Rotation angle of point and text marks. - color : str, :class:`Color`, Dict, :class:`ColorDatum`, :class:`ColorValue` - Color of the marks - either fill or stroke color based on the ``filled`` property - of mark definition. By default, ``color`` represents fill color for ``"area"``, - ``"bar"``, ``"tick"``, ``"text"``, ``"trail"``, ``"circle"``, and ``"square"`` / - stroke color for ``"line"`` and ``"point"``. - - **Default value:** If undefined, the default color depends on `mark config - `__ 's ``color`` - property. - - *Note:* 1) For fine-grained control over both fill and stroke colors of the marks, - please use the ``fill`` and ``stroke`` channels. The ``fill`` or ``stroke`` - encodings have higher precedence than ``color``, thus may override the ``color`` - encoding if conflicting encodings are specified. 2) See the scale documentation for - more information about customizing `color scheme - `__. - column : str, :class:`Column`, Dict - A field definition for the horizontal facet of trellis plots. - description : str, :class:`Description`, Dict, :class:`DescriptionValue` - A text description of this mark for ARIA accessibility (SVG output only). For SVG - output the ``"aria-label"`` attribute will be set to this description. - detail : str, :class:`Detail`, Dict, List - Additional levels of detail for grouping data in aggregate views and in line, trail, - and area marks without mapping data to a specific visual channel. - facet : str, :class:`Facet`, Dict - A field definition for the (flexible) facet of trellis plots. - - If either ``row`` or ``column`` is specified, this channel will be ignored. - fill : str, :class:`Fill`, Dict, :class:`FillDatum`, :class:`FillValue` - Fill color of the marks. **Default value:** If undefined, the default color depends - on `mark config `__ - 's ``color`` property. - - *Note:* The ``fill`` encoding has higher precedence than ``color``, thus may - override the ``color`` encoding if conflicting encodings are specified. - fillOpacity : str, :class:`FillOpacity`, Dict, :class:`FillOpacityDatum`, :class:`FillOpacityValue` - Fill opacity of the marks. - - **Default value:** If undefined, the default opacity depends on `mark config - `__ 's - ``fillOpacity`` property. - href : str, :class:`Href`, Dict, :class:`HrefValue` - A URL to load upon mouse click. - key : str, :class:`Key`, Dict - A data field to use as a unique key for data binding. When a visualization's data is - updated, the key value will be used to match data elements to existing mark - instances. Use a key channel to enable object constancy for transitions over dynamic - data. - latitude : str, :class:`Latitude`, Dict, :class:`LatitudeDatum` - Latitude position of geographically projected marks. - latitude2 : str, :class:`Latitude2`, Dict, :class:`Latitude2Datum`, :class:`Latitude2Value` - Latitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, - ``"rect"``, and ``"rule"``. - longitude : str, :class:`Longitude`, Dict, :class:`LongitudeDatum` - Longitude position of geographically projected marks. - longitude2 : str, :class:`Longitude2`, Dict, :class:`Longitude2Datum`, :class:`Longitude2Value` - Longitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, - ``"rect"``, and ``"rule"``. - opacity : str, :class:`Opacity`, Dict, :class:`OpacityDatum`, :class:`OpacityValue` - Opacity of the marks. - - **Default value:** If undefined, the default opacity depends on `mark config - `__ 's ``opacity`` - property. - order : str, :class:`Order`, Dict, List, :class:`OrderValue` - Order of the marks. - - - * For stacked marks, this ``order`` channel encodes `stack order - `__. - * For line and trail marks, this ``order`` channel encodes order of data points in - the lines. This can be useful for creating `a connected scatterplot - `__. Setting - ``order`` to ``{"value": null}`` makes the line marks use the original order in - the data sources. - * Otherwise, this ``order`` channel encodes layer order of the marks. - - **Note** : In aggregate plots, ``order`` field should be ``aggregate`` d to avoid - creating additional aggregation grouping. - radius : str, :class:`Radius`, Dict, :class:`RadiusDatum`, :class:`RadiusValue` - The outer radius in pixels of arc marks. - radius2 : str, :class:`Radius2`, Dict, :class:`Radius2Datum`, :class:`Radius2Value` - The inner radius in pixels of arc marks. - row : str, :class:`Row`, Dict - A field definition for the vertical facet of trellis plots. - shape : str, :class:`Shape`, Dict, :class:`ShapeDatum`, :class:`ShapeValue` - Shape of the mark. - - - #. - For ``point`` marks the supported values include: - plotting shapes: ``"circle"``, - ``"square"``, ``"cross"``, ``"diamond"``, ``"triangle-up"``, ``"triangle-down"``, - ``"triangle-right"``, or ``"triangle-left"``. - the line symbol ``"stroke"`` - - centered directional shapes ``"arrow"``, ``"wedge"``, or ``"triangle"`` - a custom - `SVG path string - `__ (For correct - sizing, custom shape paths should be defined within a square bounding box with - coordinates ranging from -1 to 1 along both the x and y dimensions.) - - #. - For ``geoshape`` marks it should be a field definition of the geojson data - - **Default value:** If undefined, the default shape depends on `mark config - `__ 's ``shape`` - property. ( ``"circle"`` if unset.) - size : str, :class:`Size`, Dict, :class:`SizeDatum`, :class:`SizeValue` - Size of the mark. - - - * For ``"point"``, ``"square"`` and ``"circle"``, - the symbol size, or pixel area - of the mark. - * For ``"bar"`` and ``"tick"`` - the bar and tick's size. - * For ``"text"`` - the text's font size. - * Size is unsupported for ``"line"``, ``"area"``, and ``"rect"``. (Use ``"trail"`` - instead of line with varying size) - stroke : str, :class:`Stroke`, Dict, :class:`StrokeDatum`, :class:`StrokeValue` - Stroke color of the marks. **Default value:** If undefined, the default color - depends on `mark config - `__ 's ``color`` - property. - - *Note:* The ``stroke`` encoding has higher precedence than ``color``, thus may - override the ``color`` encoding if conflicting encodings are specified. - strokeDash : str, :class:`StrokeDash`, Dict, :class:`StrokeDashDatum`, :class:`StrokeDashValue` - Stroke dash of the marks. - - **Default value:** ``[1,0]`` (No dash). - strokeOpacity : str, :class:`StrokeOpacity`, Dict, :class:`StrokeOpacityDatum`, :class:`StrokeOpacityValue` - Stroke opacity of the marks. - - **Default value:** If undefined, the default opacity depends on `mark config - `__ 's - ``strokeOpacity`` property. - strokeWidth : str, :class:`StrokeWidth`, Dict, :class:`StrokeWidthDatum`, :class:`StrokeWidthValue` - Stroke width of the marks. - - **Default value:** If undefined, the default stroke width depends on `mark config - `__ 's - ``strokeWidth`` property. - text : str, :class:`Text`, Dict, :class:`TextDatum`, :class:`TextValue` - Text of the ``text`` mark. - theta : str, :class:`Theta`, Dict, :class:`ThetaDatum`, :class:`ThetaValue` - For arc marks, the arc length in radians if theta2 is not specified, otherwise the - start arc angle. (A value of 0 indicates up or “north”, increasing values proceed - clockwise.) - - For text marks, polar coordinate angle in radians. - theta2 : str, :class:`Theta2`, Dict, :class:`Theta2Datum`, :class:`Theta2Value` - The end angle of arc marks in radians. A value of 0 indicates up or “north”, - increasing values proceed clockwise. - tooltip : str, :class:`Tooltip`, Dict, List, :class:`TooltipValue` - The tooltip text to show upon mouse hover. Specifying ``tooltip`` encoding overrides - `the tooltip property in the mark definition - `__. - - See the `tooltip `__ - documentation for a detailed discussion about tooltip in Vega-Lite. - url : str, :class:`Url`, Dict, :class:`UrlValue` - The URL of an image mark. - x : str, :class:`X`, Dict, :class:`XDatum`, :class:`XValue` - X coordinates of the marks, or width of horizontal ``"bar"`` and ``"area"`` without - specified ``x2`` or ``width``. - - The ``value`` of this channel can be a number or a string ``"width"`` for the width - of the plot. - x2 : str, :class:`X2`, Dict, :class:`X2Datum`, :class:`X2Value` - X2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. - - The ``value`` of this channel can be a number or a string ``"width"`` for the width - of the plot. - xError : str, :class:`XError`, Dict, :class:`XErrorValue` - Error value of x coordinates for error specified ``"errorbar"`` and ``"errorband"``. - xError2 : str, :class:`XError2`, Dict, :class:`XError2Value` - Secondary error value of x coordinates for error specified ``"errorbar"`` and - ``"errorband"``. - xOffset : str, :class:`XOffset`, Dict, :class:`XOffsetDatum`, :class:`XOffsetValue` - Offset of x-position of the marks - y : str, :class:`Y`, Dict, :class:`YDatum`, :class:`YValue` - Y coordinates of the marks, or height of vertical ``"bar"`` and ``"area"`` without - specified ``y2`` or ``height``. - - The ``value`` of this channel can be a number or a string ``"height"`` for the - height of the plot. - y2 : str, :class:`Y2`, Dict, :class:`Y2Datum`, :class:`Y2Value` - Y2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. - - The ``value`` of this channel can be a number or a string ``"height"`` for the - height of the plot. - yError : str, :class:`YError`, Dict, :class:`YErrorValue` - Error value of y coordinates for error specified ``"errorbar"`` and ``"errorband"``. - yError2 : str, :class:`YError2`, Dict, :class:`YError2Value` - Secondary error value of y coordinates for error specified ``"errorbar"`` and - ``"errorband"``. - yOffset : str, :class:`YOffset`, Dict, :class:`YOffsetDatum`, :class:`YOffsetValue` - Offset of y-position of the marks - """ +class _EncodingMixin: + def encode( + self, + *args: Any, + angle: Optional[str | Angle | dict | AngleDatum | AngleValue] = Undefined, + color: Optional[str | Color | dict | ColorDatum | ColorValue] = Undefined, + column: Optional[str | Column | dict] = Undefined, + description: Optional[str | Description | dict | DescriptionValue] = Undefined, + detail: Optional[str | Detail | dict | list] = Undefined, + facet: Optional[str | Facet | dict] = Undefined, + fill: Optional[str | Fill | dict | FillDatum | FillValue] = Undefined, + fillOpacity: Optional[ + str | FillOpacity | dict | FillOpacityDatum | FillOpacityValue + ] = Undefined, + href: Optional[str | Href | dict | HrefValue] = Undefined, + key: Optional[str | Key | dict] = Undefined, + latitude: Optional[str | Latitude | dict | LatitudeDatum] = Undefined, + latitude2: Optional[ + str | Latitude2 | dict | Latitude2Datum | Latitude2Value + ] = Undefined, + longitude: Optional[str | Longitude | dict | LongitudeDatum] = Undefined, + longitude2: Optional[ + str | Longitude2 | dict | Longitude2Datum | Longitude2Value + ] = Undefined, + opacity: Optional[ + str | Opacity | dict | OpacityDatum | OpacityValue + ] = Undefined, + order: Optional[str | Order | dict | list | OrderValue] = Undefined, + radius: Optional[str | Radius | dict | RadiusDatum | RadiusValue] = Undefined, + radius2: Optional[ + str | Radius2 | dict | Radius2Datum | Radius2Value + ] = Undefined, + row: Optional[str | Row | dict] = Undefined, + shape: Optional[str | Shape | dict | ShapeDatum | ShapeValue] = Undefined, + size: Optional[str | Size | dict | SizeDatum | SizeValue] = Undefined, + stroke: Optional[str | Stroke | dict | StrokeDatum | StrokeValue] = Undefined, + strokeDash: Optional[ + str | StrokeDash | dict | StrokeDashDatum | StrokeDashValue + ] = Undefined, + strokeOpacity: Optional[ + str | StrokeOpacity | dict | StrokeOpacityDatum | StrokeOpacityValue + ] = Undefined, + strokeWidth: Optional[ + str | StrokeWidth | dict | StrokeWidthDatum | StrokeWidthValue + ] = Undefined, + text: Optional[str | Text | dict | TextDatum | TextValue] = Undefined, + theta: Optional[str | Theta | dict | ThetaDatum | ThetaValue] = Undefined, + theta2: Optional[str | Theta2 | dict | Theta2Datum | Theta2Value] = Undefined, + tooltip: Optional[str | Tooltip | dict | list | TooltipValue] = Undefined, + url: Optional[str | Url | dict | UrlValue] = Undefined, + x: Optional[str | X | dict | XDatum | XValue] = Undefined, + x2: Optional[str | X2 | dict | X2Datum | X2Value] = Undefined, + xError: Optional[str | XError | dict | XErrorValue] = Undefined, + xError2: Optional[str | XError2 | dict | XError2Value] = Undefined, + xOffset: Optional[ + str | XOffset | dict | XOffsetDatum | XOffsetValue + ] = Undefined, + y: Optional[str | Y | dict | YDatum | YValue] = Undefined, + y2: Optional[str | Y2 | dict | Y2Datum | Y2Value] = Undefined, + yError: Optional[str | YError | dict | YErrorValue] = Undefined, + yError2: Optional[str | YError2 | dict | YError2Value] = Undefined, + yOffset: Optional[ + str | YOffset | dict | YOffsetDatum | YOffsetValue + ] = Undefined, + ) -> Self: + """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) + + Parameters + ---------- + angle : str, :class:`Angle`, Dict, :class:`AngleDatum`, :class:`AngleValue` + Rotation angle of point and text marks. + color : str, :class:`Color`, Dict, :class:`ColorDatum`, :class:`ColorValue` + Color of the marks - either fill or stroke color based on the ``filled`` property + of mark definition. By default, ``color`` represents fill color for ``"area"``, + ``"bar"``, ``"tick"``, ``"text"``, ``"trail"``, ``"circle"``, and ``"square"`` / + stroke color for ``"line"`` and ``"point"``. + + **Default value:** If undefined, the default color depends on `mark config + `__ 's ``color`` + property. + + *Note:* 1) For fine-grained control over both fill and stroke colors of the marks, + please use the ``fill`` and ``stroke`` channels. The ``fill`` or ``stroke`` + encodings have higher precedence than ``color``, thus may override the ``color`` + encoding if conflicting encodings are specified. 2) See the scale documentation for + more information about customizing `color scheme + `__. + column : str, :class:`Column`, Dict + A field definition for the horizontal facet of trellis plots. + description : str, :class:`Description`, Dict, :class:`DescriptionValue` + A text description of this mark for ARIA accessibility (SVG output only). For SVG + output the ``"aria-label"`` attribute will be set to this description. + detail : str, :class:`Detail`, Dict, List + Additional levels of detail for grouping data in aggregate views and in line, trail, + and area marks without mapping data to a specific visual channel. + facet : str, :class:`Facet`, Dict + A field definition for the (flexible) facet of trellis plots. + + If either ``row`` or ``column`` is specified, this channel will be ignored. + fill : str, :class:`Fill`, Dict, :class:`FillDatum`, :class:`FillValue` + Fill color of the marks. **Default value:** If undefined, the default color depends + on `mark config `__ + 's ``color`` property. + + *Note:* The ``fill`` encoding has higher precedence than ``color``, thus may + override the ``color`` encoding if conflicting encodings are specified. + fillOpacity : str, :class:`FillOpacity`, Dict, :class:`FillOpacityDatum`, :class:`FillOpacityValue` + Fill opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__ 's + ``fillOpacity`` property. + href : str, :class:`Href`, Dict, :class:`HrefValue` + A URL to load upon mouse click. + key : str, :class:`Key`, Dict + A data field to use as a unique key for data binding. When a visualization's data is + updated, the key value will be used to match data elements to existing mark + instances. Use a key channel to enable object constancy for transitions over dynamic + data. + latitude : str, :class:`Latitude`, Dict, :class:`LatitudeDatum` + Latitude position of geographically projected marks. + latitude2 : str, :class:`Latitude2`, Dict, :class:`Latitude2Datum`, :class:`Latitude2Value` + Latitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, + ``"rect"``, and ``"rule"``. + longitude : str, :class:`Longitude`, Dict, :class:`LongitudeDatum` + Longitude position of geographically projected marks. + longitude2 : str, :class:`Longitude2`, Dict, :class:`Longitude2Datum`, :class:`Longitude2Value` + Longitude-2 position for geographically projected ranged ``"area"``, ``"bar"``, + ``"rect"``, and ``"rule"``. + opacity : str, :class:`Opacity`, Dict, :class:`OpacityDatum`, :class:`OpacityValue` + Opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__ 's ``opacity`` + property. + order : str, :class:`Order`, Dict, List, :class:`OrderValue` + Order of the marks. + + + * For stacked marks, this ``order`` channel encodes `stack order + `__. + * For line and trail marks, this ``order`` channel encodes order of data points in + the lines. This can be useful for creating `a connected scatterplot + `__. Setting + ``order`` to ``{"value": null}`` makes the line marks use the original order in + the data sources. + * Otherwise, this ``order`` channel encodes layer order of the marks. + + **Note** : In aggregate plots, ``order`` field should be ``aggregate`` d to avoid + creating additional aggregation grouping. + radius : str, :class:`Radius`, Dict, :class:`RadiusDatum`, :class:`RadiusValue` + The outer radius in pixels of arc marks. + radius2 : str, :class:`Radius2`, Dict, :class:`Radius2Datum`, :class:`Radius2Value` + The inner radius in pixels of arc marks. + row : str, :class:`Row`, Dict + A field definition for the vertical facet of trellis plots. + shape : str, :class:`Shape`, Dict, :class:`ShapeDatum`, :class:`ShapeValue` + Shape of the mark. + + + #. + For ``point`` marks the supported values include: - plotting shapes: ``"circle"``, + ``"square"``, ``"cross"``, ``"diamond"``, ``"triangle-up"``, ``"triangle-down"``, + ``"triangle-right"``, or ``"triangle-left"``. - the line symbol ``"stroke"`` - + centered directional shapes ``"arrow"``, ``"wedge"``, or ``"triangle"`` - a custom + `SVG path string + `__ (For correct + sizing, custom shape paths should be defined within a square bounding box with + coordinates ranging from -1 to 1 along both the x and y dimensions.) + + #. + For ``geoshape`` marks it should be a field definition of the geojson data + + **Default value:** If undefined, the default shape depends on `mark config + `__ 's ``shape`` + property. ( ``"circle"`` if unset.) + size : str, :class:`Size`, Dict, :class:`SizeDatum`, :class:`SizeValue` + Size of the mark. + + + * For ``"point"``, ``"square"`` and ``"circle"``, - the symbol size, or pixel area + of the mark. + * For ``"bar"`` and ``"tick"`` - the bar and tick's size. + * For ``"text"`` - the text's font size. + * Size is unsupported for ``"line"``, ``"area"``, and ``"rect"``. (Use ``"trail"`` + instead of line with varying size) + stroke : str, :class:`Stroke`, Dict, :class:`StrokeDatum`, :class:`StrokeValue` + Stroke color of the marks. **Default value:** If undefined, the default color + depends on `mark config + `__ 's ``color`` + property. + + *Note:* The ``stroke`` encoding has higher precedence than ``color``, thus may + override the ``color`` encoding if conflicting encodings are specified. + strokeDash : str, :class:`StrokeDash`, Dict, :class:`StrokeDashDatum`, :class:`StrokeDashValue` + Stroke dash of the marks. + + **Default value:** ``[1,0]`` (No dash). + strokeOpacity : str, :class:`StrokeOpacity`, Dict, :class:`StrokeOpacityDatum`, :class:`StrokeOpacityValue` + Stroke opacity of the marks. + + **Default value:** If undefined, the default opacity depends on `mark config + `__ 's + ``strokeOpacity`` property. + strokeWidth : str, :class:`StrokeWidth`, Dict, :class:`StrokeWidthDatum`, :class:`StrokeWidthValue` + Stroke width of the marks. + + **Default value:** If undefined, the default stroke width depends on `mark config + `__ 's + ``strokeWidth`` property. + text : str, :class:`Text`, Dict, :class:`TextDatum`, :class:`TextValue` + Text of the ``text`` mark. + theta : str, :class:`Theta`, Dict, :class:`ThetaDatum`, :class:`ThetaValue` + For arc marks, the arc length in radians if theta2 is not specified, otherwise the + start arc angle. (A value of 0 indicates up or “north”, increasing values proceed + clockwise.) + + For text marks, polar coordinate angle in radians. + theta2 : str, :class:`Theta2`, Dict, :class:`Theta2Datum`, :class:`Theta2Value` + The end angle of arc marks in radians. A value of 0 indicates up or “north”, + increasing values proceed clockwise. + tooltip : str, :class:`Tooltip`, Dict, List, :class:`TooltipValue` + The tooltip text to show upon mouse hover. Specifying ``tooltip`` encoding overrides + `the tooltip property in the mark definition + `__. + + See the `tooltip `__ + documentation for a detailed discussion about tooltip in Vega-Lite. + url : str, :class:`Url`, Dict, :class:`UrlValue` + The URL of an image mark. + x : str, :class:`X`, Dict, :class:`XDatum`, :class:`XValue` + X coordinates of the marks, or width of horizontal ``"bar"`` and ``"area"`` without + specified ``x2`` or ``width``. + + The ``value`` of this channel can be a number or a string ``"width"`` for the width + of the plot. + x2 : str, :class:`X2`, Dict, :class:`X2Datum`, :class:`X2Value` + X2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. + + The ``value`` of this channel can be a number or a string ``"width"`` for the width + of the plot. + xError : str, :class:`XError`, Dict, :class:`XErrorValue` + Error value of x coordinates for error specified ``"errorbar"`` and ``"errorband"``. + xError2 : str, :class:`XError2`, Dict, :class:`XError2Value` + Secondary error value of x coordinates for error specified ``"errorbar"`` and + ``"errorband"``. + xOffset : str, :class:`XOffset`, Dict, :class:`XOffsetDatum`, :class:`XOffsetValue` + Offset of x-position of the marks + y : str, :class:`Y`, Dict, :class:`YDatum`, :class:`YValue` + Y coordinates of the marks, or height of vertical ``"bar"`` and ``"area"`` without + specified ``y2`` or ``height``. + + The ``value`` of this channel can be a number or a string ``"height"`` for the + height of the plot. + y2 : str, :class:`Y2`, Dict, :class:`Y2Datum`, :class:`Y2Value` + Y2 coordinates for ranged ``"area"``, ``"bar"``, ``"rect"``, and ``"rule"``. + + The ``value`` of this channel can be a number or a string ``"height"`` for the + height of the plot. + yError : str, :class:`YError`, Dict, :class:`YErrorValue` + Error value of y coordinates for error specified ``"errorbar"`` and ``"errorband"``. + yError2 : str, :class:`YError2`, Dict, :class:`YError2Value` + Secondary error value of y coordinates for error specified ``"errorbar"`` and + ``"errorband"``. + yOffset : str, :class:`YOffset`, Dict, :class:`YOffsetDatum`, :class:`YOffsetValue` + Offset of y-position of the marks + """ + # Compat prep for `infer_encoding_types` signature + kwargs = locals() + kwargs.pop("self") + args = kwargs.pop("args") + if args: + kwargs = {k: v for k, v in kwargs.items() if v is not Undefined} + + # Convert args to kwargs based on their types. + kwargs = _infer_encoding_types(args, kwargs) + # get a copy of the dict representation of the previous encoding + # ignore type as copy method comes from SchemaBase + copy = self.copy(deep=["encoding"]) # type: ignore[attr-defined] + encoding = copy._get("encoding", {}) + if isinstance(encoding, core.VegaLiteSchema): + encoding = {k: v for k, v in encoding._kwds.items() if v is not Undefined} + # update with the new encodings, and apply them to the copy + encoding.update(kwargs) + copy.encoding = core.FacetedEncoding(**encoding) + return copy diff --git a/altair/vegalite/v5/schema/core.py b/altair/vegalite/v5/schema/core.py index 38bc0868a..e232d0206 100644 --- a/altair/vegalite/v5/schema/core.py +++ b/altair/vegalite/v5/schema/core.py @@ -246,7 +246,6 @@ "LogicalAndPredicate", "LogicalNotPredicate", "LogicalOrPredicate", - "LookupData", "LookupSelection", "LookupTransform", "Mark", diff --git a/pyproject.toml b/pyproject.toml index ea128f70e..0748bc478 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,6 +216,8 @@ extend-safe-fixes=[ "EM102", # trailing-whitespace "W291", + # blank line contains whitespace + "W293" ] # https://docs.astral.sh/ruff/preview/#using-rules-that-are-in-preview diff --git a/tests/examples_arguments_syntax/hexbins.py b/tests/examples_arguments_syntax/hexbins.py index a8d91ac74..3005fcf00 100644 --- a/tests/examples_arguments_syntax/hexbins.py +++ b/tests/examples_arguments_syntax/hexbins.py @@ -30,7 +30,7 @@ labelPadding=20, tickOpacity=0, domainOpacity=0)), stroke=alt.value('black'), strokeWidth=alt.value(0.2), - fill=alt.Color('mean(temp_max):Q', scale=alt.Scale(scheme='darkblue')), + fill=alt.Fill('mean(temp_max):Q', scale=alt.Scale(scheme='darkblue')), tooltip=['month(' + xField + '):O', 'day(' + yField + '):O', 'mean(temp_max):Q'] ).transform_calculate( # This field is required for the hexagonal X-Offset diff --git a/tests/examples_arguments_syntax/select_detail.py b/tests/examples_arguments_syntax/select_detail.py index a6b558df7..500668c4f 100644 --- a/tests/examples_arguments_syntax/select_detail.py +++ b/tests/examples_arguments_syntax/select_detail.py @@ -54,7 +54,7 @@ color=alt.condition(selector, 'id:O', alt.value('lightgray'), legend=None), ) -timeseries = base.mark_line().encode( +line = base.mark_line().encode( x='time', y=alt.Y('value', scale=alt.Scale(domain=(-15, 15))), color=alt.Color('id:O', legend=None) @@ -62,4 +62,4 @@ selector ) -points | timeseries +points | line diff --git a/tests/examples_methods_syntax/hexbins.py b/tests/examples_methods_syntax/hexbins.py index 26f3890a0..4e23e7bb0 100644 --- a/tests/examples_methods_syntax/hexbins.py +++ b/tests/examples_methods_syntax/hexbins.py @@ -32,7 +32,7 @@ .axis(labelPadding=20, tickOpacity=0, domainOpacity=0), stroke=alt.value('black'), strokeWidth=alt.value(0.2), - fill=alt.Color('mean(temp_max):Q').scale(scheme='darkblue'), + fill=alt.Fill('mean(temp_max):Q').scale(scheme='darkblue'), tooltip=['month(' + xField + '):O', 'day(' + yField + '):O', 'mean(temp_max):Q'] ).transform_calculate( # This field is required for the hexagonal X-Offset diff --git a/tests/examples_methods_syntax/select_detail.py b/tests/examples_methods_syntax/select_detail.py index 4dead24c5..58bdb9dfd 100644 --- a/tests/examples_methods_syntax/select_detail.py +++ b/tests/examples_methods_syntax/select_detail.py @@ -54,7 +54,7 @@ color=alt.condition(selector, 'id:O', alt.value('lightgray'), legend=None), ) -timeseries = base.mark_line().encode( +line = base.mark_line().encode( x='time', y=alt.Y('value').scale(domain=(-15, 15)), color=alt.Color('id:O').legend(None) @@ -62,4 +62,4 @@ selector ) -points | timeseries +points | line diff --git a/tests/utils/test_core.py b/tests/utils/test_core.py index 4bcdf0e88..e127885d8 100644 --- a/tests/utils/test_core.py +++ b/tests/utils/test_core.py @@ -260,6 +260,7 @@ def _getargs(*args, **kwargs): return args, kwargs +# NOTE: Dependent on a no longer needed implementation detail def test_infer_encoding_types(channels): expected = { "x": channels.X("xval"), @@ -285,8 +286,6 @@ def test_infer_encoding_types(channels): def test_infer_encoding_types_with_condition(): - channels = alt.channels - args, kwds = _getargs( size=alt.condition("pred1", alt.value(1), alt.value(2)), color=alt.condition("pred2", alt.value("red"), "cfield:N"), @@ -294,19 +293,19 @@ def test_infer_encoding_types_with_condition(): ) expected = { - "size": channels.SizeValue( + "size": alt.SizeValue( 2, condition=alt.ConditionalPredicateValueDefnumberExprRef( value=1, test=alt.Predicate("pred1") ), ), - "color": channels.Color( + "color": alt.Color( "cfield:N", condition=alt.ConditionalPredicateValueDefGradientstringnullExprRef( value="red", test=alt.Predicate("pred2") ), ), - "opacity": channels.OpacityValue( + "opacity": alt.OpacityValue( 0.2, condition=alt.ConditionalPredicateMarkPropFieldOrDatumDef( field=alt.FieldName("ofield"), @@ -315,7 +314,7 @@ def test_infer_encoding_types_with_condition(): ), ), } - assert infer_encoding_types(args, kwds, channels) == expected + assert infer_encoding_types(args, kwds) == expected def test_invalid_data_type(): diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index 9cc201c6b..2fafa43a8 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -1,3 +1,4 @@ +# ruff: noqa: W291 import copy import io import inspect @@ -436,22 +437,6 @@ def chart_error_example__hconcat(): return points | text -def chart_error_example__invalid_channel(): - # Error: invalidChannel is an invalid encoding channel. Condition is correct - # but is added below as in previous implementations of Altair this interfered - # with finding the invalidChannel error - selection = alt.selection_point() - return ( - alt.Chart(data.barley()) - .mark_circle() - .add_params(selection) - .encode( - color=alt.condition(selection, alt.value("red"), alt.value("green")), - invalidChannel=None, - ) - ) - - def chart_error_example__invalid_y_option_value_unknown_x_option(): # Error 1: unknown is an invalid channel option for X # Error 2: Invalid Y option value "asdf" and unknown option "unknown" for X @@ -553,10 +538,10 @@ def chart_error_example__wrong_tooltip_type_in_layered_chart(): def chart_error_example__two_errors_in_layered_chart(): # Error 1: Wrong data type to pass to tooltip - # Error 2: invalidChannel is not a valid encoding channel + # Error 2: `Color` has no parameter named 'invalidArgument' return alt.layer( alt.Chart().mark_point().encode(tooltip=[{"wrong"}]), - alt.Chart().mark_line().encode(invalidChannel="unknown"), + alt.Chart().mark_line().encode(alt.Color(invalidArgument="unknown")), ) @@ -571,7 +556,7 @@ def chart_error_example__two_errors_in_complex_concat_layered_chart(): def chart_error_example__three_errors_in_complex_concat_layered_chart(): # Error 1: Wrong data type to pass to tooltip - # Error 2: invalidChannel is not a valid encoding channel + # Error 2: `Color` has no parameter named 'invalidArgument' # Error 3: Invalid value for bandPosition return ( chart_error_example__two_errors_in_layered_chart() @@ -581,7 +566,7 @@ def chart_error_example__three_errors_in_complex_concat_layered_chart(): def chart_error_example__two_errors_with_one_in_nested_layered_chart(): # Error 1: invalidOption is not a valid option for Scale - # Error 2: invalidChannel is not a valid encoding channel + # Error 2: `Color` has no parameter named 'invalidArgument' # In the final chart object, the `layer` attribute will look like this: # [alt.Chart(...), alt.Chart(...), alt.LayerChart(...)] @@ -617,7 +602,7 @@ def chart_error_example__two_errors_with_one_in_nested_layered_chart(): base = alt.Chart().encode(y=alt.datum(300)) - rule = base.mark_rule().encode(invalidChannel=2) + rule = base.mark_rule().encode(alt.Color(invalidArgument="unknown")) text = base.mark_text(text="hazardous") rule_text = rule + text @@ -665,7 +650,7 @@ def chart_error_example__four_errors(): Error 2: 'asdf' is an invalid value for `stack`. Valid values are: - One of \['zero', 'center', 'normalize'\] - - Of type 'null' or 'boolean'$""" # noqa: W291 + - Of type 'null' or 'boolean'$""" ), ), ( @@ -687,18 +672,14 @@ def chart_error_example__four_errors(): Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: `Encoding` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidArgument' Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href - - See the help for `Encoding` to read the full description of these parameters$""" # noqa: W291 + shorthand bin legend timeUnit + aggregate condition scale title + bandPosition field sort type + + See the help for `Color` to read the full description of these parameters$""" ), ), ( @@ -718,20 +699,16 @@ def chart_error_example__four_errors(): Error 1: '{'wrong'}' is an invalid value for `field`. Valid values are of type 'string' or 'object'. - Error 2: `Encoding` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidArgument' Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href + shorthand bin legend timeUnit + aggregate condition scale title + bandPosition field sort type - See the help for `Encoding` to read the full description of these parameters + See the help for `Color` to read the full description of these parameters - Error 3: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""" # noqa: W291 + Error 3: '4' is an invalid value for `bandPosition`. Valid values are of type 'number'.$""" ), ), ( @@ -750,18 +727,14 @@ def chart_error_example__four_errors(): See the help for `Scale` to read the full description of these parameters - Error 2: `Encoding` has no parameter named 'invalidChannel' + Error 2: `Color` has no parameter named 'invalidArgument' Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href - - See the help for `Encoding` to read the full description of these parameters$""" # noqa: W291 + shorthand bin legend timeUnit + aggregate condition scale title + bandPosition field sort type + + See the help for `Color` to read the full description of these parameters$""" ), ), ( @@ -775,7 +748,7 @@ def chart_error_example__four_errors(): background data padding spacing usermeta bounds datasets - See the help for `VConcatChart` to read the full description of these parameters$""" # noqa: W291 + See the help for `VConcatChart` to read the full description of these parameters$""" ), ), ( @@ -802,23 +775,6 @@ def chart_error_example__four_errors(): r"""'{'text': 'Horsepower', 'align': 'right'}' is an invalid value for `title`. Valid values are of type 'string', 'array', or 'null'.$""" ), ), - ( - chart_error_example__invalid_channel, - inspect.cleandoc( - r"""`Encoding` has no parameter named 'invalidChannel' - - Existing parameter names are: - angle key order strokeDash tooltip xOffset - color latitude radius strokeOpacity url y - description latitude2 radius2 strokeWidth x y2 - detail longitude shape text x2 yError - fill longitude2 size theta xError yError2 - fillOpacity opacity stroke theta2 xError2 yOffset - href - - See the help for `Encoding` to read the full description of these parameters$""" # noqa: W291 - ), - ), ( chart_error_example__invalid_timeunit_value, inspect.cleandoc( @@ -867,7 +823,7 @@ def chart_error_example__four_errors(): axis impute stack type bandPosition - See the help for `X` to read the full description of these parameters$""" # noqa: W291 + See the help for `X` to read the full description of these parameters$""" ), ), ( @@ -907,7 +863,7 @@ def chart_error_example__four_errors(): axis impute stack type bandPosition - See the help for `X` to read the full description of these parameters$""" # noqa: W291 + See the help for `X` to read the full description of these parameters$""" ), ), ], diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index b9f4ef18d..f1a73670f 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -69,6 +69,7 @@ def load_schema() -> dict: CHANNEL_MIXINS: Final = """ class FieldChannelMixin: + _encoding_name: str def to_dict( self, validate: bool = True, @@ -134,6 +135,7 @@ def to_dict( class ValueChannelMixin: + _encoding_name: str def to_dict( self, validate: bool = True, @@ -157,6 +159,7 @@ def to_dict( class DatumChannelMixin: + _encoding_name: str def to_dict( self, validate: bool = True, @@ -203,10 +206,30 @@ def configure_{prop}(self, *args, **kwargs) -> Self: return copy """ -ENCODE_SIGNATURE: Final = ''' -def _encode_signature({encode_method_args}): - """{docstring}""" - ... +ENCODE_METHOD: Final = ''' +class _EncodingMixin: + def encode({encode_method_args}) -> Self: + """Map properties of the data to visual properties of the chart (see :class:`FacetedEncoding`) + {docstring}""" + # Compat prep for `infer_encoding_types` signature + kwargs = locals() + kwargs.pop("self") + args = kwargs.pop("args") + if args: + kwargs = {{k: v for k, v in kwargs.items() if v is not Undefined}} + + # Convert args to kwargs based on their types. + kwargs = _infer_encoding_types(args, kwargs) + # get a copy of the dict representation of the previous encoding + # ignore type as copy method comes from SchemaBase + copy = self.copy(deep=['encoding']) # type: ignore[attr-defined] + encoding = copy._get('encoding', {{}}) + if isinstance(encoding, core.VegaLiteSchema): + encoding = {{k: v for k, v in encoding._kwds.items() if v is not Undefined}} + # update with the new encodings, and apply them to the copy + encoding.update(kwargs) + copy.encoding = core.FacetedEncoding(**encoding) + return copy ''' @@ -475,10 +498,14 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str: child.basename.append(name) # Specify __all__ explicitly so that we can exclude the ones from the list - # of exported classes which are also defined in the channels module which takes - # precedent in the generated __init__.py file one level up where core.py - # and channels.py are imported. Importing both confuses type checkers. - it = (c for c in definitions.keys() - {"Color", "Text"} if not c.startswith("_")) + # of exported classes which are also defined in the channels or api modules which takes + # precedent in the generated __init__.py files one and two levels up. + # Importing these classes from multiple modules confuses type checkers. + it = ( + c + for c in definitions.keys() - {"Color", "Text", "LookupData"} + if not c.startswith("_") + ) all_ = [*sorted(it), "Root", "VegaLiteSchema", "SchemaBase", "load_schema"] contents = [ @@ -536,21 +563,22 @@ def generate_vegalite_channel_wrappers( imports = imports or [ "from __future__ import annotations\n", - "import sys", - "from . import core", + "from typing import Any, overload, Sequence, List, Literal, Union, TYPE_CHECKING", "import pandas as pd", - "from altair.utils.schemapi import Undefined, UndefinedType, with_property_setters", + "from altair.utils.schemapi import Undefined, with_property_setters", + "from altair.utils import infer_encoding_types as _infer_encoding_types", "from altair.utils import parse_shorthand", - "from typing import Any, overload, Sequence, List, Literal, Union, TYPE_CHECKING", + "from . import core", ] contents = [ HEADER, CHANNEL_MYPY_IGNORE_STATEMENTS, *imports, _type_checking_only_imports( - "from altair import Parameter, SchemaBase # noqa: F401", - "from altair.utils.schemapi import Optional # noqa: F401", + "from altair import Parameter, SchemaBase", + "from altair.utils.schemapi import Optional", "from ._typing import * # noqa: F403", + "from typing_extensions import Self", ), CHANNEL_MIXINS, ] @@ -798,8 +826,8 @@ def vegalite_main(skip_download: bool = False) -> None: def _create_encode_signature( channel_infos: dict[str, ChannelInfo], ) -> str: - signature_args: list[str] = ["self"] - docstring_parameters: list[str] = ["", "Parameters", "----------", ""] + signature_args: list[str] = ["self", "*args: Any"] + docstring_parameters: list[str] = ["", "Parameters", "----------"] for channel, info in channel_infos.items(): field_class_name = info.field_class_name assert ( @@ -843,9 +871,9 @@ def _create_encode_signature( if len(docstring_parameters) > 1: docstring_parameters += [""] docstring = indent_docstring( - docstring_parameters, indent_level=4, width=100, lstrip=True + docstring_parameters, indent_level=4, width=100, lstrip=False ) - return ENCODE_SIGNATURE.format( + return ENCODE_METHOD.format( encode_method_args=", ".join(signature_args), docstring=docstring ) diff --git a/tools/update_init_file.py b/tools/update_init_file.py index fc694b7d8..5e8bbc6ab 100644 --- a/tools/update_init_file.py +++ b/tools/update_init_file.py @@ -92,7 +92,7 @@ def relevant_attributes(namespace: dict[str, Any], /) -> list[str]: it = ( name for name, attr in namespace.items() - if (not name.startswith("_")) and _is_relevant(attr) + if (not name.startswith("_")) and _is_relevant(attr, name) ) return sorted(it) @@ -105,12 +105,13 @@ def _is_hashable(obj: Any) -> bool: return False -def _is_relevant(attr: Any, /) -> bool: +def _is_relevant(attr: Any, name: str, /) -> bool: """Predicate logic for filtering attributes.""" if ( getattr_static(attr, "_deprecated", False) or attr is TYPE_CHECKING or (_is_hashable(attr) and attr in _TYPING_CONSTRUCTS) + or name in {"pd", "jsonschema"} ): return False elif ismodule(attr):