diff --git a/src/impl_methods.rs b/src/impl_methods.rs
index 01a0a9a4d..13a94ff0c 100644
--- a/src/impl_methods.rs
+++ b/src/impl_methods.rs
@@ -1516,6 +1516,17 @@ impl ArrayRef
/// }
/// ```
pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D>
+ {
+ self.axis_windows_with_stride(axis, window_size, 1)
+ }
+
+ /// Returns a producer which traverses over windows of a given length and
+ /// stride along an axis.
+ ///
+ /// Note that a calling this method with a stride of 1 is equivalent to
+ /// calling [`ArrayRef::axis_windows()`].
+ pub fn axis_windows_with_stride(&self, axis: Axis, window_size: usize, stride_size: usize)
+ -> AxisWindows<'_, A, D>
{
let axis_index = axis.index();
@@ -1527,7 +1538,12 @@ impl ArrayRef
self.shape()
);
- AxisWindows::new(self.view(), axis, window_size)
+ ndassert!(
+ stride_size >0,
+ "Stride size must be greater than zero"
+ );
+
+ AxisWindows::new_with_stride(self.view(), axis, window_size, stride_size)
}
/// Return a view of the diagonal elements of the array.
diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs
index afdaaa895..f3442c0af 100644
--- a/src/iterators/windows.rs
+++ b/src/iterators/windows.rs
@@ -141,7 +141,7 @@ pub struct AxisWindows<'a, A, D>
impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
{
- pub(crate) fn new(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize) -> Self
+ pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, axis: Axis, window_size: usize, stride_size: usize) -> Self
{
let window_strides = a.strides.clone();
let axis_idx = axis.index();
@@ -150,10 +150,11 @@ impl<'a, A, D: Dimension> AxisWindows<'a, A, D>
window[axis_idx] = window_size;
let ndim = window.ndim();
- let mut unit_stride = D::zeros(ndim);
- unit_stride.slice_mut().fill(1);
+ let mut stride = D::zeros(ndim);
+ stride.slice_mut().fill(1);
+ stride[axis_idx] = stride_size;
- let base = build_base(a, window.clone(), unit_stride);
+ let base = build_base(a, window.clone(), stride);
AxisWindows {
base,
axis_idx,
diff --git a/tests/windows.rs b/tests/windows.rs
index 6506d8301..4d4d0d7d7 100644
--- a/tests/windows.rs
+++ b/tests/windows.rs
@@ -294,6 +294,148 @@ fn tests_axis_windows_3d_zips_with_1d()
assert_eq!(b,arr1(&[207, 261]));
}
+/// Test verifies that non existent Axis results in panic
+#[test]
+#[should_panic]
+fn axis_windows_with_stride_outofbound()
+{
+ let a = Array::from_iter(10..37)
+ .into_shape_with_order((3, 3, 3))
+ .unwrap();
+ a.axis_windows_with_stride(Axis(4), 2, 2);
+}
+
+/// Test verifies that zero sizes results in panic
+#[test]
+#[should_panic]
+fn axis_windows_with_stride_zero_size()
+{
+ let a = Array::from_iter(10..37)
+ .into_shape_with_order((3, 3, 3))
+ .unwrap();
+ a.axis_windows_with_stride(Axis(0), 0, 2);
+}
+
+/// Test verifies that zero stride results in panic
+#[test]
+#[should_panic]
+fn axis_windows_with_stride_zero_stride()
+{
+ let a = Array::from_iter(10..37)
+ .into_shape_with_order((3, 3, 3))
+ .unwrap();
+ a.axis_windows_with_stride(Axis(0), 2, 0);
+}
+
+/// Test verifies that over sized windows yield nothing
+#[test]
+fn axis_windows_with_stride_oversized()
+{
+ let a = Array::from_iter(10..37)
+ .into_shape_with_order((3, 3, 3))
+ .unwrap();
+ let mut iter = a.axis_windows_with_stride(Axis(2), 4, 2).into_iter();
+ assert_eq!(iter.next(), None);
+}
+
+/// Simple test for iterating 1d-arrays via `Axis Windows`.
+#[test]
+fn test_axis_windows_with_stride_1d()
+{
+ let a = Array::from_iter(10..20).into_shape_with_order(10).unwrap();
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 2), vec![
+ arr1(&[10, 11, 12, 13, 14]),
+ arr1(&[12, 13, 14, 15, 16]),
+ arr1(&[14, 15, 16, 17, 18]),
+ ]);
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 5, 3), vec![
+ arr1(&[10, 11, 12, 13, 14]),
+ arr1(&[13, 14, 15, 16, 17]),
+ ]);
+}
+
+/// Simple test for iterating 2d-arrays via `Axis Windows`.
+#[test]
+fn test_axis_windows_with_stride_2d()
+{
+ let a = Array::from_iter(10..30)
+ .into_shape_with_order((5, 4))
+ .unwrap();
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 1), vec![
+ arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]),
+ arr2(&[[14, 15, 16, 17], [18, 19, 20, 21]]),
+ arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]),
+ arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]),
+ ]);
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 2), vec![
+ arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]),
+ arr2(&[[18, 19, 20, 21], [22, 23, 24, 25]]),
+ ]);
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(0), 2, 3), vec![
+ arr2(&[[10, 11, 12, 13], [14, 15, 16, 17]]),
+ arr2(&[[22, 23, 24, 25], [26, 27, 28, 29]]),
+ ]);
+}
+
+/// Simple test for iterating 3d-arrays via `Axis Windows`.
+#[test]
+fn test_axis_windows_with_stride_3d()
+{
+ let a = Array::from_iter(0..27)
+ .into_shape_with_order((3, 3, 3))
+ .unwrap();
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 1), vec![
+ arr3(&[
+ [[0, 1, 2], [3, 4, 5]],
+ [[9, 10, 11], [12, 13, 14]],
+ [[18, 19, 20], [21, 22, 23]],
+ ]),
+ arr3(&[
+ [[3, 4, 5], [6, 7, 8]],
+ [[12, 13, 14], [15, 16, 17]],
+ [[21, 22, 23], [24, 25, 26]],
+ ]),
+ ]);
+
+ itertools::assert_equal(a.axis_windows_with_stride(Axis(1), 2, 2), vec![
+ arr3(&[
+ [[0, 1, 2], [3, 4, 5]],
+ [[9, 10, 11], [12, 13, 14]],
+ [[18, 19, 20], [21, 22, 23]],
+ ]),
+ ]);
+}
+
+#[test]
+fn tests_axis_windows_with_stride_3d_zips_with_1d()
+{
+ let a = Array::from_iter(0..27)
+ .into_shape_with_order((3, 3, 3))
+ .unwrap();
+ let mut b1 = Array::zeros(2);
+ let mut b2 = Array::zeros(1);
+
+ Zip::from(b1.view_mut())
+ .and(a.axis_windows_with_stride(Axis(1), 2, 1))
+ .for_each(|b, a| {
+ *b = a.sum();
+ });
+ assert_eq!(b1,arr1(&[207, 261]));
+
+ Zip::from(b2.view_mut())
+ .and(a.axis_windows_with_stride(Axis(1), 2, 2))
+ .for_each(|b, a| {
+ *b = a.sum();
+ });
+ assert_eq!(b2,arr1(&[207]));
+}
+
#[test]
fn test_window_neg_stride()
{