Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7be5229
Merged with Harish commit with partial solution:
antoniovs1029 Feb 20, 2020
0ac2da1
Merge remote-tracking branch 'upstream/master' into is25onnxColSelect…
antoniovs1029 Mar 25, 2020
00a26b2
Revert "Merged with Harish commit with partial solution:"
antoniovs1029 Mar 25, 2020
944b066
Revert "Revert "Merged with Harish commit with partial solution:""
antoniovs1029 Mar 25, 2020
5a6a1b0
Modified test, because now output schema should be the same as onnx m…
antoniovs1029 Mar 25, 2020
4dbdc26
Actually check that the output schema has dropped the columns
antoniovs1029 Mar 25, 2020
a1b2b6e
Further modifications to make this work with all the existing tests
antoniovs1029 Mar 25, 2020
dfed82e
Remove unnecessary outputschema property on OnnxTransformer
antoniovs1029 Mar 25, 2020
816c66f
Move OutputSchema logic to OnnxDataTransform instead of Mapper
antoniovs1029 Mar 30, 2020
e928b3b
Added the use of ColumnBindings on OnnxDataTransform
antoniovs1029 Mar 30, 2020
d604cf4
Added MYTODO to comments
antoniovs1029 Mar 30, 2020
e9f4def
Still not working. GetActive() can't return inputcolumns from 2 diffe…
antoniovs1029 Apr 2, 2020
357648e
* Removed ColumnSelectingTransformer from OnnxDataTransform
antoniovs1029 Apr 7, 2020
0687f52
Drop columns inside OnnxDataTransformer.Bindings and added comments
antoniovs1029 Apr 7, 2020
9191ea9
Revert changes in OnnxTransformTests
antoniovs1029 Apr 7, 2020
88bd905
Added comment
antoniovs1029 Apr 7, 2020
183ba26
Added test for the different possible cases
antoniovs1029 Apr 9, 2020
315aa52
Added comments
antoniovs1029 Apr 9, 2020
6759324
Added test for drop columns
antoniovs1029 Apr 9, 2020
0fd1557
Fixed mistakes in ColumnSelectingOnnxTestColumnPropagation
antoniovs1029 Apr 9, 2020
f73e097
Added side effect test of dropping input columns
antoniovs1029 Apr 9, 2020
26e72ec
Updated comments
antoniovs1029 Apr 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/Microsoft.ML.Data/Transforms/RowToRowTransformerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ private protected RowToRowTransformerBase(IHost host)

bool ITransformer.IsRowToRowMapper => true;

IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) => GetRowToRowMapperCore(inputSchema);

protected virtual IRowToRowMapper GetRowToRowMapperCore(DataViewSchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
return new RowToRowMapperTransform(Host, new EmptyDataView(Host, inputSchema), MakeRowMapper(inputSchema), MakeRowMapper);
Expand All @@ -37,14 +39,18 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
[BestFriend]
private protected abstract IRowMapper MakeRowMapper(DataViewSchema schema);

public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => GetOutputSchemaCore(inputSchema);

protected virtual DataViewSchema GetOutputSchemaCore(DataViewSchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
var mapper = MakeRowMapper(inputSchema);
return RowToRowMapperTransform.GetOutputSchema(inputSchema, mapper);
}

public IDataView Transform(IDataView input) => MakeDataTransform(input);
public IDataView Transform(IDataView input) => MakeDataTransformCore(input);

private protected virtual IDataView MakeDataTransformCore(IDataView input) => MakeDataTransform(input);

[BestFriend]
private protected RowToRowMapperTransform MakeDataTransform(IDataView input)
Expand Down
226 changes: 221 additions & 5 deletions src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ internal sealed class Options : TransformInputBase
/// </summary>
internal DataViewType[] OutputTypes { get; }

public readonly DataViewSchema OutputSchema;

private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
Expand All @@ -134,12 +136,18 @@ private static VersionInfo GetVersionInfo()
// Factory method for SignatureDataTransform
private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
{
return new OnnxTransformer(env, options).MakeDataTransform(input);
var transformer = new OnnxTransformer(env, options);
var mapper = new Mapper(transformer, input.Schema);
return new OnnxDataTransform(env, input, mapper);
}

// Factory method for SignatureLoadDataTransform
private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
=> Create(env, ctx).MakeDataTransform(input);
{
var transformer = OnnxTransformer.Create(env, ctx);
var mapper = new Mapper(transformer, input.Schema);
return new OnnxDataTransform(env, input, mapper);
}

// Factory method for SignatureLoadModel.
private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
Expand Down Expand Up @@ -188,8 +196,7 @@ private static OnnxTransformer Create(IHostEnvironment env, ModelLoadContext ctx
}

// Factory method for SignatureLoadRowMapper.
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) => new Mapper(Create(env, ctx), inputSchema);

private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes = null) :
base(Contracts.CheckRef(env, nameof(env)).Register(nameof(OnnxTransformer)))
Expand Down Expand Up @@ -244,6 +251,12 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
OutputTypes[i] = outputInfo.DataViewType;
}
_options = options;

var schemaBuilder = new DataViewSchema.Builder();
for (var i = 0; i < Outputs.Length; i++)
schemaBuilder.AddColumn(Outputs[i], OutputTypes[i]);

OutputSchema = schemaBuilder.ToSchema();
}

/// <summary>
Expand Down Expand Up @@ -327,6 +340,23 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);

protected override IRowToRowMapper GetRowToRowMapperCore(DataViewSchema inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
return new OnnxDataTransform(Host, new EmptyDataView(Host, inputSchema), new Mapper(this, inputSchema));
}

protected override DataViewSchema GetOutputSchemaCore(DataViewSchema inputSchema)
{
return new OnnxDataTransform(Host, new EmptyDataView(Host, inputSchema), new Mapper(this, inputSchema)).OutputSchema;
}

private protected override IDataView MakeDataTransformCore(IDataView input)
{
Host.CheckValue(input, nameof(input));
return new OnnxDataTransform(Host, input, new Mapper(this, input.Schema));
}

Comment on lines +343 to +351

@antoniovs1029 antoniovs1029 Mar 25, 2020

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides the fact that I don't fully like modifying RowToRowTransformer to make these methods overridable, I don't like the way this looks either. It makes it kinda tricky to have a different transform set the output of this mapper, but it's the only way to go with this approach... it gets weirder because the OnnxDataTransform simply gets the output by accessing members that exist on this mapper.

But it's necessary that OnnxTransformer, OnnxTransformer.Mapper, OnnxDataTransform, and OnnxScoringEstimator all give the same output schema, since different code paths will require the output schema from any of these classes. So the entanglement here seems to be necessary. The only alternative is to have each of the classes determine the output schema by themselves, but it might become more difficult to maintain the same code on different places.

/// <summary>
/// This design assumes that all unknown dimensions are 1s. It also convert scalar shape [] in ONNX to [1].
/// [TODO] We should infer the unknown shape from input data instead of forcing them to be 1.
Expand Down Expand Up @@ -357,6 +387,11 @@ private sealed class Mapper : MapperBase
/// </summary>
private readonly Type[] _inputOnnxTypes;

private readonly DataViewSchema _outputSchema;

//public DataViewSchema OutputSchema => _parent.GetOutputSchema(InputSchema);
public DataViewSchema OutputSchema => _outputSchema;

public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
base(Contracts.CheckRef(parent, nameof(parent)).Host.Register(nameof(Mapper)), inputSchema, parent)
{
Expand Down Expand Up @@ -407,8 +442,12 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
if (typeValueCount % valCount != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {String.Join(",", inputShape)}, but input data is of length {typeValueCount}.");
}

_outputSchema = GetOutputSchema();
}

public DataViewSchema.DetachedColumn[] GetOutputColumns() => GetOutputColumnsCore();

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var stdSuffix = ".output";
Expand All @@ -426,6 +465,15 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
return info;
}

private DataViewSchema GetOutputSchema()
{
var infos = GetOutputColumnsCore();
var schemaBuilder = new DataViewSchema.Builder();
schemaBuilder.AddColumns(infos);

return schemaBuilder.ToSchema();
}

private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder builder)
{
var graph = _parent.Model.Graph;
Expand All @@ -450,6 +498,7 @@ private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder

builder.AddSlotNames(count, getter);
}
public Func<int, bool> GetDependencies(Func<int, bool> activeOutput) => GetDependenciesCore(activeOutput);

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
{
Expand Down Expand Up @@ -688,6 +737,173 @@ public NamedOnnxValue GetNamedOnnxValue()
}
}
}

private class OnnxDataTransform : TransformBase, IRowToRowMapper
{
private readonly Mapper _mapper;
private readonly IRowMapper _mapperIf;

public OnnxDataTransform(IHostEnvironment env, IDataView input, Mapper mapper)
:base(env.Register(nameof(OnnxDataTransform)), input)
{
_mapper = mapper;
_mapperIf = mapper as IRowMapper;
}

public DataViewSchema Schema => OutputSchema;

public DataViewSchema InputSchema => Source.Schema;

public override DataViewSchema OutputSchema => _mapper.OutputSchema;

public override long? GetRowCount() => Source.GetRowCount();

public void Save(ModelSaveContext ctx) => _mapperIf.Save(ctx);

public IEnumerable<DataViewSchema.Column> GetDependencies(IEnumerable<DataViewSchema.Column> dependingColumns)
{
return Source.Schema;
}

public DataViewRow GetRow(DataViewRow input, IEnumerable<DataViewSchema.Column> activeColumns)
{
Host.CheckValue(input, nameof(input));
Host.CheckValue(activeColumns, nameof(activeColumns));
Host.Check(input.Schema == Source.Schema, "Schema of input row must be the same as the schema the mapper is bound to");

using (var ch = Host.Start("GetEntireRow"))
{
var pred = RowCursorUtils.FromColumnsToPredicate(activeColumns, Schema);
var getters = _mapperIf.CreateGetters(input, pred, out Action disp);
return new RowImpl(input, this, Schema, getters, disp);
}
}

protected override DataViewRowCursor GetRowCursorCore(IEnumerable<DataViewSchema.Column> columnsNeeded, Random rand = null)
{
var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema);
var active = Utils.BuildArray(Schema.Count, predicate);
return new Cursor(Host, Source.GetRowCursor(Source.Schema, rand), this, active);
}

public override DataViewRowCursor[] GetRowCursorSet(IEnumerable<DataViewSchema.Column> columnsNeeded, int n, Random rand = null)
{
Host.CheckValueOrNull(rand);

var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, OutputSchema);
var active = Utils.BuildArray(Schema.Count, predicate);

var inputs = Source.GetRowCursorSet(Source.Schema, n, rand);
Host.AssertNonEmpty(inputs);

if (inputs.Length == 1 && n > 1 && Enumerable.Range(0, Schema.Count).Any(predicate))
inputs = DataViewUtils.CreateSplitCursors(Host, inputs[0], n);
Host.AssertNonEmpty(inputs);

var cursors = new DataViewRowCursor[inputs.Length];
for (int i = 0; i < inputs.Length; i++)
cursors[i] = new Cursor(Host, inputs[i], this, active);
return cursors;
}

private protected override void SaveModel(ModelSaveContext ctx) => _mapperIf.Save(ctx);

protected override bool? ShouldUseParallelCursors(Func<int, bool> predicate)
{
return true;
}

private sealed class RowImpl : WrappingRow
{
private readonly Delegate[] _getters;
private readonly OnnxDataTransform _parent;
private readonly Action _disposer;

public override DataViewSchema Schema { get; }

public RowImpl(DataViewRow input, OnnxDataTransform parent, DataViewSchema schema, Delegate[] getters, Action disposer)
: base(input)
{
_parent = parent;
Schema = schema;
_getters = getters;
_disposer = disposer;
}

protected override void DisposeCore(bool disposing)
{
if (disposing)
_disposer?.Invoke();
}

/// <summary>
/// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
/// This throws if the column is not active in this row, or if the type
/// <typeparamref name="TValue"/> differs from this column's type.
/// </summary>
/// <typeparam name="TValue"> is the column's content type.</typeparam>
/// <param name="column"> is the output column whose getter should be returned.</param>
public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
int index = column.Index;
Contracts.Assert(_getters[index] != null);
var fn = _getters[index] as ValueGetter<TValue>;
if (fn == null)
throw Contracts.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
return fn;
}

/// <summary>
/// Returns whether the given column is active in this row.
/// </summary>
public override bool IsColumnActive(DataViewSchema.Column column)
{
return _getters[column.Index] != null;
}
}

private sealed class Cursor : SynchronizedCursorBase
{
private readonly OnnxDataTransform _parent;
private readonly Delegate[] _getters;
private readonly bool[] _active;
private readonly Action _disposer;
private bool _disposed;

public Cursor(IChannelProvider provider, DataViewRowCursor input, OnnxDataTransform parent, bool[] active)
: base(provider, input)
{
_parent = parent;
Func<int, bool> pred = c => active[c];
_getters = parent._mapperIf.CreateGetters(input, pred, out _disposer);
_active = active;
}

public override DataViewSchema Schema => _parent.OutputSchema;

public override ValueGetter<TValue> GetGetter<TValue>(DataViewSchema.Column column)
{
var getter = _getters[column.Index];
Ch.Assert(getter != null);
var fn = getter as ValueGetter<TValue>;
if (fn == null)
throw Ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
return fn;
}

protected override void Dispose(bool disposing)
{
if (_disposed)
return;
if (disposing)
_disposer?.Invoke();
_disposed = true;
base.Dispose(disposing);
}

public override bool IsColumnActive(DataViewSchema.Column column) => _active[column.Index];
}
}
}

/// <summary>
Expand Down Expand Up @@ -772,7 +988,6 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
{
Host.CheckValue(inputSchema, nameof(inputSchema));
var result = inputSchema.ToDictionary(x => x.Name);
var resultDic = inputSchema.ToDictionary(x => x.Name);

// This loop checks if all input columns needed in the underlying transformer can be found
// in inputSchema.
Expand Down Expand Up @@ -802,6 +1017,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
}

var resultDic = new Dictionary<string, SchemaShape.Column>();
for (var i = 0; i < Transformer.Outputs.Length; i++)
{
resultDic[Transformer.Outputs[i]] = new SchemaShape.Column(Transformer.Outputs[i],
Expand Down
17 changes: 10 additions & 7 deletions test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,6 @@ private class ImageDataPoint
[ImageType(Height, Width)]
public Bitmap Image { get; set; }

/// <summary>
/// Output of ONNX model. It contains probabilities of all classes.
/// </summary>
[ColumnName("softmaxout_1")]
public float[] Scores { get; set; }

public ImageDataPoint()
{
Image = null;
Expand All @@ -385,6 +379,15 @@ public ImageDataPoint(Color color)
}
}

private class OutputImageDataPoint
{
/// <summary>
/// Output of ONNX model. It contains probabilities of all classes.
/// </summary>
[ColumnName("softmaxout_1")]
public float[] Scores { get; set; }
}

/// <summary>
/// Test applying ONNX transform on in-memory image.
/// </summary>
Expand Down Expand Up @@ -416,7 +419,7 @@ public void OnnxModelInMemoryImage()
// Convert IDataView back to IEnumerable<ImageDataPoint> so that user can inspect the output, column "softmaxout_1", of the ONNX transform.
// Note that Column "softmaxout_1" would be stored in ImageDataPont.Scores because the added attributed [ColumnName("softmaxout_1")]
// tells that ImageDataPont.Scores is equivalent to column "softmaxout_1".
var transformedDataPoints = ML.Data.CreateEnumerable<ImageDataPoint>(onnx, false).ToList();

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test needed to be changed because now we're not propagating the input Image column to the output.

var transformedDataPoints = ML.Data.CreateEnumerable<OutputImageDataPoint>(onnx, false).ToList();

// The scores are probabilities of all possible classes, so they should all be positive.
foreach (var dataPoint in transformedDataPoints)
Expand Down
2 changes: 2 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1743,7 +1743,9 @@ public void SelectColumnsOnnxTest()
var onnxResult = onnxTransformer.Transform(dataView);

// Verify that onnx output has only the four columns we selected from the input
// And the the output of the onnxmodel has the same length as the output schema
Assert.Equal(4, outputNames.Length);
Assert.Equal(outputNames.Length, onnxResult.Schema.Count);
Assert.Equal("Size.output", outputNames[0]);
Assert.Equal("Shape.output", outputNames[1]);
Assert.Equal("Thickness.output", outputNames[2]);
Expand Down