Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions src/DotNetBridge/RunGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
using System.Globalization;
using System.IO;
using System.Linq;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.CommandLine;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
using Microsoft.ML.Runtime.Internal.Utilities;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Data.IO;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.EntryPoints.JsonUtils;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.FeatureSelection;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;

Expand Down Expand Up @@ -73,13 +74,13 @@ private static void SaveIdvToFile(IDataView idv, string path, IHost host)

using (var fs = File.OpenWrite(path))
{
saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.ColumnCount)
.Where(x => !idv.Schema.IsHidden(x) && saver.IsColumnSavable(idv.Schema.GetColumnType(x)))
saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.Count)
.Where(x => !idv.Schema[x].IsHidden && saver.IsColumnSavable(idv.Schema[x].Type))
.ToArray());
}
}

private static void SavePredictorModelToFile(IPredictorModel model, string path, IHost host)
private static void SavePredictorModelToFile(PredictorModel model, string path, IHost host)
{
using (var fs = File.OpenWrite(path))
model.Save(host, fs);
Expand Down Expand Up @@ -155,7 +156,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s
Contracts.Assert(iDv < dvNative.Length);
// prefetch all columns
dv = dvNative[iDv++];
var prefetch = new int[dv.Schema.ColumnCount];
var prefetch = new int[dv.Schema.Count];
for (int i = 0; i < prefetch.Length; i++)
prefetch[i] = i;
dv = new CacheDataView(host, dv, prefetch);
Expand All @@ -167,7 +168,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s
if (!string.IsNullOrWhiteSpace(path))
{
using (var fs = File.OpenRead(path))
pm = new PredictorModel(host, fs);
pm = new PredictorModelImpl(host, fs);
}
else
throw host.Except("Model must be loaded from a file");
Expand All @@ -178,7 +179,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s
if (!string.IsNullOrWhiteSpace(path))
{
using (var fs = File.OpenRead(path))
tm = new TransformModel(host, fs);
tm = new TransformModelImpl(host, fs);
}
else
throw host.Except("Model must be loaded from a file");
Expand Down Expand Up @@ -224,7 +225,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s
}
break;
case TlcModule.DataKind.PredictorModel:
var pm = runner.GetOutput<IPredictorModel>(varName);
var pm = runner.GetOutput<PredictorModel>(varName);
if (!string.IsNullOrWhiteSpace(path))
{
SavePredictorModelToFile(pm, path, host);
Expand All @@ -233,7 +234,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s
throw host.Except("Returning in-memory models is not supported");
break;
case TlcModule.DataKind.TransformModel:
var tm = runner.GetOutput<ITransformModel>(varName);
var tm = runner.GetOutput<TransformModel>(varName);
if (!string.IsNullOrWhiteSpace(path))
{
using (var fs = File.OpenWrite(path))
Expand All @@ -245,9 +246,9 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s

case TlcModule.DataKind.Array:
var objArray = runner.GetOutput<object[]>(varName);
if (objArray is IPredictorModel[])
if (objArray is PredictorModel[])
{
var modelArray = (IPredictorModel[])objArray;
var modelArray = (PredictorModel[])objArray;
// Save each model separately
for (var i = 0; i < modelArray.Length; i++)
{
Expand Down Expand Up @@ -284,43 +285,43 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s
private static Dictionary<string, ColumnMetadataInfo> ProcessColumns(ref IDataView view, int maxSlots, IHostEnvironment env)
{
Dictionary<string, ColumnMetadataInfo> result = null;
List<DropSlotsTransform.Column> drop = null;
for (int i = 0; i < view.Schema.ColumnCount; i++)
List<SlotsDroppingTransformer.ColumnInfo> drop = null;
for (int i = 0; i < view.Schema.Count; i++)
{
if (view.Schema.IsHidden(i))
if (view.Schema[i].IsHidden)
continue;

var columnName = view.Schema.GetColumnName(i);
var columnType = view.Schema.GetColumnType(i);
var columnName = view.Schema[i].Name;
var columnType = view.Schema[i].Type;
if (columnType.IsKnownSizeVector)
{
Utils.Add(ref result, columnName, new ColumnMetadataInfo(true, null, null));
if (maxSlots > 0 && columnType.ValueCount > maxSlots)
{
Utils.Add(ref drop,
new DropSlotsTransform.Column()
{
Name = columnName,
Source = columnName,
Slots = new[] { new DropSlotsTransform.Range() { Min = maxSlots } }
});
new SlotsDroppingTransformer.ColumnInfo(
input: columnName,
slots: (maxSlots, null)));
}
}
else if (columnType.IsKey)
{
Dictionary<uint, ReadOnlyMemory<char>> map = null;
if (columnType.KeyCount > 0 && view.Schema.HasKeyNames(i, columnType.KeyCount))
if (columnType.KeyCount > 0 && view.Schema[i].HasKeyValues(columnType.KeyCount))
{
var keyNames = default(VBuffer<ReadOnlyMemory<char>>);
view.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, i, ref keyNames);
view.Schema[i].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyNames);
map = keyNames.Items().ToDictionary(kv => (uint)kv.Key, kv => kv.Value);
}
Utils.Add(ref result, columnName, new ColumnMetadataInfo(false, null, map));
}
}

if (drop != null)
view = new DropSlotsTransform(env, new DropSlotsTransform.Arguments() { Column = drop.ToArray() }, view);
{
var slotDropper = new SlotsDroppingTransformer(env, drop.ToArray());
view = slotDropper.Transform(view);
}

return result;
}
Expand Down