Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions candle-core/src/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,11 @@ impl Layout {
})
}

pub(crate) fn strided_index(&self) -> crate::StridedIndex<'_> {
pub(crate) fn strided_index(&self) -> crate::StridedIndex {
crate::StridedIndex::from_layout(self)
}

pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks {
let mut block_len = 1;
let mut contiguous_dims = 0; // These are counted from the right.
for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() {
Expand Down
41 changes: 19 additions & 22 deletions candle-core/src/strided_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,54 @@ use crate::Layout;
/// An iterator over offset position for items of an N-dimensional arrays stored in a
/// flat buffer using some potential strides.
#[derive(Debug)]
pub struct StridedIndex<'a> {
pub struct StridedIndex {
next_storage_index: Option<usize>,
multi_index: Vec<usize>,
dims: &'a [usize],
stride: &'a [usize],
multi_index: Vec<((usize, usize), usize)>,
}

impl<'a> StridedIndex<'a> {
pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self {
impl StridedIndex {
pub(crate) fn new(dims: &[usize], stride: &[usize], start_offset: usize) -> Self {
let elem_count: usize = dims.iter().product();
let next_storage_index = if elem_count == 0 {
None
} else {
// This applies to the scalar case.
Some(start_offset)
};
// This iterator is hot enough that precomputing the zipped structure is worth it.
let multi_index: Vec<_> = vec![0usize; dims.len()]
.into_iter()
.zip(dims.iter().copied())
.zip(stride.iter().copied())
.rev()
.collect();
StridedIndex {
next_storage_index,
multi_index: vec![0; dims.len()],
dims,
stride,
multi_index,
}
}

pub(crate) fn from_layout(l: &'a Layout) -> Self {
pub(crate) fn from_layout(l: &Layout) -> Self {
Self::new(l.dims(), l.stride(), l.start_offset())
}
}

impl Iterator for StridedIndex<'_> {
impl Iterator for StridedIndex {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
let storage_index = self.next_storage_index?;
let mut updated = false;
let mut next_storage_index = storage_index;
for ((multi_i, max_i), stride_i) in self
.multi_index
.iter_mut()
.zip(self.dims.iter())
.zip(self.stride.iter())
.rev()
{
for ((multi_i, max_i), stride_i) in self.multi_index.iter_mut() {
let next_i = *multi_i + 1;
if next_i < *max_i {
*multi_i = next_i;
updated = true;
next_storage_index += stride_i;
next_storage_index += *stride_i;
break;
} else {
next_storage_index -= *multi_i * stride_i;
next_storage_index -= *multi_i * *stride_i;
*multi_i = 0
}
}
Expand All @@ -67,13 +64,13 @@ impl Iterator for StridedIndex<'_> {
}

#[derive(Debug)]
pub enum StridedBlocks<'a> {
pub enum StridedBlocks {
SingleBlock {
start_offset: usize,
len: usize,
},
MultipleBlocks {
block_start_index: StridedIndex<'a>,
block_start_index: StridedIndex,
block_len: usize,
},
}
4 changes: 2 additions & 2 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1742,15 +1742,15 @@ impl Tensor {

/// Returns an iterator over position of the elements in the storage when ranging over the
/// index tuples in lexicographic order.
pub fn strided_index(&self) -> crate::StridedIndex<'_> {
pub fn strided_index(&self) -> crate::StridedIndex {
self.layout.strided_index()
}

/// Similar to `strided_index` but returns the position of the start of each contiguous block
/// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator
/// will only return the start offset and the size would be the number of elements in the
/// tensor.
pub fn strided_blocks(&self) -> crate::StridedBlocks<'_> {
pub fn strided_blocks(&self) -> crate::StridedBlocks {
self.layout.strided_blocks()
}

Expand Down
Loading