Skip to content

Commit

Permalink
Add ONNX op Random Normal Like (#2441)
Browse files Browse the repository at this point in the history
* add random normal like python code to generate onnx model

* add random normal like node

* modify onnx burn to add new op

* add test on test onnx

* revert commentouts

* fix review points to respond to dynamically shape
  • Loading branch information
tiruka authored Nov 4, 2024
1 parent 15a79c1 commit 970b9dc
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 6 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ represent the corresponding Burn Op.
| [QLinearMatMul][124] |||
| [QuantizeLinear][125] |||
| [RandomNormal][126] |||
| [RandomNormalLike][127] | ||
| [RandomNormalLike][127] | ||
| [RandomUniform][128] |||
| [RandomUniformLike][129] |||
| [Range][130] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ fn main() {
.input("tests/pow/pow_int.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/random_normal_like/random_normal_like.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/range/range.onnx")
.input("tests/recip/recip.onnx")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pytorch2.2.0:�
P
onnx::RandomNormalLike_01/RandomNormalLike"RandomNormalLike*
dtype�
main_graphZ.
onnx::RandomNormalLike_0



b
1



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: random_normal_like.onnx

import torch
import torch.nn as nn


class RandomNormalLikeModel(nn.Module):
def __init__(self):
super(RandomNormalLikeModel, self).__init__()

def forward(self, x):
return torch.randn_like(x)


def main():
# Set seed for reproducibility
torch.manual_seed(42)

# Set print options for better precision output
torch.set_printoptions(precision=8)

# Export Random NormalLike Model
model = RandomNormalLikeModel()
model.eval()
device = torch.device("cpu")

# Generate test input: a 2D matrix or batch of 2D matrices
file_name = "random_normal_like.onnx"
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices
torch.onnx.export(model,
test_input,
file_name,
verbose=False,
opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data: {}".format(test_input))
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))


if __name__ == '__main__':
main()
13 changes: 13 additions & 0 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ include_models!(
pow_int,
prelu,
random_normal,
random_normal_like,
random_uniform,
range,
recip,
Expand Down Expand Up @@ -2157,6 +2158,18 @@ mod tests {
assert_eq!(expected_shape, output.shape());
}

#[test]
fn random_normal_like() {
let device = Default::default();
let model = random_normal_like::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f64, _>(Shape::from([2, 4, 4]));
let expected_shape = Shape::from([2, 4, 4]);

let output = model.forward(input.into());

assert_eq!(expected_shape, output.shape());
}

#[test]
fn constant_of_shape() {
// This tests shape is being passed directly to the model
Expand Down
13 changes: 8 additions & 5 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use super::{
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode,
slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -121,8 +121,9 @@ pub enum Node<PS: PrecisionSettings> {
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
RandomUniform(RandomUniformNode),
RandomNormal(RandomNormalNode),
RandomNormalLike(RandomNormalLikeNode),
RandomUniform(RandomUniformNode),
ConstantOfShape(ConstantOfShapeNode),
// For now, we have to keep the precision settings in order to correctly serialize the fields
// into the right data types.
Expand Down Expand Up @@ -172,6 +173,7 @@ macro_rules! match_all {
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Node::RandomNormal(node) => $func(node),
Node::RandomNormalLike(node) => $func(node),
Node::RandomUniform(node) => $func(node),
Node::ConstantOfShape(node) => $func(node),
_ => unimplemented!(),
Expand Down Expand Up @@ -230,6 +232,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Node::RandomNormal(_) => "random_normal",
Node::RandomNormalLike(_) => "random_normal_like",
Node::RandomUniform(_) => "random_uniform",
Node::ConstantOfShape(_) => "constant_of_shape",
_ => unimplemented!(),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub(crate) mod mean;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_normal_like;
pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
Expand Down
101 changes: 101 additions & 0 deletions crates/burn-import/src/burn/node/random_normal_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct RandomNormalLikeNode {
pub mean: f64,
pub scale: f64,
pub input: TensorType,
pub output: TensorType,
}

impl RandomNormalLikeNode {
// Set distribution parameters based on mean and scale
fn get_distribution(&self) -> TokenStream {
let mean = self.mean;
let std_deviation = self.scale;
quote! { Distribution::Normal(#mean, #std_deviation) }
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomNormalLikeNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let output = &self.output.name;
let input = &self.input.name;
let dist = self.get_distribution();
quote! {
let #output = #input.random_like(#dist);
}
}

fn into_node(self) -> Node<PS> {
Node::RandomNormalLike(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::Distribution");
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorKind, TensorType};
use burn::record::FullPrecisionSettings;

#[test]
fn test_random_normal_like_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(RandomNormalLikeNode::new(
0.0f64,
1.0f64,
TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::tensor::Distribution;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.random_like(Distribution::Normal(0f64, 1f64));
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
25 changes: 25 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::{
pad::PadNode,
prelu::PReluNode,
random_normal::RandomNormalNode,
random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode,
range::RangeNode,
reshape::ReshapeNode,
Expand Down Expand Up @@ -345,6 +346,9 @@ impl ParsedOnnxGraph {
NodeType::Tile => graph.register(Self::tile_conversion(node)),
NodeType::Trilu => graph.register(Self::trilu_conversion(node)),
NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)),
NodeType::RandomNormalLike => {
graph.register(Self::random_normal_like_conversion(node))
}
NodeType::ConstantOfShape => {
graph.register(Self::constant_of_shape_conversion(node))
}
Expand Down Expand Up @@ -472,6 +476,27 @@ impl ParsedOnnxGraph {
RandomNormalNode::new(output_type, mean, scale)
}

fn random_normal_like_conversion(node: Node) -> RandomNormalLikeNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let mean = node
.attrs
.get("mean")
.map(|val| val.clone().into_f32() as f64)
.unwrap_or(0.0f64);
let scale = node
.attrs
.get("scale")
.map(|val| val.clone().into_f32() as f64)
.unwrap_or(1.0f64);

if node.attrs.contains_key("seed") {
warn!("seed attribute is not supported!");
}

RandomNormalLikeNode::new(mean, scale, input, output)
}

pub(crate) fn constant_of_shape_conversion(node: Node) -> ConstantOfShapeNode {
// Additional types needed for ConstantOfShape:
use crate::burn::node::constant_of_shape::ConstantValue;
Expand Down
29 changes: 29 additions & 0 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::PRelu => same_as_input_broadcast(node),
NodeType::Pow => same_as_input_broadcast(node),
NodeType::RandomNormal => random_update_output(node),
NodeType::RandomNormalLike => random_normal_like_update_output(node),
NodeType::RandomUniform => random_update_output(node),
NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
Expand Down Expand Up @@ -204,6 +205,34 @@ fn random_update_output(node: &mut Node) {
})
}

/// Reads & interprets an optional `dtype` attribute
fn random_normal_like_update_output(node: &mut Node) {
let dtype = node
.attrs
.get("dtype")
.map(|val| DataType::from_i32(val.clone().into_i32()).unwrap())
.unwrap_or(DataType::FLOAT);

let elem_type = match dtype {
DataType::FLOAT => ElementType::Float32,
DataType::FLOAT16 => ElementType::Float16,
DataType::DOUBLE => ElementType::Float64,
_ => panic!("Tensor with type {dtype:?} not supported for random output"),
};

if let ArgType::Tensor(tensor) = &node.inputs[0].clone().ty {
if let Some(shape) = tensor.shape.clone() {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type,
dim: shape.len(),
shape: Some(shape),
})
}
} else {
panic!("Only tensor input is valid");
}
}

/// Infer the shape of the output tensor of a Conv2d node
fn linear_update_outputs(node: &mut Node) {
// Extract the configuration of the linear layer (inputs are known)
Expand Down

0 comments on commit 970b9dc

Please sign in to comment.