diff --git a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
index baf3400eaf..0579bc5f5b 100644
--- a/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
+++ b/src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
@@ -28,7 +28,9 @@ private protected RowToRowTransformerBase(IHost host)
bool ITransformer.IsRowToRowMapper => true;
- IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
+ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) => GetRowToRowMapperCore(inputSchema);
+
+ protected virtual IRowToRowMapper GetRowToRowMapperCore(DataViewSchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper);
@@ -37,16 +39,20 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
[BestFriend]
private protected abstract IRowMapper MakeRowMapper(DataViewSchema schema);
- public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
+ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => GetOutputSchemaCore(inputSchema);
+
+ protected virtual DataViewSchema GetOutputSchemaCore(DataViewSchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
var mapper = MakeRowMapper(inputSchema);
return RowToRowMapperTransform.GetOutputSchema(inputSchema, mapper);
}
- public IDataView Transform(IDataView input) => MakeDataTransform(input);
+ public IDataView Transform(IDataView input) => MakeDataTransformCore(input);
- [BestFriend]
+ private protected virtual IDataView MakeDataTransformCore(IDataView input) => MakeDataTransform(input);
+
+ [BestFriend] // MYTODO: Since this is "BestFriend" here, should I also make the MakeDataTransformCore in OnnxDataTransform a "BestFriend"? What's the purpose?
private protected RowToRowMapperTransform MakeDataTransform(IDataView input)
{
Host.CheckValue(input, nameof(input));
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
index 530e3193f3..a95f175313 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
@@ -37,7 +37,7 @@ namespace Microsoft.ML.Transforms.Onnx
///
/// resulting from fitting an .
///
- public sealed class OnnxTransformer : RowToRowTransformerBase
+ public sealed class OnnxTransformer : RowToRowTransformerBase // MYTODO: Should I consider not to inherit from this, since now OnnxTransformer would be able to drop columns and not use the RowToRowMapperTransform?
{
///
/// A class used for capturing shape information from command line.
@@ -134,12 +134,18 @@ private static VersionInfo GetVersionInfo()
// Factory method for SignatureDataTransform
private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
- return new OnnxTransformer(env, options).MakeDataTransform(input);
+ var transformer = new OnnxTransformer(env, options);
+ var mapper = new Mapper(transformer, input.Schema);
+ return new OnnxDataTransform(env, input, mapper);
}
// Factory method for SignatureLoadDataTransform
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
- => Create(env, ctx).MakeDataTransform(input);
+ {
+ var transformer = OnnxTransformer.Create(env, ctx);
+ var mapper = new Mapper(transformer, input.Schema);
+ return new OnnxDataTransform(env, input, mapper);
+ }
// Factory method for SignatureLoadModel.
private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
@@ -188,8 +194,7 @@ private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx
}
// Factory method for SignatureLoadRowMapper.
- private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
- => Create(env, ctx).MakeRowMapper(inputSchema);
+ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); // MYTODO: In what scenario is this called? Should I worry that the mapper, only by itself, isn't capable of dropping columns?
private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
@@ -325,7 +330,24 @@ private protected override void SaveModel(ModelSaveContext ctx)
}
}
- private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
+ private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema); // MYTODO: Could I erase this? If I stop inheriting from RTRTB?
+
+ protected override IRowToRowMapper GetRowToRowMapperCore(DataViewSchema inputSchema)
+ {
+ Host.CheckValue(inputSchema, nameof(inputSchema));
+ return new OnnxDataTransform(Host, new EmptyDataView(Host, inputSchema), new Mapper(this, inputSchema));
+ }
+
+ protected override DataViewSchema GetOutputSchemaCore(DataViewSchema inputSchema)
+ {
+ return OnnxDataTransform.GetOutputSchema(inputSchema, new Mapper(this, inputSchema));
+ }
+
+ private protected override IDataView MakeDataTransformCore(IDataView input)
+ {
+ Host.CheckValue(input, nameof(input));
+ return new OnnxDataTransform(Host, input, new Mapper(this, input.Schema));
+ }
///
/// This design assumes that all unknown dimensions are 1s. It also convert scalar shape [] in ONNX to [1].
@@ -340,6 +362,19 @@ private static IEnumerable AdjustDimensions(OnnxShape shape)
return new[] { 1 };
}
+ ///
+ /// In order to fully support onnx exportability from , it was decided
+ /// that the should drop all columns that are used as input of the Onnx model,
+ /// from the input schema.
+ ///
+ /// Any column that was already inside the input schema, but which isn't used by the onnx model itself,
+ /// should simply propagate to the output.
+ ///
+ internal string[] GetDropColumnsNames()
+ {
+ return Model.ModelInfo.InputNames.ToArray();
+ }
+
private sealed class Mapper : MapperBase
{
private readonly OnnxTransformer _parent;
@@ -407,6 +442,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
if (typeValueCount % valCount != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
}
+
}
protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
@@ -687,6 +723,422 @@ public NamedOnnxValue GetNamedOnnxValue()
return OnnxUtils.CreateNamedOnnxValue(_colName, _vBufferDense.GetValues(), _tensorShape);
}
}
+
+ ///
+ ///
+ ///
+ public string[] GetDropColumnsNames()
+ {
+ return _parent.GetDropColumnsNames();
+ }
+ }
+
+ ///
+ /// Similar to , but this class will enable dropping columns from the input
+ /// schema, in order to let OnnxTransformer support the onnx export.
+ ///
+ [BestFriend] // MYTODO: Is this necessary?
+ internal sealed class Bindings // MYTODO: Should I move this inside OnnxDataTransform?
+ {
+ // MYTODO: Should I simply inherit from ColumnBindings, since everything is the same except for the constructor (specifically, only, the way it created the _colMap)?
+
+ // Indices of columns in the merged schema. Old indices are as is, new indices are stored as ~idx.
+ private readonly int[] _colMap;
+
+ ///
+ /// The indices of added columns in the .
+ ///
+ public IReadOnlyList AddedColumnIndices { get; }
+
+ ///
+ /// The input schema.
+ ///
+ public DataViewSchema InputSchema { get; }
+
+ ///
+ /// The merged schema.
+ ///
+ public DataViewSchema Schema { get; }
+
+ ///
+ /// Create a new instance of .
+ ///
+ /// The input schema that we're adding columns to.
+ /// Names of the columns to drop, so that they don't propagate from the input schema
+ /// The columns being added.
+ public Bindings(DataViewSchema input, List dropColumnsNames, DataViewSchema.DetachedColumn[] addedColumns)
+ {
+ Contracts.CheckValue(input, nameof(input));
+ Contracts.CheckValue(addedColumns, nameof(addedColumns));
+
+ InputSchema = input;
+
+ // Construct the indices.
+ // Drop the indicated columns
+ // And drop all hidden columns
+ var indices = new List();
+ var namesUsed = new HashSet();
+ for (int i = 0; i < input.Count; i++)
+ {
+ if (InputSchema[i].IsHidden || dropColumnsNames.Contains(InputSchema[i].Name)) // MYTODO: Should I drop all hidden columns? Only the ones that are inside the dropColumnsNames list?
+ continue;
+
+ namesUsed.Add(input[i].Name);
+ indices.Add(i);
+ }
+
+ for (int i = 0; i < addedColumns.Length; i++)
+ {
+ string name = addedColumns[i].Name;
+ if (namesUsed.Add(name))
+ {
+ // New name. Append to the end.
+ indices.Add(~i);
+ }
+ else
+ {
+ // Old name. Find last instance and add after it.
+ for (int j = indices.Count - 1; j >= 0; j--)
+ {
+ var colName = indices[j] >= 0 ? input[indices[j]].Name : addedColumns[~indices[j]].Name;
+ if (colName == name)
+ {
+ indices.Insert(j + 1, ~i);
+ break;
+ }
+ }
+ }
+ }
+
+ // Contracts.Assert(indices.Count == addedColumns.Length + input.Count); // MYTODO: This assertion is no longer valid, and I can't think of a better one
+
+ // Create the output schema.
+ var schemaColumns = indices.Select(idx => idx >= 0 ? new DataViewSchema.DetachedColumn(input[idx]) : addedColumns[~idx]);
+ Schema = SchemaExtensions.MakeSchema(schemaColumns);
+
+ // Memorize column maps.
+ _colMap = indices.ToArray();
+ var addedIndices = new int[addedColumns.Length];
+ for (int i = 0; i < _colMap.Length; i++)
+ {
+ int colIndex = _colMap[i];
+ if (colIndex < 0)
+ {
+ Contracts.Assert(addedIndices[~colIndex] == 0);
+ addedIndices[~colIndex] = i;
+ }
+ }
+
+ AddedColumnIndices = addedIndices.AsReadOnly();
+ }
+
+ ///
+ /// This maps a column index for this schema to either a source column index (when
+ /// is true), or to an "iinfo" index of an added column
+ /// (when is false).
+ ///
+ /// Whether the return index is for a source column
+ /// The column index for this schema
+ /// The index (either source index or iinfo index)
+ public int MapColumnIndex(out bool isSrcColumn, int col)
+ {
+ Contracts.Assert(0 <= col && col < _colMap.Length);
+ int index = _colMap[col];
+ if (index < 0)
+ {
+ index = ~index;
+ Contracts.Assert(index < AddedColumnIndices.Count);
+ isSrcColumn = false;
+ }
+ else
+ {
+ Contracts.Assert(index < InputSchema.Count);
+ isSrcColumn = true;
+ }
+ return index;
+ }
+
+ ///
+ /// The given predicate maps from output column index to whether the column is active.
+ /// This builds an array of bools of length Input.ColumnCount containing the results of calling
+ /// predicate on the output column index corresponding to each input column index.
+ ///
+ public bool[] GetActiveInput(Func predicate)
+ {
+ Contracts.AssertValue(predicate);
+
+ var active = new bool[InputSchema.Count];
+ for (int dst = 0; dst < _colMap.Length; dst++)
+ {
+ int src = _colMap[dst];
+ Contracts.Assert(-AddedColumnIndices.Count <= src && src < InputSchema.Count);
+ if (src >= 0 && predicate(dst))
+ active[src] = true;
+ }
+ return active;
+ }
+ }
+
+ private class OnnxDataTransform : RowToRowTransformBase, IRowToRowMapper
+ {
+ // MYTODO: Is it even worth it to have this OnnxDataTransform class when it (including the RowImpl and Cursor)
+ // are identical to the RowToRowMapperTransform? The differences are:
+ // - This one expects specifically a OnnxTransformer.Mapper as _mapper from where to get the GetColumnsNames, whereas RTRMT expects a generic IRowMapper
+ // - This one has a _bindings object which is off type OnnxTransformer.Bindings, whereas RTRMT expects a generic ColumnsBindings
+ // - This one in here has a differend override for the Save method
+ // - This one in here doesn't have (but I don't know if it could have) methods related to SaveOnnx, SavePfa, ApplyToData, and VersionInfo of RTRMT.
+ // - RTRMT has an extra member called "_mapperFactory" that is used in ApplyToData
+
+ private protected override void SaveModel(ModelSaveContext ctx) => (_mapper as IRowMapper).Save(ctx); // MYTODO: This is the only thing that differ between this and RTRMT. Wonder if it would work if I used theirs instead?
+
+ private readonly Mapper _mapper;
+ private readonly Bindings _bindings;
+
+ public override DataViewSchema OutputSchema => _bindings.Schema;
+
+ public OnnxDataTransform(IHostEnvironment env, IDataView input, Mapper mapper)
+ : base(env.Register(nameof(OnnxDataTransform)), input)
+ {
+ _mapper = mapper;
+ _bindings = new Bindings(input.Schema, mapper.GetDropColumnsNames().ToList(), (mapper as IRowMapper).GetOutputColumns());
+ }
+
+ public static DataViewSchema GetOutputSchema(DataViewSchema inputSchema, Mapper mapper)
+ {
+ Contracts.CheckValue(inputSchema, nameof(inputSchema));
+ Contracts.CheckValue(mapper, nameof(mapper));
+ return new Bindings(inputSchema, mapper.GetDropColumnsNames().ToList(), (mapper as IRowMapper).GetOutputColumns()).Schema;
+ }
+
+ ///
+ /// Produces the set of active columns for the data view (as a bool[] of length bindings.ColumnCount),
+ /// and the needed active input columns, given a predicate for the needed active output columns.
+ ///
+ private bool[] GetActive(Func predicate, out IEnumerable inputColumns)
+ {
+ int n = _bindings.Schema.Count;
+ var active = Utils.BuildArray(n, predicate);
+ Contracts.Assert(active.Length == n);
+
+ var activeInput = _bindings.GetActiveInput(predicate);
+ Contracts.Assert(activeInput.Length == _bindings.InputSchema.Count);
+
+ // Get a predicate that determines which outputs are active.
+ var predicateOut = GetActiveOutputColumns(active);
+
+ // Now map those to active input columns.
+ var predicateIn = (_mapper as IRowMapper).GetDependencies(predicateOut);
+
+ // Combine the two sets of input columns.
+ inputColumns = _bindings.InputSchema.Where(col => activeInput[col.Index] || predicateIn(col.Index));
+
+ return active;
+ }
+
+ private Func GetActiveOutputColumns(bool[] active)
+ {
+ Contracts.AssertValue(active);
+ Contracts.Assert(active.Length == _bindings.Schema.Count);
+
+ return
+ col =>
+ {
+ Contracts.Assert(0 <= col && col < _bindings.AddedColumnIndices.Count);
+ return 0 <= col && col < _bindings.AddedColumnIndices.Count && active[_bindings.AddedColumnIndices[col]];
+ };
+ }
+
+ protected override bool? ShouldUseParallelCursors(Func predicate)
+ {
+ Host.AssertValue(predicate, "predicate");
+ if (_bindings.AddedColumnIndices.Any(predicate)) // MYTODO: This is copied from RowToRowMapperTransform. Why is this the case, and it ignores all the other columns that propagate from the input?
+ return true;
+ return null;
+ }
+
+ protected override DataViewRowCursor GetRowCursorCore(IEnumerable columnsNeeded, Random rand = null)
+ {
+ var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
+ var active = GetActive(predicate, out IEnumerable inputCols);
+
+ return new Cursor(Host, Source.GetRowCursor(inputCols, rand), this, active);
+ }
+
+ public override DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null)
+ {
+ Host.CheckValueOrNull(rand);
+
+ var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
+ var active = GetActive(predicate, out IEnumerable inputCols);
+
+ var inputs = Source.GetRowCursorSet(inputCols, n, rand);
+ Host.AssertNonEmpty(inputs);
+
+ if (inputs.Length == 1 && n > 1 && _bindings.AddedColumnIndices.Any(predicate)) // MYTODO: This is copied from TowToRowMapperTransform. Shouldn't the last check actually call ShouldUseParallel?
+ inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n);
+ Host.AssertNonEmpty(inputs);
+
+ var cursors = new DataViewRowCursor[inputs.Length];
+ for (int i = 0; i < inputs.Length; i++)
+ cursors[i] = new Cursor(Host, inputs[i], this, active);
+ return cursors;
+ }
+
+ ///
+ /// Given a set of output columns, return the input columns that are needed to generate those output columns.
+ ///
+ IEnumerable IRowToRowMapper.GetDependencies(IEnumerable dependingColumns)
+ {
+ var predicate = RowCursorUtils.FromColumnsToPredicate(dependingColumns, OutputSchema);
+ GetActive(predicate, out var inputColumns);
+ return inputColumns;
+ }
+
+ public DataViewSchema InputSchema => Source.Schema;
+
+ DataViewRow IRowToRowMapper.GetRow(DataViewRow input, IEnumerable activeColumns)
+ {
+ Host.CheckValue(input, nameof(input));
+ Host.CheckValue(activeColumns, nameof(activeColumns));
+ Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");
+
+ using (var ch = Host.Start("GetEntireRow"))
+ {
+ var activeArr = new bool[OutputSchema.Count];
+ foreach (var column in activeColumns)
+ {
+ Host.Assert(column.Index < activeArr.Length, $"The columns {activeColumns.Select(c => c.Name)} are not suitable for the OutputSchema.");
+ activeArr[column.Index] = true;
+ }
+ var pred = GetActiveOutputColumns(activeArr);
+ var getters = (_mapper as IRowMapper).CreateGetters(input, pred, out Action disp);
+
+ return new RowImpl(input, this, OutputSchema, getters, disp);
+ }
+ }
+
+ // MYTODO: Should I also copy in here the ApplyToData method from RowToRowMapperTransform?
+
+ private sealed class RowImpl : WrappingRow
+ {
+ private readonly Delegate[] _getters;
+ private readonly OnnxDataTransform _parent;
+ private readonly Action _disposer;
+
+ public override DataViewSchema Schema { get; }
+
+ public RowImpl(DataViewRow input, OnnxDataTransform parent, DataViewSchema schema, Delegate[] getters, Action disposer)
+ : base(input)
+ {
+ _parent = parent;
+ Schema = schema;
+ _getters = getters;
+ _disposer = disposer;
+ }
+
+ protected override void DisposeCore(bool disposing)
+ {
+ if (disposing)
+ _disposer?.Invoke();
+ }
+
+ ///
+ /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
+ /// This throws if the column is not active in this row, or if the type
+ /// differs from this column's type.
+ ///
+ /// is the column's content type.
+ /// is the output column whose getter should be returned.
+ public override ValueGetter GetGetter(DataViewSchema.Column column)
+ {
+ bool isSrc;
+ int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
+ if (isSrc)
+ return Input.GetGetter(Input.Schema[index]);
+
+ Contracts.Assert(_getters[index] != null);
+ var fn = _getters[index] as ValueGetter;
+ if (fn == null)
+ throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+
+ ///
+ /// Returns whether the given column is active in this row.
+ ///
+ public override bool IsColumnActive(DataViewSchema.Column column)
+ {
+ bool isSrc;
+ int index = _parent._bindings.MapColumnIndex(out isSrc, column.Index);
+ if (isSrc)
+ return Input.IsColumnActive(Schema[index]);
+ return _getters[index] != null;
+ }
+ }
+
+ private sealed class Cursor : SynchronizedCursorBase
+ {
+ private readonly Delegate[] _getters;
+ private readonly bool[] _active;
+ private readonly Bindings _bindings;
+ private readonly Action _disposer;
+ private bool _disposed;
+
+ public override DataViewSchema Schema => _bindings.Schema;
+
+ public Cursor(IChannelProvider provider, DataViewRowCursor realInput, OnnxDataTransform parent, bool[] active)
+ : base(provider, realInput)
+ {
+ var pred = parent.GetActiveOutputColumns(active);
+ _getters = (parent._mapper as IRowMapper).CreateGetters(realInput, pred, out _disposer);
+ _active = active;
+ _bindings = parent._bindings;
+ }
+
+ ///
+ /// Returns whether the given column is active in this row.
+ ///
+ public override bool IsColumnActive(DataViewSchema.Column column)
+ {
+ Ch.Check(column.Index < _bindings.Schema.Count);
+ return _active[column.Index];
+ }
+
+ ///
+ /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
+ /// This throws if the column is not active in this row, or if the type
+ /// differs from this column's type.
+ ///
+ /// is the column's content type.
+ /// is the output column whose getter should be returned.
+ public override ValueGetter GetGetter(DataViewSchema.Column column)
+ {
+ Ch.Check(IsColumnActive(column));
+
+ bool isSrc;
+ int index = _bindings.MapColumnIndex(out isSrc, column.Index);
+ if (isSrc)
+ return Input.GetGetter(Input.Schema[index]);
+
+ Ch.AssertValue(_getters);
+ var getter = _getters[index];
+ Ch.Assert(getter != null);
+ var fn = getter as ValueGetter;
+ if (fn == null)
+ throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+
+ protected override void Dispose(bool disposing)
+ {
+ if (_disposed)
+ return;
+ if (disposing)
+ _disposer?.Invoke();
+ _disposed = true;
+ base.Dispose(disposing);
+ }
+ }
}
}
@@ -775,8 +1227,6 @@ internal OnnxScoringEstimator(IHostEnvironment env, OnnxTransformer transformer)
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
- var result = inputSchema.ToDictionary(x => x.Name);
- var resultDic = inputSchema.ToDictionary(x => x.Name);
// This loop checks if all input columns needed in the underlying transformer can be found
// in inputSchema.
@@ -806,6 +1256,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
}
+ var droppedInputs = new List(Transformer.GetDropColumnsNames());
+ var resultDic = inputSchema.Where(col => !droppedInputs.Contains(col.Name)).ToDictionary(x => x.Name); // MYTODO: Is this enough? Does SchemaShape should also worry about "hidden" columns?
+
for (var i = 0; i < Transformer.Outputs.Length; i++)
{
resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
index ca9110260c..a48e2175b8 100644
--- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
+++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs
@@ -26,7 +26,7 @@ internal sealed class OnnxModel : IDisposable
{
///
/// OnnxModelInfo contains the data that we should get from
- /// OnnxRuntime API once that functionality is added.
+ /// the OnnxRuntime API
///
public sealed class OnnxModelInfo
{
diff --git a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
index c60351e54e..0788eb8505 100644
--- a/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
+++ b/test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
@@ -139,6 +139,19 @@ public void TestSimpleCase()
}
catch (ArgumentOutOfRangeException) { }
catch (InvalidOperationException) { }
+
+ try
+ {
+ // MYTODO: Is this side effect acceptable? Should I remove this from this test?
+ // OnnxTransformer will drop all the columns in the input that have
+ // a corresponding input in the onnx model.
+ // A side effect of this, is that after applying a pretrained onnx model
+ // like the one in here, we won't be able to use or access the input
+ // columns of the model.
+ var transformedData = pipe.Fit(dataView).Transform(dataView);
+ var col = transformedData.Schema["data_0"];
+ }
+ catch (ArgumentOutOfRangeException) { }
}
[OnnxTheory]
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index c342052627..b60caaa5f8 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -1699,8 +1699,10 @@ public void FeatureSelectionOnnxTest()
- [Fact]
- public void SelectColumnsOnnxTest()
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public void ColumnSelectingOnnxTest(bool keepColumns)
{
var mlContext = new MLContext(seed: 1);
@@ -1718,13 +1720,34 @@ public void SelectColumnsOnnxTest()
new TextLoader.Column("Mitoses", DataKind.Int32, 9),
});
- var pipeline = mlContext.Transforms.ReplaceMissingValues("Size").Append(mlContext.Transforms.SelectColumns(new[] { "Size", "Shape", "Thickness", "Label" }));
+ var pipeline = mlContext.Transforms.ReplaceMissingValues("Size")
+ .Append(mlContext.Transforms.SelectColumns(new[] { "Size", "Shape", "Thickness", "Label" }));
+
+ if(!keepColumns)
+ {
+ // The ColumnSelectingTransformer can both select what columns to keep, as done above,
+ // or to choose what columns to drop.
+ //
+ // When keeping columns, it defaults to drop *all* hidden columns,
+ // when dropping columns it drops *all* the columns with the given names,
+ // but keeps the hidden columns with names not listed.
+ //
+ // Here, it drops columns:
+ pipeline = mlContext.Transforms.ReplaceMissingValues("Size")
+ .Append(mlContext.Transforms.DropColumns(new[] { "Adhesion", "EpithelialSize", "BlandChromatin", "NormalNucleoli", "Mitoses"}));
+
+ // Current implementation of OnnxTransformer drops *all* input columns that are used as inputs of
+ // the onnx model. So it will have a behavior similar to the SelectColumns, but will differ from DropColumns.
+ }
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
var onnxFileName = "selectcolumns.onnx";
+ if (!keepColumns)
+ onnxFileName = "dropcolumns.onnx";
+
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
@@ -1739,11 +1762,29 @@ public void SelectColumnsOnnxTest()
var onnxResult = onnxTransformer.Transform(dataView);
// Verify that onnx output has only the four columns we selected from the input
+ // And the the output of the onnxmodel has the same length as the output schema
Assert.Equal(4, outputNames.Length);
- Assert.Equal("Size.output", outputNames[0]);
- Assert.Equal("Shape.output", outputNames[1]);
- Assert.Equal("Thickness.output", outputNames[2]);
- Assert.Equal("Label.output", outputNames[3]);
+
+ if (keepColumns)
+ {
+ // The order in the output of the onnx model is the same as in the SelectColumns parameters:
+ Assert.Equal("Size.output", outputNames[0]);
+ Assert.Equal("Shape.output", outputNames[1]);
+ Assert.Equal("Thickness.output", outputNames[2]);
+ Assert.Equal("Label.output", outputNames[3]);
+
+ Assert.Equal(transformedData.Schema.Count, onnxResult.Schema.Count); // Both schemas have 4 columns
+ }
+ else
+ {
+ // The order in the output of the onnx model is the same as in the text loader:
+ Assert.Equal("Label.output", outputNames[0]);
+ Assert.Equal("Thickness.output", outputNames[1]);
+ Assert.Equal("Size.output", outputNames[2]);
+ Assert.Equal("Shape.output", outputNames[3]);
+
+ Assert.Equal(transformedData.Schema.Count - 1, onnxResult.Schema.Count); // transformedData schema keeps the column hidden by ReplaceMissingValues
+ }
CompareSelectedColumns("Size", "Size", transformedData, onnxResult);
CompareSelectedColumns("Shape", "Shape", transformedData, onnxResult);
@@ -1751,11 +1792,156 @@ public void SelectColumnsOnnxTest()
CompareSelectedColumns("Label", "Label", transformedData, onnxResult);
}
- onnxFileName = "SelectColumns.txt";
- var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms");
- var onnxTextModelPath = GetOutputPath(subDir, onnxFileName);
- SaveOnnxModel(onnxModel, null, onnxTextModelPath);
- CheckEquality(subDir, onnxFileName, digitsOfPrecision: 1);
+ if(keepColumns)
+ {
+ onnxFileName = "SelectColumns.txt";
+ var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms");
+ var onnxTextModelPath = GetOutputPath(subDir, onnxFileName);
+ SaveOnnxModel(onnxModel, null, onnxTextModelPath);
+ CheckEquality(subDir, onnxFileName, digitsOfPrecision: 1);
+ }
+
+ Done();
+ }
+
+ [Theory]
+ [InlineData(false)]
+ [InlineData(true)]
+ public void ColumnSelectingOnnxTestColumnPropagation(bool saveToDisk)
+ {
+ // By default when exporting a ML.NET model to Onnx,
+ // the onnx model will contain an "input node" for every
+ // column in the input schema of the input dataview.
+ // "Output nodes" for each one of the columns in the output
+ // schema are also added. The input columns that were originally
+ // propagated to the output, also get a connection in the onnx model
+ // to propagate that information from each input node to the corresponding
+ // output node.
+ //
+ // Only when using the ColumnSelectingTransformer is it possible to drop
+ // input columns to make them dissapear from the output. Since they dissapear from
+ // the output schema, the exported onnx model doesn't contain output nodes for the
+ // dropped columns. The OnnxTransformer shouldn't propagate this input columns to the output
+ // but it should drop them as well.
+ //
+ // On the other hand, OnnxTransformer should propagate all the columns that are in its input schema,
+ // but which are not mentioned in the onnx model (particullarly not mentioned in the input nodes),
+ // because this would mean that the onnx model should simply ignore them, without dropping them.
+ //
+ // In general, to support the above behavior the following cases were agreed, and tested in here:
+ //
+ // Case 1. An input column that has a corresponding input node in the onnx model, but not connected to a corresponding output node,
+ // should be dropped and not propagated to the output schema. This case can only happen if the input column was dropped in the pipeline
+ // that was exported (case 1a), or if another column with the same name is created in the pipeline (case 1b).
+ // Case 2. An input column that has a corresponding input node in the onnx model directly connected to an output node with the same name,
+ // should be propagated from the input to the output schema. This is the usual behavior unless a column is dropped or hidden in the
+ // original pipeline (which then would be case 1).
+ // Case 3. A column that was created and dropped inside the pipeline that was exported to onnx, should be dropped.
+ // Case 4. A column that was created but not dropped inside the pipeline that was exported to onnx, should not be dropped
+ // Case 5. An input column that doesn't have a corresponding input node in the onnx model (i.e., that isn't even mentioned
+ // by the model) should be propagated as if being untouched.
+ // Case 6. An input column that doesn't have a corresponding input node in the onnx model, but that has the same name as an output
+ // column created by the onnx transformer, should be hidden.
+ //
+ // In summary, to behave correctly in those cases, it was decided to drop all input columns that are used as input nodes
+ // of the onnx model, and leave the other columns untouched. Then add all the output columns corresponding to the output nodes of
+ // the onnx model (this way case 2 and case 4 are accomplished, since the OnnxTransformer itself will actually create
+ // new columns and copy the original values to the new column, simulating schema propagation). Add the output columns to the left,
+ // so to hide any column that was propagated (following case 6).
+
+ var mlContext = new MLContext(seed: 1);
+
+ string dataPath = GetDataPath("breast-cancer.txt");
+
+ var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
+ new TextLoader.Column("Label", DataKind.Boolean, 0),
+ new TextLoader.Column("Size", DataKind.Single, 2),
+ new TextLoader.Column("Shape", DataKind.Single, 3),
+ new TextLoader.Column("Adhesion", DataKind.Single, 4),
+ new TextLoader.Column("EpithelialSize", DataKind.Single, 5),
+ new TextLoader.Column("BlandChromatin", DataKind.Single, 7),
+ });
+
+ var pipeline =
+ mlContext.Transforms.Concatenate("Features", "Shape", "Adhesion", "EpithelialSize", "BlandChromatin")
+ .Append(mlContext.BinaryClassification.Trainers.AveragedPerceptron())
+ .Append(mlContext.Transforms.CopyColumns("Size", "EpithelialSize"))
+ .Append(mlContext.Transforms.SelectColumns(new[] { "Shape", "Size", "Label", "PredictedLabel" }));
+
+ var model = pipeline.Fit(dataView);
+ var transformedData = model.Transform(dataView);
+ var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
+
+ var onnxFileName = "selectcolumns-columnpropagation.onnx";
+ var onnxModelPath = GetOutputPath(onnxFileName);
+
+ SaveOnnxModel(onnxModel, onnxModelPath, null);
+
+ if (IsOnnxRuntimeSupported())
+ {
+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
+ string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
+ string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
+ var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
+ var onnxTransformer = onnxEstimator.Fit(dataView);
+
+ IDataView onnxResult = onnxTransformer.Transform(dataView);
+ if (saveToDisk)
+ {
+ var modelPath = GetOutputPath("selectcolumns - columnpropagation.zip");
+ mlContext.Model.Save(onnxTransformer, dataView.Schema, modelPath);
+ var loadedModel = mlContext.Model.Load(modelPath, out _);
+ onnxResult = loadedModel.Transform(dataView);
+ }
+
+ // "Shape" and "Label" are case 2 so they will be propagated
+ // "Adhesion" and "BlandChromain" are case 1a so they won't be propagated
+ // The original "Size" is case 1b, because it's hidden by the CopyColumns
+ // "Features" is case 3, so won't be propagated
+ // The new "Size" is case 4, and "PredictedLabel" is case 4 so they will make it to the output
+
+ Assert.Equal(4, transformedData.Schema.Count);
+ Assert.Equal(transformedData.Schema.Count, onnxResult.Schema.Count);
+ Assert.Equal(transformedData.Schema[0].Name, onnxResult.Schema[0].Name); // "Shape"
+ Assert.Equal(transformedData.Schema[1].Name, onnxResult.Schema[1].Name); // "Size"
+ Assert.Equal(transformedData.Schema[2].Name, onnxResult.Schema[2].Name); // "Label"
+ Assert.Equal(transformedData.Schema[3].Name, onnxResult.Schema[3].Name); // "PredictedLabel"
+
+ CompareSelectedColumns("Shape", "Shape", transformedData, onnxResult);
+ CompareSelectedColumns("Size", "Size", transformedData, onnxResult);
+ CompareSelectedColumns("Label", "Label", transformedData, onnxResult);
+ CompareSelectedColumns("PredictedLabel", "PredictedLabel", transformedData, onnxResult);
+
+ // Now load again the input, but also include more columns that weren't originally there when creating the onnx model inputs
+ var dataView2 = ML.Data.LoadFromTextFile(dataPath, new[] {
+ new TextLoader.Column("Label", DataKind.Boolean, 0),
+ new TextLoader.Column("Size", DataKind.Single, 2),
+ new TextLoader.Column("Shape", DataKind.Single, 3),
+ new TextLoader.Column("Adhesion", DataKind.Single, 4),
+ new TextLoader.Column("EpithelialSize", DataKind.Single, 5),
+ new TextLoader.Column("BlandChromatin", DataKind.Single, 7),
+ new TextLoader.Column("NormalNucleoli", DataKind.Single, 8), // With a new name and column, to showcase Case 5, will propagate
+ new TextLoader.Column("PredictedLabel", DataKind.Single, 9), // With same name as onnx model output, to showcase Case 6, will be hidden
+ });
+
+ var onnxResult2 = onnxTransformer.Transform(dataView2);
+ Assert.Equal(6, onnxResult2.Schema.Count); // 5 propagated plus 1 hidden
+
+ Assert.Equal("NormalNucleoli", onnxResult2.Schema[0].Name);
+ Assert.Equal("PredictedLabel", onnxResult2.Schema[1].Name);
+ Assert.True(onnxResult2.Schema[1].IsHidden); // Case 6
+
+ Assert.Equal("PredictedLabel", onnxResult2.Schema[2].Name); // It's moved before the others, because it hides the other column
+ Assert.Equal("Shape", onnxResult2.Schema[3].Name);
+ Assert.Equal("Size", onnxResult2.Schema[4].Name);
+ Assert.Equal("Label", onnxResult2.Schema[5].Name);
+
+ // CompareSelectedColumns("NormalNucleoli", "NormalNucleoli", dataView2, onnxResult2); // MYTODO: can't use this method because it expects the column on the left dataview to be a vector type
+ CompareSelectedColumns("Shape", "Shape", onnxResult2, onnxResult);
+ CompareSelectedColumns("Size", "Size", onnxResult2, onnxResult);
+ CompareSelectedColumns("Label", "Label", onnxResult2, onnxResult);
+ CompareSelectedColumns("PredictedLabel", "PredictedLabel", onnxResult2, onnxResult);
+ }
Done();
}