Skip to content
Merged
99 changes: 50 additions & 49 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,55 +131,56 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t

Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.

| Burn | PyTorch Equivalent |
|---------------------------------------------|---------------------------------------------------------------------------|
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dtype()` | `tensor.dtype` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.roll(shfts, dims)` | `tensor.roll(shifts, dims)` |
| `tensor.roll_dim(shift, dim)` | `tensor.roll([shift], [dim])` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)`| N/A |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(s![range;step])` | `tensor[(*ranges,)]` or `tensor[start:end:step]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.slice_fill(ranges, value)` | `tensor[(*ranges,)] = value` |
| `tensor.slice_dim(dim, range)` | N/A |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.take(dim, indices)` | `numpy.take(tensor, indices, dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.t()` | `tensor.T` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.unsqueeze_dims(dims)` | N/A |
| Burn | PyTorch Equivalent |
|----------------------------------------------|---------------------------------------------------------------------------|
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dtype()` | `tensor.dtype` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.roll(shfts, dims)` | `tensor.roll(shifts, dims)` |
| `tensor.roll_dim(shift, dim)` | `tensor.roll([shift], [dim])` |
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(s![range;step])` | `tensor[(*ranges,)]` or `tensor[start:end:step]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.slice_fill(ranges, value)` | `tensor[(*ranges,)] = value` |
| `tensor.slice_dim(dim, range)` | N/A |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.take(dim, indices)` | `numpy.take(tensor, indices, dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.t()` | `tensor.T` |
| `tensor.unfold(dim, size, step)` | `tensor.unfold(dim, size, step)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.unsqueeze_dims(dims)` | N/A |

### Numeric Operations

Expand Down
9 changes: 9 additions & 0 deletions crates/burn-autodiff/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,13 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
B::bool_repeat_dim(tensor, dim, times)
}

fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
B::bool_unfold(tensor, dim, size, step)
}
}
9 changes: 9 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,4 +377,13 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
B::int_cast(tensor, dtype)
}

fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
B::int_unfold(tensor, dim, size, step)
}
}
9 changes: 9 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2592,6 +2592,15 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>

// TODO: Implement float_prod and float_sum
// https://github.com/tracel-ai/burn/issues/1458

fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
AutodiffTensor::new(B::float_unfold(tensor.primitive, dim, size, step))
}
}

#[derive(Debug, Clone)]
Expand Down
32 changes: 28 additions & 4 deletions crates/burn-candle/src/ops/base.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::cmp::max;
use std::marker::PhantomData;

use burn_tensor::{Element, Shape, TensorData, TensorMetadata, backend::Backend};
use candle_core::WithDType;
use half::{bf16, f16};

use crate::{
Candle, CandleDevice, CandleTensor,
element::{CandleElement, FloatCandleElement, IntCandleElement},
};
use burn_tensor::ops::unfold::{calculate_unfold_shape, calculate_unfold_windows};
use burn_tensor::{Element, Shape, TensorData, TensorMetadata, backend::Backend};
use candle_core::{Layout, WithDType};
use half::{bf16, f16};

use super::tensor;

Expand Down Expand Up @@ -193,6 +194,29 @@ pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {
CandleTensor::new(tensor.tensor.broadcast_as(shape.dims).unwrap())
}

pub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor {
let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
let windows = result_shape[dim];

let mut select_ranges = tensor.shape().into_ranges();
let new_axis = select_ranges.len();

let mut stack = Vec::with_capacity(windows);
for widx in 0..windows {
let start = widx * step;
let end = start + size;
select_ranges[dim] = start..end;

let mut window_slice = slice(tensor.clone(), &select_ranges);

window_slice = swap_dims(window_slice, dim, new_axis);
let window_slice = CandleTensor::new(window_slice.tensor.unsqueeze(new_axis).unwrap());

stack.push(window_slice);
}
cat(stack, dim)
}

pub fn sign(tensor: CandleTensor) -> CandleTensor {
CandleTensor::new(tensor.tensor.sign().unwrap())
}
Expand Down
11 changes: 10 additions & 1 deletion crates/burn-candle/src/ops/bool_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
};

use super::base::{expand, permute};
use super::base::{expand, permute, unfold};

impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
Expand Down Expand Up @@ -136,4 +136,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
expand(tensor, shape)
}

fn bool_unfold(
tensor: BoolTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> BoolTensor<Self> {
unfold(tensor, dim, size, step)
}
}
11 changes: 10 additions & 1 deletion crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
};

use super::base::{expand, permute, sign};
use super::base::{expand, permute, sign, unfold};

impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
Expand Down Expand Up @@ -384,6 +384,15 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
expand(tensor, shape)
}

fn int_unfold(
tensor: IntTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> IntTensor<Self> {
unfold(tensor, dim, size, step)
}

fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
sign(tensor)
}
Expand Down
11 changes: 10 additions & 1 deletion crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
};

use super::base::{expand, permute, sign};
use super::base::{expand, permute, sign, unfold};

impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
Expand Down Expand Up @@ -460,6 +460,15 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
expand(tensor, shape)
}

fn float_unfold(
tensor: FloatTensor<Self>,
dim: usize,
size: usize,
step: usize,
) -> FloatTensor<Self> {
unfold(tensor, dim, size, step)
}

fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
sign(tensor)
}
Expand Down
42 changes: 42 additions & 0 deletions crates/burn-cubecl/src/ops/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
use burn_common::tensor::{ReshapeAction, reshape_action};
use burn_tensor::ops::unfold::calculate_unfold_shape;
use burn_tensor::{
Shape, TensorData,
quantization::{QTensorPrimitive, QuantLevel},
Expand Down Expand Up @@ -213,3 +214,44 @@ pub(crate) fn max_line_size_many<R: CubeRuntime>(tensors: &[&CubeTensor<R>], dim

vec.unwrap_or(0)
}

/// Unfold windows along a dimension.
///
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
/// where windows are advanced by `step` at each index.
///
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
///
/// The new view will have the unfolded dimension replaced by two dimensions;
/// one in the position of the original dimension, with size equal to the number of windows,
/// and one appended to the right-most position, with size equal to `size`.
///
/// # Arguments
///
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
/// * `dim` - the dimension to unfold.
/// * `size` - the size of each unfolded window.
/// * `step` - the step between each window.
///
/// # Returns
///
/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
pub fn unfold<R: CubeRuntime>(
tensor: CubeTensor<R>,
dim: usize,
size: usize,
step: usize,
) -> CubeTensor<R> {
let shape = calculate_unfold_shape(tensor.shape, dim, size, step);

let d_stride = tensor.strides[dim];
let mut strides = tensor.strides.clone();
strides[dim] = step * d_stride;
strides.push(d_stride);

CubeTensor {
shape: shape.into(),
strides,
..tensor
}
}
Loading
Loading