Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,13 @@ private static Expression BuildEqualToKeyExpr(
// SQL: WHERE key IS NOT NULL AND key = cursorValue.
keyExpr = Expression.AndAlso(
Expression.NotEqual(keyExpr, nullConst),
Expression.Equal(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero));
BuildEqualComparison(cursorKey, keyValueExpr, cursorExpr));
}
}
else
{
// SQL: WHERE key = cursorValue.
keyExpr = Expression.Equal(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
keyExpr = BuildEqualComparison(cursorKey, keyExpr, cursorExpr);
}

return keyExpr;
Expand Down Expand Up @@ -229,28 +225,22 @@ private static Expression BuildGreaterThanKeyExpr(
if (nullOrdering == NullOrdering.NativeNullsFirst)
{
// SQL: WHERE key > cursorValue.
keyExpr = Expression.GreaterThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
keyExpr = BuildGreaterThanComparison(cursorKey, keyValueExpr, cursorExpr);
}
else
{
// When nulls are last, null is greater than any non-null value.
// SQL: WHERE key IS NULL OR key > cursorValue.
keyExpr = Expression.OrElse(
Expression.Equal(keyExpr, nullConst),
Expression.GreaterThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero));
BuildGreaterThanComparison(cursorKey, keyValueExpr, cursorExpr));
}
}
}
else
{
// SQL: WHERE key > cursorValue.
keyExpr = Expression.GreaterThan(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
keyExpr = BuildGreaterThanComparison(cursorKey, keyExpr, cursorExpr);
}

return keyExpr;
Expand Down Expand Up @@ -296,25 +286,19 @@ private static Expression BuildLessThanKeyExpr(
// SQL: WHERE key IS NULL OR key < cursorValue.
keyExpr = Expression.OrElse(
Expression.Equal(keyExpr, nullConst),
Expression.LessThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero));
BuildLessThanComparison(cursorKey, keyValueExpr, cursorExpr));
}
else
{
// SQL: WHERE key < cursorValue.
keyExpr = Expression.LessThan(
Expression.Call(keyValueExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
keyExpr = BuildLessThanComparison(cursorKey, keyValueExpr, cursorExpr);
}
}
}
else
{
// SQL: WHERE key < cursorValue.
keyExpr = Expression.LessThan(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
keyExpr = BuildLessThanComparison(cursorKey, keyExpr, cursorExpr);
}

return keyExpr;
Expand Down Expand Up @@ -639,6 +623,63 @@ private static Expression CreateAndConvertParameter<T>(T value)
return lambda.Body;
}

private static Expression BuildEqualComparison(
CursorKey cursorKey,
Expression keyExpr,
Expression cursorExpr)
{
var comparisonType = Nullable.GetUnderlyingType(keyExpr.Type) ?? keyExpr.Type;

if (comparisonType.IsEnum)
{
return Expression.Equal(keyExpr, cursorExpr);
}

return Expression.Equal(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}

private static Expression BuildGreaterThanComparison(
CursorKey cursorKey,
Expression keyExpr,
Expression cursorExpr)
{
var comparisonType = Nullable.GetUnderlyingType(keyExpr.Type) ?? keyExpr.Type;

if (comparisonType.IsEnum)
{
var underlyingType = Enum.GetUnderlyingType(comparisonType);
return Expression.GreaterThan(
Expression.Convert(keyExpr, underlyingType),
Expression.Convert(cursorExpr, underlyingType));
}

return Expression.GreaterThan(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}

private static Expression BuildLessThanComparison(
CursorKey cursorKey,
Expression keyExpr,
Expression cursorExpr)
{
var comparisonType = Nullable.GetUnderlyingType(keyExpr.Type) ?? keyExpr.Type;

if (comparisonType.IsEnum)
{
var underlyingType = Enum.GetUnderlyingType(comparisonType);
return Expression.LessThan(
Expression.Convert(keyExpr, underlyingType),
Expression.Convert(cursorExpr, underlyingType));
}

return Expression.LessThan(
Expression.Call(keyExpr, cursorKey.CompareMethod, cursorExpr),
s_zero);
}

private static Expression ReplaceParameter(
LambdaExpression expression,
ParameterExpression replacement)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ public static class CursorKeySerializerRegistration
new BoolCursorKeySerializer(),
new UShortCursorKeySerializer(),
new UIntCursorKeySerializer(),
new ULongCursorKeySerializer()
new ULongCursorKeySerializer(),
new EnumCursorKeySerializer<byte>(),
new EnumCursorKeySerializer<sbyte>(),
new EnumCursorKeySerializer<short>(),
new EnumCursorKeySerializer<ushort>(),
new EnumCursorKeySerializer<int>(),
new EnumCursorKeySerializer<uint>(),
new EnumCursorKeySerializer<long>(),
new EnumCursorKeySerializer<ulong>()
];

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System.Buffers.Text;
using System.Numerics;
using System.Reflection;

namespace GreenDonut.Data.Cursors.Serializers;

internal sealed class EnumCursorKeySerializer<T> : ICursorKeySerializer where T : struct, INumber<T>
{
private static readonly MethodInfo _compareTo = CompareToResolver.GetCompareToMethod<T>();

public bool IsSupported(Type type)
{
var enumType = Nullable.GetUnderlyingType(type) ?? type;
return enumType.IsEnum && Enum.GetUnderlyingType(enumType) == typeof(T);
}

public MethodInfo GetCompareToMethod(Type type)
=> _compareTo;

public object Parse(ReadOnlySpan<byte> formattedKey)
{
var t = typeof(T);

return t switch
{
_ when t == typeof(byte) && Utf8Parser.TryParse(formattedKey, out byte b, out _)
=> b,
_ when t == typeof(sbyte) && Utf8Parser.TryParse(formattedKey, out sbyte sb, out _)
=> sb,
_ when t == typeof(short) && Utf8Parser.TryParse(formattedKey, out short s, out _)
=> s,
_ when t == typeof(ushort) && Utf8Parser.TryParse(formattedKey, out ushort us, out _)
=> us,
_ when t == typeof(int) && Utf8Parser.TryParse(formattedKey, out int i, out _)
=> i,
_ when t == typeof(uint) && Utf8Parser.TryParse(formattedKey, out uint ui, out _)
=> ui,
_ when t == typeof(long) && Utf8Parser.TryParse(formattedKey, out long l, out _)
=> l,
_ when t == typeof(ulong) && Utf8Parser.TryParse(formattedKey, out ulong ul, out _)
=> ul,
_ => throw new InvalidOperationException("Unsupported enum type.")
};
}

public bool TryFormat(object key, Span<byte> buffer, out int written)
{
var t = typeof(T);

return t switch
{
_ when t == typeof(byte) => Utf8Formatter.TryFormat((byte)key, buffer, out written),
_ when t == typeof(sbyte) => Utf8Formatter.TryFormat((sbyte)key, buffer, out written),
_ when t == typeof(short) => Utf8Formatter.TryFormat((short)key, buffer, out written),
_ when t == typeof(ushort) => Utf8Formatter.TryFormat((ushort)key, buffer, out written),
_ when t == typeof(int) => Utf8Formatter.TryFormat((int)key, buffer, out written),
_ when t == typeof(uint) => Utf8Formatter.TryFormat((uint)key, buffer, out written),
_ when t == typeof(long) => Utf8Formatter.TryFormat((long)key, buffer, out written),
_ when t == typeof(ulong) => Utf8Formatter.TryFormat((ulong)key, buffer, out written),
_ => throw new InvalidOperationException("Unsupported enum type.")
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,15 @@ public async Task Fetch_First_2_Items_Second_Page_Descending_AllTypes()
{ "TimeOnly", context.Tests.OrderByDescending(t => t.TimeOnly) },
{ "UInt", context.Tests.OrderByDescending(t => t.UInt) },
{ "ULong", context.Tests.OrderByDescending(t => t.ULong) },
{ "UShort", context.Tests.OrderByDescending(t => t.UShort) }
{ "UShort", context.Tests.OrderByDescending(t => t.UShort) },
{ "ByteEnum", context.Tests.OrderByDescending(t => t.ByteEnum) },
{ "SbyteEnum", context.Tests.OrderByDescending(t => t.SbyteEnum) },
{ "ShortEnum", context.Tests.OrderByDescending(t => t.ShortEnum) },
{ "UshortEnum", context.Tests.OrderByDescending(t => t.UshortEnum) },
{ "IntEnum", context.Tests.OrderByDescending(t => t.IntEnum) },
{ "UintEnum", context.Tests.OrderByDescending(t => t.UintEnum) },
{ "LongEnum", context.Tests.OrderByDescending(t => t.LongEnum) },
{ "UlongEnum", context.Tests.OrderByDescending(t => t.UlongEnum) }
};

// Act
Expand All @@ -435,7 +443,16 @@ public async Task Fetch_First_2_Items_Second_Page_Descending_AllTypes()
}

// Assert
pages.MatchMarkdownSnapshot();
pages.ToDictionary(
p => p.Key,
p =>
p.Value.Select(
t =>
new
{
t.Id,
Value = t.GetType().GetProperty(p.Key)?.GetValue(t)
})).MatchMarkdownSnapshot();
}

private static async Task SeedAsync(string connectionString)
Expand Down Expand Up @@ -476,19 +493,19 @@ private static async Task SeedTestAsync(string connectionString)
await using var context = new CatalogContext(connectionString);
await context.Database.EnsureCreatedAsync();

for (var i = 1; i <= 10; i++)
for (var i = 1; i <= 8; i++)
{
var test = new Test
{
Id = i,
Bool = i % 2 == 0,
Bool = i > 4,
DateOnly = DateOnly.FromDateTime(DateTime.UnixEpoch.AddDays(i - 1)),
DateTime = DateTime.UnixEpoch.AddDays(i - 1),
DateTimeOffset = DateTimeOffset.UnixEpoch.AddDays(i - 1),
Decimal = i,
Double = i,
Float = i,
Guid = Guid.ParseExact($"0000000000000000000000000000000{i - 1}", "N"),
Guid = Guid.ParseExact($"0000000000000000000000000000000{i}", "N"),
Int = i,
Long = i,
Short = (short)i,
Expand All @@ -497,7 +514,15 @@ private static async Task SeedTestAsync(string connectionString)
TimeSpan = TimeSpan.FromHours(i),
UInt = (uint)i,
ULong = (ulong)i,
UShort = (ushort)i
UShort = (ushort)i,
ByteEnum = i > 4 ? TestByteEnum.Two : TestByteEnum.One,
SbyteEnum = i > 4 ? TestSbyteEnum.Two : TestSbyteEnum.One,
ShortEnum = i > 4 ? TestShortEnum.Two : TestShortEnum.One,
UshortEnum = i > 4 ? TestUshortEnum.Two : TestUshortEnum.One,
IntEnum = i > 4 ? TestIntEnum.Two : TestIntEnum.One,
UintEnum = i > 4 ? TestUintEnum.Two : TestUintEnum.One,
LongEnum = i > 4 ? TestLongEnum.Two : TestLongEnum.One,
UlongEnum = i > 4 ? TestUlongEnum.Two : TestUlongEnum.One
};

context.Tests.Add(test);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,68 @@ public class Test
public ulong ULong { get; set; }

public ushort UShort { get; set; }

public TestByteEnum ByteEnum { get; set; }

public TestSbyteEnum SbyteEnum { get; set; }

public TestShortEnum ShortEnum { get; set; }

public TestUshortEnum UshortEnum { get; set; }

public TestIntEnum IntEnum { get; set; }

public TestUintEnum UintEnum { get; set; }

public TestLongEnum LongEnum { get; set; }

public TestUlongEnum UlongEnum { get; set; }
}

public enum TestByteEnum : byte
{
One = 1,
Two = 2
}

public enum TestSbyteEnum : sbyte
{
One = 1,
Two = 2
}

public enum TestShortEnum : short
{
One = 1,
Two = 2
}

public enum TestUshortEnum : ushort
{
One = 1,
Two = 2
}

public enum TestIntEnum
{
One = 1,
Two = 2
}

public enum TestUintEnum : uint
{
One = 1,
Two = 2
}

public enum TestLongEnum : long
{
One = 1,
Two = 2
}

public enum TestUlongEnum : ulong
{
One = 1,
Two = 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ LIMIT @__p_2
## Expression 0

```text
[Microsoft.EntityFrameworkCore.Query.EntityQueryRootExpression].OrderBy(t => t.Name).ThenBy(t => t.Id).Where(t => ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.String]).value) > 0) OrElse ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.String]).value) == 0) AndAlso (t.Id.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.Int32]).value) > 0)))).Take(6)
[Microsoft.EntityFrameworkCore.Query.EntityQueryRootExpression].OrderBy(t => t.Name).ThenBy(t => t.Id).Where(t => ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.String]).value) > 0) OrElse ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.String]).value) == 0) AndAlso (t.Id.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.Int32]).value) > 0)))).Take(6)
```

## Result 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ LIMIT @__p_2
## Expression 0

```text
[Microsoft.EntityFrameworkCore.Query.EntityQueryRootExpression].OrderBy(t => t.Name).ThenBy(t => t.Id).Where(t => ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.String]).value) > 0) OrElse ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.String]).value) == 0) AndAlso (t.Id.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.Int32]).value) > 0)))).Take(6)
[Microsoft.EntityFrameworkCore.Query.EntityQueryRootExpression].OrderBy(t => t.Name).ThenBy(t => t.Id).Where(t => ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.String]).value) > 0) OrElse ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.String]).value) == 0) AndAlso (t.Id.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.Int32]).value) > 0)))).Take(6)
```

## Result 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ LIMIT @__p_2
## Expression 0

```text
[Microsoft.EntityFrameworkCore.Query.EntityQueryRootExpression].OrderByDescending(t => t.Name).ThenByDescending(t => t.Id).Where(t => ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.String]).value) < 0) OrElse ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.String]).value) == 0) AndAlso (t.Id.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass14_0`1[System.Int32]).value) < 0)))).Take(6)
[Microsoft.EntityFrameworkCore.Query.EntityQueryRootExpression].OrderByDescending(t => t.Name).ThenByDescending(t => t.Id).Where(t => ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.String]).value) < 0) OrElse ((t.Name.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.String]).value) == 0) AndAlso (t.Id.CompareTo(value(GreenDonut.Data.Expressions.ExpressionHelpers+<>c__DisplayClass16_0`1[System.Int32]).value) < 0)))).Take(6)
```

## Result 3
Expand Down
Loading
Loading