diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnn.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnn.cs new file mode 100644 index 0000000000..9b17e86244 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnn.cs @@ -0,0 +1,125 @@ +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.ML; +using Microsoft.ML.Data; +using Microsoft.ML.Transforms.TimeSeries; + +namespace Samples.Dynamic +{ + public static class DetectAnomalyBySrCnn + { + // This example creates a time series (list of Data with the i-th element corresponding to the i-th time slot). + // The estimator is applied then to identify spiking points in the series. + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Generate sample series data with an anomaly + var data = new List(); + for (int index = 0; index < 20; index++) + { + data.Add(new TimeSeriesData(5)); + } + data.Add(new TimeSeriesData(10)); + for (int index = 0; index < 5; index++) + { + data.Add(new TimeSeriesData(5)); + } + + // Convert data to IDataView. + var dataView = ml.Data.LoadFromEnumerable(data); + + // Setup the estimator arguments + string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction); + string inputColumnName = nameof(TimeSeriesData.Value); + + // The transformed model. + ITransformer model = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 16, 5, 5, 3, 8, 0.35).Fit(dataView); + + // Create a time series prediction engine from the model. + var engine = model.CreateTimeSeriesPredictionFunction(ml); + + Console.WriteLine($"{outputColumnName} column obtained post-transformation."); + Console.WriteLine("Data\tAlert\tScore\tMag"); + + // Prediction column obtained post-transformation. + // Data Alert Score Mag + + // Create non-anomalous data and check for anomaly. + for (int index = 0; index < 20; index++) + { + // Anomaly detection. + PrintPrediction(5, engine.Predict(new TimeSeriesData(5))); + } + + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + + // Anomaly. + PrintPrediction(10, engine.Predict(new TimeSeriesData(10))); + + //10 1 0.47 0.93 <-- alert is on, predicted anomaly + + // Checkpoint the model. + var modelPath = "temp.zip"; + engine.CheckPoint(ml, modelPath); + + // Load the model. + using (var file = File.OpenRead(modelPath)) + model = ml.Model.Load(file, out DataViewSchema schema); + + for (int index = 0; index < 5; index++) + { + // Anomaly detection. + PrintPrediction(5, engine.Predict(new TimeSeriesData(5))); + } + + //5 0 0.31 0.50 + //5 0 0.05 0.30 + //5 0 0.01 0.23 + //5 0 0.00 0.21 + //5 0 0.01 0.25 + } + + private static void PrintPrediction(float value, SrCnnAnomalyDetection prediction) => + Console.WriteLine("{0}\t{1}\t{2:0.00}\t{3:0.00}", value, prediction.Prediction[0], + prediction.Prediction[1], prediction.Prediction[2]); + + private class TimeSeriesData + { + public float Value; + + public TimeSeriesData(float value) + { + Value = value; + } + } + + private class SrCnnAnomalyDetection + { + [VectorType(3)] + public double[] Prediction { get; set; } + } + } +} diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnnBatchPrediction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnnBatchPrediction.cs new file mode 100644 index 0000000000..ef1d2a0de9 --- /dev/null +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/TimeSeries/DetectAnomalyBySrCnnBatchPrediction.cs @@ -0,0 +1,98 @@ +using System; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; + +namespace Samples.Dynamic +{ + public static class DetectAnomalyBySrCnnBatchPrediction + { + public static void Example() + { + // Create a new ML context, for ML.NET operations. It can be used for exception tracking and logging, + // as well as the source of randomness. + var ml = new MLContext(); + + // Generate sample series data with an anomaly + var data = new List(); + for (int index = 0; index < 20; index++) + { + data.Add(new TimeSeriesData(5)); + } + data.Add(new TimeSeriesData(10)); + for (int index = 0; index < 5; index++) + { + data.Add(new TimeSeriesData(5)); + } + + // Convert data to IDataView. + var dataView = ml.Data.LoadFromEnumerable(data); + + // Setup the estimator arguments + string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction); + string inputColumnName = nameof(TimeSeriesData.Value); + + // The transformed data. + var transformedData = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 16, 5, 5, 3, 8, 0.35).Fit(dataView).Transform(dataView); + + // Getting the data of the newly created column as an IEnumerable of SrCnnAnomalyDetection. + var predictionColumn = ml.Data.CreateEnumerable(transformedData, reuseRowObject: false); + + Console.WriteLine($"{outputColumnName} column obtained post-transformation."); + Console.WriteLine("Data\tAlert\tScore\tMag"); + + int k = 0; + foreach (var prediction in predictionColumn) + PrintPrediction(data[k++].Value, prediction); + + //Prediction column obtained post-transformation. + //Data Alert Score Mag + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.00 0.00 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //5 0 0.03 0.18 + //10 1 0.47 0.93 + //5 0 0.31 0.50 + //5 0 0.05 0.30 + //5 0 0.01 0.23 + //5 0 0.00 0.21 + //5 0 0.01 0.25 + } + + private static void PrintPrediction(float value, SrCnnAnomalyDetection prediction) => + Console.WriteLine("{0}\t{1}\t{2:0.00}\t{3:0.00}", value, prediction.Prediction[0], + prediction.Prediction[1], prediction.Prediction[2]); + + private class TimeSeriesData + { + public float Value; + + public TimeSeriesData(float value) + { + Value = value; + } + } + + private class SrCnnAnomalyDetection + { + [VectorType(3)] + public double[] Prediction { get; set; } + } + } +} diff --git a/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs b/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs index e4a3da6761..f5261db8cc 100644 --- a/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.TimeSeries/ExtensionsCatalog.cs @@ -121,5 +121,29 @@ public static SsaChangePointEstimator DetectChangePointBySsa(this TransformsCata public static SsaSpikeEstimator DetectSpikeBySsa(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, int confidence, int pvalueHistoryLength, int trainingWindowSize, int seasonalityWindowSize, AnomalySide side = AnomalySide.TwoSided, ErrorFunction errorFunction = ErrorFunction.SignedDifference) => new SsaSpikeEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, confidence, pvalueHistoryLength, trainingWindowSize, seasonalityWindowSize, inputColumnName, side, errorFunction); + + /// + /// Create , which detects timeseries anomalies using SRCNN algorithm. + /// + /// The transform's catalog. + /// Name of the column resulting from the transformation of . + /// The column data is a vector of . The vector contains 3 elements: alert (1 means anomaly while 0 means normal), raw score, and magnitude of spectual residual. + /// Name of column to transform. The column data must be . + /// The size of the sliding window for computing spectral residual. + /// The number of points to add back of training window. No more than windowSize, usually keep default value. + /// The number of pervious points used in prediction. No more than windowSize, usually keep default value. + /// The size of sliding window to generate a saliency map for the series. No more than windowSize, usually keep default value. + /// The size of sliding window to calculate the anomaly score for each data point. No more than windowSize. + /// The threshold to determine anomaly, score larger than the threshold is considered as anomaly. Should be in (0,1) + /// + /// + /// + /// + /// + public static SrCnnAnomalyEstimator DetectAnomalyBySrCnn(this TransformsCatalog catalog, string outputColumnName, string inputColumnName, + int windowSize=64, int backAddWindowSize=5, int lookaheadWindowSize=5, int averageingWindowSize=3, int judgementWindowSize=21, double threshold=0.3) + => new SrCnnAnomalyEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, windowSize, backAddWindowSize, lookaheadWindowSize, averageingWindowSize, judgementWindowSize, threshold, inputColumnName); } } diff --git a/src/Microsoft.ML.TimeSeries/SRCNNAnomalyDetector.cs b/src/Microsoft.ML.TimeSeries/SRCNNAnomalyDetector.cs new file mode 100644 index 0000000000..31f00c10d1 --- /dev/null +++ b/src/Microsoft.ML.TimeSeries/SRCNNAnomalyDetector.cs @@ -0,0 +1,287 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Transforms.TimeSeries; + +[assembly: LoadableClass(SrCnnAnomalyDetector.Summary, typeof(IDataTransform), typeof(SrCnnAnomalyDetector), typeof(SrCnnAnomalyDetector.Options), typeof(SignatureDataTransform), + SrCnnAnomalyDetector.UserName, SrCnnAnomalyDetector.LoaderSignature, SrCnnAnomalyDetector.ShortName)] + +[assembly: LoadableClass(SrCnnAnomalyDetector.Summary, typeof(IDataTransform), typeof(SrCnnAnomalyDetector), null, typeof(SignatureLoadDataTransform), + SrCnnAnomalyDetector.UserName, SrCnnAnomalyDetector.LoaderSignature)] + +[assembly: LoadableClass(SrCnnAnomalyDetector.Summary, typeof(SrCnnAnomalyDetector), null, typeof(SignatureLoadModel), + SrCnnAnomalyDetector.UserName, SrCnnAnomalyDetector.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(SrCnnAnomalyDetector), null, typeof(SignatureLoadRowMapper), + SrCnnAnomalyDetector.UserName, SrCnnAnomalyDetector.LoaderSignature)] + +namespace Microsoft.ML.Transforms.TimeSeries +{ + /// + /// resulting from fitting a . + /// + public sealed class SrCnnAnomalyDetector : SrCnnAnomalyDetectionBase, IStatefulTransformer + { + internal const string Summary = "This transform detects the anomalies in a time-series using SRCNN."; + internal const string LoaderSignature = "SrCnnAnomalyDetector"; + internal const string UserName = "SrCnn Anomaly Detection"; + internal const string ShortName = "srcnn"; + + internal sealed class Options : TransformInputBase + { + [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src", + SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] + public string Source; + + [Argument(ArgumentType.Required, HelpText = "The name of the new column.", + SortOrder = 2)] + public string Name; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing spectral residual", ShortName = "wnd", + SortOrder = 101)] + public int WindowSize = 24; + + [Argument(ArgumentType.Required, HelpText = "The number of points to the back of training window.", + ShortName = "backwnd", SortOrder = 102)] + public int BackAddWindowSize = 5; + + [Argument(ArgumentType.Required, HelpText = "The number of pervious points used in prediction.", + ShortName = "aheadwnd", SortOrder = 103)] + public int LookaheadWindowSize = 5; + + [Argument(ArgumentType.Required, HelpText = "The size of sliding window to generate a saliency map for the series.", + ShortName = "avgwnd", SortOrder = 104)] + public int AvergingWindowSize = 3; + + [Argument(ArgumentType.Required, HelpText = "The size of sliding window to calculate the anomaly score for each data point.", + ShortName = "jdgwnd", SortOrder = 105)] + public int JudgementWindowSize = 21; + + [Argument(ArgumentType.Required, HelpText = "The threshold to determine anomaly, score larger than the threshold is considered as anomaly.", + ShortName = "thre", SortOrder = 106)] + public double Threshold = 0.3; + } + + private sealed class SrCnnArgument : SrCnnArgumentBase + { + public SrCnnArgument(Options options) + { + Source = options.Source; + Name = options.Name; + WindowSize = options.WindowSize; + InitialWindowSize = 0; + BackAddWindowSize = options.BackAddWindowSize; + LookaheadWindowSize = options.LookaheadWindowSize; + AvergingWindowSize = options.AvergingWindowSize; + JudgementWindowSize = options.JudgementWindowSize; + Threshold = options.Threshold; + } + + public SrCnnArgument(SrCnnAnomalyDetector transform) + { + Source = transform.InternalTransform.InputColumnName; + Name = transform.InternalTransform.OutputColumnName; + WindowSize = transform.InternalTransform.WindowSize; + InitialWindowSize = 0; + BackAddWindowSize = transform.InternalTransform.BackAddWindowSize; + LookaheadWindowSize = transform.InternalTransform.LookaheadWindowSize; + AvergingWindowSize = transform.InternalTransform.AvergingWindowSize; + JudgementWindowSize = transform.InternalTransform.JudgementWindowSize; + Threshold = transform.InternalTransform.AlertThreshold; + } + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "SRCNTRNS", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(SrCnnAnomalyDetector).Assembly.FullName); + } + + private static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(options, nameof(options)); + env.CheckValue(input, nameof(input)); + + return new SrCnnAnomalyDetector(env, options).MakeDataTransform(input); + } + + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + env.CheckValue(input, nameof(input)); + + return new SrCnnAnomalyDetector(env, ctx).MakeDataTransform(input); + } + + private static SrCnnAnomalyDetector Create(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + return new SrCnnAnomalyDetector(env, ctx); + } + + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, DataViewSchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + IStatefulTransformer IStatefulTransformer.Clone() + { + var clone = (SrCnnAnomalyDetector)MemberwiseClone(); + clone.InternalTransform.StateRef = (SrCnnAnomalyDetectionBaseCore.State)clone.InternalTransform.StateRef.Clone(); + clone.InternalTransform.StateRef.InitState(clone.InternalTransform, InternalTransform.Host); + return clone; + } + + internal SrCnnAnomalyDetector(IHostEnvironment env, Options options) + : base(new SrCnnArgument(options), LoaderSignature, env) + { + } + + internal SrCnnAnomalyDetector(IHostEnvironment env, ModelLoadContext ctx) + : base(env, ctx, LoaderSignature) + { + } + + private SrCnnAnomalyDetector(IHostEnvironment env, SrCnnAnomalyDetector transform) + : base(new SrCnnArgument(transform), LoaderSignature, env) + { + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + InternalTransform.Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // + base.SaveModel(ctx); + } + } + + /// + /// Detect anomalies in time series using Spectral Residual(SR) algorithm + /// + /// + /// | + /// | Output column data type | 3-element vector of | + /// ### Background + /// At Microsoft, we develop a time-series anomaly detection service which helps customers to monitor the time-series continuously + /// and alert for potential incidents on time. To tackle the problem of time-series anomaly detection, + /// we propose a novel algorithm based on Spectral Residual (SR) and Convolutional Neural Network + /// (CNN). The SR model is borrowed from visual saliency detection domain to time-series anomaly detection. + /// And here we onboarded this SR algorithm firstly. + /// + /// The Spectral Residual (SR) algorithm is unsupervised, which means training step is not needed while using SR. It consists of three major steps: + /// (1) Fourier Transform to get the log amplitude spectrum; + /// (2) calculation of spectral residual; + /// (3) Inverse Fourier Transform that transforms the sequence back to spatial domain. + /// Mathematically, given a sequence $\mathbf{x}$, we have + /// $$A(f) = Amplitude(\mathfrak{F}(\mathbf{x}))\\P(f) = Phrase(\mathfrak{F}(\mathbf{x}))\\L(f) = log(A(f))\\AL(f) = h_n(f) \cdot L(f)\\R(f) = L(f) - AL(f)\\S(\mathbf{x}) = \mathfrak{F}^{-1}(exp(R(f) + P(f))^{2})$$ + /// where $\mathfrak{F}$ and $\mathfrak{F}^{-1}$ denote Fourier Transform and Inverse Fourier Transform respectively. + /// $\mathbf{x}$ is the input sequence with shape $n × 1$; $A(f)$ is the amplitude spectrum of sequence $\mathbf{x}$; + /// $P(f)$ is the corresponding phase spectrum of sequence $\mathbf{x}$; $L(f)$ is the log representation of $A(f)$; + /// and $AL(f)$ is the average spectrum of $L(f)$ which can be approximated by convoluting the input sequence by $h_n(f)$, + /// where $h_n(f)$ is an $n × n$ matrix defined as: + /// $$n_f(f) = \begin{bmatrix}1&1&1&\cdots&1\\1&1&1&\cdots&1\\\vdots&\vdots&\vdots&\ddots&\vdots\\1&1&1&\cdots&1\end{bmatrix}$$ + /// $R(f)$ is the spectral residual, i.e., the log spectrum $L(f)$ subtracting the averaged log spectrum $AL(f)$. + /// The spectral residual serves as a compressed representation of the sequence while the innovation part of the original sequence becomes more significant. + /// At last, we transfer the sequence back to spatial domain via Inverse Fourier Transform. The result sequence $S(\mathbf{x})$ is called the saliency map. + /// Given the saliency map $S(\mathbf{x})$, the output sequence $O(\mathbf{x})$ is computed by: + /// $$O(x_i) = \begin{cases}1, if \frac{S(x_i)-\overline{S(x_i)}}{S(x_i)} > \tau\\0,otherwise,\end{cases}$$ + /// where $x_i$ represents an arbitrary point in sequence $\mathbf{x}$; $S(x_i)$is the corresponding point in the saliency map; + /// and $\overline{S(x_i)}$ is the local average of the preceding points of $S(x_i)$. + /// + /// There are several parameters for SR algorithm. To obtain a model with good performance, + /// we suggest to tune windowSize and threshold at first, + /// these are the most important parameters to SR. Then you could search for an appropriate judgementWindowSize + /// which is no larger than windowSize. And for the remaining parameters, you could use the default value directly. + /// + /// * Link to the KDD 2019 paper will be updated after it goes public. + /// ]]> + /// + /// + /// + public sealed class SrCnnAnomalyEstimator : TrivialEstimator + { + /// Host environment. + /// Name of the column resulting from the transformation of . + /// The size of the sliding window for computing spectral residual. + /// The size of the sliding window for computing spectral residual. + /// The number of pervious points used in prediction. + /// The size of sliding window to generate a saliency map for the series. + /// The size of sliding window to calculate the anomaly score for each data point. + /// The threshold to determine anomaly, score larger than the threshold is considered as anomaly. + /// Name of column to transform. The column data must be . + internal SrCnnAnomalyEstimator(IHostEnvironment env, + string outputColumnName, + int windowSize, + int backAddWindowSize, + int lookaheadWindowSize, + int averagingWindowSize, + int judgementWindowSize, + double threshold = 0.3, + string inputColumnName = null) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(SrCnnAnomalyEstimator)), + new SrCnnAnomalyDetector(env, new SrCnnAnomalyDetector.Options + { + Source = inputColumnName ?? outputColumnName, + Name = outputColumnName, + WindowSize = windowSize, + BackAddWindowSize = backAddWindowSize, + LookaheadWindowSize = lookaheadWindowSize, + AvergingWindowSize = averagingWindowSize, + JudgementWindowSize = judgementWindowSize, + Threshold = threshold + })) + { + } + + internal SrCnnAnomalyEstimator(IHostEnvironment env, SrCnnAnomalyDetector.Options options) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(SrCnnAnomalyEstimator)), new SrCnnAnomalyDetector(env, options)) + { + } + + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryFindColumn(Transformer.InternalTransform.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName); + if (col.ItemType != NumberDataViewType.Single) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Transformer.InternalTransform.InputColumnName, "Single", col.GetTypeString()); + + var metadata = new List() { + new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextDataViewType.Instance, false) + }; + var resultDic = inputSchema.ToDictionary(x => x.Name); + resultDic[Transformer.InternalTransform.OutputColumnName] = new SchemaShape.Column( + Transformer.InternalTransform.OutputColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Double, false, new SchemaShape(metadata)); + + return new SchemaShape(resultDic.Values); + } + + } +} diff --git a/src/Microsoft.ML.TimeSeries/SrCnnAnomalyDetectionBase.cs b/src/Microsoft.ML.TimeSeries/SrCnnAnomalyDetectionBase.cs new file mode 100644 index 0000000000..566b5b6cd2 --- /dev/null +++ b/src/Microsoft.ML.TimeSeries/SrCnnAnomalyDetectionBase.cs @@ -0,0 +1,298 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Transforms.TimeSeries +{ + public class SrCnnAnomalyDetectionBase : IStatefulTransformer, ICanSaveModel + { + /// + /// Whether a call to should succeed, on an + /// appropriate schema. + /// + bool ITransformer.IsRowToRowMapper => ((ITransformer)InternalTransform).IsRowToRowMapper; + + /// + /// Create a clone of the transformer. Used for taking the snapshot of the state. + /// + IStatefulTransformer IStatefulTransformer.Clone() => InternalTransform.Clone(); + + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// + public DataViewSchema GetOutputSchema(DataViewSchema inputSchema) => InternalTransform.GetOutputSchema(inputSchema); + + /// + /// Constructs a row-to-row mapper based on an input schema. If + /// is false, then an exception should be thrown. If the input schema is in any way + /// unsuitable for constructing the mapper, an exception should likewise be thrown. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema) + => ((ITransformer)InternalTransform).GetRowToRowMapper(inputSchema); + + /// + /// Same as but also supports mechanism to save the state. + /// + /// The input schema for which we should get the mapper. + /// The row to row mapper. + public IRowToRowMapper GetStatefulRowToRowMapper(DataViewSchema inputSchema) + => ((IStatefulTransformer)InternalTransform).GetStatefulRowToRowMapper(inputSchema); + + /// + /// Initialize a transformer which will do lambda transfrom on input data in prediction engine. No actual transformations happen here, just schema validation. + /// + public IDataView Transform(IDataView input) => InternalTransform.Transform(input); + + /// + /// For saving a model into a repository. + /// + void ICanSaveModel.Save(ModelSaveContext ctx) => SaveModel(ctx); + + private protected virtual void SaveModel(ModelSaveContext ctx) + { + InternalTransform.SaveThis(ctx); + } + + /// + /// Creates a row mapper from Schema. + /// + internal IStatefulRowMapper MakeRowMapper(DataViewSchema schema) => InternalTransform.MakeRowMapper(schema); + + /// + /// Creates an IDataTransform from an IDataView. + /// + internal IDataTransform MakeDataTransform(IDataView input) => InternalTransform.MakeDataTransform(input); + + internal SrCnnAnomalyDetectionBaseCore InternalTransform { get; } + + internal SrCnnAnomalyDetectionBase(SrCnnArgumentBase args, string name, IHostEnvironment env) + { + InternalTransform = new SrCnnAnomalyDetectionBaseCore(args, name, env, this); + } + + internal SrCnnAnomalyDetectionBase(IHostEnvironment env, ModelLoadContext ctx, string name) + { + InternalTransform = new SrCnnAnomalyDetectionBaseCore(env, ctx, name, this); + } + + internal sealed class SrCnnAnomalyDetectionBaseCore : SrCnnTransformBase + { + internal SrCnnAnomalyDetectionBase Parent; + + public SrCnnAnomalyDetectionBaseCore(SrCnnArgumentBase args, string name, IHostEnvironment env, SrCnnAnomalyDetectionBase parent) + : base(args, name, env) + { + InitialWindowSize = WindowSize; + StateRef = new State(); + StateRef.InitState(WindowSize, InitialWindowSize, this, Host); + Parent = parent; + } + + public SrCnnAnomalyDetectionBaseCore(IHostEnvironment env, ModelLoadContext ctx, string name, SrCnnAnomalyDetectionBase parent) + : base(env, ctx, name) + { + StateRef = new State(ctx.Reader); + StateRef.InitState(this, Host); + Parent = parent; + } + + public override DataViewSchema GetOutputSchema(DataViewSchema inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + + if (!inputSchema.TryGetColumnIndex(InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName); + + var colType = inputSchema[col].Type; + if (colType != NumberDataViewType.Single) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", InputColumnName, NumberDataViewType.Single.ToString(), colType.ToString()); + + return Transform(new EmptyDataView(Host, inputSchema)).Schema; + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + ((ICanSaveModel)Parent).Save(ctx); + } + + internal void SaveThis(ModelSaveContext ctx) + { + ctx.CheckAtModel(); + base.SaveModel(ctx); + + // *** Binary format *** + // + // State: StateRef + StateRef.Save(ctx.Writer); + } + + internal sealed class State : SrCnnStateBase + { + public State() + { + } + + internal State(BinaryReader reader) : base(reader) + { + WindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + InitialWindowedBuffer = TimeSeriesUtils.DeserializeFixedSizeQueueSingle(reader, Host); + } + + internal override void Save(BinaryWriter writer) + { + base.Save(writer); + TimeSeriesUtils.SerializeFixedSizeQueue(WindowedBuffer, writer); + TimeSeriesUtils.SerializeFixedSizeQueue(InitialWindowedBuffer, writer); + } + + private protected override void CloneCore(State state) + { + base.CloneCore(state); + Contracts.Assert(state is State); + var stateLocal = state as State; + stateLocal.WindowedBuffer = WindowedBuffer.Clone(); + stateLocal.InitialWindowedBuffer = InitialWindowedBuffer.Clone(); + } + + private protected override void LearnStateFromDataCore(FixedSizeQueue data) + { + } + + private protected override sealed void SpectralResidual(Single input, FixedSizeQueue data, ref VBufferEditor result) + { + // Step 1: Get backadd wave + List backAddList = BackAdd(data); + + // Step 2: FFT transformation + int length = backAddList.Count; + float[] fftRe = new float[length]; + float[] fftIm = new float[length]; + FftUtils.ComputeForwardFft(backAddList.ToArray(), Enumerable.Repeat(0.0f, length).ToArray(), fftRe, fftIm, length); + + // Step 3: Calculate mags of FFT + List magList = new List(); + for (int i = 0; i < length; ++i) + { + magList.Add(MathUtils.Sqrt((fftRe[i] * fftRe[i] + fftIm[i] * fftIm[i]))); + } + + // Step 4: Calculate spectral + List magLogList = magList.Select(x => x != 0 ? MathUtils.Log(x) : 0).ToList(); + List filteredLogList = AverageFilter(magLogList, Parent.AvergingWindowSize); + List spectralList = new List(); + for (int i = 0; i < magLogList.Count; ++i) + { + spectralList.Add(MathUtils.ExpSlow(magLogList[i] - filteredLogList[i])); + } + + // Step 5: IFFT transformation + float[] transRe = new float[length]; + float[] transIm = new float[length]; + for (int i = 0; i < length; ++i) + { + if (magLogList[i] != 0) + { + transRe[i] = fftRe[i] * spectralList[i] / magList[i]; + transIm[i] = fftIm[i] * spectralList[i] / magList[i]; + } + else + { + transRe[i] = 0; + transIm[i] = 0; + } + } + + float[] ifftRe = new float[length]; + float[] ifftIm = new float[length]; + FftUtils.ComputeBackwardFft(transRe, transIm, ifftRe, ifftIm, length); + + // Step 6: Calculate mag and ave_mag of IFFT + List ifftMagList = new List(); + for (int i = 0; i < length; ++i) + { + ifftMagList.Add(MathUtils.Sqrt((ifftRe[i] * ifftRe[i] + ifftIm[i] * ifftIm[i]))); + } + List filteredIfftMagList = AverageFilter(ifftMagList, Parent.JudgementWindowSize); + + // Step 7: Calculate score and set result + var score = CalculateSocre(ifftMagList[data.Count - 1], filteredIfftMagList[data.Count - 1]); + score /= 10.0f; + result.Values[1] = score; + + score = Math.Min(score, 1); + score = Math.Max(score, 0); + var detres = score > Parent.AlertThreshold ? 1 : 0; + result.Values[0] = detres; + + var mag = ifftMagList[data.Count - 1]; + result.Values[2] = mag; + } + + private List BackAdd(FixedSizeQueue data) + { + List predictArray = new List(); + for (int i = data.Count - Parent.LookaheadWindowSize - 2; i < data.Count - 1; ++i) + { + predictArray.Add(data[i]); + } + var predictedValue = PredictNext(predictArray); + List backAddArray = new List(); + for (int i = 0; i < data.Count; ++i) + { + backAddArray.Add(data[i]); + } + backAddArray.AddRange(Enumerable.Repeat(predictedValue, Parent.BackAddWindowSize)); + return backAddArray; + } + + private Single PredictNext(List data) + { + var n = data.Count; + Single slopeSum = 0.0f; + for (int i = 0; i < n - 1; ++i) + { + slopeSum += (data[n-1] - data[i]) / (n - 1 - i); + } + return (data[1] + slopeSum); + } + + private List AverageFilter(List data, int n) + { + Single cumsum = 0.0f; + List cumSumList = data.Select(x => cumsum += x).ToList(); + List cumSumShift = new List(cumSumList); + for (int i = n; i < cumSumList.Count; ++i) + { + cumSumList[i] = (cumSumList[i] - cumSumShift[i - n]) / n; + } + for (int i = 1; i < n; ++i) + { + cumSumList[i] /= (i + 1); + } + return cumSumList; + } + + private Single CalculateSocre(Single mag, Single avgMag) + { + double safeDivisor = avgMag; + if (safeDivisor < 1e-8) + { + safeDivisor = 1e-8; + } + return (float)(Math.Abs(mag - avgMag) / safeDivisor); + } + } + } + } +} diff --git a/src/Microsoft.ML.TimeSeries/SrCnnTransformBase.cs b/src/Microsoft.ML.TimeSeries/SrCnnTransformBase.cs new file mode 100644 index 0000000000..9f64fa2167 --- /dev/null +++ b/src/Microsoft.ML.TimeSeries/SrCnnTransformBase.cs @@ -0,0 +1,318 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Threading; +using Microsoft.ML.CommandLine; +using Microsoft.ML.Data; +using Microsoft.ML.Internal.Utilities; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Transforms.TimeSeries +{ + internal abstract class SrCnnArgumentBase + { + [Argument(ArgumentType.Required, HelpText = "The name of the source column.", ShortName = "src", + SortOrder = 1, Purpose = SpecialPurpose.ColumnName)] + public string Source; + + [Argument(ArgumentType.Required, HelpText = "The name of the new column.", + SortOrder = 2)] + public string Name; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the sliding window for computing spectral residual", ShortName = "wnd", + SortOrder = 3)] + public int WindowSize = 24; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The size of the initial window for computingd. The default value is set to 0, which means there is no initial window considered.", ShortName = "iwnd", + SortOrder = 4)] + public int InitialWindowSize = 0; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of points to the back of training window.", + ShortName = "backwnd", SortOrder = 5)] + public int BackAddWindowSize = 5; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The number of pervious points used in prediction.", + ShortName = "aheadwnd", SortOrder = 6)] + public int LookaheadWindowSize = 5; + + [Argument(ArgumentType.Required, HelpText = "The size of sliding window to generate a saliency map for the series.", + ShortName = "avgwnd", SortOrder = 7)] + public int AvergingWindowSize = 3; + + [Argument(ArgumentType.Required, HelpText = "The size of sliding window to generate a saliency map for the series.", + ShortName = "jdgwnd", SortOrder = 8)] + public int JudgementWindowSize = 21; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The threshold to determine anomaly, score larger than the threshold is considered as anomaly.", + ShortName = "thre", SortOrder = 9)] + public double Threshold = 0.3; + } + + internal abstract class SrCnnTransformBase : SequentialTransformerBase, TState> + where TState : SrCnnTransformBase.SrCnnStateBase, new() + { + internal int BackAddWindowSize { get; } + + internal int LookaheadWindowSize { get; } + + internal int AvergingWindowSize { get; } + + internal int JudgementWindowSize { get; } + + internal double AlertThreshold { get; } + + internal int OutputLength { get; } + + private protected SrCnnTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, + int backAddWindowSize, int lookaheadWindowSize, int averagingWindowSize, int judgementWindowSize, Double alertThreshold) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorDataViewType(NumberDataViewType.Double, 3)) + { + Host.CheckUserArg(backAddWindowSize > 0, nameof(SrCnnArgumentBase.BackAddWindowSize), "Must be non-negative"); + Host.CheckUserArg(lookaheadWindowSize > 0 && lookaheadWindowSize <= windowSize, nameof(SrCnnArgumentBase.LookaheadWindowSize), "Must be non-negative and not larger than window size"); + Host.CheckUserArg(averagingWindowSize > 0 && averagingWindowSize <= windowSize, nameof(SrCnnArgumentBase.AvergingWindowSize), "Must be non-negative and not larger than window size"); + Host.CheckUserArg(judgementWindowSize > 0 && judgementWindowSize <= windowSize, nameof(SrCnnArgumentBase.JudgementWindowSize), "Must be non-negative and not larger than window size"); + Host.CheckUserArg(alertThreshold > 0 && alertThreshold < 1, nameof(SrCnnArgumentBase.Threshold), "Must be in (0,1)"); + + BackAddWindowSize = backAddWindowSize; + LookaheadWindowSize = lookaheadWindowSize; + AvergingWindowSize = averagingWindowSize; + JudgementWindowSize = judgementWindowSize; + AlertThreshold = alertThreshold; + + OutputLength = 3; + } + + private protected SrCnnTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), ctx) + { + OutputLength = 3; + + byte temp; + temp = ctx.Reader.ReadByte(); + BackAddWindowSize = (int)temp; + Host.CheckDecode(BackAddWindowSize > 0); + + temp = ctx.Reader.ReadByte(); + LookaheadWindowSize = (int)temp; + Host.CheckDecode(LookaheadWindowSize > 0); + + temp = ctx.Reader.ReadByte(); + AvergingWindowSize = (int)temp; + Host.CheckDecode(AvergingWindowSize > 0); + + temp = ctx.Reader.ReadByte(); + JudgementWindowSize = (int)temp; + Host.CheckDecode(JudgementWindowSize > 0); + + AlertThreshold = ctx.Reader.ReadDouble(); + Host.CheckDecode(AlertThreshold >= 0 && AlertThreshold <= 1); + } + + private protected SrCnnTransformBase(SrCnnArgumentBase args, string name, IHostEnvironment env) + : this(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, + name, env, args.BackAddWindowSize, args.LookaheadWindowSize, args.AvergingWindowSize, args.JudgementWindowSize, args.Threshold) + { + } + + private protected override void SaveModel(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + + Host.Assert(WindowSize > 0); + Host.Assert(InitialWindowSize == WindowSize); + Host.Assert(BackAddWindowSize > 0); + Host.Assert(LookaheadWindowSize > 0); + Host.Assert(AvergingWindowSize > 0); + Host.Assert(JudgementWindowSize > 0); + Host.Assert(AlertThreshold >= 0 && AlertThreshold <= 1); + + base.SaveModel(ctx); + ctx.Writer.Write((byte)BackAddWindowSize); + ctx.Writer.Write((byte)LookaheadWindowSize); + ctx.Writer.Write((byte)AvergingWindowSize); + ctx.Writer.Write((byte)JudgementWindowSize); + ctx.Writer.Write(AlertThreshold); + } + + internal override IStatefulRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(Host, this, schema); + + internal sealed class Mapper : IStatefulRowMapper + { + private readonly IHost _host; + private readonly SrCnnTransformBase _parent; + private readonly DataViewSchema _parentSchema; + private readonly int _inputColumnIndex; + private readonly VBuffer> _slotNames; + private SrCnnStateBase State { get; set; } + + public Mapper(IHostEnvironment env, SrCnnTransformBase parent, DataViewSchema inputSchema) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(Mapper)); + _host.CheckValue(inputSchema, nameof(inputSchema)); + _host.CheckValue(parent, nameof(parent)); + + if (!inputSchema.TryGetColumnIndex(parent.InputColumnName, out _inputColumnIndex)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName); + + var colType = inputSchema[_inputColumnIndex].Type; + if (colType != NumberDataViewType.Single) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.InputColumnName, "Single", colType.ToString()); + + _parent = parent; + _parentSchema = inputSchema; + _slotNames = new VBuffer>(_parent.OutputLength, new[] { "Alert".AsMemory(), "Raw Score".AsMemory(), + "Mag".AsMemory()}); + + State = (SrCnnStateBase)_parent.StateRef; + } + + public DataViewSchema.DetachedColumn[] GetOutputColumns() + { + var meta = new DataViewSchema.Annotations.Builder(); + meta.AddSlotNames(_parent.OutputLength, GetSlotNames); + var info = new DataViewSchema.DetachedColumn[1]; + info[0] = new DataViewSchema.DetachedColumn(_parent.OutputColumnName, new VectorDataViewType(NumberDataViewType.Double, _parent.OutputLength), meta.ToAnnotations()); + return info; + } + + public void GetSlotNames(ref VBuffer> dst) => _slotNames.CopyTo(ref dst, 0, _parent.OutputLength); + + public Func GetDependencies(Func activeOutput) + { + if (activeOutput(0)) + return col => col == _inputColumnIndex; + else + return col => false; + } + + void ICanSaveModel.Save(ModelSaveContext ctx) => _parent.SaveModel(ctx); + + public Delegate[] CreateGetters(DataViewRow input, Func activeOutput, out Action disposer) + { + disposer = null; + var getters = new Delegate[1]; + if (activeOutput(0)) + getters[0] = MakeGetter(input, State); + + return getters; + } + + private delegate void ProcessData(ref TInput src, ref VBuffer dst); + + private Delegate MakeGetter(DataViewRow input, SrCnnStateBase state) + { + _host.AssertValue(input); + var srcGetter = input.GetGetter(input.Schema[_inputColumnIndex]); + ProcessData processData = _parent.WindowSize > 0 ? + (ProcessData)state.Process : state.ProcessWithoutBuffer; + + ValueGetter> valueGetter = (ref VBuffer dst) => + { + TInput src = default; + srcGetter(ref src); + processData(ref src, ref dst); + }; + return valueGetter; + } + + public Action CreatePinger(DataViewRow input, Func activeOutput, out Action disposer) + { + disposer = null; + Action pinger = null; + if (activeOutput(0)) + pinger = MakePinger(input, State); + + return pinger; + } + + private Action MakePinger(DataViewRow input, SrCnnStateBase state) + { + _host.AssertValue(input); + var srcGetter = input.GetGetter(input.Schema[_inputColumnIndex]); + Action pinger = (long rowPosition) => + { + TInput src = default; + srcGetter(ref src); + state.UpdateState(ref src, rowPosition, _parent.WindowSize > 0); + }; + return pinger; + } + + public void CloneState() + { + if (Interlocked.Increment(ref _parent.StateRefCount) > 1) + { + State = (SrCnnStateBase)_parent.StateRef.Clone(); + } + } + + public ITransformer GetTransformer() + { + return _parent; + } + } + + internal abstract class SrCnnStateBase : SequentialTransformerBase, TState>.StateBase + { + protected SrCnnTransformBase Parent; + + private protected SrCnnStateBase() { } + + private protected override void CloneCore(TState state) + { + base.CloneCore(state); + Contracts.Assert(state is SrCnnStateBase); + } + + private protected SrCnnStateBase(BinaryReader reader) : base(reader) + { + } + + internal override void Save(BinaryWriter writer) + { + base.Save(writer); + } + + private protected override void SetNaOutput(ref VBuffer dst) + { + var outputLength = Parent.OutputLength; + var editor = VBufferEditor.Create(ref dst, outputLength); + + for (int i = 0; i < outputLength; ++i) + editor.Values[i] = 0; + + dst = editor.Commit(); + } + + private protected sealed override void TransformCore(ref TInput input, FixedSizeQueue windowedBuffer, long iteration, ref VBuffer dst) + { + var outputLength = Parent.OutputLength; + + var result = VBufferEditor.Create(ref dst, outputLength); + result.Values.Fill(Double.NaN); + + SpectralResidual(input, windowedBuffer, ref result); + + dst = result.Commit(); + } + + private protected sealed override void InitializeStateCore(bool disk = false) + { + Parent = (SrCnnTransformBase)ParentTransform; + } + + private protected override void LearnStateFromDataCore(FixedSizeQueue data) + { + } + + private protected virtual void SpectralResidual(TInput input, FixedSizeQueue data, ref VBufferEditor result) + { + } + } + } +} diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 05055b5613..281fbe96c2 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.IO; using Microsoft.ML.Data; @@ -41,6 +42,22 @@ public Data(float value) } } + private sealed class TimeSeriesData + { + public float Value; + + public TimeSeriesData(float value) + { + Value = value; + } + } + + private sealed class SrCnnAnomalyDetection + { + [VectorType(3)] + public double[] Prediction { get; set; } + } + [Fact] public void ChangeDetection() { @@ -276,5 +293,46 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() Assert.Equal(0.14823824685192111, prediction.Change[2], precision: 5); // P-Value score Assert.Equal(1.5292508189989167E-07, prediction.Change[3], precision: 5); // Martingale score } + + [Fact] + public void AnomalyDetectionWithSrCnn() + { + var ml = new MLContext(); + + // Generate sample series data with an anomaly + var data = new List(); + for (int index = 0; index < 20; index++) + { + data.Add(new TimeSeriesData(5)); + } + data.Add(new TimeSeriesData(10)); + for (int index = 0; index < 5; index++) + { + data.Add(new TimeSeriesData(5)); + } + + // Convert data to IDataView. + var dataView = ml.Data.LoadFromEnumerable(data); + + // Setup the estimator arguments + string outputColumnName = nameof(SrCnnAnomalyDetection.Prediction); + string inputColumnName = nameof(TimeSeriesData.Value); + + // The transformed data. + var transformedData = ml.Transforms.DetectAnomalyBySrCnn(outputColumnName, inputColumnName, 16, 5, 5, 3, 8, 0.35).Fit(dataView).Transform(dataView); + + // Getting the data of the newly created column as an IEnumerable of SrCnnAnomalyDetection. + var predictionColumn = ml.Data.CreateEnumerable(transformedData, reuseRowObject: false); + + int k = 0; + foreach (var prediction in predictionColumn) + { + if (k == 20) + Assert.Equal(1, prediction.Prediction[0]); + else + Assert.Equal(0, prediction.Prediction[0]); + k += 1; + } + } } }