From a0b6881588df2b4b66ebc054682cc79b94146e33 Mon Sep 17 00:00:00 2001 From: tiruka Date: Thu, 24 Oct 2024 16:34:30 +0900 Subject: [PATCH 1/6] add random normal like python code to generate onnx model --- .../random_normal_like.onnx | 15 ++++++ .../random_normal_like/random_normal_like.py | 49 +++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx create mode 100644 crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py diff --git a/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx new file mode 100644 index 0000000000..6e4b6f97c6 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.onnx @@ -0,0 +1,15 @@ +pytorch2.2.0: +P +onnx::RandomNormalLike_01/RandomNormalLike"RandomNormalLike* +dtype +main_graphZ. +onnx::RandomNormalLike_0 + + + +b +1 + + + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py new file mode 100644 index 0000000000..64aed94e3b --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/random_normal_like/random_normal_like.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# used to generate model: random_normal_like.onnx + +import torch +import torch.nn as nn + + +class RandomNormalLikeModel(nn.Module): + def __init__(self): + super(RandomNormalLikeModel, self).__init__() + + def forward(self, x): + return torch.randn_like(x) + + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + + # Set print options for better precision output + torch.set_printoptions(precision=8) + + # Export Random NormalLike Model + model = RandomNormalLikeModel() + model.eval() + device = torch.device("cpu") + + # Generate test input: a 2D matrix or batch of 2D matrices + file_name = "random_normal_like.onnx" + test_input = torch.randn(2, 4, 4, device=device) # 2 batches of 4x4 matrices + torch.onnx.export(model, + test_input, + file_name, + verbose=False, + opset_version=16) + + print("Finished exporting model to {}".format(file_name)) + + # Output some test data for use in the test + print("Test input data: {}".format(test_input)) + print("Test input data shape: {}".format(test_input.shape)) + output = model.forward(test_input) + print("Test output data shape: {}".format(output.shape)) + print("Test output: {}".format(output)) + + +if __name__ == '__main__': + main() \ No newline at end of file From 4eb3ac2c867a1ce3b896cf7a6a57b3111a42b0b5 Mon Sep 17 00:00:00 2001 From: tiruka Date: Tue, 29 Oct 2024 13:17:53 +0900 Subject: [PATCH 2/6] add random normal like node --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/src/burn/node/base.rs | 8 +- crates/burn-import/src/burn/node/mod.rs | 1 + .../src/burn/node/random_normal_like.rs | 117 ++++++++++++++++++ crates/onnx-ir/src/dim_inference.rs | 1 + 5 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 crates/burn-import/src/burn/node/random_normal_like.rs diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index 15eb22db5a..27930aac97 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -133,7 +133,7 @@ represent the corresponding Burn Op. | [QLinearMatMul][124] | ❌ | ❌ | | [QuantizeLinear][125] | ❌ | ❌ | | [RandomNormal][126] | ✅ | ✅ | -| [RandomNormalLike][127] | ❌ | ✅ | +| [RandomNormalLike][127] | ✅ | ✅ | | [RandomUniform][128] | ✅ | ✅ | | [RandomUniformLike][129] | ❌ | ✅ | | [Range][130] | ✅ | ✅ | diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 427bbdcad7..e6e89653e8 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -10,7 +10,8 @@ use super::{ gather_elements::GatherElementsNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, - prelu::PReluNode, random_normal::RandomNormalNode, random_uniform::RandomUniformNode, + prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, + random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, @@ -121,8 +122,9 @@ pub enum Node { Unary(UnaryNode), Unsqueeze(UnsqueezeNode), Where(WhereNode), - RandomUniform(RandomUniformNode), RandomNormal(RandomNormalNode), + RandomNormalLike(RandomNormalLikeNode), + RandomUniform(RandomUniformNode), ConstantOfShape(ConstantOfShapeNode), // For now, we have to keep the precision settings in order to correctly serialize the fields // into the right data types. @@ -172,6 +174,7 @@ macro_rules! match_all { Node::Unsqueeze(node) => $func(node), Node::Where(node) => $func(node), Node::RandomNormal(node) => $func(node), + Node::RandomNormalLike(node) => $func(node), Node::RandomUniform(node) => $func(node), Node::ConstantOfShape(node) => $func(node), _ => unimplemented!(), @@ -230,6 +233,7 @@ impl Node { Node::Unsqueeze(_) => "unsqueeze", Node::Where(_) => "where", Node::RandomNormal(_) => "random_normal", + Node::RandomNormalLike(_) => "random_normal_like", Node::RandomUniform(_) => "random_uniform", Node::ConstantOfShape(_) => "constant_of_shape", _ => unimplemented!(), diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index 6313d25f63..57dcae1c6f 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -30,6 +30,7 @@ pub(crate) mod mean; pub(crate) mod pad; pub(crate) mod prelu; pub(crate) mod random_normal; +pub(crate) mod random_normal_like; pub(crate) mod random_uniform; pub(crate) mod range; pub(crate) mod reshape; diff --git a/crates/burn-import/src/burn/node/random_normal_like.rs b/crates/burn-import/src/burn/node/random_normal_like.rs new file mode 100644 index 0000000000..e9bbc7a715 --- /dev/null +++ b/crates/burn-import/src/burn/node/random_normal_like.rs @@ -0,0 +1,117 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct RandomNormalLikeNode { + pub mean: f64, + pub scale: f64, + pub input: TensorType, // Input tensor to copy shape from + pub output: TensorType, +} + +impl RandomNormalLikeNode { + // Get shape from the input tensor + fn get_output_shape(&self) -> TokenStream { + let shape_it = self.input.shape.as_ref().expect("Input tensor has no shape!").iter(); + quote! { Shape::new([#(#shape_it),*]) } + } + + // Set distribution parameters based on mean and scale + fn get_distribution(&self) -> TokenStream { + let mean = self.mean; + let std_deviation = self.scale; // Scale parameter as per ONNX specs + quote! { Distribution::Normal(#mean, #std_deviation) } + } +} + +impl NodeCodegen for RandomNormalLikeNode { + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] // Input tensor type + } + + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { + let output = &self.output.name; + let shape = self.get_output_shape(); + let dist = self.get_distribution(); + quote! { + let #output = Tensor::random(#shape, #dist, &*self.device); + } + } + + fn into_node(self) -> Node { + Node::RandomNormalLike(self) + } + + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::tensor::Distribution"); + imports.register("burn::prelude::Shape"); + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::test::assert_tokens, + TensorType, TensorKind, + }; + + #[test] + fn test_random_normal_like_codegen() { + let mut graph = BurnGraph::::default(); + + graph.register(RandomNormalLikeNode::new( + 0.0f64, + 1.0f64, + TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])), + TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])), + )); + + graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); + + let expected = quote! { + use burn::prelude::Shape; + use burn::tensor::Distribution; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, input: Tensor) -> Tensor { + let output = Tensor::random( + Shape::new([2usize, 3usize]), + Distribution::Normal(0f64, 1f64), + &*self.device, + ); + output + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index cd22a1403a..8655cf2a4e 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -61,6 +61,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::PRelu => same_as_input_broadcast(node), NodeType::Pow => same_as_input_broadcast(node), NodeType::RandomNormal => random_update_output(node), + NodeType::RandomNormalLike => same_as_input(node), NodeType::RandomUniform => random_update_output(node), NodeType::Range => range_update_outputs(node), NodeType::Reciprocal => same_as_input(node), From 2013178196912ea1a0458ccdd81f111b28ce90df Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 30 Oct 2024 20:41:57 +0900 Subject: [PATCH 3/6] modify onnx burn to add new op --- crates/burn-import/src/burn/node/base.rs | 7 ++--- .../src/burn/node/random_normal_like.rs | 19 ++++++------ crates/burn-import/src/onnx/to_burn.rs | 25 ++++++++++++++++ crates/onnx-ir/src/dim_inference.rs | 30 ++++++++++++++++++- 4 files changed, 67 insertions(+), 14 deletions(-) diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index e6e89653e8..123a29a223 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -11,10 +11,9 @@ use super::{ layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, random_normal_like::RandomNormalLikeNode, - random_uniform::RandomUniformNode, - range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, slice::SliceNode, - squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, unary::UnaryNode, - unsqueeze::UnsqueezeNode, + random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, resize::ResizeNode, + slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, tile::TileNode, trilu::TriluNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; diff --git a/crates/burn-import/src/burn/node/random_normal_like.rs b/crates/burn-import/src/burn/node/random_normal_like.rs index e9bbc7a715..695305293b 100644 --- a/crates/burn-import/src/burn/node/random_normal_like.rs +++ b/crates/burn-import/src/burn/node/random_normal_like.rs @@ -8,14 +8,19 @@ use quote::quote; pub struct RandomNormalLikeNode { pub mean: f64, pub scale: f64, - pub input: TensorType, // Input tensor to copy shape from + pub input: TensorType, pub output: TensorType, } impl RandomNormalLikeNode { // Get shape from the input tensor fn get_output_shape(&self) -> TokenStream { - let shape_it = self.input.shape.as_ref().expect("Input tensor has no shape!").iter(); + let shape_it = self + .input + .shape + .as_ref() + .expect("Input tensor has no shape!") + .iter(); quote! { Shape::new([#(#shape_it),*]) } } @@ -29,7 +34,7 @@ impl RandomNormalLikeNode { impl NodeCodegen for RandomNormalLikeNode { fn input_types(&self) -> Vec { - vec![Type::Tensor(self.input.clone())] // Input tensor type + vec![Type::Tensor(self.input.clone())] } fn output_types(&self) -> Vec { @@ -57,13 +62,9 @@ impl NodeCodegen for RandomNormalLikeNode { #[cfg(test)] mod tests { - use burn::record::FullPrecisionSettings; use super::*; - use crate::burn::{ - graph::BurnGraph, - node::test::assert_tokens, - TensorType, TensorKind, - }; + use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorKind, TensorType}; + use burn::record::FullPrecisionSettings; #[test] fn test_random_normal_like_codegen() { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 4b081f6e43..16d56284cc 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -44,6 +44,7 @@ use crate::{ pad::PadNode, prelu::PReluNode, random_normal::RandomNormalNode, + random_normal_like::RandomNormalLikeNode, random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, @@ -345,6 +346,9 @@ impl ParsedOnnxGraph { NodeType::Tile => graph.register(Self::tile_conversion(node)), NodeType::Trilu => graph.register(Self::trilu_conversion(node)), NodeType::RandomNormal => graph.register(Self::random_normal_conversion(node)), + NodeType::RandomNormalLike => { + graph.register(Self::random_normal_like_conversion(node)) + } NodeType::ConstantOfShape => { graph.register(Self::constant_of_shape_conversion(node)) } @@ -472,6 +476,27 @@ impl ParsedOnnxGraph { RandomNormalNode::new(output_type, mean, scale) } + fn random_normal_like_conversion(node: Node) -> RandomNormalLikeNode { + let input = TensorType::from(node.inputs.first().unwrap()); + let output = TensorType::from(node.outputs.first().unwrap()); + let mean = node + .attrs + .get("mean") + .map(|val| val.clone().into_f32() as f64) + .unwrap_or(0.0f64); + let scale = node + .attrs + .get("scale") + .map(|val| val.clone().into_f32() as f64) + .unwrap_or(1.0f64); + + if node.attrs.contains_key("seed") { + warn!("seed attribute is not supported!"); + } + + RandomNormalLikeNode::new(mean, scale, input, output) + } + pub(crate) fn constant_of_shape_conversion(node: Node) -> ConstantOfShapeNode { // Additional types needed for ConstantOfShape: use crate::burn::node::constant_of_shape::ConstantValue; diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 8655cf2a4e..eabbb781f7 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -61,7 +61,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::PRelu => same_as_input_broadcast(node), NodeType::Pow => same_as_input_broadcast(node), NodeType::RandomNormal => random_update_output(node), - NodeType::RandomNormalLike => same_as_input(node), + NodeType::RandomNormalLike => random_normal_like_update_output(node), NodeType::RandomUniform => random_update_output(node), NodeType::Range => range_update_outputs(node), NodeType::Reciprocal => same_as_input(node), @@ -205,6 +205,34 @@ fn random_update_output(node: &mut Node) { }) } +/// Reads & interprets an optional `dtype` attribute +fn random_normal_like_update_output(node: &mut Node) { + let dtype = node + .attrs + .get("dtype") + .map(|val| DataType::from_i32(val.clone().into_i32()).unwrap()) + .unwrap_or(DataType::FLOAT); + + let elem_type = match dtype { + DataType::FLOAT => ElementType::Float32, + DataType::FLOAT16 => ElementType::Float16, + DataType::DOUBLE => ElementType::Float64, + _ => panic!("Tensor with type {dtype:?} not supported for random output"), + }; + + if let ArgType::Tensor(tensor) = &node.inputs[0].clone().ty { + if let Some(shape) = tensor.shape.clone() { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type, + dim: shape.len(), + shape: Some(shape), + }) + } + } else { + panic!("Only tensor input is valid"); + } +} + /// Infer the shape of the output tensor of a Conv2d node fn linear_update_outputs(node: &mut Node) { // Extract the configuration of the linear layer (inputs are known) From f987270efce6220d8a739432bd0f0c1815415b59 Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 30 Oct 2024 21:50:32 +0900 Subject: [PATCH 4/6] add test on test onnx --- crates/burn-import/onnx-tests/build.rs | 1 + crates/burn-import/onnx-tests/tests/test_onnx.rs | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 198f20eb9c..2229626c00 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -75,6 +75,7 @@ fn main() { .input("tests/pow/pow_int.onnx") .input("tests/prelu/prelu.onnx") .input("tests/random_normal/random_normal.onnx") + .input("tests/random_normal_like/random_normal_like.onnx") .input("tests/random_uniform/random_uniform.onnx") .input("tests/range/range.onnx") .input("tests/recip/recip.onnx") diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 24ddaeca45..7cbf1fac15 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -1,4 +1,4 @@ -#![no_std] +// #![no_std] /// Include generated models in the `model` directory in the target directory. macro_rules! include_models { @@ -84,6 +84,7 @@ include_models!( pow_int, prelu, random_normal, + random_normal_like, random_uniform, range, recip, @@ -2157,6 +2158,18 @@ mod tests { assert_eq!(expected_shape, output.shape()); } + #[test] + fn random_normal_like() { + let device = Default::default(); + let model = random_normal_like::Model::::new(&device); + let input = TensorData::zeros::(Shape::from([2, 4, 4])); + let expected_shape = Shape::from([2, 4, 4]); + + let output = model.forward(input.into()); + + assert_eq!(expected_shape, output.shape()); + } + #[test] fn constant_of_shape() { // This tests shape is being passed directly to the model From 4fee47a64eff58a82e46d8dd8db008c034f7b487 Mon Sep 17 00:00:00 2001 From: tiruka Date: Wed, 30 Oct 2024 22:02:58 +0900 Subject: [PATCH 5/6] revert commentouts --- crates/burn-import/onnx-tests/tests/test_onnx.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/onnx-tests/tests/test_onnx.rs b/crates/burn-import/onnx-tests/tests/test_onnx.rs index 7cbf1fac15..a4acd51477 100644 --- a/crates/burn-import/onnx-tests/tests/test_onnx.rs +++ b/crates/burn-import/onnx-tests/tests/test_onnx.rs @@ -1,4 +1,4 @@ -// #![no_std] +#![no_std] /// Include generated models in the `model` directory in the target directory. macro_rules! include_models { From ae7a82586a58cfec7287d3ac07c386a892c94b38 Mon Sep 17 00:00:00 2001 From: tiruka Date: Thu, 31 Oct 2024 21:58:05 +0900 Subject: [PATCH 6/6] fix review points to respond to dynamically shape --- .../src/burn/node/random_normal_like.rs | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/crates/burn-import/src/burn/node/random_normal_like.rs b/crates/burn-import/src/burn/node/random_normal_like.rs index 695305293b..c502f9c76a 100644 --- a/crates/burn-import/src/burn/node/random_normal_like.rs +++ b/crates/burn-import/src/burn/node/random_normal_like.rs @@ -13,21 +13,10 @@ pub struct RandomNormalLikeNode { } impl RandomNormalLikeNode { - // Get shape from the input tensor - fn get_output_shape(&self) -> TokenStream { - let shape_it = self - .input - .shape - .as_ref() - .expect("Input tensor has no shape!") - .iter(); - quote! { Shape::new([#(#shape_it),*]) } - } - // Set distribution parameters based on mean and scale fn get_distribution(&self) -> TokenStream { let mean = self.mean; - let std_deviation = self.scale; // Scale parameter as per ONNX specs + let std_deviation = self.scale; quote! { Distribution::Normal(#mean, #std_deviation) } } } @@ -43,10 +32,10 @@ impl NodeCodegen for RandomNormalLikeNode { fn forward(&self, _scope: &mut Scope, _node_position: usize) -> TokenStream { let output = &self.output.name; - let shape = self.get_output_shape(); + let input = &self.input.name; let dist = self.get_distribution(); quote! { - let #output = Tensor::random(#shape, #dist, &*self.device); + let #output = #input.random_like(#dist); } } @@ -56,7 +45,6 @@ impl NodeCodegen for RandomNormalLikeNode { fn register_imports(&self, imports: &mut crate::burn::BurnImports) { imports.register("burn::tensor::Distribution"); - imports.register("burn::prelude::Shape"); } } @@ -80,7 +68,6 @@ mod tests { graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]); let expected = quote! { - use burn::prelude::Shape; use burn::tensor::Distribution; use burn::{ module::Module, @@ -103,11 +90,7 @@ mod tests { } #[allow(clippy::let_and_return, clippy::approx_constant)] pub fn forward(&self, input: Tensor) -> Tensor { - let output = Tensor::random( - Shape::new([2usize, 3usize]), - Distribution::Normal(0f64, 1f64), - &*self.device, - ); + let output = input.random_like(Distribution::Normal(0f64, 1f64)); output } }