Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,28 @@ public Mapper(KeyToVectorMappingTransformer parent, DataViewSchema inputSchema)
_parent = parent;
_infos = CreateInfos(inputSchema);
_types = new VectorDataViewType[_parent.ColumnPairs.Length];
for (int i = 0; i < _parent.ColumnPairs.Length; i++)

// The following try catch block is designed to provide user a better exception message
// by providing related column name when some exceptions occur in VectorDataViewType()(e.g overflow)
// This change is related with https://github.com/dotnet/machinelearning/issues/5211
try
{
int valueCount = _infos[i].TypeSrc.GetValueCount();
int keyCount = _infos[i].TypeSrc.GetItemType().GetKeyCountAsInt32(Host);
if (_parent._columns[i].OutputCountVector || valueCount == 1)
_types[i] = new VectorDataViewType(NumberDataViewType.Single, keyCount);
else
_types[i] = new VectorDataViewType(NumberDataViewType.Single, valueCount, keyCount);
for (int i = 0; i < _parent.ColumnPairs.Length; i++)
{
int valueCount = _infos[i].TypeSrc.GetValueCount();
int keyCount = _infos[i].TypeSrc.GetItemType().GetKeyCountAsInt32(Host);
if (_parent._columns[i].OutputCountVector || valueCount == 1)
_types[i] = new VectorDataViewType(NumberDataViewType.Single, keyCount);
else
_types[i] = new VectorDataViewType(NumberDataViewType.Single, valueCount, keyCount);
}
}
catch (OverflowException e)
{
var errorMsg = e.Message + " Related column: ";
foreach (var info in _infos)
errorMsg += info.Name + " ";
throw Host.Except(errorMsg);
}
}

Expand Down
47 changes: 47 additions & 0 deletions test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -222,5 +223,51 @@ public void TestOldSavingAndLoading()
var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms);
}
}

public class ModelInput
{
[ColumnName("Label"), LoadColumn(0)]
public int Label { get; set; }


[ColumnName("ProblematicColumn"), LoadColumn(1)]
public string ProblematicColumn { get; set; }
}

static IEnumerable<ModelInput> GetData()
{
for (int i = 0; i < 1000; i++)
{
yield return new ModelInput { Label = i % 3, ProblematicColumn = (i % 200).ToString() };
}
}

[Fact]
public void KeyToVectorOverflowTest()
{
// This test is introduced for https://github.com/dotnet/machinelearning/issues/5211
// that provides users an informational exception message
// This exception happens if call OneHotHashEncoding twice in your pipeline
MLContext mlContext = new MLContext(1);

IDataView dataview = mlContext.Data.LoadFromEnumerable(GetData());

var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
.Append(mlContext.Transforms.Categorical.OneHotHashEncoding("ProblematicColumn"));

var featurizedData = pipeline.Fit(dataview).Transform(dataview);

try
{
var transformer = pipeline.Fit(featurizedData);
Assert.True(false);
}
catch (System.Exception ex)
{
Assert.Contains("Arithmetic operation resulted in an overflow. Related column: ProblematicColumn", ex.Message);
return;
}

}
}
}