diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index a85f8d36d2..95cd6eb19a 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -66,7 +66,13 @@ pub trait BackendStorage: Sized { ) -> Result; fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result; + fn max_pool2d( + &self, + _: &Layout, + _: (usize, usize), + _: (usize, usize), + _: usize, + ) -> Result; fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result; fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index a14306657b..d8b04cb4f7 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -358,6 +358,7 @@ impl Tensor { arg, kernel_size, stride, + padding, } => { if kernel_size != stride { crate::bail!("backward not supported for maxpool2d if ksize {kernel_size:?} != stride {stride:?}") @@ -369,7 +370,11 @@ impl Tensor { // we scale the gradient for this case). let node_upsampled = node.upsample_nearest2d(h, w)?; let mask = arg.eq(&node_upsampled)?.to_dtype(arg.dtype())?; - let avg = mask.avg_pool2d_with_stride(*kernel_size, *stride)?; + let avg = if *padding == 0 { + mask.avg_pool2d_with_stride(*kernel_size, *stride)? + } else { + mask.max_pool2d_with_stride_padding(*kernel_size, *stride, *padding)? + }; let grad_arg = ((grad * avg)?.upsample_nearest2d(h, w)? * mask)?; let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad_arg)?; diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 8d8219ec9d..76a161a54d 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -343,20 +343,27 @@ impl Map1 for AvgPool2D { } } -struct MaxPool2D((usize, usize), (usize, usize)); +struct MaxPool2D((usize, usize), (usize, usize), usize); impl Map1 for MaxPool2D { fn f(&self, src: &[T], layout: &Layout) -> Result> { // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html let (k_h, k_w) = self.0; let (s_h, s_w) = self.1; + let padding = self.2; let (b_sz, c, h, w) = layout.shape().dims4()?; let stride = layout.stride(); let (stride_h, stride_w) = (stride[2], stride[3]); - let h_out = (h - k_h) / s_h + 1; - let w_out = (w - k_w) / s_w + 1; + + // Calculate output dimensions with padding + let h_padded = h + 2 * padding; + let w_padded = w + 2 * padding; + let h_out = (h_padded - k_h) / s_h + 1; + let w_out = (w_padded - k_w) / s_w + 1; + let src_index = layout.start_offset(); let mut dst = vec![T::zero(); b_sz * c * h_out * w_out]; + for b_idx in 0..b_sz { let dst = &mut dst[b_idx * c * h_out * w_out..]; let src_index = src_index + b_idx * stride[0]; @@ -365,17 +372,37 @@ impl Map1 for MaxPool2D { let src_index = src_index + c_idx * stride[1]; for h_idx in 0..h_out { for w_idx in 0..w_out { - let mut largest = - src[src_index + s_h * h_idx * stride_h + s_w * w_idx * stride_w]; + let mut largest = T::zero(); + let mut found_valid = false; + for m in 0..k_h { for n in 0..k_w { - let m = s_h * h_idx + m; - let n = s_w * w_idx + n; - if largest < src[src_index + m * stride_h + n * stride_w] { - largest = src[src_index + m * stride_h + n * stride_w] + let src_h = s_h * h_idx + m; + let src_w = s_w * w_idx + n; + + // Check if we're within the original (unpadded) bounds + if src_h >= padding + && src_h < h + padding + && src_w >= padding + && src_w < w + padding + { + let actual_h = src_h - padding; + let actual_w = src_w - padding; + let val = + src[src_index + actual_h * stride_h + actual_w * stride_w]; + if !found_valid || largest < val { + largest = val; + found_valid = true; + } } } } + + // If no valid values were found (all padding), use zero + if !found_valid { + largest = T::zero(); + } + dst[h_idx * w_out + w_idx] = largest; } } @@ -1979,8 +2006,9 @@ impl BackendStorage for CpuStorage { layout: &Layout, kernel_size: (usize, usize), stride: (usize, usize), + padding: usize, ) -> Result { - MaxPool2D(kernel_size, stride).map(self, layout) + MaxPool2D(kernel_size, stride, padding).map(self, layout) } fn upsample_nearest1d(&self, layout: &Layout, sz: usize) -> Result { diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index b1f166a6ac..1ffb24b4fe 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -885,6 +885,7 @@ struct Pool2D { h_k: usize, w_stride: usize, h_stride: usize, + padding: usize, op: PoolOp, } @@ -905,8 +906,13 @@ impl Map1 for Pool2D { crate::bail!("unexpected input shape for pool {dims:?}") }; let el = shape.elem_count(); - let out_w = (dims[2] - self.w_k) / self.w_stride + 1; - let out_h = (dims[3] - self.h_k) / self.h_stride + 1; + + // Calculate output dimensions with padding + let h_padded = dims[2] + 2 * self.padding; + let w_padded = dims[3] + 2 * self.padding; + let out_w = (h_padded - self.w_k) / self.w_stride + 1; + let out_h = (w_padded - self.h_k) / self.h_stride + 1; + let dst_el = out_w * out_h * dims[0] * dims[1]; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let kname = match self.op { @@ -923,6 +929,7 @@ impl Map1 for Pool2D { barg!(builder, self.h_k); barg!(builder, self.w_stride); barg!(builder, self.h_stride); + barg!(builder, self.padding); builder.arg(&ds); builder.arg(inp); builder.arg(&out); @@ -1889,13 +1896,20 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn max_pool2d(&self, l: &Layout, k: (usize, usize), stride: (usize, usize)) -> Result { + fn max_pool2d( + &self, + l: &Layout, + k: (usize, usize), + stride: (usize, usize), + padding: usize, + ) -> Result { let device = self.device().clone(); let slice = Pool2D { w_k: k.0, h_k: k.1, w_stride: stride.0, h_stride: stride.1, + padding, op: PoolOp::Max, } .map(&self.slice, &device, l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 329099354b..f312b8b772 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -195,7 +195,13 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + fn max_pool2d( + &self, + _: &Layout, + _: (usize, usize), + _: (usize, usize), + _: usize, + ) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/dummy_metal_backend.rs b/candle-core/src/dummy_metal_backend.rs index de43f243fb..cb46ed7bb9 100644 --- a/candle-core/src/dummy_metal_backend.rs +++ b/candle-core/src/dummy_metal_backend.rs @@ -199,7 +199,13 @@ impl crate::backend::BackendStorage for MetalStorage { Err(Error::NotCompiledWithMetalSupport) } - fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result { + fn max_pool2d( + &self, + _: &Layout, + _: (usize, usize), + _: (usize, usize), + _: usize, + ) -> Result { Err(Error::NotCompiledWithMetalSupport) } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 3f47f6a4d2..0b7622f358 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1297,6 +1297,7 @@ impl BackendStorage for MetalStorage { let out_w = (width - w_k) / w_stride + 1; let out_h = (height - h_k) / h_stride + 1; let dst_el = out_w * out_h * b_size * channels; + let padding = 0; let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?; let command_buffers = self.device.command_buffer()?; candle_metal_kernels::call_pool2d( @@ -1312,6 +1313,7 @@ impl BackendStorage for MetalStorage { h_k, w_stride, h_stride, + padding, &self.buffer, &buffer, ) @@ -1324,6 +1326,7 @@ impl BackendStorage for MetalStorage { inp_l: &Layout, (w_k, h_k): (usize, usize), (w_stride, h_stride): (usize, usize), + padding: usize, ) -> Result { let shape = inp_l.shape(); let (b_size, channels, width, height) = shape.dims4()?; @@ -1336,8 +1339,13 @@ impl BackendStorage for MetalStorage { DType::U32 => "max_pool2d_u32", dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"), }; - let out_w = (width - w_k) / w_stride + 1; - let out_h = (height - h_k) / h_stride + 1; + + // Calculate output dimensions with padding + let width_padded = width + 2 * padding; + let height_padded = height + 2 * padding; + let out_w = (width_padded - w_k) / w_stride + 1; + let out_h = (height_padded - h_k) / h_stride + 1; + let dst_el = out_w * out_h * b_size * channels; let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?; let command_buffers = self.device.command_buffer()?; @@ -1354,6 +1362,7 @@ impl BackendStorage for MetalStorage { h_k, w_stride, h_stride, + padding, &self.buffer, &buffer, ) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 8e24368ff1..226526a1b1 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -135,6 +135,7 @@ pub enum Op { arg: Tensor, kernel_size: (usize, usize), stride: (usize, usize), + padding: usize, }, UpsampleNearest1D { diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 32af582473..7bbc8b4180 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -519,18 +519,19 @@ impl Storage { layout: &Layout, kernel_size: (usize, usize), stride: (usize, usize), + padding: usize, ) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.max_pool2d(layout, kernel_size, stride)?; + let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.max_pool2d(layout, kernel_size, stride)?; + let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?; Ok(Self::Cuda(storage)) } Self::Metal(storage) => { - let storage = storage.max_pool2d(layout, kernel_size, stride)?; + let storage = storage.max_pool2d(layout, kernel_size, stride, padding)?; Ok(Self::Metal(storage)) } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d71630212d..ce64da2dde 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1214,24 +1214,60 @@ impl Tensor { &self, kernel_size: T, stride: T, + ) -> Result { + let sz = kernel_size.to_usize2(); + let stride = stride.to_usize2(); + self.max_pool2d_with_stride_padding(sz, stride, 0) + } + + /// Same as `max_pool2d_with_stride` but with padding support. + /// + /// # Arguments + /// + /// * `kernel_size` - The size of the pooling window + /// * `stride` - The stride of the pooling operation, controlling how far the window + /// * `padding` - The amount of zero-padding to add to both sides of the height and width + /// # Examples + /// ```rust + /// use candle_core::{Tensor,Device,Shape}; + /// let t = Tensor::rand(0.0f32, 1.0, (1, 1, 10, 10), &Device::Cpu)?; + /// let (kernel_size,stride,padding) = (2,2,2); + /// let t = t.max_pool2d_with_stride_padding(kernel_size,stride,padding)?; + /// assert_eq!(t.shape().dims(),[1,1,7,7]); + /// # Ok::<(), candle_core::Error>(()) + /// ``` + pub fn max_pool2d_with_stride_padding( + &self, + kernel_size: T, + stride: T, + padding: usize, ) -> Result { let kernel_size = kernel_size.to_usize2(); let stride = stride.to_usize2(); let (n, c, h, w) = self.dims4()?; - if h < kernel_size.0 || w < kernel_size.1 { - bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}") + + // Calculate effective input size with padding + let h_padded = h + 2 * padding; + let w_padded = w + 2 * padding; + + if h_padded < kernel_size.0 || w_padded < kernel_size.1 { + bail!("kernel-size {kernel_size:?} is larger than the padded input size {h_padded},{w_padded}") } + // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d - let h_out = (h - kernel_size.0) / stride.0 + 1; - let w_out = (w - kernel_size.1) / stride.1 + 1; + let h_out = (h_padded - kernel_size.0) / stride.0 + 1; + let w_out = (w_padded - kernel_size.1) / stride.1 + 1; + let op = BackpropOp::new1(self, |arg| Op::MaxPool2D { arg, kernel_size, stride, + padding, }); + let storage = self .storage() - .max_pool2d(self.layout(), kernel_size, stride)?; + .max_pool2d(self.layout(), kernel_size, stride, padding)?; Ok(from_storage(storage, (n, c, h_out, w_out), op, false)) } diff --git a/candle-kernels/src/conv.cu b/candle-kernels/src/conv.cu index 3f15e0ad2e..33f51c54cb 100644 --- a/candle-kernels/src/conv.cu +++ b/candle-kernels/src/conv.cu @@ -446,6 +446,7 @@ __device__ void max_pool2d( const size_t h_k, const size_t w_stride, const size_t h_stride, + const size_t padding, const size_t *info, const T *src, T *dst @@ -459,8 +460,8 @@ __device__ void max_pool2d( const size_t w_in = src_dims[2]; const size_t h_in = src_dims[3]; - const size_t w_out = (w_in - w_k) / w_stride + 1; - const size_t h_out = (h_in - h_k) / h_stride + 1; + const size_t w_out = (w_in + 2 * padding - w_k) / w_stride + 1; + const size_t h_out = (h_in + 2 * padding - h_k) / h_stride + 1; if (dst_i >= src_dims[0] * c * w_out * h_out) { return; } @@ -476,14 +477,16 @@ __device__ void max_pool2d( bool set = false; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { size_t src_w = w_stride * dst_w + w_offset; - if (src_w >= w_in) { + if (src_w < padding || src_w >= w_in + padding) { continue; } + src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { size_t src_h = h_stride * dst_h + h_offset; - if (src_h >= h_in) { + if (src_h < padding || src_h >= h_in + padding) { continue; } + src_h -= padding; const size_t src_idx = src_idx0 + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3]; if (set) { d = maxg(d, src[src_idx]); @@ -671,11 +674,12 @@ extern "C" __global__ void FN_NAME( \ const size_t h_k, \ const size_t w_stride, \ const size_t h_stride, \ + const size_t padding, \ const size_t *info, \ const TYPENAME *src, \ TYPENAME *dst \ ) { \ - max_pool2d(src_numel, w_k, h_k, w_stride, h_stride, info, src, dst); \ + max_pool2d(src_numel, w_k, h_k, w_stride, h_stride, padding, info, src, dst); \ } \ #define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ diff --git a/candle-metal-kernels/src/kernels/convolution.rs b/candle-metal-kernels/src/kernels/convolution.rs index 6b2e5fcf96..3cc7680fae 100644 --- a/candle-metal-kernels/src/kernels/convolution.rs +++ b/candle-metal-kernels/src/kernels/convolution.rs @@ -148,6 +148,7 @@ pub fn call_pool2d( h_k: usize, w_stride: usize, h_stride: usize, + padding: usize, input: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { @@ -159,7 +160,7 @@ pub fn call_pool2d( encoder.set_compute_pipeline_state(&pipeline); set_params!( encoder, - (w_k, h_k, w_stride, h_stride, shape, strides, input, output) + (w_k, h_k, w_stride, h_stride, padding, shape, strides, input, output) ); encoder.use_resource(input, MTLResourceUsage::Read); encoder.use_resource(output, MTLResourceUsage::Write); @@ -167,6 +168,7 @@ pub fn call_pool2d( Ok(()) } + #[allow(clippy::too_many_arguments)] pub fn call_conv_transpose1d( device: &Device, diff --git a/candle-metal-kernels/src/metal_src/conv.metal b/candle-metal-kernels/src/metal_src/conv.metal index fbe19bb87f..d4e38a463d 100644 --- a/candle-metal-kernels/src/metal_src/conv.metal +++ b/candle-metal-kernels/src/metal_src/conv.metal @@ -332,6 +332,7 @@ METAL_FUNC void max_pool2d( constant size_t &h_k, constant size_t &w_stride, constant size_t &h_stride, + constant size_t &padding, constant size_t *src_dims, constant size_t *src_strides, device const T *src, @@ -342,8 +343,8 @@ METAL_FUNC void max_pool2d( const size_t w_in = src_dims[2]; const size_t h_in = src_dims[3]; - const size_t w_out = (w_in - w_k) / w_stride + 1; - const size_t h_out = (h_in - h_k) / h_stride + 1; + const size_t w_out = (w_in + 2 * padding - w_k) / w_stride + 1; + const size_t h_out = (h_in + 2 * padding - h_k) / h_stride + 1; if (tid >= src_dims[0] * c * w_out * h_out) { return; } @@ -358,14 +359,16 @@ METAL_FUNC void max_pool2d( bool set = false; for (size_t w_offset = 0; w_offset < w_k; ++w_offset) { size_t src_w = w_stride * dst_w + w_offset; - if (src_w >= w_in){ + if (src_w < padding || src_w >= w_in + padding){ continue; } + src_w -= padding; for (size_t h_offset = 0; h_offset < h_k; ++h_offset) { size_t src_h = h_stride * dst_h + h_offset; - if (src_h >= h_in) { + if (src_h < padding || src_h >= h_in + padding) { continue; } + src_h -= padding; const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3]; if (set) { d = MAX(d, src[src_idx]); @@ -385,13 +388,14 @@ kernel void FN_NAME( \ constant size_t &h_k, \ constant size_t &w_s, \ constant size_t &h_s, \ + constant size_t &padding, \ constant size_t *src_dims, \ constant size_t *src_s, \ device const TYPENAME *src, \ device TYPENAME *dst, \ uint tid [[ thread_position_in_grid ]] \ ) { \ - max_pool2d(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \ + max_pool2d(w_k, h_k, w_s, h_s, padding, src_dims, src_s, src, dst, tid); \ } \