From 928bc683dc14b8471e6f8b891a14e21b14e089f7 Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Wed, 22 Oct 2025 13:03:36 +0800 Subject: [PATCH 1/5] add padding support for max pool --- candle-core/src/backend.rs | 2 +- candle-core/src/backprop.rs | 7 +++- candle-core/src/cpu_backend/mod.rs | 44 ++++++++++++++++++++------ candle-core/src/cuda_backend/mod.rs | 14 ++++++-- candle-core/src/dummy_cuda_backend.rs | 2 +- candle-core/src/dummy_metal_backend.rs | 2 +- candle-core/src/metal_backend/mod.rs | 11 +++++-- candle-core/src/op.rs | 1 + candle-core/src/storage.rs | 7 ++-- candle-core/src/tensor.rs | 42 +++++++++++++++++++++++- 10 files changed, 109 insertions(+), 23 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a85f8d36d2..6ddf18dec6 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -66,7 +66,7 @@ pub trait BackendStorage: Sized { ) -> Result; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result; fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a14306657b..d8b04cb4f7 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -358,6 +358,7 @@ impl Tensor { arg, kernel_size, stride, + padding, } => { if kernel_size != stride { crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}") @@ -369,7 +370,11 @@ impl Tensor { // we scale the gradient for this case). let node_upsampled = node.upsample_nearest2d(h, w)?; let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; - let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?; + let avg = if *padding == 0 { + mask.avg_pool2d_with_stride(*kernel_size, *stride)? + } else { + mask.max_pool2d_with_stride_padding(*kernel_size, *stride, *padding)? + }; let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 06edfe8d14..79b3cd2783 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -342,20 +342,27 @@ impl Map1 for AvgPool2D { } } -struct MaxPool2D((usize, usize), (usize, usize)); +struct MaxPool2D((usize, usize), (usize, usize), usize); impl Map1 for MaxPool2D { fn f(&self, src: &[T], layout: &Layout) -> Result> { // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html let (k_h, k_w) = self.0; let (s_h, s_w) = self.1; + let padding = self.2; let (b_sz, c, h, w) = layout.shape().dims4()?; let stride = layout.stride(); let (stride_h, stride_w) = (stride[2], stride[3]); - let h_out = (h - k_h) / s_h + 1; - let w_out = (w - k_w) / s_w + 1; + + // Calculate output dimensions with padding + let h_padded = h + 2 * padding; + let w_padded = w + 2 * padding; + let h_out = (h_padded - k_h) / s_h + 1; + let w_out = (w_padded - k_w) / s_w + 1; + let src_index = layout.start_offset(); let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + for b_idx in 0..b_sz { let dst = &mut dst[b_idx * c * h_out * w_out..]; let src_index = src_index + b_idx * stride[0]; @@ -364,17 +371,33 @@ impl Map1 for MaxPool2D { let src_index = src_index + c_idx * stride[1]; for h_idx in 0..h_out { for w_idx in 0..w_out { - let mut largest = - src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w]; + let mut largest = T::zero(); + let mut found_valid = false; + for m in 0..k_h { for n in 0..k_w { - let m = s_h * h_idx + m; - let n = s_w * w_idx + n; - if largest < src[src_index + m * stride_h + n * stride_w] { - largest = src[src_index + m * stride_h + n * stride_w] + let src_h = s_h * h_idx + m; + let src_w = s_w * w_idx + n; + + // Check if we're within the original (unpadded) bounds + if src_h >= padding && src_h < h + padding && + src_w >= padding && src_w < w + padding { + let actual_h = src_h - padding; + let actual_w = src_w - padding; + let val = src[src_index + actual_h * stride_h + actual_w * stride_w]; + if !found_valid || largest < val { + largest = val; + found_valid = true; + } } } } + + // If no valid values were found (all padding), use zero + if !found_valid { + largest = T::zero(); + } + dst[h_idx * w_out + w_idx] = largest; } } @@ -2066,8 +2089,9 @@ impl BackendStorage for CpuStorage { layout: &Layout, kernel_size: (usize, usize), stride: (usize, usize), + padding: usize, ) -> Result { - MaxPool2D(kernel_size, stride).map(self, layout) + MaxPool2D(kernel_size, stride, padding).map(self, layout) } fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result { diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index b1f166a6ac..726f092d6a 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -885,6 +885,7 @@ struct Pool2D { h_k: usize, w_stride: usize, h_stride: usize, + padding: usize, op: PoolOp, } @@ -905,8 +906,13 @@ impl Map1 for Pool2D { crate::bail!("unexpected input shape for pool {dims:?}") }; let el = shape.elem_count(); - let out_w = (dims[2] - self.w_k) / self.w_stride + 1; - let out_h = (dims[3] - self.h_k) / self.h_stride + 1; + + // Calculate output dimensions with padding + let h_padded = dims[2] + 2 * self.padding; + let w_padded = dims[3] + 2 * self.padding; + let out_w = (h_padded - self.w_k) / self.w_stride + 1; + let out_h = (w_padded - self.h_k) / self.h_stride + 1; + let dst_el = out_w * out_h * dims[0] * dims[1]; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let kname = match self.op { @@ -923,6 +929,7 @@ impl Map1 for Pool2D { barg!(builder, self.h_k); barg!(builder, self.w_stride); barg!(builder, self.h_stride); + barg!(builder, self.padding); builder.arg(&ds); builder.arg(inp); builder.arg(&out); @@ -1889,13 +1896,14 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { + fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize), padding: usize) -> Result { let device = self.device().clone(); let slice = Pool2D { w_k: k.0, h_k: k.1, w_stride: stride.0, h_stride: stride.1, + padding, op: PoolOp::Max, } .map(&self.slice, &device, l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 329099354b..c943d52268 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -195,7 +195,7 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index de43f243fb..624af65411 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -199,7 +199,7 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 3f47f6a4d2..6f6be9add8 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1324,6 +1324,7 @@ impl BackendStorage for MetalStorage { inp_l: &Layout, (w_k, h_k): (usize, usize), (w_stride, h_stride): (usize, usize), + padding: usize, ) -> Result { let shape = inp_l.shape(); let (b_size, channels, width, height) = shape.dims4()?; @@ -1336,8 +1337,13 @@ impl BackendStorage for MetalStorage { DType::U32 => "max_pool2d_u32", dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"), }; - let out_w = (width - w_k) / w_stride + 1; - let out_h = (height - h_k) / h_stride + 1; + + // Calculate output dimensions with padding + let width_padded = width + 2 * padding; + let height_padded = height + 2 * padding; + let out_w = (width_padded - w_k) / w_stride + 1; + let out_h = (height_padded - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; let command_buffers = self.device.command_buffer()?; @@ -1354,6 +1360,7 @@ impl BackendStorage for MetalStorage { h_k, w_stride, h_stride, + padding, &self.buffer, &buffer, ) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 8e24368ff1..226526a1b1 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -135,6 +135,7 @@ pub enum Op { arg: Tensor, kernel_size: (usize, usize), stride: (usize, usize), + padding: usize, }, UpsampleNearest1D { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 32af582473..7bbc8b4180 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -519,18 +519,19 @@ impl Storage { layout: &Layout, kernel_size: (usize, usize), stride: (usize, usize), + padding: usize, ) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.max_pool2d(layout, kernel_size, stride)?; + let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.max_pool2d(layout, kernel_size, stride)?; + let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?; Ok(Self::Cuda(storage)) } Self::Metal(storage) => { - let storage = storage.max_pool2d(layout, kernel_size, stride)?; + let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?; Ok(Self::Metal(storage)) } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d71630212d..669e71ad4c 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1228,10 +1228,50 @@ impl Tensor { arg, kernel_size, stride, + padding: 0, }); let storage = self .storage() - .max_pool2d(self.layout(), kernel_size, stride)?; + .max_pool2d(self.layout(), kernel_size, stride, 0)?; + Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + } + + /// Same as `max_pool2d_with_stride` but with padding support. + /// + /// The padding parameter adds padding around the input tensor before applying max pooling. + /// The padding is applied symmetrically (same amount on all sides). + pub fn max_pool2d_with_stride_padding( + &self, + kernel_size: T, + stride: T, + padding: usize, + ) -> Result { + let kernel_size = kernel_size.to_usize2(); + let stride = stride.to_usize2(); + let (n, c, h, w) = self.dims4()?; + + // Calculate effective input size with padding + let h_padded = h + 2 * padding; + let w_padded = w + 2 * padding; + + if h_padded < kernel_size.0 || w_padded < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the padded input size {h_padded},{w_padded}") + } + + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d + let h_out = (h_padded - kernel_size.0) / stride.0 + 1; + let w_out = (w_padded - kernel_size.1) / stride.1 + 1; + + let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { + arg, + kernel_size, + stride, + padding, + }); + + let storage = self + .storage() + .max_pool2d(self.layout(), kernel_size, stride, padding)?; Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } From 88ffe3537bc69f79d0219c27dc360f101b5ea2bd Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Wed, 22 Oct 2025 13:34:50 +0800 Subject: [PATCH 2/5] add padding support for max pool --- candle-core/src/tensor.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 669e71ad4c..26f2abbc36 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1238,8 +1238,19 @@ impl Tensor { /// Same as `max_pool2d_with_stride` but with padding support. /// - /// The padding parameter adds padding around the input tensor before applying max pooling. - /// The padding is applied symmetrically (same amount on all sides). + /// # Arguments + /// + /// * `kernel_size` - The size of the pooling window + /// * `stride` - The stride of the pooling operation, controlling how far the window + /// * `padding` - The amount of zero-padding to add to both sides of the height and width + /// # Examples + /// ```rust + /// use candle_core::{Tensor,Device,Shape}; + /// let t = Tensor::rand(0.0f32, 1.0, (1, 1, 10, 10), &device)?; + /// let (kernel_size,stride,padding) = (2,2,2) + /// let t = t.max_pool2d_with_stride_padding(kernel_size,stride,padding)?; + /// assert_eq!(t.shape().dims(),[1,1,7,7]); + /// ``` pub fn max_pool2d_with_stride_padding( &self, kernel_size: T, From 3b65fa8c610c6203273daf4f116ae3ca94e48ef9 Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Wed, 22 Oct 2025 21:00:17 +0800 Subject: [PATCH 3/5] add metal and cuda kernel --- candle-core/src/metal_backend/mod.rs | 2 ++ candle-kernels/src/conv.cu | 14 +++++--- .../src/kernels/convolution.rs | 34 +++++++++++++++++++ candle-metal-kernels/src/metal_src/conv.metal | 14 +++++--- 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 6f6be9add8..557b039ab2 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1300,6 +1300,7 @@ impl BackendStorage for MetalStorage { let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?; let command_buffers = self.device.command_buffer()?; candle_metal_kernels::call_pool2d( + candle_metal_kernels::call_pool2d_with_padding( &self.device.device, &command_buffers, &self.device.kernels, @@ -1312,6 +1313,7 @@ impl BackendStorage for MetalStorage { h_k, w_stride, h_stride, + padding, &self.buffer, &buffer, ) diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 3f15e0ad2e..33f51c54cb 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -446,6 +446,7 @@ __device__ void max_pool2d( const size_t h_k, const size_t w_stride, const size_t h_stride, + const size_t padding, const size_t *info, const T *src, T *dst @@ -459,8 +460,8 @@ __device__ void max_pool2d( const size_t w_in = src_dims[2]; const size_t h_in = src_dims[3]; - const size_t w_out = (w_in - w_k) / w_stride + 1; - const size_t h_out = (h_in - h_k) / h_stride + 1; + const size_t w_out = (w_in + 2 * padding - w_k) / w_stride + 1; + const size_t h_out = (h_in + 2 * padding - h_k) / h_stride + 1; if (dst_i >= src_dims[0] * c * w_out * h_out) { return; } @@ -476,14 +477,16 @@ __device__ void max_pool2d( bool set = false; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { size_t src_w = w_stride * dst_w + w_offset; - if (src_w >= w_in) { + if (src_w < padding || src_w >= w_in + padding) { continue; } + src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { size_t src_h = h_stride * dst_h + h_offset; - if (src_h >= h_in) { + if (src_h < padding || src_h >= h_in + padding) { continue; } + src_h -= padding; const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; if (set) { d = maxg(d, src[src_idx]); @@ -671,11 +674,12 @@ extern "C" __global__ void FN_NAME( \ const size_t h_k, \ const size_t w_stride, \ const size_t h_stride, \ + const size_t padding, \ const size_t *info, \ const TYPENAME *src, \ TYPENAME *dst \ ) { \ - max_pool2d(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ + max_pool2d(src_numel, w_k, h_k, w_stride, h_stride, padding, info, src, dst); \ } \ #define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ diff --git a/candle-metal-kernels/src/kernels/convolution.rs b/candle-metal-kernels/src/kernels/convolution.rs index 6b2e5fcf96..eef6671e71 100644 --- a/candle-metal-kernels/src/kernels/convolution.rs +++ b/candle-metal-kernels/src/kernels/convolution.rs @@ -167,6 +167,40 @@ pub fn call_pool2d( Ok(()) } +#[allow(clippy::too_many_arguments)] +pub fn call_pool2d_with_padding( + device: &Device, + ep: impl EncoderProvider, + kernels: &Kernels, + name: &'static str, + shape: &[usize], + strides: &[usize], + out_w: usize, + out_h: usize, + w_k: usize, + h_k: usize, + w_stride: usize, + h_stride: usize, + padding: usize, + input: &Buffer, + output: &Buffer, +) -> Result<(), MetalKernelError> { + let dst_el = out_w * out_h * shape[0] * shape[1]; + let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; + let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); + let encoder = ep.encoder(); + let encoder: &ComputeCommandEncoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&pipeline); + set_params!( + encoder, + (w_k, h_k, w_stride, h_stride, padding, shape, strides, input, output) + ); + encoder.use_resource(input, MTLResourceUsage::Read); + encoder.use_resource(output, MTLResourceUsage::Write); + encoder.dispatch_thread_groups(thread_group_count, thread_group_size); + Ok(()) +} + #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose1d( device: &Device, diff --git a/candle-metal-kernels/src/metal_src/conv.metal b/candle-metal-kernels/src/metal_src/conv.metal index fbe19bb87f..d4e38a463d 100644 --- a/candle-metal-kernels/src/metal_src/conv.metal +++ b/candle-metal-kernels/src/metal_src/conv.metal @@ -332,6 +332,7 @@ METAL_FUNC void max_pool2d( constant size_t &h_k, constant size_t &w_stride, constant size_t &h_stride, + constant size_t &padding, constant size_t *src_dims, constant size_t *src_strides, device const T *src, @@ -342,8 +343,8 @@ METAL_FUNC void max_pool2d( const size_t w_in = src_dims[2]; const size_t h_in = src_dims[3]; - const size_t w_out = (w_in - w_k) / w_stride + 1; - const size_t h_out = (h_in - h_k) / h_stride + 1; + const size_t w_out = (w_in + 2 * padding - w_k) / w_stride + 1; + const size_t h_out = (h_in + 2 * padding - h_k) / h_stride + 1; if (tid >= src_dims[0] * c * w_out * h_out) { return; } @@ -358,14 +359,16 @@ METAL_FUNC void max_pool2d( bool set = false; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { size_t src_w = w_stride * dst_w + w_offset; - if (src_w >= w_in){ + if (src_w < padding || src_w >= w_in + padding){ continue; } + src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { size_t src_h = h_stride * dst_h + h_offset; - if (src_h >= h_in) { + if (src_h < padding || src_h >= h_in + padding) { continue; } + src_h -= padding; const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; if (set) { d = MAX(d, src[src_idx]); @@ -385,13 +388,14 @@ kernel void FN_NAME( \ constant size_t &h_k, \ constant size_t &w_s, \ constant size_t &h_s, \ + constant size_t &padding, \ constant size_t *src_dims, \ constant size_t *src_s, \ device const TYPENAME *src, \ device TYPENAME *dst, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - max_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ + max_pool2d(w_k, h_k, w_s, h_s, padding, src_dims, src_s, src, dst, tid); \ } \ From d706e7111338fdf6e07b3fad0be4c1b32e751efa Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Sun, 26 Oct 2025 14:18:17 +0800 Subject: [PATCH 4/5] 1.change the max_pool2d_with_stride to call the max_pool2d_with_stride_padding \n2. combine the new call_pool2d_with_padding into original max_pool2d function \n3. change the metal_backend --- candle-core/src/backend.rs | 8 ++++- candle-core/src/cpu_backend/mod.rs | 24 +++++++------ candle-core/src/cuda_backend/mod.rs | 12 +++++-- candle-core/src/dummy_cuda_backend.rs | 8 ++++- candle-core/src/dummy_metal_backend.rs | 8 ++++- candle-core/src/metal_backend/mod.rs | 6 ++-- candle-core/src/tensor.rs | 30 ++++------------ .../src/kernels/convolution.rs | 34 +------------------ 8 files changed, 55 insertions(+), 75 deletions(-) diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index 6ddf18dec6..95cd6eb19a 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -66,7 +66,13 @@ pub trait BackendStorage: Sized { ) -> Result; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result; + fn max_pool2d( + &self, + _: &Layout, + _: (usize, usize), + _: (usize, usize), + _: usize, + ) -> Result; fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 79b3cd2783..070e6703aa 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -353,16 +353,16 @@ impl Map1 for MaxPool2D { let (b_sz, c, h, w) = layout.shape().dims4()?; let stride = layout.stride(); let (stride_h, stride_w) = (stride[2], stride[3]); - + // Calculate output dimensions with padding let h_padded = h + 2 * padding; let w_padded = w + 2 * padding; let h_out = (h_padded - k_h) / s_h + 1; let w_out = (w_padded - k_w) / s_w + 1; - + let src_index = layout.start_offset(); let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; - + for b_idx in 0..b_sz { let dst = &mut dst[b_idx * c * h_out * w_out..]; let src_index = src_index + b_idx * stride[0]; @@ -373,18 +373,22 @@ impl Map1 for MaxPool2D { for w_idx in 0..w_out { let mut largest = T::zero(); let mut found_valid = false; - + for m in 0..k_h { for n in 0..k_w { let src_h = s_h * h_idx + m; let src_w = s_w * w_idx + n; - + // Check if we're within the original (unpadded) bounds - if src_h >= padding && src_h < h + padding && - src_w >= padding && src_w < w + padding { + if src_h >= padding + && src_h < h + padding + && src_w >= padding + && src_w < w + padding + { let actual_h = src_h - padding; let actual_w = src_w - padding; - let val = src[src_index + actual_h * stride_h + actual_w * stride_w]; + let val = + src[src_index + actual_h * stride_h + actual_w * stride_w]; if !found_valid || largest < val { largest = val; found_valid = true; @@ -392,12 +396,12 @@ impl Map1 for MaxPool2D { } } } - + // If no valid values were found (all padding), use zero if !found_valid { largest = T::zero(); } - + dst[h_idx * w_out + w_idx] = largest; } } diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 726f092d6a..1ffb24b4fe 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -906,13 +906,13 @@ impl Map1 for Pool2D { crate::bail!("unexpected input shape for pool {dims:?}") }; let el = shape.elem_count(); - + // Calculate output dimensions with padding let h_padded = dims[2] + 2 * self.padding; let w_padded = dims[3] + 2 * self.padding; let out_w = (h_padded - self.w_k) / self.w_stride + 1; let out_h = (w_padded - self.h_k) / self.h_stride + 1; - + let dst_el = out_w * out_h * dims[0] * dims[1]; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let kname = match self.op { @@ -1896,7 +1896,13 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize), padding: usize) -> Result { + fn max_pool2d( + &self, + l: &Layout, + k: (usize, usize), + stride: (usize, usize), + padding: usize, + ) -> Result { let device = self.device().clone(); let slice = Pool2D { w_k: k.0, diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index c943d52268..f312b8b772 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -195,7 +195,13 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result { + fn max_pool2d( + &self, + _: &Layout, + _: (usize, usize), + _: (usize, usize), + _: usize, + ) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index 624af65411..cb46ed7bb9 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -199,7 +199,13 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize), _: usize) -> Result { + fn max_pool2d( + &self, + _: &Layout, + _: (usize, usize), + _: (usize, usize), + _: usize, + ) -> Result { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 557b039ab2..0b7622f358 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1297,10 +1297,10 @@ impl BackendStorage for MetalStorage { let out_w = (width - w_k) / w_stride + 1; let out_h = (height - h_k) / h_stride + 1; let dst_el = out_w * out_h * b_size * channels; + let padding = 0; let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?; let command_buffers = self.device.command_buffer()?; candle_metal_kernels::call_pool2d( - candle_metal_kernels::call_pool2d_with_padding( &self.device.device, &command_buffers, &self.device.kernels, @@ -1339,13 +1339,13 @@ impl BackendStorage for MetalStorage { DType::U32 => "max_pool2d_u32", dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"), }; - + // Calculate output dimensions with padding let width_padded = width + 2 * padding; let height_padded = height + 2 * padding; let out_w = (width_padded - w_k) / w_stride + 1; let out_h = (height_padded - h_k) / h_stride + 1; - + let dst_el = out_w * out_h * b_size * channels; let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; let command_buffers = self.device.command_buffer()?; diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 26f2abbc36..611819759a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1215,25 +1215,9 @@ impl Tensor { kernel_size: T, stride: T, ) -> Result { - let kernel_size = kernel_size.to_usize2(); + let sz = kernel_size.to_usize2(); let stride = stride.to_usize2(); - let (n, c, h, w) = self.dims4()?; - if h < kernel_size.0 || w < kernel_size.1 { - bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") - } - // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d - let h_out = (h - kernel_size.0) / stride.0 + 1; - let w_out = (w - kernel_size.1) / stride.1 + 1; - let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { - arg, - kernel_size, - stride, - padding: 0, - }); - let storage = self - .storage() - .max_pool2d(self.layout(), kernel_size, stride, 0)?; - Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) + self.max_pool2d_with_stride_padding(sz, stride, 0) } /// Same as `max_pool2d_with_stride` but with padding support. @@ -1260,26 +1244,26 @@ impl Tensor { let kernel_size = kernel_size.to_usize2(); let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; - + // Calculate effective input size with padding let h_padded = h + 2 * padding; let w_padded = w + 2 * padding; - + if h_padded < kernel_size.0 || w_padded < kernel_size.1 { bail!("kernel-size {kernel_size:?} is larger than the padded input size {h_padded},{w_padded}") } - + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d let h_out = (h_padded - kernel_size.0) / stride.0 + 1; let w_out = (w_padded - kernel_size.1) / stride.1 + 1; - + let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { arg, kernel_size, stride, padding, }); - + let storage = self .storage() .max_pool2d(self.layout(), kernel_size, stride, padding)?; diff --git a/candle-metal-kernels/src/kernels/convolution.rs b/candle-metal-kernels/src/kernels/convolution.rs index eef6671e71..3cc7680fae 100644 --- a/candle-metal-kernels/src/kernels/convolution.rs +++ b/candle-metal-kernels/src/kernels/convolution.rs @@ -136,39 +136,6 @@ pub fn call_upsample_nearest_2d( #[allow(clippy::too_many_arguments)] pub fn call_pool2d( - device: &Device, - ep: impl EncoderProvider, - kernels: &Kernels, - name: &'static str, - shape: &[usize], - strides: &[usize], - out_w: usize, - out_h: usize, - w_k: usize, - h_k: usize, - w_stride: usize, - h_stride: usize, - input: &Buffer, - output: &Buffer, -) -> Result<(), MetalKernelError> { - let dst_el = out_w * out_h * shape[0] * shape[1]; - let pipeline = kernels.load_pipeline(device, Source::Conv, name)?; - let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); - let encoder = ep.encoder(); - let encoder: &ComputeCommandEncoder = encoder.as_ref(); - encoder.set_compute_pipeline_state(&pipeline); - set_params!( - encoder, - (w_k, h_k, w_stride, h_stride, shape, strides, input, output) - ); - encoder.use_resource(input, MTLResourceUsage::Read); - encoder.use_resource(output, MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -pub fn call_pool2d_with_padding( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, @@ -201,6 +168,7 @@ pub fn call_pool2d_with_padding( Ok(()) } + #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose1d( device: &Device, From 4369f919dd8ad782a76c1a6a903acee6b3005802 Mon Sep 17 00:00:00 2001 From: Donjuanplatinum Date: Sun, 26 Oct 2025 14:40:49 +0800 Subject: [PATCH 5/5] fix the example for max_pool2d_with_stride_padding --- candle-core/src/tensor.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 611819759a..ce64da2dde 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1230,10 +1230,11 @@ impl Tensor { /// # Examples /// ```rust /// use candle_core::{Tensor,Device,Shape}; - /// let t = Tensor::rand(0.0f32, 1.0, (1, 1, 10, 10), &device)?; - /// let (kernel_size,stride,padding) = (2,2,2) + /// let t = Tensor::rand(0.0f32, 1.0, (1, 1, 10, 10), &Device::Cpu)?; + /// let (kernel_size,stride,padding) = (2,2,2); /// let t = t.max_pool2d_with_stride_padding(kernel_size,stride,padding)?; /// assert_eq!(t.shape().dims(),[1,1,7,7]); + /// # Ok::<(), candle_core::Error>(()) /// ``` pub fn max_pool2d_with_stride_padding( &self,