Skip to content

Commit 8ca27c9

Browse files
author
Becca McHenry
committed
fix all the tests
1 parent fe95784 commit 8ca27c9

File tree

3 files changed

+160
-5
lines changed

3 files changed

+160
-5
lines changed

src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public partial class VBufferDataFrameColumn<T> : DataFrameColumn, IEnumerable<VB
2828
/// </summary>
2929
/// <param name="name">The name of the column.</param>
3030
/// <param name="length">Length of values</param>
31-
public VBufferDataFrameColumn(string name, long length = 0) : base(name, 0, typeof(VBuffer<T>))
31+
public VBufferDataFrameColumn(string name, long length = 0) : base(name, length, typeof(VBuffer<T>))
3232
{
3333
int numberOfBuffersRequired = Math.Max((int)(length / int.MaxValue), 1);
3434
for (int i = 0; i < numberOfBuffersRequired; i++)
@@ -205,6 +205,137 @@ protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, D
205205
}
206206
}
207207

208+
private VBufferDataFrameColumn<T> Clone(PrimitiveDataFrameColumn<bool> boolColumn)
209+
{
210+
if (boolColumn.Length > Length)
211+
throw new ArgumentException(Strings.MapIndicesExceedsColumnLenth, nameof(boolColumn));
212+
VBufferDataFrameColumn<T> ret = new VBufferDataFrameColumn<T>(Name, 0);
213+
for (long i = 0; i < boolColumn.Length; i++)
214+
{
215+
bool? value = boolColumn[i];
216+
if (value.HasValue && value.Value == true)
217+
ret.Append(this[i]);
218+
}
219+
return ret;
220+
}
221+
222+
private VBufferDataFrameColumn<T> Clone(PrimitiveDataFrameColumn<long> mapIndices = null, bool invertMapIndex = false)
223+
{
224+
if (mapIndices is null)
225+
{
226+
VBufferDataFrameColumn<T> ret = new VBufferDataFrameColumn<T>(Name, Length);
227+
for (long i = 0; i < Length; i++)
228+
{
229+
ret[i] = this[i];
230+
}
231+
return ret;
232+
}
233+
else
234+
{
235+
return CloneImplementation(mapIndices, invertMapIndex);
236+
}
237+
}
238+
239+
private VBufferDataFrameColumn<T> Clone(PrimitiveDataFrameColumn<int> mapIndices, bool invertMapIndex = false)
240+
{
241+
return CloneImplementation(mapIndices, invertMapIndex);
242+
}
243+
244+
private VBufferDataFrameColumn<T> CloneImplementation<U>(PrimitiveDataFrameColumn<U> mapIndices, bool invertMapIndices = false, long numberOfNullsToAppend = 0)
245+
where U : unmanaged
246+
{
247+
mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices));
248+
VBufferDataFrameColumn<T> ret = new VBufferDataFrameColumn<T>(Name, mapIndices.Length);
249+
250+
List<VBuffer<T>> setBuffer = ret._vBuffers[0];
251+
long setBufferMinRange = 0;
252+
long setBufferMaxRange = int.MaxValue;
253+
List<VBuffer<T>> getBuffer = _vBuffers[0];
254+
long getBufferMinRange = 0;
255+
long getBufferMaxRange = int.MaxValue;
256+
long maxCapacity = int.MaxValue;
257+
if (mapIndices.DataType == typeof(long))
258+
{
259+
PrimitiveDataFrameColumn<long> longMapIndices = mapIndices as PrimitiveDataFrameColumn<long>;
260+
longMapIndices.ApplyElementwise((long? mapIndex, long rowIndex) =>
261+
{
262+
long index = rowIndex;
263+
if (invertMapIndices)
264+
index = longMapIndices.Length - 1 - index;
265+
if (index < setBufferMinRange || index >= setBufferMaxRange)
266+
{
267+
int bufferIndex = (int)(index / maxCapacity);
268+
setBuffer = ret._vBuffers[bufferIndex];
269+
setBufferMinRange = bufferIndex * maxCapacity;
270+
setBufferMaxRange = (bufferIndex + 1) * maxCapacity;
271+
}
272+
index -= setBufferMinRange;
273+
274+
if (mapIndex.Value < getBufferMinRange || mapIndex.Value >= getBufferMaxRange)
275+
{
276+
int bufferIndex = (int)(mapIndex.Value / maxCapacity);
277+
getBuffer = _vBuffers[bufferIndex];
278+
getBufferMinRange = bufferIndex * maxCapacity;
279+
getBufferMaxRange = (bufferIndex + 1) * maxCapacity;
280+
}
281+
int bufferLocalMapIndex = (int)(mapIndex - getBufferMinRange);
282+
VBuffer<T> value = getBuffer[bufferLocalMapIndex];
283+
setBuffer[(int)index] = value;
284+
285+
return mapIndex;
286+
});
287+
}
288+
else if (mapIndices.DataType == typeof(int))
289+
{
290+
PrimitiveDataFrameColumn<int> intMapIndices = mapIndices as PrimitiveDataFrameColumn<int>;
291+
intMapIndices.ApplyElementwise((int? mapIndex, long rowIndex) =>
292+
{
293+
long index = rowIndex;
294+
if (invertMapIndices)
295+
index = intMapIndices.Length - 1 - index;
296+
297+
VBuffer<T> value = getBuffer[mapIndex.Value];
298+
setBuffer[(int)index] = value;
299+
300+
return mapIndex;
301+
});
302+
}
303+
else
304+
{
305+
Debug.Assert(false, nameof(mapIndices.DataType));
306+
}
307+
308+
return ret;
309+
}
310+
311+
public new VBufferDataFrameColumn<T> Clone(DataFrameColumn mapIndices, bool invertMapIndices, long numberOfNullsToAppend)
312+
{
313+
VBufferDataFrameColumn<T> clone;
314+
if (!(mapIndices is null))
315+
{
316+
Type dataType = mapIndices.DataType;
317+
if (dataType != typeof(long) && dataType != typeof(int) && dataType != typeof(bool))
318+
throw new ArgumentException(String.Format(Strings.MultipleMismatchedValueType, typeof(long), typeof(int), typeof(bool)), nameof(mapIndices));
319+
if (mapIndices.DataType == typeof(long))
320+
clone = Clone(mapIndices as PrimitiveDataFrameColumn<long>, invertMapIndices);
321+
else if (dataType == typeof(int))
322+
clone = Clone(mapIndices as PrimitiveDataFrameColumn<int>, invertMapIndices);
323+
else
324+
clone = Clone(mapIndices as PrimitiveDataFrameColumn<bool>);
325+
}
326+
else
327+
{
328+
clone = Clone();
329+
}
330+
331+
return clone;
332+
}
333+
334+
protected override DataFrameColumn CloneImplementation(DataFrameColumn mapIndices = null, bool invertMapIndices = false, long numberOfNullsToAppend = 0)
335+
{
336+
return Clone(mapIndices, invertMapIndices, numberOfNullsToAppend);
337+
}
338+
208339
private static VectorDataViewType GetDataViewType()
209340
{
210341
if (typeof(T) == typeof(bool))

test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
using Xunit;
1212
using Microsoft.ML.Trainers;
1313

14-
1514
namespace Microsoft.Data.Analysis.Tests
1615
{
1716
public partial class DataFrameIDataViewTests
@@ -23,7 +22,7 @@ public void TestIDataView()
2322

2423
DataDebuggerPreview preview = dataView.Preview();
2524
Assert.Equal(10, preview.RowView.Length);
26-
Assert.Equal(16, preview.ColumnView.Length);
25+
Assert.Equal(17, preview.ColumnView.Length);
2726

2827
Assert.Equal("Byte", preview.ColumnView[0].Column.Name);
2928
Assert.Equal((byte)0, preview.ColumnView[0].Values[0]);
@@ -88,6 +87,10 @@ public void TestIDataView()
8887
Assert.Equal("ArrowString", preview.ColumnView[15].Column.Name);
8988
Assert.Equal("foo".ToString(), preview.ColumnView[15].Values[0].ToString());
9089
Assert.Equal("foo".ToString(), preview.ColumnView[15].Values[1].ToString());
90+
91+
Assert.Equal("VBuffer", preview.ColumnView[16].Column.Name);
92+
Assert.Equal("Dense vector of size 5", preview.ColumnView[16].Values[0].ToString());
93+
Assert.Equal("Dense vector of size 5", preview.ColumnView[16].Values[1].ToString());
9194
}
9295

9396
[Fact]
@@ -125,7 +128,7 @@ public void TestIDataViewWithNulls()
125128

126129
DataDebuggerPreview preview = dataView.Preview();
127130
Assert.Equal(length, preview.RowView.Length);
128-
Assert.Equal(16, preview.ColumnView.Length);
131+
Assert.Equal(17, preview.ColumnView.Length);
129132

130133
Assert.Equal("Byte", preview.ColumnView[0].Column.Name);
131134
Assert.Equal((byte)0, preview.ColumnView[0].Values[0]);
@@ -238,12 +241,16 @@ public void TestIDataViewWithNulls()
238241
Assert.Equal("foo", preview.ColumnView[15].Values[4].ToString());
239242
Assert.Equal("", preview.ColumnView[15].Values[5].ToString()); // null row
240243
Assert.Equal("foo", preview.ColumnView[15].Values[6].ToString());
244+
245+
Assert.Equal("VBuffer", preview.ColumnView[16].Column.Name);
246+
Assert.True(preview.ColumnView[16].Values[0] is VBuffer<int>);
247+
Assert.True(preview.ColumnView[16].Values[6] is VBuffer<int>);
241248
}
242249

243250
[Fact]
244251
public void TestDataFrameFromIDataView()
245252
{
246-
DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false);
253+
DataFrame df = DataFrameTests.MakeDataFrameWithAllMutableAndArrowColumnTypes(10, withNulls: false);
247254
df.Columns.Remove("Char"); // Because chars are returned as uint16 by IDataView, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts
248255
IDataView dfAsIDataView = df;
249256
DataFrame newDf = dfAsIDataView.ToDataFrame();

test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,28 @@ public static ArrowStringDataFrameColumn CreateArrowStringColumn(int length, boo
7575
return new ArrowStringDataFrameColumn("ArrowString", dataMemory, offsetMemory, nullMemory, length, nullCount);
7676
}
7777

78+
public static VBufferDataFrameColumn<int> CreateVBufferDataFrame(int length)
79+
{
80+
var buffers = Enumerable.Repeat(new VBuffer<int>(5, new[] { 0, 1, 2, 3, 4 }), length).ToArray();
81+
return new VBufferDataFrameColumn<int>("VBuffer", buffers);
82+
}
83+
7884
public static DataFrame MakeDataFrameWithAllColumnTypes(int length, bool withNulls = true)
85+
{
86+
DataFrame df = MakeDataFrameWithAllMutableAndArrowColumnTypes(length, withNulls);
87+
88+
var vBufferColumn = CreateVBufferDataFrame(length);
89+
df.Columns.Insert(df.Columns.Count, vBufferColumn);
90+
91+
return df;
92+
}
93+
94+
public static DataFrame MakeDataFrameWithAllMutableAndArrowColumnTypes(int length, bool withNulls = true)
7995
{
8096
DataFrame df = MakeDataFrameWithAllMutableColumnTypes(length, withNulls);
8197
DataFrameColumn arrowStringColumn = CreateArrowStringColumn(length, withNulls);
8298
df.Columns.Insert(df.Columns.Count, arrowStringColumn);
99+
83100
return df;
84101
}
85102

0 commit comments

Comments
 (0)