diff --git a/examples/imagenet.rs b/examples/imagenet.rs index 63765d25..da53a97e 100644 --- a/examples/imagenet.rs +++ b/examples/imagenet.rs @@ -58,9 +58,27 @@ fn read_image f32>( normalize_pixel: N, out_chan_order: ChannelOrder, out_dim_order: DimOrder, + out_height: u32, + out_width: u32, ) -> Result, Box> { let input_img = image::open(path)?; let input_img = input_img.into_rgb8(); + + // Resize the image using the `imageops::resize` function from the `image` + // crate rather than using Wasnn's `resize` operator because + // `imageops::resize` supports antialiasing. This significantly improves + // output image quality and thus prediction accuracy when the output is + // small (eg. 224 or 256px). + // + // The outputs of `imageops::resize` still don't match PyTorch exactly + // though, which can lead to small differences in prediction outputs. + let input_img = image::imageops::resize( + &input_img, + out_width, + out_height, + image::imageops::FilterType::Triangle, + ); + let (width, height) = input_img.dimensions(); // Map input channel index, in RGB order, to output channel index @@ -197,18 +215,6 @@ fn main() -> Result<(), Box> { }, }; - let img_tensor = read_image( - &args.image, - normalize_pixel, - args.config.chan_order, - args.config.dim_order, - )?; - - let (height, width) = match args.config.dim_order { - DimOrder::Nchw => (img_tensor.size(2), img_tensor.size(3)), - DimOrder::Nhwc => (img_tensor.size(1), img_tensor.size(2)), - }; - let input_id = model .input_ids() .get(0) @@ -237,11 +243,14 @@ fn main() -> Result<(), Box> { } }; - let img_tensor = if height != in_height as usize || width != in_width as usize { - img_tensor.resize_image([in_height, in_width])? - } else { - img_tensor - }; + let img_tensor = read_image( + &args.image, + normalize_pixel, + args.config.chan_order, + args.config.dim_order, + in_height as u32, + in_width as u32, + )?; let logits: NdTensor = model.run_one(img_tensor.view().into(), None)?.try_into()?;