Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX op Random Normal Like #2441

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ represent the corresponding Burn Op.
| [QLinearMatMul][124] | ❌ | ❌ |
| [QuantizeLinear][125] | ❌ | ❌ |
| [RandomNormal][126] | ✅ | ✅ |
| [RandomNormalLike][127] | | ✅ |
| [RandomNormalLike][127] | | ✅ |
| [RandomUniform][128] | ✅ | ✅ |
| [RandomUniformLike][129] | ❌ | ✅ |
| [Range][130] | ✅ | ✅ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ fn main() {
.input("tests/pow/pow_int.onnx")
.input("tests/prelu/prelu.onnx")
.input("tests/random_normal/random_normal.onnx")
.input("tests/random_normal_like/random_normal_like.onnx")
.input("tests/random_uniform/random_uniform.onnx")
.input("tests/range/range.onnx")
.input("tests/recip/recip.onnx")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pytorch2.2.0:�
P
onnx::RandomNormalLike_01/RandomNormalLike"RandomNormalLike*
dtype�
main_graphZ.
onnx::RandomNormalLike_0



b
1



B
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python3

# used to generate model: random_normal_like.onnx

import torch
import torch.nn as nn


class RandomNormalLikeModel(nn.Module):
def __init__(self):
super(RandomNormalLikeModel, self).__init__()

def forward(self, x):
return torch.randn_like(x)


def main():
# Set seed for reproducibility
torch.manual_seed(42)

# Set print options for better precision output
torch.set_printoptions(precision=8)

# Export Random NormalLike Model
model = RandomNormalLikeModel()
model.eval()
device = torch.device("cpu")

# Generate test input: a 2D matrix or batch of 2D matrices
file_name = "random_normal_like.onnx"
test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices
torch.onnx.export(model,
test_input,
file_name,
verbose=False,
opset_version=16)

print("Finished exporting model to {}".format(file_name))

# Output some test data for use in the test
print("Test input data: {}".format(test_input))
print("Test input data shape: {}".format(test_input.shape))
output = model.forward(test_input)
print("Test output data shape: {}".format(output.shape))
print("Test output: {}".format(output))


if __name__ == '__main__':
main()
13 changes: 13 additions & 0 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ include_models!(
pow_int,
prelu,
random_normal,
random_normal_like,
random_uniform,
range,
recip,
Expand Down Expand Up @@ -2157,6 +2158,18 @@ mod tests {
assert_eq!(expected_shape, output.shape());
}

#[test]
fn random_normal_like() {
let device = Default::default();
let model = random_normal_like::Model::<Backend>::new(&device);
let input = TensorData::zeros::<f64, _>(Shape::from([2, 4, 4]));
let expected_shape = Shape::from([2, 4, 4]);

let output = model.forward(input.into());

assert_eq!(expected_shape, output.shape());
}

#[test]
fn constant_of_shape() {
// This tests shape is being passed directly to the model
Expand Down
13 changes: 8 additions & 5 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use super::{
gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode,
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode,
range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode,
squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode,
unsqueeze::UnsqueezeNode,
prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode,
slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode,
unary::UnaryNode, unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
Expand Down Expand Up @@ -121,8 +121,9 @@ pub enum Node<PS: PrecisionSettings> {
Unary(UnaryNode),
Unsqueeze(UnsqueezeNode),
Where(WhereNode),
RandomUniform(RandomUniformNode),
RandomNormal(RandomNormalNode),
RandomNormalLike(RandomNormalLikeNode),
RandomUniform(RandomUniformNode),
ConstantOfShape(ConstantOfShapeNode),
// For now, we have to keep the precision settings in order to correctly serialize the fields
// into the right data types.
Expand Down Expand Up @@ -172,6 +173,7 @@ macro_rules! match_all {
Node::Unsqueeze(node) => $func(node),
Node::Where(node) => $func(node),
Node::RandomNormal(node) => $func(node),
Node::RandomNormalLike(node) => $func(node),
Node::RandomUniform(node) => $func(node),
Node::ConstantOfShape(node) => $func(node),
_ => unimplemented!(),
Expand Down Expand Up @@ -230,6 +232,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Unsqueeze(_) => "unsqueeze",
Node::Where(_) => "where",
Node::RandomNormal(_) => "random_normal",
Node::RandomNormalLike(_) => "random_normal_like",
Node::RandomUniform(_) => "random_uniform",
Node::ConstantOfShape(_) => "constant_of_shape",
_ => unimplemented!(),
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub(crate) mod mean;
pub(crate) mod pad;
pub(crate) mod prelu;
pub(crate) mod random_normal;
pub(crate) mod random_normal_like;
pub(crate) mod random_uniform;
pub(crate) mod range;
pub(crate) mod reshape;
Expand Down
101 changes: 101 additions & 0 deletions crates/burn-import/src/burn/node/random_normal_like.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use super::{Node, NodeCodegen};
use crate::burn::{Scope, TensorType, Type};
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use quote::quote;

#[derive(Debug, Clone, new)]
pub struct RandomNormalLikeNode {
pub mean: f64,
pub scale: f64,
pub input: TensorType,
pub output: TensorType,
}

impl RandomNormalLikeNode {
// Set distribution parameters based on mean and scale
fn get_distribution(&self) -> TokenStream {
let mean = self.mean;
let std_deviation = self.scale;
quote! { Distribution::Normal(#mean, #std_deviation) }
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomNormalLikeNode {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}

fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}

fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream {
let output = &self.output.name;
let input = &self.input.name;
let dist = self.get_distribution();
quote! {
let #output = #input.random_like(#dist);
}
}

fn into_node(self) -> Node<PS> {
Node::RandomNormalLike(self)
}

fn register_imports(&self, imports: &mut crate::burn::BurnImports) {
imports.register("burn::tensor::Distribution");
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorKind, TensorType};
use burn::record::FullPrecisionSettings;

#[test]
fn test_random_normal_like_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(RandomNormalLikeNode::new(
0.0f64,
1.0f64,
TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::tensor::Distribution;
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
phantom: core::marker::PhantomData<B>,
device: burn::module::Ignored<B::Device>,
}

impl<B: Backend> Model<B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
Self {
phantom: core::marker::PhantomData,
device: burn::module::Ignored(device.clone()),
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let output = input.random_like(Distribution::Normal(0f64, 1f64));
output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
25 changes: 25 additions & 0 deletions crates/burn-import/src/onnx/to_burn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::{
pad::PadNode,
prelu::PReluNode,
random_normal::RandomNormalNode,
random_normal_like::RandomNormalLikeNode,
random_uniform::RandomUniformNode,
range::RangeNode,
reshape::ReshapeNode,
Expand Down Expand Up @@ -345,6 +346,9 @@ impl ParsedOnnxGraph {
NodeType::Tile => graph.register(Self::tile_conversion(node)),
NodeType::Trilu => graph.register(Self::trilu_conversion(node)),
NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)),
NodeType::RandomNormalLike => {
graph.register(Self::random_normal_like_conversion(node))
}
NodeType::ConstantOfShape => {
graph.register(Self::constant_of_shape_conversion(node))
}
Expand Down Expand Up @@ -472,6 +476,27 @@ impl ParsedOnnxGraph {
RandomNormalNode::new(output_type, mean, scale)
}

fn random_normal_like_conversion(node: Node) -> RandomNormalLikeNode {
let input = TensorType::from(node.inputs.first().unwrap());
let output = TensorType::from(node.outputs.first().unwrap());
let mean = node
.attrs
.get("mean")
.map(|val| val.clone().into_f32() as f64)
.unwrap_or(0.0f64);
let scale = node
.attrs
.get("scale")
.map(|val| val.clone().into_f32() as f64)
.unwrap_or(1.0f64);

if node.attrs.contains_key("seed") {
warn!("seed attribute is not supported!");
}

RandomNormalLikeNode::new(mean, scale, input, output)
}

pub(crate) fn constant_of_shape_conversion(node: Node) -> ConstantOfShapeNode {
// Additional types needed for ConstantOfShape:
use crate::burn::node::constant_of_shape::ConstantValue;
Expand Down
29 changes: 29 additions & 0 deletions crates/onnx-ir/src/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub fn dim_inference(node: &mut Node) {
NodeType::PRelu => same_as_input_broadcast(node),
NodeType::Pow => same_as_input_broadcast(node),
NodeType::RandomNormal => random_update_output(node),
NodeType::RandomNormalLike => random_normal_like_update_output(node),
NodeType::RandomUniform => random_update_output(node),
NodeType::Range => range_update_outputs(node),
NodeType::Reciprocal => same_as_input(node),
Expand Down Expand Up @@ -204,6 +205,34 @@ fn random_update_output(node: &mut Node) {
})
}

/// Reads & interprets an optional `dtype` attribute
fn random_normal_like_update_output(node: &mut Node) {
let dtype = node
.attrs
.get("dtype")
.map(|val| DataType::from_i32(val.clone().into_i32()).unwrap())
.unwrap_or(DataType::FLOAT);

let elem_type = match dtype {
DataType::FLOAT => ElementType::Float32,
DataType::FLOAT16 => ElementType::Float16,
DataType::DOUBLE => ElementType::Float64,
_ => panic!("Tensor with type {dtype:?} not supported for random output"),
};

if let ArgType::Tensor(tensor) = &node.inputs[0].clone().ty {
if let Some(shape) = tensor.shape.clone() {
node.outputs[0].ty = ArgType::Tensor(TensorType {
elem_type,
dim: shape.len(),
shape: Some(shape),
})
}
} else {
panic!("Only tensor input is valid");
}
}

/// Infer the shape of the output tensor of a Conv2d node
fn linear_update_outputs(node: &mut Node) {
// Extract the configuration of the linear layer (inputs are known)
Expand Down