Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 2 additions & 35 deletions src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -305,41 +305,8 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output
}
}

if (string.IsNullOrEmpty(stratificationColumn))
{
stratificationColumn = "StratificationColumn";
int tmp;
int inc = 0;
while (input.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
stratificationColumn = string.Format("StratificationColumn_{0:000}", ++inc);
var keyGenArgs = new GenerateNumberTransform.Options();
var col = new GenerateNumberTransform.Column();
col.Name = stratificationColumn;
keyGenArgs.Columns = new[] { col };
output = new GenerateNumberTransform(Host, keyGenArgs, input);
}
else
{
int col;
if (!input.Schema.TryGetColumnIndex(stratificationColumn, out col))
throw ch.ExceptUserArg(nameof(Arguments.StratificationColumn), "Column '{0}' does not exist", stratificationColumn);
var type = input.Schema[col].Type;
if (!RangeFilter.IsValidRangeFilterColumnType(ch, type))
{
ch.Info("Hashing the stratification column");
var origStratCol = stratificationColumn;
stratificationColumn = input.Schema.GetTempColumnName("strat");

// HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
var itemType = type.GetItemType();
if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
input = new TypeConvertingTransformer(Host, origStratCol, DataKind.Int64, origStratCol).Transform(input);

output = new HashingEstimator(Host, stratificationColumn, origStratCol, 30).Fit(input).Transform(input);
}
}

return stratificationColumn;
var newStratificationColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Host, ref output, stratificationColumn);
return newStratificationColumn;
}

private bool TryGetOverallMetrics(Dictionary<string, IDataView>[] metrics, out List<IDataView> overallList)
Expand Down
81 changes: 59 additions & 22 deletions src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -413,18 +413,18 @@ public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, s
_env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
_env.CheckValueOrNull(samplingKeyColumnName);

EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed);
var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true);

var trainFilter = new RangeFilter(_env, new RangeFilter.Options()
{
Column = samplingKeyColumnName,
Column = newSamplingKeyColumn,
Min = 0,
Max = testFraction,
Complement = true
}, data);
var testFilter = new RangeFilter(_env, new RangeFilter.Options()
{
Column = samplingKeyColumnName,
Column = newSamplingKeyColumn,
Min = 0,
Max = testFraction,
Complement = false
Expand Down Expand Up @@ -455,9 +455,9 @@ public IReadOnlyList<TrainTestData> CrossValidationSplit(IDataView data, int num
_env.CheckValue(data, nameof(data));
_env.CheckParam(numberOfFolds > 1, nameof(numberOfFolds), "Must be more than 1");
_env.CheckValueOrNull(samplingKeyColumnName);
EnsureGroupPreservationColumn(_env, ref data, ref samplingKeyColumnName, seed);
var newSamplingKeyColumn = CreateGroupPreservationColumn(_env, ref data, samplingKeyColumnName, seed, fallbackInEnvSeed: true);
var result = new List<TrainTestData>();
foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, samplingKeyColumnName))
foreach (var split in CrossValidationSplit(_env, data, numberOfFolds, newSamplingKeyColumn))
result.Add(split);
return result;
}
Expand Down Expand Up @@ -491,50 +491,87 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment
}

/// <summary>
/// Ensures the provided <paramref name="samplingKeyColumn"/> is valid for <see cref="RangeFilter"/>, hashing it if necessary, or creates a new column <paramref name="samplingKeyColumn"/> is null.
/// Based on the input samplingKeyColumn creates a new column that will be used by the callers to apply a RangeFilter that will produce train-test split
/// or cross-validation splits.
///
/// Notice that the new column might get dropped by the callers of this method after using it.
/// </summary>
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
/// <param name="env">IHostEnvironment of the caller</param>
/// <param name="data">DataView that should contain the "samplingKeyColumn". The new column will be added to this DataView.</param>
/// <param name="samplingKeyColumn">Name of the column that will be used as base of the new GroupPreservationColumn.
/// Notice that in other places in the code this column, and/or the column that this method creates,
/// are refered to as "SamplingKeyColumn", "GroupPreservationColumn" or "StratificationColumn". </param>
/// <param name="seed">The seed used by the transformers that will create the new column</param>
/// <param name="fallbackInEnvSeed">If seed = null, then should we use the env seed? If seed = null, and this parameter is false, then we won't use a seed.</param>
/// <return>The name of the new column</return>
[BestFriend]
internal static string CreateGroupPreservationColumn(IHostEnvironment env, ref IDataView data, string samplingKeyColumn, int? seed = null, bool fallbackInEnvSeed = false)
{
Contracts.CheckValue(env, nameof(env));
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.
Contracts.CheckValueOrNull(samplingKeyColumn);

var newSamplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
int? seedToUse;

if(seed.HasValue)
{
seedToUse = seed.Value;
}
else if(fallbackInEnvSeed)
{
ISeededEnvironment seededEnv = (ISeededEnvironment)env;
seedToUse = seededEnv.Seed;
}
else
{
seedToUse = null;
}

// We need to handle two cases: if samplingKeyColumn is not provided, we generate a random number.
if (samplingKeyColumn == null)
{
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
data = new GenerateNumberTransform(env, data, newSamplingKeyColumn, (uint?)seedToUse);
}
else
{
// If samplingKeyColumn was provided we will make a new column based on it, but using a temporary
// name, as it might be dropped elsewhere in the code

if (!data.Schema.TryGetColumnIndex(samplingKeyColumn, out int stratCol))
throw env.ExceptSchemaMismatch(nameof(samplingKeyColumn), "SamplingKeyColumn", samplingKeyColumn);

var type = data.Schema[stratCol].Type;
if (!RangeFilter.IsValidRangeFilterColumnType(env, type))
{
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
var hashInputColumnName = samplingKeyColumn;
// HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
var itemType = type.GetItemType();
if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
data = new TypeConvertingTransformer(env, origStratCol, DataKind.Int64, origStratCol).Transform(data);
{
data = new TypeConvertingTransformer(env, newSamplingKeyColumn, DataKind.Int64, samplingKeyColumn).Transform(data);
hashInputColumnName = newSamplingKeyColumn;
}

var localSeed = seed.HasValue ? seed : ((ISeededEnvironment)env).Seed.HasValue ? ((ISeededEnvironment)env).Seed : null;
var columnOptions =
localSeed.HasValue ?
new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, (uint)localSeed.Value, combine: true) :
new HashingEstimator.ColumnOptions(samplingKeyColumn, origStratCol, 30, combine: true);
seedToUse.HasValue ?
new HashingEstimator.ColumnOptions(newSamplingKeyColumn, hashInputColumnName, 30, (uint)seedToUse.Value, combine: true) :
new HashingEstimator.ColumnOptions(newSamplingKeyColumn, hashInputColumnName, 30, combine: true);
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
}
else
{
if (!data.Schema[samplingKeyColumn].IsNormalized() && (type == NumberDataViewType.Single || type == NumberDataViewType.Double))
if (data.Schema[samplingKeyColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double))
{
data = new ColumnCopyingEstimator(env, (newSamplingKeyColumn, samplingKeyColumn)).Fit(data).Transform(data);
}
else
{
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(newSamplingKeyColumn, samplingKeyColumn, ensureZeroUntouched: false)).Fit(data).Transform(data);
}
Comment thread
antoniovs1029 marked this conversation as resolved.
}
}

return newSamplingKeyColumn;
}
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEs
Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
Environment.CheckValueOrNull(samplingKeyColumn);

DataOperationsCatalog.EnsureGroupPreservationColumn(Environment, ref data, ref samplingKeyColumn, seed);
var newSamplingKeyColumn = DataOperationsCatalog.CreateGroupPreservationColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true);
var result = new CrossValidationResult[numFolds];
int fold = 0;
// Sequential per-fold training.
// REVIEW: we could have a parallel implementation here. We would need to
// spawn off a separate host per fold in that case.
foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, samplingKeyColumn))
foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, numFolds, newSamplingKeyColumn))
{
var model = estimator.Fit(split.TrainSet);
var scoredTest = model.Transform(split.TestSet);
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.EntryPoints/CVSplit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
Expand Down Expand Up @@ -53,7 +54,7 @@ public static Output Split(IHostEnvironment env, Input input)

var data = input.Data;

var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);
var stratCol = DataOperationsCatalog.CreateGroupPreservationColumn(env, ref data, input.StratificationColumn);

int n = input.NumFolds;
var output = new Output
Expand Down
55 changes: 1 addition & 54 deletions src/Microsoft.ML.EntryPoints/TrainTestSplit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static Output Split(IHostEnvironment env, Input input)
EntryPointUtils.CheckInputArgs(host, input);

var data = input.Data;
var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);
var stratCol = DataOperationsCatalog.CreateGroupPreservationColumn(env, ref data, input.StratificationColumn);

IDataView trainData = new RangeFilter(host,
new RangeFilter.Options { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data);
Expand All @@ -64,57 +64,4 @@ public static Output Split(IHostEnvironment env, Input input)
}

}

internal static class SplitUtils
{
public static string CreateStratificationColumn(IHost host, ref IDataView data, string stratificationColumn = null)
{
host.CheckValue(data, nameof(data));
host.CheckValueOrNull(stratificationColumn);

// Pick a unique name for the stratificationColumn.
const string stratColName = "StratificationKey";
string stratCol = data.Schema.GetTempColumnName(stratColName);

// Construct the stratification column. If user-provided stratification column exists, use HashJoin
// of it to construct the strat column, otherwise generate a random number and use it.
if (stratificationColumn == null)
{
data = new GenerateNumberTransform(host,
new GenerateNumberTransform.Options
{
Columns = new[] { new GenerateNumberTransform.Column { Name = stratCol } }
}, data);
}
else
{
var col = data.Schema.GetColumnOrNull(stratificationColumn);
if (col == null)
throw host.ExceptSchemaMismatch(nameof(stratificationColumn), "Stratification", stratificationColumn);

var type = col.Value.Type;
if (!RangeFilter.IsValidRangeFilterColumnType(host, type))
{
// HashingEstimator currently handles all primitive types except for DateTime, DateTimeOffset and TimeSpan.
var itemType = type.GetItemType();
if (itemType is DateTimeDataViewType || itemType is DateTimeOffsetDataViewType || itemType is TimeSpanDataViewType)
data = new TypeConvertingTransformer(host, stratificationColumn, DataKind.Int64, stratificationColumn).Transform(data);

var columnOptions = new HashingEstimator.ColumnOptions(stratCol, stratificationColumn, 30, combine: true);
data = new HashingEstimator(host, columnOptions).Fit(data).Transform(data);
}
else
{
if (data.Schema[stratificationColumn].IsNormalized() || (type != NumberDataViewType.Single && type != NumberDataViewType.Double))
return stratificationColumn;

data = new NormalizingEstimator(host,
new NormalizingEstimator.MinMaxColumnOptions(stratCol, stratificationColumn, ensureZeroUntouched: true))
.Fit(data).Transform(data);
}
}

return stratCol;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
maml.exe CV tr=lr{l1=1.0 l2=0.1 ot=1e-3 nt=1} strat=Strat threads=- norm=Warn loader=text{col=Features:R4:9-14 col=Label:R4:0 col=Strat:TX:1 header+} data=%Data% out=%Output%
Hashing the stratification column
Warning: A normalizer is needed for this trainer. Either add a normalizing transform or use the 'norm=Auto', 'norm=Yes' or 'norm=No' options.
Beginning optimization
num vars: 7
Expand Down
Loading