Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;

fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(
&self,
_: &Layout,
_: (usize, usize),
_: (usize, usize),
_: usize,
) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;

Expand Down
7 changes: 6 additions & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ impl Tensor {
arg,
kernel_size,
stride,
padding,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
Expand All @@ -369,7 +370,11 @@ impl Tensor {
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let avg = if *padding == 0 {
mask.avg_pool2d_with_stride(*kernel_size, *stride)?
} else {
mask.max_pool2d_with_stride_padding(*kernel_size, *stride, *padding)?
};
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
Expand Down
48 changes: 38 additions & 10 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,27 @@ impl Map1 for AvgPool2D {
}
}

struct MaxPool2D((usize, usize), (usize, usize));
struct MaxPool2D((usize, usize), (usize, usize), usize);

impl Map1 for MaxPool2D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
let (k_h, k_w) = self.0;
let (s_h, s_w) = self.1;
let padding = self.2;
let (b_sz, c, h, w) = layout.shape().dims4()?;
let stride = layout.stride();
let (stride_h, stride_w) = (stride[2], stride[3]);
let h_out = (h - k_h) / s_h + 1;
let w_out = (w - k_w) / s_w + 1;

// Calculate output dimensions with padding
let h_padded = h + 2 * padding;
let w_padded = w + 2 * padding;
let h_out = (h_padded - k_h) / s_h + 1;
let w_out = (w_padded - k_w) / s_w + 1;

let src_index = layout.start_offset();
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];

for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * h_out * w_out..];
let src_index = src_index + b_idx * stride[0];
Expand All @@ -365,17 +372,37 @@ impl Map1 for MaxPool2D {
let src_index = src_index + c_idx * stride[1];
for h_idx in 0..h_out {
for w_idx in 0..w_out {
let mut largest =
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
let mut largest = T::zero();
let mut found_valid = false;

for m in 0..k_h {
for n in 0..k_w {
let m = s_h * h_idx + m;
let n = s_w * w_idx + n;
if largest < src[src_index + m * stride_h + n * stride_w] {
largest = src[src_index + m * stride_h + n * stride_w]
let src_h = s_h * h_idx + m;
let src_w = s_w * w_idx + n;

// Check if we're within the original (unpadded) bounds
if src_h >= padding
&& src_h < h + padding
&& src_w >= padding
&& src_w < w + padding
{
let actual_h = src_h - padding;
let actual_w = src_w - padding;
let val =
src[src_index + actual_h * stride_h + actual_w * stride_w];
if !found_valid || largest < val {
largest = val;
found_valid = true;
}
}
}
}

// If no valid values were found (all padding), use zero
if !found_valid {
largest = T::zero();
}

dst[h_idx * w_out + w_idx] = largest;
}
}
Expand Down Expand Up @@ -1979,8 +2006,9 @@ impl BackendStorage for CpuStorage {
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: usize,
) -> Result<Self> {
MaxPool2D(kernel_size, stride).map(self, layout)
MaxPool2D(kernel_size, stride, padding).map(self, layout)
}

fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
Expand Down
20 changes: 17 additions & 3 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ struct Pool2D {
h_k: usize,
w_stride: usize,
h_stride: usize,
padding: usize,
op: PoolOp,
}

Expand All @@ -905,8 +906,13 @@ impl Map1 for Pool2D {
crate::bail!("unexpected input shape for pool {dims:?}")
};
let el = shape.elem_count();
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
let out_h = (dims[3] - self.h_k) / self.h_stride + 1;

// Calculate output dimensions with padding
let h_padded = dims[2] + 2 * self.padding;
let w_padded = dims[3] + 2 * self.padding;
let out_w = (h_padded - self.w_k) / self.w_stride + 1;
let out_h = (w_padded - self.h_k) / self.h_stride + 1;

let dst_el = out_w * out_h * dims[0] * dims[1];
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let kname = match self.op {
Expand All @@ -923,6 +929,7 @@ impl Map1 for Pool2D {
barg!(builder, self.h_k);
barg!(builder, self.w_stride);
barg!(builder, self.h_stride);
barg!(builder, self.padding);
builder.arg(&ds);
builder.arg(inp);
builder.arg(&out);
Expand Down Expand Up @@ -1889,13 +1896,20 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}

fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
fn max_pool2d(
&self,
l: &Layout,
k: (usize, usize),
stride: (usize, usize),
padding: usize,
) -> Result<Self> {
let device = self.device().clone();
let slice = Pool2D {
w_k: k.0,
h_k: k.1,
w_stride: stride.0,
h_stride: stride.1,
padding,
op: PoolOp::Max,
}
.map(&self.slice, &device, l)?;
Expand Down
8 changes: 7 additions & 1 deletion candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,13 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}

fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
fn max_pool2d(
&self,
_: &Layout,
_: (usize, usize),
_: (usize, usize),
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

Expand Down
8 changes: 7 additions & 1 deletion candle-core/src/dummy_metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,13 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}

fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
fn max_pool2d(
&self,
_: &Layout,
_: (usize, usize),
_: (usize, usize),
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

Expand Down
13 changes: 11 additions & 2 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,7 @@ impl BackendStorage for MetalStorage {
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;
let dst_el = out_w * out_h * b_size * channels;
let padding = 0;
let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?;
let command_buffers = self.device.command_buffer()?;
candle_metal_kernels::call_pool2d(
Expand All @@ -1312,6 +1313,7 @@ impl BackendStorage for MetalStorage {
h_k,
w_stride,
h_stride,
padding,
&self.buffer,
&buffer,
)
Expand All @@ -1324,6 +1326,7 @@ impl BackendStorage for MetalStorage {
inp_l: &Layout,
(w_k, h_k): (usize, usize),
(w_stride, h_stride): (usize, usize),
padding: usize,
) -> Result<Self> {
let shape = inp_l.shape();
let (b_size, channels, width, height) = shape.dims4()?;
Expand All @@ -1336,8 +1339,13 @@ impl BackendStorage for MetalStorage {
DType::U32 => "max_pool2d_u32",
dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"),
};
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;

// Calculate output dimensions with padding
let width_padded = width + 2 * padding;
let height_padded = height + 2 * padding;
let out_w = (width_padded - w_k) / w_stride + 1;
let out_h = (height_padded - h_k) / h_stride + 1;

let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?;
let command_buffers = self.device.command_buffer()?;
Expand All @@ -1354,6 +1362,7 @@ impl BackendStorage for MetalStorage {
h_k,
w_stride,
h_stride,
padding,
&self.buffer,
&buffer,
)
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub enum Op {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: usize,
},

UpsampleNearest1D {
Expand Down
7 changes: 4 additions & 3 deletions candle-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,18 +519,19 @@ impl Storage {
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: usize,
) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?;
Ok(Self::Metal(storage))
}
}
Expand Down
46 changes: 41 additions & 5 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1214,24 +1214,60 @@ impl Tensor {
&self,
kernel_size: T,
stride: T,
) -> Result<Self> {
let sz = kernel_size.to_usize2();
let stride = stride.to_usize2();
self.max_pool2d_with_stride_padding(sz, stride, 0)
}

/// Same as `max_pool2d_with_stride` but with padding support.
///
/// # Arguments
///
/// * `kernel_size` - The size of the pooling window
/// * `stride` - The stride of the pooling operation, controlling how far the window
/// * `padding` - The amount of zero-padding to add to both sides of the height and width
/// # Examples
/// ```rust
/// use candle_core::{Tensor,Device,Shape};
/// let t = Tensor::rand(0.0f32, 1.0, (1, 1, 10, 10), &Device::Cpu)?;
/// let (kernel_size,stride,padding) = (2,2,2);
/// let t = t.max_pool2d_with_stride_padding(kernel_size,stride,padding)?;
/// assert_eq!(t.shape().dims(),[1,1,7,7]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn max_pool2d_with_stride_padding<T: crate::ToUsize2>(
&self,
kernel_size: T,
stride: T,
padding: usize,
) -> Result<Self> {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
if h < kernel_size.0 || w < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")

// Calculate effective input size with padding
let h_padded = h + 2 * padding;
let w_padded = w + 2 * padding;

if h_padded < kernel_size.0 || w_padded < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the padded input size {h_padded},{w_padded}")
}

// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
let h_out = (h_padded - kernel_size.0) / stride.0 + 1;
let w_out = (w_padded - kernel_size.1) / stride.1 + 1;

let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
arg,
kernel_size,
stride,
padding,
});

let storage = self
.storage()
.max_pool2d(self.layout(), kernel_size, stride)?;
.max_pool2d(self.layout(), kernel_size, stride, padding)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}

Expand Down
Loading