diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs new file mode 100644 index 0000000000..d28ef0b3bf --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/Featurizers/DateTimeTransformer.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Featurizers; + +namespace Samples.Dynamic +{ + public static class DateTimeTransformer + { + private class DateTimeInput + { + public long Date; + } + + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for + // exception tracking and logging, as well as the source of randomness. + var mlContext = new MLContext(); + + // Create a small dataset as an IEnumerable. + // Future Date - 2025 June 30 + var samples = new[] { new DateTimeInput() { Date = 1751241600 } }; + + // Convert training data to IDataView. + var dataview = mlContext.Data.LoadFromEnumerable(samples); + + // A pipeline for splitting the time features into individual columns + var pipeline = mlContext.Transforms.FeaturizeDateTime("Date", "DTC"); + + // The transformed data. + var transformedData = pipeline.Fit(dataview).Transform(dataview); + + // Now let's take a look at what this did. We should have created 21 more columns with all the + // DateTime information split into its own columns + var featuresColumn = mlContext.Data.CreateEnumerable( + transformedData, reuseRowObject: false); + + // And we can write out a few rows + Console.WriteLine($"Features column obtained post-transformation."); + foreach (var featureRow in featuresColumn) + Console.WriteLine(featureRow.Date + ", " + featureRow.DTCYear + ", " + featureRow.DTCMonth + ", " + + featureRow.DTCDay + ", " + featureRow.DTCHour + ", " + featureRow.DTCMinute + ", " + + featureRow.DTCSecond + ", " + featureRow.DTCAmPm + ", " + featureRow.DTCHour12 + ", " + + featureRow.DTCDayOfWeek + ", " + featureRow.DTCDayOfQuarter + ", " + featureRow.DTCDayOfYear + + ", " + featureRow.DTCWeekOfMonth + ", " + featureRow.DTCQuarterOfYear + ", " + featureRow.DTCHalfOfYear + + ", " + featureRow.DTCWeekIso + ", " + featureRow.DTCYearIso + ", " + featureRow.DTCMonthLabel + ", " + + featureRow.DTCAmPmLabel + ", " + featureRow.DTCDayOfWeekLabel + ", " + featureRow.DTCHolidayName + ", " + + featureRow.DTCIsPaidTimeOff); + + // Expected output: + // Features columns obtained post-transformation. + // 1751241600, 2025, 6, 30, 0, 0, 0, 0, 0, 1, 91, 180, 4, 2, 1, 27, 2025, June, am, Monday, , 0 + } + + // These columns start with DTC because that is the prefix we picked + private sealed class TransformedData + { + public long Date { get; set; } + public int DTCYear { get; set; } + public byte DTCMonth { get; set; } + public byte DTCDay { get; set; } + public byte DTCHour { get; set; } + public byte DTCMinute { get; set; } + public byte DTCSecond { get; set; } + public byte DTCAmPm { get; set; } + public byte DTCHour12 { get; set; } + public byte DTCDayOfWeek { get; set; } + public byte DTCDayOfQuarter { get; set; } + public ushort DTCDayOfYear { get; set; } + public ushort DTCWeekOfMonth { get; set; } + public byte DTCQuarterOfYear { get; set; } + public byte DTCHalfOfYear { get; set; } + public byte DTCWeekIso { get; set; } + public int DTCYearIso { get; set; } + public string DTCMonthLabel { get; set; } + public string DTCAmPmLabel { get; set; } + public string DTCDayOfWeekLabel { get; set; } + public string DTCHolidayName { get; set; } + public byte DTCIsPaidTimeOff { get; set; } + } + } +} diff --git a/src/Microsoft.ML.Featurizers/Common.cs b/src/Microsoft.ML.Featurizers/Common.cs index 0d2111d8a2..26158b453f 100644 --- a/src/Microsoft.ML.Featurizers/Common.cs +++ b/src/Microsoft.ML.Featurizers/Common.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Runtime.InteropServices; using System.Security; using System.Text; @@ -219,5 +220,31 @@ internal static TypeId GetNativeTypeIdFromType(this Type type) throw new InvalidOperationException($"Unsupported type {type}"); } + + // The Native Featurizers do not currently support CentOS7, this method checks the OS and returns true if it is CentOS7. + internal static bool OsIsCentOS7() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + using (Process process = new Process()) + { + process.StartInfo.FileName = "/bin/bash"; + process.StartInfo.Arguments = "-c \"cat /etc/*-release\""; + process.StartInfo.UseShellExecute = false; + process.StartInfo.RedirectStandardOutput = true; + process.StartInfo.CreateNoWindow = true; + process.Start(); + + string distro = process.StandardOutput.ReadToEnd().Trim(); + + process.WaitForExit(); + if (distro.Contains("CentOS Linux 7")) + { + return true; + } + } + } + return false; + } } } diff --git a/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs new file mode 100644 index 0000000000..1f100e49a5 --- /dev/null +++ b/src/Microsoft.ML.Featurizers/DateTimeTransformer.cs @@ -0,0 +1,786 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Security; +using System.Text; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms; +using Microsoft.Win32.SafeHandles; +using static Microsoft.ML.Featurizers.CommonExtensions; + +[assembly: LoadableClass(typeof(DateTimeTransformer), null, typeof(SignatureLoadModel), + DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(DateTimeTransformer), null, typeof(SignatureLoadRowMapper), + DateTimeTransformer.UserName, DateTimeTransformer.LoaderSignature)] + +[assembly: EntryPointModule(typeof(DateTimeTransformerEntrypoint))] + +namespace Microsoft.ML.Featurizers +{ + + public static class DateTimeTransformerExtensionClass + { + /// + /// Create a , which splits up the input column specified by + /// into all its individual datetime components. Input column must be of type Int64 representing the number of seconds since the unix epoc. + /// This transformer will append the to all the output columns. If you specify a country, + /// Holiday details will be looked up for that country as well. + /// + /// Transform catalog + /// Input column name + /// Prefix to add to the generated columns + /// Country name to get holiday details for + /// + public static DateTimeEstimator FeaturizeDateTime(this TransformsCatalog catalog, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country = DateTimeEstimator.HolidayList.None) + => new DateTimeEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName, columnPrefix, country); + + #region ColumnsProduced static extentions + + internal static Type GetRawColumnType(this DateTimeEstimator.ColumnsProduced column) + { + switch (column) + { + case DateTimeEstimator.ColumnsProduced.Year: + case DateTimeEstimator.ColumnsProduced.YearIso: + return typeof(int); + case DateTimeEstimator.ColumnsProduced.DayOfYear: + case DateTimeEstimator.ColumnsProduced.WeekOfMonth: + return typeof(ushort); + case DateTimeEstimator.ColumnsProduced.MonthLabel: + case DateTimeEstimator.ColumnsProduced.AmPmLabel: + case DateTimeEstimator.ColumnsProduced.DayOfWeekLabel: + case DateTimeEstimator.ColumnsProduced.HolidayName: + return typeof(ReadOnlyMemory); + default: + return typeof(byte); + } + } + + #endregion + } + + /// + /// The DateTimeTransformerEstimator splits up a date into all of its sub parts as individual columns. It generates these fields with a user specified prefix: + /// int Year, byte Month, byte Day, byte Hour, byte Minute, byte Second, byte AmPm, byte Hour12, byte DayOfWeek, byte DayOfQuarter, + /// ushort DayOfYear, ushort WeekOfMonth, byte QuarterOfYear, byte HalfOfYear, byte WeekIso, int YearIso, string MonthLabel, string AmPmLabel, + /// string DayOfWeekLabel, string HolidayName, byte IsPaidTimeOff + /// + /// You can optionally specify a country and it will pull holiday information about the country as well + /// + /// + /// is a trivial estimator and does not need training. + /// + /// + /// ]]> + /// + /// + /// + public sealed class DateTimeEstimator : IEstimator + { + private readonly Options _options; + + private readonly IHost _host; + + #region Options + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "Input column", Name = "Source", ShortName = "src", SortOrder = 1)] + public string Source; + + // This transformer adds columns + [Argument(ArgumentType.Required, HelpText = "Output column prefix", Name = "Prefix", ShortName = "pre", SortOrder = 2)] + public string Prefix; + + [Argument(ArgumentType.AtMostOnce, HelpText = "Country to get holidays for. Defaults to none if not passed", Name = "Country", ShortName = "ctry", SortOrder = 4)] + public HolidayList Country = HolidayList.None; + } + + #endregion + + internal DateTimeEstimator(IHostEnvironment env, string inputColumnName, string columnPrefix, HolidayList country = HolidayList.None) + { + + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + _host.CheckValue(inputColumnName, nameof(inputColumnName), "Input column should not be null."); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + + _options = new Options + { + Source = inputColumnName, + Prefix = columnPrefix, + Country = country + }; + } + + internal DateTimeEstimator(IHostEnvironment env, Options options) + { + Contracts.CheckValue(env, nameof(env)); + _host = Contracts.CheckRef(env, nameof(env)).Register("DateTimeTransformerEstimator"); + _host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + + _options = options; + } + + public DateTimeTransformer Fit(IDataView input) + { + return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.Country); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var columns = inputSchema.ToDictionary(x => x.Name); + + foreach (ColumnsProduced column in Enum.GetValues(typeof(ColumnsProduced))) + { + columns[_options.Prefix + column.ToString()] = new SchemaShape.Column(_options.Prefix + column.ToString(), SchemaShape.Column.VectorKind.Scalar, + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()), false, null); + } + + return new SchemaShape(columns.Values); + } + + #region Enums + public enum ColumnsProduced : byte + { + Year = 1, + Month = 2, + Day = 3, + Hour = 4, + Minute = 5, + Second = 6, + AmPm = 7, + Hour12 = 8, + DayOfWeek = 9, + DayOfQuarter = 10, + DayOfYear = 11, + WeekOfMonth = 12, + QuarterOfYear = 13, + HalfOfYear = 14, + WeekIso = 15, + YearIso = 16, + MonthLabel = 17, + AmPmLabel = 18, + DayOfWeekLabel = 19, + HolidayName = 20, + IsPaidTimeOff = 21 + }; + + public enum HolidayList : uint + { + None = 1, + Argentina = 2, + Australia = 3, + Austria = 4, + Belarus = 5, + Belgium = 6, + Brazil = 7, + Canada = 8, + Colombia = 9, + Croatia = 10, + Czech = 11, + Denmark = 12, + England = 13, + Finland = 14, + France = 15, + Germany = 16, + Hungary = 17, + India = 18, + Ireland = 19, + IsleofMan = 20, + Italy = 21, + Japan = 22, + Mexico = 23, + Netherlands = 24, + NewZealand = 25, + NorthernIreland = 26, + Norway = 27, + Poland = 28, + Portugal = 29, + Scotland = 30, + Slovenia = 31, + SouthAfrica = 32, + Spain = 33, + Sweden = 34, + Switzerland = 35, + Ukraine = 36, + UnitedKingdom = 37, + UnitedStates = 38, + Wales = 39 + } + + #endregion + } + + public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable + { + #region Class data members + + internal const string Summary = "Splits a date time value into each individual component"; + internal const string UserName = "DateTime Transform"; + internal const string ShortName = "DateTimeTransform"; + internal const string LoadName = "DateTimeTransform"; + internal const string LoaderSignature = "DateTimeTransform"; + private LongTypedColumn _column; + + #endregion + + internal DateTimeTransformer(IHostEnvironment host, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country) : + base(host.Register(nameof(DateTimeTransformer))) + { + host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + + _column = new LongTypedColumn(inputColumnName, columnPrefix); + _column.CreateTransformerFromEstimator(country); + } + + // Factory method for SignatureLoadModel. + internal DateTimeTransformer(IHostEnvironment host, ModelLoadContext ctx) : + base(host.Register(nameof(DateTimeTransformer))) + { + + Host.CheckValue(ctx, nameof(ctx)); + host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported"); + + ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // name of input column + // column prefix + // length of C++ state array + // C++ byte state array + + _column = new LongTypedColumn(ctx.Reader.ReadString(), ctx.Reader.ReadString()); + + var dataLength = ctx.Reader.ReadInt32(); + var data = ctx.Reader.ReadByteArray(dataLength); + _column.CreateTransformerFromSavedData(data); + } + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => new DateTimeTransformer(env, ctx).MakeRowMapper(inputSchema); + + private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema); + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "DATETI T", + verWrittenCur: 0x00010001, + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(DateTimeTransformer).Assembly.FullName); + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // name of input column + // column prefix + // length of C++ state array + // C++ byte state array + + ctx.Writer.Write(_column.Source); + ctx.Writer.Write(_column.Prefix); + + var data = _column.CreateTransformerSaveData(); + ctx.Writer.Write(data.Length); + ctx.Writer.Write(data); + } + + public void Dispose() + { + _column.Dispose(); + } + + #region C++ Safe handle classes + + internal class TransformedDataSafeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + private readonly DestroyTransformedDataNative _destroyTransformedDataHandler; + + public TransformedDataSafeHandle(IntPtr handle, DestroyTransformedDataNative destroyTransformedDataHandler) : base(true) + { + SetHandle(handle); + _destroyTransformedDataHandler = destroyTransformedDataHandler; + } + + protected override bool ReleaseHandle() + { + // Not sure what to do with error stuff here. There shoudln't ever be one though. + return _destroyTransformedDataHandler(handle, out IntPtr errorHandle); + } + } + + #endregion + + #region TimePoint + + [StructLayoutAttribute(LayoutKind.Sequential)] + internal struct TimePoint + { + public int Year; + public byte Month; + public byte Day; + public byte Hour; + public byte Minute; + public byte Second; + public byte AmPm; + public byte Hour12; + public byte DayOfWeek; + public byte DayOfQuarter; + public ushort DayOfYear; + public ushort WeekOfMonth; + public byte QuarterOfYear; + public byte HalfOfYear; + public byte WeekIso; + public int YearIso; + public string MonthLabel; + public string AmPmLabel; + public string DayOfWeekLabel; + public string HolidayName; + public byte IsPaidTimeOff; + + internal TimePoint(ReadOnlySpan rawData, int intPtrSize) + { + + int index = 0; + + Year = MemoryMarshal.Read(rawData); + index += 4; + + Month = rawData[index++]; + Day = rawData[index++]; + Hour = rawData[index++]; + Minute = rawData[index++]; + Second = rawData[index++]; + AmPm = rawData[index++]; + Hour12 = rawData[index++]; + DayOfWeek = rawData[index++]; + DayOfQuarter = rawData[index++]; + DayOfYear = MemoryMarshal.Read(rawData.Slice(index)); + index += 2; + + WeekOfMonth = MemoryMarshal.Read(rawData.Slice(index)); + index += 2; + + QuarterOfYear = rawData[index++]; + HalfOfYear = rawData[index++]; + WeekIso = rawData[index++]; + YearIso = MemoryMarshal.Read(rawData.Slice(index)); + index += 4; + + // Convert char * to string + MonthLabel = GetStringFromPointer(ref rawData, ref index, intPtrSize); + AmPmLabel = GetStringFromPointer(ref rawData, ref index, intPtrSize); + DayOfWeekLabel = GetStringFromPointer(ref rawData, ref index, intPtrSize); + HolidayName = GetStringFromPointer(ref rawData, ref index, intPtrSize); + IsPaidTimeOff = rawData[index]; + } + + // Converts a pointer to a native char* to a string and increments pointer by to the next value. + // The length of the string is stored at byte* + sizeof(IntPtr). + private static unsafe string GetStringFromPointer(ref ReadOnlySpan rawData, ref int index, int intPtrSize) + { + ulong stringLength; + ReadOnlySpan buffer; + if (intPtrSize == 4) // 32 bit machine + { + stringLength = MemoryMarshal.Read(rawData.Slice(index + intPtrSize)); + IntPtr stringData = new IntPtr(MemoryMarshal.Read(rawData.Slice(index))); + buffer = new ReadOnlySpan(stringData.ToPointer(), (int)stringLength); + } + else // 64 bit machine + { + stringLength = MemoryMarshal.Read(rawData.Slice(index + intPtrSize)); + IntPtr stringData = new IntPtr(MemoryMarshal.Read(rawData.Slice(index))); + buffer = new ReadOnlySpan(stringData.ToPointer(), (int)stringLength); + } + + if (stringLength == 0) + { + index += intPtrSize * 2; + return string.Empty; + } + + index += intPtrSize * 2; +#if NETSTANDARD2_0 + return Encoding.UTF8.GetString(buffer.ToArray()); +#else + return Encoding.UTF8.GetString(buffer); +#endif + } + + }; + +#endregion + +#region BaseClass + + internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle); + internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle); + internal delegate bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle); + + internal abstract class TypedColumn : IDisposable + { + internal readonly string Source; + internal readonly string Prefix; + + internal TypedColumn(string source, string prefix) + { + Source = source; + Prefix = prefix; + } + + internal abstract void CreateTransformerFromEstimator(DateTimeEstimator.HolidayList country); + private protected abstract unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize); + private protected unsafe abstract bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle); + private protected abstract bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle); + private protected abstract bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + public abstract void Dispose(); + + private protected unsafe TransformerEstimatorSafeHandle CreateTransformerFromEstimatorBase(DateTimeEstimator.HolidayList country) + { + bool success; + IntPtr errorHandle; + IntPtr estimator; + if (country == DateTimeEstimator.HolidayList.None) + { + success = CreateEstimatorHelper(null, null, out estimator, out errorHandle); + } + else + { + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(AppDomain.CurrentDomain.BaseDirectory + char.MinValue)) + fixed (byte* countryPointer = Encoding.UTF8.GetBytes(Enum.GetName(typeof(DateTimeEstimator.HolidayList), country) + char.MinValue)) + { + success = CreateEstimatorHelper(countryPointer, dataRootDir, out estimator, out errorHandle); + } + } + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + using (var estimatorHandler = new TransformerEstimatorSafeHandle(estimator, DestroyEstimatorHelper)) + { + + success = CreateTransformerFromEstimatorHelper(estimatorHandler, out IntPtr transformer, out errorHandle); + if (!success) + { + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + } + + return new TransformerEstimatorSafeHandle(transformer, DestroyTransformerHelper); + } + } + + internal byte[] CreateTransformerSaveData() + { + + var success = CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var savedDataHandle = new SaveDataSafeHandle(buffer, bufferSize)) + { + byte[] savedData = new byte[bufferSize.ToInt32()]; + Marshal.Copy(buffer, savedData, 0, savedData.Length); + return savedData; + } + } + + internal unsafe void CreateTransformerFromSavedData(byte[] data) + { + fixed (byte* rawData = data) + { + IntPtr dataSize = new IntPtr(data.Count()); + CreateTransformerFromSavedDataHelper(rawData, dataSize); + } + } + } + + internal abstract class TypedColumn : TypedColumn + { + internal TypedColumn(string source, string prefix) : + base(source, prefix) + { + } + + internal abstract TimePoint Transform(T input); + + } + +#endregion + +#region DateTimeTypedColumn + + internal sealed class LongTypedColumn : TypedColumn + { + private TransformerEstimatorSafeHandle _transformerHandler; + private readonly int _intPtrSize; + private readonly int _structSize; + internal LongTypedColumn(string source, string prefix) : + base(source, prefix) + { + _intPtrSize = IntPtr.Size; + + // The native struct is 25 bytes + 8 size_t. + _structSize = 25 + (_intPtrSize * 8); + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateEstimator"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateEstimatorNative(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle); + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyEstimatorNative(IntPtr estimator, out IntPtr errorHandle); // Should ONLY be called by safe handle + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromEstimator"), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerFromEstimatorNative(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformer"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformerNative(IntPtr transformer, out IntPtr errorHandle); + internal override unsafe void CreateTransformerFromEstimator(DateTimeEstimator.HolidayList country) + { + _transformerHandler = CreateTransformerFromEstimatorBase(country); + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerFromSavedDataWithDataRoot"), SuppressUnmanagedCodeSecurity] + private static unsafe extern bool CreateTransformerFromSavedDataNative(byte* rawData, IntPtr bufferSize, byte* dataRootDir, out IntPtr transformer, out IntPtr errorHandle); + private protected override unsafe void CreateTransformerFromSavedDataHelper(byte* rawData, IntPtr dataSize) + { + fixed (byte* dataRootDir = Encoding.UTF8.GetBytes(AppDomain.CurrentDomain.BaseDirectory + char.MinValue)) + { + var result = CreateTransformerFromSavedDataNative(rawData, dataSize, dataRootDir, out IntPtr transformer, out IntPtr errorHandle); + if (!result) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + _transformerHandler = new TransformerEstimatorSafeHandle(transformer, DestroyTransformerNative); + } + } + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_Transform"), SuppressUnmanagedCodeSecurity] + private static extern bool TransformDataNative(TransformerEstimatorSafeHandle transformer, long input, out IntPtr output, out IntPtr errorHandle); + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_DestroyTransformedData"), SuppressUnmanagedCodeSecurity] + private static extern bool DestroyTransformedDataNative(IntPtr output, out IntPtr errorHandle); + internal override TimePoint Transform(long input) + { + var success = TransformDataNative(_transformerHandler, input, out IntPtr output, out IntPtr errorHandle); + if (!success) + throw new Exception(GetErrorDetailsAndFreeNativeMemory(errorHandle)); + + using (var handler = new TransformedDataSafeHandle(output, DestroyTransformedDataNative)) + { + // 29 plus size. + unsafe + { + return new TimePoint(new ReadOnlySpan(output.ToPointer(), _structSize), _intPtrSize); + } + } + } + + public override void Dispose() + { + if (!_transformerHandler.IsClosed) + _transformerHandler.Dispose(); + } + + private protected unsafe override bool CreateEstimatorHelper(byte* countryName, byte* dataRootDir, out IntPtr estimator, out IntPtr errorHandle) => + CreateEstimatorNative(countryName, dataRootDir, out estimator, out errorHandle); + + private protected override bool CreateTransformerFromEstimatorHelper(TransformerEstimatorSafeHandle estimator, out IntPtr transformer, out IntPtr errorHandle) => + CreateTransformerFromEstimatorNative(estimator, out transformer, out errorHandle); + + private protected override bool DestroyEstimatorHelper(IntPtr estimator, out IntPtr errorHandle) => + DestroyEstimatorNative(estimator, out errorHandle); + + private protected override bool DestroyTransformerHelper(IntPtr transformer, out IntPtr errorHandle) => + DestroyTransformerNative(transformer, out errorHandle); + + [DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateTransformerSaveData", CallingConvention = CallingConvention.Cdecl), SuppressUnmanagedCodeSecurity] + private static extern bool CreateTransformerSaveDataNative(TransformerEstimatorSafeHandle transformer, out IntPtr buffer, out IntPtr bufferSize, out IntPtr error); + private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffer, out IntPtr bufferSize, out IntPtr errorHandle) => + CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle); + } + +#endregion + + private sealed class Mapper : MapperBase + { + +#region Class data members + + private readonly DateTimeTransformer _parent; + private ConcurrentDictionary _cache; + private ConcurrentQueue _oldestKeys; + +#endregion + + public Mapper(DateTimeTransformer parent, DataViewSchema inputSchema) : + base(parent.Host.Register(nameof(Mapper)), inputSchema, parent) + { + _parent = parent; + _cache = new ConcurrentDictionary(); + _oldestKeys = new ConcurrentQueue(); + } + + protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore() + { + var columns = new List(); + + foreach (DateTimeEstimator.ColumnsProduced column in Enum.GetValues(typeof(DateTimeEstimator.ColumnsProduced))) + { + columns.Add(new DataViewSchema.DetachedColumn(_parent._column.Prefix + column.ToString(), + ColumnTypeExtensions.PrimitiveTypeFromType(column.GetRawColumnType()))); + } + + return columns.ToArray(); + } + + private Delegate MakeGetter(DataViewRow input, int iinfo) + { + var getter = input.GetGetter(input.Schema[_parent._column.Source]); + ValueGetter result = (ref T dst) => + { + long dateTime = default; + getter(ref dateTime); + + if (!_cache.TryGetValue(dateTime, out TimePoint timePoint)) + { + _cache[dateTime] = _parent._column.Transform(dateTime); + _oldestKeys.Enqueue(dateTime); + timePoint = _cache[dateTime]; + + // If more than 100 cached items, remove 20 + if (_cache.Count > 100) + { + for (int i = 0; i < 20; i++) + { + long key; + while (!_oldestKeys.TryDequeue(out key)) { } + while (!_cache.TryRemove(key, out TimePoint removedValue)) { } + } + } + } + + if (iinfo == 0) + dst = (T)Convert.ChangeType(timePoint.Year, typeof(T)); + else if (iinfo == 1) + dst = (T)Convert.ChangeType(timePoint.Month, typeof(T)); + else if (iinfo == 2) + dst = (T)Convert.ChangeType(timePoint.Day, typeof(T)); + else if (iinfo == 3) + dst = (T)Convert.ChangeType(timePoint.Hour, typeof(T)); + else if (iinfo == 4) + dst = (T)Convert.ChangeType(timePoint.Minute, typeof(T)); + else if (iinfo == 5) + dst = (T)Convert.ChangeType(timePoint.Second, typeof(T)); + else if (iinfo == 6) + dst = (T)Convert.ChangeType(timePoint.AmPm, typeof(T)); + else if (iinfo == 7) + dst = (T)Convert.ChangeType(timePoint.Hour12, typeof(T)); + else if (iinfo == 8) + dst = (T)Convert.ChangeType(timePoint.DayOfWeek, typeof(T)); + else if (iinfo == 9) + dst = (T)Convert.ChangeType(timePoint.DayOfQuarter, typeof(T)); + else if (iinfo == 10) + dst = (T)Convert.ChangeType(timePoint.DayOfYear, typeof(T)); + else if (iinfo == 11) + dst = (T)Convert.ChangeType(timePoint.WeekOfMonth, typeof(T)); + else if (iinfo == 12) + dst = (T)Convert.ChangeType(timePoint.QuarterOfYear, typeof(T)); + else if (iinfo == 13) + dst = (T)Convert.ChangeType(timePoint.HalfOfYear, typeof(T)); + else if (iinfo == 14) + dst = (T)Convert.ChangeType(timePoint.WeekIso, typeof(T)); + else if (iinfo == 15) + dst = (T)Convert.ChangeType(timePoint.YearIso, typeof(T)); + else if (iinfo == 16) + dst = (T)Convert.ChangeType(timePoint.MonthLabel.AsMemory(), typeof(T)); + else if (iinfo == 17) + dst = (T)Convert.ChangeType(timePoint.AmPmLabel.AsMemory(), typeof(T)); + else if (iinfo == 18) + dst = (T)Convert.ChangeType(timePoint.DayOfWeekLabel.AsMemory(), typeof(T)); + else if (iinfo == 19) + dst = (T)Convert.ChangeType(timePoint.HolidayName.AsMemory(), typeof(T)); + else + dst = (T)Convert.ChangeType(timePoint.IsPaidTimeOff, typeof(T)); + }; + + return result; + } + + protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func activeOutput, out Action disposer) + { + disposer = null; + + // Have to add 1 to iinfo since the enum starts at 1 + return Utils.MarshalInvoke(MakeGetter, ((DateTimeEstimator.ColumnsProduced)iinfo + 1).GetRawColumnType(), input, iinfo); + } + + private protected override Func GetDependenciesCore(Func activeOutput) + { + var active = new bool[InputSchema.Count]; + for (int i = 0; i < InputSchema.Count; i++) + { + if (InputSchema[i].Name.Equals(_parent._column.Source)) + { + active[i] = true; + } + } + + return col => active[col]; + } + + private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx); + } + } + + internal static class DateTimeTransformerEntrypoint + { + [TlcModule.EntryPoint(Name = "Transforms.DateTimeSplitter", + Desc = DateTimeTransformer.Summary, + UserName = DateTimeTransformer.UserName, + ShortName = DateTimeTransformer.ShortName)] + public static CommonOutputs.TransformOutput DateTimeSplit(IHostEnvironment env, DateTimeEstimator.Options input) + { + var h = EntryPointUtils.CheckArgsAndCreateHost(env, DateTimeTransformer.ShortName, input); + var xf = new DateTimeEstimator(h, input).Fit(input.Data).Transform(input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModelImpl(h, xf, input.Data), + OutputData = xf + }; + } + } +} diff --git a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj index f3a35c5d85..a221b11509 100644 --- a/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj +++ b/src/Microsoft.ML.Featurizers/Microsoft.ML.Featurizers.csproj @@ -1,13 +1,13 @@  - netstandard2.0 + netstandard2.0;netcoreapp2.1 Microsoft.ML.Featurizers true - + diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 4f2bcc426a..8b242c5f31 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -85,6 +85,7 @@ Transforms.CombinerByContiguousGroupId Groups values of a scalar column into a v Transforms.ConditionalNormalizer Normalize the columns only if needed Microsoft.ML.Data.Normalize IfNeeded Microsoft.ML.Transforms.NormalizeTransform+MinMaxArguments Microsoft.ML.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput] Transforms.DatasetScorer Score a dataset with a predictor model Microsoft.ML.EntryPoints.ScoreModel Score Microsoft.ML.EntryPoints.ScoreModel+Input Microsoft.ML.EntryPoints.ScoreModel+Output Transforms.DatasetTransformScorer Score a dataset with a transform model Microsoft.ML.EntryPoints.ScoreModel ScoreUsingTransform Microsoft.ML.EntryPoints.ScoreModel+InputTransformScorer Microsoft.ML.EntryPoints.ScoreModel+Output +Transforms.DateTimeSplitter Splits a date time value into each individual component Microsoft.ML.Featurizers.DateTimeTransformerEntrypoint DateTimeSplit Microsoft.ML.Featurizers.DateTimeEstimator+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.Dictionarizer Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Transforms.Text.TextAnalytics TermTransform Microsoft.ML.Transforms.ValueToKeyMappingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureCombiner Combines all the features into one feature column. Microsoft.ML.EntryPoints.FeatureCombiner PrepareFeatures Microsoft.ML.EntryPoints.FeatureCombiner+FeatureCombinerInput Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput Transforms.FeatureContributionCalculationTransformer For each data point, calculates the contribution of individual features to the model prediction. Microsoft.ML.Transforms.FeatureContributionEntryPoint FeatureContributionCalculation Microsoft.ML.Transforms.FeatureContributionCalculatingTransformer+Options Microsoft.ML.EntryPoints.CommonOutputs+TransformOutput diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index c8e6d6e55c..7f768d1926 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -18077,6 +18077,117 @@ } ] }, + { + "Name": "Transforms.DateTimeSplitter", + "Desc": "Splits a date time value into each individual component", + "FriendlyName": "DateTime Transform", + "ShortName": "DateTimeTransform", + "Inputs": [ + { + "Name": "Source", + "Type": "String", + "Desc": "Input column", + "Aliases": [ + "src" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Prefix", + "Type": "String", + "Desc": "Output column prefix", + "Aliases": [ + "pre" + ], + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + }, + { + "Name": "Country", + "Type": { + "Kind": "Enum", + "Values": [ + "None", + "Argentina", + "Australia", + "Austria", + "Belarus", + "Belgium", + "Brazil", + "Canada", + "Colombia", + "Croatia", + "Czech", + "Denmark", + "England", + "Finland", + "France", + "Germany", + "Hungary", + "India", + "Ireland", + "IsleofMan", + "Italy", + "Japan", + "Mexico", + "Netherlands", + "NewZealand", + "NorthernIreland", + "Norway", + "Poland", + "Portugal", + "Scotland", + "Slovenia", + "SouthAfrica", + "Spain", + "Sweden", + "Switzerland", + "Ukraine", + "UnitedKingdom", + "UnitedStates", + "Wales" + ] + }, + "Desc": "Country to get holidays for. Defaults to none if not passed", + "Aliases": [ + "ctry" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": "None" + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.Dictionarizer", "Desc": "Converts input values (words, numbers, etc.) to index in a dictionary.", diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index bd425c05ff..268e7fc390 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -12,6 +12,7 @@ using Microsoft.ML.Data; using Microsoft.ML.Data.IO; using Microsoft.ML.EntryPoints; +using Microsoft.ML.Featurizers; using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Model; using Microsoft.ML.Model.OnnxConverter; @@ -330,6 +331,7 @@ public void EntryPointCatalogCheckDuplicateParams() Env.ComponentCatalog.RegisterAssembly(typeof(SaveOnnxCommand).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(TimeSeriesProcessingEntryPoints).Assembly); Env.ComponentCatalog.RegisterAssembly(typeof(ParquetLoader).Assembly); + Env.ComponentCatalog.RegisterAssembly(typeof(DateTimeTransformer).Assembly); var catalog = Env.ComponentCatalog; diff --git a/test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs b/test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs new file mode 100644 index 0000000000..849d765c21 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/Attributes/NotCentOS7FactAttribute.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Microsoft.ML.TestFrameworkCommon.Attributes; + +namespace Microsoft.ML.TestFramework.Attributes +{ + /// + /// A fact for tests that wont run on CentOS7 + /// + public sealed class NotCentOS7FactAttribute : EnvironmentSpecificFactAttribute + { + public NotCentOS7FactAttribute() : base("These tests are not CentOS7 complient.") + { + } + protected override bool IsEnvironmentSupported() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + using (Process process = new Process()) + { + process.StartInfo.FileName = "/bin/bash"; + process.StartInfo.Arguments= "-c \"cat /etc/*-release\""; + process.StartInfo.UseShellExecute = false; + process.StartInfo.RedirectStandardOutput = true; + process.StartInfo.CreateNoWindow = true; + process.Start(); + + string distro = process.StandardOutput.ReadToEnd().Trim(); + + process.WaitForExit(); + if (distro.Contains("CentOS Linux 7")) + { + return false; + } + } + } + return true; + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs new file mode 100644 index 0000000000..4b91b30586 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/DateTimeTransformerTests.cs @@ -0,0 +1,255 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.RunTests; +using Microsoft.ML.Featurizers; +using System; +using Xunit; +using Xunit.Abstractions; +using Microsoft.ML.TestFramework.Attributes; + +namespace Microsoft.ML.Tests.Transformers +{ + public class DateTimeTransformerTests : TestDataPipeBase + { + public DateTimeTransformerTests(ITestOutputHelper output) : base(output) + { + } + + private class DateTimeInput + { + public long date; + } + + [NotCentOS7FactAttribute] + public void CorrectNumberOfColumnsAndSchema() + { + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 0 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var columnPrefix = "DTC_"; + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", columnPrefix); + var model = pipeline.Fit(data); + var output = model.Transform(data); + var schema = output.Schema; + + // Check the schema has 22 columns + Assert.Equal(22, schema.Count); + + // Make sure names with prefix and order are correct + Assert.Equal($"{columnPrefix}Year", schema[1].Name); + Assert.Equal($"{columnPrefix}Month", schema[2].Name); + Assert.Equal($"{columnPrefix}Day", schema[3].Name); + Assert.Equal($"{columnPrefix}Hour", schema[4].Name); + Assert.Equal($"{columnPrefix}Minute", schema[5].Name); + Assert.Equal($"{columnPrefix}Second", schema[6].Name); + Assert.Equal($"{columnPrefix}AmPm", schema[7].Name); + Assert.Equal($"{columnPrefix}Hour12", schema[8].Name); + Assert.Equal($"{columnPrefix}DayOfWeek", schema[9].Name); + Assert.Equal($"{columnPrefix}DayOfQuarter", schema[10].Name); + Assert.Equal($"{columnPrefix}DayOfYear", schema[11].Name); + Assert.Equal($"{columnPrefix}WeekOfMonth", schema[12].Name); + Assert.Equal($"{columnPrefix}QuarterOfYear", schema[13].Name); + Assert.Equal($"{columnPrefix}HalfOfYear", schema[14].Name); + Assert.Equal($"{columnPrefix}WeekIso", schema[15].Name); + Assert.Equal($"{columnPrefix}YearIso", schema[16].Name); + Assert.Equal($"{columnPrefix}MonthLabel", schema[17].Name); + Assert.Equal($"{columnPrefix}AmPmLabel", schema[18].Name); + Assert.Equal($"{columnPrefix}DayOfWeekLabel", schema[19].Name); + Assert.Equal($"{columnPrefix}HolidayName", schema[20].Name); + Assert.Equal($"{columnPrefix}IsPaidTimeOff", schema[21].Name); + + // Make sure types are correct + Assert.Equal(typeof(int), schema[1].Type.RawType); + Assert.Equal(typeof(byte), schema[2].Type.RawType); + Assert.Equal(typeof(byte), schema[3].Type.RawType); + Assert.Equal(typeof(byte), schema[4].Type.RawType); + Assert.Equal(typeof(byte), schema[5].Type.RawType); + Assert.Equal(typeof(byte), schema[6].Type.RawType); + Assert.Equal(typeof(byte), schema[7].Type.RawType); + Assert.Equal(typeof(byte), schema[8].Type.RawType); + Assert.Equal(typeof(byte), schema[9].Type.RawType); + Assert.Equal(typeof(byte), schema[10].Type.RawType); + Assert.Equal(typeof(ushort), schema[11].Type.RawType); + Assert.Equal(typeof(ushort), schema[12].Type.RawType); + Assert.Equal(typeof(byte), schema[13].Type.RawType); + Assert.Equal(typeof(byte), schema[14].Type.RawType); + Assert.Equal(typeof(byte), schema[15].Type.RawType); + Assert.Equal(typeof(int), schema[16].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[17].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[18].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[19].Type.RawType); + Assert.Equal(typeof(ReadOnlyMemory), schema[20].Type.RawType); + Assert.Equal(typeof(byte), schema[21].Type.RawType); + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7FactAttribute] + public void CanUseDateFromColumn() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", "DTC"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7FactAttribute] + public void HolidayTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 157161600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", "DTC", country: DateTimeEstimator.HolidayList.Canada); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row for holidays is what we expect + Assert.Equal("Christmas Day", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7FactAttribute] + public void ManyRowsTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 }, + new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }, new DateTimeInput() { date = 1751241600 }, + new DateTimeInput() { date = 1751241600 }, new DateTimeInput() { date = 12341 }, + new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 134 }, new DateTimeInput() { date = 1234 }}; + + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var pipeline = mlContext.Transforms.FeaturizeDateTime("date", "DTC"); + var model = pipeline.Fit(data); + var output = model.Transform(data); + + // Get the data from the first row and make sure it matches expected + var row = output.Preview().RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + TestEstimatorCore(pipeline, data); + Done(); + } + + [NotCentOS7FactAttribute] + public void EntryPointTest() + { + // Future Date - 2025 June 30 + MLContext mlContext = new MLContext(1); + var dataList = new[] { new DateTimeInput() { date = 1751241600 } }; + var data = mlContext.Data.LoadFromEnumerable(dataList); + + // Build the pipeline, fit, and transform it. + var options = new DateTimeEstimator.Options + { + Source = "date", + Prefix = "pref_", + Data = data + }; + + var entryOutput = DateTimeTransformerEntrypoint.DateTimeSplit(mlContext.Transforms.GetEnvironment(), options); + var output = entryOutput.OutputData; + + // Get the data from the first row and make sure it matches expected + var row = output.Preview(1).RowView[0].Values; + + // Assert the data from the first row is what we expect + Assert.Equal(2025, row[1].Value); // Year + Assert.Equal((byte)6, row[2].Value); // Month + Assert.Equal((byte)30, row[3].Value); // Day + Assert.Equal((byte)0, row[4].Value); // Hour + Assert.Equal((byte)0, row[5].Value); // Minute + Assert.Equal((byte)0, row[6].Value); // Second + Assert.Equal((byte)0, row[7].Value); // AmPm + Assert.Equal((byte)0, row[8].Value); // Hour12 + Assert.Equal((byte)1, row[9].Value); // DayOfWeek + Assert.Equal((byte)91, row[10].Value); // DayOfQuarter + Assert.Equal((ushort)180, row[11].Value); // DayOfYear + Assert.Equal((ushort)4, row[12].Value); // WeekOfMonth + Assert.Equal((byte)2, row[13].Value); // QuarterOfYear + Assert.Equal((byte)1, row[14].Value); // HalfOfYear + Assert.Equal((byte)27, row[15].Value); // WeekIso + Assert.Equal(2025, row[16].Value); // YearIso + Assert.Equal("June", row[17].Value.ToString()); // MonthLabel + Assert.Equal("am", row[18].Value.ToString()); // AmPmLabel + Assert.Equal("Monday", row[19].Value.ToString()); // DayOfWeekLabel + Assert.Equal("", row[20].Value.ToString()); // HolidayName + Assert.Equal((byte)0, row[21].Value); // IsPaidTimeOff + + Done(); + } + } +}