Skip to content

Commit

Permalink
Handle cases like .First(where) and .Sum
Browse files Browse the repository at this point in the history
  • Loading branch information
zoriya committed Oct 10, 2023
1 parent aa7cf4f commit 4fb6013
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 10 deletions.
20 changes: 20 additions & 0 deletions samples/BasicSample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,31 @@ public static void Main(string[] args)
{
Console.WriteLine($"User name: {u.FullName}");
}

foreach (var u in dbContext.Users.ToList())
{
Console.WriteLine($"User name: {u.FullName}");
}

foreach (var u in dbContext.Users.OrderBy(x => x.FullName))
{
Console.WriteLine($"User name: {u.FullName}");
}
}

{
foreach (var u in dbContext.Users.Where(x => x.TotalSpent >= 1))
{
Console.WriteLine($"User name: {u.FullName}");
}
}

{
var result = dbContext.Users.FirstOrDefault();
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");

result = dbContext.Users.FirstOrDefault(x => x.TotalSpent > 1);
Console.WriteLine($"Our first user {result.FullName} has spent {result.TotalSpent}");
}

{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
Expand All @@ -17,9 +18,26 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
private bool _disableRootRewrite;
private IEntityType? _entityType;

private readonly MethodInfo _select;
private readonly MethodInfo _where;

public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver)
{
_resolver = projectionExpressionResolver;
_select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Select))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
);
_where = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Where))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
);
}

bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out LambdaExpression? reflectedExpression)
Expand All @@ -45,6 +63,7 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La

if (_disableRootRewrite)
{
// This boolean is enabled when a "Select" is encountered
return ret;
}

Expand All @@ -53,10 +72,62 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
// Probably a First() or ToList()
case MethodCallExpression { Arguments.Count: > 0, Object: null } call when _entityType != null:
{
// if return type != IQueryable {
// if return type is IEnuberable {
// // case of a ToList()
// return (ret.arg[0]).Select(...).ToList() or the other method
// } else {
// // case of a Max()
// return ret;
// }
// } else if retrun type == entitytype {
// // case of a first()
// return obj.MyMap(x => new Obj {});
// }


if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
{
// Generic case where the return type is still a IQueryable<T>
return _AddProjectableSelect(call, _entityType);
}

if (call.Method.ReturnType == _entityType.ClrType)
{
// case of a .First(), .SingleAsync()
if (call.Arguments.Count != 1 && true /* Add && arg.count == 1 exist */)
{
// .First(x => whereCondition), since we need to add a select after the last condition but
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
// as .Where(where).Select(x => ...).First()

var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
// The call instance is based on the wrong polymorphied method.
var first = call.Method.DeclaringType?.GetMethods()
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
if (first == null)
{
// Unknown case that should not happen.
return call;
}

return Expression.Call(null, first.MakeGenericMethod(_entityType.ClrType), _AddProjectableSelect(where, _entityType));
}

// .First() without arguments is the same case as bellow so we let it fallthrough
}
else if (!call.Method.ReturnType.IsAssignableTo(typeof(IEnumerable)))
{
// case of something like a .Max(), .Sum()
return call;
}

// return type is IEnumerable<EntityType> or EntityType (in case of fallthrough from a .First())

// case of something like .ToList(), .ToArrayAsync()
var self = _AddProjectableSelect(call.Arguments.First(), _entityType);
return call.Update(null, call.Arguments.Skip(1).Prepend(self));
}
// Probably a foreach call
case QueryRootExpression root:
return _AddProjectableSelect(root, root.EntityType);
default:
Expand Down Expand Up @@ -170,14 +241,7 @@ private Expression _AddProjectableSelect(Expression node, IEntityType entityType
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));

// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
var select = typeof(Queryable).GetMethods(BindingFlags.Static | BindingFlags.Public)
.Where(x => x.Name == nameof(Queryable.Select))
.First(x =>
x.GetParameters().Last().ParameterType // Expression<Func<T, Ret>>
.GetGenericArguments().First() // Func<T, Ret>
.GetGenericArguments().Length == 2 // Separate between Func<T, Ret> and Func<T, int, Ret>
)
.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var xParam = Expression.Parameter(entityType.ClrType);
return Expression.Call(
null,
Expand Down

0 comments on commit 4fb6013

Please sign in to comment.