diff --git a/src/Microsoft.Data.Analysis/DataFrame.cs b/src/Microsoft.Data.Analysis/DataFrame.cs
index 8eb04797aa..1bfb4a4784 100644
--- a/src/Microsoft.Data.Analysis/DataFrame.cs
+++ b/src/Microsoft.Data.Analysis/DataFrame.cs
@@ -367,6 +367,25 @@ public GroupBy GroupBy(string columnName)
DataFrameColumn column = _columnCollection[columnIndex];
return column.GroupBy(columnIndex, this);
}
+
+ ///
+ /// Groups the rows of the by unique values in the column.
+ ///
+ /// Type of column used for grouping
+ /// The column used to group unique values
+ /// A GroupBy object that stores the group information.
+ public GroupBy GroupBy(string columnName)
+ {
+ GroupBy group = GroupBy(columnName) as GroupBy;
+
+ if (group == null)
+ {
+ DataFrameColumn column = this[columnName];
+ throw new InvalidCastException(String.Format(Strings.BadColumnCastDuringGrouping, columnName, column.DataType, typeof(TKey)));
+ }
+
+ return group;
+ }
// In GroupBy and ReadCsv calls, columns get resized. We need to set the RowCount to reflect the true Length of the DataFrame. This does internal validation
internal void SetTableRowCount(long rowCount)
diff --git a/src/Microsoft.Data.Analysis/GroupBy.cs b/src/Microsoft.Data.Analysis/GroupBy.cs
index 5d8013e9b6..64642272d5 100644
--- a/src/Microsoft.Data.Analysis/GroupBy.cs
+++ b/src/Microsoft.Data.Analysis/GroupBy.cs
@@ -3,7 +3,9 @@
// See the LICENSE file in the project root for more information.
using System;
+using System.Collections;
using System.Collections.Generic;
+using System.Linq;
namespace Microsoft.Data.Analysis
{
@@ -72,6 +74,33 @@ public abstract class GroupBy
public class GroupBy : GroupBy
{
+ #region Internal class that implements IGrouping LINQ interface
+ private class Grouping : IGrouping
+ {
+ private readonly TKey _key;
+ private readonly IEnumerable _rows;
+
+ public Grouping(TKey key, IEnumerable rows)
+ {
+ _key = key;
+ _rows = rows;
+ }
+
+ public TKey Key => _key;
+
+ public IEnumerator GetEnumerator()
+ {
+ return _rows.GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return _rows.GetEnumerator();
+ }
+ }
+
+ #endregion
+
private int _groupByColumnIndex;
private IDictionary> _keyToRowIndicesMap;
private DataFrame _dataFrame;
@@ -464,5 +493,15 @@ public override DataFrame Mean(params string[] columnNames)
return ret;
}
+ ///
+ /// Returns a collection of Grouping objects, where each object represent a set of DataFrameRows having the same Key
+ ///
+ public IEnumerable> Groupings
+ {
+ get
+ {
+ return _keyToRowIndicesMap.Select(kvp => new Grouping(kvp.Key, kvp.Value.Select(index => _dataFrame.Rows[index])));
+ }
+ }
}
}
diff --git a/src/Microsoft.Data.Analysis/Strings.Designer.cs b/src/Microsoft.Data.Analysis/Strings.Designer.cs
index ff3cd6cadd..4b24665bf3 100644
--- a/src/Microsoft.Data.Analysis/Strings.Designer.cs
+++ b/src/Microsoft.Data.Analysis/Strings.Designer.cs
@@ -69,6 +69,15 @@ internal static string BadColumnCast {
}
}
+ ///
+ /// Looks up a localized string similar to Cannot cast elements of column '{0}' type of {1} to type {2} used as TKey in grouping .
+ ///
+ internal static string BadColumnCastDuringGrouping {
+ get {
+ return ResourceManager.GetString("BadColumnCastDuringGrouping", resourceCulture);
+ }
+ }
+
///
/// Looks up a localized string similar to Line {0} cannot be parsed with the current Delimiters..
///
@@ -365,7 +374,7 @@ internal static string NotSupportedColumnType {
return ResourceManager.GetString("NotSupportedColumnType", resourceCulture);
}
}
-
+
///
/// Looks up a localized string similar to Delimiters is null..
///
diff --git a/src/Microsoft.Data.Analysis/Strings.resx b/src/Microsoft.Data.Analysis/Strings.resx
index de91078cec..db6dfb4984 100644
--- a/src/Microsoft.Data.Analysis/Strings.resx
+++ b/src/Microsoft.Data.Analysis/Strings.resx
@@ -120,6 +120,9 @@
Cannot cast column holding {0} values to type {1}
+
+ Cannot cast elements of column '{0}' type of {1} to type {2} used as TKey in grouping
+
Line {0} cannot be parsed with the current Delimiters.
diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs
new file mode 100644
index 0000000000..fdbc859f7b
--- /dev/null
+++ b/test/Microsoft.Data.Analysis.Tests/DataFrameGroupByTests.cs
@@ -0,0 +1,116 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Xunit;
+
+namespace Microsoft.Data.Analysis.Tests
+{
+ public class DataFrameGroupByTests
+ {
+ [Fact]
+ public void TestGroupingWithTKeyTypeofString()
+ {
+ const int length = 11;
+
+ //Create test dataframe (numbers starting from 0 up to lenght)
+ DataFrame df = MakeTestDataFrameWithParityAndTensColumns(length);
+
+ var grouping = df.GroupBy("Parity").Groupings;
+
+ //Check groups count
+ Assert.Equal(2, grouping.Count());
+
+ //Check number of elements in each group
+ var oddGroup = grouping.Where(gr => gr.Key == "odd").FirstOrDefault();
+ Assert.NotNull(oddGroup);
+ Assert.Equal(length/2, oddGroup.Count());
+
+ var evenGroup = grouping.Where(gr => gr.Key == "even").FirstOrDefault();
+ Assert.NotNull(evenGroup);
+ Assert.Equal(length / 2 + length % 2, evenGroup.Count());
+
+
+ }
+
+ [Fact]
+ public void TestGroupingWithTKey_CornerCases()
+ {
+ //Check corner cases
+ var df = MakeTestDataFrameWithParityAndTensColumns(0);
+ var grouping = df.GroupBy("Parity").Groupings;
+ Assert.Empty(grouping);
+
+
+ df = MakeTestDataFrameWithParityAndTensColumns(1);
+ grouping = df.GroupBy("Parity").Groupings;
+ Assert.Single(grouping);
+ Assert.Equal("even", grouping.First().Key);
+ }
+
+
+ [Fact]
+ public void TestGroupingWithTKeyPrimitiveType()
+ {
+ const int length = 55;
+
+ //Create test dataframe (numbers starting from 0 up to lenght)
+ DataFrame df = MakeTestDataFrameWithParityAndTensColumns(length);
+
+ //Group elements by int column, that contain the amount of full tens in each int
+ var groupings = df.GroupBy("Tens").Groupings.ToDictionary(g => g.Key, g => g.ToList());
+
+ //Get the amount of all number based columns
+ int numberColumnsCount = df.Columns.Count - 2; //except "Parity" and "Tens" columns
+
+ //Check each group
+ for (int i = 0; i < length / 10; i++)
+ {
+ Assert.Equal(10, groupings[i].Count());
+
+ var rows = groupings[i];
+ for (int colIndex = 0; colIndex < numberColumnsCount; colIndex++)
+ {
+ var values = rows.Select(row => Convert.ToInt32(row[colIndex]));
+
+ for (int j = 0; j < 10; j++)
+ {
+ Assert.Contains(i * 10 + j, values);
+ }
+ }
+ }
+
+ //Last group should contain smaller amount of items
+ Assert.Equal(length % 10, groupings[length / 10].Count());
+ }
+
+ [Fact]
+ public void TestGroupingWithTKeyOfWrongType()
+ {
+
+ var message = string.Empty;
+
+ //Create test dataframe (numbers starting from 0 up to lenght)
+ DataFrame df = MakeTestDataFrameWithParityAndTensColumns(1);
+
+ //Use wrong type for grouping
+ Assert.Throws(() => df.GroupBy("Tens"));
+ }
+
+
+ private DataFrame MakeTestDataFrameWithParityAndTensColumns(int length)
+ {
+ DataFrame df = DataFrameTests.MakeDataFrameWithNumericColumns(length, false);
+ DataFrameColumn parityColumn = new StringDataFrameColumn("Parity", Enumerable.Range(0, length).Select(x => x % 2 == 0 ? "even" : "odd"));
+ DataFrameColumn tensColumn = new Int32DataFrameColumn("Tens", Enumerable.Range(0, length).Select(x => x / 10));
+ df.Columns.Insert(df.Columns.Count, parityColumn);
+ df.Columns.Insert(df.Columns.Count, tensColumn);
+
+ return df;
+ }
+ }
+}