diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 91e50481ec..949695848b 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -187,11 +187,11 @@ impl Layout { }) } - pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> { + pub(crate) fn strided_index(&self) -> crate::StridedIndex { crate::StridedIndex::from_layout(self) } - pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> { + pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks { let mut block_len = 1; let mut contiguous_dims = 0; // These are counted from the right. for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() { diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 92734b8447..2e6292b64b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -3,15 +3,15 @@ use crate::Layout; /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] -pub struct StridedIndex<'a> { +pub struct StridedIndex { next_storage_index: Option, - multi_index: Vec, - dims: &'a [usize], - stride: &'a [usize], + // (current_index_in_dim, max_index_in_dim, stride_for_dim) + multi_index: Vec<(usize, usize, usize)>, + remaining: usize, } -impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self { +impl StridedIndex { + pub(crate) fn new(dims: &[usize], stride: &[usize], start_offset: usize) -> Self { let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None @@ -19,44 +19,47 @@ impl<'a> StridedIndex<'a> { // This applies to the scalar case. Some(start_offset) }; + // Precompute multi index iterator. + // For each dim, we have (current_index_in_dim, max_index_in_dim, stride_for_dim) + let multi_index: Vec<_> = dims + .iter() + .zip(stride.iter()) + .rev() + .map(|(dim, stride)| (0, *dim, *stride)) + .collect(); StridedIndex { next_storage_index, - multi_index: vec![0; dims.len()], - dims, - stride, + multi_index, + remaining: elem_count, } } - pub(crate) fn from_layout(l: &'a Layout) -> Self { + pub(crate) fn from_layout(l: &Layout) -> Self { Self::new(l.dims(), l.stride(), l.start_offset()) } } -impl Iterator for StridedIndex<'_> { +impl Iterator for StridedIndex { type Item = usize; + #[inline] fn next(&mut self) -> Option { let storage_index = self.next_storage_index?; let mut updated = false; let mut next_storage_index = storage_index; - for ((multi_i, max_i), stride_i) in self - .multi_index - .iter_mut() - .zip(self.dims.iter()) - .zip(self.stride.iter()) - .rev() - { + for (multi_i, max_i, stride_i) in self.multi_index.iter_mut() { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; updated = true; - next_storage_index += stride_i; + next_storage_index += *stride_i; break; } else { - next_storage_index -= *multi_i * stride_i; + next_storage_index -= *multi_i * *stride_i; *multi_i = 0 } } + self.remaining -= 1; self.next_storage_index = if updated { Some(next_storage_index) } else { @@ -64,16 +67,28 @@ impl Iterator for StridedIndex<'_> { }; Some(storage_index) } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +impl ExactSizeIterator for StridedIndex { + #[inline] + fn len(&self) -> usize { + self.remaining + } } #[derive(Debug)] -pub enum StridedBlocks<'a> { +pub enum StridedBlocks { SingleBlock { start_offset: usize, len: usize, }, MultipleBlocks { - block_start_index: StridedIndex<'a>, + block_start_index: StridedIndex, block_len: usize, }, } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d71630212d..952374c2e6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1742,7 +1742,7 @@ impl Tensor { /// Returns an iterator over position of the elements in the storage when ranging over the /// index tuples in lexicographic order. - pub fn strided_index(&self) -> crate::StridedIndex<'_> { + pub fn strided_index(&self) -> crate::StridedIndex { self.layout.strided_index() } @@ -1750,7 +1750,7 @@ impl Tensor { /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator /// will only return the start offset and the size would be the number of elements in the /// tensor. - pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> { + pub fn strided_blocks(&self) -> crate::StridedBlocks { self.layout.strided_blocks() }