diff --git a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs index 2d96d81687..4ab48e5dfb 100644 --- a/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs +++ b/src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoaderCursor.cs @@ -236,7 +236,7 @@ private Delegate CreateGetterDelegate(int col) private ValueGetter CreateBooleanGetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref bool value) => value = DataReader.GetBoolean(columnIndex); + return (ref bool value) => value = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetBoolean(columnIndex); } private ValueGetter CreateByteGetterDelegate(ColInfo colInfo) @@ -254,61 +254,61 @@ private ValueGetter CreateDateTimeGetterDelegate(ColInfo colInfo) private ValueGetter CreateDoubleGetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref double value) => value = DataReader.GetDouble(columnIndex); + return (ref double value) => value = DataReader.IsDBNull(columnIndex) ? double.NaN : DataReader.GetDouble(columnIndex); } private ValueGetter CreateInt16GetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref short value) => value = DataReader.GetInt16(columnIndex); + return (ref short value) => value = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt16(columnIndex); } private ValueGetter CreateInt32GetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref int value) => value = DataReader.GetInt32(columnIndex); + return (ref int value) => value = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt32(columnIndex); } private ValueGetter CreateInt64GetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref long value) => value = DataReader.GetInt64(columnIndex); + return (ref long value) => value = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetInt64(columnIndex); } private ValueGetter CreateSByteGetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref sbyte value) => value = (sbyte)DataReader.GetByte(columnIndex); + return (ref sbyte value) => value = DataReader.IsDBNull(columnIndex) ? default : (sbyte)DataReader.GetByte(columnIndex); } private ValueGetter CreateSingleGetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref float value) => value = DataReader.GetFloat(columnIndex); + return (ref float value) => value = DataReader.IsDBNull(columnIndex) ? float.NaN : DataReader.GetFloat(columnIndex); } private ValueGetter> CreateStringGetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref ReadOnlyMemory value) => value = DataReader.GetString(columnIndex).AsMemory(); + return (ref ReadOnlyMemory value) => value = DataReader.IsDBNull(columnIndex) ? default : DataReader.GetString(columnIndex).AsMemory(); } private ValueGetter CreateUInt16GetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref ushort value) => value = (ushort)DataReader.GetInt16(columnIndex); + return (ref ushort value) => value = DataReader.IsDBNull(columnIndex) ? default : (ushort)DataReader.GetInt16(columnIndex); } private ValueGetter CreateUInt32GetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref uint value) => value = (uint)DataReader.GetInt32(columnIndex); + return (ref uint value) => value = DataReader.IsDBNull(columnIndex) ? default : (uint)DataReader.GetInt32(columnIndex); } private ValueGetter CreateUInt64GetterDelegate(ColInfo colInfo) { int columnIndex = GetColumnIndex(colInfo); - return (ref ulong value) => value = (ulong)DataReader.GetInt64(columnIndex); + return (ref ulong value) => value = DataReader.IsDBNull(columnIndex) ? default : (ulong)DataReader.GetInt64(columnIndex); } private int GetColumnIndex(ColInfo colInfo) diff --git a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs index 34405a7a0c..c2bc27df75 100644 --- a/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs +++ b/test/Microsoft.ML.Tests/DatabaseLoaderTests.cs @@ -379,7 +379,7 @@ public override int GetOrdinal(string name) public override int GetValues(object[] values) => throw new NotImplementedException(); - public override bool IsDBNull(int ordinal) => throw new NotImplementedException(); + public override bool IsDBNull(int ordinal) => false; public override bool NextResult() => throw new NotImplementedException();