Skip to content

Commit

Permalink
geometric_algebra: type MultiVector.map and componentwise
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Nov 13, 2024
1 parent 045d2dc commit f32859d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
12 changes: 6 additions & 6 deletions pymbolic/geometric_algebra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
"""

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Generic, TypeVar, cast

import numpy as np

from pytools import memoize, memoize_method

from pymbolic.primitives import expr_dataclass
from pymbolic.typing import ArithmeticExpressionT
from pymbolic.primitives import expr_dataclass, is_zero
from pymbolic.typing import ArithmeticExpressionT, T


__doc__ = """
Expand Down Expand Up @@ -1104,7 +1104,7 @@ def as_vector(self, dtype=None):

# {{{ helper functions

def map(self, f):
def map(self, f: Callable[[CoeffT], CoeffT]) -> MultiVector[CoeffT]:
"""Return a new :class:`MultiVector` with coefficients mapped by
function *f*, which takes a single coefficient as input and returns the
new coefficient.
Expand All @@ -1127,14 +1127,14 @@ def map(self, f):
# }}}


def componentwise(f, expr):
def componentwise(f: Callable[[CoeffT], CoeffT], expr: T) -> T:
"""Apply function *f* componentwise to object arrays and
:class:`MultiVector` instances. *expr* is also allowed to
be a scalar.
"""

if isinstance(expr, MultiVector):
return expr.map(f)
return cast(T, expr.map(f))

from pytools.obj_array import obj_array_vectorize
return obj_array_vectorize(f, expr)
Expand Down
4 changes: 3 additions & 1 deletion pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,9 @@ def map_multivector(self,
expr: MultiVector[ArithmeticExpressionT],
*args: P.args, **kwargs: P.kwargs
) -> ExpressionT:
return expr.map(lambda ch: self.rec(ch, *args, **kwargs))
# True fact: MultiVectors aren't expressions
return expr.map(lambda ch: cast(ArithmeticExpressionT,
self.rec(ch, *args, **kwargs))) # type: ignore[return-value]

def map_common_subexpression(self,
expr: p.CommonSubexpression,
Expand Down
2 changes: 1 addition & 1 deletion pymbolic/mapper/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def map_numpy_array(self, expr: np.ndarray) -> ResultT:
return result # type: ignore[return-value]

def map_multivector(self, expr: MultiVector) -> ResultT:
return expr.map(lambda ch: self.rec(ch))
return expr.map(lambda ch: self.rec(ch)) # type: ignore[return-value]

def map_common_subexpression_uncached(self, expr: p.CommonSubexpression) -> ResultT:
return self.rec(expr.child)
Expand Down

0 comments on commit f32859d

Please sign in to comment.