Skip to content

Commit

Permalink
Switch to image::imageops::resize for resizing in the imagenet demo
Browse files Browse the repository at this point in the history
The `imageops::resize` function produces better quality outputs when the output
is small, and has been significantly downscaled from the input, because it does
antialiasing. The the implementation in the `Resize` operator does not.

This improves the alignment between prediction outputs of the
`tools/export-timm-model.py` script and the imagenet demo. The results are still
not perfectly identical because the transforms don't exactly match.
  • Loading branch information
robertknight committed Nov 18, 2023
1 parent b2f920f commit 958177f
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions examples/imagenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,27 @@ fn read_image<N: Fn(usize, f32) -> f32>(
normalize_pixel: N,
out_chan_order: ChannelOrder,
out_dim_order: DimOrder,
out_height: u32,
out_width: u32,
) -> Result<Tensor<f32>, Box<dyn Error>> {
let input_img = image::open(path)?;
let input_img = input_img.into_rgb8();

// Resize the image using the `imageops::resize` function from the `image`
// crate rather than using Wasnn's `resize` operator because
// `imageops::resize` supports antialiasing. This significantly improves
// output image quality and thus prediction accuracy when the output is
// small (eg. 224 or 256px).
//
// The outputs of `imageops::resize` still don't match PyTorch exactly
// though, which can lead to small differences in prediction outputs.
let input_img = image::imageops::resize(
&input_img,
out_width,
out_height,
image::imageops::FilterType::Triangle,
);

let (width, height) = input_img.dimensions();

// Map input channel index, in RGB order, to output channel index
Expand Down Expand Up @@ -197,18 +215,6 @@ fn main() -> Result<(), Box<dyn Error>> {
},
};

let img_tensor = read_image(
&args.image,
normalize_pixel,
args.config.chan_order,
args.config.dim_order,
)?;

let (height, width) = match args.config.dim_order {
DimOrder::Nchw => (img_tensor.size(2), img_tensor.size(3)),
DimOrder::Nhwc => (img_tensor.size(1), img_tensor.size(2)),
};

let input_id = model
.input_ids()
.get(0)
Expand Down Expand Up @@ -237,11 +243,14 @@ fn main() -> Result<(), Box<dyn Error>> {
}
};

let img_tensor = if height != in_height as usize || width != in_width as usize {
img_tensor.resize_image([in_height, in_width])?
} else {
img_tensor
};
let img_tensor = read_image(
&args.image,
normalize_pixel,
args.config.chan_order,
args.config.dim_order,
in_height as u32,
in_width as u32,
)?;

let logits: NdTensor<f32, 2> = model.run_one(img_tensor.view().into(), None)?.try_into()?;

Expand Down

0 comments on commit 958177f

Please sign in to comment.