Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion candle-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub trait BackendStorage: Sized {
) -> Result<Self>;

fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self>;
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result<Self>;
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self>;
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;

Expand Down
7 changes: 6 additions & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ impl Tensor {
arg,
kernel_size,
stride,
padding,
} => {
if kernel_size != stride {
crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}")
Expand All @@ -369,7 +370,11 @@ impl Tensor {
// we scale the gradient for this case).
let node_upsampled = node.upsample_nearest2d(h, w)?;
let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?;
let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?;
let avg = if *padding == 0 {
mask.avg_pool2d_with_stride(*kernel_size, *stride)?
} else {
mask.max_pool2d_with_stride_padding(*kernel_size, *stride, *padding)?
};
let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad_arg)?;
Expand Down
44 changes: 34 additions & 10 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,20 +342,27 @@ impl Map1 for AvgPool2D {
}
}

struct MaxPool2D((usize, usize), (usize, usize));
struct MaxPool2D((usize, usize), (usize, usize), usize);

impl Map1 for MaxPool2D {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html
let (k_h, k_w) = self.0;
let (s_h, s_w) = self.1;
let padding = self.2;
let (b_sz, c, h, w) = layout.shape().dims4()?;
let stride = layout.stride();
let (stride_h, stride_w) = (stride[2], stride[3]);
let h_out = (h - k_h) / s_h + 1;
let w_out = (w - k_w) / s_w + 1;

// Calculate output dimensions with padding
let h_padded = h + 2 * padding;
let w_padded = w + 2 * padding;
let h_out = (h_padded - k_h) / s_h + 1;
let w_out = (w_padded - k_w) / s_w + 1;

let src_index = layout.start_offset();
let mut dst = vec![T::zero(); b_sz * c * h_out * w_out];

for b_idx in 0..b_sz {
let dst = &mut dst[b_idx * c * h_out * w_out..];
let src_index = src_index + b_idx * stride[0];
Expand All @@ -364,17 +371,33 @@ impl Map1 for MaxPool2D {
let src_index = src_index + c_idx * stride[1];
for h_idx in 0..h_out {
for w_idx in 0..w_out {
let mut largest =
src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w];
let mut largest = T::zero();
let mut found_valid = false;

for m in 0..k_h {
for n in 0..k_w {
let m = s_h * h_idx + m;
let n = s_w * w_idx + n;
if largest < src[src_index + m * stride_h + n * stride_w] {
largest = src[src_index + m * stride_h + n * stride_w]
let src_h = s_h * h_idx + m;
let src_w = s_w * w_idx + n;

// Check if we're within the original (unpadded) bounds
if src_h >= padding && src_h < h + padding &&
src_w >= padding && src_w < w + padding {
let actual_h = src_h - padding;
let actual_w = src_w - padding;
let val = src[src_index + actual_h * stride_h + actual_w * stride_w];
if !found_valid || largest < val {
largest = val;
found_valid = true;
}
}
}
}

// If no valid values were found (all padding), use zero
if !found_valid {
largest = T::zero();
}

dst[h_idx * w_out + w_idx] = largest;
}
}
Expand Down Expand Up @@ -2066,8 +2089,9 @@ impl BackendStorage for CpuStorage {
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: usize,
) -> Result<Self> {
MaxPool2D(kernel_size, stride).map(self, layout)
MaxPool2D(kernel_size, stride, padding).map(self, layout)
}

fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result<Self> {
Expand Down
14 changes: 11 additions & 3 deletions candle-core/src/cuda_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ struct Pool2D {
h_k: usize,
w_stride: usize,
h_stride: usize,
padding: usize,
op: PoolOp,
}

Expand All @@ -905,8 +906,13 @@ impl Map1 for Pool2D {
crate::bail!("unexpected input shape for pool {dims:?}")
};
let el = shape.elem_count();
let out_w = (dims[2] - self.w_k) / self.w_stride + 1;
let out_h = (dims[3] - self.h_k) / self.h_stride + 1;

// Calculate output dimensions with padding
let h_padded = dims[2] + 2 * self.padding;
let w_padded = dims[3] + 2 * self.padding;
let out_w = (h_padded - self.w_k) / self.w_stride + 1;
let out_h = (w_padded - self.h_k) / self.h_stride + 1;

let dst_el = out_w * out_h * dims[0] * dims[1];
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let kname = match self.op {
Expand All @@ -923,6 +929,7 @@ impl Map1 for Pool2D {
barg!(builder, self.h_k);
barg!(builder, self.w_stride);
barg!(builder, self.h_stride);
barg!(builder, self.padding);
builder.arg(&ds);
builder.arg(inp);
builder.arg(&out);
Expand Down Expand Up @@ -1889,13 +1896,14 @@ impl BackendStorage for CudaStorage {
Ok(Self { slice, device })
}

fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result<Self> {
fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize), padding: usize) -> Result<Self> {
let device = self.device().clone();
let slice = Pool2D {
w_k: k.0,
h_k: k.1,
w_stride: stride.0,
h_stride: stride.1,
padding,
op: PoolOp::Max,
}
.map(&self.slice, &device, l)?;
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/dummy_cuda_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl crate::backend::BackendStorage for CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}

fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/dummy_metal_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ impl crate::backend::BackendStorage for MetalStorage {
Err(Error::NotCompiledWithMetalSupport)
}

fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}

Expand Down
13 changes: 11 additions & 2 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,8 @@
let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?;
let command_buffers = self.device.command_buffer()?;
candle_metal_kernels::call_pool2d(

Check failure on line 1302 in candle-core/src/metal_backend/mod.rs

View workflow job for this annotation

GitHub Actions / Rustfmt

mismatched closing delimiter: `}`
candle_metal_kernels::call_pool2d_with_padding(
&self.device.device,
&command_buffers,
&self.device.kernels,
Expand All @@ -1312,6 +1313,7 @@
h_k,
w_stride,
h_stride,
padding,
&self.buffer,
&buffer,
)
Expand All @@ -1324,6 +1326,7 @@
inp_l: &Layout,
(w_k, h_k): (usize, usize),
(w_stride, h_stride): (usize, usize),
padding: usize,
) -> Result<Self> {
let shape = inp_l.shape();
let (b_size, channels, width, height) = shape.dims4()?;
Expand All @@ -1336,8 +1339,13 @@
DType::U32 => "max_pool2d_u32",
dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"),
};
let out_w = (width - w_k) / w_stride + 1;
let out_h = (height - h_k) / h_stride + 1;

// Calculate output dimensions with padding
let width_padded = width + 2 * padding;
let height_padded = height + 2 * padding;
let out_w = (width_padded - w_k) / w_stride + 1;
let out_h = (height_padded - h_k) / h_stride + 1;

let dst_el = out_w * out_h * b_size * channels;
let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?;
let command_buffers = self.device.command_buffer()?;
Expand All @@ -1354,6 +1362,7 @@
h_k,
w_stride,
h_stride,
padding,
&self.buffer,
&buffer,
)
Expand Down
1 change: 1 addition & 0 deletions candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub enum Op {
arg: Tensor,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: usize,
},

UpsampleNearest1D {
Expand Down
7 changes: 4 additions & 3 deletions candle-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -519,18 +519,19 @@ impl Storage {
layout: &Layout,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: usize,
) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?;
Ok(Self::Metal(storage))
}
}
Expand Down
53 changes: 52 additions & 1 deletion candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1228,10 +1228,61 @@
arg,
kernel_size,
stride,
padding: 0,
});
let storage = self
.storage()
.max_pool2d(self.layout(), kernel_size, stride)?;
.max_pool2d(self.layout(), kernel_size, stride, 0)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}

/// Same as `max_pool2d_with_stride` but with padding support.
///
/// # Arguments
///
/// * `kernel_size` - The size of the pooling window
/// * `stride` - The stride of the pooling operation, controlling how far the window
/// * `padding` - The amount of zero-padding to add to both sides of the height and width
/// # Examples
/// ```rust
/// use candle_core::{Tensor,Device,Shape};
/// let t = Tensor::rand(0.0f32, 1.0, (1, 1, 10, 10), &device)?;
/// let (kernel_size,stride,padding) = (2,2,2)

Check failure on line 1250 in candle-core/src/tensor.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

the `?` operator can only be used in a function that returns `Result` or `Option` (or another type that implements `FromResidual`)

Check failure on line 1250 in candle-core/src/tensor.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

cannot find value `device` in this scope
/// let t = t.max_pool2d_with_stride_padding(kernel_size,stride,padding)?;

Check failure on line 1251 in candle-core/src/tensor.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

expected `;`, found keyword `let`
/// assert_eq!(t.shape().dims(),[1,1,7,7]);

Check failure on line 1252 in candle-core/src/tensor.rs

View workflow job for this annotation

GitHub Actions / Test Suite (macOS-latest, stable)

the `?` operator can only be used in a function that returns `Result` or `Option` (or another type that implements `FromResidual`)
/// ```
pub fn max_pool2d_with_stride_padding<T: crate::ToUsize2>(
&self,
kernel_size: T,
stride: T,
padding: usize,
) -> Result<Self> {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;

// Calculate effective input size with padding
let h_padded = h + 2 * padding;
let w_padded = w + 2 * padding;

if h_padded < kernel_size.0 || w_padded < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the padded input size {h_padded},{w_padded}")
}

// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h_padded - kernel_size.0) / stride.0 + 1;
let w_out = (w_padded - kernel_size.1) / stride.1 + 1;

let op = BackpropOp::new1(self, |arg| Op::MaxPool2D {
arg,
kernel_size,
stride,
padding,
});

let storage = self
.storage()
.max_pool2d(self.layout(), kernel_size, stride, padding)?;
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
}

Expand Down
14 changes: 9 additions & 5 deletions candle-kernels/src/conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ __device__ void max_pool2d(
const size_t h_k,
const size_t w_stride,
const size_t h_stride,
const size_t padding,
const size_t *info,
const T *src,
T *dst
Expand All @@ -459,8 +460,8 @@ __device__ void max_pool2d(
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];

const size_t w_out = (w_in - w_k) / w_stride + 1;
const size_t h_out = (h_in - h_k) / h_stride + 1;
const size_t w_out = (w_in + 2 * padding - w_k) / w_stride + 1;
const size_t h_out = (h_in + 2 * padding - h_k) / h_stride + 1;
if (dst_i >= src_dims[0] * c * w_out * h_out) {
return;
}
Expand All @@ -476,14 +477,16 @@ __device__ void max_pool2d(
bool set = false;
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
size_t src_w = w_stride * dst_w + w_offset;
if (src_w >= w_in) {
if (src_w < padding || src_w >= w_in + padding) {
continue;
}
src_w -= padding;
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
size_t src_h = h_stride * dst_h + h_offset;
if (src_h >= h_in) {
if (src_h < padding || src_h >= h_in + padding) {
continue;
}
src_h -= padding;
const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
if (set) {
d = maxg(d, src[src_idx]);
Expand Down Expand Up @@ -671,11 +674,12 @@ extern "C" __global__ void FN_NAME( \
const size_t h_k, \
const size_t w_stride, \
const size_t h_stride, \
const size_t padding, \
const size_t *info, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \
max_pool2d<TYPENAME>(src_numel, w_k, h_k, w_stride, h_stride, padding, info, src, dst); \
} \

#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
Expand Down
Loading
Loading