Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TaggedExpression #688

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ information provided. Now we will count the operations:

>>> op_map = lp.get_op_map(knl, subgroup_size=32)
>>> print(op_map)
Op(np:dtype('float32'), add, subgroup, "stats_knl"): ...
Op(np:dtype('float32'), add, subgroup, "stats_knl", None): ...

Each line of output will look roughly like::

Expand Down Expand Up @@ -1628,7 +1628,7 @@ together into keys containing only the specified fields:

>>> op_map_dtype = op_map.group_by('dtype')
>>> print(op_map_dtype)
Op(np:dtype('float32'), None, None): ...
Op(np:dtype('float32'), None, None, None): ...
<BLANKLINE>
>>> f32op_count = op_map_dtype[lp.Op(dtype=np.float32)
... ].eval_with_dict(param_dict)
Expand Down
22 changes: 17 additions & 5 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,14 @@ class Op(ImmutableRecord):

A :class:`str` representing the kernel name where the operation occurred.

.. attribute:: tags

A :class:`frozenset` of tags to the operation.

"""

def __init__(self, dtype=None, name=None, count_granularity=None,
kernel_name=None):
kernel_name=None, tags=None):
if count_granularity not in CountGranularity.ALL+[None]:
raise ValueError("Op.__init__: count_granularity '%s' is "
"not allowed. count_granularity options: %s"
Expand All @@ -651,15 +655,17 @@ def __init__(self, dtype=None, name=None, count_granularity=None,

super().__init__(dtype=dtype, name=name,
count_granularity=count_granularity,
kernel_name=kernel_name)
kernel_name=kernel_name,
tags=tags)

def __repr__(self):
# Record.__repr__ overridden for consistent ordering and conciseness
if self.kernel_name is not None:
return (f"Op({self.dtype}, {self.name}, {self.count_granularity},"
f' "{self.kernel_name}")')
f' "{self.kernel_name}", {self.tags})')
else:
return f"Op({self.dtype}, {self.name}, {self.count_granularity})"
return f"Op({self.dtype}, {self.name}, " + \
f"{self.count_granularity}, {self.tags})"

# }}}

Expand Down Expand Up @@ -724,7 +730,7 @@ class MemAccess(ImmutableRecord):
work-group executes on a single compute unit with all work-items within
the work-group sharing local memory. A sub-group is an
implementation-dependent grouping of work-items within a work-group,
analagous to an NVIDIA CUDA warp.
analogous to an NVIDIA CUDA warp.

.. attribute:: kernel_name

Expand Down Expand Up @@ -922,6 +928,12 @@ def map_constant(self, expr):
map_tagged_variable = map_constant
map_variable = map_constant

def map_with_tag(self, expr):
opmap = self.rec(expr.expr)
for op in opmap.count_map:
op.tags = expr.tags
return opmap
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would overwrite tags of subexpressions that already have tags.

self.rec(expr.expr, expr.tags)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 50adbc4 what you had in mind?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, generally.


def map_call(self, expr):
from loopy.symbolic import ResolvedFunction
assert isinstance(expr.function, ResolvedFunction)
Expand Down
46 changes: 46 additions & 0 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@
# {{{ mappers with support for loopy-specific primitives

class IdentityMapperMixin:
def map_with_tag(self, expr, *args, **kwargs):
new_expr = self.rec(expr.expr, *args, **kwargs)
return WithTag(expr.tags, new_expr)

def map_literal(self, expr, *args, **kwargs):
return expr

Expand Down Expand Up @@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr):


class WalkMapperMixin:
def map_with_tag(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return

self.rec(expr.expr, *args, **kwargs)

def map_literal(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)

Expand Down Expand Up @@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):


class CombineMapper(CombineMapperBase):
def map_with_tag(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

def map_reduction(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

Expand All @@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,


class StringifyMapper(StringifyMapperBase):
def map_with_tag(self, expr, *args):
from pymbolic.mapper.stringifier import PREC_NONE
return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}"

def map_literal(self, expr, *args):
return expr.s

Expand Down Expand Up @@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs):
def map_loopy_function_identifier(self, expr, *args, **kwargs):
return set()

def map_with_tag(self, expr, *args, **kwargs):
deps = self.rec(expr.expr, *args, **kwargs)
return deps

def map_sub_array_ref(self, expr, *args, **kwargs):
deps = self.rec(expr.subscript, *args, **kwargs)
return deps - set(expr.swept_inames)
Expand Down Expand Up @@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None):
mapper_method = intern("map_tagged_variable")


class WithTag(LoopyExpressionBase):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class WithTag(LoopyExpressionBase):
class TaggedExpression(LoopyExpressionBase):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 45e14e7

"""
Represents a frozenset of tags attached to an :attr:`expr`.
"""

init_arg_names = ("tags", "expr")

def __init__(self, tags, expr):
self.tags = tags
self.expr = expr

def __getinitargs__(self):
return (self.tags, self.expr)

def get_hash(self):
return hash((self.__class__, self.tags, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.tags == self.tags
and other.expr == self.expr)

mapper_method = intern("map_with_tag")


class Reduction(LoopyExpressionBase):
"""
Represents a reduction operation on :attr:`expr` across :attr:`inames`.
Expand Down
58 changes: 58 additions & 0 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,64 @@ def test_no_loop_ops():
assert f64_mul == 1


from pytools.tag import Tag


class MyCostTag1(Tag):
pass


class MyCostTag2(Tag):
pass


class MyCostTagSum(Tag):
pass


def test_op_with_tag():
from loopy.symbolic import WithTag
from pymbolic.primitives import Subscript, Variable, Sum

n = 500

knl = lp.make_kernel(
"{[i]: 0<=i<n}",
[
lp.Assignment("c[i]", WithTag(frozenset((MyCostTagSum(),)),
Sum(
(WithTag(frozenset((MyCostTag1(),)),
Subscript(Variable("a"), Variable("i"))),
WithTag(frozenset((MyCostTag2(),)),
Subscript(Variable("b"), Variable("i")))))))
])

knl = lp.add_dtypes(knl, {"a": np.float64, "b": np.float64})

params = {"n": n}

op_map = lp.get_op_map(knl, subgroup_size=32)

f64_add = op_map.filter_by(dtype=[np.float64]).eval_and_sum(params)
assert f64_add == n

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTagSum(),))]).eval_and_sum(params)
assert f64_add == n

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTag1(),))]).eval_and_sum(params)
assert f64_add == 0

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTag2(),))]).eval_and_sum(params)
assert f64_add == 0

f64_add = op_map.filter_by(
tags=[frozenset((MyCostTag2(), MyCostTagSum()))]).eval_and_sum(params)
assert f64_add == 0


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down