From 470b26d68cb8c76c40b9327a580afc4c504f9c5f Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 24 Jan 2020 16:51:22 +0200 Subject: [PATCH] Disallow bad input types in the estimator's GetOutputSchema method --- .../CountFeatureSelection.cs | 2 ++ .../MutualInformationFeatureSelection.cs | 13 +++++++ .../Transformers/FeatureSelectionTests.cs | 36 +++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index 682b5660e4..1e706772a9 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -171,6 +171,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName); if (!CountFeatureSelectionUtils.IsValidColumnType(col.ItemType)) throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.InputColumnName); + if (col.Kind == SchemaShape.Column.VectorKind.VariableVector) + throw _host.ExceptUserArg(nameof(inputSchema), $"Variable length column '{col.Name}' is not allowed"); var metadata = new List(); if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 3daa7625dd..16c093935e 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -219,14 +219,27 @@ public ITransformer Fit(IDataView input) public SchemaShape GetOutputSchema(SchemaShape inputSchema) { _host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryFindColumn(_labelColumnName, out var label)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", $"Label column '{_labelColumnName}' not found in input schema"); + if (!(label.IsKey || MutualInformationFeatureSelectionUtils.IsValidColumnType(label.ItemType))) + { + throw _host.ExceptUserArg(nameof(inputSchema), + $"Label column '{_labelColumnName}' does not have compatible type. Expected types are float, double, int, bool and key."); + } + var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _columns) { if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName); if (!MutualInformationFeatureSelectionUtils.IsValidColumnType(col.ItemType)) + { throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double, int, bool and key.", colPair.inputColumnName); + } + if (col.Kind == SchemaShape.Column.VectorKind.VariableVector) + throw _host.ExceptUserArg(nameof(inputSchema), $"Variable length column '{col.Name}' is not allowed"); var metadata = new List(); if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); diff --git a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs index ed7f0cd8e0..f14f2821df 100644 --- a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.IO; using Microsoft.ML.Data; using Microsoft.ML.Data.IO; @@ -220,5 +221,40 @@ public void TestMutualInformationOldSavingAndLoading() } Done(); } + + [Fact] + public void TestFeatureSelectionWithBadInput() + { + string dataPath = GetDataPath("breast-cancer.txt"); + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("BadLabel", DataKind.UInt32, 0), + new TextLoader.Column("Label", DataKind.Single, 0), + new TextLoader.Column("Features", DataKind.String, 1, 9), + }); + + var ex = Assert.Throws(() => + { + var pipeline = ML.Transforms.Text.TokenizeIntoWords("Features") + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("Features")); + var model = pipeline.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + }); + Assert.Contains("Variable length column 'Features' is not allowed", ex.Message); + + ex = Assert.Throws(() => + { + var pipeline = ML.Transforms.Text.TokenizeIntoWords("Features") + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("Features", labelColumnName: "BadLabel")); + var model = pipeline.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + }); + Assert.Contains("Label column 'BadLabel' does not have compatible type. Expected types are float, double, int, bool and key.", ex.Message); + + ex = Assert.Throws(() => + { + var pipeline = ML.Transforms.Text.TokenizeIntoWords("Features") + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("Features")); + var model = pipeline.GetOutputSchema(SchemaShape.Create(dataView.Schema)); + }); + Assert.Contains("Column 'Features' does not have compatible type. Expected types are float, double, int, bool and key.", ex.Message); + } } }