diff --git a/src/Microsoft.Data.Analysis/DataFrame.cs b/src/Microsoft.Data.Analysis/DataFrame.cs index 8eb04797aa..1bfb4a4784 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.cs @@ -367,6 +367,25 @@ public GroupBy GroupBy(string columnName) DataFrameColumn column = _columnCollection[columnIndex]; return column.GroupBy(columnIndex, this); } + + /// + /// Groups the rows of the by unique values in the column. + /// + /// Type of column used for grouping + /// The column used to group unique values + /// A GroupBy object that stores the group information. + public GroupBy GroupBy(string columnName) + { + GroupBy group = GroupBy(columnName) as GroupBy; + + if (group == null) + { + DataFrameColumn column = this[columnName]; + throw new InvalidCastException(String.Format(Strings.BadColumnCastDuringGrouping, columnName, column.DataType, typeof(TKey))); + } + + return group; + } // In GroupBy and ReadCsv calls, columns get resized. We need to set the RowCount to reflect the true Length of the DataFrame. This does internal validation internal void SetTableRowCount(long rowCount) diff --git a/src/Microsoft.Data.Analysis/GroupBy.cs b/src/Microsoft.Data.Analysis/GroupBy.cs index 5d8013e9b6..64642272d5 100644 --- a/src/Microsoft.Data.Analysis/GroupBy.cs +++ b/src/Microsoft.Data.Analysis/GroupBy.cs @@ -3,7 +3,9 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections; using System.Collections.Generic; +using System.Linq; namespace Microsoft.Data.Analysis { @@ -72,6 +74,33 @@ public abstract class GroupBy public class GroupBy : GroupBy { + #region Internal class that implements IGrouping LINQ interface + private class Grouping : IGrouping + { + private readonly TKey _key; + private readonly IEnumerable _rows; + + public Grouping(TKey key, IEnumerable rows) + { + _key = key; + _rows = rows; + } + + public TKey Key => _key; + + public IEnumerator GetEnumerator() + { + return _rows.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _rows.GetEnumerator(); + } + } + + #endregion + private int _groupByColumnIndex; private IDictionary> _keyToRowIndicesMap; private DataFrame _dataFrame; @@ -464,5 +493,15 @@ public override DataFrame Mean(params string[] columnNames) return ret; } + /// + /// Returns a collection of Grouping objects, where each object represent a set of DataFrameRows having the same Key + /// + public IEnumerable> Groupings + { + get + { + return _keyToRowIndicesMap.Select(kvp => new Grouping(kvp.Key, kvp.Value.Select(index => _dataFrame.Rows[index]))); + } + } } } diff --git a/src/Microsoft.Data.Analysis/Strings.Designer.cs b/src/Microsoft.Data.Analysis/Strings.Designer.cs index ff3cd6cadd..4b24665bf3 100644 --- a/src/Microsoft.Data.Analysis/Strings.Designer.cs +++ b/src/Microsoft.Data.Analysis/Strings.Designer.cs @@ -69,6 +69,15 @@ internal static string BadColumnCast { } } + /// + /// Looks up a localized string similar to Cannot cast elements of column '{0}' type of {1} to type {2} used as TKey in grouping . + /// + internal static string BadColumnCastDuringGrouping { + get { + return ResourceManager.GetString("BadColumnCastDuringGrouping", resourceCulture); + } + } + /// /// Looks up a localized string similar to Line {0} cannot be parsed with the current Delimiters.. /// @@ -365,7 +374,7 @@ internal static string NotSupportedColumnType { return ResourceManager.GetString("NotSupportedColumnType", resourceCulture); } } - + /// /// Looks up a localized string similar to Delimiters is null.. /// diff --git a/src/Microsoft.Data.Analysis/Strings.resx b/src/Microsoft.Data.Analysis/Strings.resx index de91078cec..db6dfb4984 100644 --- a/src/Microsoft.Data.Analysis/Strings.resx +++ b/src/Microsoft.Data.Analysis/Strings.resx @@ -120,6 +120,9 @@ Cannot cast column holding {0} values to type {1} + + Cannot cast elements of column '{0}' type of {1} to type {2} used as TKey in grouping + Line {0} cannot be parsed with the current Delimiters. diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs new file mode 100644 index 0000000000..fdbc859f7b --- /dev/null +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Xunit; + +namespace Microsoft.Data.Analysis.Tests +{ + public class DataFrameGroupByTests + { + [Fact] + public void TestGroupingWithTKeyTypeofString() + { + const int length = 11; + + //Create test dataframe (numbers starting from 0 up to lenght) + DataFrame df = MakeTestDataFrameWithParityAndTensColumns(length); + + var grouping = df.GroupBy("Parity").Groupings; + + //Check groups count + Assert.Equal(2, grouping.Count()); + + //Check number of elements in each group + var oddGroup = grouping.Where(gr => gr.Key == "odd").FirstOrDefault(); + Assert.NotNull(oddGroup); + Assert.Equal(length/2, oddGroup.Count()); + + var evenGroup = grouping.Where(gr => gr.Key == "even").FirstOrDefault(); + Assert.NotNull(evenGroup); + Assert.Equal(length / 2 + length % 2, evenGroup.Count()); + + + } + + [Fact] + public void TestGroupingWithTKey_CornerCases() + { + //Check corner cases + var df = MakeTestDataFrameWithParityAndTensColumns(0); + var grouping = df.GroupBy("Parity").Groupings; + Assert.Empty(grouping); + + + df = MakeTestDataFrameWithParityAndTensColumns(1); + grouping = df.GroupBy("Parity").Groupings; + Assert.Single(grouping); + Assert.Equal("even", grouping.First().Key); + } + + + [Fact] + public void TestGroupingWithTKeyPrimitiveType() + { + const int length = 55; + + //Create test dataframe (numbers starting from 0 up to lenght) + DataFrame df = MakeTestDataFrameWithParityAndTensColumns(length); + + //Group elements by int column, that contain the amount of full tens in each int + var groupings = df.GroupBy("Tens").Groupings.ToDictionary(g => g.Key, g => g.ToList()); + + //Get the amount of all number based columns + int numberColumnsCount = df.Columns.Count - 2; //except "Parity" and "Tens" columns + + //Check each group + for (int i = 0; i < length / 10; i++) + { + Assert.Equal(10, groupings[i].Count()); + + var rows = groupings[i]; + for (int colIndex = 0; colIndex < numberColumnsCount; colIndex++) + { + var values = rows.Select(row => Convert.ToInt32(row[colIndex])); + + for (int j = 0; j < 10; j++) + { + Assert.Contains(i * 10 + j, values); + } + } + } + + //Last group should contain smaller amount of items + Assert.Equal(length % 10, groupings[length / 10].Count()); + } + + [Fact] + public void TestGroupingWithTKeyOfWrongType() + { + + var message = string.Empty; + + //Create test dataframe (numbers starting from 0 up to lenght) + DataFrame df = MakeTestDataFrameWithParityAndTensColumns(1); + + //Use wrong type for grouping + Assert.Throws(() => df.GroupBy("Tens")); + } + + + private DataFrame MakeTestDataFrameWithParityAndTensColumns(int length) + { + DataFrame df = DataFrameTests.MakeDataFrameWithNumericColumns(length, false); + DataFrameColumn parityColumn = new StringDataFrameColumn("Parity", Enumerable.Range(0, length).Select(x => x % 2 == 0 ? "even" : "odd")); + DataFrameColumn tensColumn = new Int32DataFrameColumn("Tens", Enumerable.Range(0, length).Select(x => x / 10)); + df.Columns.Insert(df.Columns.Count, parityColumn); + df.Columns.Insert(df.Columns.Count, tensColumn); + + return df; + } + } +}