Skip to content

Commit

Permalink
Query: Make IEntityType part of EntityQueryable
Browse files Browse the repository at this point in the history
Part of #9914

With shared entity types query root can no longer be identified using just type.

Part of #18923
  • Loading branch information
smitpatel committed Feb 6, 2020
1 parent dd843a2 commit 9062743
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public override ShapedQueryExpression TranslateSubquery(Expression expression)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
[Obsolete("Use overload which takes IEntityType.")]
protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
{
Check.NotNull(elementType, nameof(elementType));
Expand All @@ -114,6 +115,29 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen
false));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
{
Check.NotNull(entityType, nameof(entityType));

var selectExpression = _sqlExpressionFactory.Select(entityType);

return new ShapedQueryExpression(
selectExpression,
new EntityShaperExpression(
entityType,
new ProjectionBindingExpression(
selectExpression,
new ProjectionMember(),
typeof(ValueBuffer)),
false));
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,18 @@ protected InMemoryQueryableMethodTranslatingExpressionVisitor(
protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVisitor()
=> new InMemoryQueryableMethodTranslatingExpressionVisitor(this);

[Obsolete("Use overload which takes IEntityType.")]
protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
{
Check.NotNull(elementType, nameof(elementType));

return CreateShapedQueryExpression(_model.FindEntityType(elementType));
}

private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
=> CreateShapedQueryExpressionStatic(entityType);

private static ShapedQueryExpression CreateShapedQueryExpressionStatic(IEntityType entityType)
{
var queryExpression = new InMemoryQueryExpression(entityType);

Expand Down Expand Up @@ -1033,7 +1037,7 @@ private Expression TryExpand(Expression source, MemberIdentity member)
var foreignKey = navigation.ForeignKey;
if (navigation.IsCollection)
{
var innerShapedQuery = CreateShapedQueryExpression(targetEntityType);
var innerShapedQuery = CreateShapedQueryExpressionStatic(targetEntityType);
var innerQueryExpression = (InMemoryQueryExpression)innerShapedQuery.QueryExpression;

var makeNullable = foreignKey.PrincipalKey.Properties
Expand Down Expand Up @@ -1081,7 +1085,7 @@ ProjectionBindingExpression projectionBindingExpression
var innerShaper = entityProjectionExpression.BindNavigation(navigation);
if (innerShaper == null)
{
var innerShapedQuery = CreateShapedQueryExpression(targetEntityType);
var innerShapedQuery = CreateShapedQueryExpressionStatic(targetEntityType);
var innerQueryExpression = (InMemoryQueryExpression)innerShapedQuery.QueryExpression;

var makeNullable = foreignKey.PrincipalKey.Properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlOnQueryable))
{
var sql = (string)((ConstantExpression)methodCallExpression.Arguments[1]).Value;
var queryable = (IQueryable)((ConstantExpression)methodCallExpression.Arguments[0]).Value;
return CreateShapedQueryExpression(queryable.ElementType, sql, methodCallExpression.Arguments[2]);
var queryable = (IEntityQueryable)((ConstantExpression)methodCallExpression.Arguments[0]).Value;

return CreateShapedQueryExpression(
queryable.EntityType, _sqlExpressionFactory.Select(queryable.EntityType, sql, methodCallExpression.Arguments[2]));
}

return base.VisitMethodCall(methodCallExpression);
Expand All @@ -81,6 +83,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVisitor()
=> new RelationalQueryableMethodTranslatingExpressionVisitor(this);

[Obsolete("Use overload which takes IEntityType.")]
protected override ShapedQueryExpression CreateShapedQueryExpression(Type elementType)
{
Check.NotNull(elementType, nameof(elementType));
Expand All @@ -91,12 +94,11 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen
return CreateShapedQueryExpression(entityType, queryExpression);
}

private ShapedQueryExpression CreateShapedQueryExpression(Type elementType, string sql, Expression arguments)
protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType)
{
var entityType = _model.FindEntityType(elementType);
var queryExpression = _sqlExpressionFactory.Select(entityType, sql, arguments);
Check.NotNull(entityType, nameof(entityType));

return CreateShapedQueryExpression(entityType, queryExpression);
return CreateShapedQueryExpression(entityType, _sqlExpressionFactory.Select(entityType));
}

private static ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, SelectExpression selectExpression)
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/Internal/InternalDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ private EntityQueryable<TEntity> EntityQueryable
}

private EntityQueryable<TEntity> CreateEntityQueryable()
=> new EntityQueryable<TEntity>(_context.GetDependencies().QueryProvider);
=> new EntityQueryable<TEntity>(_context.GetDependencies().QueryProvider, EntityType);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,26 +52,34 @@ public virtual void ProcessModelFinalized(
var queryFilter = entityType.GetQueryFilter();
if (queryFilter != null)
{
entityType.SetQueryFilter((LambdaExpression)DbSetAccessRewriter.Visit(queryFilter));
entityType.SetQueryFilter((LambdaExpression)DbSetAccessRewriter.Rewrite(modelBuilder.Metadata, queryFilter));
}

var definingQuery = entityType.GetDefiningQuery();
if (definingQuery != null)
{
entityType.SetDefiningQuery((LambdaExpression)DbSetAccessRewriter.Visit(definingQuery));
entityType.SetDefiningQuery((LambdaExpression)DbSetAccessRewriter.Rewrite(modelBuilder.Metadata, definingQuery));
}
}
}

protected class DbSetAccessRewritingExpressionVisitor : ExpressionVisitor
{
private readonly Type _contextType;
private IModel _model;

public DbSetAccessRewritingExpressionVisitor(Type contextType)
{
_contextType = contextType;
}

public Expression Rewrite(IModel model, Expression expression)
{
_model = model;

return Visit(expression);
}

protected override Expression VisitMember(MemberExpression memberExpression)
{
Check.NotNull(memberExpression, nameof(memberExpression));
Expand All @@ -80,9 +88,10 @@ protected override Expression VisitMember(MemberExpression memberExpression)
&& (memberExpression.Expression.Type.IsAssignableFrom(_contextType)
|| _contextType.IsAssignableFrom(memberExpression.Expression.Type))
&& memberExpression.Type.IsGenericType
&& memberExpression.Type.GetGenericTypeDefinition() == typeof(DbSet<>))
&& memberExpression.Type.GetGenericTypeDefinition() == typeof(DbSet<>)
&& _model != null)
{
return NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(memberExpression.Type.GetGenericArguments()[0]);
return NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(FindEntityType(memberExpression.Type));
}

return base.VisitMember(memberExpression);
Expand All @@ -96,14 +105,16 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
&& methodCallExpression.Object != null
&& typeof(DbContext).IsAssignableFrom(methodCallExpression.Object.Type)
&& methodCallExpression.Type.IsGenericType
&& methodCallExpression.Type.GetGenericTypeDefinition() == typeof(DbSet<>))
&& methodCallExpression.Type.GetGenericTypeDefinition() == typeof(DbSet<>)
&& _model != null)
{
return NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(
methodCallExpression.Type.GetGenericArguments()[0]);
return NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(FindEntityType(methodCallExpression.Type));
}

return base.VisitMethodCall(methodCallExpression);
}

private IEntityType FindEntityType(Type dbSetType) => _model.FindRuntimeEntityType(dbSetType.GetGenericArguments()[0]);
}
}
}
23 changes: 23 additions & 0 deletions src/EFCore/Query/IEntityQueryable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using Microsoft.EntityFrameworkCore.Metadata;

namespace Microsoft.EntityFrameworkCore.Query
{
/// <summary>
/// An interface to identify query roots in LINQ.
/// </summary>
public interface IEntityQueryable
{
/// <summary>
/// Detach context if associated with this query root.
/// </summary>
IEntityQueryable DetachContext();

/// <summary>
/// Return entity type this query root references.
/// </summary>
IEntityType EntityType { get; }
}
}
14 changes: 8 additions & 6 deletions src/EFCore/Query/Internal/AsyncQueryProviderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.Internal
Expand All @@ -24,24 +25,25 @@ public static class AsyncQueryProviderExtensions
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public static ConstantExpression CreateEntityQueryableExpression(
[NotNull] this IAsyncQueryProvider entityQueryProvider, [NotNull] Type type)
[NotNull] this IAsyncQueryProvider entityQueryProvider, [NotNull] IEntityType entityType)
{
Check.NotNull(entityQueryProvider, nameof(entityQueryProvider));
Check.NotNull(type, nameof(type));
Check.NotNull(entityType, nameof(entityType));

return Expression.Constant(
_createEntityQueryableMethod
.MakeGenericMethod(type)
.MakeGenericMethod(entityType.ClrType)
.Invoke(
null, new object[] { entityQueryProvider }));
null, new object[] { entityQueryProvider, entityType }));
}

private static readonly MethodInfo _createEntityQueryableMethod
= typeof(AsyncQueryProviderExtensions)
.GetTypeInfo().GetDeclaredMethod(nameof(CreateEntityQueryable));

[UsedImplicitly]
private static EntityQueryable<TResult> CreateEntityQueryable<TResult>(IAsyncQueryProvider entityQueryProvider)
=> new EntityQueryable<TResult>(entityQueryProvider);
private static EntityQueryable<TResult> CreateEntityQueryable<TResult>(
IAsyncQueryProvider entityQueryProvider, IEntityType entityType)
=> new EntityQueryable<TResult>(entityQueryProvider, entityType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio
return constantExpression.IsEntityQueryable()
? new EntityReferenceExpression(
constantExpression,
_queryCompilationContext.Model.FindEntityType(((IQueryable)constantExpression.Value).ElementType))
((IEntityQueryable)constantExpression.Value).EntityType)
: (Expression)constantExpression;
}

Expand Down
22 changes: 16 additions & 6 deletions src/EFCore/Query/Internal/EntityQueryable`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.Query.Internal
Expand All @@ -24,25 +25,25 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal
public class EntityQueryable<TResult>
: IOrderedQueryable<TResult>,
IAsyncEnumerable<TResult>,
IDetachableContext,
IEntityQueryable,
IListSource
{
private static readonly EntityQueryable<TResult> _detached
= new EntityQueryable<TResult>(NullAsyncQueryProvider.Instance);

private readonly IAsyncQueryProvider _queryProvider;
private readonly IEntityType _entityType;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public EntityQueryable([NotNull] IAsyncQueryProvider queryProvider)
public EntityQueryable([NotNull] IAsyncQueryProvider queryProvider, [NotNull] IEntityType entityType)
{
Check.NotNull(queryProvider, nameof(queryProvider));
Check.NotNull(entityType, nameof(entityType));

_queryProvider = queryProvider;
_entityType = entityType;
Expression = Expression.Constant(this);
}

Expand Down Expand Up @@ -118,7 +119,16 @@ public virtual IAsyncEnumerator<TResult> GetAsyncEnumerator(CancellationToken ca
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
IDetachableContext IDetachableContext.DetachContext() => _detached;
IEntityQueryable IEntityQueryable.DetachContext()
=> new EntityQueryable<TResult>(NullAsyncQueryProvider.Instance, _entityType);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
IEntityType IEntityQueryable.EntityType => _entityType;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
22 changes: 0 additions & 22 deletions src/EFCore/Query/Internal/IDetachableContext.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ protected Expression ExpandNavigation(
return ownedExpansion;
}

var innerQueryableType = targetType.ClrType;
var innerQueryable = NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(innerQueryableType);
var innerQueryable = NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(targetType);
var innerSource = (NavigationExpansionExpression)_navigationExpandingExpressionVisitor.Visit(innerQueryable);
if (entityReference.IncludePaths.ContainsKey(navigation))
{
Expand Down Expand Up @@ -685,9 +684,7 @@ protected override Expression VisitConstant(ConstantExpression constantExpressio

if (constantExpression.IsEntityQueryable())
{
var entityType =
_navigationExpandingExpressionVisitor._queryCompilationContext.Model.FindEntityType(
((IQueryable)constantExpression.Value).ElementType);
var entityType = ((IEntityQueryable)constantExpression.Value).EntityType;
if (entityType == _entityType)
{
return _navigationExpandingExpressionVisitor.CreateNavigationExpansionExpression(constantExpression, entityType);
Expand Down
Loading

0 comments on commit 9062743

Please sign in to comment.