Skip to content

Commit

Permalink
Rename type aliases for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Nov 21, 2024
1 parent 6f52279 commit 31dbd39
Show file tree
Hide file tree
Showing 27 changed files with 448 additions and 360 deletions.
12 changes: 10 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@
("https://immutabledict.corenting.fr/", None)
}
autodoc_type_aliases = {
"ExpressionT": "ExpressionT",
"ArithmeticExpressionT": "ArithmeticExpressionT",
"Expression": "Expression",
"ArithmeticExpression": "ArithmeticExpression",
}


nitpick_ignore_regex = [
# Avoids this error. Not sure where to even look.
# <unknown>:1: WARNING: py:class reference target not found: ExpressionNode [ref.class] # noqa: E501
["py:class", r"ExpressionNode"],
]


import sys


Expand Down
5 changes: 3 additions & 2 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ You can also easily define your own objects to use inside an expression:

.. doctest::

>>> from pymbolic.primitives import Expression, expr_dataclass
>>> from pymbolic import ExpressionNode, expr_dataclass
>>> from pymbolic.typing import Expression
>>>
>>> @expr_dataclass()
... class FancyOperator(Expression):
... class FancyOperator(ExpressionNode):
... operand: Expression
...
>>> u
Expand Down
3 changes: 2 additions & 1 deletion doc/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Parser

.. function:: parse(expr_str)

Return a :class:`pymbolic.primitives.Expression` tree corresponding to *expr_str*.
Return a :class:`pymbolic.primitives.ExpressionNode` tree corresponding
to *expr_str*.

The parser is also relatively easy to extend. See the source code of the following
class.
Expand Down
85 changes: 49 additions & 36 deletions pymbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,60 @@
"""


from pymbolic.version import VERSION_TEXT as __version__ # noqa

from . import parser
from . import compiler
from functools import partial

from .mapper import evaluator
from .mapper import stringifier
from .mapper import dependency
from .mapper import substitutor
from .mapper import differentiator
from .mapper import distributor
from .mapper import flattener
from . import primitives
from pytools import module_getattr_for_deprecations

from .primitives import (Variable as var, # noqa: N813
from . import compiler, parser, primitives
from .compiler import compile
from .mapper import (
dependency,
differentiator,
distributor,
evaluator,
flattener,
stringifier,
substitutor,
)
from .mapper.differentiator import differentiate, differentiate as diff
from .mapper.distributor import distribute, distribute as expand
from .mapper.evaluator import evaluate, evaluate_kw
from .mapper.flattener import flatten
from .mapper.substitutor import substitute
from .parser import parse
from .primitives import ( # noqa: N813
ExpressionNode,
Variable,
Expression,
variables,
flattened_sum,
subscript,
Variable as var,
disable_subscript_by_getitem,
expr_dataclass,
flattened_product,
quotient,
flattened_sum,
linear_combination,
make_common_subexpression as cse,
make_sym_vector,
disable_subscript_by_getitem,
expr_dataclass,
quotient,
subscript,
variables,
)
from .parser import parse
from .mapper.evaluator import evaluate
from .mapper.evaluator import evaluate_kw
from .compiler import compile
from .mapper.substitutor import substitute
from .mapper.differentiator import differentiate as diff
from .mapper.differentiator import differentiate
from .mapper.distributor import distribute as expand
from .mapper.distributor import distribute
from .mapper.flattener import flatten
from .typing import NumberT, ScalarT, ArithmeticExpressionT, ExpressionT, BoolT
from .typing import (
ArithmeticExpression,
Bool,
Expression,
Expression as _TypingExpression,
Number,
Scalar,
)
from pymbolic.version import VERSION_TEXT as __version__ # noqa


__all__ = (
"ArithmeticExpressionT",
"BoolT",
"ArithmeticExpression",
"Bool",
"Expression",
"ExpressionT",
"NumberT",
"ScalarT",
"ExpressionNode",
"Number",
"Scalar",
"Variable",
"compile",
"compiler",
Expand Down Expand Up @@ -105,3 +111,10 @@
"var",
"variables",
)

__getattr__ = partial(module_getattr_for_deprecations, __name__, {
"ExpressionT": ("pymbolic.typing.Expression", _TypingExpression, 2026),
"ArithmeticExpressionT": ("ArithmeticExpression", ArithmeticExpression, 2026),
"BoolT": ("Bool", Bool, 2026),
"ScalarT": ("Scalar", Scalar, 2026),
})
2 changes: 1 addition & 1 deletion pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def tag_common_subexpressions(exprs):
get_key = NormalizedKeyGetter()
ucm = UseCountMapper(get_key)

if isinstance(exprs, prim.Expression):
if isinstance(exprs, prim.ExpressionNode):
raise TypeError("exprs should be an iterable of expressions")

for expr in exprs:
Expand Down
6 changes: 3 additions & 3 deletions pymbolic/geometric_algebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pytools import memoize, memoize_method

from pymbolic.primitives import expr_dataclass, is_zero
from pymbolic.typing import ArithmeticExpressionT, T
from pymbolic.typing import ArithmeticExpression, T


__doc__ = """
Expand Down Expand Up @@ -293,7 +293,7 @@ def get_euclidean_space(n: int) -> Space:
# }}}


CoeffT = TypeVar("CoeffT", bound=ArithmeticExpressionT)
CoeffT = TypeVar("CoeffT", bound=ArithmeticExpression)


# {{{ blade product weights
Expand Down Expand Up @@ -428,7 +428,7 @@ def _cast_to_mv(obj: Any, space: Space) -> MultiVector:
class MultiVector(Generic[CoeffT]):
r"""An immutable multivector type. Its implementation follows [DFM].
It is pickleable, and not picky about what data is used as coefficients.
It supports :class:`pymbolic.primitives.Expression` objects of course,
It supports :class:`pymbolic.primitives.ExpressionNode` objects of course,
but it can take just about any other scalar-ish coefficients.
.. autoattribute:: data
Expand Down
10 changes: 6 additions & 4 deletions pymbolic/geometric_algebra/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,23 @@
PREC_NONE,
StringifyMapper as StringifyMapperBase,
)
from pymbolic.primitives import Expression
from pymbolic.primitives import ExpressionNode


class IdentityMapper(IdentityMapperBase[P]):
def map_nabla(
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs) -> Expression:
self, expr: prim.Nabla, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
return expr

def map_nabla_component(self,
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs) -> Expression:
expr: prim.NablaComponent, *args: P.args, **kwargs: P.kwargs
) -> ExpressionNode:
return expr

def map_derivative_source(self,
expr: prim.DerivativeSource, *args: P.args, **kwargs: P.kwargs
) -> Expression:
) -> ExpressionNode:
operand = self.rec(expr.operand, *args, **kwargs)
if operand is expr.operand:
return expr
Expand Down
8 changes: 4 additions & 4 deletions pymbolic/geometric_algebra/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from collections.abc import Hashable
from typing import ClassVar

from pymbolic.primitives import Expression, Variable, expr_dataclass
from pymbolic.typing import ExpressionT
from pymbolic.primitives import ExpressionNode, Variable, expr_dataclass
from pymbolic.typing import Expression


class MultiVectorVariable(Variable):
Expand All @@ -39,7 +39,7 @@ class MultiVectorVariable(Variable):

# {{{ geometric calculus

class _GeometricCalculusExpression(Expression):
class _GeometricCalculusExpression(ExpressionNode):
def stringifier(self):
from pymbolic.geometric_algebra.mapper import StringifyMapper
return StringifyMapper
Expand All @@ -58,7 +58,7 @@ class Nabla(_GeometricCalculusExpression):

@expr_dataclass()
class DerivativeSource(_GeometricCalculusExpression):
operand: ExpressionT
operand: Expression
nabla_id: Hashable


Expand Down
6 changes: 3 additions & 3 deletions pymbolic/interop/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import pymbolic.primitives as p
from pymbolic.mapper import CachedMapper
from pymbolic.typing import ExpressionT
from pymbolic.typing import Expression


__doc__ = r'''
Expand Down Expand Up @@ -263,7 +263,7 @@ def map_variable(self, expr) -> ast.expr:
return ast.Name(id=expr.name)

def _map_multi_children_op(self,
children: tuple[ExpressionT, ...],
children: tuple[Expression, ...],
op_type: ast.operator) -> ast.expr:
rec_children = [self.rec(child) for child in children]
result = rec_children[-1]
Expand Down Expand Up @@ -435,7 +435,7 @@ def to_python_ast(expr) -> ast.expr:
return PymbolicToASTMapper()(expr)


def to_evaluatable_python_function(expr: ExpressionT,
def to_evaluatable_python_function(expr: Expression,
fn_name: str
) -> str:
"""
Expand Down
30 changes: 15 additions & 15 deletions pymbolic/interop/matchpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@
)

import pymbolic.primitives as p
from pymbolic.typing import ScalarT
from pymbolic.typing import Scalar as PbScalar


ExprT: TypeAlias = Expression
ConstantT = TypeVar("ConstantT")
ToMatchpyT = Callable[[p.Expression], ExprT]
FromMatchpyT = Callable[[ExprT], p.Expression]
ToMatchpyT = Callable[[p.ExpressionNode], ExprT]
FromMatchpyT = Callable[[ExprT], p.ExpressionNode]


_NOT_OPERAND_METADATA = {"not_an_operand": True}
Expand Down Expand Up @@ -95,7 +95,7 @@ def __lt__(self, other):


@op_dataclass
class Scalar(_Constant[ScalarT]):
class Scalar(_Constant[PbScalar]):
_mapper_method: str = "map_scalar"


Expand Down Expand Up @@ -360,11 +360,11 @@ def _get_operand_at_path(expr: PymbolicOp, path: tuple[int, ...]) -> PymbolicOp:
return result


def match(subject: p.Expression,
pattern: p.Expression,
def match(subject: p.ExpressionNode,
pattern: p.ExpressionNode,
to_matchpy_expr: ToMatchpyT | None = None,
from_matchpy_expr: FromMatchpyT | None = None
) -> Iterator[Mapping[str, p.Expression | ScalarT]]:
) -> Iterator[Mapping[str, p.ExpressionNode | PbScalar]]:
from matchpy import Pattern, match

from .tofrom import FromMatchpyExpressionMapper, ToMatchpyExpressionMapper
Expand All @@ -383,12 +383,12 @@ def match(subject: p.Expression,
for name, expr in subst.items()}


def match_anywhere(subject: p.Expression,
pattern: p.Expression,
def match_anywhere(subject: p.ExpressionNode,
pattern: p.ExpressionNode,
to_matchpy_expr: ToMatchpyT | None = None,
from_matchpy_expr: FromMatchpyT | None = None
) -> Iterator[tuple[Mapping[str, p.Expression | ScalarT],
p.Expression | ScalarT]
) -> Iterator[tuple[Mapping[str, p.ExpressionNode | PbScalar],
p.ExpressionNode | PbScalar]
]:
from matchpy import Pattern, match_anywhere

Expand All @@ -409,8 +409,8 @@ def match_anywhere(subject: p.Expression,
from_matchpy_expr(_get_operand_at_path(m_subject, path)))


def make_replacement_rule(pattern: p.Expression,
replacement: Callable[..., p.Expression],
def make_replacement_rule(pattern: p.ExpressionNode,
replacement: Callable[..., p.ExpressionNode],
to_matchpy_expr: ToMatchpyT | None = None,
from_matchpy_expr: FromMatchpyT | None = None
) -> ReplacementRule:
Expand All @@ -437,11 +437,11 @@ def make_replacement_rule(pattern: p.Expression,
from_matchpy_expr))


def replace_all(expression: p.Expression,
def replace_all(expression: p.ExpressionNode,
rules: Iterable[ReplacementRule],
to_matchpy_expr: ToMatchpyT | None = None,
from_matchpy_expr: FromMatchpyT | None = None
) -> p.Expression | tuple[p.Expression, ...]:
) -> p.ExpressionNode | tuple[p.ExpressionNode, ...]:
import collections.abc as abc

from matchpy import replace_all
Expand Down
5 changes: 3 additions & 2 deletions pymbolic/interop/matchpy/tofrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pymbolic.primitives as p
from pymbolic.interop.matchpy.mapper import Mapper as BaseMatchPyMapper
from pymbolic.mapper import Mapper as BasePymMapper
from pymbolic.typing import Scalar as PbScalar


# {{{ to matchpy
Expand Down Expand Up @@ -117,7 +118,7 @@ def map_star_wildcard(self, expr: p.StarWildcard) -> m.Wildcard:
# {{{ from matchpy

class FromMatchpyExpressionMapper(BaseMatchPyMapper):
def map_scalar(self, expr: m.Scalar) -> m.ScalarT:
def map_scalar(self, expr: m.Scalar) -> PbScalar:
return expr.value

def map_variable(self, expr: m.Variable) -> p.Variable:
Expand Down Expand Up @@ -200,7 +201,7 @@ def map_if(self, expr: m.If) -> p.If:

@dataclass(frozen=True, eq=True)
class ToFromReplacement:
f: Callable[..., p.Expression]
f: Callable[..., p.ExpressionNode]
to_matchpy_expr: m.ToMatchpyT
from_matchpy_expr: m.FromMatchpyT

Expand Down
Loading

0 comments on commit 31dbd39

Please sign in to comment.