diff --git a/src/DotNetBridge/RunGraph.cs b/src/DotNetBridge/RunGraph.cs index e2e1dfc9..63a10e01 100644 --- a/src/DotNetBridge/RunGraph.cs +++ b/src/DotNetBridge/RunGraph.cs @@ -8,14 +8,15 @@ using System.Globalization; using System.IO; using System.Linq; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; -using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.EntryPoints.JsonUtils; -using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Data.IO; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.EntryPoints.JsonUtils; +using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.FeatureSelection; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -73,13 +74,13 @@ private static void SaveIdvToFile(IDataView idv, string path, IHost host) using (var fs = File.OpenWrite(path)) { - saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.ColumnCount) - .Where(x => !idv.Schema.IsHidden(x) && saver.IsColumnSavable(idv.Schema.GetColumnType(x))) + saver.SaveData(fs, idv, Utils.GetIdentityPermutation(idv.Schema.Count) + .Where(x => !idv.Schema[x].IsHidden && saver.IsColumnSavable(idv.Schema[x].Type)) .ToArray()); } } - private static void SavePredictorModelToFile(IPredictorModel model, string path, IHost host) + private static void SavePredictorModelToFile(PredictorModel model, string path, IHost host) { using (var fs = File.OpenWrite(path)) model.Save(host, fs); @@ -155,7 +156,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s Contracts.Assert(iDv < dvNative.Length); // prefetch all columns dv = dvNative[iDv++]; - var prefetch = new int[dv.Schema.ColumnCount]; + var prefetch = new int[dv.Schema.Count]; for (int i = 0; i < prefetch.Length; i++) prefetch[i] = i; dv = new CacheDataView(host, dv, prefetch); @@ -167,7 +168,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s if (!string.IsNullOrWhiteSpace(path)) { using (var fs = File.OpenRead(path)) - pm = new PredictorModel(host, fs); + pm = new PredictorModelImpl(host, fs); } else throw host.Except("Model must be loaded from a file"); @@ -178,7 +179,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s if (!string.IsNullOrWhiteSpace(path)) { using (var fs = File.OpenRead(path)) - tm = new TransformModel(host, fs); + tm = new TransformModelImpl(host, fs); } else throw host.Except("Model must be loaded from a file"); @@ -224,7 +225,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s } break; case TlcModule.DataKind.PredictorModel: - var pm = runner.GetOutput(varName); + var pm = runner.GetOutput(varName); if (!string.IsNullOrWhiteSpace(path)) { SavePredictorModelToFile(pm, path, host); @@ -233,7 +234,7 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s throw host.Except("Returning in-memory models is not supported"); break; case TlcModule.DataKind.TransformModel: - var tm = runner.GetOutput(varName); + var tm = runner.GetOutput(varName); if (!string.IsNullOrWhiteSpace(path)) { using (var fs = File.OpenWrite(path)) @@ -245,9 +246,9 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s case TlcModule.DataKind.Array: var objArray = runner.GetOutput(varName); - if (objArray is IPredictorModel[]) + if (objArray is PredictorModel[]) { - var modelArray = (IPredictorModel[])objArray; + var modelArray = (PredictorModel[])objArray; // Save each model separately for (var i = 0; i < modelArray.Length; i++) { @@ -284,35 +285,32 @@ private static void RunGraphCore(EnvironmentBlock* penv, IHostEnvironment env, s private static Dictionary ProcessColumns(ref IDataView view, int maxSlots, IHostEnvironment env) { Dictionary result = null; - List drop = null; - for (int i = 0; i < view.Schema.ColumnCount; i++) + List drop = null; + for (int i = 0; i < view.Schema.Count; i++) { - if (view.Schema.IsHidden(i)) + if (view.Schema[i].IsHidden) continue; - var columnName = view.Schema.GetColumnName(i); - var columnType = view.Schema.GetColumnType(i); + var columnName = view.Schema[i].Name; + var columnType = view.Schema[i].Type; if (columnType.IsKnownSizeVector) { Utils.Add(ref result, columnName, new ColumnMetadataInfo(true, null, null)); if (maxSlots > 0 && columnType.ValueCount > maxSlots) { Utils.Add(ref drop, - new DropSlotsTransform.Column() - { - Name = columnName, - Source = columnName, - Slots = new[] { new DropSlotsTransform.Range() { Min = maxSlots } } - }); + new SlotsDroppingTransformer.ColumnInfo( + input: columnName, + slots: (maxSlots, null))); } } else if (columnType.IsKey) { Dictionary> map = null; - if (columnType.KeyCount > 0 && view.Schema.HasKeyNames(i, columnType.KeyCount)) + if (columnType.KeyCount > 0 && view.Schema[i].HasKeyValues(columnType.KeyCount)) { var keyNames = default(VBuffer>); - view.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, i, ref keyNames); + view.Schema[i].Metadata.GetValue(MetadataUtils.Kinds.KeyValues, ref keyNames); map = keyNames.Items().ToDictionary(kv => (uint)kv.Key, kv => kv.Value); } Utils.Add(ref result, columnName, new ColumnMetadataInfo(false, null, map)); @@ -320,7 +318,10 @@ private static Dictionary ProcessColumns(ref IDataVi } if (drop != null) - view = new DropSlotsTransform(env, new DropSlotsTransform.Arguments() { Column = drop.ToArray() }, view); + { + var slotDropper = new SlotsDroppingTransformer(env, drop.ToArray()); + view = slotDropper.Transform(view); + } return result; }