Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,28 @@ public Mapper(KeyToVectorMappingTransformer parent, DataViewSchema inputSchema)
_parent = parent;
_infos = CreateInfos(inputSchema);
_types = new VectorDataViewType[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)

// The following try catch block is designed to provide user a better exception message
// by providing related column name when some exceptions occur in VectorDataViewType()(e.g overflow)
// This change is related with https://github.com/dotnet/machinelearning/issues/5211
try
Comment thread
gh-yewang marked this conversation as resolved.
{
int valueCount = _infos[i].TypeSrc.GetValueCount();
int keyCount = _infos[i].TypeSrc.GetItemType().GetKeyCountAsInt32(Host);
if (_parent._columns[i].OutputCountVector || valueCount == 1)
_types[i] = new VectorDataViewType(NumberDataViewType.Single, keyCount);
else
_types[i] = new VectorDataViewType(NumberDataViewType.Single, valueCount, keyCount);
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
Comment thread
gh-yewang marked this conversation as resolved.
{
int valueCount = _infos[i].TypeSrc.GetValueCount();
int keyCount = _infos[i].TypeSrc.GetItemType().GetKeyCountAsInt32(Host);
if (_parent._columns[i].OutputCountVector || valueCount == 1)
_types[i] = new VectorDataViewType(NumberDataViewType.Single, keyCount);
else
_types[i] = new VectorDataViewType(NumberDataViewType.Single, valueCount, keyCount);
}
}
catch (Exception e)
{
var errorMsg = e.Message + " Related column: ";
foreach (var info in _infos)
errorMsg += info.Name + " ";
throw Host.Except(errorMsg);
}
}

Expand Down
72 changes: 72 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/AutoMLFailureTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.TestFramework;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.ML.AutoML.Test
{

public class AutoMLFailureTests : BaseTestClass
{
public AutoMLFailureTests(ITestOutputHelper output) : base(output)
{
}

public class ModelInput
{
[ColumnName("Label"), LoadColumn(0)]
public int Label { get; set; }


[ColumnName("ProblematicColumn"), LoadColumn(1)]
public string ProblematicColumn { get; set; }

}

[Fact]
public void CrossValidationOverflowTest()
{
// This test is introduced for https://github.com/dotnet/machinelearning/issues/5211
// that provides users an informational exception message
MLContext mlContext = new MLContext(1);

IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: GetDataPath("cross_validation_overflow_dataset.txt"),
hasHeader: true,
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
separatorChar: '\t',
allowQuoting: true,
allowSparse: false);

IDataView testDataView = mlContext.Data.BootstrapSample(trainingDataView);

ExperimentResult<MulticlassClassificationMetrics> experimentResult = mlContext.Auto()
.CreateMulticlassClassificationExperiment(60)
Comment thread
gh-yewang marked this conversation as resolved.
Outdated
.Execute(trainingDataView, labelColumnName: "Label");
RunDetail<MulticlassClassificationMetrics> bestRun = experimentResult.BestRun;
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView);

try
{
var testMetrics = mlContext.MulticlassClassification.CrossValidate(
testDataViewWithBestScore,
bestRun.Estimator,
numberOfFolds: 5,
labelColumnName: "Label");
Assert.True(false);
}
catch (System.Exception ex)
{
Assert.Contains("Arithmetic operation resulted in an overflow. Related column: ProblematicColumn", ex.Message);
return;
}

}
}
}


Loading