@@ -3888,6 +3888,62 @@ def _impl_v12(cls, inputs, attr, params):
38883888 return _op .einsum (inputs , equation )
38893889
38903890
3891+ class RandomNormal (OnnxOpConverter ):
3892+ """Operator converter for random_normal"""
3893+
3894+ @classmethod
3895+ def _impl_v1 (cls , inputs , attr , params ):
3896+ dtype = get_type (attr .get ("dtype" , 1 ))
3897+ mean = attr .get ("mean" , 0.0 )
3898+ scale = attr .get ("scale" , 1.0 )
3899+ seed = attr .get ("seed" , None )
3900+ shape = attr ["shape" ]
3901+
3902+ assert dtype in [
3903+ "float32" ,
3904+ "float64" ,
3905+ ], "Only float random value generation is currently supported."
3906+
3907+ if seed is None :
3908+ seed = np .random .randint (1e6 )
3909+ else :
3910+ seed = int (seed )
3911+ key = _random .threefry_key (seed )
3912+ output = _op .random .normal (key , shape , dtype = dtype , mean = mean , scale = scale )
3913+ _ , vals = _expr .TupleWrapper (output , 2 )
3914+ return vals
3915+
3916+
3917+ class RandomNormalLike (OnnxOpConverter ):
3918+ """Operator converter for random_normal_like"""
3919+
3920+ @classmethod
3921+ def _impl_v1 (cls , inputs , attr , params ):
3922+ dtype = attr .get ("dtype" , None )
3923+ scale = attr .get ("scale" , 1.0 )
3924+ mean = attr .get ("mean" , 0.0 )
3925+ seed = attr .get ("seed" , None )
3926+ shape = infer_shape (inputs [0 ])
3927+ if dtype is None :
3928+ dtype = infer_type (inputs [0 ]).checked_type .dtype
3929+ else :
3930+ dtype = get_type (dtype )
3931+
3932+ assert dtype in [
3933+ "float32" ,
3934+ "float64" ,
3935+ ], "Only float random value generation is currently supported."
3936+
3937+ if seed is None :
3938+ seed = np .random .randint (1e6 )
3939+ else :
3940+ seed = int (seed )
3941+ key = _random .threefry_key (seed )
3942+ output = _op .random .normal (key , shape , dtype = dtype , mean = mean , scale = scale )
3943+ _ , vals = _expr .TupleWrapper (output , 2 )
3944+ return vals
3945+
3946+
38913947class RandomUniform (OnnxOpConverter ):
38923948 """Operator converter for random_uniform"""
38933949
@@ -3906,6 +3962,38 @@ def _impl_v1(cls, inputs, attr, params):
39063962
39073963 if seed is None :
39083964 seed = np .random .randint (1e6 )
3965+ else :
3966+ seed = int (seed )
3967+ key = _random .threefry_key (seed )
3968+ output = _op .random .uniform (key , shape , dtype = dtype , low = low , high = high )
3969+ _ , vals = _expr .TupleWrapper (output , 2 )
3970+ return vals
3971+
3972+
3973+ class RandomUniformLike (OnnxOpConverter ):
3974+ """Operator converter for random_uniform_like"""
3975+
3976+ @classmethod
3977+ def _impl_v1 (cls , inputs , attr , params ):
3978+ dtype = attr .get ("dtype" , None )
3979+ high = attr .get ("high" , 1.0 )
3980+ low = attr .get ("low" , 0.0 )
3981+ seed = attr .get ("seed" , None )
3982+ shape = infer_shape (inputs [0 ])
3983+ if dtype is None :
3984+ dtype = infer_type (inputs [0 ]).checked_type .dtype
3985+ else :
3986+ dtype = get_type (dtype )
3987+
3988+ assert dtype in [
3989+ "float32" ,
3990+ "float64" ,
3991+ ], "Only float random value generation is currently supported."
3992+
3993+ if seed is None :
3994+ seed = np .random .randint (1e6 )
3995+ else :
3996+ seed = int (seed )
39093997 key = _random .threefry_key (seed )
39103998 output = _op .random .uniform (key , shape , dtype = dtype , low = low , high = high )
39113999 _ , vals = _expr .TupleWrapper (output , 2 )
@@ -4396,7 +4484,10 @@ def _get_convert_map(opset):
43964484 "QLinearGlobalAveragePool" : QLinearGlobalAveragePool .get_converter (opset ),
43974485 "QLinearLeakyRelu" : QLinearLeakyRelu .get_converter (opset ),
43984486 # Random number generation.
4487+ "RandomNormal" : RandomNormal .get_converter (opset ),
4488+ "RandomNormalLike" : RandomNormalLike .get_converter (opset ),
43994489 "RandomUniform" : RandomUniform .get_converter (opset ),
4490+ "RandomUniformLike" : RandomUniformLike .get_converter (opset ),
44004491 # Loss functions / training
44014492 "NegativeLogLikelihoodLoss" : NegativeLogLikelihoodLoss .get_converter (opset ),
44024493 "SoftmaxCrossEntropyLoss" : SoftmaxCrossEntropyLoss .get_converter (opset ),
0 commit comments