diff --git a/src/Microsoft.Data.Analysis/DataFrame.cs b/src/Microsoft.Data.Analysis/DataFrame.cs index 39f4865b4d..8eb04797aa 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.cs @@ -335,7 +335,7 @@ public DataFrame Sample(int numberOfRows) int shuffleLowerLimit = 0; int shuffleUpperLimit = (int)Math.Min(Int32.MaxValue, Rows.Count); - + int[] shuffleArray = Enumerable.Range(0, shuffleUpperLimit).ToArray(); Random rand = new Random(); while (shuffleLowerLimit < numberOfRows) @@ -349,7 +349,7 @@ public DataFrame Sample(int numberOfRows) ArraySegment segment = new ArraySegment(shuffleArray, 0, shuffleLowerLimit); PrimitiveDataFrameColumn indices = new PrimitiveDataFrameColumn("indices", segment); - + return Clone(indices); } @@ -623,12 +623,16 @@ private void OnColumnsChanged() private DataFrame Sort(string columnName, bool isAscending) { DataFrameColumn column = Columns[columnName]; - DataFrameColumn sortIndices = column.GetAscendingSortIndices(); + PrimitiveDataFrameColumn sortIndices = column.GetAscendingSortIndices(out Int64DataFrameColumn nullIndices); + for (long i = 0; i < nullIndices.Length; i++) + { + sortIndices.Append(nullIndices[i]); + } List newColumns = new List(Columns.Count); for (int i = 0; i < Columns.Count; i++) { DataFrameColumn oldColumn = Columns[i]; - DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending, oldColumn.NullCount); + DataFrameColumn newColumn = oldColumn.Clone(sortIndices, !isAscending); Debug.Assert(newColumn.NullCount == oldColumn.NullCount); newColumns.Add(newColumn); } diff --git a/src/Microsoft.Data.Analysis/DataFrameColumn.cs b/src/Microsoft.Data.Analysis/DataFrameColumn.cs index bd21d6fe96..618d8d2d47 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumn.cs @@ -199,7 +199,7 @@ public object this[long rowIndex] /// public virtual DataFrameColumn Sort(bool ascending = true) { - PrimitiveDataFrameColumn sortIndices = GetAscendingSortIndices(); + PrimitiveDataFrameColumn sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _); return Clone(sortIndices, !ascending, NullCount); } @@ -331,7 +331,11 @@ public virtual StringDataFrameColumn Info() /// public virtual DataFrameColumn Description() => throw new NotImplementedException(); - internal virtual PrimitiveDataFrameColumn GetAscendingSortIndices() => throw new NotImplementedException(); + /// + /// Returns the indices of non-null values that, when applied, result in this column being sorted in ascending order. Also returns the indices of null values in . + /// + /// Indices of values that are . + internal virtual PrimitiveDataFrameColumn GetAscendingSortIndices(out Int64DataFrameColumn nullIndices) => throw new NotImplementedException(); internal delegate long GetBufferSortIndex(int bufferIndex, int sortIndex); internal delegate ValueTuple GetValueAndBufferSortIndexAtBuffer(int bufferIndex, int valueIndex); diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs index 60f7e21046..d83c601777 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.Sort.cs @@ -14,20 +14,21 @@ public partial class PrimitiveDataFrameColumn : DataFrameColumn { public new PrimitiveDataFrameColumn Sort(bool ascending = true) { - PrimitiveDataFrameColumn sortIndices = GetAscendingSortIndices(); + PrimitiveDataFrameColumn sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _); return Clone(sortIndices, !ascending, NullCount); } - internal override PrimitiveDataFrameColumn GetAscendingSortIndices() + internal override PrimitiveDataFrameColumn GetAscendingSortIndices(out Int64DataFrameColumn nullIndices) { - // The return sortIndices contains only the non null indices. - GetSortIndices(Comparer.Default, out PrimitiveDataFrameColumn sortIndices); + Int64DataFrameColumn sortIndices = GetSortIndices(Comparer.Default, out nullIndices); return sortIndices; } - private void GetSortIndices(IComparer comparer, out PrimitiveDataFrameColumn columnSortIndices) + private Int64DataFrameColumn GetSortIndices(IComparer comparer, out Int64DataFrameColumn columnNullIndices) { List> bufferSortIndices = new List>(_columnContainer.Buffers.Count); + columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount); + long nullIndicesSlot = 0; // Sort each buffer first for (int b = 0; b < _columnContainer.Buffers.Count; b++) { @@ -35,15 +36,24 @@ private void GetSortIndices(IComparer comparer, out PrimitiveDataFrameColumn< ReadOnlySpan nullBitMapSpan = _columnContainer.NullBitMapBuffers[b].ReadOnlySpan; int[] sortIndices = new int[buffer.Length]; for (int i = 0; i < buffer.Length; i++) + { sortIndices[i] = i; + } IntrospectiveSort(buffer.ReadOnlySpan, buffer.Length, sortIndices, comparer); // Bug fix: QuickSort is not stable. When PrimitiveDataFrameColumn has null values and default values, they move around List nonNullSortIndices = new List(); for (int i = 0; i < sortIndices.Length; i++) { - if (_columnContainer.IsValid(nullBitMapSpan, sortIndices[i])) + int localSortIndex = sortIndices[i]; + if (_columnContainer.IsValid(nullBitMapSpan, localSortIndex)) + { nonNullSortIndices.Add(sortIndices[i]); - + } + else + { + columnNullIndices[nullIndicesSlot] = localSortIndex + b * _columnContainer.Buffers[0].Length; + nullIndicesSlot++; + } } bufferSortIndices.Add(nonNullSortIndices); } @@ -90,11 +100,13 @@ ValueTuple GetFirstNonNullValueAndBufferIndexStartingAtIndex(int bufferI heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferIndex.Item1, new List>() { (valueAndBufferIndex.Item2, i) }); } } - columnSortIndices = new PrimitiveDataFrameColumn("SortIndices"); + Int64DataFrameColumn columnSortIndices = new Int64DataFrameColumn("SortIndices"); GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Count); GetValueAndBufferSortIndexAtBuffer getValueAndBufferSortIndexAtBuffer = new GetValueAndBufferSortIndexAtBuffer((int bufferIndex, int sortIndex) => GetFirstNonNullValueAndBufferIndexStartingAtIndex(bufferIndex, sortIndex)); GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Count); PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAndBufferSortIndexAtBuffer, getBufferLengthAtIndex); + + return columnSortIndices; } } } diff --git a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs index a7e7d20cb9..6f54641468 100644 --- a/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs @@ -225,7 +225,7 @@ public override double Median() // Not the most efficient implementation. Using a selection algorithm here would be O(n) instead of O(nLogn) if (Length == 0) return 0; - PrimitiveDataFrameColumn sortIndices = GetAscendingSortIndices(); + PrimitiveDataFrameColumn sortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _); long middle = sortIndices.Length / 2; double middleValue = (double)Convert.ChangeType(this[sortIndices[middle].Value].Value, typeof(double)); if (Length % 2 == 0) diff --git a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs index 7ada30e10c..1f0983c942 100644 --- a/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/StringDataFrameColumn.cs @@ -171,25 +171,32 @@ public IEnumerator GetEnumerator() public new StringDataFrameColumn Sort(bool ascending = true) { - PrimitiveDataFrameColumn columnSortIndices = GetAscendingSortIndices(); + PrimitiveDataFrameColumn columnSortIndices = GetAscendingSortIndices(out Int64DataFrameColumn _); return Clone(columnSortIndices, !ascending, NullCount); } - internal override PrimitiveDataFrameColumn GetAscendingSortIndices() + internal override PrimitiveDataFrameColumn GetAscendingSortIndices(out Int64DataFrameColumn nullIndices) { - GetSortIndices(Comparer.Default, out PrimitiveDataFrameColumn columnSortIndices); + PrimitiveDataFrameColumn columnSortIndices = GetSortIndices(Comparer.Default, out nullIndices); return columnSortIndices; } - private void GetSortIndices(Comparer comparer, out PrimitiveDataFrameColumn columnSortIndices) + private PrimitiveDataFrameColumn GetSortIndices(Comparer comparer, out Int64DataFrameColumn columnNullIndices) { List bufferSortIndices = new List(_stringBuffers.Count); + columnNullIndices = new Int64DataFrameColumn("NullIndices", NullCount); + long nullIndicesSlot = 0; foreach (List buffer in _stringBuffers) { var sortIndices = new int[buffer.Count]; for (int i = 0; i < buffer.Count; i++) { sortIndices[i] = i; + if (buffer[i] == null) + { + columnNullIndices[nullIndicesSlot] = i + bufferSortIndices.Count * int.MaxValue; + nullIndicesSlot++; + } } // TODO: Refactor the sort routine to also work with IList? string[] array = buffer.ToArray(); @@ -227,11 +234,12 @@ ValueTuple GetFirstNonNullValueStartingAtIndex(int stringBufferInde heapOfValueAndListOfTupleOfSortAndBufferIndex.Add(valueAndBufferSortIndex.Item1, new List>() { (valueAndBufferSortIndex.Item2, i) }); } } - columnSortIndices = new PrimitiveDataFrameColumn("SortIndices"); + PrimitiveDataFrameColumn columnSortIndices = new PrimitiveDataFrameColumn("SortIndices"); GetBufferSortIndex getBufferSortIndex = new GetBufferSortIndex((int bufferIndex, int sortIndex) => (bufferSortIndices[bufferIndex][sortIndex]) + bufferIndex * bufferSortIndices[0].Length); GetValueAndBufferSortIndexAtBuffer getValueAtBuffer = new GetValueAndBufferSortIndexAtBuffer((int bufferIndex, int sortIndex) => GetFirstNonNullValueStartingAtIndex(bufferIndex, sortIndex)); GetBufferLengthAtIndex getBufferLengthAtIndex = new GetBufferLengthAtIndex((int bufferIndex) => bufferSortIndices[bufferIndex].Length); PopulateColumnSortIndicesWithHeap(heapOfValueAndListOfTupleOfSortAndBufferIndex, columnSortIndices, getBufferSortIndex, getValueAtBuffer, getBufferLengthAtIndex); + return columnSortIndices; } public new StringDataFrameColumn Clone(DataFrameColumn mapIndices, bool invertMapIndices, long numberOfNullsToAppend) diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 300babbffb..8011c0ea6d 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -815,10 +815,10 @@ public void TestOrderBy() // Sort by "Int" in descending order sortedDf = df.OrderByDescending("Int"); - Assert.Null(sortedDf.Columns["Int"][19]); - Assert.Equal(-1, sortedDf.Columns["Int"][18]); - Assert.Equal(100, sortedDf.Columns["Int"][1]); - Assert.Equal(2000, sortedDf.Columns["Int"][0]); + Assert.Null(sortedDf.Columns["Int"][0]); + Assert.Equal(-1, sortedDf.Columns["Int"][19]); + Assert.Equal(100, sortedDf.Columns["Int"][2]); + Assert.Equal(2000, sortedDf.Columns["Int"][1]); // Sort by "String" in ascending order sortedDf = df.OrderBy("String"); @@ -829,9 +829,9 @@ public void TestOrderBy() // Sort by "String" in descending order sortedDf = df.OrderByDescending("String"); - Assert.Null(sortedDf.Columns["Int"][19]); - Assert.Equal(8, sortedDf.Columns["Int"][1]); - Assert.Equal(9, sortedDf.Columns["Int"][0]); + Assert.Null(sortedDf.Columns["Int"][0]); + Assert.Equal(8, sortedDf.Columns["Int"][2]); + Assert.Equal(9, sortedDf.Columns["Int"][1]); } [Fact] @@ -920,6 +920,43 @@ public void TestPrimitiveColumnSort(int numberOfNulls) Assert.Null(sortedIntColumn[9]); } + [Fact] + public void TestSortWithDifferentNullCountsInColumns() + { + DataFrame dataFrame = MakeDataFrameWithAllMutableColumnTypes(10); + dataFrame["Int"][3] = null; + dataFrame["String"][3] = null; + DataFrame sorted = dataFrame.OrderBy("Int"); + void Verify(DataFrame sortedDataFrame) + { + Assert.Equal(10, sortedDataFrame.Rows.Count); + DataFrameRow lastRow = sortedDataFrame.Rows[sortedDataFrame.Rows.Count - 1]; + DataFrameRow penultimateRow = sortedDataFrame.Rows[sortedDataFrame.Rows.Count - 2]; + foreach (object value in lastRow) + { + Assert.Null(value); + } + + for (int i = 0; i < sortedDataFrame.Columns.Count; i++) + { + string columnName = sortedDataFrame.Columns[i].Name; + if (columnName != "String" && columnName != "Int") + { + Assert.Equal(dataFrame[columnName][3], penultimateRow[i]); + } + else if (columnName == "String" || columnName == "Int") + { + Assert.Null(penultimateRow[i]); + } + } + } + + Verify(sorted); + + sorted = dataFrame.OrderBy("String"); + Verify(sorted); + } + private void VerifyJoin(DataFrame join, DataFrame left, DataFrame right, JoinAlgorithm joinAlgorithm) { Int64DataFrameColumn mapIndices = new Int64DataFrameColumn("map", join.Rows.Count);