Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a pass to flag arrays only differing in tags #420

Open
inducer opened this issue Mar 8, 2023 · 8 comments
Open

Add a pass to flag arrays only differing in tags #420

inducer opened this issue Mar 8, 2023 · 8 comments

Comments

@inducer
Copy link
Owner

inducer commented Mar 8, 2023

@majosm reported a situation where a large compile time difference was observed based on an array having a tag vs. not. This is plausible, as even different just tags can lead to arrays not being viewed as equal and therefore failing to be merged in common subexpression elimination. This means that this value (and all its dependents, if both versions are used) are computed multiple times. If multiple uses of the pattern occur, then this could lead to exponential growth of DAG size.

All of this is likely almost always unintended, and so we should at least warn about it (if not error). What I have in mind is a pass that strips all tags and flags the situation in which that process produces multiple versions of the same array that compare equal after stripping.

@majosm
Copy link
Collaborator

majosm commented Mar 14, 2023

What I have in mind is a pass that strips all tags and flags the situation in which that process produces multiple versions of the same array that compare equal after stripping.

Where might be a good place to insert this pass? (Not very familiar with the overall structure of pytato yet.)

@inducer
Copy link
Owner Author

inducer commented Mar 14, 2023

I think it would come down to adding a function in analysis that perhaps uses a custom WalkMapper for the traversal. That function could then be called from somewhere in the appropriate array context, to actually perform the check.

@kaushikcfd
Copy link
Collaborator

Here's one way to do it:

(py311_env) $ cat remove_tags_and_merge.py 
import pytato as pt
import numpy as np
from pytools.tag import Tag


def remove_tag_t(expr, tag_t):
    def _rec_remove_tag_t(expr):
        if isinstance(expr, pt.Array):
            if tags_to_remove := expr.tags_of_type(tag_t):
                return expr.without_tags(tags_to_remove,
                                         verify_existence=False)
            else:
                return expr
        else:
            return expr

    expr = pt.transform.map_and_copy(expr, _rec_remove_tag_t)
    return pt.transform.BranchMorpher()(expr)


x = pt.make_placeholder("x", (10, 4), np.float64)
y = pt.make_placeholder("y", (10, 4), np.float64)

tmp = x + y
tmp1 = tmp.tagged(pt.tags.ImplStored())

out = 2*tmp + 3*tmp1

print(pt.analysis.get_num_nodes(out))
print(pt.analysis.get_num_nodes(remove_tag_t(out, tag_t=Tag)))
(py311_env) $ python remove_tags_and_merge.py 
8
7

@kaushikcfd
Copy link
Collaborator

This is plausible, as even different just tags can lead to arrays not being viewed as equal and therefore failing to be merged in common subexpression elimination

This is true, but if it's just one node differing in the tag, then something else is wrong here as the subexpressions for the diverging nodes would still be the same and the relative difference in runtime/compile time should have been insignificant.

@inducer
Copy link
Owner Author

inducer commented Mar 15, 2023

This is true, but if it's just one node differing in the tag, then something else is wrong here as the subexpressions for the diverging nodes would still be the same and the relative difference in runtime/compile time should have been insignificant.

Are you sure? Wouldn't depending nodes necessarily also compare non-equal?

@inducer
Copy link
Owner Author

inducer commented Mar 15, 2023

Here's one way to do it:

Thanks for providing that! It's quick, but it's got a few downsides: It has quite a few traversals, and it doesn't explicitly identify the offending nodes.

@kaushikcfd
Copy link
Collaborator

Are you sure? Wouldn't depending nodes necessarily also compare non-equal?

Aah fair. I was only thinking of the predecessors and not the successors. Thanks for the correction!

Thanks for providing that! It's quick, but it's got a few downsides: It has quite a few traversals, and it doesn't explicitly identify the offending nodes.

Yep, it's a starting point. However, extending it to the functionalities that you point out shouldn't be more than another 50 lines, I think :).

@kaushikcfd
Copy link
Collaborator

kaushikcfd commented Mar 15, 2023

FWIW, this is more in line with what you suggested:

import pytato as pt
import numpy as np
from typing import Dict


class MyWalkMapper(pt.transform.CachedWalkMapper):
    def __init__(self):
        super().__init__()
        self.stripped_ary_to_ary: Dict[pt.Array, pt.Array] = {}

    def get_cache_key(self, expr):
        return id(expr)

    def post_visit(self, expr: pt.transform.ArrayOrNames):
        if isinstance(expr, pt.Array):
            from pytato.array import (_get_default_tags,
                                      _get_default_axes)
            tagless_expr = expr.copy(
                tags=_get_default_tags(),
                axes=_get_default_axes(expr.ndim))
            try:
                if colliding_expr := self.stripped_ary_to_ary[tagless_expr] != expr:
                    raise ValueError(f"Arrays '{colliding_expr}' and '{expr}'"
                                     " are semantically the same array except"
                                     " the attached metadata => will lead to "
                                     " inefficient generated code.")
            except KeyError:
                self.stripped_ary_to_ary[tagless_expr] = expr


x = pt.make_placeholder("x", (10, 4), np.float64)
y = pt.make_placeholder("y", (10, 4), np.float64)

tmp = x + y
tmp1 = tmp.tagged(pt.tags.ImplStored())

out = 2*tmp + 3*tmp1

MyWalkMapper()(out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants