Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 60 additions & 89 deletions aesara/tensor/basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,28 +1486,34 @@ def apply(self, fgraph):
@register_specialize("local_alloc_elemwise")
@local_optimizer([Elemwise])
def local_elemwise_alloc(fgraph, node):
"""
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))

elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))

BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the
output.

We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have
the shape info. The `DimShuffle` will be faster to exec.

TODO: Global optimizer that lifts the assert to the beginning of the graph?
TODO: Optimize all inputs when possible -- currently when all inputs have
an `Alloc` all but one is optimized.

r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.

`Alloc`\s are effectively a type of `Elemwise` operation
(e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
`Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
broadcasts).

In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
`Alloc`\s.

The rewrite essentially performs the following replacement:
``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
when ``y.shape`` for some input ``y`` (or the combined shapes of the
non-`Alloc`\s) is sufficient to maintain the same/correct output shape.

In it's current form, it also explicitly accounts for `DimShuffle`\s of
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions.
"""
if not isinstance(node.op, Elemwise):
return False

# Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1:
return None

if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
Expand Down Expand Up @@ -1546,8 +1552,9 @@ def dimshuffled_alloc(i):
):
return False

# Search for input that we can use as a baseline for the dimensions.
assert_op_idx = -1
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions.
assert_op_idx = None
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
Expand All @@ -1558,45 +1565,22 @@ def dimshuffled_alloc(i):
assert_op_idx = idx
break

# It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist.
if assert_op_idx < 0:
# We want to optimize as many `Alloc`s as possible. When
# there is more than one then do all but one. number of
# inputs with `Alloc` or `DimShuffle` `Alloc`
l2 = [
i
for i in node.inputs
if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)))
]
# If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we
# will use for the shape. So no `Alloc` would be removed.
if len(l2) > 1:
# One contains inputs with `Alloc` or `DimShuffle` `Alloc`
# only. Its length will always be at least one, as we
# checked that before
l = [
idx
for idx, i in enumerate(node.inputs)
if i.broadcastable == node.outputs[0].broadcastable
]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
return False
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if assert_op_idx is None:
for idx, i in enumerate(node.inputs):
if (i.type.broadcastable == node.outputs[0].type.broadcastable) and (
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
assert_op_idx = idx
break

assert_op_in = node.inputs[assert_op_idx]
cmp_op = assert_op_in
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove `Alloc`
if (
i.owner
and isinstance(i.owner.op, Alloc)
and not i.owner.inputs[0].type.is_super(i.owner.outputs[0].type)
):
# when `i.owner.inputs[0].type.is_super(i.owner.outputs[0].type)` we
# will remove that `Alloc` later
if i.owner and isinstance(i.owner.op, Alloc):
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
Expand All @@ -1610,7 +1594,16 @@ def dimshuffled_alloc(i):
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0])
alloc_input = i.owner.inputs[0]
if alloc_input.ndim != i.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = i.ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
copy_stack_trace(i, alloc_input)
new_i.append(alloc_input)

# Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i):
Expand All @@ -1626,28 +1619,30 @@ def dimshuffled_alloc(i):
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The `Alloc` can add dimension to the value
# We add a `DimShuffle` to add them.
# We let later optimization merge the multiple `DimShuffle`
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
# We let later optimizations merge the nested `DimShuffle`s
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)

# We need to keep the `DimShuffle`. It could swap axes or
# We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)

# Copy stack trace from i to new_i
copy_stack_trace(i, r_i)
new_i.append(r_i)

else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in

ret = node.op(*new_i, return_list=True)
# If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization.
if all(new is old for new, old in zip(new_i, node.inputs)):
return

# Copy over stack trace from previous outputs to new outputs.
ret = node.op(*new_i, return_list=True)
copy_stack_trace(node.outputs, ret)
return ret

Expand Down Expand Up @@ -1809,49 +1804,25 @@ def local_useless_alloc(fgraph, node):
@register_specialize
@register_stabilize
@register_canonicalize
@local_optimizer([alloc])
def local_canonicalize_alloc(fgraph, node):
"""If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed. (as local_useless_alloc)

Also, it will canonicalize alloc by creating Dimshuffle after the
alloc to introduce the dimensions of constant size 1.

See https://github.com/Theano/Theano/issues/4072 to know why this
is needed.

"""
@local_optimizer([Alloc])
def local_alloc_sink_dimshuffle(fgraph, node):
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s."""
op = node.op
if not isinstance(op, Alloc):
return False

inp = node.inputs[0]
output = node.outputs[0]

# Check if dtype and broadcast remain the same.
if (
inp.type.dtype == output.type.dtype
and inp.type.broadcastable == output.type.broadcastable
):
# We don't need to copy over any stack traces here
return [inp]

# Allow local_merge_alloc to do its work first
clients = fgraph.clients[output]
for client, i in clients:
if client != "output" and isinstance(client.op, Alloc):
return

# Check if alloc adds a broadcastable dimension with shape 1.

output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - inp.ndim):
if extract_constant(output_shape[i], only_process_constants=True) == 1:
num_dims_with_size_1_added_to_left += 1
else:
break

new_output_shape = output_shape[num_dims_with_size_1_added_to_left:]
if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= inp.ndim:
if (
Expand Down
78 changes: 71 additions & 7 deletions tests/tensor/test_basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
ShapeFeature,
apply_rebroadcast_opt,
assert_op,
local_canonicalize_alloc,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
Expand Down Expand Up @@ -1423,8 +1423,7 @@ def test_basic_fill(self):

# The optimization 'locall_fill_to_alloc' should call at.alloc,
# which should return x and not alloc(x, ...)
mode = mode_opt.excluding("local_canonicalize_alloc")
f = function([x], [y], mode=mode)
f = function([x], [y], mode=mode_opt.including("local_fill_to_alloc"))
assert not any(
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
)
Expand All @@ -1433,9 +1432,12 @@ def test_basic_tile(self):
x = matrix("x")
y = at.tile(x, (1,) * 2)

mode = mode_opt.including("local_canonicalize_alloc")
mode = mode_opt.including(
"local_dimshuffle_lift",
"local_useless_dimshuffle_in_reshape",
"local_alloc_sink_dimshuffle",
)
f = function([x], [y], mode=mode)
[node.op.__class__ for node in f.maker.fgraph.toposort()]

assert not any(
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
Expand All @@ -1454,7 +1456,7 @@ def test_useless_alloc_with_shape_one(self, x, has_alloc):
g = FunctionGraph(outputs=[x])
assert any(isinstance(node.op, Alloc) for node in g.toposort())

alloc_lift = out2in(local_canonicalize_alloc)
alloc_lift = out2in(local_alloc_sink_dimshuffle)
alloc_lift.optimize(g)

if has_alloc:
Expand Down Expand Up @@ -3217,7 +3219,7 @@ def test_local_Unique_Alloc_lift(
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
opt_mode = default_mode.excluding(
"local_useless_alloc", "local_canonicalize_alloc", "local_Unique_Alloc_lift"
"local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift"
)
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `Alloc` is used to compute the reference `y`
Expand Down Expand Up @@ -3505,3 +3507,65 @@ def test_Shape_i_canonicalize():
assert isinstance(y_opt.owner.op, Shape_i)
assert y_opt.owner.op.i == 0
assert y_opt.owner.inputs[0] == x


@pytest.mark.parametrize(
"expr, x_shape, y_shape",
[
pytest.param(
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15)),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)),
(
lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y),
(15, 2),
(2, 15),
),
],
)
def test_local_elemwise_alloc(expr, x_shape, y_shape):
x = at.tensor("int64", (False,) * len(x_shape))
y = at.tensor("int64", (False,) * len(y_shape))
z = expr(x, y)

z_opt = aesara.function(
[x, y],
z,
mode=get_default_mode().including("local_elemwise_alloc"),
on_unused_input="ignore",
)

assert not any(isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort())

z_no_opt = aesara.function(
[x, y],
z,
mode=get_default_mode().excluding("local_elemwise_alloc"),
on_unused_input="ignore",
)

x_val = np.arange(np.prod(x_shape), dtype=np.int64).reshape(x_shape)
y_val = np.arange(np.prod(y_shape), dtype=np.int64).reshape(y_shape)

res = z_opt(x_val, y_val)
exp_res = z_no_opt(x_val, y_val)
assert np.array_equal(res, exp_res)


def test_local_elemwise_alloc_single_input():
# Test that rewrite is not triggered when there is only one Alloc in an Elemwise
x = at.matrix("x")
z = at.exp(at.alloc(x, 15, 1))

z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()])

z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"])
assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes)
2 changes: 1 addition & 1 deletion tests/tensor/test_subtensor_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ def test_remove_alloc_wo_dimshuffle(self):
# Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding(
"local_useless_alloc", "local_canonicalize_alloc"
"local_useless_alloc", "local_alloc_sink_dimshuffle"
)
# No optimization on alloc
func = function(
Expand Down