-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Image transforms become Estimators #753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
70d7611
1eeda48
5e32910
f0469e1
d2073eb
f05ec52
a03994b
ef5cb62
55b33af
ac46be9
013bfa4
11aa1fc
281e731
baa5c34
245c344
e7dc348
d8c0648
3e845fd
fe5374f
e660eeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using Microsoft.ML.Core.Data; | ||
|
|
||
| namespace Microsoft.ML.Runtime.Data | ||
| { | ||
| /// <summary> | ||
| /// The trivial implementation of <see cref="IEstimator{TTransformer}"/> that already has | ||
| /// the transformer and returns it on every call to <see cref="Fit(IDataView)"/>. | ||
| /// | ||
| /// Concrete implementations still have to provide the schema propagation mechanism, since | ||
| /// there is no easy way to infer it from the transformer. | ||
| /// </summary> | ||
| public abstract class TrivialEstimator<TTransformer> : IEstimator<TTransformer> | ||
| where TTransformer : class, ITransformer | ||
| { | ||
| protected readonly IHost Host; | ||
| protected readonly TTransformer Transformer; | ||
|
|
||
| protected TrivialEstimator(IHost host, TTransformer transformer) | ||
| { | ||
| Contracts.AssertValue(host); | ||
|
|
||
| Host = host; | ||
| Host.CheckValue(transformer, nameof(transformer)); | ||
| Transformer = transformer; | ||
| } | ||
|
|
||
| public TTransformer Fit(IDataView input) => Transformer; | ||
|
|
||
| public abstract SchemaShape GetOutputSchema(SchemaShape inputSchema); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| // Licensed to the .NET Foundation under one or more agreements. | ||
| // The .NET Foundation licenses this file to you under the MIT license. | ||
| // See the LICENSE file in the project root for more information. | ||
|
|
||
| using System; | ||
| using System.Collections.Generic; | ||
| using System.Linq; | ||
| using Microsoft.ML.Core.Data; | ||
| using Microsoft.ML.Runtime.Model; | ||
|
|
||
| namespace Microsoft.ML.Runtime.Data | ||
| { | ||
| public abstract class OneToOneTransformerBase : ITransformer, ICanSaveModel | ||
| { | ||
| protected readonly IHost Host; | ||
| protected readonly (string input, string output)[] ColumnPairs; | ||
|
|
||
| protected OneToOneTransformerBase(IHost host, (string input, string output)[] columns) | ||
| { | ||
| Contracts.AssertValue(host); | ||
| host.CheckValue(columns, nameof(columns)); | ||
|
|
||
| var newNames = new HashSet<string>(); | ||
| foreach (var column in columns) | ||
| { | ||
| host.CheckNonEmpty(column.input, nameof(columns)); | ||
| host.CheckNonEmpty(column.output, nameof(columns)); | ||
|
|
||
| if (!newNames.Add(column.output)) | ||
| throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); | ||
| } | ||
|
|
||
| Host = host; | ||
| ColumnPairs = columns; | ||
| } | ||
|
|
||
| protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) | ||
| { | ||
| Host = host; | ||
| // *** Binary format *** | ||
| // int: number of added columns | ||
| // for each added column | ||
| // int: id of output column name | ||
| // int: id of input column name | ||
|
|
||
| int n = ctx.Reader.ReadInt32(); | ||
| ColumnPairs = new (string input, string output)[n]; | ||
| for (int i = 0; i < n; i++) | ||
| { | ||
| string output = ctx.LoadNonEmptyString(); | ||
| string input = ctx.LoadNonEmptyString(); | ||
| ColumnPairs[i] = (input, output); | ||
| } | ||
| } | ||
|
|
||
| public abstract void Save(ModelSaveContext ctx); | ||
|
|
||
| protected void SaveColumns(ModelSaveContext ctx) | ||
| { | ||
| Host.CheckValue(ctx, nameof(ctx)); | ||
|
|
||
| // *** Binary format *** | ||
| // int: number of added columns | ||
| // for each added column | ||
| // int: id of output column name | ||
| // int: id of input column name | ||
|
|
||
| ctx.Writer.Write(ColumnPairs.Length); | ||
| for (int i = 0; i < ColumnPairs.Length; i++) | ||
| { | ||
| ctx.SaveNonEmptyString(ColumnPairs[i].output); | ||
| ctx.SaveNonEmptyString(ColumnPairs[i].input); | ||
| } | ||
| } | ||
|
|
||
| private void CheckInput(ISchema inputSchema, int col, out int srcCol) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
make this public and call it in GetOutputSchema in each estimator. #ByDesign
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cannot do this, because GetOutputSchema in the estimator operates over In reply to: 213863333 [](ancestors = 213863333)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if only ISchema and SchemaShape were relatives.... In reply to: 213863552 [](ancestors = 213863552,213863333) |
||
| { | ||
| Contracts.AssertValue(inputSchema); | ||
| Contracts.Assert(0 <= col && col < ColumnPairs.Length); | ||
|
|
||
| if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol)) | ||
| throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input); | ||
| CheckInputColumn(inputSchema, col, srcCol); | ||
| } | ||
|
|
||
| protected virtual void CheckInputColumn(ISchema inputSchema, int col, int srcCol) | ||
| { | ||
| // By default, there are no extra checks. | ||
| } | ||
|
|
||
| protected abstract IRowMapper MakeRowMapper(ISchema schema); | ||
|
|
||
| public ISchema GetOutputSchema(ISchema inputSchema) | ||
| { | ||
| Host.CheckValue(inputSchema, nameof(inputSchema)); | ||
|
|
||
| // Check that all the input columns are present and correct. | ||
| for (int i = 0; i < ColumnPairs.Length; i++) | ||
| CheckInput(inputSchema, i, out int col); | ||
|
|
||
| return Transform(new EmptyDataView(Host, inputSchema)).Schema; | ||
| } | ||
|
|
||
| public IDataView Transform(IDataView input) => MakeDataTransform(input); | ||
|
|
||
| protected RowToRowMapperTransform MakeDataTransform(IDataView input) | ||
| { | ||
| Host.CheckValue(input, nameof(input)); | ||
| return new RowToRowMapperTransform(Host, input, MakeRowMapper(input.Schema)); | ||
| } | ||
|
|
||
| protected abstract class MapperBase : IRowMapper | ||
| { | ||
| protected readonly IHost Host; | ||
| protected readonly Dictionary<int, int> ColMapNewToOld; | ||
| protected readonly ISchema InputSchema; | ||
| private readonly OneToOneTransformerBase _parent; | ||
|
|
||
| protected MapperBase(IHost host, OneToOneTransformerBase parent, ISchema inputSchema) | ||
| { | ||
| Contracts.AssertValue(host); | ||
| Contracts.AssertValue(parent); | ||
| Contracts.AssertValue(inputSchema); | ||
|
|
||
| Host = host; | ||
| _parent = parent; | ||
|
|
||
| ColMapNewToOld = new Dictionary<int, int>(); | ||
| for (int i = 0; i < _parent.ColumnPairs.Length; i++) | ||
| { | ||
| _parent.CheckInput(inputSchema, i, out int srcCol); | ||
| ColMapNewToOld.Add(i, srcCol); | ||
| } | ||
| InputSchema = inputSchema; | ||
| } | ||
| public Func<int, bool> GetDependencies(Func<int, bool> activeOutput) | ||
| { | ||
| var active = new bool[InputSchema.ColumnCount]; | ||
| foreach (var pair in ColMapNewToOld) | ||
| if (activeOutput(pair.Key)) | ||
| active[pair.Value] = true; | ||
| return col => active[col]; | ||
| } | ||
|
|
||
| public abstract RowMapperColumnInfo[] GetOutputColumns(); | ||
|
|
||
| public void Save(ModelSaveContext ctx) => _parent.Save(ctx); | ||
|
|
||
| public Delegate[] CreateGetters(IRow input, Func<int, bool> activeOutput, out Action disposer) | ||
| { | ||
| Contracts.Assert(input.Schema == InputSchema); | ||
| var result = new Delegate[_parent.ColumnPairs.Length]; | ||
| var disposers = new Action[_parent.ColumnPairs.Length]; | ||
| for (int i = 0; i < _parent.ColumnPairs.Length; i++) | ||
| { | ||
| if (!activeOutput(i)) | ||
| continue; | ||
| int srcCol = ColMapNewToOld[i]; | ||
| result[i] = MakeGetter(input, i, out disposers[i]); | ||
| } | ||
| if (disposers.Any(x => x != null)) | ||
| { | ||
| disposer = () => | ||
| { | ||
| foreach (var act in disposers) | ||
| act(); | ||
| }; | ||
| } | ||
| else | ||
| disposer = null; | ||
| return result; | ||
| } | ||
|
|
||
| protected abstract Delegate MakeGetter(IRow input, int iinfo, out Action disposer); | ||
| } | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is somewhat problematic for key columns. How do we represent the two distinct concepts of a key-columns of known and unknown counts? #Pending
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well,
ItemTypeis never supposed to be a key. For example, for a scalar key column, the correct representation isKind = Scalar, ItemType = PrimitiveType.U4, IsKey = true.I think I should enforce it in constructor even.
I don't think we should even have key types of unknown counts: it is already causing issues for
Termtransform, and I'm yet to see any benefit of this type.In reply to: 213151407 [](ancestors = 213151407)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that keys of unknown size should go away, but are we going to do this in this PR?
That we're representing a key-type using a column type which is not a key-type at all, and whose only connection to it is that their
DataKindhappen to be the same, is rather odd and unfortunate. I hope we can imagine a better way here, though I'm not sure I see one right off the bat.In reply to: 213156549 [](ancestors = 213156549,213151407)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are also representing a vector type using a column type which is not a vector-type at all, and whose only connection to it is that their
ItemTypehappen to be the same :)In reply to: 214125274 [](ancestors = 214125274,213156549,213151407)