-
Notifications
You must be signed in to change notification settings - Fork 1.3k
slightly improve StridedIndex performance #3112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
multi-index lookup, this yields a few % improvement due to this iterator being hot enough
|
Very nice 👌 |
|
This is super noisy, yes. I stumbled upon this while dissecting So, what we're actually trying to optimize here is A sample run of a 3x3 kernel on a batch of 2 images with resolution 320 x 320: For comparison a pytorch eager run with the same input: Kernel is already fast (thanks to The only thing that produced any noticeable improvement was the I'm ok with not accepting this, this is a very small win (although this also loses a couple lifetimes, so that's a code win) compared to the big stuff. For repro, I've tried to simplify my experiment code. Put this inside Run it with exp_conv2d.rs: use std::time::{Duration, Instant};
use anyhow::Result;
use candle::{DType, Device, Module, Tensor};
use candle_nn::{Conv2dConfig, VarBuilder, VarMap};
fn test_tensor(dims: (usize, usize, usize, usize), device: &Device) -> Result<Tensor> {
let (batch_size, channels, height, width) = dims;
// Create deterministic input tensor with hardcoded values for reproducibility
let mut input_data = vec![0.0f32; batch_size * channels * height * width];
for b in 0..batch_size {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
// Create a deterministic pattern based on batch, channel, and position
let value = (b + 1) as f32 * 0.1
+ (c + 1) as f32 * 0.01
+ (h * width + w) as f32 * 0.001;
let idx = b * channels * height * width + c * height * width + h * width + w;
input_data[idx] = value;
}
}
}
}
let input_tensor = Tensor::from_vec(input_data, (batch_size, channels, height, width), device)?;
Ok(input_tensor)
}
#[test]
fn just_conv() -> Result<()> {
// let device = Device::new_cuda(0)?;
let device = Device::Cpu;
// Create deterministic input: batch_size=2, channels=3, height=32, width=32
let (batch_size, in_channels, height, width) = (2, 3, 320, 320);
let dims = (batch_size, in_channels, height, width);
let input_tensor = test_tensor(dims, &device)?;
// Create conv2d layer: 3 input channels, 16 output channels, 3x3 kernel, stride=1, padding=1
let (out_channels, kernel_size, stride, padding) = (16, 3, 1, 1);
// Create VarMap and VarBuilder for the conv layer
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let conv_config = Conv2dConfig {
padding,
stride,
dilation: 1,
groups: 1,
cudnn_fwd_algo: None,
};
let conv_layer = candle_nn::conv2d_no_bias(
in_channels,
out_channels,
kernel_size,
conv_config,
vb.pp("conv"),
)?;
const ITERS: usize = 200;
const WARMUP: usize = 50;
let mut min = Duration::MAX;
let mut max = Duration::ZERO;
let mut total = Duration::ZERO;
for i in 0..ITERS + WARMUP {
// println!("---");
let start = Instant::now();
let _ = conv_layer.forward(&input_tensor)?;
device.synchronize()?;
let elapsed = start.elapsed();
if i > WARMUP {
total += elapsed;
if elapsed < min {
min = elapsed;
}
if elapsed > max {
max = elapsed;
}
}
// println!("conv2d {:?}", start.elapsed());
}
device.synchronize()?;
println!(
"{ITERS} iters, min/avg/max: [{:?} {:?} {:?}]",
min,
total / ITERS as u32,
max
);
Ok(())
} |
|
Don't worry, we're not giving up on this quite yet 😉 I'm actually getting better results when I set Also, since you've already got everything set up - what does the pytorch call stack/trace look like? If I was to guess I would say they are calling an optimized oneDNN strided conv2d kernel. |
|
Non-im2col version for me is slower: 29.5 ms on average vs 23.4 ms with im2col. Pytorch uses something called: onednn has some documentation here regarding convolutions: https://github.com/uxlfoundation/oneDNN/blob/main/doc/primitives/convolution.md#algorithms They seem to have :
No one seems to be using im2col anymore 😄 For candle to be fast, we probably have to focus on optimizing direct convolutions or implementing implicit gemm. If I understand correctly, fast implicit gemm should be a thing on cpu as well... Perhaps an issue should be opened about speeding up cpu convolutions? |
|
Created follow-up issue with a few more timings from another machine: #3119 |
|
I've included your changes as well + added some more comments.
|
|
Implemented Does seem to have an improvement on the
|
I'm trying to make convolutions go faster on cpu, and one relatively low hanging fruit happened to be in
StridedIndex.StridedIndexhas a relatively expensiveIterator::nextimplementation, which repeatedly chains/zips several iterators on everynextcall. By simply precomputing the zipped structure, I'm seeing small, but consistent improvements for everything that usesStridedIndex, particularly cpu conv stuff.I've attached 2
cargo benchruns on cpu, before and after the changes made in this PR.cpu.baseline.txt
cpu.precomp.strided.2.txt
A few outtakes: