From 486f28eb4e555530a08710de4af4dd3a29d31611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andris=20V=C4=81ravs?= Date: Sat, 4 Oct 2025 00:26:29 +0300 Subject: [PATCH 1/3] minor performance improvement for StridedIndex * precompute inner multi-index lookup, this yields a few % improvement due to this iterator being hot enough --- candle-core/src/layout.rs | 4 ++-- candle-core/src/strided_index.rs | 41 +++++++++++++++----------------- candle-core/src/tensor.rs | 4 ++-- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 91e50481ec..949695848b 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -187,11 +187,11 @@ impl Layout { }) } - pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> { + pub(crate) fn strided_index(&self) -> crate::StridedIndex { crate::StridedIndex::from_layout(self) } - pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> { + pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks { let mut block_len = 1; let mut contiguous_dims = 0; // These are counted from the right. for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() { diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 92734b8447..db02492879 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -3,15 +3,13 @@ use crate::Layout; /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] -pub struct StridedIndex<'a> { +pub struct StridedIndex { next_storage_index: Option, - multi_index: Vec, - dims: &'a [usize], - stride: &'a [usize], + multi_index: Vec<((usize, usize), usize)>, } -impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self { +impl StridedIndex { + pub(crate) fn new(dims: &[usize], stride: &[usize], start_offset: usize) -> Self { let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None @@ -19,41 +17,40 @@ impl<'a> StridedIndex<'a> { // This applies to the scalar case. Some(start_offset) }; + // This iterator is hot enough that precomputing the zipped structure is worth it. + let multi_index: Vec<_> = vec![0usize; dims.len()] + .into_iter() + .zip(dims.iter().copied()) + .zip(stride.iter().copied()) + .rev() + .collect(); StridedIndex { next_storage_index, - multi_index: vec![0; dims.len()], - dims, - stride, + multi_index, } } - pub(crate) fn from_layout(l: &'a Layout) -> Self { + pub(crate) fn from_layout(l: &Layout) -> Self { Self::new(l.dims(), l.stride(), l.start_offset()) } } -impl Iterator for StridedIndex<'_> { +impl Iterator for StridedIndex { type Item = usize; fn next(&mut self) -> Option { let storage_index = self.next_storage_index?; let mut updated = false; let mut next_storage_index = storage_index; - for ((multi_i, max_i), stride_i) in self - .multi_index - .iter_mut() - .zip(self.dims.iter()) - .zip(self.stride.iter()) - .rev() - { + for ((multi_i, max_i), stride_i) in self.multi_index.iter_mut() { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; updated = true; - next_storage_index += stride_i; + next_storage_index += *stride_i; break; } else { - next_storage_index -= *multi_i * stride_i; + next_storage_index -= *multi_i * *stride_i; *multi_i = 0 } } @@ -67,13 +64,13 @@ impl Iterator for StridedIndex<'_> { } #[derive(Debug)] -pub enum StridedBlocks<'a> { +pub enum StridedBlocks { SingleBlock { start_offset: usize, len: usize, }, MultipleBlocks { - block_start_index: StridedIndex<'a>, + block_start_index: StridedIndex, block_len: usize, }, } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d71630212d..952374c2e6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1742,7 +1742,7 @@ impl Tensor { /// Returns an iterator over position of the elements in the storage when ranging over the /// index tuples in lexicographic order. - pub fn strided_index(&self) -> crate::StridedIndex<'_> { + pub fn strided_index(&self) -> crate::StridedIndex { self.layout.strided_index() } @@ -1750,7 +1750,7 @@ impl Tensor { /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator /// will only return the start offset and the size would be the number of elements in the /// tensor. - pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> { + pub fn strided_blocks(&self) -> crate::StridedBlocks { self.layout.strided_blocks() } From c33633fbdbde1ce85a5727ad0bc4389fdaf6bffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andris=20V=C4=81ravs?= Date: Sat, 18 Oct 2025 23:53:00 +0300 Subject: [PATCH 2/3] candle-core: slightly faster StridedIndex::new() + more comments --- candle-core/src/strided_index.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index db02492879..f5efd469e1 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -5,7 +5,8 @@ use crate::Layout; #[derive(Debug)] pub struct StridedIndex { next_storage_index: Option, - multi_index: Vec<((usize, usize), usize)>, + // (current_index_in_dim, max_index_in_dim, stride_for_dim) + multi_index: Vec<(usize, usize, usize)>, } impl StridedIndex { @@ -17,12 +18,13 @@ impl StridedIndex { // This applies to the scalar case. Some(start_offset) }; - // This iterator is hot enough that precomputing the zipped structure is worth it. - let multi_index: Vec<_> = vec![0usize; dims.len()] - .into_iter() - .zip(dims.iter().copied()) - .zip(stride.iter().copied()) + // Precompute multi index iterator. + // For each dim, we have (current_index_in_dim, max_index_in_dim, stride_for_dim) + let multi_index: Vec<_> = dims + .iter() + .zip(stride.iter()) .rev() + .map(|(dim, stride)| (0, *dim, *stride)) .collect(); StridedIndex { next_storage_index, @@ -42,7 +44,7 @@ impl Iterator for StridedIndex { let storage_index = self.next_storage_index?; let mut updated = false; let mut next_storage_index = storage_index; - for ((multi_i, max_i), stride_i) in self.multi_index.iter_mut() { + for (multi_i, max_i, stride_i) in self.multi_index.iter_mut() { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; From 010fd643187ea2236979999b68b235e080b85c10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andris=20V=C4=81ravs?= Date: Sun, 26 Oct 2025 00:11:55 +0300 Subject: [PATCH 3/3] candle-core: impl size hint/exact length for StridedIndex --- candle-core/src/strided_index.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index f5efd469e1..2e6292b64b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -7,6 +7,7 @@ pub struct StridedIndex { next_storage_index: Option, // (current_index_in_dim, max_index_in_dim, stride_for_dim) multi_index: Vec<(usize, usize, usize)>, + remaining: usize, } impl StridedIndex { @@ -29,6 +30,7 @@ impl StridedIndex { StridedIndex { next_storage_index, multi_index, + remaining: elem_count, } } @@ -40,6 +42,7 @@ impl StridedIndex { impl Iterator for StridedIndex { type Item = usize; + #[inline] fn next(&mut self) -> Option { let storage_index = self.next_storage_index?; let mut updated = false; @@ -56,6 +59,7 @@ impl Iterator for StridedIndex { *multi_i = 0 } } + self.remaining -= 1; self.next_storage_index = if updated { Some(next_storage_index) } else { @@ -63,6 +67,18 @@ impl Iterator for StridedIndex { }; Some(storage_index) } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +impl ExactSizeIterator for StridedIndex { + #[inline] + fn len(&self) -> usize { + self.remaining + } } #[derive(Debug)]