Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Transforms/CountFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName);
if (!CountFeatureSelectionUtils.IsValidColumnType(col.ItemType))
throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.InputColumnName);
if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
throw _host.ExceptUserArg(nameof(inputSchema), $"Variable length column '{col.Name}' is not allowed");
var metadata = new List<SchemaShape.Column>();
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
metadata.Add(slotMeta);
Expand Down
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,27 @@ public ITransformer Fit(IDataView input)
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
_host.CheckValue(inputSchema, nameof(inputSchema));

if (!inputSchema.TryFindColumn(_labelColumnName, out var label))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "label", $"Label column '{_labelColumnName}' not found in input schema");
if (!(label.IsKey || MutualInformationFeatureSelectionUtils.IsValidColumnType(label.ItemType)))
{
throw _host.ExceptUserArg(nameof(inputSchema),
$"Label column '{_labelColumnName}' does not have compatible type. Expected types are float, double, int, bool and key.");
}

var result = inputSchema.ToDictionary(x => x.Name);
foreach (var colPair in _columns)
{
if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col))
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName);
if (!MutualInformationFeatureSelectionUtils.IsValidColumnType(col.ItemType))
{
throw _host.ExceptUserArg(nameof(inputSchema),
"Column '{0}' does not have compatible type. Expected types are float, double, int, bool and key.", colPair.inputColumnName);
}
if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
throw _host.ExceptUserArg(nameof(inputSchema), $"Variable length column '{col.Name}' is not allowed");
var metadata = new List<SchemaShape.Column>();
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.SlotNames, out var slotMeta))
metadata.Add(slotMeta);
Expand Down
36 changes: 36 additions & 0 deletions test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
Expand Down Expand Up @@ -220,5 +221,40 @@ public void TestMutualInformationOldSavingAndLoading()
}
Done();
}

[Fact]
public void TestFeatureSelectionWithBadInput()
{
string dataPath = GetDataPath("breast-cancer.txt");
var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("BadLabel", DataKind.UInt32, 0),
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("Features", DataKind.String, 1, 9),
});

var ex = Assert.Throws<ArgumentOutOfRangeException>(() =>
{
var pipeline = ML.Transforms.Text.TokenizeIntoWords("Features")
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("Features"));
var model = pipeline.GetOutputSchema(SchemaShape.Create(dataView.Schema));
});
Assert.Contains("Variable length column 'Features' is not allowed", ex.Message);

ex = Assert.Throws<ArgumentOutOfRangeException>(() =>
{
var pipeline = ML.Transforms.Text.TokenizeIntoWords("Features")
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("Features", labelColumnName: "BadLabel"));
var model = pipeline.GetOutputSchema(SchemaShape.Create(dataView.Schema));
});
Assert.Contains("Label column 'BadLabel' does not have compatible type. Expected types are float, double, int, bool and key.", ex.Message);

ex = Assert.Throws<ArgumentOutOfRangeException>(() =>
{
var pipeline = ML.Transforms.Text.TokenizeIntoWords("Features")
.Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("Features"));
var model = pipeline.GetOutputSchema(SchemaShape.Create(dataView.Schema));
});
Assert.Contains("Column 'Features' does not have compatible type. Expected types are float, double, int, bool and key.", ex.Message);
}
}
}