diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs
new file mode 100644
index 0000000000..537f0611b9
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs
@@ -0,0 +1,660 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML;
+using Microsoft.ML.CommandLine;
+using Microsoft.ML.Data;
+using Microsoft.ML.EntryPoints;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Internal.Utilities;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Transforms;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+
+[assembly: LoadableClass(typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadModel),
+ TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)]
+
+[assembly: LoadableClass(typeof(IDataTransform), typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadDataTransform),
+ TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)]
+
+[assembly: EntryPointModule(typeof(TimeSeriesTransformerEntrypoint))]
+
+namespace Microsoft.ML.Featurizers
+{
+ public static class TimeSeriesImputerExtensionClass
+ {
+ ///
+ /// Create a , Imputes missing rows and column data per grain. Operates on all columns in the IDataView.
+ /// Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the
+ /// purpose of this estimator. Other column types will have the default value placed if a row is imputed.
+ ///
+ /// The transform catalog.
+ /// Column representing the time series. Should be of type
+ /// List of columns to use as grains
+ /// Mode of imputation for missing values in column. If not passed defaults to forward fill
+ public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns,
+ TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill)
+ => new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, null, TimeSeriesImputerEstimator.FilterMode.NoFilter, imputeMode, true);
+
+ ///
+ /// Create a , Imputes missing rows and column data per grain. Applies the imputation strategy on
+ /// a filtered list of columns in the IDataView. Columns that are are excluded will have the default value for that data type used when a row
+ /// is imputed. Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the
+ /// purpose of this estimator.
+ ///
+ /// The transform catalog.
+ /// Column representing the time series. Should be of type
+ /// List of columns to use as grains
+ /// List of columns to filter. If is than columns in the list will be ignored.
+ /// If is than values in the list are the only columns imputed.
+ /// Whether the list should include or exclude those columns.
+ /// Mode of imputation for missing values in column. If not passed defaults to forward fill
+ /// Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column and use the default value. If false, will stop and throw an error.
+ public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this TransformsCatalog catalog, string timeSeriesColumn,
+ string[] grainColumns, string[] filterColumns, TimeSeriesImputerEstimator.FilterMode filterMode = TimeSeriesImputerEstimator.FilterMode.Exclude,
+ TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill,
+ bool suppressTypeErrors = false)
+ => new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, filterColumns, filterMode, imputeMode, suppressTypeErrors);
+ }
+
+ ///
+ /// Imputes missing rows and column data per grain, based on the dates in the date column. This operation needs to happen to every column in the IDataView,
+ /// If you "filter" a column using the filterColumns and filterMode parameters, if a row is imputed the default value for that type will be used.
+ /// Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the
+ /// purpose of this estimator. A new column is added to the schema after this operation is run. The column is called "IsRowImputed" and is a
+ /// boolean value representing if the row was created as a result of this operation or not.
+ ///
+ /// NOTE: It is not recommended to chain this multiple times. If a column is filtered, the default value is placed when a row is imputed, and the
+ /// default value is not null. Thus any other TimeSeriesImputers will not be able to replace those values anymore causing essentially a very
+ /// computationally expensive NO-OP.
+ ///
+ ///
+ ///
+ /// is not a trivial estimator and needs training.
+ ///
+ ///
+ /// ]]>
+ ///
+ ///
+ ///
+ ///
+ public sealed class TimeSeriesImputerEstimator : IEstimator
+ {
+ private Options _options;
+ internal const string IsRowImputedColumnName = "IsRowImputed";
+
+ private readonly IHost _host;
+ private static readonly List _currentSupportedTypes = new List { typeof(sbyte), typeof(byte), typeof(short), typeof(ushort), typeof(int), typeof(uint),
+ typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(string), typeof(ReadOnlyMemory)};
+
+ #region Options
+ internal sealed class Options : TransformInputBase
+ {
+ [Argument(ArgumentType.Required, HelpText = "Column representing the time", Name = "TimeSeriesColumn", ShortName = "time", SortOrder = 1)]
+ public string TimeSeriesColumn;
+
+ [Argument((ArgumentType.MultipleUnique | ArgumentType.Required), HelpText = "List of grain columns", Name = "GrainColumns", ShortName = "grains", SortOrder = 2)]
+ public string[] GrainColumns;
+
+ // This transformer adds columns
+ [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to filter", Name = "FilterColumns", ShortName = "filters", SortOrder = 2)]
+ public string[] FilterColumns;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Filter mode. Either include or exclude", Name = "FilterMode", ShortName = "fmode", SortOrder = 3)]
+ public FilterMode FilterMode = FilterMode.Exclude;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Mode for imputing, defaults to ForwardFill if not provided", Name = "ImputeMode", ShortName = "mode", SortOrder = 3)]
+ public ImputationStrategy ImputeMode = ImputationStrategy.ForwardFill;
+
+ [Argument(ArgumentType.AtMostOnce, HelpText = "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.", Name = "SupressTypeErrors", ShortName = "error", SortOrder = 3)]
+ public bool SupressTypeErrors = false;
+ }
+
+ #endregion
+
+ #region Class Enums
+
+ public enum ImputationStrategy : byte
+ {
+ ForwardFill = 1,
+ BackFill = 2,
+ Median = 3,
+ // Interpolate = 4, interpolate not currently supported in the native code.
+ };
+
+ public enum FilterMode : byte
+ {
+ NoFilter = 1,
+ Include = 2,
+ Exclude = 3
+ };
+
+ #endregion
+
+ internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, FilterMode filterMode, ImputationStrategy imputeMode, bool supressTypeErrors)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
+ _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
+ _host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null.");
+ _host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column.");
+ _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
+ if (filterMode == FilterMode.Include)
+ _host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified");
+
+ _options = new Options
+ {
+ TimeSeriesColumn = timeSeriesColumn,
+ GrainColumns = grainColumns,
+ FilterColumns = filterColumns == null ? new string[] { } : filterColumns,
+ FilterMode = filterMode,
+ ImputeMode = imputeMode,
+ SupressTypeErrors = supressTypeErrors
+ };
+ }
+
+ internal TimeSeriesImputerEstimator(IHostEnvironment env, Options options)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
+ _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator");
+ _host.CheckValue(options.TimeSeriesColumn, nameof(options.TimeSeriesColumn), "TimePoint column should not be null.");
+ _host.CheckValue(options.GrainColumns, nameof(options.GrainColumns), "Grain columns should not be null.");
+ _host.CheckNonEmpty(options.GrainColumns, nameof(options.GrainColumns), "Need at least one grain column.");
+ _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
+ if (options.FilterMode != FilterMode.NoFilter)
+ _host.CheckNonEmpty(options.FilterColumns, nameof(options.FilterColumns), "Need at least 1 filter column if a FilterMode is specified");
+
+ _options = options;
+ }
+
+ public TimeSeriesImputerTransformer Fit(IDataView input)
+ {
+ // If we are not suppressing type errors make sure columns to impute only contain supported types.
+ if (!_options.SupressTypeErrors)
+ {
+ var columns = input.Schema.Where(x => !_options.GrainColumns.Contains(x.Name));
+ if (_options.FilterMode == FilterMode.Exclude)
+ columns = columns.Where(x => !_options.FilterColumns.Contains(x.Name));
+ else if (_options.FilterMode == FilterMode.Include)
+ columns = columns.Where(x => _options.FilterColumns.Contains(x.Name));
+
+ foreach (var column in columns)
+ {
+ if (!_currentSupportedTypes.Contains(column.Type.RawType))
+ throw new InvalidOperationException($"Type {column.Type.RawType.ToString()} for column {column.Name} not a supported type.");
+ }
+ }
+
+ return new TimeSeriesImputerTransformer(_host, _options, input);
+ }
+
+ // Add one column called WasColumnImputed, otherwise everything stays the same.
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+ columns[IsRowImputedColumnName] = new SchemaShape.Column(IsRowImputedColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false);
+ return new SchemaShape(columns.Values);
+ }
+ }
+
+ public sealed class TimeSeriesImputerTransformer : ITransformer, IDisposable
+ {
+ #region Class data members
+
+ internal const string Summary = "Fills in missing row and values";
+ internal const string UserName = "TimeSeriesImputer";
+ internal const string ShortName = "tsi";
+ internal const string LoadName = "TimeSeriesImputer";
+ internal const string LoaderSignature = "TimeSeriesImputer";
+
+ private readonly IHost _host;
+ private readonly string _timeSeriesColumn;
+ private readonly string[] _grainColumns;
+ private readonly string[] _dataColumns;
+ private readonly string[] _allColumnNames;
+ private readonly bool _suppressTypeErrors;
+ private readonly TimeSeriesImputerEstimator.ImputationStrategy _imputeMode;
+ internal TransformerEstimatorSafeHandle TransformerHandle;
+
+ #endregion
+
+ // Normal constructor.
+ internal TimeSeriesImputerTransformer(IHostEnvironment host, TimeSeriesImputerEstimator.Options options, IDataView input)
+ {
+ _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
+
+ _host = host.Register(nameof(TimeSeriesImputerTransformer));
+ _timeSeriesColumn = options.TimeSeriesColumn;
+ _grainColumns = options.GrainColumns;
+ _imputeMode = options.ImputeMode;
+ _suppressTypeErrors = options.SupressTypeErrors;
+
+ IEnumerable tempDataColumns;
+
+ if (options.FilterMode == TimeSeriesImputerEstimator.FilterMode.Exclude)
+ tempDataColumns = input.Schema.Where(x => !options.FilterColumns.Contains(x.Name)).Select(x => x.Name);
+ else if (options.FilterMode == TimeSeriesImputerEstimator.FilterMode.Include)
+ tempDataColumns = input.Schema.Where(x => options.FilterColumns.Contains(x.Name)).Select(x => x.Name);
+ else
+ tempDataColumns = input.Schema.Select(x => x.Name);
+
+ // Time series and Grain columns should never be included in the data columns
+ _dataColumns = tempDataColumns.Where(x => x != _timeSeriesColumn && !_grainColumns.Contains(x)).ToArray();
+
+ // 1 is for the time series column. Make one array in the correct order of all the columns.
+ // Order is Timeseries column, All grain columns, All data columns.
+ _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length];
+ _allColumnNames[0] = _timeSeriesColumn;
+ Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length);
+ Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length);
+
+ TransformerHandle = CreateTransformerFromEstimator(input);
+ }
+
+ // Factory method for SignatureLoadModel.
+ internal TimeSeriesImputerTransformer(IHostEnvironment host, ModelLoadContext ctx)
+ {
+ _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");
+ _host = host.Register(nameof(TimeSeriesImputerTransformer));
+
+ // *** Binary format ***
+ // name of time series column
+ // length of grain column array
+ // all column names in grain column array
+ // length of filter column array
+ // all column names in filter column array
+ // byte value of filter mode
+ // byte value of impute mode
+ // length of C++ state array
+ // C++ byte state array
+
+ _timeSeriesColumn = ctx.Reader.ReadString();
+
+ _grainColumns = new string[ctx.Reader.ReadInt32()];
+ for (int i = 0; i < _grainColumns.Length; i++)
+ _grainColumns[i] = ctx.Reader.ReadString();
+
+ _dataColumns = new string[ctx.Reader.ReadInt32()];
+ for (int i = 0; i < _dataColumns.Length; i++)
+ _dataColumns[i] = ctx.Reader.ReadString();
+
+ _imputeMode = (TimeSeriesImputerEstimator.ImputationStrategy)ctx.Reader.ReadByte();
+
+ _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length];
+ _allColumnNames[0] = _timeSeriesColumn;
+ Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length);
+ Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length);
+
+ var nativeState = ctx.Reader.ReadByteArray();
+ TransformerHandle = CreateTransformerFromSavedData(nativeState);
+ }
+
+ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromSavedData(byte[] nativeState)
+ {
+ fixed (byte* rawStatePointer = nativeState)
+ {
+ IntPtr dataSize = new IntPtr(nativeState.Count());
+ var result = CreateTransformerFromSavedDataNative(rawStatePointer, dataSize, out IntPtr transformer, out IntPtr errorHandle);
+ if (!result)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+ }
+
+ // Factory method for SignatureLoadDataTransform.
+ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
+ {
+ return (IDataTransform)(new TimeSeriesImputerTransformer(env, ctx).Transform(input));
+ }
+
+ private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input)
+ {
+ IntPtr estimator;
+ IntPtr errorHandle;
+ bool success;
+
+ var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name);
+
+ // Create buffer to hold binary data
+ var columnBuffer = new byte[4096];
+
+ // Create TypeId[] for types of grain and data columns;
+ var dataColumnTypes = new TypeId[_dataColumns.Length];
+ var grainColumnTypes = new TypeId[_grainColumns.Length];
+
+ foreach (var column in _grainColumns.Select((value, index) => new { index, value }))
+ grainColumnTypes[column.index] = allColumns[column.value].GetTypeId();
+
+ foreach (var column in _dataColumns.Select((value, index) => new { index, value }))
+ dataColumnTypes[column.index] = allColumns[column.value].GetTypeId();
+
+ fixed (bool* suppressErrors = &_suppressTypeErrors)
+ fixed (TypeId* rawDataColumnTypes = dataColumnTypes)
+ fixed (TypeId* rawGrainColumnTypes = grainColumnTypes)
+ {
+ success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle);
+ }
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative))
+ {
+ var fitResult = FitResult.Continue;
+ while (fitResult != FitResult.Complete)
+ {
+ using (var cursor = input.GetRowCursorForAllColumns())
+ {
+ // Initialize getters for start of loop
+ foreach (var column in allColumns.Values)
+ column.InitializeGetter(cursor);
+
+ while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext())
+ {
+ BuildColumnByteArray(allColumns, ref columnBuffer, out int serializedDataLength);
+
+ fixed (byte* bufferPointer = columnBuffer)
+ {
+ var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(serializedDataLength) };
+ success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle);
+ }
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+ }
+
+ success = CreateTransformerFromEstimatorNative(estimatorHandler, out IntPtr transformer, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative);
+ }
+ }
+
+ private void BuildColumnByteArray(Dictionary allColumns, ref byte[] columnByteBuffer, out int serializedDataLength)
+ {
+ serializedDataLength = 0;
+ foreach (var column in _allColumnNames)
+ {
+ var bytes = allColumns[column].GetSerializedValue();
+ var byteLength = bytes.Length;
+ if (byteLength + serializedDataLength >= columnByteBuffer.Length)
+ Array.Resize(ref columnByteBuffer, columnByteBuffer.Length * 2);
+
+ Array.Copy(bytes, 0, columnByteBuffer, serializedDataLength, byteLength);
+ serializedDataLength += byteLength;
+ }
+ }
+
+ public bool IsRowToRowMapper => false;
+
+ // Schema not changed
+ public DataViewSchema GetOutputSchema(DataViewSchema inputSchema)
+ {
+ var columns = inputSchema.ToDictionary(x => x.Name);
+ var schemaBuilder = new DataViewSchema.Builder();
+ schemaBuilder.AddColumns(inputSchema.AsEnumerable());
+ schemaBuilder.AddColumn(TimeSeriesImputerEstimator.IsRowImputedColumnName, BooleanDataViewType.Instance);
+
+ return schemaBuilder.ToSchema();
+ }
+
+ public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => throw new InvalidOperationException("Not a RowToRowMapper.");
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "TimeIm T",
+ verWrittenCur: 0x00010001,
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature,
+ loaderAssemblyName: typeof(TimeSeriesImputerTransformer).Assembly.FullName);
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ _host.CheckValue(ctx, nameof(ctx));
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ // *** Binary format ***
+ // name of time series column
+ // length of grain column array
+ // all column names in grain column array
+ // length of data column array
+ // all column names in data column array
+ // byte value of impute mode
+ // length of C++ state array
+ // C++ byte state array
+
+ ctx.Writer.Write(_timeSeriesColumn);
+ ctx.Writer.Write(_grainColumns.Length);
+ foreach (var column in _grainColumns)
+ ctx.Writer.Write(column);
+ ctx.Writer.Write(_dataColumns.Length);
+ foreach (var column in _dataColumns)
+ ctx.Writer.Write(column);
+ ctx.Writer.Write((byte)_imputeMode);
+ var data = CreateTransformerSaveData();
+ ctx.Writer.Write(data.Length);
+ ctx.Writer.Write(data);
+ }
+
+ private byte[] CreateTransformerSaveData()
+ {
+ var success = CreateTransformerSaveDataNative(TransformerHandle, out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize))
+ {
+ byte[] savedData = new byte[bufferSize.ToInt32()];
+ Marshal.Copy(buffer, savedData, 0, savedData.Length);
+ return savedData;
+ }
+ }
+
+ public IDataView Transform(IDataView input) => MakeDataTransform(input);
+
+ internal TimeSeriesImputerDataView MakeDataTransform(IDataView input)
+ {
+ _host.CheckValue(input, nameof(input));
+
+ return new TimeSeriesImputerDataView(_host, input, _timeSeriesColumn, _grainColumns, _dataColumns, _allColumnNames, this);
+ }
+
+ internal TransformerEstimatorSafeHandle CloneTransformer() => CreateTransformerFromSavedData(CreateTransformerSaveData());
+
+ public void Dispose()
+ {
+ if (!TransformerHandle.IsClosed)
+ TransformerHandle.Close();
+ }
+
+ #region C++ function declarations
+ // TODO: Update entry points
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateEstimatorNative(TypeId* grainTypes, IntPtr grainTypesSize, TypeId* dataTypes, IntPtr dataTypesSize, TimeSeriesImputerEstimator.ImputationStrategy strategy, bool* suppressTypeErrors, out IntPtr estimator, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, NativeBinaryArchiveData data, out FitResult fitResult, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity]
+ private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity]
+ private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle);
+
+ #endregion
+
+ #region Typed Columns
+
+ private abstract class TypedColumn
+ {
+ internal readonly DataViewSchema.Column Column;
+ internal TypedColumn(DataViewSchema.Column column)
+ {
+ Column = column;
+ }
+
+ internal abstract void InitializeGetter(DataViewRowCursor cursor);
+ internal abstract byte[] GetSerializedValue();
+ internal abstract TypeId GetTypeId();
+
+ internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns)
+ {
+ var type = column.Type.RawType.ToString();
+ if (type == typeof(sbyte).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(short).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(int).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(long).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(byte).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(ushort).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(uint).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(ulong).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(float).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(double).ToString())
+ return new NumericTypedColumn(column, optionalColumns.Contains(column.Name));
+ else if (type == typeof(ReadOnlyMemory).ToString())
+ return new StringTypedColumn(column, optionalColumns.Contains(column.Name));
+
+ throw new InvalidOperationException($"Unsupported type {type}");
+ }
+ }
+
+ private abstract class TypedColumn : TypedColumn
+ {
+ private ValueGetter _getter;
+ private T _value;
+
+ internal TypedColumn(DataViewSchema.Column column) :
+ base(column)
+ {
+ _value = default;
+ }
+
+ internal override void InitializeGetter(DataViewRowCursor cursor)
+ {
+ _getter = cursor.GetGetter(Column);
+ }
+
+ internal T GetValue()
+ {
+ _getter(ref _value);
+ return _value;
+ }
+
+ internal override TypeId GetTypeId()
+ {
+ return typeof(T).GetNativeTypeIdFromType();
+ }
+ }
+
+ private class NumericTypedColumn : TypedColumn
+ {
+ private readonly bool _isNullable;
+
+ internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable = false) :
+ base(column)
+ {
+ _isNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ dynamic value = GetValue();
+ byte[] bytes;
+ if (value.GetType() == typeof(byte))
+ bytes = new byte[1] { value };
+ bytes = BitConverter.GetBytes(value);
+
+ if (_isNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double))
+ return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray();
+ else
+ return bytes;
+ }
+ }
+
+ private class StringTypedColumn : TypedColumn>
+ {
+ private readonly bool _isNullable;
+
+ internal StringTypedColumn(DataViewSchema.Column column, bool isNullable = false) :
+ base(column)
+ {
+ _isNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ var value = GetValue().ToString();
+ var stringBytes = Encoding.UTF8.GetBytes(value);
+ if (_isNullable)
+ return new byte[] { Convert.ToByte(true) }.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray();
+ return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray();
+ }
+ }
+
+ #endregion
+ }
+
+ internal static class TimeSeriesTransformerEntrypoint
+ {
+ [TlcModule.EntryPoint(Name = "Transforms.TimeSeriesImputer",
+ Desc = TimeSeriesImputerTransformer.Summary,
+ UserName = TimeSeriesImputerTransformer.UserName,
+ ShortName = TimeSeriesImputerTransformer.ShortName)]
+ public static CommonOutputs.TransformOutput TimeSeriesImputer(IHostEnvironment env, TimeSeriesImputerEstimator.Options input)
+ {
+ var h = EntryPointUtils.CheckArgsAndCreateHost(env, TimeSeriesImputerTransformer.ShortName, input);
+ var xf = new TimeSeriesImputerEstimator(h, input).Fit(input.Data).Transform(input.Data);
+ return new CommonOutputs.TransformOutput()
+ {
+ Model = new TransformModelImpl(h, xf, input.Data),
+ OutputData = xf
+ };
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs
new file mode 100644
index 0000000000..ccb33b99e9
--- /dev/null
+++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs
@@ -0,0 +1,766 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Security;
+using System.Text;
+using Microsoft.ML.Data;
+using Microsoft.ML.Featurizers;
+using Microsoft.ML.Runtime;
+using Microsoft.Win32.SafeHandles;
+using static Microsoft.ML.Featurizers.CommonExtensions;
+using static Microsoft.ML.Featurizers.TimeSeriesImputerEstimator;
+
+namespace Microsoft.ML.Transforms
+{
+
+ internal sealed class TimeSeriesImputerDataView : IDataTransform
+ {
+ #region Typed Columns
+ private TimeSeriesImputerTransformer _parent;
+ public class SharedColumnState
+ {
+ public bool SourceCanMoveNext { get; set; }
+ public int TransformedDataPosition { get; set; }
+
+ // This array is used to hold the returned row data from the native transformer. Because we create rows in this transformer, the number
+ // of rows returned from the native code is not always consistent and so this has to be an array.
+ public NativeBinaryArchiveData[] TransformedData { get; set; }
+
+ // Hold the serialized data that we are going to send to the native code for processing.
+ public byte[] ColumnBuffer { get; set; }
+ public TransformedDataSafeHandle TransformedDataHandler { get; set; }
+ }
+
+ private abstract class TypedColumn
+ {
+ private protected SharedColumnState SharedState;
+
+ internal readonly DataViewSchema.Column Column;
+ internal readonly bool IsImputed;
+ internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state)
+ {
+ Column = column;
+ SharedState = state;
+ IsImputed = isImputed;
+ }
+
+ internal abstract Delegate GetGetter();
+ internal abstract void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformerParent, string timeSeriesColumn,
+ string[] grainColumns, string[] dataColumns, string[] allColumnNames, Dictionary allColumns);
+
+ internal abstract TypeId GetTypeId();
+ internal abstract byte[] GetSerializedValue();
+ internal abstract unsafe int GetDataSizeInBytes(byte* data, int currentOffset);
+ internal abstract void QueueNonImputedColumnValue();
+
+ public bool MoveNext(DataViewRowCursor cursor)
+ {
+ SharedState.TransformedDataPosition++;
+
+ if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ SharedState.SourceCanMoveNext = cursor.MoveNext();
+
+ if (!SharedState.SourceCanMoveNext)
+ if (SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ {
+ if (!SharedState.TransformedDataHandler.IsClosed)
+ SharedState.TransformedDataHandler.Dispose();
+ return false;
+ }
+
+ return true;
+ }
+
+ internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns, string[] allImputedColumns, SharedColumnState state)
+ {
+ var type = column.Type.RawType.ToString();
+ if (type == typeof(sbyte).ToString())
+ return new SByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(short).ToString())
+ return new ShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(int).ToString())
+ return new IntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(long).ToString())
+ return new LongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(byte).ToString())
+ return new ByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(ushort).ToString())
+ return new UShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(uint).ToString())
+ return new UIntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(ulong).ToString())
+ return new ULongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(float).ToString())
+ return new FloatTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(double).ToString())
+ return new DoubleTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(ReadOnlyMemory).ToString())
+ return new StringTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+ else if (type == typeof(bool).ToString())
+ return new BoolTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state);
+
+ throw new InvalidOperationException($"Unsupported type {type}");
+ }
+ }
+
+ private abstract class TypedColumn : TypedColumn
+ {
+ private ValueGetter _getter;
+ private ValueGetter _sourceGetter;
+ private long _position;
+ private T _cache;
+
+ // When columns are not being imputed, we need to store the column values in memory until they are used.
+ private protected Queue SourceQueue;
+
+ internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state) :
+ base(column, isImputed, state)
+ {
+ SourceQueue = new Queue();
+ _position = -1;
+ }
+
+ internal override Delegate GetGetter()
+ {
+ return _getter;
+ }
+
+ internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn,
+ string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, Dictionary allColumns)
+ {
+ if (Column.Name != IsRowImputedColumnName)
+ _sourceGetter = cursor.GetGetter(Column);
+
+ _getter = (ref T dst) =>
+ {
+ IntPtr errorHandle = IntPtr.Zero;
+ bool success = false;
+ if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ {
+ // Free native memory if we are about to get more
+ if (SharedState.TransformedData != null && SharedState.TransformedDataPosition >= SharedState.TransformedData.Length)
+ SharedState.TransformedDataHandler.Dispose();
+
+ var outputDataSize = IntPtr.Zero;
+ NativeBinaryArchiveData* outputData = default;
+ while(outputDataSize == IntPtr.Zero && SharedState.SourceCanMoveNext)
+ {
+ BuildColumnByteArray(allColumns, allImputedColumnNames, out int bufferLength);
+ QueueDataForNonImputedColumns(allColumns, allImputedColumnNames);
+ fixed (byte* bufferPointer = SharedState.ColumnBuffer)
+ {
+ var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) };
+ success = TransformDataNative(transformer, binaryArchiveData, out outputData, out outputDataSize, out errorHandle);
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+ }
+
+ if (outputDataSize == IntPtr.Zero)
+ SharedState.SourceCanMoveNext = cursor.MoveNext();
+ }
+
+ if (!SharedState.SourceCanMoveNext)
+ success = FlushDataNative(transformer, out outputData, out outputDataSize, out errorHandle);
+
+ if (!success)
+ throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle));
+
+ if (outputDataSize.ToInt32() > 0)
+ {
+ SharedState.TransformedDataHandler = new TransformedDataSafeHandle((IntPtr)outputData, outputDataSize);
+ SharedState.TransformedData = new NativeBinaryArchiveData[outputDataSize.ToInt32()];
+ for (int i = 0; i < outputDataSize.ToInt32(); i++)
+ {
+ SharedState.TransformedData[i] = *(outputData + i);
+ }
+ SharedState.TransformedDataPosition = 0;
+ }
+ }
+
+ // Base case where we didn't impute the column
+ if (!allImputedColumnNames.Contains(Column.Name))
+ {
+ var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition];
+ // If the row was imputed we want to just return the default value for the type.
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(imputedData.Data, 0))
+ {
+ dst = default;
+ }
+ else
+ {
+ // If the row wasn't imputed, get the original value for that row we stored in the queue and return that.
+ if (_position != cursor.Position)
+ {
+ _position = cursor.Position;
+ _cache = SourceQueue.Dequeue();
+ }
+ dst = _cache;
+ }
+ }
+ // If we did impute the column then parse the data from the returned byte array.
+ else
+ {
+ var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition];
+ int offset = 0;
+ foreach (var columnName in allImputedColumnNames)
+ {
+ var col = allColumns[columnName];
+ if (col.Column.Name == Column.Name)
+ {
+ dst = GetDataFromNativeBinaryArchiveData(imputedData.Data, offset);
+ return;
+ }
+
+ offset += col.GetDataSizeInBytes(imputedData.Data, offset);
+ }
+
+ // This should never be hit.
+ dst = default;
+ }
+ };
+ }
+
+ private void QueueDataForNonImputedColumns(Dictionary allColumns, string[] allImputedColumnNames)
+ {
+ foreach (var column in allColumns.Where(x => !allImputedColumnNames.Contains(x.Value.Column.Name)).Select(x => x.Value))
+ column.QueueNonImputedColumnValue();
+ }
+
+ internal override void QueueNonImputedColumnValue()
+ {
+ SourceQueue.Enqueue(GetSourceValue());
+ }
+
+ private void BuildColumnByteArray(Dictionary allColumns, string[] columns, out int bufferLength)
+ {
+ bufferLength = 0;
+ foreach(var column in columns.Where(x => x != IsRowImputedColumnName))
+ {
+ var bytes = allColumns[column].GetSerializedValue();
+ var byteLength = bytes.Length;
+ if (byteLength + bufferLength >= SharedState.ColumnBuffer.Length)
+ {
+ var buffer = SharedState.ColumnBuffer;
+ Array.Resize(ref buffer, SharedState.ColumnBuffer.Length * 2);
+ SharedState.ColumnBuffer = buffer;
+ }
+
+ Array.Copy(bytes, 0, SharedState.ColumnBuffer, bufferLength, byteLength);
+ bufferLength += byteLength;
+ }
+ }
+
+ private protected T GetSourceValue()
+ {
+ T value = default;
+ _sourceGetter(ref value);
+ return value;
+ }
+
+ internal override TypeId GetTypeId()
+ {
+ return typeof(T).GetNativeTypeIdFromType();
+ }
+
+ internal unsafe abstract T GetDataFromNativeBinaryArchiveData(byte* data, int offset);
+ }
+
+ private abstract class NumericTypedColumn : TypedColumn
+ {
+ private protected readonly bool IsNullable;
+
+ internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isImputed, state)
+ {
+ IsNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ dynamic value = GetSourceValue();
+ byte[] bytes;
+ if (value.GetType() == typeof(byte))
+ bytes = new byte[1] { value };
+ if (BitConverter.IsLittleEndian)
+ bytes = BitConverter.GetBytes(value);
+ else
+ bytes = BitConverter.GetBytes(value);
+
+ if (IsNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double))
+ return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray();
+ else
+ return bytes;
+ }
+
+ internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset)
+ {
+ if (IsNullable && typeof(T) != typeof(float) && typeof(T) != typeof(double))
+ return Marshal.SizeOf(default(T)) + sizeof(bool);
+ else
+ return Marshal.SizeOf(default(T));
+ }
+ }
+
+ private class ByteTypedColumn : NumericTypedColumn
+ {
+ internal ByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override byte GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(byte*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(byte*)(data + offset);
+ }
+ }
+
+ private class SByteTypedColumn : NumericTypedColumn
+ {
+ internal SByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override sbyte GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(sbyte*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(sbyte*)(data + offset);
+ }
+ }
+
+ private class ShortTypedColumn : NumericTypedColumn
+ {
+ internal ShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override short GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(short*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(short*)(data + offset);
+ }
+ }
+
+ private class UShortTypedColumn : NumericTypedColumn
+ {
+ internal UShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override ushort GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(ushort*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(ushort*)(data + offset);
+ }
+ }
+
+ private class IntTypedColumn : NumericTypedColumn
+ {
+ internal IntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override int GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(int*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(int*)(data + offset);
+ }
+ }
+
+ private class UIntTypedColumn : NumericTypedColumn
+ {
+ internal UIntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override uint GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(uint*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(uint*)(data + offset);
+ }
+ }
+
+ private class LongTypedColumn : NumericTypedColumn
+ {
+ internal LongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override long GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(long*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(long*)(data + offset);
+ }
+ }
+
+ private class ULongTypedColumn : NumericTypedColumn
+ {
+ internal ULongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override ulong GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(ulong*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(ulong*)(data + offset);
+ }
+ }
+
+ private class FloatTypedColumn : NumericTypedColumn
+ {
+ internal FloatTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override float GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ var bytes = new byte[sizeof(float)];
+ Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(float));
+ return BitConverter.ToSingle(bytes, 0);
+ }
+ }
+
+ private class DoubleTypedColumn : NumericTypedColumn
+ {
+ internal DoubleTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override double GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ var bytes = new byte[sizeof(double)];
+ Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(double));
+ return BitConverter.ToDouble(bytes, 0);
+ }
+ }
+
+ private class BoolTypedColumn : NumericTypedColumn
+ {
+ internal BoolTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isNullable, isImputed, state)
+ {
+ }
+
+ internal unsafe override bool GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (IsNullable)
+ {
+ if (GetBoolFromNativeBinaryArchiveData(data, offset))
+ return *(bool*)(data + offset + sizeof(bool));
+ else
+ return default;
+ }
+ else
+ return *(bool*)(data + offset);
+ }
+
+ internal static unsafe bool GetBoolFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ return *(bool*)(data + offset);
+ }
+
+ internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset)
+ {
+ return sizeof(bool);
+ }
+ }
+
+ private class StringTypedColumn : TypedColumn>
+ {
+ private readonly bool _isNullable;
+ internal StringTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) :
+ base(column, isImputed, state)
+ {
+ _isNullable = isNullable;
+ }
+
+ internal override byte[] GetSerializedValue()
+ {
+ var value = GetSourceValue().ToString();
+ var stringBytes = Encoding.UTF8.GetBytes(value);
+ if (_isNullable)
+ return new byte[] { Convert.ToByte(true)}.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray();
+ return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray();
+ }
+
+ internal unsafe override ReadOnlyMemory GetDataFromNativeBinaryArchiveData(byte* data, int offset)
+ {
+ if (_isNullable)
+ {
+ if (!BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) // If value not present return empty string
+ return new ReadOnlyMemory("".ToCharArray());
+
+ var size = *(uint*)(data + offset + 1); // Add 1 for the byte bool flag
+
+ var bytes = new byte[size];
+ Marshal.Copy((IntPtr)(data + offset + sizeof(uint) + 1), bytes, 0, (int)size);
+ return Encoding.UTF8.GetString(bytes).AsMemory();
+ }
+ else
+ {
+ var size = *(uint*)(data + offset);
+
+ var bytes = new byte[size];
+ Marshal.Copy((IntPtr)(data + offset + sizeof(uint)), bytes, 0, (int)size);
+ return Encoding.UTF8.GetString(bytes).AsMemory();
+ }
+ }
+
+ internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset)
+ {
+ var size = *(uint*)(data + currentOffset);
+ if (_isNullable)
+ return 1 + (int)size + sizeof(uint); // + 1 for the byte bool flag
+
+ return (int)size + sizeof(uint);
+ }
+ }
+
+ #endregion
+
+ #region Native Exports
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern unsafe bool TransformDataNative(TransformerEstimatorSafeHandle transformer, /*in*/ NativeBinaryArchiveData data, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity]
+ private static extern unsafe bool FlushDataNative(TransformerEstimatorSafeHandle transformer, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle);
+
+ [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformedData"), SuppressUnmanagedCodeSecurity]
+ private static extern unsafe bool DestroyTransformedDataNative(IntPtr data, IntPtr dataSize, out IntPtr errorHandle);
+
+ #endregion
+
+ #region Native SafeHandles
+
+ internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid
+ {
+ private IntPtr _size;
+ public TransformedDataSafeHandle(IntPtr handle, IntPtr size) : base(true)
+ {
+ SetHandle(handle);
+ _size = size;
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ // Not sure what to do with error stuff here. There shoudln't ever be one though.
+ return DestroyTransformedDataNative(handle, _size, out IntPtr errorHandle);
+ }
+ }
+
+ #endregion
+
+ private readonly IHostEnvironment _host;
+ private readonly IDataView _source;
+ private readonly string _timeSeriesColumn;
+ private readonly string[] _dataColumns;
+ private readonly string[] _grainColumns;
+ private readonly string[] _allImputedColumnNames;
+ private readonly DataViewSchema _schema;
+
+ internal TimeSeriesImputerDataView(IHostEnvironment env, IDataView input, string timeSeriesColumn, string[] grainColumns, string[] dataColumns, string[] allColumnNames, TimeSeriesImputerTransformer parent)
+ {
+ _host = env;
+ _source = input;
+
+ _timeSeriesColumn = timeSeriesColumn;
+ _grainColumns = grainColumns;
+ _dataColumns = dataColumns;
+ _allImputedColumnNames = new string[] { IsRowImputedColumnName }.Concat(allColumnNames).ToArray();
+ _parent = parent;
+ // Build new schema.
+ var schemaColumns = _source.Schema.ToDictionary(x => x.Name);
+ var schemaBuilder = new DataViewSchema.Builder();
+ schemaBuilder.AddColumns(_source.Schema.AsEnumerable());
+ schemaBuilder.AddColumn(IsRowImputedColumnName, BooleanDataViewType.Instance);
+
+ _schema = schemaBuilder.ToSchema();
+ }
+
+ public bool CanShuffle => false;
+
+ public DataViewSchema Schema => _schema;
+
+ public IDataView Source => _source;
+
+ public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null)
+ {
+ _host.AssertValueOrNull(rand);
+
+ var input = _source.GetRowCursorForAllColumns();
+ return new Cursor(_host, input, _parent.CloneTransformer(), _timeSeriesColumn, _grainColumns, _dataColumns, _allImputedColumnNames, _schema);
+ }
+
+ // Can't use parallel cursors so this defaults to calling non-parallel version
+ public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) =>
+ new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) };
+
+ // Since we may add rows we don't know the row count
+ public long? GetRowCount() => null;
+
+ public void Save(ModelSaveContext ctx)
+ {
+ _parent.Save(ctx);
+ }
+
+ private sealed class Cursor : DataViewRowCursor
+ {
+ private readonly IChannelProvider _ch;
+ private DataViewRowCursor _input;
+ private long _position;
+ private bool _isGood;
+ private readonly Dictionary _allColumns;
+ private readonly DataViewSchema _schema;
+ private readonly TransformerEstimatorSafeHandle _transformer;
+
+ public Cursor(IChannelProvider provider, DataViewRowCursor input, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn,
+ string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, DataViewSchema schema)
+ {
+ _ch = provider;
+ _ch.CheckValue(input, nameof(input));
+
+ _input = input;
+ var length = input.Schema.Count;
+ _position = -1;
+ _schema = schema;
+ _transformer = transformer;
+
+ var sharedState = new SharedColumnState()
+ {
+ SourceCanMoveNext = true,
+ ColumnBuffer = new byte[4096]
+ };
+
+ _allColumns = _schema.Select(x => TypedColumn.CreateTypedColumn(x, dataColumns, allImputedColumnNames, sharedState)).ToDictionary(x => x.Column.Name); ;
+ _allColumns[IsRowImputedColumnName] = new BoolTypedColumn(_schema[IsRowImputedColumnName], false, true, sharedState);
+
+ foreach (var column in _allColumns.Values)
+ {
+ column.InitializeGetter(_input, transformer, timeSeriesColumn, grainColumns, dataColumns, allImputedColumnNames, _allColumns);
+ }
+ }
+
+ public sealed override ValueGetter GetIdGetter()
+ {
+ return
+ (ref DataViewRowId val) =>
+ {
+ _ch.Check(_isGood, RowCursorUtils.FetchValueStateError);
+ val = new DataViewRowId((ulong)Position, 0);
+ };
+ }
+
+ public sealed override DataViewSchema Schema => _schema;
+
+ ///
+ /// Since rows will be generated all columns are active
+ ///
+ public override bool IsColumnActive(DataViewSchema.Column column) => true;
+
+ protected override void Dispose(bool disposing)
+ {
+ if (!_transformer.IsClosed)
+ _transformer.Close();
+ }
+
+ ///
+ /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row.
+ /// This throws if the column is not active in this row, or if the type
+ /// differs from this column's type.
+ ///
+ /// is the column's content type.
+ /// is the output column whose getter should be returned.
+ public override ValueGetter GetGetter(DataViewSchema.Column column)
+ {
+ _ch.Check(IsColumnActive(column));
+
+ var fn = _allColumns[column.Name].GetGetter() as ValueGetter;
+ if (fn == null)
+ throw _ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue));
+ return fn;
+ }
+
+ public override bool MoveNext()
+ {
+ _position++;
+ _isGood = _allColumns[IsRowImputedColumnName].MoveNext(_input);
+ return _isGood;
+ }
+
+ public sealed override long Position => _position;
+
+ public sealed override long Batch => _input.Batch;
+ }
+ }
+}
diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
index 9088dfc1eb..bcd7dc5c44 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
+++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv
@@ -133,6 +133,7 @@ Transforms.SentimentAnalyzer Uses a pretrained sentiment model to score input st
Transforms.TensorFlowScorer Transforms the data using the TensorFlow model. Microsoft.ML.Transforms.TensorFlowTransformer TensorFlowScorer Microsoft.ML.Transforms.TensorFlowEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TextFeaturizer A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are normalized counts of (word and/or character) n-grams in a given tokenized text. Microsoft.ML.Transforms.Text.TextAnalytics TextTransform Microsoft.ML.Transforms.Text.TextFeaturizingEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TextToKeyConverter Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Categorical TextToKey Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
+Transforms.TimeSeriesImputer Fills in missing row and values Microsoft.ML.Featurizers.TimeSeriesTransformerEntrypoint TimeSeriesImputer Microsoft.ML.Featurizers.TimeSeriesImputerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.ToString Turns the given column into a column of its string representation Microsoft.ML.Featurizers.ToStringTransformerEntrypoint ToString Microsoft.ML.Featurizers.ToStringTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets Microsoft.ML.EntryPoints.TrainTestSplit Split Microsoft.ML.EntryPoints.TrainTestSplit+Input Microsoft.ML.EntryPoints.TrainTestSplit+Output
Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Data.TreeFeaturize Featurizer Microsoft.ML.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 02313887d2..b2587d6cd6 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -23252,6 +23252,130 @@
"ITransformOutput"
]
},
+ {
+ "Name": "Transforms.TimeSeriesImputer",
+ "Desc": "Fills in missing row and values",
+ "FriendlyName": "TimeSeriesImputer",
+ "ShortName": "tsi",
+ "Inputs": [
+ {
+ "Name": "TimeSeriesColumn",
+ "Type": "String",
+ "Desc": "Column representing the time",
+ "Aliases": [
+ "time"
+ ],
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "Data",
+ "Type": "DataView",
+ "Desc": "Input dataset",
+ "Required": true,
+ "SortOrder": 1.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "GrainColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "List of grain columns",
+ "Aliases": [
+ "grains"
+ ],
+ "Required": true,
+ "SortOrder": 2.0,
+ "IsNullable": false
+ },
+ {
+ "Name": "FilterColumns",
+ "Type": {
+ "Kind": "Array",
+ "ItemType": "String"
+ },
+ "Desc": "Columns to filter",
+ "Aliases": [
+ "filters"
+ ],
+ "Required": false,
+ "SortOrder": 2.0,
+ "IsNullable": false,
+ "Default": null
+ },
+ {
+ "Name": "FilterMode",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "NoFilter",
+ "Include",
+ "Exclude"
+ ]
+ },
+ "Desc": "Filter mode. Either include or exclude",
+ "Aliases": [
+ "fmode"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": "Exclude"
+ },
+ {
+ "Name": "ImputeMode",
+ "Type": {
+ "Kind": "Enum",
+ "Values": [
+ "ForwardFill",
+ "BackFill",
+ "Median"
+ ]
+ },
+ "Desc": "Mode for imputing, defaults to ForwardFill if not provided",
+ "Aliases": [
+ "mode"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": "ForwardFill"
+ },
+ {
+ "Name": "SupressTypeErrors",
+ "Type": "Bool",
+ "Desc": "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.",
+ "Aliases": [
+ "error"
+ ],
+ "Required": false,
+ "SortOrder": 3.0,
+ "IsNullable": false,
+ "Default": false
+ }
+ ],
+ "Outputs": [
+ {
+ "Name": "OutputData",
+ "Type": "DataView",
+ "Desc": "Transformed dataset"
+ },
+ {
+ "Name": "Model",
+ "Type": "TransformModel",
+ "Desc": "Transform model"
+ }
+ ],
+ "InputKind": [
+ "ITransformInput"
+ ],
+ "OutputKind": [
+ "ITransformOutput"
+ ]
+ },
{
"Name": "Transforms.ToString",
"Desc": "Turns the given column into a column of its string representation",
diff --git a/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs
new file mode 100644
index 0000000000..bed0a67e5f
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs
@@ -0,0 +1,471 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.RunTests;
+using Microsoft.ML.Featurizers;
+using System;
+using Xunit;
+using Xunit.Abstractions;
+using System.Drawing.Printing;
+using System.Linq;
+using Microsoft.ML.TestFramework.Attributes;
+
+namespace Microsoft.ML.Tests.Transformers
+{
+ public class TimeSeriesImputerTests : TestDataPipeBase
+ {
+ public TimeSeriesImputerTests(ITestOutputHelper output) : base(output)
+ {
+ }
+
+ private class TimeSeriesTwoGrainInput
+ {
+ public long date;
+ public string grainA;
+ public string grainB;
+ public float data;
+ }
+
+ private class TimeSeriesOneGrainInput
+ {
+ public long date;
+ public string grainA;
+ public int dataA;
+ public float dataB;
+ public uint dataC;
+ }
+
+ private class TimeSeriesOneGrainFloatInput
+ {
+ public long date;
+ public string grainA;
+ public float dataA;
+ }
+
+ private class TimeSeriesOneGrainStringInput
+ {
+ public long date;
+ public string grainA;
+ public string dataA;
+ }
+
+ [NotCentOS7Fact]
+ public void NotImputeOneColumn()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] {
+ new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }
+ };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, new string[] { "dataB"});
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // We always output the same column as the input, plus adding a column saying whether the row was imputed or not.
+ Assert.Equal(6, schema.Count);
+ Assert.Equal("date", schema[0].Name);
+ Assert.Equal("grainA", schema[1].Name);
+ Assert.Equal("dataA", schema[2].Name);
+ Assert.Equal("dataB", schema[3].Name);
+ Assert.Equal("dataC", schema[4].Name);
+ Assert.Equal("IsRowImputed", schema[5].Name);
+
+ // We are imputing 1 row, so total rows should be 4.
+ var preview = output.Preview();
+ Assert.Equal(4, preview.RowView.Length);
+
+ // Row that was imputed should have date of 27
+ Assert.Equal(27L, preview.ColumnView[0].Values[2]);
+
+ // Since we are not imputing data on one column and a row is getting imputed, its value should be default(T)
+ Assert.Equal(default(float), preview.ColumnView[3].Values[2]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void ImputeOnlyOneColumn()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] {
+ new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 },
+ new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }
+ };
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, new string[] { "dataB"}, TimeSeriesImputerEstimator.FilterMode.Include);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var schema = output.Schema;
+
+ // We always output the same column as the input, plus adding a column saying whether the row was imputed or not.
+ Assert.Equal(6, schema.Count);
+ Assert.Equal("date", schema[0].Name);
+ Assert.Equal("grainA", schema[1].Name);
+ Assert.Equal("dataA", schema[2].Name);
+ Assert.Equal("dataB", schema[3].Name);
+ Assert.Equal("dataC", schema[4].Name);
+ Assert.Equal("IsRowImputed", schema[5].Name);
+
+ // We are imputing 1 row, so total rows should be 4.
+ var preview = output.Preview();
+ Assert.Equal(4, preview.RowView.Length);
+
+ // Row that was imputed should have date of 27
+ Assert.Equal(27L, preview.ColumnView[0].Values[2]);
+
+ // Since we are not imputing data on two columns and a row is getting imputed, its value should be default(T)
+ Assert.Equal(default(int), preview.ColumnView[2].Values[2]);
+ Assert.Equal(default(uint), preview.ColumnView[4].Values[2]);
+
+ // Column that was imputed should have value of 2.0f
+ Assert.Equal(2.0f, preview.ColumnView[3].Values[2]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void Forwardfill()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" });
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(2L, prev.ColumnView[0].Values[2]);
+ Assert.Equal(4L, prev.ColumnView[0].Values[4]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[6]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[1]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[2]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[4]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[5]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[6]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[7]);
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void EntryPoint()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new { ts = 1L, grain = 1970, c3 = 10, c4 = 19},
+ new { ts = 2L, grain = 1970, c3 = 13, c4 = 12},
+ new { ts = 3L, grain = 1970, c3 = 15, c4 = 16},
+ new { ts = 5L, grain = 1970, c3 = 20, c4 = 19}
+ };
+
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+ TimeSeriesImputerEstimator.Options options = new TimeSeriesImputerEstimator.Options() {
+ TimeSeriesColumn = "ts",
+ GrainColumns = new[] { "grain" },
+ FilterColumns = new[] { "c3", "c4" },
+ FilterMode = TimeSeriesImputerEstimator.FilterMode.Include,
+ ImputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill,
+ Data = data
+ };
+
+ var entryOutput = TimeSeriesTransformerEntrypoint.TimeSeriesImputer(mlContext.Transforms.GetEnvironment(), options);
+ // Build the pipeline, fit, and transform it.
+ var output = entryOutput.OutputData;
+
+ // Get the data from the first row and make sure it matches expected
+ var prev = output.Preview();
+
+ // Should have 4 original columns + 1 more for IsRowImputed
+ Assert.Equal(5, output.Schema.Count);
+
+ // Imputing rows with date 4 so should have length of 5
+ Assert.Equal(5, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(4L, prev.ColumnView[0].Values[3]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal(1970, prev.ColumnView[1].Values[2]);
+
+ // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(15, prev.ColumnView[2].Values[3]);
+ Assert.Equal(16, prev.ColumnView[3].Values[3]);
+
+ // Make sure IsRowImputed is true for row 4, false for the rest
+ Assert.Equal(false, prev.ColumnView[4].Values[0]);
+ Assert.Equal(false, prev.ColumnView[4].Values[1]);
+ Assert.Equal(false, prev.ColumnView[4].Values[2]);
+ Assert.Equal(true, prev.ColumnView[4].Values[3]);
+ Assert.Equal(false, prev.ColumnView[4].Values[4]);
+
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void Median()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true);
+ var model = pipeline.Fit(data);
+
+ var output = model.Transform(data);
+
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(2L, prev.ColumnView[0].Values[2]);
+ Assert.Equal(4L, prev.ColumnView[0].Values[4]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[6]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure Median is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[1]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[2]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[4]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[5]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[6]);
+ Assert.Equal(3.5f, prev.ColumnView[2].Values[7]);
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void Backfill()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f },
+ new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN },
+ new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = 2.0f }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(2L, prev.ColumnView[0].Values[2]);
+ Assert.Equal(4L, prev.ColumnView[0].Values[4]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[6]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[0]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[1]);
+ Assert.Equal(5.0f, prev.ColumnView[2].Values[2]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[4]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[5]);
+ Assert.Equal(2.0f, prev.ColumnView[2].Values[6]);
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void BackfillTwoGrain()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesTwoGrainInput() { date = 0, grainA = "A", grainB = "A", data = float.NaN},
+ new TimeSeriesTwoGrainInput() { date = 1, grainA = "A", grainB = "A", data = 0.0f},
+ new TimeSeriesTwoGrainInput() { date = 3, grainA = "A", grainB = "B", data = 1.0f},
+ new TimeSeriesTwoGrainInput() { date = 5, grainA = "A", grainB = "B", data = float.NaN},
+ new TimeSeriesTwoGrainInput() { date = 7, grainA = "A", grainB = "B", data = 2.0f }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // Build the pipeline, fit, and transform it.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA", "grainB" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill);
+ var model = pipeline.Fit(data);
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 4 original columns + 1 more for IsRowImputed
+ Assert.Equal(5, output.Schema.Count);
+
+ // Imputing rows with dates 4,6, so should have length of 8
+ Assert.Equal(7, prev.RowView.Length);
+
+ // Check that imputed rows have the correct dates
+ Assert.Equal(4L, prev.ColumnView[0].Values[3]);
+ Assert.Equal(6L, prev.ColumnView[0].Values[5]);
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[3].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[5].ToString());
+ Assert.Equal("B", prev.ColumnView[2].Values[3].ToString());
+ Assert.Equal("B", prev.ColumnView[2].Values[5].ToString());
+
+ // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too
+ Assert.Equal(0.0f, prev.ColumnView[3].Values[0]);
+ Assert.Equal(2.0f, prev.ColumnView[3].Values[3]);
+ Assert.Equal(2.0f, prev.ColumnView[3].Values[4]);
+ Assert.Equal(2.0f, prev.ColumnView[3].Values[5]);
+
+ // Make sure IsRowImputed is true for row 4,6 false for the rest
+ Assert.Equal(false, prev.ColumnView[4].Values[0]);
+ Assert.Equal(false, prev.ColumnView[4].Values[1]);
+ Assert.Equal(false, prev.ColumnView[4].Values[2]);
+ Assert.Equal(true, prev.ColumnView[4].Values[3]);
+ Assert.Equal(false, prev.ColumnView[4].Values[4]);
+ Assert.Equal(true, prev.ColumnView[4].Values[5]);
+ Assert.Equal(false, prev.ColumnView[4].Values[6]);
+
+ TestEstimatorCore(pipeline, data);
+ Done();
+ }
+
+ [NotCentOS7Fact]
+ public void InvalidTypeForImputationStrategy()
+ {
+ MLContext mlContext = new MLContext(1);
+ var dataList = new[] { new TimeSeriesOneGrainStringInput(){ date = 0L, grainA = "A", dataA = "zero" },
+ new TimeSeriesOneGrainStringInput(){ date = 1L, grainA = "A", dataA = "one" },
+ new TimeSeriesOneGrainStringInput(){ date = 3L, grainA = "A", dataA = "three" },
+ new TimeSeriesOneGrainStringInput(){ date = 5L, grainA = "A", dataA = "five" },
+ new TimeSeriesOneGrainStringInput(){ date = 7L, grainA = "A", dataA = "seven" }};
+ var data = mlContext.Data.LoadFromEnumerable(dataList);
+
+ // When suppressTypeErrors is set to false this will throw an error.
+ var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: false);
+ var ex = Assert.Throws(() => pipeline.Fit(data));
+ Assert.Equal("Only Numeric type columns are supported for ImputationStrategy median. (use suppressError flag to skip imputing non-numeric types)", ex.Message);
+
+ // When suppressTypeErrors is set to true then the default value will be used.
+ pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true);
+ var model = pipeline.Fit(data);
+
+ var output = model.Transform(data);
+ var prev = output.Preview();
+
+ // Should have 3 original columns + 1 more for IsRowImputed
+ Assert.Equal(4, output.Schema.Count);
+
+ // Imputing rows with dates 2,4,6, so should have length of 8
+ Assert.Equal(8, prev.RowView.Length);
+
+ // Check that imputed rows have the default value
+ Assert.Equal("", prev.ColumnView[2].Values[2].ToString());
+ Assert.Equal("", prev.ColumnView[2].Values[4].ToString());
+ Assert.Equal("", prev.ColumnView[2].Values[6].ToString());
+
+ // Make sure grain was propagated correctly
+ Assert.Equal("A", prev.ColumnView[1].Values[2].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[4].ToString());
+ Assert.Equal("A", prev.ColumnView[1].Values[6].ToString());
+
+ // Make sure original values stayed the same
+ Assert.Equal("zero", prev.ColumnView[2].Values[0].ToString());
+ Assert.Equal("one", prev.ColumnView[2].Values[1].ToString());
+ Assert.Equal("three", prev.ColumnView[2].Values[3].ToString());
+ Assert.Equal("five", prev.ColumnView[2].Values[5].ToString());
+ Assert.Equal("seven", prev.ColumnView[2].Values[7].ToString());
+
+ // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest
+ Assert.Equal(false, prev.ColumnView[3].Values[0]);
+ Assert.Equal(false, prev.ColumnView[3].Values[1]);
+ Assert.Equal(true, prev.ColumnView[3].Values[2]);
+ Assert.Equal(false, prev.ColumnView[3].Values[3]);
+ Assert.Equal(true, prev.ColumnView[3].Values[4]);
+ Assert.Equal(false, prev.ColumnView[3].Values[5]);
+ Assert.Equal(true, prev.ColumnView[3].Values[6]);
+ Assert.Equal(false, prev.ColumnView[3].Values[7]);
+
+ TestEstimatorCore(pipeline, data);
+
+ Done();
+ }
+ }
+}