Skip to content

Conversation

@slckl
Copy link
Contributor

@slckl slckl commented Oct 3, 2025

I'm trying to make convolutions go faster on cpu, and one relatively low hanging fruit happened to be in StridedIndex.

StridedIndex has a relatively expensive Iterator::next implementation, which repeatedly chains/zips several iterators on every next call. By simply precomputing the zipped structure, I'm seeing small, but consistent improvements for everything that uses StridedIndex, particularly cpu conv stuff.

I've attached 2 cargo bench runs on cpu, before and after the changes made in this PR.
cpu.baseline.txt
cpu.precomp.strided.2.txt

A few outtakes:

baseline
---
baseline cpu_conv_transpose2d_f32/iter
                        time:   [320.07 µs 324.01 µs 329.06 µs]
                        thrpt:  [115.93 MiB/s 117.73 MiB/s 119.18 MiB/s]

precomputed cpu_conv_transpose2d_f32/iter
                        time:   [299.14 µs 301.07 µs 303.90 µs]
                        thrpt:  [125.52 MiB/s 126.71 MiB/s 127.52 MiB/s]
---
baseline cpu_conv_transpose2d_f16/iter
                        time:   [538.14 µs 543.90 µs 551.24 µs]
                        thrpt:  [34.601 MiB/s 35.068 MiB/s 35.444 MiB/s]

precomputed cpu_conv_transpose2d_f16/iter
                        time:   [517.83 µs 519.02 µs 521.06 µs]
                        thrpt:  [36.605 MiB/s 36.749 MiB/s 36.833 MiB/s]
---
baseline cpu_conv_transpose2d_bf16/iter
                        time:   [654.04 µs 662.47 µs 669.64 µs]
                        thrpt:  [28.483 MiB/s 28.791 MiB/s 29.162 MiB/s]

precomputed cpu_conv_transpose2d_bf16/iter
                        time:   [605.58 µs 612.43 µs 621.15 µs]
                        thrpt:  [30.707 MiB/s 31.144 MiB/s 31.496 MiB/s]

multi-index lookup, this yields a few % improvement due to this iterator
being hot enough
@ivarflakstad
Copy link
Member

Very nice 👌
I've found cpu benchmarking to be less reliable than say cuda/metal (larger noise ratio), but the changes I'm seeing from your results seem real.
However I think we need look into this further / to do some more benchmarks, because I'm unable to reproduce your results on my machine.
What are you running on? :)

@slckl
Copy link
Contributor Author

slckl commented Oct 4, 2025

This is super noisy, yes. I stumbled upon this while dissecting conv2d performance on cpu. I'm running this on a laptop (even more noise, yay) i7 12700H, Ubuntu 24.

So, what we're actually trying to optimize here is copy_strided_src function, which in conv2d case copies the results of the convolution to a contiguous output tensor. In my experiments on cpu, this function is like 30% of the runtime of a single conv2d pass.

A sample run of a 3x3 kernel on a batch of 2 images with resolution 320 x 320:

-- im2col: 7.658279ms
-- kernel setup: 751.496µs
-- kernel exec: 1.525553ms
-- copy_strided_src: 5.81544ms
conv2d 15.792275ms

For comparison a pytorch eager run with the same input:

Average conv2d time over 100 runs: 2.092 ms

Kernel is already fast (thanks to gemm library), but there's substantial overhead before and after the kernel, and we can see pytorch eager running circles around candle here. Unfortunately, the copy_strided_src function is already very simple (just a couple of for loops and memcpys), the real fix would be avoiding it somehow.

The only thing that produced any noticeable improvement was the StridedIndex precomputation, but the improvement is so small, I may have just gotten a couple lucky runs.

I'm ok with not accepting this, this is a very small win (although this also loses a couple lifetimes, so that's a code win) compared to the big stuff.

For repro, I've tried to simplify my experiment code.

Put this inside candle-nn/tests/exp_conv2d.rs, it should work on both main and this branch.

Run it with cargo test -r -p candle-nn --test exp_conv2d -- --nocapture

exp_conv2d.rs:

use std::time::{Duration, Instant};

use anyhow::Result;
use candle::{DType, Device, Module, Tensor};
use candle_nn::{Conv2dConfig, VarBuilder, VarMap};

fn test_tensor(dims: (usize, usize, usize, usize), device: &Device) -> Result<Tensor> {
    let (batch_size, channels, height, width) = dims;

    // Create deterministic input tensor with hardcoded values for reproducibility
    let mut input_data = vec![0.0f32; batch_size * channels * height * width];

    for b in 0..batch_size {
        for c in 0..channels {
            for h in 0..height {
                for w in 0..width {
                    // Create a deterministic pattern based on batch, channel, and position
                    let value = (b + 1) as f32 * 0.1
                        + (c + 1) as f32 * 0.01
                        + (h * width + w) as f32 * 0.001;
                    let idx = b * channels * height * width + c * height * width + h * width + w;
                    input_data[idx] = value;
                }
            }
        }
    }

    let input_tensor = Tensor::from_vec(input_data, (batch_size, channels, height, width), device)?;

    Ok(input_tensor)
}

#[test]
fn just_conv() -> Result<()> {
    // let device = Device::new_cuda(0)?;
    let device = Device::Cpu;

    // Create deterministic input: batch_size=2, channels=3, height=32, width=32
    let (batch_size, in_channels, height, width) = (2, 3, 320, 320);
    let dims = (batch_size, in_channels, height, width);

    let input_tensor = test_tensor(dims, &device)?;

    // Create conv2d layer: 3 input channels, 16 output channels, 3x3 kernel, stride=1, padding=1
    let (out_channels, kernel_size, stride, padding) = (16, 3, 1, 1);

    // Create VarMap and VarBuilder for the conv layer
    let varmap = VarMap::new();
    let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);

    let conv_config = Conv2dConfig {
        padding,
        stride,
        dilation: 1,
        groups: 1,
        cudnn_fwd_algo: None,
    };

    let conv_layer = candle_nn::conv2d_no_bias(
        in_channels,
        out_channels,
        kernel_size,
        conv_config,
        vb.pp("conv"),
    )?;
    const ITERS: usize = 200;
    const WARMUP: usize = 50;
    let mut min = Duration::MAX;
    let mut max = Duration::ZERO;
    let mut total = Duration::ZERO;
    for i in 0..ITERS + WARMUP {
        // println!("---");
        let start = Instant::now();
        let _ = conv_layer.forward(&input_tensor)?;
        device.synchronize()?;
        let elapsed = start.elapsed();
        if i > WARMUP {
            total += elapsed;
            if elapsed < min {
                min = elapsed;
            }
            if elapsed > max {
                max = elapsed;
            }
        }
        // println!("conv2d {:?}", start.elapsed());
    }
    device.synchronize()?;

    println!(
        "{ITERS} iters, min/avg/max: [{:?} {:?} {:?}]",
        min,
        total / ITERS as u32,
        max
    );

    Ok(())
}

@ivarflakstad
Copy link
Member

Don't worry, we're not giving up on this quite yet 😉

I'm actually getting better results when I set const USE_IM2COL_CONV2D: bool = false;. How about you? (I'm on a mac so might be completely different)

Also, since you've already got everything set up - what does the pytorch call stack/trace look like? If I was to guess I would say they are calling an optimized oneDNN strided conv2d kernel.

@slckl
Copy link
Contributor Author

slckl commented Oct 7, 2025

Non-im2col version for me is slower: 29.5 ms on average vs 23.4 ms with im2col.

Pytorch uses something called: aten::mkldnn_convolution
The code can be found here: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/mkldnn/Conv.cpp
Doesn't have any nice high level comments, however.

onednn has some documentation here regarding convolutions: https://github.com/uxlfoundation/oneDNN/blob/main/doc/primitives/convolution.md#algorithms

They seem to have :

  1. optimized direct convolution impl,
  2. fallback implicit gemm (which is also what cudnn does, afaik),
  3. specialized kernels.

No one seems to be using im2col anymore 😄

For candle to be fast, we probably have to focus on optimizing direct convolutions or implementing implicit gemm. If I understand correctly, fast implicit gemm should be a thing on cpu as well...

Perhaps an issue should be opened about speeding up cpu convolutions?

@slckl
Copy link
Contributor Author

slckl commented Oct 7, 2025

Created follow-up issue with a few more timings from another machine: #3119

@slckl
Copy link
Contributor Author

slckl commented Oct 18, 2025

I've included your changes as well + added some more comments.
The .t() unary sqrt bench gives me these numbers:

main branch

cpu_sqrt_F32/iter       time:   [2.6323 ms 2.6493 ms 2.6791 ms]
                        thrpt:  [1.4581 GiB/s 1.4745 GiB/s 1.4840 GiB/s]

cpu_sqrt_BF16/iter      time:   [2.7795 ms 2.8006 ms 2.8286 ms]
                        thrpt:  [707.06 MiB/s 714.12 MiB/s 719.55 MiB/s]

cpu_sqrt_F16/iter       time:   [7.0622 ms 7.0849 ms 7.1176 ms]
                        thrpt:  [280.99 MiB/s 282.29 MiB/s 283.20 MiB/s]

strided-idx

cpu_sqrt_F32/iter       time:   [2.5412 ms 2.5598 ms 2.5897 ms]
                        thrpt:  [1.5084 GiB/s 1.5260 GiB/s 1.5371 GiB/s]

cpu_sqrt_BF16/iter      time:   [2.6430 ms 2.6612 ms 2.6871 ms]
                        thrpt:  [744.30 MiB/s 751.54 MiB/s 756.73 MiB/s]

cpu_sqrt_F16/iter       time:   [7.0443 ms 7.0654 ms 7.0996 ms]
                        thrpt:  [281.71 MiB/s 283.07 MiB/s 283.92 MiB/s]

But the perf change is v noisy, and a bad run can be behind main.

I think it can still be merged due to slightly, uhh, more readable code or smth.

@slckl slckl requested a review from ivarflakstad October 18, 2025 21:04
@slckl
Copy link
Contributor Author

slckl commented Oct 25, 2025

Implemented size_hint and ExactSizeIterator as well.

Does seem to have an improvement on the broadcast_add benchmarks, numbers bo3, on i7-12700h:

  • main branch
cpu_broadcast_add_f32/iter
                        time:   [4.3552 ms 4.3672 ms 4.3829 ms]
                        thrpt:  [14.037 MiB/s 14.088 MiB/s 14.126 MiB/s]

cpu_broadcast_add_f16/iter
                        time:   [9.9714 ms 10.029 ms 10.104 ms]
                        thrpt:  [3.0445 MiB/s 3.0674 MiB/s 3.0850 MiB/s]

cpu_broadcast_add_bf16/iter
                        time:   [9.3572 ms 9.3862 ms 9.4246 ms]
                        thrpt:  [3.2640 MiB/s 3.2773 MiB/s 3.2875 MiB/s]

strided-idx branch aka this PR:

cpu_broadcast_add_f32/iter
                        time:   [4.2696 ms 4.2927 ms 4.3204 ms]
                        thrpt:  [14.240 MiB/s 14.332 MiB/s 14.410 MiB/s]

cpu_broadcast_add_f16/iter
                        time:   [8.3756 ms 8.4021 ms 8.4401 ms]
                        thrpt:  [3.6447 MiB/s 3.6612 MiB/s 3.6728 MiB/s]

cpu_broadcast_add_bf16/iter
                        time:   [9.1538 ms 9.1703 ms 9.1929 ms]
                        thrpt:  [3.3462 MiB/s 3.3545 MiB/s 3.3605 MiB/s]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants