diff --git a/aesara/tensor/basic_opt.py b/aesara/tensor/basic_opt.py index d0b4e4f4bb..6d92396412 100644 --- a/aesara/tensor/basic_opt.py +++ b/aesara/tensor/basic_opt.py @@ -1486,28 +1486,34 @@ def apply(self, fgraph): @register_specialize("local_alloc_elemwise") @local_optimizer([Elemwise]) def local_elemwise_alloc(fgraph, node): - """ - elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) - -> elemwise(x, y.TensorType(BROADCAST CONDITION)) - - elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION)) - -> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION)) - - BROADCAST CONDITION: the condition is that the one input that are - not to be optimized to have the same broadcast pattern as the - output. - - We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have - the shape info. The `DimShuffle` will be faster to exec. - - TODO: Global optimizer that lifts the assert to the beginning of the graph? - TODO: Optimize all inputs when possible -- currently when all inputs have - an `Alloc` all but one is optimized. - + r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. + + `Alloc`\s are effectively a type of `Elemwise` operation + (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so + this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to + `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it + broadcasts). + + In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant + `Alloc`\s. + + The rewrite essentially performs the following replacement: + ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``, + when ``y.shape`` for some input ``y`` (or the combined shapes of the + non-`Alloc`\s) is sufficient to maintain the same/correct output shape. + + In it's current form, it also explicitly accounts for `DimShuffle`\s of + `Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which + introduces them as a canonicalization of `Alloc`'s with leading + broadcastable dimensions. """ if not isinstance(node.op, Elemwise): return False + # Rewrite is only applicable when there are at least two inputs + if len(node.inputs) == 1: + return None + if len(node.outputs) > 1: # Ensure all outputs have the same broadcast pattern # This is a supposition that I'm not sure is always true. @@ -1546,8 +1552,9 @@ def dimshuffled_alloc(i): ): return False - # Search for input that we can use as a baseline for the dimensions. - assert_op_idx = -1 + # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a + # baseline for the dimensions. + assert_op_idx = None for idx, i in enumerate(node.inputs): if i.type.broadcastable == node.outputs[0].type.broadcastable: # Prefer an input that is not a `Alloc` nor a `DimShuffle` of a @@ -1558,31 +1565,14 @@ def dimshuffled_alloc(i): assert_op_idx = idx break - # It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist. - if assert_op_idx < 0: - # We want to optimize as many `Alloc`s as possible. When - # there is more than one then do all but one. number of - # inputs with `Alloc` or `DimShuffle` `Alloc` - l2 = [ - i - for i in node.inputs - if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))) - ] - # If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we - # will use for the shape. So no `Alloc` would be removed. - if len(l2) > 1: - # One contains inputs with `Alloc` or `DimShuffle` `Alloc` - # only. Its length will always be at least one, as we - # checked that before - l = [ - idx - for idx, i in enumerate(node.inputs) - if i.broadcastable == node.outputs[0].broadcastable - ] - assert_op_idx = l[0] # The first one is as good as any to use. - else: - # Nothing would be optimized! - return False + # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one + if assert_op_idx is None: + for idx, i in enumerate(node.inputs): + if (i.type.broadcastable == node.outputs[0].type.broadcastable) and ( + i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) + ): + assert_op_idx = idx + break assert_op_in = node.inputs[assert_op_idx] cmp_op = assert_op_in @@ -1590,13 +1580,7 @@ def dimshuffled_alloc(i): same_shape = fgraph.shape_feature.same_shape for i in node.inputs: # Remove `Alloc` - if ( - i.owner - and isinstance(i.owner.op, Alloc) - and not i.owner.inputs[0].type.is_super(i.owner.outputs[0].type) - ): - # when `i.owner.inputs[0].type.is_super(i.owner.outputs[0].type)` we - # will remove that `Alloc` later + if i.owner and isinstance(i.owner.op, Alloc): assert i.type.ndim == cmp_op.ndim if config.experimental__local_alloc_elemwise_assert: get_shape = fgraph.shape_feature.get_shape @@ -1610,7 +1594,16 @@ def dimshuffled_alloc(i): cond.append(eq(i_shp, cmp_shp)) if cond: assert_op_in = assert_op(assert_op_in, *cond) - new_i.append(i.owner.inputs[0]) + alloc_input = i.owner.inputs[0] + if alloc_input.ndim != i.ndim: + # The `Alloc` can add dimensions to the value. + # We replace those cases with a `DimShuffle` here. + nb_dim_to_add = i.ndim - alloc_input.ndim + alloc_input = alloc_input.dimshuffle( + ["x"] * nb_dim_to_add + list(range(alloc_input.ndim)) + ) + copy_stack_trace(i, alloc_input) + new_i.append(alloc_input) # Remove `Alloc` in `DimShuffle` elif i.owner and dimshuffled_alloc(i): @@ -1626,28 +1619,30 @@ def dimshuffled_alloc(i): assert_op_in = assert_op(assert_op_in, *assert_cond) alloc_input = i.owner.inputs[0].owner.inputs[0] if alloc_input.ndim != i.owner.inputs[0].ndim: - # The `Alloc` can add dimension to the value - # We add a `DimShuffle` to add them. - # We let later optimization merge the multiple `DimShuffle` + # The `Alloc` can add dimensions to the value. + # We replace those cases with a `DimShuffle` here. + # We let later optimizations merge the nested `DimShuffle`s nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim alloc_input = alloc_input.dimshuffle( ["x"] * nb_dim_to_add + list(range(alloc_input.ndim)) ) - # We need to keep the `DimShuffle`. It could swap axes or + # We need to keep the old `DimShuffle`. It could swap axes or # add dimensions anywhere. r_i = i.owner.op(alloc_input) - - # Copy stack trace from i to new_i copy_stack_trace(i, r_i) new_i.append(r_i) + else: new_i.append(i) new_i[assert_op_idx] = assert_op_in - ret = node.op(*new_i, return_list=True) + # If this assert is triggered, it means we are recreating an equivalent graph + # which would result in a cyclical merge optimization. + if all(new is old for new, old in zip(new_i, node.inputs)): + return - # Copy over stack trace from previous outputs to new outputs. + ret = node.op(*new_i, return_list=True) copy_stack_trace(node.outputs, ret) return ret @@ -1809,19 +1804,9 @@ def local_useless_alloc(fgraph, node): @register_specialize @register_stabilize @register_canonicalize -@local_optimizer([alloc]) -def local_canonicalize_alloc(fgraph, node): - """If the input type is the same as the output type (dtype and broadcast) - there is no change in the shape of the input. So this is just a simple copy - of the input. This is not needed. (as local_useless_alloc) - - Also, it will canonicalize alloc by creating Dimshuffle after the - alloc to introduce the dimensions of constant size 1. - - See https://github.com/Theano/Theano/issues/4072 to know why this - is needed. - - """ +@local_optimizer([Alloc]) +def local_alloc_sink_dimshuffle(fgraph, node): + r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s.""" op = node.op if not isinstance(op, Alloc): return False @@ -1829,22 +1814,7 @@ def local_canonicalize_alloc(fgraph, node): inp = node.inputs[0] output = node.outputs[0] - # Check if dtype and broadcast remain the same. - if ( - inp.type.dtype == output.type.dtype - and inp.type.broadcastable == output.type.broadcastable - ): - # We don't need to copy over any stack traces here - return [inp] - - # Allow local_merge_alloc to do its work first - clients = fgraph.clients[output] - for client, i in clients: - if client != "output" and isinstance(client.op, Alloc): - return - # Check if alloc adds a broadcastable dimension with shape 1. - output_shape = node.inputs[1:] num_dims_with_size_1_added_to_left = 0 for i in range(len(output_shape) - inp.ndim): @@ -1852,6 +1822,7 @@ def local_canonicalize_alloc(fgraph, node): num_dims_with_size_1_added_to_left += 1 else: break + new_output_shape = output_shape[num_dims_with_size_1_added_to_left:] if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= inp.ndim: if ( diff --git a/tests/tensor/test_basic_opt.py b/tests/tensor/test_basic_opt.py index 4ac08aa5b5..a7270e8709 100644 --- a/tests/tensor/test_basic_opt.py +++ b/tests/tensor/test_basic_opt.py @@ -40,7 +40,7 @@ ShapeFeature, apply_rebroadcast_opt, assert_op, - local_canonicalize_alloc, + local_alloc_sink_dimshuffle, local_dimshuffle_lift, local_merge_alloc, local_reshape_to_dimshuffle, @@ -1423,8 +1423,7 @@ def test_basic_fill(self): # The optimization 'locall_fill_to_alloc' should call at.alloc, # which should return x and not alloc(x, ...) - mode = mode_opt.excluding("local_canonicalize_alloc") - f = function([x], [y], mode=mode) + f = function([x], [y], mode=mode_opt.including("local_fill_to_alloc")) assert not any( [isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()] ) @@ -1433,9 +1432,12 @@ def test_basic_tile(self): x = matrix("x") y = at.tile(x, (1,) * 2) - mode = mode_opt.including("local_canonicalize_alloc") + mode = mode_opt.including( + "local_dimshuffle_lift", + "local_useless_dimshuffle_in_reshape", + "local_alloc_sink_dimshuffle", + ) f = function([x], [y], mode=mode) - [node.op.__class__ for node in f.maker.fgraph.toposort()] assert not any( [isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()] @@ -1454,7 +1456,7 @@ def test_useless_alloc_with_shape_one(self, x, has_alloc): g = FunctionGraph(outputs=[x]) assert any(isinstance(node.op, Alloc) for node in g.toposort()) - alloc_lift = out2in(local_canonicalize_alloc) + alloc_lift = out2in(local_alloc_sink_dimshuffle) alloc_lift.optimize(g) if has_alloc: @@ -3217,7 +3219,7 @@ def test_local_Unique_Alloc_lift( # The remaining exclusions simply allow us to perform the check below that # makes sure the original `Alloc` is present in our reference (sub)graph. opt_mode = default_mode.excluding( - "local_useless_alloc", "local_canonicalize_alloc", "local_Unique_Alloc_lift" + "local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift" ) y_fn = function([x], [y, y_opt], mode=opt_mode) # Make sure that the original `Alloc` is used to compute the reference `y` @@ -3505,3 +3507,65 @@ def test_Shape_i_canonicalize(): assert isinstance(y_opt.owner.op, Shape_i) assert y_opt.owner.op.i == 0 assert y_opt.owner.inputs[0] == x + + +@pytest.mark.parametrize( + "expr, x_shape, y_shape", + [ + pytest.param( + lambda x, y: at.mul(y, at.alloc(1, x)), + (), + (), + marks=pytest.mark.xfail(reason="Not implemented"), + ), + (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), + (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), + (lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1)), + (lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2)), + (lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15)), + (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)), + ( + lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y), + (15, 2), + (2, 15), + ), + ], +) +def test_local_elemwise_alloc(expr, x_shape, y_shape): + x = at.tensor("int64", (False,) * len(x_shape)) + y = at.tensor("int64", (False,) * len(y_shape)) + z = expr(x, y) + + z_opt = aesara.function( + [x, y], + z, + mode=get_default_mode().including("local_elemwise_alloc"), + on_unused_input="ignore", + ) + + assert not any(isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort()) + + z_no_opt = aesara.function( + [x, y], + z, + mode=get_default_mode().excluding("local_elemwise_alloc"), + on_unused_input="ignore", + ) + + x_val = np.arange(np.prod(x_shape), dtype=np.int64).reshape(x_shape) + y_val = np.arange(np.prod(y_shape), dtype=np.int64).reshape(y_shape) + + res = z_opt(x_val, y_val) + exp_res = z_no_opt(x_val, y_val) + assert np.array_equal(res, exp_res) + + +def test_local_elemwise_alloc_single_input(): + # Test that rewrite is not triggered when there is only one Alloc in an Elemwise + x = at.matrix("x") + z = at.exp(at.alloc(x, 15, 1)) + + z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()]) + + z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"]) + assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes) diff --git a/tests/tensor/test_subtensor_opt.py b/tests/tensor/test_subtensor_opt.py index 42c7aac635..4c6a4debd2 100644 --- a/tests/tensor/test_subtensor_opt.py +++ b/tests/tensor/test_subtensor_opt.py @@ -1860,7 +1860,7 @@ def test_remove_alloc_wo_dimshuffle(self): # Exclude local_useless_alloc, since it does not introduce # assert in all the same cases. self.fast_run_mode = self.fast_run_mode.excluding( - "local_useless_alloc", "local_canonicalize_alloc" + "local_useless_alloc", "local_alloc_sink_dimshuffle" ) # No optimization on alloc func = function(