Skip to content

Commit

Permalink
feat: resize onnx import (#1863)
Browse files Browse the repository at this point in the history
* feat: resize onnx import

* fix: resize import proc macro output

* fix: lint

* fix: simplify resize onnx

* fix: onnx-tests passing

* feedback: remove dead code and resolve merge conflicts
  • Loading branch information
mosure committed Jun 11, 2024
1 parent 671ec8c commit 71bd5ef
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 5 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ represent the corresponding Burn Op.
| [ReduceSumSquare][140] |||
| [Relu][141] |||
| [Reshape][142] |||
| [Resize][143] | ||
| [Resize][143] | ||
| [ReverseSequence][144] |||
| [RNN][145] |||
| [RoiAlign][146] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ fn main() {
.input("tests/reduce_sum/reduce_sum_opset13.onnx")
.input("tests/reduce_sum/reduce_sum_opset11.onnx")
.input("tests/reshape/reshape.onnx")
.input("tests/resize/resize.onnx")
.input("tests/shape/shape.onnx")
.input("tests/sigmoid/sigmoid.onnx")
.input("tests/sign/sign.onnx")
Expand Down
25 changes: 25 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ include_models!(
reduce_sum_opset11,
relu,
reshape,
resize,
shape,
sigmoid,
sign,
Expand Down Expand Up @@ -789,6 +790,30 @@ mod tests {
assert_eq!(output.to_data(), expected);
}

#[test]
fn resize() {
// Initialize the model without weights (because the exported file does not contain them)
let device = Default::default();
let model: resize::Model<Backend> = resize::Model::new(&device);

// Run the model
let input = Tensor::<Backend, 4>::from_floats(
[[[
[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 10.0, 11.0],
[12.0, 13.0, 14.0, 15.0],
]]],
&device,
);
let size = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 3], &device);

let output = model.forward(input, size);
let expected = Data::from([[[[0.0, 1.5, 3.0], [12.0, 13.5, 15.0]]]]);

assert_eq!(output.to_data(), expected);
}

#[test]
fn shape() {
let device = Default::default();
Expand Down
Binary file not shown.
35 changes: 35 additions & 0 deletions crates/burn-import/onnx-tests/tests/resize/resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/resize/resize.onnx

import onnx
from onnx import helper, TensorProto

def main() -> None:
input_tensor = helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [1, 1, 4, 4])
sizes_tensor = helper.make_tensor_value_info("sizes", TensorProto.INT64, [4])

resize_node = helper.make_node(
"Resize",
name="resize_node",
inputs=["input_tensor", "", "", "sizes"],
outputs=["output"],
mode="linear",
)

graph_def = helper.make_graph(
nodes=[resize_node],
name="ResizeGraph",
inputs=[input_tensor, sizes_tensor],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 2, 2])
],
)

model_def = helper.make_model(graph_def, producer_name="resize")

onnx.save(model_def, "resize.onnx")


if __name__ == "__main__":
main()
7 changes: 5 additions & 2 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use super::{
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
MaxPool2d(MaxPool2dNode),
Range(RangeNode),
Reshape(ReshapeNode),
Resize(ResizeNode),
Slice(SliceNode),
Squeeze(SqueezeNode),
Sum(SumNode),
Expand Down Expand Up @@ -140,6 +141,7 @@ macro_rules! match_all {
Node::MaxPool2d(node) => $func(node),
Node::Range(node) => $func(node),
Node::Reshape(node) => $func(node),
Node::Resize(node) => $func(node),
Node::Slice(node) => $func(node),
Node::Squeeze(node) => $func(node),
Node::Sum(node) => $func(node),
Expand Down Expand Up @@ -188,6 +190,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::MaxPool2d(_) => "max_pool2d",
Node::Range(_) => "range",
Node::Reshape(_) => "reshape",
Node::Resize(_) => "resize",
Node::Slice(_) => "slice",
Node::Squeeze(_) => "squeeze",
Node::Sum(_) => "add",
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub(crate) mod random_normal;
pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
pub(crate) mod resize;
pub(crate) mod slice;
pub(crate) mod squeeze;
pub(crate) mod sum;
Expand Down
207 changes: 207 additions & 0 deletions crates/burn-import/src/burn/node/resize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use super::{Node, NodeCodegen};
use crate::burn::{OtherType, Scope, TensorType, Type};
use burn::module::Module;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Module, Debug, Clone)]
pub enum ResizeMode {
Nearest,
Linear,
Cubic,
}

#[derive(new, Module, Debug, Clone)]
pub struct ResizeOptions {
pub mode: ResizeMode,
}

#[derive(Debug, Clone)]
pub struct ResizeNode {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub output_size: TensorType,
pub config: ResizeOptions,
}

impl ResizeNode {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
output_size: TensorType,
config: ResizeOptions,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
burn::module::Ignored<InterpolateOptions>
},
),
input,
output,
output_size,
config,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for ResizeNode {
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn input_types(&self) -> Vec<Type> {
vec![
Type::Tensor(self.input.clone()),
Type::Tensor(self.output_size.clone()),
]
}

fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;

let mode = match self.config.mode {
ResizeMode::Linear => quote! { InterpolateMode::Bilinear },
ResizeMode::Nearest => quote! { InterpolateMode::Nearest },
ResizeMode::Cubic => quote! { InterpolateMode::Bicubic },
};

let tokens = quote! {
let #name = InterpolateOptions {
mode: #mode,
};
let #name = burn::module::Ignored(#name);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
S::serialize_none(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output_size = scope.tensor_use_owned(&self.output_size, node_position);
let output = &self.output.name;

let field = &self.field.name;

quote! {
let output_size_raw = #output_size.to_data().value;
let mut output_size = [0usize; 2];

for (i, &x) in output_size_raw.iter().rev().take(2).rev().enumerate() {
output_size[i] = x.elem::<i64>() as usize;
}

let #output = interpolate(
#input,
output_size,
self.#field.0.clone(),
);
}
}

fn into_node(self) -> Node<PS> {
Node::Resize(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::ElementConversion");
imports.register("burn::tensor::module::interpolate");
imports.register("burn::tensor::ops::InterpolateMode");
imports.register("burn::tensor::ops::InterpolateOptions");
}
}

#[cfg(test)]
mod tests {
use burn::record::FullPrecisionSettings;

use super::*;
use crate::burn::{
graph::BurnGraph,
node::{resize::ResizeNode, test::assert_tokens},
TensorType,
};

#[test]
fn test_codegen_nodes() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(ResizeNode::new(
"resize",
TensorType::new_float("tensor1", 4),
TensorType::new_float("tensor2", 4),
TensorType::new_int("output_size", 1),
ResizeOptions::new(ResizeMode::Linear),
));

graph.register_input_output(
vec!["tensor1".to_string(), "output_size".to_string()],
vec!["tensor2".to_string()],
);

let expected = quote! {
use burn::tensor::module::interpolate;
use burn::tensor::ops::InterpolateMode;
use burn::tensor::ops::InterpolateOptions;
use burn::tensor::ElementConversion;
use burn::tensor::Int;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
resize: burn::module::Ignored<InterpolateOptions>,
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let resize = InterpolateOptions {
mode: InterpolateMode::Bilinear,
};
let resize = burn::module::Ignored(resize);
Self {
resize,
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(
&self,
tensor1: Tensor<B, 4>,
output_size: Tensor<B, 1, Int>
) -> Tensor<B, 4> {
let output_size_raw = output_size.to_data().value;
let mut output_size = [0usize; 2];

for (i, &x) in output_size_raw.iter().rev().take(2).rev().enumerate() {
output_size[i] = x.elem::<i64>() as usize;
}

let tensor2 = interpolate(tensor1, output_size, self.resize.0.clone());

tensor2
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
28 changes: 28 additions & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::ReduceSum => reduce_sum_update_outputs(node),
NodeType::Relu => same_as_input(node),
NodeType::Reshape => reshape_update_outputs(node),
NodeType::Resize => resize_update_outputs(node),
NodeType::Shape => shape_update_outputs(node),
NodeType::Sigmoid => same_as_input(node),
NodeType::Sign => same_as_input(node),
Expand Down Expand Up @@ -285,6 +286,33 @@ fn reshape_update_outputs(node: &mut Node) {
}
}

fn resize_update_outputs(node: &mut Node) {
let input = match &node.inputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Resize: invalid input type"),
};

let output = match &node.outputs[0].ty {
ArgType::Tensor(tensor) => tensor.clone(),
_ => panic!("Resize: invalid output type"),
};

let output_size = match &node.inputs[3].ty {
ArgType::Tensor(output_size) => output_size.clone(),
_ => panic!("Resize: invalid output_size type"),
};

if output_size.dim != 1 {
panic!("Resize: output_size must be 1D");
}

node.outputs[0].ty = ArgType::Tensor(TensorType {
dim: input.dim,
shape: None, // shape is calculated at runtime
..output
});
}

fn greater_update_outputs(node: &mut Node) {
match &node.inputs[0].ty {
ArgType::Tensor(tensor) => {
Expand Down
Loading

0 comments on commit 71bd5ef

Please sign in to comment.