diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs index baf3400eaf..0579bc5f5b 100644 --- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs @@ -28,7 +28,9 @@ private protected RowToRowTransformerBase(IHost host) bool ITransformer.IsRowToRowMapper => true; - IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) => GetRowToRowMapperCore(inputSchema); + + protected virtual IRowToRowMapper GetRowToRowMapperCore(DataViewSchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper); @@ -37,16 +39,20 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) [BestFriend] private protected abstract IRowMapper MakeRowMapper(DataViewSchema schema); - public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => GetOutputSchemaCore(inputSchema); + + protected virtual DataViewSchema GetOutputSchemaCore(DataViewSchema inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var mapper = MakeRowMapper(inputSchema); return RowToRowMapperTransform.GetOutputSchema(inputSchema, mapper); } - public IDataView Transform(IDataView input) => MakeDataTransform(input); + public IDataView Transform(IDataView input) => MakeDataTransformCore(input); - [BestFriend] + private protected virtual IDataView MakeDataTransformCore(IDataView input) => MakeDataTransform(input); + + [BestFriend] // MYTODO: Since this is "BestFriend" here, should I also make the MakeDataTransformCore in OnnxDataTransform a "BestFriend"? What's the purpose? private protected RowToRowMapperTransform MakeDataTransform(IDataView input) { Host.CheckValue(input, nameof(input)); diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 530e3193f3..a95f175313 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -37,7 +37,7 @@ namespace Microsoft.ML.Transforms.Onnx /// /// resulting from fitting an . /// - public sealed class OnnxTransformer : RowToRowTransformerBase + public sealed class OnnxTransformer : RowToRowTransformerBase // MYTODO: Should I consider not to inherit from this, since now OnnxTransformer would be able to drop columns and not use the RowToRowMapperTransform? { /// /// A class used for capturing shape information from command line. @@ -134,12 +134,18 @@ private static VersionInfo GetVersionInfo() // Factory method for SignatureDataTransform private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) { - return new OnnxTransformer(env, options).MakeDataTransform(input); + var transformer = new OnnxTransformer(env, options); + var mapper = new Mapper(transformer, input.Schema); + return new OnnxDataTransform(env, input, mapper); } // Factory method for SignatureLoadDataTransform private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - => Create(env, ctx).MakeDataTransform(input); + { + var transformer = OnnxTransformer.Create(env, ctx); + var mapper = new Mapper(transformer, input.Schema); + return new OnnxDataTransform(env, input, mapper); + } // Factory method for SignatureLoadModel. private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx) @@ -188,8 +194,7 @@ private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx } // Factory method for SignatureLoadRowMapper. - private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) - => Create(env, ctx).MakeRowMapper(inputSchema); + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); // MYTODO: In what scenario is this called? Should I worry that the mapper, only by itself, isn't capable of dropping columns? private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer))) @@ -325,7 +330,24 @@ private protected override void SaveModel(ModelSaveContext ctx) } } - private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); + private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); // MYTODO: Could I erase this? If I stop inheriting from RTRTB? + + protected override IRowToRowMapper GetRowToRowMapperCore(DataViewSchema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + return new OnnxDataTransform(Host, new EmptyDataView(Host, inputSchema), new Mapper(this, inputSchema)); + } + + protected override DataViewSchema GetOutputSchemaCore(DataViewSchema inputSchema) + { + return OnnxDataTransform.GetOutputSchema(inputSchema, new Mapper(this, inputSchema)); + } + + private protected override IDataView MakeDataTransformCore(IDataView input) + { + Host.CheckValue(input, nameof(input)); + return new OnnxDataTransform(Host, input, new Mapper(this, input.Schema)); + } /// /// This design assumes that all unknown dimensions are 1s. It also convert scalar shape [] in ONNX to [1]. @@ -340,6 +362,19 @@ private static IEnumerable AdjustDimensions(OnnxShape shape) return new[] { 1 }; } + /// + /// In order to fully support onnx exportability from , it was decided + /// that the should drop all columns that are used as input of the Onnx model, + /// from the input schema. + /// + /// Any column that was already inside the input schema, but which isn't used by the onnx model itself, + /// should simply propagate to the output. + /// + internal string[] GetDropColumnsNames() + { + return Model.ModelInfo.InputNames.ToArray(); + } + private sealed class Mapper : MapperBase { private readonly OnnxTransformer _parent; @@ -407,6 +442,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) : if (typeValueCount % valCount != 0) throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}."); } + } protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() @@ -687,6 +723,422 @@ public NamedOnnxValue GetNamedOnnxValue() return OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape); } } + + /// + /// + /// + public string[] GetDropColumnsNames() + { + return _parent.GetDropColumnsNames(); + } + } + + /// + /// Similar to , but this class will enable dropping columns from the input + /// schema, in order to let OnnxTransformer support the onnx export. + /// + [BestFriend] // MYTODO: Is this necessary? + internal sealed class Bindings // MYTODO: Should I move this inside OnnxDataTransform? + { + // MYTODO: Should I simply inherit from ColumnBindings, since everything is the same except for the constructor (specifically, only, the way it created the _colMap)? + + // Indices of columns in the merged schema. Old indices are as is, new indices are stored as ~idx. + private readonly int[] _colMap; + + /// + /// The indices of added columns in the . + /// + public IReadOnlyList AddedColumnIndices { get; } + + /// + /// The input schema. + /// + public DataViewSchema InputSchema { get; } + + /// + /// The merged schema. + /// + public DataViewSchema Schema { get; } + + /// + /// Create a new instance of . + /// + /// The input schema that we're adding columns to. + /// Names of the columns to drop, so that they don't propagate from the input schema + /// The columns being added. + public Bindings(DataViewSchema input, List dropColumnsNames, DataViewSchema.DetachedColumn[] addedColumns) + { + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckValue(addedColumns, nameof(addedColumns)); + + InputSchema = input; + + // Construct the indices. + // Drop the indicated columns + // And drop all hidden columns + var indices = new List(); + var namesUsed = new HashSet(); + for (int i = 0; i < input.Count; i++) + { + if (InputSchema[i].IsHidden || dropColumnsNames.Contains(InputSchema[i].Name)) // MYTODO: Should I drop all hidden columns? Only the ones that are inside the dropColumnsNames list? + continue; + + namesUsed.Add(input[i].Name); + indices.Add(i); + } + + for (int i = 0; i < addedColumns.Length; i++) + { + string name = addedColumns[i].Name; + if (namesUsed.Add(name)) + { + // New name. Append to the end. + indices.Add(~i); + } + else + { + // Old name. Find last instance and add after it. + for (int j = indices.Count - 1; j >= 0; j--) + { + var colName = indices[j] >= 0 ? input[indices[j]].Name : addedColumns[~indices[j]].Name; + if (colName == name) + { + indices.Insert(j + 1, ~i); + break; + } + } + } + } + + // Contracts.Assert(indices.Count == addedColumns.Length + input.Count); // MYTODO: This assertion is no longer valid, and I can't think of a better one + + // Create the output schema. + var schemaColumns = indices.Select(idx => idx >= 0 ? new DataViewSchema.DetachedColumn(input[idx]) : addedColumns[~idx]); + Schema = SchemaExtensions.MakeSchema(schemaColumns); + + // Memorize column maps. + _colMap = indices.ToArray(); + var addedIndices = new int[addedColumns.Length]; + for (int i = 0; i < _colMap.Length; i++) + { + int colIndex = _colMap[i]; + if (colIndex < 0) + { + Contracts.Assert(addedIndices[~colIndex] == 0); + addedIndices[~colIndex] = i; + } + } + + AddedColumnIndices = addedIndices.AsReadOnly(); + } + + /// + /// This maps a column index for this schema to either a source column index (when + /// is true), or to an "iinfo" index of an added column + /// (when is false). + /// + /// Whether the return index is for a source column + /// The column index for this schema + /// The index (either source index or iinfo index) + public int MapColumnIndex(out bool isSrcColumn, int col) + { + Contracts.Assert(0 <= col && col < _colMap.Length); + int index = _colMap[col]; + if (index < 0) + { + index = ~index; + Contracts.Assert(index < AddedColumnIndices.Count); + isSrcColumn = false; + } + else + { + Contracts.Assert(index < InputSchema.Count); + isSrcColumn = true; + } + return index; + } + + /// + /// The given predicate maps from output column index to whether the column is active. + /// This builds an array of bools of length Input.ColumnCount containing the results of calling + /// predicate on the output column index corresponding to each input column index. + /// + public bool[] GetActiveInput(Func predicate) + { + Contracts.AssertValue(predicate); + + var active = new bool[InputSchema.Count]; + for (int dst = 0; dst < _colMap.Length; dst++) + { + int src = _colMap[dst]; + Contracts.Assert(-AddedColumnIndices.Count <= src && src < InputSchema.Count); + if (src >= 0 && predicate(dst)) + active[src] = true; + } + return active; + } + } + + private class OnnxDataTransform : RowToRowTransformBase, IRowToRowMapper + { + // MYTODO: Is it even worth it to have this OnnxDataTransform class when it (including the RowImpl and Cursor) + // are identical to the RowToRowMapperTransform? The differences are: + // - This one expects specifically a OnnxTransformer.Mapper as _mapper from where to get the GetColumnsNames, whereas RTRMT expects a generic IRowMapper + // - This one has a _bindings object which is off type OnnxTransformer.Bindings, whereas RTRMT expects a generic ColumnsBindings + // - This one in here has a differend override for the Save method + // - This one in here doesn't have (but I don't know if it could have) methods related to SaveOnnx, SavePfa, ApplyToData, and VersionInfo of RTRMT. + // - RTRMT has an extra member called "_mapperFactory" that is used in ApplyToData + + private protected override void SaveModel(ModelSaveContext ctx) => (_mapper as IRowMapper).Save(ctx); // MYTODO: This is the only thing that differ between this and RTRMT. Wonder if it would work if I used theirs instead? + + private readonly Mapper _mapper; + private readonly Bindings _bindings; + + public override DataViewSchema OutputSchema => _bindings.Schema; + + public OnnxDataTransform(IHostEnvironment env, IDataView input, Mapper mapper) + : base(env.Register(nameof(OnnxDataTransform)), input) + { + _mapper = mapper; + _bindings = new Bindings(input.Schema, mapper.GetDropColumnsNames().ToList(), (mapper as IRowMapper).GetOutputColumns()); + } + + public static DataViewSchema GetOutputSchema(DataViewSchema inputSchema, Mapper mapper) + { + Contracts.CheckValue(inputSchema, nameof(inputSchema)); + Contracts.CheckValue(mapper, nameof(mapper)); + return new Bindings(inputSchema, mapper.GetDropColumnsNames().ToList(), (mapper as IRowMapper).GetOutputColumns()).Schema; + } + + /// + /// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount), + /// and the needed active input columns, given a predicate for the needed active output columns. + /// + private bool[] GetActive(Func predicate, out IEnumerable inputColumns) + { + int n = _bindings.Schema.Count; + var active = Utils.BuildArray(n, predicate); + Contracts.Assert(active.Length == n); + + var activeInput = _bindings.GetActiveInput(predicate); + Contracts.Assert(activeInput.Length == _bindings.InputSchema.Count); + + // Get a predicate that determines which outputs are active. + var predicateOut = GetActiveOutputColumns(active); + + // Now map those to active input columns. + var predicateIn = (_mapper as IRowMapper).GetDependencies(predicateOut); + + // Combine the two sets of input columns. + inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index)); + + return active; + } + + private Func GetActiveOutputColumns(bool[] active) + { + Contracts.AssertValue(active); + Contracts.Assert(active.Length == _bindings.Schema.Count); + + return + col => + { + Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count); + return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]]; + }; + } + + protected override bool? ShouldUseParallelCursors(Func predicate) + { + Host.AssertValue(predicate, "predicate"); + if (_bindings.AddedColumnIndices.Any(predicate)) // MYTODO: This is copied from RowToRowMapperTransform. Why is this the case, and it ignores all the other columns that propagate from the input? + return true; + return null; + } + + protected override DataViewRowCursor GetRowCursorCore(IEnumerable columnsNeeded, Random rand = null) + { + var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema); + var active = GetActive(predicate, out IEnumerable inputCols); + + return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active); + } + + public override DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) + { + Host.CheckValueOrNull(rand); + + var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema); + var active = GetActive(predicate, out IEnumerable inputCols); + + var inputs = Source.GetRowCursorSet(inputCols, n, rand); + Host.AssertNonEmpty(inputs); + + if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate)) // MYTODO: This is copied from TowToRowMapperTransform. Shouldn't the last check actually call ShouldUseParallel? + inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n); + Host.AssertNonEmpty(inputs); + + var cursors = new DataViewRowCursor[inputs.Length]; + for (int i = 0; i < inputs.Length; i++) + cursors[i] = new Cursor(Host, inputs[i], this, active); + return cursors; + } + + /// + /// Given a set of output columns, return the input columns that are needed to generate those output columns. + /// + IEnumerable IRowToRowMapper.GetDependencies(IEnumerable dependingColumns) + { + var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema); + GetActive(predicate, out var inputColumns); + return inputColumns; + } + + public DataViewSchema InputSchema => Source.Schema; + + DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable activeColumns) + { + Host.CheckValue(input, nameof(input)); + Host.CheckValue(activeColumns, nameof(activeColumns)); + Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to"); + + using (var ch = Host.Start("GetEntireRow")) + { + var activeArr = new bool[OutputSchema.Count]; + foreach (var column in activeColumns) + { + Host.Assert(column.Index < activeArr.Length, $"The columns {activeColumns.Select(c => c.Name)} are not suitable for the OutputSchema."); + activeArr[column.Index] = true; + } + var pred = GetActiveOutputColumns(activeArr); + var getters = (_mapper as IRowMapper).CreateGetters(input, pred, out Action disp); + + return new RowImpl(input, this, OutputSchema, getters, disp); + } + } + + // MYTODO: Should I also copy in here the ApplyToData method from RowToRowMapperTransform? + + private sealed class RowImpl : WrappingRow + { + private readonly Delegate[] _getters; + private readonly OnnxDataTransform _parent; + private readonly Action _disposer; + + public override DataViewSchema Schema { get; } + + public RowImpl(DataViewRow input, OnnxDataTransform parent, DataViewSchema schema, Delegate[] getters, Action disposer) + : base(input) + { + _parent = parent; + Schema = schema; + _getters = getters; + _disposer = disposer; + } + + protected override void DisposeCore(bool disposing) + { + if (disposing) + _disposer?.Invoke(); + } + + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + bool isSrc; + int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index); + if (isSrc) + return Input.GetGetter(Input.Schema[index]); + + Contracts.Assert(_getters[index] != null); + var fn = _getters[index] as ValueGetter; + if (fn == null) + throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + /// + /// Returns whether the given column is active in this row. + /// + public override bool IsColumnActive(DataViewSchema.Column column) + { + bool isSrc; + int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index); + if (isSrc) + return Input.IsColumnActive(Schema[index]); + return _getters[index] != null; + } + } + + private sealed class Cursor : SynchronizedCursorBase + { + private readonly Delegate[] _getters; + private readonly bool[] _active; + private readonly Bindings _bindings; + private readonly Action _disposer; + private bool _disposed; + + public override DataViewSchema Schema => _bindings.Schema; + + public Cursor(IChannelProvider provider, DataViewRowCursor realInput, OnnxDataTransform parent, bool[] active) + : base(provider, realInput) + { + var pred = parent.GetActiveOutputColumns(active); + _getters = (parent._mapper as IRowMapper).CreateGetters(realInput, pred, out _disposer); + _active = active; + _bindings = parent._bindings; + } + + /// + /// Returns whether the given column is active in this row. + /// + public override bool IsColumnActive(DataViewSchema.Column column) + { + Ch.Check(column.Index < _bindings.Schema.Count); + return _active[column.Index]; + } + + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + Ch.Check(IsColumnActive(column)); + + bool isSrc; + int index = _bindings.MapColumnIndex(out isSrc, column.Index); + if (isSrc) + return Input.GetGetter(Input.Schema[index]); + + Ch.AssertValue(_getters); + var getter = _getters[index]; + Ch.Assert(getter != null); + var fn = getter as ValueGetter; + if (fn == null) + throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + protected override void Dispose(bool disposing) + { + if (_disposed) + return; + if (disposing) + _disposer?.Invoke(); + _disposed = true; + base.Dispose(disposing); + } + } } } @@ -775,8 +1227,6 @@ internal OnnxScoringEstimator(IHostEnvironment env, OnnxTransformer transformer) public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.ToDictionary(x => x.Name); - var resultDic = inputSchema.ToDictionary(x => x.Name); // This loop checks if all input columns needed in the underlying transformer can be found // in inputSchema. @@ -806,6 +1256,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString()); } + var droppedInputs = new List(Transformer.GetDropColumnsNames()); + var resultDic = inputSchema.Where(col => !droppedInputs.Contains(col.Name)).ToDictionary(x => x.Name); // MYTODO: Is this enough? Does SchemaShape should also worry about "hidden" columns? + for (var i = 0; i < Transformer.Outputs.Length; i++) { resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i], diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index ca9110260c..a48e2175b8 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -26,7 +26,7 @@ internal sealed class OnnxModel : IDisposable { /// /// OnnxModelInfo contains the data that we should get from - /// OnnxRuntime API once that functionality is added. + /// the OnnxRuntime API /// public sealed class OnnxModelInfo { diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs index c60351e54e..0788eb8505 100644 --- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs @@ -139,6 +139,19 @@ public void TestSimpleCase() } catch (ArgumentOutOfRangeException) { } catch (InvalidOperationException) { } + + try + { + // MYTODO: Is this side effect acceptable? Should I remove this from this test? + // OnnxTransformer will drop all the columns in the input that have + // a corresponding input in the onnx model. + // A side effect of this, is that after applying a pretrained onnx model + // like the one in here, we won't be able to use or access the input + // columns of the model. + var transformedData = pipe.Fit(dataView).Transform(dataView); + var col = transformedData.Schema["data_0"]; + } + catch (ArgumentOutOfRangeException) { } } [OnnxTheory] diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index c342052627..b60caaa5f8 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -1699,8 +1699,10 @@ public void FeatureSelectionOnnxTest() - [Fact] - public void SelectColumnsOnnxTest() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ColumnSelectingOnnxTest(bool keepColumns) { var mlContext = new MLContext(seed: 1); @@ -1718,13 +1720,34 @@ public void SelectColumnsOnnxTest() new TextLoader.Column("Mitoses", DataKind.Int32, 9), }); - var pipeline = mlContext.Transforms.ReplaceMissingValues("Size").Append(mlContext.Transforms.SelectColumns(new[] { "Size", "Shape", "Thickness", "Label" })); + var pipeline = mlContext.Transforms.ReplaceMissingValues("Size") + .Append(mlContext.Transforms.SelectColumns(new[] { "Size", "Shape", "Thickness", "Label" })); + + if(!keepColumns) + { + // The ColumnSelectingTransformer can both select what columns to keep, as done above, + // or to choose what columns to drop. + // + // When keeping columns, it defaults to drop *all* hidden columns, + // when dropping columns it drops *all* the columns with the given names, + // but keeps the hidden columns with names not listed. + // + // Here, it drops columns: + pipeline = mlContext.Transforms.ReplaceMissingValues("Size") + .Append(mlContext.Transforms.DropColumns(new[] { "Adhesion", "EpithelialSize", "BlandChromatin", "NormalNucleoli", "Mitoses"})); + + // Current implementation of OnnxTransformer drops *all* input columns that are used as inputs of + // the onnx model. So it will have a behavior similar to the SelectColumns, but will differ from DropColumns. + } var model = pipeline.Fit(dataView); var transformedData = model.Transform(dataView); var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); var onnxFileName = "selectcolumns.onnx"; + if (!keepColumns) + onnxFileName = "dropcolumns.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); SaveOnnxModel(onnxModel, onnxModelPath, null); @@ -1739,11 +1762,29 @@ public void SelectColumnsOnnxTest() var onnxResult = onnxTransformer.Transform(dataView); // Verify that onnx output has only the four columns we selected from the input + // And the the output of the onnxmodel has the same length as the output schema Assert.Equal(4, outputNames.Length); - Assert.Equal("Size.output", outputNames[0]); - Assert.Equal("Shape.output", outputNames[1]); - Assert.Equal("Thickness.output", outputNames[2]); - Assert.Equal("Label.output", outputNames[3]); + + if (keepColumns) + { + // The order in the output of the onnx model is the same as in the SelectColumns parameters: + Assert.Equal("Size.output", outputNames[0]); + Assert.Equal("Shape.output", outputNames[1]); + Assert.Equal("Thickness.output", outputNames[2]); + Assert.Equal("Label.output", outputNames[3]); + + Assert.Equal(transformedData.Schema.Count, onnxResult.Schema.Count); // Both schemas have 4 columns + } + else + { + // The order in the output of the onnx model is the same as in the text loader: + Assert.Equal("Label.output", outputNames[0]); + Assert.Equal("Thickness.output", outputNames[1]); + Assert.Equal("Size.output", outputNames[2]); + Assert.Equal("Shape.output", outputNames[3]); + + Assert.Equal(transformedData.Schema.Count - 1, onnxResult.Schema.Count); // transformedData schema keeps the column hidden by ReplaceMissingValues + } CompareSelectedColumns("Size", "Size", transformedData, onnxResult); CompareSelectedColumns("Shape", "Shape", transformedData, onnxResult); @@ -1751,11 +1792,156 @@ public void SelectColumnsOnnxTest() CompareSelectedColumns("Label", "Label", transformedData, onnxResult); } - onnxFileName = "SelectColumns.txt"; - var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms"); - var onnxTextModelPath = GetOutputPath(subDir, onnxFileName); - SaveOnnxModel(onnxModel, null, onnxTextModelPath); - CheckEquality(subDir, onnxFileName, digitsOfPrecision: 1); + if(keepColumns) + { + onnxFileName = "SelectColumns.txt"; + var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms"); + var onnxTextModelPath = GetOutputPath(subDir, onnxFileName); + SaveOnnxModel(onnxModel, null, onnxTextModelPath); + CheckEquality(subDir, onnxFileName, digitsOfPrecision: 1); + } + + Done(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void ColumnSelectingOnnxTestColumnPropagation(bool saveToDisk) + { + // By default when exporting a ML.NET model to Onnx, + // the onnx model will contain an "input node" for every + // column in the input schema of the input dataview. + // "Output nodes" for each one of the columns in the output + // schema are also added. The input columns that were originally + // propagated to the output, also get a connection in the onnx model + // to propagate that information from each input node to the corresponding + // output node. + // + // Only when using the ColumnSelectingTransformer is it possible to drop + // input columns to make them dissapear from the output. Since they dissapear from + // the output schema, the exported onnx model doesn't contain output nodes for the + // dropped columns. The OnnxTransformer shouldn't propagate this input columns to the output + // but it should drop them as well. + // + // On the other hand, OnnxTransformer should propagate all the columns that are in its input schema, + // but which are not mentioned in the onnx model (particullarly not mentioned in the input nodes), + // because this would mean that the onnx model should simply ignore them, without dropping them. + // + // In general, to support the above behavior the following cases were agreed, and tested in here: + // + // Case 1. An input column that has a corresponding input node in the onnx model, but not connected to a corresponding output node, + // should be dropped and not propagated to the output schema. This case can only happen if the input column was dropped in the pipeline + // that was exported (case 1a), or if another column with the same name is created in the pipeline (case 1b). + // Case 2. An input column that has a corresponding input node in the onnx model directly connected to an output node with the same name, + // should be propagated from the input to the output schema. This is the usual behavior unless a column is dropped or hidden in the + // original pipeline (which then would be case 1). + // Case 3. A column that was created and dropped inside the pipeline that was exported to onnx, should be dropped. + // Case 4. A column that was created but not dropped inside the pipeline that was exported to onnx, should not be dropped + // Case 5. An input column that doesn't have a corresponding input node in the onnx model (i.e., that isn't even mentioned + // by the model) should be propagated as if being untouched. + // Case 6. An input column that doesn't have a corresponding input node in the onnx model, but that has the same name as an output + // column created by the onnx transformer, should be hidden. + // + // In summary, to behave correctly in those cases, it was decided to drop all input columns that are used as input nodes + // of the onnx model, and leave the other columns untouched. Then add all the output columns corresponding to the output nodes of + // the onnx model (this way case 2 and case 4 are accomplished, since the OnnxTransformer itself will actually create + // new columns and copy the original values to the new column, simulating schema propagation). Add the output columns to the left, + // so to hide any column that was propagated (following case 6). + + var mlContext = new MLContext(seed: 1); + + string dataPath = GetDataPath("breast-cancer.txt"); + + var dataView = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Size", DataKind.Single, 2), + new TextLoader.Column("Shape", DataKind.Single, 3), + new TextLoader.Column("Adhesion", DataKind.Single, 4), + new TextLoader.Column("EpithelialSize", DataKind.Single, 5), + new TextLoader.Column("BlandChromatin", DataKind.Single, 7), + }); + + var pipeline = + mlContext.Transforms.Concatenate("Features", "Shape", "Adhesion", "EpithelialSize", "BlandChromatin") + .Append(mlContext.BinaryClassification.Trainers.AveragedPerceptron()) + .Append(mlContext.Transforms.CopyColumns("Size", "EpithelialSize")) + .Append(mlContext.Transforms.SelectColumns(new[] { "Shape", "Size", "Label", "PredictedLabel" })); + + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + + var onnxFileName = "selectcolumns-columnpropagation.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + + SaveOnnxModel(onnxModel, onnxModelPath, null); + + if (IsOnnxRuntimeSupported()) + { + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + + IDataView onnxResult = onnxTransformer.Transform(dataView); + if (saveToDisk) + { + var modelPath = GetOutputPath("selectcolumns - columnpropagation.zip"); + mlContext.Model.Save(onnxTransformer, dataView.Schema, modelPath); + var loadedModel = mlContext.Model.Load(modelPath, out _); + onnxResult = loadedModel.Transform(dataView); + } + + // "Shape" and "Label" are case 2 so they will be propagated + // "Adhesion" and "BlandChromain" are case 1a so they won't be propagated + // The original "Size" is case 1b, because it's hidden by the CopyColumns + // "Features" is case 3, so won't be propagated + // The new "Size" is case 4, and "PredictedLabel" is case 4 so they will make it to the output + + Assert.Equal(4, transformedData.Schema.Count); + Assert.Equal(transformedData.Schema.Count, onnxResult.Schema.Count); + Assert.Equal(transformedData.Schema[0].Name, onnxResult.Schema[0].Name); // "Shape" + Assert.Equal(transformedData.Schema[1].Name, onnxResult.Schema[1].Name); // "Size" + Assert.Equal(transformedData.Schema[2].Name, onnxResult.Schema[2].Name); // "Label" + Assert.Equal(transformedData.Schema[3].Name, onnxResult.Schema[3].Name); // "PredictedLabel" + + CompareSelectedColumns("Shape", "Shape", transformedData, onnxResult); + CompareSelectedColumns("Size", "Size", transformedData, onnxResult); + CompareSelectedColumns("Label", "Label", transformedData, onnxResult); + CompareSelectedColumns("PredictedLabel", "PredictedLabel", transformedData, onnxResult); + + // Now load again the input, but also include more columns that weren't originally there when creating the onnx model inputs + var dataView2 = ML.Data.LoadFromTextFile(dataPath, new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("Size", DataKind.Single, 2), + new TextLoader.Column("Shape", DataKind.Single, 3), + new TextLoader.Column("Adhesion", DataKind.Single, 4), + new TextLoader.Column("EpithelialSize", DataKind.Single, 5), + new TextLoader.Column("BlandChromatin", DataKind.Single, 7), + new TextLoader.Column("NormalNucleoli", DataKind.Single, 8), // With a new name and column, to showcase Case 5, will propagate + new TextLoader.Column("PredictedLabel", DataKind.Single, 9), // With same name as onnx model output, to showcase Case 6, will be hidden + }); + + var onnxResult2 = onnxTransformer.Transform(dataView2); + Assert.Equal(6, onnxResult2.Schema.Count); // 5 propagated plus 1 hidden + + Assert.Equal("NormalNucleoli", onnxResult2.Schema[0].Name); + Assert.Equal("PredictedLabel", onnxResult2.Schema[1].Name); + Assert.True(onnxResult2.Schema[1].IsHidden); // Case 6 + + Assert.Equal("PredictedLabel", onnxResult2.Schema[2].Name); // It's moved before the others, because it hides the other column + Assert.Equal("Shape", onnxResult2.Schema[3].Name); + Assert.Equal("Size", onnxResult2.Schema[4].Name); + Assert.Equal("Label", onnxResult2.Schema[5].Name); + + // CompareSelectedColumns("NormalNucleoli", "NormalNucleoli", dataView2, onnxResult2); // MYTODO: can't use this method because it expects the column on the left dataview to be a vector type + CompareSelectedColumns("Shape", "Shape", onnxResult2, onnxResult); + CompareSelectedColumns("Size", "Size", onnxResult2, onnxResult); + CompareSelectedColumns("Label", "Label", onnxResult2, onnxResult); + CompareSelectedColumns("PredictedLabel", "PredictedLabel", onnxResult2, onnxResult); + } Done(); }