diff --git a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs index 91add2ad98..7231a14487 100644 --- a/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs +++ b/src/Microsoft.ML.Mkl.Components/VectorWhitening.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Internal.CpuMath; using Microsoft.ML.Internal.Internallearn; using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Model.OnnxConverter; using Microsoft.ML.Runtime; using Microsoft.ML.Transforms; @@ -546,7 +547,7 @@ public static extern int Svd(Layout layout, SvdJob jobu, SvdJob jobvt, private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); - private sealed class Mapper : OneToOneMapperBase + private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly VectorWhiteningTransformer _parent; private readonly int[] _cols; @@ -607,6 +608,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func 0) ? ex.Rank : cslotSrc; + var model = _parent._models[iinfo]; ValueGetter> del = (ref VBuffer dst) => @@ -618,6 +620,51 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func true; + + public void SaveAsOnnx(OnnxContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + int numColumns = _parent.ColumnPairs.Length; + for (int iinfo = 0; iinfo < numColumns; ++iinfo) + { + string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) + continue; + + string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName; + string srcVariableName = ctx.GetVariableName(inputColumnName); + string dstVariableName = ctx.AddIntermediateVariable(_srcTypes[iinfo], outputColumnName, true); + SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName); + } + } + + private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) + { + var model = _parent._models[iinfo]; + int dimension = _srcTypes[iinfo].GetValueCount(); + Host.Assert(model.Length == dimension * dimension); + + var parameters = _parent._columns[iinfo]; + Host.Assert(parameters.Kind == WhiteningKind.PrincipalComponentAnalysis || parameters.Kind == WhiteningKind.ZeroPhaseComponentAnalysis); + + int rank = (parameters.Kind == WhiteningKind.PrincipalComponentAnalysis && parameters.Rank > 0) ? parameters.Rank : dimension; + Host.CheckParam(rank <= dimension, nameof(rank), "Rank must be at most the dimension of untransformed data."); + + long[] modelDimension = { rank, dimension }; + + var opType = "Gemm"; + var modelName = ctx.AddInitializer(model.Take(rank * dimension), modelDimension, "model"); + var zeroValueName = ctx.AddInitializer((float)0); + + var gemmOutput = ctx.AddIntermediateVariable(null, "GemmOutput", true); + var node = ctx.CreateNode(opType, new[] { modelName, srcVariableName, zeroValueName }, new[] { gemmOutput }, ctx.GetNodeName(opType), ""); + node.AddAttribute("transB", 1); + + opType = "Transpose"; + ctx.CreateNode(opType, new[] { gemmOutput }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); + } + private ValueGetter GetSrcGetter(DataViewRow input, int iinfo) { Host.AssertValue(input); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 82f75fafe5..240b7c0c87 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -309,6 +309,40 @@ public void BinaryClassificationTrainersOnnxConversionTest() Done(); } + [Fact] + public void TestVectorWhiteningOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + string dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); + var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("label", DataKind.Single, 11), + new TextLoader.Column("features", DataKind.Single, 0, 10) + }, hasHeader: true, separatorChar: ';'); + + var pipeline = new VectorWhiteningEstimator(mlContext, "whitened1", "features") + .Append(new VectorWhiteningEstimator(mlContext, "whitened2", "features", kind: WhiteningKind.PrincipalComponentAnalysis, rank: 5)); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + var onnxFileName = $"VectorWhitening.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4VectorColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); // whitened1 + CompareSelectedR4VectorColumns(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); // whitened2 + } + Done(); + } + private class DataPoint { [VectorType(3)]