Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aesara/gpuarray/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def perform(self, node, inp, out, params):

res = input

res = res.transpose(self.shuffle + self.drop)
res = res.transpose(self.transposition)

shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
Expand Down
2 changes: 1 addition & 1 deletion aesara/link/jax/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def reshape(x, shape):
def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):

res = jnp.transpose(x, op.shuffle + op.drop)
res = jnp.transpose(x, op.transposition)

shape = list(res.shape[: len(op.shuffle)])

Expand Down
4 changes: 2 additions & 2 deletions aesara/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs):
shuffle = tuple(op.shuffle)
drop = tuple(op.drop)
transposition = tuple(op.transposition)
augment = tuple(op.augment)
inplace = op.inplace

Expand Down Expand Up @@ -352,7 +352,7 @@ def populate_new_shape(i, j, new_shape, shuffle_shape):

@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
res = np.transpose(x, transposition)
shuffle_shape = res.shape[: len(shuffle)]

new_shape = create_zeros_tuple()
Expand Down
159 changes: 68 additions & 91 deletions aesara/tensor/c_code/dimshuffle.c
Original file line number Diff line number Diff line change
@@ -1,104 +1,81 @@
#section support_code_apply

int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject* input, PyArrayObject** res, PARAMS_TYPE* params) {
npy_bool* input_broadcastable;
npy_int64* new_order;
npy_intp nd_in;
npy_intp nd_out;
PyArrayObject* basename;
npy_intp* dimensions;
npy_intp* strides;

if (!PyArray_IS_C_CONTIGUOUS(params->input_broadcastable)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param input_broadcastable must be C-contiguous.");
return 1;
}
if (!PyArray_IS_C_CONTIGUOUS(params->_new_order)) {
PyErr_SetString(PyExc_RuntimeError, "DimShuffle: param _new_order must be C-contiguous.");
return 1;
}
input_broadcastable = (npy_bool*) PyArray_DATA(params->input_broadcastable);
new_order = (npy_int64*) PyArray_DATA(params->_new_order);
nd_in = PyArray_SIZE(params->input_broadcastable);
nd_out = PyArray_SIZE(params->_new_order);

/* check_input_nd */
if (PyArray_NDIM(input) != nd_in) {
PyErr_SetString(PyExc_NotImplementedError, "input nd");
return 1;
}
int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res,
PARAMS_TYPE *params) {

/* clear_output */
if (*res)
Py_XDECREF(*res);
// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject *_input;

/* get_base */
if (params->inplace) {
basename = input;
Py_INCREF((PyObject*)basename);
} else {
basename =
(PyArrayObject*)PyArray_FromAny((PyObject*)input,
NULL, 0, 0, NPY_ARRAY_ALIGNED|NPY_ARRAY_ENSURECOPY, NULL);
}
if (*res)
Py_XDECREF(*res);

/* shape_statements and strides_statements */
dimensions = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
strides = (npy_intp*) malloc(nd_out * sizeof(npy_intp));
if (dimensions == NULL || strides == NULL) {
PyErr_NoMemory();
free(dimensions);
free(strides);
return 1;
};

for (npy_intp i = 0; i < nd_out; ++i) {
if (new_order[i] != -1) {
dimensions[i] = PyArray_DIMS(basename)[new_order[i]];
strides[i] = PyArray_DIMS(basename)[new_order[i]] == 1 ?
0 : PyArray_STRIDES(basename)[new_order[i]];
} else {
dimensions[i] = 1;
strides[i] = 0;
}
}
if (params->inplace) {
_input = input;
Py_INCREF((PyObject *)_input);
} else {
_input = (PyArrayObject *)PyArray_FromAny(
(PyObject *)input, NULL, 0, 0, NPY_ARRAY_ALIGNED | NPY_ARRAY_ENSURECOPY,
NULL);
}

/* set the strides of the broadcasted dimensions.
* This algorithm is from numpy: PyArray_Newshape() in
* cvs/numpy/numpy/core/src/multiarraymodule.c */
if (nd_out > 0) {
if (strides[nd_out - 1] == 0)
strides[nd_out - 1] = PyArray_DESCR(basename)->elsize;
for (npy_intp i = nd_out - 2; i > -1; --i) {
if (strides[i] == 0)
strides[i] = strides[i + 1] * dimensions[i + 1];
}
}
PyArray_Dims permute;

if (!PyArray_IntpConverter((PyObject *)params->transposition, &permute)) {
return 1;
}

/* close_bracket */
// create a new array.
*res = (PyArrayObject*)PyArray_New(&PyArray_Type, nd_out, dimensions,
PyArray_TYPE(basename), strides,
PyArray_DATA(basename), PyArray_ITEMSIZE(basename),
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
(NPY_ARRAY_WRITEABLE * PyArray_ISWRITEABLE(basename)),
NULL);

if (*res == NULL) {
free(dimensions);
free(strides);
return 1;
/*
res = res.transpose(self.transposition)
*/
PyArrayObject *transposed_input =
(PyArrayObject *)PyArray_Transpose(_input, &permute);

PyDimMem_FREE(permute.ptr);

npy_intp *res_shape = PyArray_DIMS(transposed_input);
npy_intp N_shuffle = PyArray_SIZE(params->shuffle);
npy_intp N_augment = PyArray_SIZE(params->augment);
npy_intp N = N_augment + N_shuffle;
npy_intp *_reshape_shape = (npy_intp *)malloc(N * sizeof(npy_intp));

if (_reshape_shape == NULL) {
PyErr_NoMemory();
free(_reshape_shape);
return 1;
}

/*
shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
*/
npy_intp aug_idx = 0;
int res_idx = 0;
for (npy_intp i = 0; i < N; i++) {
if (aug_idx < N_augment &&
i == *((npy_intp *)PyArray_GetPtr(params->augment, &aug_idx))) {
_reshape_shape[i] = 1;
aug_idx++;
} else {
_reshape_shape[i] = res_shape[res_idx];
res_idx++;
}
}

PyArray_Dims reshape_shape = {.ptr = _reshape_shape, .len = (int)N};

/* res = res.reshape(shape) */
*res = (PyArrayObject *)PyArray_Newshape(transposed_input, &reshape_shape,
NPY_CORDER);

// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags(*res, NPY_ARRAY_UPDATE_ALL);
/* Py_XDECREF(transposed_input); */

// we are making a view in both inplace and non-inplace cases
PyArray_SetBaseObject(*res, (PyObject*)basename);
PyDimMem_FREE(reshape_shape.ptr);

free(strides);
free(dimensions);
if (!*res) {
return 1;
}

return 0;
return 0;
}
79 changes: 27 additions & 52 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,47 +119,27 @@ class DimShuffle(ExternalCOp):

@property
def params_type(self):
# We can't directly create `params_type` as class attribute
# because of importation issues related to TensorType.
return ParamsType(
input_broadcastable=TensorType(dtype="bool", broadcastable=(False,)),
_new_order=lvector,
transposition=TensorType(dtype="uint32", broadcastable=(False,)),
shuffle=lvector,
augment=lvector,
transposition=lvector,
inplace=scalar_bool,
)

@property
def _new_order(self):
# Param for C code.
# self.new_order may contain 'x', which is not a valid integer value.
# We replace it with -1.
return [(-1 if x == "x" else x) for x in self.new_order]

@property
def transposition(self):
return self.shuffle + self.drop

def __init__(self, input_broadcastable, new_order, inplace=True):
def __init__(self, input_broadcastable, new_order):
super().__init__([self.c_func_file], self.c_func_name)

self.input_broadcastable = tuple(input_broadcastable)
self.new_order = tuple(new_order)
if inplace is True:
self.inplace = inplace
else:
raise ValueError(
"DimShuffle is inplace by default and hence the inplace for DimShuffle must be true"
)

self.inplace = True

for i, j in enumerate(new_order):
if j != "x":
# There is a bug in numpy that results in
# isinstance(x, integer_types) returning False for
# numpy integers. See
# <http://projects.scipy.org/numpy/ticket/2235>.
if not isinstance(j, (int, np.integer)):
raise TypeError(
"DimShuffle indices must be python ints. "
f"Got: '{j}' of type '{type(j)}'."
"DimShuffle indices must be Python ints; got "
f"{j} of type {type(j)}."
)
if j >= len(input_broadcastable):
raise ValueError(
Expand All @@ -169,31 +149,30 @@ def __init__(self, input_broadcastable, new_order, inplace=True):
if j in new_order[(i + 1) :]:
raise ValueError(
"The same input dimension may not appear "
"twice in the list of output dimensions",
new_order,
f"twice in the list of output dimensions: {new_order}"
)

# list of dimensions of the input to drop
self.drop = []
# List of input dimensions to drop
drop = []
for i, b in enumerate(input_broadcastable):
if i not in new_order:
# we want to drop this dimension because it's not a value in
# new_order
if b == 1: # 1 aka True
self.drop.append(i)
# We want to drop this dimension because it's not a value in
# `new_order`
if b == 1:
drop.append(i)
else:
# we cannot drop non-broadcastable dimensions
# We cannot drop non-broadcastable dimensions
raise ValueError(
"You cannot drop a non-broadcastable dimension:",
f" {input_broadcastable}, {new_order}",
"Cannot drop a non-broadcastable dimension: "
f"{input_broadcastable}, {new_order}"
)

# this is the list of the original dimensions that we keep
# This is the list of the original dimensions that we keep
self.shuffle = [x for x in new_order if x != "x"]

# list of dimensions of the output that are broadcastable and were not
self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not
# in the original input
self.augment = [i for i, x in enumerate(new_order) if x == "x"]
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"])

if self.inplace:
self.view_map = {0: [0]}
Expand Down Expand Up @@ -241,27 +220,23 @@ def __str__(self):
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)

def perform(self, node, inp, out, params):
(input,) = inp
(res,) = inp
(storage,) = out
# drop
res = input

if type(res) != np.ndarray and type(res) != np.memmap:
raise TypeError(res)

# transpose
res = res.transpose(self.shuffle + self.drop)
res = res.transpose(self.transposition)

# augment
shape = list(res.shape[: len(self.shuffle)])
for augm in self.augment:
shape.insert(augm, 1)
res = res.reshape(shape)

# copy (if not inplace)
if not self.inplace:
res = np.copy(res)

storage[0] = np.asarray(res) # asarray puts scalars back into array
storage[0] = np.asarray(res)

def infer_shape(self, fgraph, node, shapes):
(ishp,) = shapes
Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,4 @@ def conj_inplace(a):
def transpose_inplace(x, **kwargs):
"Perform a transpose on a tensor without copying the underlying storage"
dims = list(range(x.ndim - 1, -1, -1))
return DimShuffle(x.broadcastable, dims, inplace=True)(x)
return DimShuffle(x.broadcastable, dims)(x)
Loading