From 3dcf356d6a3844ffd115bfcc923c41e83706a8f1 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Wed, 4 Dec 2019 09:56:42 -0800 Subject: [PATCH 1/8] files and changes needed for the DateTimeTransformer --- .../Featurizers/DateTimeTransformer.cs | 84 ++ .../DateTimeTransformerDropColumns.cs | 80 ++ .../DateTimeTransformer.cs | 782 ++++++++++++++++++ .../Common/EntryPoints/core_ep-list.tsv | 1 + .../Common/EntryPoints/core_manifest.json | 151 ++++ .../UnitTests/TestEntryPoints.cs | 2 + .../Transformers/DateTimeTransformerTests.cs | 360 ++++++++ 7 files changed, 1460 insertions(+) create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs create mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs create mode 100644 src/Microsoft.ML.Featurizers/DateTimeTransformer.cs create mode 100644 test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs new file mode 100644 index 0000000000..9bf851b0ba --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class DateTimeTransformer + { + private class DateTimeInput + { + public long Date; + } + + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + // Future Date - 2025 June 30 + var samples = new[] { new DateTimeInput() { Date = 1751241600 } }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for splitting the time features into individual columns + var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC"); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. We should have created 21 more columns with all the + // DateTime information split into its own columns + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " + + featureRow.DTCDay + ", " + featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + + featureRow.DTCSecond + ", " + featureRow.DTCAmPm + ", " + featureRow.DTCHour12 + ", " + + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + featureRow.DTCDayOfYear + + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCQuarterOfYear + ", " + featureRow.DTCHalfOfYear + + ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " + + featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel + ", " + featureRow.DTCHolidayName + ", " + + featureRow.DTCIsPaidTimeOff); + + // Expected output: + // Features columns obtained post-transformation. + // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday, , 0 + } + + // These columns start with DTC because that is the prefix we picked + private sealed class TransformedData + { + public long Date { get; set; } + public int DTCYear { get; set; } + public byte DTCMonth { get; set; } + public byte DTCDay { get; set; } + public byte DTCHour { get; set; } + public byte DTCMinute { get; set; } + public byte DTCSecond { get; set; } + public byte DTCAmPm { get; set; } + public byte DTCHour12 { get; set; } + public byte DTCDayOfWeek { get; set; } + public byte DTCDayOfQuarter { get; set; } + public ushort DTCDayOfYear { get; set; } + public ushort DTCWeekOfMonth { get; set; } + public byte DTCQuarterOfYear { get; set; } + public byte DTCHalfOfYear { get; set; } + public byte DTCWeekIso { get; set; } + public int DTCYearIso { get; set; } + public string DTCMonthLabel { get; set; } + public string DTCAmPmLabel { get; set; } + public string DTCDayOfWeekLabel { get; set; } + public string DTCHolidayName { get; set; } + public byte DTCIsPaidTimeOff { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs new file mode 100644 index 0000000000..680b338056 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs @@ -0,0 +1,80 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class DateTimeTransformerDropColumns + { + private class DateTimeInput + { + public long Date; + } + + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + // Future Date - 2025 June 30 + var samples = new[] { new DateTimeInput() { Date = 1751241600 } }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for splitting the time features into individual columns + // All the columns listed here will be dropped. + var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC", DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff, + DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, + DateTimeTransformerEstimator.ColumnsProduced.AmPm, DateTimeTransformerEstimator.ColumnsProduced.HolidayName); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. We should have created 16 more columns with all the + // DateTime information split into its own columns + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " + + featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + featureRow.DTCSecond + ", " + + featureRow.DTCHour12 + ", " + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + + featureRow.DTCDayOfYear + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCHalfOfYear + + ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " + + featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel); + + // Expected output: + // Features columns obtained post-transformation. + // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday + } + + // These columns start with DTC because that is the prefix we picked + private sealed class TransformedData + { + public long Date { get; set; } + public int DTCYear { get; set; } + public byte DTCMonth { get; set; } + public byte DTCHour { get; set; } + public byte DTCMinute { get; set; } + public byte DTCSecond { get; set; } + public byte DTCHour12 { get; set; } + public byte DTCDayOfWeek { get; set; } + public byte DTCDayOfQuarter { get; set; } + public ushort DTCDayOfYear { get; set; } + public ushort DTCWeekOfMonth { get; set; } + public byte DTCHalfOfYear { get; set; } + public byte DTCWeekIso { get; set; } + public int DTCYearIso { get; set; } + public string DTCMonthLabel { get; set; } + public string DTCAmPmLabel { get; set; } + public string DTCDayOfWeekLabel { get; set; } + } + } +} diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs new file mode 100644 index 0000000000..2f48bba2b1 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -0,0 +1,782 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using Microsoft.Win32.SafeHandles; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(DateTimeTransformer), null, typeof(SignatureLoadModel), + DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(DateTimeTransformer), null, typeof(SignatureLoadRowMapper), + DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(DateTimeTransformerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + + public static class DateTimeTransformerExtensionClass + { + /// + /// Create a , which splits up the input column specified by + /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. + /// This transformer will append the to all the output columns. If is empty, + /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. + /// + /// Transform catalog + /// Input column name + /// Prefix to add to the generated columns + /// List of columns to drop, if any + /// + public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop) + => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop); + + /// + /// Create a , which splits up the input column specified by + /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. + /// This transformer will append the to all the output columns. If is empty, + /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. If you specify a country, + /// Holiday details will be looked up for that country as well. + /// + /// Transform catalog + /// Input column name + /// Prefix to add to the generated columns + /// List of columns to drop, if any + /// Country name to get holiday details for + /// + public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeTransformerEstimator.Countries country = DateTimeTransformerEstimator.Countries.None) + => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country); + + #region ColumnsProduced static extentions + + internal static Type GetRawColumnType(this DateTimeTransformerEstimator.ColumnsProduced column) + { + switch (column) + { + case DateTimeTransformerEstimator.ColumnsProduced.Year: + case DateTimeTransformerEstimator.ColumnsProduced.YearIso: + return typeof(int); + case DateTimeTransformerEstimator.ColumnsProduced.DayOfYear: + case DateTimeTransformerEstimator.ColumnsProduced.WeekOfMonth: + return typeof(ushort); + case DateTimeTransformerEstimator.ColumnsProduced.MonthLabel: + case DateTimeTransformerEstimator.ColumnsProduced.AmPmLabel: + case DateTimeTransformerEstimator.ColumnsProduced.DayOfWeekLabel: + case DateTimeTransformerEstimator.ColumnsProduced.HolidayName: + return typeof(ReadOnlyMemory); + default: + return typeof(byte); + } + } + + #endregion + } + + /// + /// The DateTimeTransformerEstimator splits up a date into all of its sub parts as individual columns. It generates these fields with a user specified prefix: + /// int Year, byte Month, byte Day, byte Hour, byte Minute, byte Second, byte AmPm, byte Hour12, byte DayOfWeek, byte DayOfQuarter, + /// ushort DayOfYear, ushort WeekOfMonth, byte QuarterOfYear, byte HalfOfYear, byte WeekIso, int YearIso, string MonthLabel, string AmPmLabel, + /// string DayOfWeekLabel, string HolidayName, byte IsPaidTimeOff + /// + /// You can optionally specify a country and it will pull holiday information about the country as well + /// + /// + /// is a trivial estimator and does not need training. + /// + /// + /// ]]> + /// + /// + /// + /// + public sealed class DateTimeTransformerEstimator : IEstimator + { + private readonly Options _options; + + private readonly IHost _host; + + #region Options + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "Input column", Name = "Source", ShortName = "src", SortOrder = 1)] + public string Source; + + // This transformer adds columns + [Argument(ArgumentType.Required, HelpText = "Output column prefix", Name = "Prefix", ShortName = "pre", SortOrder = 2)] + public string Prefix; + + [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to drop after the DateTime Expansion", Name = "ColumnsToDrop", ShortName = "drop", SortOrder = 3)] + public ColumnsProduced[] ColumnsToDrop; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)] + public Countries Country = Countries.None; + } + + #endregion + + public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, Countries country = Countries.None) + { + + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + _host.CheckValue(inputColumnName, nameof(inputColumnName), "Input column should not be null."); + + _options = new Options + { + Source = inputColumnName, + Prefix = columnPrefix, + ColumnsToDrop = columnsToDrop == null ? Array.Empty() : columnsToDrop, + Country = country + }; + } + + internal DateTimeTransformerEstimator(IHostEnvironment env, Options options) + { + + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + + _options = options; + _options.ColumnsToDrop = _options.ColumnsToDrop == null ? Array.Empty() : _options.ColumnsToDrop; + } + + public DateTimeTransformer Fit(IDataView input) + { + return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.ColumnsToDrop, _options.Country); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + + foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced))) + if (_options.ColumnsToDrop == null || !_options.ColumnsToDrop.Contains(column)) + { + columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar, + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null); + } + + return new SchemaShape(columns.Values); + } + + #region Enums + public enum ColumnsProduced : byte + { + Year = 1, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear, + WeekOfMonth, QuarterOfYear, HalfOfYear, WeekIso, YearIso, MonthLabel, AmPmLabel, DayOfWeekLabel, + HolidayName, IsPaidTimeOff + }; + + public enum Countries : byte + { + None = 1, + Argentina, Australia, Austria, Belarus, Belgium, Brazil, Canada, Colombia, Croatia, Czech, Denmark, + England, Finland, France, Germany, Hungary, India, Ireland, IsleofMan, Italy, Japan, Mexico, Netherlands, + NewZealand, NorthernIreland, Norway, Poland, Portugal, Scotland, Slovenia, SouthAfrica, Spain, Sweden, Switzerland, + Ukraine, UnitedKingdom, UnitedStates, Wales + } + + #endregion + } + + public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable + { + #region Class data members + + internal const string Summary = "Splits a date time value into each individual component"; + internal const string UserName = "DateTime Transform"; + internal const string ShortName = "DateTimeTransform"; + internal const string LoadName = "DateTimeTransform"; + internal const string LoaderSignature = "DateTimeTransform"; + private DateTimeTypedColumn _column; + + private DateTimeTransformerEstimator.ColumnsProduced[] _columnsToDrop; + private byte[] _activeColumnMapping; + + #endregion + + public DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop, DateTimeTransformerEstimator.Countries country) : + base(env.Register(nameof(DateTimeTransformer))) + { + + _columnsToDrop = columnsToDrop; + var activeColumnLength = Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - (_columnsToDrop == null ? 0 : _columnsToDrop.Length); + _activeColumnMapping = new byte[activeColumnLength]; + var index = 0; + foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + { + if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) + { + _activeColumnMapping[index++] = (byte)column; + } + } + + _column = new DateTimeTypedColumn(inputColumnName, columnPrefix); + _column.CreateTransformerFromEstimator(country); + } + + // Factory method for SignatureLoadModel. + internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) : + base(host.Register(nameof(DateTimeTransformer))) + { + + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // name of input column + // column prefix + // byte length of columns to drop array + // byte array of columns to drop + // length of C++ state array + // C++ byte state array + + _column = new DateTimeTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString()); + + var dropColumnsLength = ctx.Reader.ReadInt32(); + if (dropColumnsLength > 0) + { + _columnsToDrop = new DateTimeTransformerEstimator.ColumnsProduced[dropColumnsLength]; + //read in enum bytes + for (int i = 0; i < dropColumnsLength; i++) + _columnsToDrop[i] = (DateTimeTransformerEstimator.ColumnsProduced)ctx.Reader.ReadByte(); + } + + _activeColumnMapping = new byte[Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - dropColumnsLength]; + var index = 0; + foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + { + if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) + { + _activeColumnMapping[index++] = (byte)column; + } + } + + var dataLength = ctx.Reader.ReadInt32(); + var data = ctx.Reader.ReadByteArray(dataLength); + _column.CreateTransformerFromSavedData(data); + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new DateTimeTransformer(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "DATETI T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(DateTimeTransformer).Assembly.FullName); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // name of input column + // column prefix + // byte length of columns to drop array + // byte array of columns to drop + // length of C++ state array + // C++ byte state array + + ctx.Writer.Write(_column.Source); + ctx.Writer.Write(_column.Prefix); + + ctx.Writer.Write(_columnsToDrop == null ? 0 : _columnsToDrop.Length); + if (_columnsToDrop != null) + { + foreach (var toDrop in _columnsToDrop) + ctx.Writer.Write((byte)toDrop); + } + + var data = _column.CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + + public void Dispose() + { + _column.Dispose(); + } + + #region C++ Safe handle classes + + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private readonly DestroyTransformedDataNative _destroyTransformedDataHandler; + + public TransformedDataSafeHandle(IntPtr handle, DestroyTransformedDataNative destroyTransformedDataHandler) : base(true) + { + SetHandle(handle); + _destroyTransformedDataHandler = destroyTransformedDataHandler; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return _destroyTransformedDataHandler(handle, out IntPtr errorHandle); + } + } + + #endregion + + #region TimePoint + + [StructLayoutAttribute(LayoutKind.Sequential)] + internal struct TimePoint + { + public int Year; + public byte Month; + public byte Day; + public byte Hour; + public byte Minute; + public byte Second; + public byte AmPm; + public byte Hour12; + public byte DayOfWeek; + public byte DayOfQuarter; + public ushort DayOfYear; + public ushort WeekOfMonth; + public byte QuarterOfYear; + public byte HalfOfYear; + public byte WeekIso; + public int YearIso; + public string MonthLabel; + public string AmPmLabel; + public string DayOfWeekLabel; + public string HolidayName; + public byte IsPaidTimeOff; + + internal unsafe TimePoint(byte* rawData) + { + int intPtrSize = sizeof(IntPtr); + + Year = *(int*)rawData; + rawData += 4; + + Month = *rawData++; + Day = *rawData++; + Hour = *rawData++; + Minute = *rawData++; + Second = *rawData++; + AmPm = *rawData++; + Hour12 = *rawData++; + DayOfWeek = *rawData++; + DayOfQuarter = *rawData++; + DayOfYear = *(ushort*)rawData; + rawData += 2; + + WeekOfMonth = *(ushort*)rawData; + rawData += 2; + + QuarterOfYear = *rawData++; + HalfOfYear = *rawData++; + WeekIso = *rawData++; + YearIso = *(int*)rawData; + rawData += 4; + + // Convert char * to string + MonthLabel = GetStringFromPointer(ref rawData, intPtrSize); + AmPmLabel = GetStringFromPointer(ref rawData, intPtrSize); + DayOfWeekLabel = GetStringFromPointer(ref rawData, intPtrSize); + HolidayName = GetStringFromPointer(ref rawData, intPtrSize); + IsPaidTimeOff = *rawData; + } + + // Converts a pointer to a native char* to a string and increments pointer by to the next value. + // The length of the string is stored at byte* + sizeof(IntPtr). + private static unsafe string GetStringFromPointer(ref byte* rawData, int intPtrSize) + { + byte[] buffer; + if (intPtrSize == 4) // 32 bit machine + buffer = new byte[*(uint*)(rawData + intPtrSize)]; + else // 64 bit machine + buffer = new byte[*(ulong*)(rawData + intPtrSize)]; + + if (buffer.Length == 0) + { + rawData += intPtrSize * 2; + return string.Empty; + } + + Marshal.Copy(new IntPtr(*(int**)rawData), buffer, 0, buffer.Length); + rawData += intPtrSize * 2; + + return Encoding.UTF8.GetString(buffer); + } + + }; + + #endregion + + #region BaseClass + + internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); + internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); + internal delegate bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle); + + internal abstract class TypedColumn : IDisposable + { + internal readonly string Source; + internal readonly string Prefix; + + internal TypedColumn(string source, string prefix) + { + Source = source; + Prefix = prefix; + } + + internal abstract void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country); + private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); + private protected unsafe abstract bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + public abstract void Dispose(); + + private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(DateTimeTransformerEstimator.Countries country) + { + bool success; + IntPtr errorHandle; + IntPtr estimator; + if (country == DateTimeTransformerEstimator.Countries.None) + { + success = CreateEstimatorHelper(null, null, out estimator, out errorHandle); + } + else + { + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) + fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeTransformerEstimator.Countries), country) + char.MinValue)) + { + success = CreateEstimatorHelper(countryPointer, dataRootDir, out estimator, out errorHandle); + } + } + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper)) + { + + success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper); + } + } + + internal byte[] CreateTransformerSaveData() + { + + var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + internal unsafe void CreateTransformerFromSavedData(byte[] data) + { + fixed (byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + CreateTransformerFromSavedDataHelper(rawData, dataSize); + } + } + } + + internal abstract class TypedColumn : TypedColumn + { + internal TypedColumn(string source, string prefix) : + base(source, prefix) + { + } + + internal abstract TimePoint Transform(T input); + + } + + #endregion + + #region DateTimeTypedColumn + + internal sealed class DateTimeTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + internal DateTimeTypedColumn(string source, string prefix) : + base(source, prefix) + { + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateEstimatorNative(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override unsafe void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country) + { + _transformerHandler = CreateTransformerFromEstimatorBase(country); + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromSavedDataWithDataRoot"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, byte* dataRootDir, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, dataRootDir, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out IntPtr output, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle); + internal override TimePoint Transform(long input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var handler = new TransformedDataSafeHandle(output, DestroyTransformedDataNative)) + { + unsafe + { + return new TimePoint((byte*)output.ToPointer()); + } + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected unsafe override bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(countryName, dataRootDir, out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + + #endregion + + private sealed class Mapper : MapperBase + { + + #region Class data members + + private readonly DateTimeTransformer _parent; + private ConcurrentDictionary _cache; + private ConcurrentQueue _oldestKeys; + + #endregion + + public Mapper(DateTimeTransformer parent, DataViewSchema inputSchema) : + base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) + { + _parent = parent; + _cache = new ConcurrentDictionary(); + _oldestKeys = new ConcurrentQueue(); + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + var columns = new List(); + + foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + if (_parent._columnsToDrop == null || !_parent._columnsToDrop.Contains(column)) + { + columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()))); + } + + return columns.ToArray(); + } + + private Delegate MakeGetter(DataViewRow input, int iinfo) + { + ValueGetter result = (ref T dst) => + { + long dateTime = default; + var getter = input.GetGetter(input.Schema[_parent._column.Source]); + getter(ref dateTime); + + if (!_cache.TryGetValue(dateTime, out TimePoint timePoint)) + { + _cache[dateTime] = _parent._column.Transform(dateTime); + _oldestKeys.Enqueue(dateTime); + timePoint = _cache[dateTime]; + + // If more than 100 cached items, remove 20 + if (_cache.Count > 100) + { + for (int i = 0; i < 20; i++) + { + long key; + while (!_oldestKeys.TryDequeue(out key)) { } + while (!_cache.TryRemove(key, out TimePoint removedValue)) { } + } + } + } + + if (iinfo == 0) + dst = (T)Convert.ChangeType(timePoint.Year, typeof(T)); + else if (iinfo == 1) + dst = (T)Convert.ChangeType(timePoint.Month, typeof(T)); + else if (iinfo == 2) + dst = (T)Convert.ChangeType(timePoint.Day, typeof(T)); + else if (iinfo == 3) + dst = (T)Convert.ChangeType(timePoint.Hour, typeof(T)); + else if (iinfo == 4) + dst = (T)Convert.ChangeType(timePoint.Minute, typeof(T)); + else if (iinfo == 5) + dst = (T)Convert.ChangeType(timePoint.Second, typeof(T)); + else if (iinfo == 6) + dst = (T)Convert.ChangeType(timePoint.AmPm, typeof(T)); + else if (iinfo == 7) + dst = (T)Convert.ChangeType(timePoint.Hour12, typeof(T)); + else if (iinfo == 8) + dst = (T)Convert.ChangeType(timePoint.DayOfWeek, typeof(T)); + else if (iinfo == 9) + dst = (T)Convert.ChangeType(timePoint.DayOfQuarter, typeof(T)); + else if (iinfo == 10) + dst = (T)Convert.ChangeType(timePoint.DayOfYear, typeof(T)); + else if (iinfo == 11) + dst = (T)Convert.ChangeType(timePoint.WeekOfMonth, typeof(T)); + else if (iinfo == 12) + dst = (T)Convert.ChangeType(timePoint.QuarterOfYear, typeof(T)); + else if (iinfo == 13) + dst = (T)Convert.ChangeType(timePoint.HalfOfYear, typeof(T)); + else if (iinfo == 14) + dst = (T)Convert.ChangeType(timePoint.WeekIso, typeof(T)); + else if (iinfo == 15) + dst = (T)Convert.ChangeType(timePoint.YearIso, typeof(T)); + else if (iinfo == 16) + dst = (T)Convert.ChangeType(timePoint.MonthLabel.AsMemory(), typeof(T)); + else if (iinfo == 17) + dst = (T)Convert.ChangeType(timePoint.AmPmLabel.AsMemory(), typeof(T)); + else if (iinfo == 18) + dst = (T)Convert.ChangeType(timePoint.DayOfWeekLabel.AsMemory(), typeof(T)); + else if (iinfo == 19) + dst = (T)Convert.ChangeType(timePoint.HolidayName.AsMemory(), typeof(T)); + else + dst = (T)Convert.ChangeType(timePoint.IsPaidTimeOff, typeof(T)); + }; + + return result; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + + var outputColumn = (int)_parent._activeColumnMapping[iinfo]; + + // Have to subtract 1 from the output column since the enum starts and 1 and not 0. + return Utils.MarshalInvoke(MakeGetter, ((DateTimeTransformerEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + for (int i = 0; i < InputSchema.Count; i++) + { + if (InputSchema[i].Name.Equals(_parent._column.Source)) + { + active[i] = true; + } + } + + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + } + } + + internal static class DateTimeTransformerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.DateTimeSplitter", + Desc = DateTimeTransformer.Summary, + UserName = DateTimeTransformer.UserName, + ShortName = DateTimeTransformer.ShortName)] + public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeTransformerEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input); + var xf = new DateTimeTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 4f2bcc426a..7125ed4188 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -85,6 +85,7 @@ Transforms.CombinerByContiguousGroupId Groups values of a scalar column into a v Transforms.ConditionalNormalizer Normalize the columns only if needed Microsoft.ML.Data.Normalize IfNeeded Microsoft.ML.Transforms.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput] Transforms.DatasetScorer Score a dataset with a predictor model Microsoft.ML.EntryPoints.ScoreModel Score Microsoft.ML.EntryPoints.ScoreModel+Input Microsoft.ML.EntryPoints.ScoreModel+Output Transforms.DatasetTransformScorer Score a dataset with a transform model Microsoft.ML.EntryPoints.ScoreModel ScoreUsingTransform Microsoft.ML.EntryPoints.ScoreModel+InputTransformScorer Microsoft.ML.EntryPoints.ScoreModel+Output +Transforms.DateTimeSplitter Splits a date time value into each individual component Microsoft.ML.Featurizers.DateTimeTransformerEntrypoint DateTimeSplit Microsoft.ML.Featurizers.DateTimeTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.Dictionarizer Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Text.TextAnalytics TermTransform Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureCombiner Combines all the features into one feature column. Microsoft.ML.EntryPoints.FeatureCombiner PrepareFeatures Microsoft.ML.EntryPoints.FeatureCombiner+FeatureCombinerInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Transforms.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Transforms.FeatureContributionCalculatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c8e6d6e55c..ac968830dc 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -18077,6 +18077,157 @@ } ] }, + { + "Name": "Transforms.DateTimeSplitter", + "Desc": "Splits a date time value into each individual component", + "FriendlyName": "DateTime Transform", + "ShortName": "DateTimeTransform", + "Inputs": [ + { + "Name": "Source", + "Type": "String", + "Desc": "Input column", + "Aliases": [ + "src" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Prefix", + "Type": "String", + "Desc": "Output column prefix", + "Aliases": [ + "pre" + ], + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + }, + { + "Name": "ColumnsToDrop", + "Type": { + "Kind": "Array", + "ItemType": { + "Kind": "Enum", + "Values": [ + "Year", + "Month", + "Day", + "Hour", + "Minute", + "Second", + "AmPm", + "Hour12", + "DayOfWeek", + "DayOfQuarter", + "DayOfYear", + "WeekOfMonth", + "QuarterOfYear", + "HalfOfYear", + "WeekIso", + "YearIso", + "MonthLabel", + "AmPmLabel", + "DayOfWeekLabel", + "HolidayName", + "IsPaidTimeOff" + ] + } + }, + "Desc": "Columns to drop after the DateTime Expansion", + "Aliases": [ + "drop" + ], + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Country", + "Type": { + "Kind": "Enum", + "Values": [ + "None", + "Argentina", + "Australia", + "Austria", + "Belarus", + "Belgium", + "Brazil", + "Canada", + "Colombia", + "Croatia", + "Czech", + "Denmark", + "England", + "Finland", + "France", + "Germany", + "Hungary", + "India", + "Ireland", + "IsleofMan", + "Italy", + "Japan", + "Mexico", + "Netherlands", + "NewZealand", + "NorthernIreland", + "Norway", + "Poland", + "Portugal", + "Scotland", + "Slovenia", + "SouthAfrica", + "Spain", + "Sweden", + "Switzerland", + "Ukraine", + "UnitedKingdom", + "UnitedStates", + "Wales" + ] + }, + "Desc": "Country to get holidays for. Defaults to none if not passed", + "Aliases": [ + "ctry" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": "None" + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.Dictionarizer", "Desc": "Converts input values (words, numbers, etc.) to index in a dictionary.", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index bd425c05ff..268e7fc390 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Model.OnnxConverter; @@ -330,6 +331,7 @@ public void EntryPointCatalogCheckDuplicateParams() Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(DateTimeTransformer).Assembly); var catalog = Env.ComponentCatalog; diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs new file mode 100644 index 0000000000..6514ecd90b --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs @@ -0,0 +1,360 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Featurizers; +using System; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class DateTimeTransformerTests : TestDataPipeBase + { + public DateTimeTransformerTests(ITestOutputHelper output) : base(output) + { + } + + private class DateTimeInput + { + public long date; + } + + [Fact] + public void CorrectNumberOfColumnsAndSchema() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 22 columns + Assert.Equal(22, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Day", schema[3].Name); + Assert.Equal($"{columnPrefix}Hour", schema[4].Name); + Assert.Equal($"{columnPrefix}Minute", schema[5].Name); + Assert.Equal($"{columnPrefix}Second", schema[6].Name); + Assert.Equal($"{columnPrefix}AmPm", schema[7].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name); + Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[16].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name); + Assert.Equal($"{columnPrefix}IsPaidTimeOff", schema[21].Name); + + // Make sure types are correct + Assert.Equal(typeof(int), schema[1].Type.RawType); + Assert.Equal(typeof(byte), schema[2].Type.RawType); + Assert.Equal(typeof(byte), schema[3].Type.RawType); + Assert.Equal(typeof(byte), schema[4].Type.RawType); + Assert.Equal(typeof(byte), schema[5].Type.RawType); + Assert.Equal(typeof(byte), schema[6].Type.RawType); + Assert.Equal(typeof(byte), schema[7].Type.RawType); + Assert.Equal(typeof(byte), schema[8].Type.RawType); + Assert.Equal(typeof(byte), schema[9].Type.RawType); + Assert.Equal(typeof(byte), schema[10].Type.RawType); + Assert.Equal(typeof(ushort), schema[11].Type.RawType); + Assert.Equal(typeof(ushort), schema[12].Type.RawType); + Assert.Equal(typeof(byte), schema[13].Type.RawType); + Assert.Equal(typeof(byte), schema[14].Type.RawType); + Assert.Equal(typeof(byte), schema[15].Type.RawType); + Assert.Equal(typeof(int), schema[16].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[18].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[19].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[20].Type.RawType); + Assert.Equal(typeof(byte), schema[21].Type.RawType); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void DropOneColumn() + { + // TODO: This will fail until we figure out the C++ dll situation + + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 21 columns + Assert.Equal(21, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Day", schema[3].Name); + Assert.Equal($"{columnPrefix}Hour", schema[4].Name); + Assert.Equal($"{columnPrefix}Minute", schema[5].Name); + Assert.Equal($"{columnPrefix}Second", schema[6].Name); + Assert.Equal($"{columnPrefix}AmPm", schema[7].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name); + Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[16].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void DropManyColumns() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff, + DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, DateTimeTransformerEstimator.ColumnsProduced.AmPm); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 18 columns + Assert.Equal(18, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Hour", schema[3].Name); + Assert.Equal($"{columnPrefix}Minute", schema[4].Name); + Assert.Equal($"{columnPrefix}Second", schema[5].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[6].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[7].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[9].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[10].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[12].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[13].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[14].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[15].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[16].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[17].Name); + + // Make sure types are correct + Assert.Equal(typeof(int), schema[1].Type.RawType); + Assert.Equal(typeof(byte), schema[2].Type.RawType); + Assert.Equal(typeof(byte), schema[3].Type.RawType); + Assert.Equal(typeof(byte), schema[4].Type.RawType); + Assert.Equal(typeof(byte), schema[5].Type.RawType); + Assert.Equal(typeof(byte), schema[6].Type.RawType); + Assert.Equal(typeof(byte), schema[7].Type.RawType); + Assert.Equal(typeof(byte), schema[8].Type.RawType); + Assert.Equal(typeof(ushort), schema[9].Type.RawType); + Assert.Equal(typeof(ushort), schema[10].Type.RawType); + Assert.Equal(typeof(byte), schema[11].Type.RawType); + Assert.Equal(typeof(byte), schema[12].Type.RawType); + Assert.Equal(typeof(int), schema[13].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[14].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[15].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[16].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void CanUseDateFromColumn() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void HolidayTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 157161600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC", country: DateTimeTransformerEstimator.Countries.Canada); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row for holidays is what we expect + Assert.Equal("Christmas Day", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void ManyRowsTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 }, + new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }, new DateTimeInput() { date = 1751241600 }, + new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 }, + new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }}; + + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview().RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [Fact] + public void EntryPointTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var options = new DateTimeTransformerEstimator.Options + { + ColumnsToDrop = null, + Source = "date", + Prefix = "pref_", + Data = data + }; + + var entryOutput = DateTimeTransformerEntrypoint.DateTimeSplit(mlContext.Transforms.GetEnvironment(), options); + var output = entryOutput.OutputData; + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + Done(); + } + } +} From 1bedb5907f143254e17e05edca4e378f0311140a Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 5 Dec 2019 10:21:31 -0800 Subject: [PATCH 2/8] updates from PR comments --- .../DateTimeTransformerDropColumns.cs | 6 +- .../DateTimeTransformer.cs | 158 ++++++++++++------ .../Transformers/DateTimeTransformerTests.cs | 10 +- 3 files changed, 113 insertions(+), 61 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs index 680b338056..5c85736d2f 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs @@ -28,9 +28,9 @@ public static void Example() // A pipeline for splitting the time features into individual columns // All the columns listed here will be dropped. - var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC", DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff, - DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, - DateTimeTransformerEstimator.ColumnsProduced.AmPm, DateTimeTransformerEstimator.ColumnsProduced.HolidayName); + var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC", DateTimeEstimator.ColumnsProduced.IsPaidTimeOff, + DateTimeEstimator.ColumnsProduced.Day, DateTimeEstimator.ColumnsProduced.QuarterOfYear, + DateTimeEstimator.ColumnsProduced.AmPm, DateTimeEstimator.ColumnsProduced.HolidayName); // The transformed data. var transformedData = pipeline.Fit(dataview).Transform(dataview); diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index 2f48bba2b1..fdcae35b0c 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -36,7 +36,7 @@ namespace Microsoft.ML.Featurizers public static class DateTimeTransformerExtensionClass { /// - /// Create a , which splits up the input column specified by + /// Create a , which splits up the input column specified by /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. /// This transformer will append the to all the output columns. If is empty, /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. @@ -45,12 +45,12 @@ public static class DateTimeTransformerExtensionClass /// Input column name /// Prefix to add to the generated columns /// List of columns to drop, if any - /// - public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop) - => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop); + /// + public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeEstimator.ColumnsProduced[] columnsToDrop) + => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop); /// - /// Create a , which splits up the input column specified by + /// Create a , which splits up the input column specified by /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. /// This transformer will append the to all the output columns. If is empty, /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. If you specify a country, @@ -61,26 +61,26 @@ public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCa /// Prefix to add to the generated columns /// List of columns to drop, if any /// Country name to get holiday details for - /// - public static DateTimeTransformerEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeTransformerEstimator.Countries country = DateTimeTransformerEstimator.Countries.None) - => new DateTimeTransformerEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country); + /// + public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeEstimator.HolidayList country = DateTimeEstimator.HolidayList.None) + => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country); #region ColumnsProduced static extentions - internal static Type GetRawColumnType(this DateTimeTransformerEstimator.ColumnsProduced column) + internal static Type GetRawColumnType(this DateTimeEstimator.ColumnsProduced column) { switch (column) { - case DateTimeTransformerEstimator.ColumnsProduced.Year: - case DateTimeTransformerEstimator.ColumnsProduced.YearIso: + case DateTimeEstimator.ColumnsProduced.Year: + case DateTimeEstimator.ColumnsProduced.YearIso: return typeof(int); - case DateTimeTransformerEstimator.ColumnsProduced.DayOfYear: - case DateTimeTransformerEstimator.ColumnsProduced.WeekOfMonth: + case DateTimeEstimator.ColumnsProduced.DayOfYear: + case DateTimeEstimator.ColumnsProduced.WeekOfMonth: return typeof(ushort); - case DateTimeTransformerEstimator.ColumnsProduced.MonthLabel: - case DateTimeTransformerEstimator.ColumnsProduced.AmPmLabel: - case DateTimeTransformerEstimator.ColumnsProduced.DayOfWeekLabel: - case DateTimeTransformerEstimator.ColumnsProduced.HolidayName: + case DateTimeEstimator.ColumnsProduced.MonthLabel: + case DateTimeEstimator.ColumnsProduced.AmPmLabel: + case DateTimeEstimator.ColumnsProduced.DayOfWeekLabel: + case DateTimeEstimator.ColumnsProduced.HolidayName: return typeof(ReadOnlyMemory); default: return typeof(byte); @@ -114,9 +114,9 @@ internal static Type GetRawColumnType(this DateTimeTransformerEstimator.ColumnsP /// ]]> /// /// - /// - /// - public sealed class DateTimeTransformerEstimator : IEstimator + /// + /// + public sealed class DateTimeEstimator : IEstimator { private readonly Options _options; @@ -136,12 +136,12 @@ internal sealed class Options : TransformInputBase public ColumnsProduced[] ColumnsToDrop; [Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)] - public Countries Country = Countries.None; + public HolidayList Country = HolidayList.None; } #endregion - public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, Countries country = Countries.None) + internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, HolidayList country = HolidayList.None) { Contracts.CheckValue(env, nameof(env)); @@ -157,7 +157,7 @@ public DateTimeTransformerEstimator(IHostEnvironment env, string inputColumnName }; } - internal DateTimeTransformerEstimator(IHostEnvironment env, Options options) + internal DateTimeEstimator(IHostEnvironment env, Options options) { Contracts.CheckValue(env, nameof(env)); @@ -189,18 +189,70 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) #region Enums public enum ColumnsProduced : byte { - Year = 1, Month, Day, Hour, Minute, Second, AmPm, Hour12, DayOfWeek, DayOfQuarter, DayOfYear, - WeekOfMonth, QuarterOfYear, HalfOfYear, WeekIso, YearIso, MonthLabel, AmPmLabel, DayOfWeekLabel, - HolidayName, IsPaidTimeOff + Year = 1, + Month = 2, + Day = 3, + Hour = 4, + Minute = 5, + Second = 6, + AmPm = 7, + Hour12 = 8, + DayOfWeek = 9, + DayOfQuarter = 10, + DayOfYear = 11, + WeekOfMonth = 12, + QuarterOfYear = 13, + HalfOfYear = 14, + WeekIso = 15, + YearIso = 16, + MonthLabel = 17, + AmPmLabel = 18, + DayOfWeekLabel = 19, + HolidayName = 20, + IsPaidTimeOff = 21 }; - public enum Countries : byte + public enum HolidayList : uint { None = 1, - Argentina, Australia, Austria, Belarus, Belgium, Brazil, Canada, Colombia, Croatia, Czech, Denmark, - England, Finland, France, Germany, Hungary, India, Ireland, IsleofMan, Italy, Japan, Mexico, Netherlands, - NewZealand, NorthernIreland, Norway, Poland, Portugal, Scotland, Slovenia, SouthAfrica, Spain, Sweden, Switzerland, - Ukraine, UnitedKingdom, UnitedStates, Wales + Argentina = 2, + Australia = 3, + Austria = 4, + Belarus = 5, + Belgium = 6, + Brazil = 7, + Canada = 8, + Colombia = 9, + Croatia = 10, + Czech = 11, + Denmark = 12, + England = 13, + Finland = 14, + France = 15, + Germany = 16, + Hungary = 17, + India = 18, + Ireland = 19, + IsleofMan = 20, + Italy = 21, + Japan = 22, + Mexico = 23, + Netherlands = 24, + NewZealand = 25, + NorthernIreland = 26, + Norway = 27, + Poland = 28, + Portugal = 29, + Scotland = 30, + Slovenia = 31, + SouthAfrica = 32, + Spain = 33, + Sweden = 34, + Switzerland = 35, + Ukraine = 36, + UnitedKingdom = 37, + UnitedStates = 38, + Wales = 39 } #endregion @@ -215,22 +267,22 @@ public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable internal const string ShortName = "DateTimeTransform"; internal const string LoadName = "DateTimeTransform"; internal const string LoaderSignature = "DateTimeTransform"; - private DateTimeTypedColumn _column; + private LongTypedColumn _column; - private DateTimeTransformerEstimator.ColumnsProduced[] _columnsToDrop; + private DateTimeEstimator.ColumnsProduced[] _columnsToDrop; private byte[] _activeColumnMapping; #endregion - public DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeTransformerEstimator.ColumnsProduced[] columnsToDrop, DateTimeTransformerEstimator.Countries country) : + internal DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeEstimator.ColumnsProduced[] columnsToDrop, DateTimeEstimator.HolidayList country) : base(env.Register(nameof(DateTimeTransformer))) { _columnsToDrop = columnsToDrop; - var activeColumnLength = Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - (_columnsToDrop == null ? 0 : _columnsToDrop.Length); + var activeColumnLength = Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced)).Length - (_columnsToDrop == null ? 0 : _columnsToDrop.Length); _activeColumnMapping = new byte[activeColumnLength]; var index = 0; - foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) { if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) { @@ -238,7 +290,7 @@ public DateTimeTransformer(IHostEnvironment env, string inputColumnName, string } } - _column = new DateTimeTypedColumn(inputColumnName, columnPrefix); + _column = new LongTypedColumn(inputColumnName, columnPrefix); _column.CreateTransformerFromEstimator(country); } @@ -257,20 +309,20 @@ internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) : // length of C++ state array // C++ byte state array - _column = new DateTimeTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString()); + _column = new LongTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString()); var dropColumnsLength = ctx.Reader.ReadInt32(); if (dropColumnsLength > 0) { - _columnsToDrop = new DateTimeTransformerEstimator.ColumnsProduced[dropColumnsLength]; + _columnsToDrop = new DateTimeEstimator.ColumnsProduced[dropColumnsLength]; //read in enum bytes for (int i = 0; i < dropColumnsLength; i++) - _columnsToDrop[i] = (DateTimeTransformerEstimator.ColumnsProduced)ctx.Reader.ReadByte(); + _columnsToDrop[i] = (DateTimeEstimator.ColumnsProduced)ctx.Reader.ReadByte(); } - _activeColumnMapping = new byte[Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced)).Length - dropColumnsLength]; + _activeColumnMapping = new byte[Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced)).Length - dropColumnsLength]; var index = 0; - foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) { if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) { @@ -462,7 +514,7 @@ internal TypedColumn(string source, string prefix) Prefix = prefix; } - internal abstract void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country); + internal abstract void CreateTransformerFromEstimator(DateTimeEstimator.HolidayList country); private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); private protected unsafe abstract bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); @@ -471,19 +523,19 @@ internal TypedColumn(string source, string prefix) private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); public abstract void Dispose(); - private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(DateTimeTransformerEstimator.Countries country) + private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(DateTimeEstimator.HolidayList country) { bool success; IntPtr errorHandle; IntPtr estimator; - if (country == DateTimeTransformerEstimator.Countries.None) + if (country == DateTimeEstimator.HolidayList.None) { success = CreateEstimatorHelper(null, null, out estimator, out errorHandle); } else { fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) - fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeTransformerEstimator.Countries), country) + char.MinValue)) + fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeEstimator.HolidayList), country) + char.MinValue)) { success = CreateEstimatorHelper(countryPointer, dataRootDir, out estimator, out errorHandle); } @@ -546,10 +598,10 @@ internal TypedColumn(string source, string prefix) : #region DateTimeTypedColumn - internal sealed class DateTimeTypedColumn : TypedColumn + internal sealed class LongTypedColumn : TypedColumn { private TransformerEstimatorSafeHandle _transformerHandler; - internal DateTimeTypedColumn(string source, string prefix) : + internal LongTypedColumn(string source, string prefix) : base(source, prefix) { } @@ -564,7 +616,7 @@ internal DateTimeTypedColumn(string source, string prefix) : private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformer"), SuppressUnmanagedCodeSecurity] private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); - internal override unsafe void CreateTransformerFromEstimator(DateTimeTransformerEstimator.Countries country) + internal override unsafe void CreateTransformerFromEstimator(DateTimeEstimator.HolidayList country) { _transformerHandler = CreateTransformerFromEstimatorBase(country); } @@ -651,7 +703,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() { var columns = new List(); - foreach (DateTimeTransformerEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeTransformerEstimator.ColumnsProduced))) + foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) if (_parent._columnsToDrop == null || !_parent._columnsToDrop.Contains(column)) { columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), @@ -741,7 +793,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, ((DateTimeTransformerEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1); + return Utils.MarshalInvoke(MakeGetter, ((DateTimeEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1); } private protected override Func GetDependenciesCore(Func activeOutput) @@ -768,10 +820,10 @@ internal static class DateTimeTransformerEntrypoint Desc = DateTimeTransformer.Summary, UserName = DateTimeTransformer.UserName, ShortName = DateTimeTransformer.ShortName)] - public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeTransformerEstimator.Options input) + public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeEstimator.Options input) { var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input); - var xf = new DateTimeTransformerEstimator(h, input).Fit(input.Data).Transform(input.Data); + var xf = new DateTimeEstimator(h, input).Fit(input.Data).Transform(input.Data); return new CommonOutputs.TransformOutput() { Model = new TransformModelImpl(h, xf, input.Data), diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs index 6514ecd90b..b8d234c58e 100644 --- a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs @@ -100,7 +100,7 @@ public void DropOneColumn() // Build the pipeline, fit, and transform it. var columnPrefix = "DTC_"; - var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff); + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeEstimator.ColumnsProduced.IsPaidTimeOff); var model = pipeline.Fit(data); var output = model.Transform(data); var schema = output.Schema; @@ -143,8 +143,8 @@ public void DropManyColumns() // Build the pipeline, fit, and transform it. var columnPrefix = "DTC_"; - var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeTransformerEstimator.ColumnsProduced.IsPaidTimeOff, - DateTimeTransformerEstimator.ColumnsProduced.Day, DateTimeTransformerEstimator.ColumnsProduced.QuarterOfYear, DateTimeTransformerEstimator.ColumnsProduced.AmPm); + var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeEstimator.ColumnsProduced.IsPaidTimeOff, + DateTimeEstimator.ColumnsProduced.Day, DateTimeEstimator.ColumnsProduced.QuarterOfYear, DateTimeEstimator.ColumnsProduced.AmPm); var model = pipeline.Fit(data); var output = model.Transform(data); var schema = output.Schema; @@ -246,7 +246,7 @@ public void HolidayTest() var data = mlContext.Data.LoadFromEnumerable(dataList); // Build the pipeline, fit, and transform it. - var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC", country: DateTimeTransformerEstimator.Countries.Canada); + var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC", country: DateTimeEstimator.HolidayList.Canada); var model = pipeline.Fit(data); var output = model.Transform(data); @@ -317,7 +317,7 @@ public void EntryPointTest() var data = mlContext.Data.LoadFromEnumerable(dataList); // Build the pipeline, fit, and transform it. - var options = new DateTimeTransformerEstimator.Options + var options = new DateTimeEstimator.Options { ColumnsToDrop = null, Source = "date", From bee2616e08b964310595b63be35153d060db3e08 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 5 Dec 2019 11:10:11 -0800 Subject: [PATCH 3/8] removed ability to drop columns from DateTimeTransformer --- .../DateTimeTransformerDropColumns.cs | 80 ------------- .../DateTimeTransformer.cs | 101 ++++------------- .../Transformers/DateTimeTransformerTests.cs | 106 ------------------ 3 files changed, 21 insertions(+), 266 deletions(-) delete mode 100644 docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs deleted file mode 100644 index 5c85736d2f..0000000000 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformerDropColumns.cs +++ /dev/null @@ -1,80 +0,0 @@ -using System; -using System.Collections.Generic; -using Microsoft.ML; -using Microsoft.ML.Data; -using Microsoft.ML.Featurizers; - -namespace Samples.Dynamic -{ - public static class DateTimeTransformerDropColumns - { - private class DateTimeInput - { - public long Date; - } - - public static void Example() - { - // Create a new ML context, for ML.NET operations. It can be used for - // exception tracking and logging, as well as the source of randomness. - var mlContext = new MLContext(); - - // Create a small dataset as an IEnumerable. - // Future Date - 2025 June 30 - var samples = new[] { new DateTimeInput() { Date = 1751241600 } }; - - // Convert training data to IDataView. - var dataview = mlContext.Data.LoadFromEnumerable(samples); - - // A pipeline for splitting the time features into individual columns - // All the columns listed here will be dropped. - var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC", DateTimeEstimator.ColumnsProduced.IsPaidTimeOff, - DateTimeEstimator.ColumnsProduced.Day, DateTimeEstimator.ColumnsProduced.QuarterOfYear, - DateTimeEstimator.ColumnsProduced.AmPm, DateTimeEstimator.ColumnsProduced.HolidayName); - - // The transformed data. - var transformedData = pipeline.Fit(dataview).Transform(dataview); - - // Now let's take a look at what this did. We should have created 16 more columns with all the - // DateTime information split into its own columns - var featuresColumn = mlContext.Data.CreateEnumerable( - transformedData, reuseRowObject: false); - - // And we can write out a few rows - Console.WriteLine($"Features column obtained post-transformation."); - foreach (var featureRow in featuresColumn) - Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " + - featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + featureRow.DTCSecond + ", " + - featureRow.DTCHour12 + ", " + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + - featureRow.DTCDayOfYear + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCHalfOfYear + - ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " + - featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel); - - // Expected output: - // Features columns obtained post-transformation. - // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday - } - - // These columns start with DTC because that is the prefix we picked - private sealed class TransformedData - { - public long Date { get; set; } - public int DTCYear { get; set; } - public byte DTCMonth { get; set; } - public byte DTCHour { get; set; } - public byte DTCMinute { get; set; } - public byte DTCSecond { get; set; } - public byte DTCHour12 { get; set; } - public byte DTCDayOfWeek { get; set; } - public byte DTCDayOfQuarter { get; set; } - public ushort DTCDayOfYear { get; set; } - public ushort DTCWeekOfMonth { get; set; } - public byte DTCHalfOfYear { get; set; } - public byte DTCWeekIso { get; set; } - public int DTCYearIso { get; set; } - public string DTCMonthLabel { get; set; } - public string DTCAmPmLabel { get; set; } - public string DTCDayOfWeekLabel { get; set; } - } - } -} diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index fdcae35b0c..790b1941b0 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -38,32 +38,28 @@ public static class DateTimeTransformerExtensionClass /// /// Create a , which splits up the input column specified by /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. - /// This transformer will append the to all the output columns. If is empty, - /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. + /// This transformer will append the to all the output columns. /// /// Transform catalog /// Input column name /// Prefix to add to the generated columns - /// List of columns to drop, if any /// - public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, params DateTimeEstimator.ColumnsProduced[] columnsToDrop) - => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop); + public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix) + => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix); /// /// Create a , which splits up the input column specified by /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. - /// This transformer will append the to all the output columns. If is empty, - /// then all the columns are returned. Otherwise, the columns listed in the array will be dropped from the return value. If you specify a country, + /// This transformer will append the to all the output columns. If you specify a country, /// Holiday details will be looked up for that country as well. /// /// Transform catalog /// Input column name /// Prefix to add to the generated columns - /// List of columns to drop, if any /// Country name to get holiday details for /// - public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeEstimator.ColumnsProduced[] columnsToDrop = null, DateTimeEstimator.HolidayList country = DateTimeEstimator.HolidayList.None) - => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, columnsToDrop, country); + public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country = DateTimeEstimator.HolidayList.None) + => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, country); #region ColumnsProduced static extentions @@ -114,8 +110,8 @@ internal static Type GetRawColumnType(this DateTimeEstimator.ColumnsProduced col /// ]]> /// /// - /// - /// + /// + /// public sealed class DateTimeEstimator : IEstimator { private readonly Options _options; @@ -132,16 +128,13 @@ internal sealed class Options : TransformInputBase [Argument(ArgumentType.Required, HelpText = "Output column prefix", Name = "Prefix", ShortName = "pre", SortOrder = 2)] public string Prefix; - [Argument(ArgumentType.MultipleUnique, HelpText = "Columns to drop after the DateTime Expansion", Name = "ColumnsToDrop", ShortName = "drop", SortOrder = 3)] - public ColumnsProduced[] ColumnsToDrop; - [Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)] public HolidayList Country = HolidayList.None; } #endregion - internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, ColumnsProduced[] columnsToDrop, HolidayList country = HolidayList.None) + internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, HolidayList country = HolidayList.None) { Contracts.CheckValue(env, nameof(env)); @@ -152,7 +145,6 @@ internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string { Source = inputColumnName, Prefix = columnPrefix, - ColumnsToDrop = columnsToDrop == null ? Array.Empty() : columnsToDrop, Country = country }; } @@ -164,12 +156,11 @@ internal DateTimeEstimator(IHostEnvironment env, Options options) _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); _options = options; - _options.ColumnsToDrop = _options.ColumnsToDrop == null ? Array.Empty() : _options.ColumnsToDrop; } public DateTimeTransformer Fit(IDataView input) { - return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.ColumnsToDrop, _options.Country); + return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.Country); } public SchemaShape GetOutputSchema(SchemaShape inputSchema) @@ -177,11 +168,10 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var columns = inputSchema.ToDictionary(x => x.Name); foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced))) - if (_options.ColumnsToDrop == null || !_options.ColumnsToDrop.Contains(column)) - { - columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar, - ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null); - } + { + columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar, + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null); + } return new SchemaShape(columns.Values); } @@ -269,27 +259,11 @@ public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable internal const string LoaderSignature = "DateTimeTransform"; private LongTypedColumn _column; - private DateTimeEstimator.ColumnsProduced[] _columnsToDrop; - private byte[] _activeColumnMapping; - #endregion - internal DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeEstimator.ColumnsProduced[] columnsToDrop, DateTimeEstimator.HolidayList country) : + internal DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country) : base(env.Register(nameof(DateTimeTransformer))) { - - _columnsToDrop = columnsToDrop; - var activeColumnLength = Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced)).Length - (_columnsToDrop == null ? 0 : _columnsToDrop.Length); - _activeColumnMapping = new byte[activeColumnLength]; - var index = 0; - foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) - { - if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) - { - _activeColumnMapping[index++] = (byte)column; - } - } - _column = new LongTypedColumn(inputColumnName, columnPrefix); _column.CreateTransformerFromEstimator(country); } @@ -304,32 +278,11 @@ internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) : // *** Binary format *** // name of input column // column prefix - // byte length of columns to drop array - // byte array of columns to drop // length of C++ state array // C++ byte state array _column = new LongTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString()); - var dropColumnsLength = ctx.Reader.ReadInt32(); - if (dropColumnsLength > 0) - { - _columnsToDrop = new DateTimeEstimator.ColumnsProduced[dropColumnsLength]; - //read in enum bytes - for (int i = 0; i < dropColumnsLength; i++) - _columnsToDrop[i] = (DateTimeEstimator.ColumnsProduced)ctx.Reader.ReadByte(); - } - - _activeColumnMapping = new byte[Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced)).Length - dropColumnsLength]; - var index = 0; - foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) - { - if (_columnsToDrop == null || !_columnsToDrop.Contains(column)) - { - _activeColumnMapping[index++] = (byte)column; - } - } - var dataLength = ctx.Reader.ReadInt32(); var data = ctx.Reader.ReadByteArray(dataLength); _column.CreateTransformerFromSavedData(data); @@ -362,21 +315,12 @@ private protected override void SaveModel(ModelSaveContext ctx) // *** Binary format *** // name of input column // column prefix - // byte length of columns to drop array - // byte array of columns to drop // length of C++ state array // C++ byte state array ctx.Writer.Write(_column.Source); ctx.Writer.Write(_column.Prefix); - ctx.Writer.Write(_columnsToDrop == null ? 0 : _columnsToDrop.Length); - if (_columnsToDrop != null) - { - foreach (var toDrop in _columnsToDrop) - ctx.Writer.Write((byte)toDrop); - } - var data = _column.CreateTransformerSaveData(); ctx.Writer.Write(data.Length); ctx.Writer.Write(data); @@ -704,11 +648,10 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() var columns = new List(); foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) - if (_parent._columnsToDrop == null || !_parent._columnsToDrop.Contains(column)) - { - columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), - ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()))); - } + { + columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()))); + } return columns.ToArray(); } @@ -790,10 +733,8 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func, ((DateTimeEstimator.ColumnsProduced)outputColumn).GetRawColumnType(), input, outputColumn - 1); + // Have to add 1 to iinfo since the enum starts at 1 + return Utils.MarshalInvoke(MakeGetter, ((DateTimeEstimator.ColumnsProduced)iinfo + 1).GetRawColumnType(), input, iinfo); } private protected override Func GetDependenciesCore(Func activeOutput) diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs index b8d234c58e..79505c55f7 100644 --- a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs @@ -89,111 +89,6 @@ public void CorrectNumberOfColumnsAndSchema() Done(); } - [Fact] - public void DropOneColumn() - { - // TODO: This will fail until we figure out the C++ dll situation - - MLContext mlContext = new MLContext(1); - var dataList = new[] { new DateTimeInput() { date = 0 } }; - var data = mlContext.Data.LoadFromEnumerable(dataList); - - // Build the pipeline, fit, and transform it. - var columnPrefix = "DTC_"; - var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeEstimator.ColumnsProduced.IsPaidTimeOff); - var model = pipeline.Fit(data); - var output = model.Transform(data); - var schema = output.Schema; - - // Check the schema has 21 columns - Assert.Equal(21, schema.Count); - - // Make sure names with prefix and order are correct - Assert.Equal($"{columnPrefix}Year", schema[1].Name); - Assert.Equal($"{columnPrefix}Month", schema[2].Name); - Assert.Equal($"{columnPrefix}Day", schema[3].Name); - Assert.Equal($"{columnPrefix}Hour", schema[4].Name); - Assert.Equal($"{columnPrefix}Minute", schema[5].Name); - Assert.Equal($"{columnPrefix}Second", schema[6].Name); - Assert.Equal($"{columnPrefix}AmPm", schema[7].Name); - Assert.Equal($"{columnPrefix}Hour12", schema[8].Name); - Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name); - Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name); - Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name); - Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name); - Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name); - Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name); - Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name); - Assert.Equal($"{columnPrefix}YearIso", schema[16].Name); - Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name); - Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name); - Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name); - Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name); - - TestEstimatorCore(pipeline, data); - Done(); - } - - [Fact] - public void DropManyColumns() - { - MLContext mlContext = new MLContext(1); - var dataList = new[] { new DateTimeInput() { date = 0 } }; - var data = mlContext.Data.LoadFromEnumerable(dataList); - - // Build the pipeline, fit, and transform it. - var columnPrefix = "DTC_"; - var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix, DateTimeEstimator.ColumnsProduced.IsPaidTimeOff, - DateTimeEstimator.ColumnsProduced.Day, DateTimeEstimator.ColumnsProduced.QuarterOfYear, DateTimeEstimator.ColumnsProduced.AmPm); - var model = pipeline.Fit(data); - var output = model.Transform(data); - var schema = output.Schema; - - // Check the schema has 18 columns - Assert.Equal(18, schema.Count); - - // Make sure names with prefix and order are correct - Assert.Equal($"{columnPrefix}Year", schema[1].Name); - Assert.Equal($"{columnPrefix}Month", schema[2].Name); - Assert.Equal($"{columnPrefix}Hour", schema[3].Name); - Assert.Equal($"{columnPrefix}Minute", schema[4].Name); - Assert.Equal($"{columnPrefix}Second", schema[5].Name); - Assert.Equal($"{columnPrefix}Hour12", schema[6].Name); - Assert.Equal($"{columnPrefix}DayOfWeek", schema[7].Name); - Assert.Equal($"{columnPrefix}DayOfQuarter", schema[8].Name); - Assert.Equal($"{columnPrefix}DayOfYear", schema[9].Name); - Assert.Equal($"{columnPrefix}WeekOfMonth", schema[10].Name); - Assert.Equal($"{columnPrefix}HalfOfYear", schema[11].Name); - Assert.Equal($"{columnPrefix}WeekIso", schema[12].Name); - Assert.Equal($"{columnPrefix}YearIso", schema[13].Name); - Assert.Equal($"{columnPrefix}MonthLabel", schema[14].Name); - Assert.Equal($"{columnPrefix}AmPmLabel", schema[15].Name); - Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[16].Name); - Assert.Equal($"{columnPrefix}HolidayName", schema[17].Name); - - // Make sure types are correct - Assert.Equal(typeof(int), schema[1].Type.RawType); - Assert.Equal(typeof(byte), schema[2].Type.RawType); - Assert.Equal(typeof(byte), schema[3].Type.RawType); - Assert.Equal(typeof(byte), schema[4].Type.RawType); - Assert.Equal(typeof(byte), schema[5].Type.RawType); - Assert.Equal(typeof(byte), schema[6].Type.RawType); - Assert.Equal(typeof(byte), schema[7].Type.RawType); - Assert.Equal(typeof(byte), schema[8].Type.RawType); - Assert.Equal(typeof(ushort), schema[9].Type.RawType); - Assert.Equal(typeof(ushort), schema[10].Type.RawType); - Assert.Equal(typeof(byte), schema[11].Type.RawType); - Assert.Equal(typeof(byte), schema[12].Type.RawType); - Assert.Equal(typeof(int), schema[13].Type.RawType); - Assert.Equal(typeof(ReadOnlyMemory), schema[14].Type.RawType); - Assert.Equal(typeof(ReadOnlyMemory), schema[15].Type.RawType); - Assert.Equal(typeof(ReadOnlyMemory), schema[16].Type.RawType); - Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType); - - TestEstimatorCore(pipeline, data); - Done(); - } - [Fact] public void CanUseDateFromColumn() { @@ -319,7 +214,6 @@ public void EntryPointTest() // Build the pipeline, fit, and transform it. var options = new DateTimeEstimator.Options { - ColumnsToDrop = null, Source = "date", Prefix = "pref_", Data = data From 0d10ae7d48134891c98eb5563155b3b8b2b44bb1 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 5 Dec 2019 15:21:45 -0800 Subject: [PATCH 4/8] changes from PR comments --- .../DateTimeTransformer.cs | 100 ++++++++++-------- .../Common/EntryPoints/core_ep-list.tsv | 2 +- .../Common/EntryPoints/core_manifest.json | 40 ------- 3 files changed, 59 insertions(+), 83 deletions(-) diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index 790b1941b0..5440556917 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -379,62 +379,71 @@ internal struct TimePoint public string HolidayName; public byte IsPaidTimeOff; - internal unsafe TimePoint(byte* rawData) + internal TimePoint(ReadOnlySpan rawData, int intPtrSize) { - int intPtrSize = sizeof(IntPtr); - - Year = *(int*)rawData; - rawData += 4; - - Month = *rawData++; - Day = *rawData++; - Hour = *rawData++; - Minute = *rawData++; - Second = *rawData++; - AmPm = *rawData++; - Hour12 = *rawData++; - DayOfWeek = *rawData++; - DayOfQuarter = *rawData++; - DayOfYear = *(ushort*)rawData; - rawData += 2; - - WeekOfMonth = *(ushort*)rawData; - rawData += 2; - - QuarterOfYear = *rawData++; - HalfOfYear = *rawData++; - WeekIso = *rawData++; - YearIso = *(int*)rawData; - rawData += 4; + + int index = 0; + + Year = MemoryMarshal.Read(rawData); + index += 4; + + Month = rawData[index++]; + Day = rawData[index++]; + Hour = rawData[index++]; + Minute = rawData[index++]; + Second = rawData[index++]; + AmPm = rawData[index++]; + Hour12 = rawData[index++]; + DayOfWeek = rawData[index++]; + DayOfQuarter = rawData[index++]; + DayOfYear = MemoryMarshal.Read(rawData.Slice(index)); + index += 2; + + WeekOfMonth = MemoryMarshal.Read(rawData.Slice(index)); + index += 2; + + QuarterOfYear = rawData[index++]; + HalfOfYear = rawData[index++]; + WeekIso = rawData[index++]; + YearIso = MemoryMarshal.Read(rawData.Slice(index)); + index += 4; // Convert char * to string - MonthLabel = GetStringFromPointer(ref rawData, intPtrSize); - AmPmLabel = GetStringFromPointer(ref rawData, intPtrSize); - DayOfWeekLabel = GetStringFromPointer(ref rawData, intPtrSize); - HolidayName = GetStringFromPointer(ref rawData, intPtrSize); - IsPaidTimeOff = *rawData; + MonthLabel = GetStringFromPointer(ref rawData, ref index, intPtrSize); + AmPmLabel = GetStringFromPointer(ref rawData, ref index, intPtrSize); + DayOfWeekLabel = GetStringFromPointer(ref rawData, ref index, intPtrSize); + HolidayName = GetStringFromPointer(ref rawData, ref index, intPtrSize); + IsPaidTimeOff = rawData[index]; } // Converts a pointer to a native char* to a string and increments pointer by to the next value. // The length of the string is stored at byte* + sizeof(IntPtr). - private static unsafe string GetStringFromPointer(ref byte* rawData, int intPtrSize) + private static unsafe string GetStringFromPointer(ref ReadOnlySpan rawData, ref int index, int intPtrSize) { - byte[] buffer; - if (intPtrSize == 4) // 32 bit machine - buffer = new byte[*(uint*)(rawData + intPtrSize)]; + ulong stringLength; + ReadOnlySpan buffer; + if (intPtrSize == 4) // 32 bit machine + { + stringLength = MemoryMarshal.Read(rawData.Slice(index + intPtrSize)); + IntPtr stringData = new IntPtr(MemoryMarshal.Read(rawData.Slice(index))); + buffer = new ReadOnlySpan(stringData.ToPointer(), (int)stringLength); + } else // 64 bit machine - buffer = new byte[*(ulong*)(rawData + intPtrSize)]; + { + stringLength = MemoryMarshal.Read(rawData.Slice(index + intPtrSize)); + IntPtr stringData = new IntPtr(MemoryMarshal.Read(rawData.Slice(index))); + buffer = new ReadOnlySpan(stringData.ToPointer(), (int)stringLength); + } - if (buffer.Length == 0) + if (stringLength == 0) { - rawData += intPtrSize * 2; + index += intPtrSize * 2; return string.Empty; } - Marshal.Copy(new IntPtr(*(int**)rawData), buffer, 0, buffer.Length); - rawData += intPtrSize * 2; + index += intPtrSize * 2; - return Encoding.UTF8.GetString(buffer); + return Encoding.UTF8.GetString(buffer.ToArray()); } }; @@ -545,9 +554,15 @@ internal TypedColumn(string source, string prefix) : internal sealed class LongTypedColumn : TypedColumn { private TransformerEstimatorSafeHandle _transformerHandler; + private readonly int _intPtrSize; + private readonly int _structSize; internal LongTypedColumn(string source, string prefix) : base(source, prefix) { + _intPtrSize = IntPtr.Size; + + // The native struct is 25 bytes + 8 size_t. + _structSize = 25 + (_intPtrSize * 8); } [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateEstimator"), SuppressUnmanagedCodeSecurity] @@ -591,9 +606,10 @@ internal override TimePoint Transform(long input) using (var handler = new TransformedDataSafeHandle(output, DestroyTransformedDataNative)) { + // 29 plus size. unsafe { - return new TimePoint((byte*)output.ToPointer()); + return new TimePoint(new ReadOnlySpan(output.ToPointer(), _structSize), _intPtrSize); } } } diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 7125ed4188..8b242c5f31 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -85,7 +85,7 @@ Transforms.CombinerByContiguousGroupId Groups values of a scalar column into a v Transforms.ConditionalNormalizer Normalize the columns only if needed Microsoft.ML.Data.Normalize IfNeeded Microsoft.ML.Transforms.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput] Transforms.DatasetScorer Score a dataset with a predictor model Microsoft.ML.EntryPoints.ScoreModel Score Microsoft.ML.EntryPoints.ScoreModel+Input Microsoft.ML.EntryPoints.ScoreModel+Output Transforms.DatasetTransformScorer Score a dataset with a transform model Microsoft.ML.EntryPoints.ScoreModel ScoreUsingTransform Microsoft.ML.EntryPoints.ScoreModel+InputTransformScorer Microsoft.ML.EntryPoints.ScoreModel+Output -Transforms.DateTimeSplitter Splits a date time value into each individual component Microsoft.ML.Featurizers.DateTimeTransformerEntrypoint DateTimeSplit Microsoft.ML.Featurizers.DateTimeTransformerEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput +Transforms.DateTimeSplitter Splits a date time value into each individual component Microsoft.ML.Featurizers.DateTimeTransformerEntrypoint DateTimeSplit Microsoft.ML.Featurizers.DateTimeEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.Dictionarizer Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Text.TextAnalytics TermTransform Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureCombiner Combines all the features into one feature column. Microsoft.ML.EntryPoints.FeatureCombiner PrepareFeatures Microsoft.ML.EntryPoints.FeatureCombiner+FeatureCombinerInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Transforms.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Transforms.FeatureContributionCalculatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index ac968830dc..7f768d1926 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -18113,46 +18113,6 @@ "SortOrder": 2.0, "IsNullable": false }, - { - "Name": "ColumnsToDrop", - "Type": { - "Kind": "Array", - "ItemType": { - "Kind": "Enum", - "Values": [ - "Year", - "Month", - "Day", - "Hour", - "Minute", - "Second", - "AmPm", - "Hour12", - "DayOfWeek", - "DayOfQuarter", - "DayOfYear", - "WeekOfMonth", - "QuarterOfYear", - "HalfOfYear", - "WeekIso", - "YearIso", - "MonthLabel", - "AmPmLabel", - "DayOfWeekLabel", - "HolidayName", - "IsPaidTimeOff" - ] - } - }, - "Desc": "Columns to drop after the DateTime Expansion", - "Aliases": [ - "drop" - ], - "Required": false, - "SortOrder": 3.0, - "IsNullable": false, - "Default": null - }, { "Name": "Country", "Type": { From 6d5625b89ee02c31b1ba43e2ffde15504f35ab2a Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 9 Dec 2019 11:47:25 -0800 Subject: [PATCH 5/8] updates from PR comments --- .../DateTimeTransformer.cs | 21 +++++++++++-------- .../Microsoft.ML.Featurizers.csproj | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index 5440556917..2e1646be9c 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -442,15 +442,18 @@ private static unsafe string GetStringFromPointer(ref ReadOnlySpan rawData } index += intPtrSize * 2; - +#if NETSTANDARD2_0 return Encoding.UTF8.GetString(buffer.ToArray()); +#else + return Encoding.UTF8.GetString(buffer); +#endif } }; - #endregion +#endregion - #region BaseClass +#region BaseClass internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); @@ -547,9 +550,9 @@ internal TypedColumn(string source, string prefix) : } - #endregion +#endregion - #region DateTimeTypedColumn +#region DateTimeTypedColumn internal sealed class LongTypedColumn : TypedColumn { @@ -638,18 +641,18 @@ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffe CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); } - #endregion +#endregion private sealed class Mapper : MapperBase { - #region Class data members +#region Class data members private readonly DateTimeTransformer _parent; private ConcurrentDictionary _cache; private ConcurrentQueue _oldestKeys; - #endregion +#endregion public Mapper(DateTimeTransformer parent, DataViewSchema inputSchema) : base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) @@ -674,10 +677,10 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() private Delegate MakeGetter(DataViewRow input, int iinfo) { + var getter = input.GetGetter(input.Schema[_parent._column.Source]); ValueGetter result = (ref T dst) => { long dateTime = default; - var getter = input.GetGetter(input.Schema[_parent._column.Source]); getter(ref dateTime); if (!_cache.TryGetValue(dateTime, out TimePoint timePoint)) diff --git a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj index f3a35c5d85..e80548611c 100644 --- a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj +++ b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj @@ -1,7 +1,7 @@  - netstandard2.0 + netstandard2.0;netcoreapp2.1 Microsoft.ML.Featurizers true From 02f808e6f8041b4a70a24c09662d76c98ff04394 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Wed, 11 Dec 2019 12:54:15 -0800 Subject: [PATCH 6/8] updates to fix build issues --- src/Microsoft.ML.Featurizers/DateTimeTransformer.cs | 4 ++-- src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index 2e1646be9c..ecff6b1d71 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -490,7 +490,7 @@ private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEst } else { - fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(AppDomain.CurrentDomain.BaseDirectory + char.MinValue)) fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeEstimator.HolidayList), country) + char.MinValue)) { success = CreateEstimatorHelper(countryPointer, dataRootDir, out estimator, out errorHandle); @@ -587,7 +587,7 @@ internal override unsafe void CreateTransformerFromEstimator(DateTimeEstimator.H private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, byte* dataRootDir, out IntPtr transformer, out IntPtr errorHandle); private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) { - fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location) + char.MinValue)) + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(AppDomain.CurrentDomain.BaseDirectory + char.MinValue)) { var result = CreateTransformerFromSavedDataNative(rawData, dataSize, dataRootDir, out IntPtr transformer, out IntPtr errorHandle); if (!result) diff --git a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj index e80548611c..a221b11509 100644 --- a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj +++ b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj @@ -7,7 +7,7 @@ - + From de2fc126901dd1b4dc7aabf960aab49a8c437841 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Thu, 12 Dec 2019 15:56:46 -0800 Subject: [PATCH 7/8] test for centos7 --- .../Featurizers/DateTimeTransformer.cs | 2 +- .../DateTimeTransformer.cs | 19 ++------ .../Attributes/NotCentOS7FactAttribute.cs | 45 +++++++++++++++++++ .../Transformers/DateTimeTransformerTests.cs | 19 ++++---- 4 files changed, 59 insertions(+), 26 deletions(-) create mode 100644 test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs index 9bf851b0ba..d28ef0b3bf 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs @@ -27,7 +27,7 @@ public static void Example() var dataview = mlContext.Data.LoadFromEnumerable(samples); // A pipeline for splitting the time features into individual columns - var pipeline = mlContext.Transforms.DateTimeTransformer("Date", "DTC"); + var pipeline = mlContext.Transforms.FeaturizeDateTime("Date", "DTC"); // The transformed data. var transformedData = pipeline.Fit(dataview).Transform(dataview); diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index ecff6b1d71..78ebdef39f 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -35,18 +35,6 @@ namespace Microsoft.ML.Featurizers public static class DateTimeTransformerExtensionClass { - /// - /// Create a , which splits up the input column specified by - /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. - /// This transformer will append the to all the output columns. - /// - /// Transform catalog - /// Input column name - /// Prefix to add to the generated columns - /// - public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix) - => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix); - /// /// Create a , which splits up the input column specified by /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. @@ -58,7 +46,7 @@ public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catal /// Prefix to add to the generated columns /// Country name to get holiday details for /// - public static DateTimeEstimator DateTimeTransformer(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country = DateTimeEstimator.HolidayList.None) + public static DateTimeEstimator FeaturizeDateTime(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country = DateTimeEstimator.HolidayList.None) => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, country); #region ColumnsProduced static extentions @@ -104,14 +92,13 @@ internal static Type GetRawColumnType(this DateTimeEstimator.ColumnsProduced col /// | Input column data type | Int64 | /// | Output column data type | Columns and types listed in the summary | /// - /// The is a trivial estimator and does not need training. + /// The is a trivial estimator and does not need training. /// /// /// ]]> /// /// - /// - /// + /// public sealed class DateTimeEstimator : IEstimator { private readonly Options _options; diff --git a/test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs new file mode 100644 index 0000000000..849d765c21 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Microsoft.ML.TestFrameworkCommon.Attributes; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests that wont run on CentOS7 + /// + public sealed class NotCentOS7FactAttribute : EnvironmentSpecificFactAttribute + { + public NotCentOS7FactAttribute() : base("These tests are not CentOS7 complient.") + { + } + protected override bool IsEnvironmentSupported() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + using (Process process = new Process()) + { + process.StartInfo.FileName = "/bin/bash"; + process.StartInfo.Arguments= "-c \"cat /etc/*-release\""; + process.StartInfo.UseShellExecute = false; + process.StartInfo.RedirectStandardOutput = true; + process.StartInfo.CreateNoWindow = true; + process.Start(); + + string distro = process.StandardOutput.ReadToEnd().Trim(); + + process.WaitForExit(); + if (distro.Contains("CentOS Linux 7")) + { + return false; + } + } + } + return true; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs index 79505c55f7..4b91b30586 100644 --- a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs @@ -8,6 +8,7 @@ using System; using Xunit; using Xunit.Abstractions; +using Microsoft.ML.TestFramework.Attributes; namespace Microsoft.ML.Tests.Transformers { @@ -22,7 +23,7 @@ private class DateTimeInput public long date; } - [Fact] + [NotCentOS7FactAttribute] public void CorrectNumberOfColumnsAndSchema() { MLContext mlContext = new MLContext(1); @@ -31,7 +32,7 @@ public void CorrectNumberOfColumnsAndSchema() // Build the pipeline, fit, and transform it. var columnPrefix = "DTC_"; - var pipeline = mlContext.Transforms.DateTimeTransformer("date", columnPrefix); + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", columnPrefix); var model = pipeline.Fit(data); var output = model.Transform(data); var schema = output.Schema; @@ -89,7 +90,7 @@ public void CorrectNumberOfColumnsAndSchema() Done(); } - [Fact] + [NotCentOS7FactAttribute] public void CanUseDateFromColumn() { // Future Date - 2025 June 30 @@ -98,7 +99,7 @@ public void CanUseDateFromColumn() var data = mlContext.Data.LoadFromEnumerable(dataList); // Build the pipeline, fit, and transform it. - var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC"); + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", "DTC"); var model = pipeline.Fit(data); var output = model.Transform(data); @@ -132,7 +133,7 @@ public void CanUseDateFromColumn() Done(); } - [Fact] + [NotCentOS7FactAttribute] public void HolidayTest() { // Future Date - 2025 June 30 @@ -141,7 +142,7 @@ public void HolidayTest() var data = mlContext.Data.LoadFromEnumerable(dataList); // Build the pipeline, fit, and transform it. - var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC", country: DateTimeEstimator.HolidayList.Canada); + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", "DTC", country: DateTimeEstimator.HolidayList.Canada); var model = pipeline.Fit(data); var output = model.Transform(data); @@ -156,7 +157,7 @@ public void HolidayTest() Done(); } - [Fact] + [NotCentOS7FactAttribute] public void ManyRowsTest() { // Future Date - 2025 June 30 @@ -169,7 +170,7 @@ public void ManyRowsTest() var data = mlContext.Data.LoadFromEnumerable(dataList); // Build the pipeline, fit, and transform it. - var pipeline = mlContext.Transforms.DateTimeTransformer("date", "DTC"); + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", "DTC"); var model = pipeline.Fit(data); var output = model.Transform(data); @@ -203,7 +204,7 @@ public void ManyRowsTest() Done(); } - [Fact] + [NotCentOS7FactAttribute] public void EntryPointTest() { // Future Date - 2025 June 30 From a93e8bfe870ed33131f3ae0e6259935e025ba710 Mon Sep 17 00:00:00 2001 From: Michael Sharp Date: Mon, 16 Dec 2019 13:48:09 -0800 Subject: [PATCH 8/8] Fixed CentOS7 check --- src/Microsoft.ML.Featurizers/Common.cs | 27 +++++++++++++++++++ .../DateTimeTransformer.cs | 11 +++++--- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.Featurizers/Common.cs b/src/Microsoft.ML.Featurizers/Common.cs index 0d2111d8a2..26158b453f 100644 --- a/src/Microsoft.ML.Featurizers/Common.cs +++ b/src/Microsoft.ML.Featurizers/Common.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.InteropServices; using System.Security; using System.Text; @@ -219,5 +220,31 @@ internal static TypeId GetNativeTypeIdFromType(this Type type) throw new InvalidOperationException($"Unsupported type {type}"); } + + // The Native Featurizers do not currently support CentOS7, this method checks the OS and returns true if it is CentOS7. + internal static bool OsIsCentOS7() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + using (Process process = new Process()) + { + process.StartInfo.FileName = "/bin/bash"; + process.StartInfo.Arguments = "-c \"cat /etc/*-release\""; + process.StartInfo.UseShellExecute = false; + process.StartInfo.RedirectStandardOutput = true; + process.StartInfo.CreateNoWindow = true; + process.Start(); + + string distro = process.StandardOutput.ReadToEnd().Trim(); + + process.WaitForExit(); + if (distro.Contains("CentOS Linux 7")) + { + return true; + } + } + } + return false; + } } } diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs index 78ebdef39f..1f100e49a5 100644 --- a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -127,6 +127,7 @@ internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string Contracts.CheckValue(env, nameof(env)); _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); _host.CheckValue(inputColumnName, nameof(inputColumnName), "Input column should not be null."); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); _options = new Options { @@ -138,9 +139,9 @@ internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string internal DateTimeEstimator(IHostEnvironment env, Options options) { - Contracts.CheckValue(env, nameof(env)); _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); _options = options; } @@ -248,9 +249,11 @@ public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable #endregion - internal DateTimeTransformer(IHostEnvironment env, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country) : - base(env.Register(nameof(DateTimeTransformer))) + internal DateTimeTransformer(IHostEnvironment host, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country) : + base(host.Register(nameof(DateTimeTransformer))) { + host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + _column = new LongTypedColumn(inputColumnName, columnPrefix); _column.CreateTransformerFromEstimator(country); } @@ -261,6 +264,8 @@ internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) : { Host.CheckValue(ctx, nameof(ctx)); + host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // name of input column