diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs new file mode 100644 index 0000000000..537f0611b9 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputer.cs @@ -0,0 +1,660 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadModel), + TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IDataTransform), typeof(TimeSeriesImputerTransformer), null, typeof(SignatureLoadDataTransform), + TimeSeriesImputerTransformer.UserName, TimeSeriesImputerTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(TimeSeriesTransformerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + public static class TimeSeriesImputerExtensionClass + { + /// + /// Create a , Imputes missing rows and column data per grain. Operates on all columns in the IDataView. + /// Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the + /// purpose of this estimator. Other column types will have the default value placed if a row is imputed. + /// + /// The transform catalog. + /// Column representing the time series. Should be of type + /// List of columns to use as grains + /// Mode of imputation for missing values in column. If not passed defaults to forward fill + public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this TransformsCatalog catalog, string timeSeriesColumn, string[] grainColumns, + TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill) + => new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, null, TimeSeriesImputerEstimator.FilterMode.NoFilter, imputeMode, true); + + /// + /// Create a , Imputes missing rows and column data per grain. Applies the imputation strategy on + /// a filtered list of columns in the IDataView. Columns that are are excluded will have the default value for that data type used when a row + /// is imputed. Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the + /// purpose of this estimator. + /// + /// The transform catalog. + /// Column representing the time series. Should be of type + /// List of columns to use as grains + /// List of columns to filter. If is than columns in the list will be ignored. + /// If is than values in the list are the only columns imputed. + /// Whether the list should include or exclude those columns. + /// Mode of imputation for missing values in column. If not passed defaults to forward fill + /// Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column and use the default value. If false, will stop and throw an error. + public static TimeSeriesImputerEstimator ReplaceMissingTimeSeriesValues(this TransformsCatalog catalog, string timeSeriesColumn, + string[] grainColumns, string[] filterColumns, TimeSeriesImputerEstimator.FilterMode filterMode = TimeSeriesImputerEstimator.FilterMode.Exclude, + TimeSeriesImputerEstimator.ImputationStrategy imputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill, + bool suppressTypeErrors = false) + => new TimeSeriesImputerEstimator(CatalogUtils.GetEnvironment(catalog), timeSeriesColumn, grainColumns, filterColumns, filterMode, imputeMode, suppressTypeErrors); + } + + /// + /// Imputes missing rows and column data per grain, based on the dates in the date column. This operation needs to happen to every column in the IDataView, + /// If you "filter" a column using the filterColumns and filterMode parameters, if a row is imputed the default value for that type will be used. + /// Currently only float/double/string columns are supported for imputation strategies, and an empty string is considered "missing" for the + /// purpose of this estimator. A new column is added to the schema after this operation is run. The column is called "IsRowImputed" and is a + /// boolean value representing if the row was created as a result of this operation or not. + /// + /// NOTE: It is not recommended to chain this multiple times. If a column is filtered, the default value is placed when a row is imputed, and the + /// default value is not null. Thus any other TimeSeriesImputers will not be able to replace those values anymore causing essentially a very + /// computationally expensive NO-OP. + /// + /// + /// + /// is not a trivial estimator and needs training. + /// + /// + /// ]]> + /// + /// + /// + /// + public sealed class TimeSeriesImputerEstimator : IEstimator + { + private Options _options; + internal const string IsRowImputedColumnName = "IsRowImputed"; + + private readonly IHost _host; + private static readonly List _currentSupportedTypes = new List { typeof(sbyte), typeof(byte), typeof(short), typeof(ushort), typeof(int), typeof(uint), + typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(string), typeof(ReadOnlyMemory)}; + + #region Options + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "Column representing the time", Name = "TimeSeriesColumn", ShortName = "time", SortOrder = 1)] + public string TimeSeriesColumn; + + [Argument((ArgumentType.MultipleUnique | ArgumentType.Required), HelpText = "List of grain columns", Name = "GrainColumns", ShortName = "grains", SortOrder = 2)] + public string[] GrainColumns; + + // This transformer adds columns + [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to filter", Name = "FilterColumns", ShortName = "filters", SortOrder = 2)] + public string[] FilterColumns; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Filter mode. Either include or exclude", Name = "FilterMode", ShortName = "fmode", SortOrder = 3)] + public FilterMode FilterMode = FilterMode.Exclude; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Mode for imputing, defaults to ForwardFill if not provided", Name = "ImputeMode", ShortName = "mode", SortOrder = 3)] + public ImputationStrategy ImputeMode = ImputationStrategy.ForwardFill; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.", Name = "SupressTypeErrors", ShortName = "error", SortOrder = 3)] + public bool SupressTypeErrors = false; + } + + #endregion + + #region Class Enums + + public enum ImputationStrategy : byte + { + ForwardFill = 1, + BackFill = 2, + Median = 3, + // Interpolate = 4, interpolate not currently supported in the native code. + }; + + public enum FilterMode : byte + { + NoFilter = 1, + Include = 2, + Exclude = 3 + }; + + #endregion + + internal TimeSeriesImputerEstimator(IHostEnvironment env, string timeSeriesColumn, string[] grainColumns, string[] filterColumns, FilterMode filterMode, ImputationStrategy imputeMode, bool supressTypeErrors) + { + Contracts.CheckValue(env, nameof(env)); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator"); + _host.CheckValue(timeSeriesColumn, nameof(timeSeriesColumn), "TimePoint column should not be null."); + _host.CheckNonEmpty(grainColumns, nameof(grainColumns), "Need at least one grain column."); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + if (filterMode == FilterMode.Include) + _host.CheckNonEmpty(filterColumns, nameof(filterColumns), "Need at least 1 filter column if a FilterMode is specified"); + + _options = new Options + { + TimeSeriesColumn = timeSeriesColumn, + GrainColumns = grainColumns, + FilterColumns = filterColumns == null ? new string[] { } : filterColumns, + FilterMode = filterMode, + ImputeMode = imputeMode, + SupressTypeErrors = supressTypeErrors + }; + } + + internal TimeSeriesImputerEstimator(IHostEnvironment env, Options options) + { + Contracts.CheckValue(env, nameof(env)); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + _host = Contracts.CheckRef(env, nameof(env)).Register("TimeSeriesImputerEstimator"); + _host.CheckValue(options.TimeSeriesColumn, nameof(options.TimeSeriesColumn), "TimePoint column should not be null."); + _host.CheckValue(options.GrainColumns, nameof(options.GrainColumns), "Grain columns should not be null."); + _host.CheckNonEmpty(options.GrainColumns, nameof(options.GrainColumns), "Need at least one grain column."); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + if (options.FilterMode != FilterMode.NoFilter) + _host.CheckNonEmpty(options.FilterColumns, nameof(options.FilterColumns), "Need at least 1 filter column if a FilterMode is specified"); + + _options = options; + } + + public TimeSeriesImputerTransformer Fit(IDataView input) + { + // If we are not suppressing type errors make sure columns to impute only contain supported types. + if (!_options.SupressTypeErrors) + { + var columns = input.Schema.Where(x => !_options.GrainColumns.Contains(x.Name)); + if (_options.FilterMode == FilterMode.Exclude) + columns = columns.Where(x => !_options.FilterColumns.Contains(x.Name)); + else if (_options.FilterMode == FilterMode.Include) + columns = columns.Where(x => _options.FilterColumns.Contains(x.Name)); + + foreach (var column in columns) + { + if (!_currentSupportedTypes.Contains(column.Type.RawType)) + throw new InvalidOperationException($"Type {column.Type.RawType.ToString()} for column {column.Name} not a supported type."); + } + } + + return new TimeSeriesImputerTransformer(_host, _options, input); + } + + // Add one column called WasColumnImputed, otherwise everything stays the same. + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + columns[IsRowImputedColumnName] = new SchemaShape.Column(IsRowImputedColumnName, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false); + return new SchemaShape(columns.Values); + } + } + + public sealed class TimeSeriesImputerTransformer : ITransformer, IDisposable + { + #region Class data members + + internal const string Summary = "Fills in missing row and values"; + internal const string UserName = "TimeSeriesImputer"; + internal const string ShortName = "tsi"; + internal const string LoadName = "TimeSeriesImputer"; + internal const string LoaderSignature = "TimeSeriesImputer"; + + private readonly IHost _host; + private readonly string _timeSeriesColumn; + private readonly string[] _grainColumns; + private readonly string[] _dataColumns; + private readonly string[] _allColumnNames; + private readonly bool _suppressTypeErrors; + private readonly TimeSeriesImputerEstimator.ImputationStrategy _imputeMode; + internal TransformerEstimatorSafeHandle TransformerHandle; + + #endregion + + // Normal constructor. + internal TimeSeriesImputerTransformer(IHostEnvironment host, TimeSeriesImputerEstimator.Options options, IDataView input) + { + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + + _host = host.Register(nameof(TimeSeriesImputerTransformer)); + _timeSeriesColumn = options.TimeSeriesColumn; + _grainColumns = options.GrainColumns; + _imputeMode = options.ImputeMode; + _suppressTypeErrors = options.SupressTypeErrors; + + IEnumerable tempDataColumns; + + if (options.FilterMode == TimeSeriesImputerEstimator.FilterMode.Exclude) + tempDataColumns = input.Schema.Where(x => !options.FilterColumns.Contains(x.Name)).Select(x => x.Name); + else if (options.FilterMode == TimeSeriesImputerEstimator.FilterMode.Include) + tempDataColumns = input.Schema.Where(x => options.FilterColumns.Contains(x.Name)).Select(x => x.Name); + else + tempDataColumns = input.Schema.Select(x => x.Name); + + // Time series and Grain columns should never be included in the data columns + _dataColumns = tempDataColumns.Where(x => x != _timeSeriesColumn && !_grainColumns.Contains(x)).ToArray(); + + // 1 is for the time series column. Make one array in the correct order of all the columns. + // Order is Timeseries column, All grain columns, All data columns. + _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length]; + _allColumnNames[0] = _timeSeriesColumn; + Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length); + Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length); + + TransformerHandle = CreateTransformerFromEstimator(input); + } + + // Factory method for SignatureLoadModel. + internal TimeSeriesImputerTransformer(IHostEnvironment host, ModelLoadContext ctx) + { + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + _host = host.Register(nameof(TimeSeriesImputerTransformer)); + + // *** Binary format *** + // name of time series column + // length of grain column array + // all column names in grain column array + // length of filter column array + // all column names in filter column array + // byte value of filter mode + // byte value of impute mode + // length of C++ state array + // C++ byte state array + + _timeSeriesColumn = ctx.Reader.ReadString(); + + _grainColumns = new string[ctx.Reader.ReadInt32()]; + for (int i = 0; i < _grainColumns.Length; i++) + _grainColumns[i] = ctx.Reader.ReadString(); + + _dataColumns = new string[ctx.Reader.ReadInt32()]; + for (int i = 0; i < _dataColumns.Length; i++) + _dataColumns[i] = ctx.Reader.ReadString(); + + _imputeMode = (TimeSeriesImputerEstimator.ImputationStrategy)ctx.Reader.ReadByte(); + + _allColumnNames = new string[1 + _grainColumns.Length + _dataColumns.Length]; + _allColumnNames[0] = _timeSeriesColumn; + Array.Copy(_grainColumns, 0, _allColumnNames, 1, _grainColumns.Length); + Array.Copy(_dataColumns, 0, _allColumnNames, 1 + _grainColumns.Length, _dataColumns.Length); + + var nativeState = ctx.Reader.ReadByteArray(); + TransformerHandle = CreateTransformerFromSavedData(nativeState); + } + + private unsafe TransformerEstimatorSafeHandle CreateTransformerFromSavedData(byte[] nativeState) + { + fixed (byte* rawStatePointer = nativeState) + { + IntPtr dataSize = new IntPtr(nativeState.Count()); + var result = CreateTransformerFromSavedDataNative(rawStatePointer, dataSize, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + return (IDataTransform)(new TimeSeriesImputerTransformer(env, ctx).Transform(input)); + } + + private unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimator(IDataView input) + { + IntPtr estimator; + IntPtr errorHandle; + bool success; + + var allColumns = input.Schema.Where(x => _allColumnNames.Contains(x.Name)).Select(x => TypedColumn.CreateTypedColumn(x, _dataColumns)).ToDictionary(x => x.Column.Name); + + // Create buffer to hold binary data + var columnBuffer = new byte[4096]; + + // Create TypeId[] for types of grain and data columns; + var dataColumnTypes = new TypeId[_dataColumns.Length]; + var grainColumnTypes = new TypeId[_grainColumns.Length]; + + foreach (var column in _grainColumns.Select((value, index) => new { index, value })) + grainColumnTypes[column.index] = allColumns[column.value].GetTypeId(); + + foreach (var column in _dataColumns.Select((value, index) => new { index, value })) + dataColumnTypes[column.index] = allColumns[column.value].GetTypeId(); + + fixed (bool* suppressErrors = &_suppressTypeErrors) + fixed (TypeId* rawDataColumnTypes = dataColumnTypes) + fixed (TypeId* rawGrainColumnTypes = grainColumnTypes) + { + success = CreateEstimatorNative(rawGrainColumnTypes, new IntPtr(grainColumnTypes.Length), rawDataColumnTypes, new IntPtr(dataColumnTypes.Length), _imputeMode, suppressErrors, out estimator, out errorHandle); + } + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorNative)) + { + var fitResult = FitResult.Continue; + while (fitResult != FitResult.Complete) + { + using (var cursor = input.GetRowCursorForAllColumns()) + { + // Initialize getters for start of loop + foreach (var column in allColumns.Values) + column.InitializeGetter(cursor); + + while ((fitResult == FitResult.Continue || fitResult == FitResult.ResetAndContinue) && cursor.MoveNext()) + { + BuildColumnByteArray(allColumns, ref columnBuffer, out int serializedDataLength); + + fixed (byte* bufferPointer = columnBuffer) + { + var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(serializedDataLength) }; + success = FitNative(estimatorHandler, binaryArchiveData, out fitResult, out errorHandle); + } + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + success = CompleteTrainingNative(estimatorHandler, out fitResult, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + } + + success = CreateTransformerFromEstimatorNative(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + private void BuildColumnByteArray(Dictionary allColumns, ref byte[] columnByteBuffer, out int serializedDataLength) + { + serializedDataLength = 0; + foreach (var column in _allColumnNames) + { + var bytes = allColumns[column].GetSerializedValue(); + var byteLength = bytes.Length; + if (byteLength + serializedDataLength >= columnByteBuffer.Length) + Array.Resize(ref columnByteBuffer, columnByteBuffer.Length * 2); + + Array.Copy(bytes, 0, columnByteBuffer, serializedDataLength, byteLength); + serializedDataLength += byteLength; + } + } + + public bool IsRowToRowMapper => false; + + // Schema not changed + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumns(inputSchema.AsEnumerable()); + schemaBuilder.AddColumn(TimeSeriesImputerEstimator.IsRowImputedColumnName, BooleanDataViewType.Instance); + + return schemaBuilder.ToSchema(); + } + + public IRowToRowMapper GetRowToRowMapper(DataViewSchema inputSchema) => throw new InvalidOperationException("Not a RowToRowMapper."); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TimeIm T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(TimeSeriesImputerTransformer).Assembly.FullName); + } + + public void Save(ModelSaveContext ctx) + { + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // name of time series column + // length of grain column array + // all column names in grain column array + // length of data column array + // all column names in data column array + // byte value of impute mode + // length of C++ state array + // C++ byte state array + + ctx.Writer.Write(_timeSeriesColumn); + ctx.Writer.Write(_grainColumns.Length); + foreach (var column in _grainColumns) + ctx.Writer.Write(column); + ctx.Writer.Write(_dataColumns.Length); + foreach (var column in _dataColumns) + ctx.Writer.Write(column); + ctx.Writer.Write((byte)_imputeMode); + var data = CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + + private byte[] CreateTransformerSaveData() + { + var success = CreateTransformerSaveDataNative(TransformerHandle, out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + public IDataView Transform(IDataView input) => MakeDataTransform(input); + + internal TimeSeriesImputerDataView MakeDataTransform(IDataView input) + { + _host.CheckValue(input, nameof(input)); + + return new TimeSeriesImputerDataView(_host, input, _timeSeriesColumn, _grainColumns, _dataColumns, _allColumnNames, this); + } + + internal TransformerEstimatorSafeHandle CloneTransformer() => CreateTransformerFromSavedData(CreateTransformerSaveData()); + + public void Dispose() + { + if (!TransformerHandle.IsClosed) + TransformerHandle.Close(); + } + + #region C++ function declarations + // TODO: Update entry points + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateEstimatorNative(TypeId* grainTypes, IntPtr grainTypesSize, TypeId* dataTypes, IntPtr dataTypesSize, TimeSeriesImputerEstimator.ImputationStrategy strategy, bool* suppressTypeErrors, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformer", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CompleteTraining", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CompleteTrainingNative(TransformerEstimatorSafeHandle estimator, out FitResult fitResult, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Fit", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool FitNative(TransformerEstimatorSafeHandle estimator, NativeBinaryArchiveData data, out FitResult fitResult, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromEstimator", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_CreateTransformerFromSavedData"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, out IntPtr transformer, out IntPtr errorHandle); + + #endregion + + #region Typed Columns + + private abstract class TypedColumn + { + internal readonly DataViewSchema.Column Column; + internal TypedColumn(DataViewSchema.Column column) + { + Column = column; + } + + internal abstract void InitializeGetter(DataViewRowCursor cursor); + internal abstract byte[] GetSerializedValue(); + internal abstract TypeId GetTypeId(); + + internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns) + { + var type = column.Type.RawType.ToString(); + if (type == typeof(sbyte).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(short).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(int).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(long).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(byte).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(ushort).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(uint).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(ulong).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(float).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(double).ToString()) + return new NumericTypedColumn(column, optionalColumns.Contains(column.Name)); + else if (type == typeof(ReadOnlyMemory).ToString()) + return new StringTypedColumn(column, optionalColumns.Contains(column.Name)); + + throw new InvalidOperationException($"Unsupported type {type}"); + } + } + + private abstract class TypedColumn : TypedColumn + { + private ValueGetter _getter; + private T _value; + + internal TypedColumn(DataViewSchema.Column column) : + base(column) + { + _value = default; + } + + internal override void InitializeGetter(DataViewRowCursor cursor) + { + _getter = cursor.GetGetter(Column); + } + + internal T GetValue() + { + _getter(ref _value); + return _value; + } + + internal override TypeId GetTypeId() + { + return typeof(T).GetNativeTypeIdFromType(); + } + } + + private class NumericTypedColumn : TypedColumn + { + private readonly bool _isNullable; + + internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable = false) : + base(column) + { + _isNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + dynamic value = GetValue(); + byte[] bytes; + if (value.GetType() == typeof(byte)) + bytes = new byte[1] { value }; + bytes = BitConverter.GetBytes(value); + + if (_isNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double)) + return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray(); + else + return bytes; + } + } + + private class StringTypedColumn : TypedColumn> + { + private readonly bool _isNullable; + + internal StringTypedColumn(DataViewSchema.Column column, bool isNullable = false) : + base(column) + { + _isNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + var value = GetValue().ToString(); + var stringBytes = Encoding.UTF8.GetBytes(value); + if (_isNullable) + return new byte[] { Convert.ToByte(true) }.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray(); + return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray(); + } + } + + #endregion + } + + internal static class TimeSeriesTransformerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.TimeSeriesImputer", + Desc = TimeSeriesImputerTransformer.Summary, + UserName = TimeSeriesImputerTransformer.UserName, + ShortName = TimeSeriesImputerTransformer.ShortName)] + public static CommonOutputs.TransformOutput TimeSeriesImputer(IHostEnvironment env, TimeSeriesImputerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, TimeSeriesImputerTransformer.ShortName, input); + var xf = new TimeSeriesImputerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs new file mode 100644 index 0000000000..ccb33b99e9 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/TimeSeriesImputerDataView.cs @@ -0,0 +1,766 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Runtime; +using Microsoft.Win32.SafeHandles; +using static Microsoft.ML.Featurizers.CommonExtensions; +using static Microsoft.ML.Featurizers.TimeSeriesImputerEstimator; + +namespace Microsoft.ML.Transforms +{ + + internal sealed class TimeSeriesImputerDataView : IDataTransform + { + #region Typed Columns + private TimeSeriesImputerTransformer _parent; + public class SharedColumnState + { + public bool SourceCanMoveNext { get; set; } + public int TransformedDataPosition { get; set; } + + // This array is used to hold the returned row data from the native transformer. Because we create rows in this transformer, the number + // of rows returned from the native code is not always consistent and so this has to be an array. + public NativeBinaryArchiveData[] TransformedData { get; set; } + + // Hold the serialized data that we are going to send to the native code for processing. + public byte[] ColumnBuffer { get; set; } + public TransformedDataSafeHandle TransformedDataHandler { get; set; } + } + + private abstract class TypedColumn + { + private protected SharedColumnState SharedState; + + internal readonly DataViewSchema.Column Column; + internal readonly bool IsImputed; + internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state) + { + Column = column; + SharedState = state; + IsImputed = isImputed; + } + + internal abstract Delegate GetGetter(); + internal abstract void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformerParent, string timeSeriesColumn, + string[] grainColumns, string[] dataColumns, string[] allColumnNames, Dictionary allColumns); + + internal abstract TypeId GetTypeId(); + internal abstract byte[] GetSerializedValue(); + internal abstract unsafe int GetDataSizeInBytes(byte* data, int currentOffset); + internal abstract void QueueNonImputedColumnValue(); + + public bool MoveNext(DataViewRowCursor cursor) + { + SharedState.TransformedDataPosition++; + + if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + SharedState.SourceCanMoveNext = cursor.MoveNext(); + + if (!SharedState.SourceCanMoveNext) + if (SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + { + if (!SharedState.TransformedDataHandler.IsClosed) + SharedState.TransformedDataHandler.Dispose(); + return false; + } + + return true; + } + + internal static TypedColumn CreateTypedColumn(DataViewSchema.Column column, string[] optionalColumns, string[] allImputedColumns, SharedColumnState state) + { + var type = column.Type.RawType.ToString(); + if (type == typeof(sbyte).ToString()) + return new SByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(short).ToString()) + return new ShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(int).ToString()) + return new IntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(long).ToString()) + return new LongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(byte).ToString()) + return new ByteTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(ushort).ToString()) + return new UShortTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(uint).ToString()) + return new UIntTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(ulong).ToString()) + return new ULongTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(float).ToString()) + return new FloatTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(double).ToString()) + return new DoubleTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(ReadOnlyMemory).ToString()) + return new StringTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + else if (type == typeof(bool).ToString()) + return new BoolTypedColumn(column, optionalColumns.Contains(column.Name), allImputedColumns.Contains(column.Name), state); + + throw new InvalidOperationException($"Unsupported type {type}"); + } + } + + private abstract class TypedColumn : TypedColumn + { + private ValueGetter _getter; + private ValueGetter _sourceGetter; + private long _position; + private T _cache; + + // When columns are not being imputed, we need to store the column values in memory until they are used. + private protected Queue SourceQueue; + + internal TypedColumn(DataViewSchema.Column column, bool isImputed, SharedColumnState state) : + base(column, isImputed, state) + { + SourceQueue = new Queue(); + _position = -1; + } + + internal override Delegate GetGetter() + { + return _getter; + } + + internal override unsafe void InitializeGetter(DataViewRowCursor cursor, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn, + string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, Dictionary allColumns) + { + if (Column.Name != IsRowImputedColumnName) + _sourceGetter = cursor.GetGetter(Column); + + _getter = (ref T dst) => + { + IntPtr errorHandle = IntPtr.Zero; + bool success = false; + if (SharedState.TransformedData == null || SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + { + // Free native memory if we are about to get more + if (SharedState.TransformedData != null && SharedState.TransformedDataPosition >= SharedState.TransformedData.Length) + SharedState.TransformedDataHandler.Dispose(); + + var outputDataSize = IntPtr.Zero; + NativeBinaryArchiveData* outputData = default; + while(outputDataSize == IntPtr.Zero && SharedState.SourceCanMoveNext) + { + BuildColumnByteArray(allColumns, allImputedColumnNames, out int bufferLength); + QueueDataForNonImputedColumns(allColumns, allImputedColumnNames); + fixed (byte* bufferPointer = SharedState.ColumnBuffer) + { + var binaryArchiveData = new NativeBinaryArchiveData() { Data = bufferPointer, DataSize = new IntPtr(bufferLength) }; + success = TransformDataNative(transformer, binaryArchiveData, out outputData, out outputDataSize, out errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + if (outputDataSize == IntPtr.Zero) + SharedState.SourceCanMoveNext = cursor.MoveNext(); + } + + if (!SharedState.SourceCanMoveNext) + success = FlushDataNative(transformer, out outputData, out outputDataSize, out errorHandle); + + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + if (outputDataSize.ToInt32() > 0) + { + SharedState.TransformedDataHandler = new TransformedDataSafeHandle((IntPtr)outputData, outputDataSize); + SharedState.TransformedData = new NativeBinaryArchiveData[outputDataSize.ToInt32()]; + for (int i = 0; i < outputDataSize.ToInt32(); i++) + { + SharedState.TransformedData[i] = *(outputData + i); + } + SharedState.TransformedDataPosition = 0; + } + } + + // Base case where we didn't impute the column + if (!allImputedColumnNames.Contains(Column.Name)) + { + var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition]; + // If the row was imputed we want to just return the default value for the type. + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(imputedData.Data, 0)) + { + dst = default; + } + else + { + // If the row wasn't imputed, get the original value for that row we stored in the queue and return that. + if (_position != cursor.Position) + { + _position = cursor.Position; + _cache = SourceQueue.Dequeue(); + } + dst = _cache; + } + } + // If we did impute the column then parse the data from the returned byte array. + else + { + var imputedData = SharedState.TransformedData[SharedState.TransformedDataPosition]; + int offset = 0; + foreach (var columnName in allImputedColumnNames) + { + var col = allColumns[columnName]; + if (col.Column.Name == Column.Name) + { + dst = GetDataFromNativeBinaryArchiveData(imputedData.Data, offset); + return; + } + + offset += col.GetDataSizeInBytes(imputedData.Data, offset); + } + + // This should never be hit. + dst = default; + } + }; + } + + private void QueueDataForNonImputedColumns(Dictionary allColumns, string[] allImputedColumnNames) + { + foreach (var column in allColumns.Where(x => !allImputedColumnNames.Contains(x.Value.Column.Name)).Select(x => x.Value)) + column.QueueNonImputedColumnValue(); + } + + internal override void QueueNonImputedColumnValue() + { + SourceQueue.Enqueue(GetSourceValue()); + } + + private void BuildColumnByteArray(Dictionary allColumns, string[] columns, out int bufferLength) + { + bufferLength = 0; + foreach(var column in columns.Where(x => x != IsRowImputedColumnName)) + { + var bytes = allColumns[column].GetSerializedValue(); + var byteLength = bytes.Length; + if (byteLength + bufferLength >= SharedState.ColumnBuffer.Length) + { + var buffer = SharedState.ColumnBuffer; + Array.Resize(ref buffer, SharedState.ColumnBuffer.Length * 2); + SharedState.ColumnBuffer = buffer; + } + + Array.Copy(bytes, 0, SharedState.ColumnBuffer, bufferLength, byteLength); + bufferLength += byteLength; + } + } + + private protected T GetSourceValue() + { + T value = default; + _sourceGetter(ref value); + return value; + } + + internal override TypeId GetTypeId() + { + return typeof(T).GetNativeTypeIdFromType(); + } + + internal unsafe abstract T GetDataFromNativeBinaryArchiveData(byte* data, int offset); + } + + private abstract class NumericTypedColumn : TypedColumn + { + private protected readonly bool IsNullable; + + internal NumericTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isImputed, state) + { + IsNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + dynamic value = GetSourceValue(); + byte[] bytes; + if (value.GetType() == typeof(byte)) + bytes = new byte[1] { value }; + if (BitConverter.IsLittleEndian) + bytes = BitConverter.GetBytes(value); + else + bytes = BitConverter.GetBytes(value); + + if (IsNullable && value.GetType() != typeof(float) && value.GetType() != typeof(double)) + return new byte[1] { Convert.ToByte(true) }.Concat(bytes).ToArray(); + else + return bytes; + } + + internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset) + { + if (IsNullable && typeof(T) != typeof(float) && typeof(T) != typeof(double)) + return Marshal.SizeOf(default(T)) + sizeof(bool); + else + return Marshal.SizeOf(default(T)); + } + } + + private class ByteTypedColumn : NumericTypedColumn + { + internal ByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override byte GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(byte*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(byte*)(data + offset); + } + } + + private class SByteTypedColumn : NumericTypedColumn + { + internal SByteTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override sbyte GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(sbyte*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(sbyte*)(data + offset); + } + } + + private class ShortTypedColumn : NumericTypedColumn + { + internal ShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override short GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(short*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(short*)(data + offset); + } + } + + private class UShortTypedColumn : NumericTypedColumn + { + internal UShortTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override ushort GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(ushort*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(ushort*)(data + offset); + } + } + + private class IntTypedColumn : NumericTypedColumn + { + internal IntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override int GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(int*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(int*)(data + offset); + } + } + + private class UIntTypedColumn : NumericTypedColumn + { + internal UIntTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override uint GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(uint*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(uint*)(data + offset); + } + } + + private class LongTypedColumn : NumericTypedColumn + { + internal LongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override long GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(long*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(long*)(data + offset); + } + } + + private class ULongTypedColumn : NumericTypedColumn + { + internal ULongTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override ulong GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(ulong*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(ulong*)(data + offset); + } + } + + private class FloatTypedColumn : NumericTypedColumn + { + internal FloatTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override float GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + var bytes = new byte[sizeof(float)]; + Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(float)); + return BitConverter.ToSingle(bytes, 0); + } + } + + private class DoubleTypedColumn : NumericTypedColumn + { + internal DoubleTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override double GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + var bytes = new byte[sizeof(double)]; + Marshal.Copy((IntPtr)(data + offset), bytes, 0, sizeof(double)); + return BitConverter.ToDouble(bytes, 0); + } + } + + private class BoolTypedColumn : NumericTypedColumn + { + internal BoolTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isNullable, isImputed, state) + { + } + + internal unsafe override bool GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (IsNullable) + { + if (GetBoolFromNativeBinaryArchiveData(data, offset)) + return *(bool*)(data + offset + sizeof(bool)); + else + return default; + } + else + return *(bool*)(data + offset); + } + + internal static unsafe bool GetBoolFromNativeBinaryArchiveData(byte* data, int offset) + { + return *(bool*)(data + offset); + } + + internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset) + { + return sizeof(bool); + } + } + + private class StringTypedColumn : TypedColumn> + { + private readonly bool _isNullable; + internal StringTypedColumn(DataViewSchema.Column column, bool isNullable, bool isImputed, SharedColumnState state) : + base(column, isImputed, state) + { + _isNullable = isNullable; + } + + internal override byte[] GetSerializedValue() + { + var value = GetSourceValue().ToString(); + var stringBytes = Encoding.UTF8.GetBytes(value); + if (_isNullable) + return new byte[] { Convert.ToByte(true)}.Concat(BitConverter.GetBytes(stringBytes.Length)).Concat(stringBytes).ToArray(); + return BitConverter.GetBytes(stringBytes.Length).Concat(stringBytes).ToArray(); + } + + internal unsafe override ReadOnlyMemory GetDataFromNativeBinaryArchiveData(byte* data, int offset) + { + if (_isNullable) + { + if (!BoolTypedColumn.GetBoolFromNativeBinaryArchiveData(data, offset)) // If value not present return empty string + return new ReadOnlyMemory("".ToCharArray()); + + var size = *(uint*)(data + offset + 1); // Add 1 for the byte bool flag + + var bytes = new byte[size]; + Marshal.Copy((IntPtr)(data + offset + sizeof(uint) + 1), bytes, 0, (int)size); + return Encoding.UTF8.GetString(bytes).AsMemory(); + } + else + { + var size = *(uint*)(data + offset); + + var bytes = new byte[size]; + Marshal.Copy((IntPtr)(data + offset + sizeof(uint)), bytes, 0, (int)size); + return Encoding.UTF8.GetString(bytes).AsMemory(); + } + } + + internal override unsafe int GetDataSizeInBytes(byte* data, int currentOffset) + { + var size = *(uint*)(data + currentOffset); + if (_isNullable) + return 1 + (int)size + sizeof(uint); // + 1 for the byte bool flag + + return (int)size + sizeof(uint); + } + } + + #endregion + + #region Native Exports + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity] + private static extern unsafe bool TransformDataNative(TransformerEstimatorSafeHandle transformer, /*in*/ NativeBinaryArchiveData data, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_Transform"), SuppressUnmanagedCodeSecurity] + private static extern unsafe bool FlushDataNative(TransformerEstimatorSafeHandle transformer, out NativeBinaryArchiveData* outputData, out IntPtr outputDataSize, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "TimeSeriesImputerFeaturizer_BinaryArchive_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern unsafe bool DestroyTransformedDataNative(IntPtr data, IntPtr dataSize, out IntPtr errorHandle); + + #endregion + + #region Native SafeHandles + + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private IntPtr _size; + public TransformedDataSafeHandle(IntPtr handle, IntPtr size) : base(true) + { + SetHandle(handle); + _size = size; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return DestroyTransformedDataNative(handle, _size, out IntPtr errorHandle); + } + } + + #endregion + + private readonly IHostEnvironment _host; + private readonly IDataView _source; + private readonly string _timeSeriesColumn; + private readonly string[] _dataColumns; + private readonly string[] _grainColumns; + private readonly string[] _allImputedColumnNames; + private readonly DataViewSchema _schema; + + internal TimeSeriesImputerDataView(IHostEnvironment env, IDataView input, string timeSeriesColumn, string[] grainColumns, string[] dataColumns, string[] allColumnNames, TimeSeriesImputerTransformer parent) + { + _host = env; + _source = input; + + _timeSeriesColumn = timeSeriesColumn; + _grainColumns = grainColumns; + _dataColumns = dataColumns; + _allImputedColumnNames = new string[] { IsRowImputedColumnName }.Concat(allColumnNames).ToArray(); + _parent = parent; + // Build new schema. + var schemaColumns = _source.Schema.ToDictionary(x => x.Name); + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumns(_source.Schema.AsEnumerable()); + schemaBuilder.AddColumn(IsRowImputedColumnName, BooleanDataViewType.Instance); + + _schema = schemaBuilder.ToSchema(); + } + + public bool CanShuffle => false; + + public DataViewSchema Schema => _schema; + + public IDataView Source => _source; + + public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) + { + _host.AssertValueOrNull(rand); + + var input = _source.GetRowCursorForAllColumns(); + return new Cursor(_host, input, _parent.CloneTransformer(), _timeSeriesColumn, _grainColumns, _dataColumns, _allImputedColumnNames, _schema); + } + + // Can't use parallel cursors so this defaults to calling non-parallel version + public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) => + new DataViewRowCursor[] { GetRowCursor(columnsNeeded, rand) }; + + // Since we may add rows we don't know the row count + public long? GetRowCount() => null; + + public void Save(ModelSaveContext ctx) + { + _parent.Save(ctx); + } + + private sealed class Cursor : DataViewRowCursor + { + private readonly IChannelProvider _ch; + private DataViewRowCursor _input; + private long _position; + private bool _isGood; + private readonly Dictionary _allColumns; + private readonly DataViewSchema _schema; + private readonly TransformerEstimatorSafeHandle _transformer; + + public Cursor(IChannelProvider provider, DataViewRowCursor input, TransformerEstimatorSafeHandle transformer, string timeSeriesColumn, + string[] grainColumns, string[] dataColumns, string[] allImputedColumnNames, DataViewSchema schema) + { + _ch = provider; + _ch.CheckValue(input, nameof(input)); + + _input = input; + var length = input.Schema.Count; + _position = -1; + _schema = schema; + _transformer = transformer; + + var sharedState = new SharedColumnState() + { + SourceCanMoveNext = true, + ColumnBuffer = new byte[4096] + }; + + _allColumns = _schema.Select(x => TypedColumn.CreateTypedColumn(x, dataColumns, allImputedColumnNames, sharedState)).ToDictionary(x => x.Column.Name); ; + _allColumns[IsRowImputedColumnName] = new BoolTypedColumn(_schema[IsRowImputedColumnName], false, true, sharedState); + + foreach (var column in _allColumns.Values) + { + column.InitializeGetter(_input, transformer, timeSeriesColumn, grainColumns, dataColumns, allImputedColumnNames, _allColumns); + } + } + + public sealed override ValueGetter GetIdGetter() + { + return + (ref DataViewRowId val) => + { + _ch.Check(_isGood, RowCursorUtils.FetchValueStateError); + val = new DataViewRowId((ulong)Position, 0); + }; + } + + public sealed override DataViewSchema Schema => _schema; + + /// + /// Since rows will be generated all columns are active + /// + public override bool IsColumnActive(DataViewSchema.Column column) => true; + + protected override void Dispose(bool disposing) + { + if (!_transformer.IsClosed) + _transformer.Close(); + } + + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + _ch.Check(IsColumnActive(column)); + + var fn = _allColumns[column.Name].GetGetter() as ValueGetter; + if (fn == null) + throw _ch.Except("Invalid TValue in GetGetter: '{0}'", typeof(TValue)); + return fn; + } + + public override bool MoveNext() + { + _position++; + _isGood = _allColumns[IsRowImputedColumnName].MoveNext(_input); + return _isGood; + } + + public sealed override long Position => _position; + + public sealed override long Batch => _input.Batch; + } + } +} diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 9088dfc1eb..bcd7dc5c44 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -133,6 +133,7 @@ Transforms.SentimentAnalyzer Uses a pretrained sentiment model to score input st Transforms.TensorFlowScorer Transforms the data using the TensorFlow model. Microsoft.ML.Transforms.TensorFlowTransformer TensorFlowScorer Microsoft.ML.Transforms.TensorFlowEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TextFeaturizer A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are normalized counts of (word and/or character) n-grams in a given tokenized text. Microsoft.ML.Transforms.Text.TextAnalytics TextTransform Microsoft.ML.Transforms.Text.TextFeaturizingEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TextToKeyConverter Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Categorical TextToKey Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.TimeSeriesImputer Fills in missing row and values Microsoft.ML.Featurizers.TimeSeriesTransformerEntrypoint TimeSeriesImputer Microsoft.ML.Featurizers.TimeSeriesImputerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.ToString Turns the given column into a column of its string representation Microsoft.ML.Featurizers.ToStringTransformerEntrypoint ToString Microsoft.ML.Featurizers.ToStringTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets Microsoft.ML.EntryPoints.TrainTestSplit Split Microsoft.ML.EntryPoints.TrainTestSplit+Input Microsoft.ML.EntryPoints.TrainTestSplit+Output Transforms.TreeLeafFeaturizer Trains a tree ensemble, or loads it from a file, then maps a numeric feature vector to three outputs: 1. A vector containing the individual tree outputs of the tree ensemble. 2. A vector indicating the leaves that the feature vector falls on in the tree ensemble. 3. A vector indicating the paths that the feature vector falls on in the tree ensemble. If a both a model file and a trainer are specified - will use the model file. If neither are specified, will train a default FastTree model. This can handle key labels by training a regression model towards their optionally permuted indices. Microsoft.ML.Data.TreeFeaturize Featurizer Microsoft.ML.Data.TreeEnsembleFeaturizerTransform+ArgumentsForEntryPoint Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 02313887d2..b2587d6cd6 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -23252,6 +23252,130 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.TimeSeriesImputer", + "Desc": "Fills in missing row and values", + "FriendlyName": "TimeSeriesImputer", + "ShortName": "tsi", + "Inputs": [ + { + "Name": "TimeSeriesColumn", + "Type": "String", + "Desc": "Column representing the time", + "Aliases": [ + "time" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "GrainColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "List of grain columns", + "Aliases": [ + "grains" + ], + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + }, + { + "Name": "FilterColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "Columns to filter", + "Aliases": [ + "filters" + ], + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "FilterMode", + "Type": { + "Kind": "Enum", + "Values": [ + "NoFilter", + "Include", + "Exclude" + ] + }, + "Desc": "Filter mode. Either include or exclude", + "Aliases": [ + "fmode" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": "Exclude" + }, + { + "Name": "ImputeMode", + "Type": { + "Kind": "Enum", + "Values": [ + "ForwardFill", + "BackFill", + "Median" + ] + }, + "Desc": "Mode for imputing, defaults to ForwardFill if not provided", + "Aliases": [ + "mode" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": "ForwardFill" + }, + { + "Name": "SupressTypeErrors", + "Type": "Bool", + "Desc": "Supress the errors that would occur if a column and impute mode are imcompatible. If true, will skip the column. If false, will stop and throw an error.", + "Aliases": [ + "error" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.ToString", "Desc": "Turns the given column into a column of its string representation", diff --git a/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs new file mode 100644 index 0000000000..bed0a67e5f --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/TimeSeriesImputerTests.cs @@ -0,0 +1,471 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Featurizers; +using System; +using Xunit; +using Xunit.Abstractions; +using System.Drawing.Printing; +using System.Linq; +using Microsoft.ML.TestFramework.Attributes; + +namespace Microsoft.ML.Tests.Transformers +{ + public class TimeSeriesImputerTests : TestDataPipeBase + { + public TimeSeriesImputerTests(ITestOutputHelper output) : base(output) + { + } + + private class TimeSeriesTwoGrainInput + { + public long date; + public string grainA; + public string grainB; + public float data; + } + + private class TimeSeriesOneGrainInput + { + public long date; + public string grainA; + public int dataA; + public float dataB; + public uint dataC; + } + + private class TimeSeriesOneGrainFloatInput + { + public long date; + public string grainA; + public float dataA; + } + + private class TimeSeriesOneGrainStringInput + { + public long date; + public string grainA; + public string dataA; + } + + [NotCentOS7Fact] + public void NotImputeOneColumn() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { + new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 } + }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, new string[] { "dataB"}); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // We always output the same column as the input, plus adding a column saying whether the row was imputed or not. + Assert.Equal(6, schema.Count); + Assert.Equal("date", schema[0].Name); + Assert.Equal("grainA", schema[1].Name); + Assert.Equal("dataA", schema[2].Name); + Assert.Equal("dataB", schema[3].Name); + Assert.Equal("dataC", schema[4].Name); + Assert.Equal("IsRowImputed", schema[5].Name); + + // We are imputing 1 row, so total rows should be 4. + var preview = output.Preview(); + Assert.Equal(4, preview.RowView.Length); + + // Row that was imputed should have date of 27 + Assert.Equal(27L, preview.ColumnView[0].Values[2]); + + // Since we are not imputing data on one column and a row is getting imputed, its value should be default(T) + Assert.Equal(default(float), preview.ColumnView[3].Values[2]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7Fact] + public void ImputeOnlyOneColumn() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { + new TimeSeriesOneGrainInput() { date = 25, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 26, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 }, + new TimeSeriesOneGrainInput() { date = 28, grainA = "A", dataA = 1, dataB = 2.0f, dataC = 5 } + }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, new string[] { "dataB"}, TimeSeriesImputerEstimator.FilterMode.Include); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // We always output the same column as the input, plus adding a column saying whether the row was imputed or not. + Assert.Equal(6, schema.Count); + Assert.Equal("date", schema[0].Name); + Assert.Equal("grainA", schema[1].Name); + Assert.Equal("dataA", schema[2].Name); + Assert.Equal("dataB", schema[3].Name); + Assert.Equal("dataC", schema[4].Name); + Assert.Equal("IsRowImputed", schema[5].Name); + + // We are imputing 1 row, so total rows should be 4. + var preview = output.Preview(); + Assert.Equal(4, preview.RowView.Length); + + // Row that was imputed should have date of 27 + Assert.Equal(27L, preview.ColumnView[0].Values[2]); + + // Since we are not imputing data on two columns and a row is getting imputed, its value should be default(T) + Assert.Equal(default(int), preview.ColumnView[2].Values[2]); + Assert.Equal(default(uint), preview.ColumnView[4].Values[2]); + + // Column that was imputed should have value of 2.0f + Assert.Equal(2.0f, preview.ColumnView[3].Values[2]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7Fact] + public void Forwardfill() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f }, + new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f }, + new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(2L, prev.ColumnView[0].Values[2]); + Assert.Equal(4L, prev.ColumnView[0].Values[4]); + Assert.Equal(6L, prev.ColumnView[0].Values[6]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(2.0f, prev.ColumnView[2].Values[1]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[2]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[4]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[5]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[6]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[7]); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7Fact] + public void EntryPoint() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new { ts = 1L, grain = 1970, c3 = 10, c4 = 19}, + new { ts = 2L, grain = 1970, c3 = 13, c4 = 12}, + new { ts = 3L, grain = 1970, c3 = 15, c4 = 16}, + new { ts = 5L, grain = 1970, c3 = 20, c4 = 19} + }; + + var data = mlContext.Data.LoadFromEnumerable(dataList); + TimeSeriesImputerEstimator.Options options = new TimeSeriesImputerEstimator.Options() { + TimeSeriesColumn = "ts", + GrainColumns = new[] { "grain" }, + FilterColumns = new[] { "c3", "c4" }, + FilterMode = TimeSeriesImputerEstimator.FilterMode.Include, + ImputeMode = TimeSeriesImputerEstimator.ImputationStrategy.ForwardFill, + Data = data + }; + + var entryOutput = TimeSeriesTransformerEntrypoint.TimeSeriesImputer(mlContext.Transforms.GetEnvironment(), options); + // Build the pipeline, fit, and transform it. + var output = entryOutput.OutputData; + + // Get the data from the first row and make sure it matches expected + var prev = output.Preview(); + + // Should have 4 original columns + 1 more for IsRowImputed + Assert.Equal(5, output.Schema.Count); + + // Imputing rows with date 4 so should have length of 5 + Assert.Equal(5, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(4L, prev.ColumnView[0].Values[3]); + + // Make sure grain was propagated correctly + Assert.Equal(1970, prev.ColumnView[1].Values[2]); + + // Make sure forward fill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(15, prev.ColumnView[2].Values[3]); + Assert.Equal(16, prev.ColumnView[3].Values[3]); + + // Make sure IsRowImputed is true for row 4, false for the rest + Assert.Equal(false, prev.ColumnView[4].Values[0]); + Assert.Equal(false, prev.ColumnView[4].Values[1]); + Assert.Equal(false, prev.ColumnView[4].Values[2]); + Assert.Equal(true, prev.ColumnView[4].Values[3]); + Assert.Equal(false, prev.ColumnView[4].Values[4]); + + Done(); + } + + [NotCentOS7Fact] + public void Median() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = 2.0f }, + new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f }, + new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = float.NaN }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true); + var model = pipeline.Fit(data); + + var output = model.Transform(data); + + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(2L, prev.ColumnView[0].Values[2]); + Assert.Equal(4L, prev.ColumnView[0].Values[4]); + Assert.Equal(6L, prev.ColumnView[0].Values[6]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure Median is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(3.5f, prev.ColumnView[2].Values[1]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[2]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[4]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[5]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[6]); + Assert.Equal(3.5f, prev.ColumnView[2].Values[7]); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7Fact] + public void Backfill() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainFloatInput() { date = 0, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 1, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 3, grainA = "A", dataA = 5.0f }, + new TimeSeriesOneGrainFloatInput() { date = 5, grainA = "A", dataA = float.NaN }, + new TimeSeriesOneGrainFloatInput() { date = 7, grainA = "A", dataA = 2.0f }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(2L, prev.ColumnView[0].Values[2]); + Assert.Equal(4L, prev.ColumnView[0].Values[4]); + Assert.Equal(6L, prev.ColumnView[0].Values[6]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(5.0f, prev.ColumnView[2].Values[0]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[1]); + Assert.Equal(5.0f, prev.ColumnView[2].Values[2]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[4]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[5]); + Assert.Equal(2.0f, prev.ColumnView[2].Values[6]); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7Fact] + public void BackfillTwoGrain() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesTwoGrainInput() { date = 0, grainA = "A", grainB = "A", data = float.NaN}, + new TimeSeriesTwoGrainInput() { date = 1, grainA = "A", grainB = "A", data = 0.0f}, + new TimeSeriesTwoGrainInput() { date = 3, grainA = "A", grainB = "B", data = 1.0f}, + new TimeSeriesTwoGrainInput() { date = 5, grainA = "A", grainB = "B", data = float.NaN}, + new TimeSeriesTwoGrainInput() { date = 7, grainA = "A", grainB = "B", data = 2.0f }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA", "grainB" }, TimeSeriesImputerEstimator.ImputationStrategy.BackFill); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 4 original columns + 1 more for IsRowImputed + Assert.Equal(5, output.Schema.Count); + + // Imputing rows with dates 4,6, so should have length of 8 + Assert.Equal(7, prev.RowView.Length); + + // Check that imputed rows have the correct dates + Assert.Equal(4L, prev.ColumnView[0].Values[3]); + Assert.Equal(6L, prev.ColumnView[0].Values[5]); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[3].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[5].ToString()); + Assert.Equal("B", prev.ColumnView[2].Values[3].ToString()); + Assert.Equal("B", prev.ColumnView[2].Values[5].ToString()); + + // Make sure backfill is working as expected. All NA's should be replaced, and imputed rows should have correct values too + Assert.Equal(0.0f, prev.ColumnView[3].Values[0]); + Assert.Equal(2.0f, prev.ColumnView[3].Values[3]); + Assert.Equal(2.0f, prev.ColumnView[3].Values[4]); + Assert.Equal(2.0f, prev.ColumnView[3].Values[5]); + + // Make sure IsRowImputed is true for row 4,6 false for the rest + Assert.Equal(false, prev.ColumnView[4].Values[0]); + Assert.Equal(false, prev.ColumnView[4].Values[1]); + Assert.Equal(false, prev.ColumnView[4].Values[2]); + Assert.Equal(true, prev.ColumnView[4].Values[3]); + Assert.Equal(false, prev.ColumnView[4].Values[4]); + Assert.Equal(true, prev.ColumnView[4].Values[5]); + Assert.Equal(false, prev.ColumnView[4].Values[6]); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7Fact] + public void InvalidTypeForImputationStrategy() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new TimeSeriesOneGrainStringInput(){ date = 0L, grainA = "A", dataA = "zero" }, + new TimeSeriesOneGrainStringInput(){ date = 1L, grainA = "A", dataA = "one" }, + new TimeSeriesOneGrainStringInput(){ date = 3L, grainA = "A", dataA = "three" }, + new TimeSeriesOneGrainStringInput(){ date = 5L, grainA = "A", dataA = "five" }, + new TimeSeriesOneGrainStringInput(){ date = 7L, grainA = "A", dataA = "seven" }}; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // When suppressTypeErrors is set to false this will throw an error. + var pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: false); + var ex = Assert.Throws(() => pipeline.Fit(data)); + Assert.Equal("Only Numeric type columns are supported for ImputationStrategy median. (use suppressError flag to skip imputing non-numeric types)", ex.Message); + + // When suppressTypeErrors is set to true then the default value will be used. + pipeline = mlContext.Transforms.ReplaceMissingTimeSeriesValues("date", new string[] { "grainA" }, imputeMode: TimeSeriesImputerEstimator.ImputationStrategy.Median, filterColumns: null, suppressTypeErrors: true); + var model = pipeline.Fit(data); + + var output = model.Transform(data); + var prev = output.Preview(); + + // Should have 3 original columns + 1 more for IsRowImputed + Assert.Equal(4, output.Schema.Count); + + // Imputing rows with dates 2,4,6, so should have length of 8 + Assert.Equal(8, prev.RowView.Length); + + // Check that imputed rows have the default value + Assert.Equal("", prev.ColumnView[2].Values[2].ToString()); + Assert.Equal("", prev.ColumnView[2].Values[4].ToString()); + Assert.Equal("", prev.ColumnView[2].Values[6].ToString()); + + // Make sure grain was propagated correctly + Assert.Equal("A", prev.ColumnView[1].Values[2].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[4].ToString()); + Assert.Equal("A", prev.ColumnView[1].Values[6].ToString()); + + // Make sure original values stayed the same + Assert.Equal("zero", prev.ColumnView[2].Values[0].ToString()); + Assert.Equal("one", prev.ColumnView[2].Values[1].ToString()); + Assert.Equal("three", prev.ColumnView[2].Values[3].ToString()); + Assert.Equal("five", prev.ColumnView[2].Values[5].ToString()); + Assert.Equal("seven", prev.ColumnView[2].Values[7].ToString()); + + // Make sure IsRowImputed is true for row 2, 4,6 , false for the rest + Assert.Equal(false, prev.ColumnView[3].Values[0]); + Assert.Equal(false, prev.ColumnView[3].Values[1]); + Assert.Equal(true, prev.ColumnView[3].Values[2]); + Assert.Equal(false, prev.ColumnView[3].Values[3]); + Assert.Equal(true, prev.ColumnView[3].Values[4]); + Assert.Equal(false, prev.ColumnView[3].Values[5]); + Assert.Equal(true, prev.ColumnView[3].Values[6]); + Assert.Equal(false, prev.ColumnView[3].Values[7]); + + TestEstimatorCore(pipeline, data); + + Done(); + } + } +}