diff --git a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs index 60e32327ff..8a3d255767 100644 --- a/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/SlotsDroppingTransformer.cs @@ -454,7 +454,6 @@ private static readonly FuncInstanceMethodInfo1 _makeVecTrivia private readonly SlotsDroppingTransformer _parent; private readonly int[] _cols; private readonly DataViewType[] _srcTypes; - private readonly DataViewType[] _rawTypes; private readonly DataViewType[] _dstTypes; private readonly SlotDropper[] _slotDropper; // Track if all the slots of the column are to be dropped. @@ -467,7 +466,6 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema) _parent = parent; _cols = new int[_parent.ColumnPairs.Length]; _srcTypes = new DataViewType[_parent.ColumnPairs.Length]; - _rawTypes = new DataViewType[_parent.ColumnPairs.Length]; _dstTypes = new DataViewType[_parent.ColumnPairs.Length]; _slotDropper = new SlotDropper[_parent.ColumnPairs.Length]; _suppressed = new bool[_parent.ColumnPairs.Length]; @@ -480,8 +478,8 @@ public Mapper(SlotsDroppingTransformer parent, DataViewSchema inputSchema) _srcTypes[i] = inputSchema[_cols[i]].Type; VectorDataViewType srcVectorType = _srcTypes[i] as VectorDataViewType; - _rawTypes[i] = srcVectorType?.ItemType ?? _srcTypes[i]; - if (!IsValidColumnType(_rawTypes[i])) + var rawType = srcVectorType?.ItemType ?? _srcTypes[i]; + if (!IsValidColumnType(rawType)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); int valueCount = srcVectorType?.Size ?? 1; @@ -896,27 +894,26 @@ public void SaveAsOnnx(OnnxContext ctx) public bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { string opType; - if (_srcTypes[iinfo] is VectorDataViewType) + var slots = _slotDropper[iinfo].GetPreservedSlots(); + // vector column is not suppressed + if (slots.Count() > 0) { opType = "GatherElements"; - IEnumerable slots = _slotDropper[iinfo].GetPreservedSlots(); var slotsVar = ctx.AddInitializer(slots, new long[] { 1, slots.Count() }, "PreservedSlots"); var node = ctx.CreateNode(opType, new[] { srcVariableName, slotsVar }, new[] { dstVariableName }, ctx.GetNodeName(opType), ""); node.AddAttribute("axis", 1); } + // When the vector/scalar columnn is suppressed, we simply create an empty output vector else { string constVal; - long[] dims = { 1, 1 }; - float[] floatVals = { 0.0f }; - long[] keyVals = { 0 }; - string[] stringVals = { "" }; - if (_rawTypes[iinfo] is TextDataViewType) - constVal = ctx.AddInitializer(stringVals, dims); - else if (_rawTypes[iinfo] is KeyDataViewType) - constVal = ctx.AddInitializer(keyVals, dims); + var type = _srcTypes[iinfo].GetItemType(); + if (type == TextDataViewType.Instance) + constVal = ctx.AddInitializer(new string[] { "" }, new long[] { 1, 1 }); + else if (type == NumberDataViewType.Single) + constVal = ctx.AddInitializer(new float[] { 0 }, new long[] { 1, 1 }); else - constVal = ctx.AddInitializer(floatVals, dims); + constVal = ctx.AddInitializer(new double[] { 0 }, new long[] { 1, 1 }); opType = "Identity"; ctx.CreateNode(opType, constVal, dstVariableName, ctx.GetNodeName(opType), ""); diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index 1e706772a9..1c09ae54ef 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -28,7 +28,7 @@ namespace Microsoft.ML.Transforms /// | | | /// | -- | -- | /// | Does this estimator need to look at the data to train its parameters? | Yes | - /// | Input column data type | Vector or scalar of numeric, [text](xref:Microsoft.ML.Data.TextDataViewType) or [key](xref:Microsoft.ML.Data.KeyDataViewType) data types| + /// | Input column data type | Vector or scalar of , or [text](xref:Microsoft.ML.Data.TextDataViewType) data types| /// | Output column data type | Same as the input column| /// | Exportable to ONNX | Yes | /// diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index c342052627..f365c7f423 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1648,57 +1648,62 @@ public void UseKeyDataViewTypeAsUInt32InOnnxInput() Done(); } - [Fact] - public void FeatureSelectionOnnxTest() + [Theory] + [InlineData(DataKind.String)] + [InlineData(DataKind.Single)] + [InlineData(DataKind.Double)] + public void FeatureSelectionOnnxTest(DataKind dataKind) { var mlContext = new MLContext(seed: 1); string dataPath = GetDataPath("breast-cancer.txt"); - var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { - new TextLoader.Column("ScalarFloat", DataKind.Single, 6), - new TextLoader.Column("VectorFloat", DataKind.Single, 1, 4), - new TextLoader.Column("VectorDouble", DataKind.Double, 4, 8), + var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Scalar", dataKind, 6), + new TextLoader.Column("Vector", dataKind, 1, 6), new TextLoader.Column("Label", DataKind.Boolean, 0) }); - var columns = new[] { - new CountFeatureSelectingEstimator.ColumnOptions("FeatureSelectDouble", "VectorDouble", count: 1), - new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing690", "ScalarFloat", count: 690), - new CountFeatureSelectingEstimator.ColumnOptions("ScalFeatureSelectMissing100", "ScalarFloat", count: 100), - new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing690", "VectorDouble", count: 690), - new CountFeatureSelectingEstimator.ColumnOptions("VecFeatureSelectMissing100", "VectorDouble", count: 100) - }; - var pipeline = ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("FeatureSelect", "VectorFloat", count: 1) - .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount(columns)) - .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIScalarFloat", "ScalarFloat")) - .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("FeatureSelectMIVectorFloat", "VectorFloat")); + IEstimator[] pipelines = + { + // one or more features selected + mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("VectorOutput", "Vector", count: 690). + Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("ScalarOutput", "Scalar", count: 100)), - var model = pipeline.Fit(dataView); - var transformedData = model.Transform(dataView); - var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + // no feature selected => column suppressed + mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("VectorOutput", "Vector", count: 800). + Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("ScalarOutput", "Scalar", count: 800)), - var onnxFileName = "countfeatures.onnx"; - var onnxModelPath = GetOutputPath(onnxFileName); + mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("VectorOutput", "Vector"). + Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("ScalarOutput", "Scalar")) + }; + for (int i = 0; i < pipelines.Length; i++) + { + //There's currently no support for suppressed string columns, since onnx string variable initiation is not supported + if (dataKind == DataKind.String && i > 0) + break; + var model = pipelines[i].Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); - SaveOnnxModel(onnxModel, onnxModelPath, null); + var onnxFileName = "countfeatures.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); - if (IsOnnxRuntimeSupported()) - { - // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. - var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); - var onnxTransformer = onnxEstimator.Fit(dataView); - var onnxResult = onnxTransformer.Transform(dataView); - CompareSelectedColumns("FeatureSelectMIScalarFloat", "FeatureSelectMIScalarFloat", transformedData, onnxResult); - CompareSelectedColumns("FeatureSelectMIVectorFloat", "FeatureSelectMIVectorFloat", transformedData, onnxResult); - CompareSelectedColumns("ScalFeatureSelectMissing690", "ScalFeatureSelectMissing690", transformedData, onnxResult); - CompareSelectedColumns("VecFeatureSelectMissing690", "VecFeatureSelectMissing690", transformedData, onnxResult); + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareResults("VectorOutput", "VectorOutput", transformedData, onnxResult); + CompareResults("ScalarOutput", "ScalarOutput", transformedData, onnxResult); + } } Done(); } - - [Fact] public void SelectColumnsOnnxTest() {