Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.EntityFrameworkCore.Query;
using System.Collections;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Reflection;
Expand Down Expand Up @@ -41,6 +40,18 @@ private IncludeEvaluator() { }
/// <inheritdoc/>
public IQueryable<T> GetQuery<T>(IQueryable<T> query, ISpecification<T> specification) where T : class
{
if (specification is Specification<T> spec)
{
if (spec.OneOrManyIncludeExpressions.IsEmpty) return query;
if (spec.OneOrManyIncludeExpressions.SingleOrDefault is { } includeExpression)
{
var lambdaExpr = includeExpression.LambdaExpression;
var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, null);
var include = _cache.GetOrAdd(key, CreateIncludeDelegate);
return (IQueryable<T>)include(query, lambdaExpr);
}
}

foreach (var includeExpression in specification.IncludeExpressions)
{
var lambdaExpr = includeExpression.LambdaExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ private IncludeStringEvaluator() { }
/// <inheritdoc/>
public IQueryable<T> GetQuery<T>(IQueryable<T> query, ISpecification<T> specification) where T : class
{
if (specification is Specification<T> spec)
{
if (spec.OneOrManyIncludeStrings.IsEmpty) return query;
if (spec.OneOrManyIncludeStrings.SingleOrDefault is { } includeString)
{
return query.Include(includeString);
}
}

foreach (var includeString in specification.IncludeStrings)
{
query = query.Include(includeString);
Expand Down
24 changes: 12 additions & 12 deletions src/Ardalis.Specification/Specification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ public class Specification<T, TResult> : Specification<T>, ISpecification<T, TRe
public class Specification<T> : ISpecification<T>
{
private const int DEFAULT_CAPACITY_SEARCH = 2;
private const int DEFAULT_CAPACITY_INCLUDE = 2;
private const int DEFAULT_CAPACITY_INCLUDESTRING = 1;

// It is utilized only during the building stage for the sub-chains. Once the state is built, we don't care about it anymore.
// The initial value is not important since the value is always initialized by the root of the chain.
Expand All @@ -43,8 +41,8 @@ public class Specification<T> : ISpecification<T>
private OneOrMany<WhereExpressionInfo<T>> _whereExpressions = new();
private List<SearchExpressionInfo<T>>? _searchExpressions;
private OneOrMany<OrderExpressionInfo<T>> _orderExpressions = new();
private List<IncludeExpressionInfo>? _includeExpressions;
private List<string>? _includeStrings;
private OneOrMany<IncludeExpressionInfo> _includeExpressions = new();
private OneOrMany<string> _includeStrings = new();
private Dictionary<string, object>? _items;
private OneOrMany<string> _queryTags = new();

Expand Down Expand Up @@ -94,8 +92,8 @@ public class Specification<T> : ISpecification<T>
// Specs are not intended to be thread-safe, so we don't need to worry about thread-safety here.
internal void Add(WhereExpressionInfo<T> whereExpression) => _whereExpressions.Add(whereExpression);
internal void Add(OrderExpressionInfo<T> orderExpression) => _orderExpressions.Add(orderExpression);
internal void Add(IncludeExpressionInfo includeExpression) => (_includeExpressions ??= new(DEFAULT_CAPACITY_INCLUDE)).Add(includeExpression);
internal void Add(string includeString) => (_includeStrings ??= new(DEFAULT_CAPACITY_INCLUDESTRING)).Add(includeString);
internal void Add(IncludeExpressionInfo includeExpression) => _includeExpressions.Add(includeExpression);
internal void Add(string includeString) => _includeStrings.Add(includeString);
internal void Add(SearchExpressionInfo<T> searchExpression)
{
if (_searchExpressions is null)
Expand Down Expand Up @@ -132,16 +130,18 @@ internal void Add(SearchExpressionInfo<T> searchExpression)
public IEnumerable<OrderExpressionInfo<T>> OrderExpressions => _orderExpressions.Values;

/// <inheritdoc/>
public IEnumerable<IncludeExpressionInfo> IncludeExpressions => _includeExpressions ?? Enumerable.Empty<IncludeExpressionInfo>();
public IEnumerable<IncludeExpressionInfo> IncludeExpressions => _includeExpressions.Values;

/// <inheritdoc/>
public IEnumerable<string> IncludeStrings => _includeStrings ?? Enumerable.Empty<string>();
public IEnumerable<string> IncludeStrings => _includeStrings.Values;

/// <inheritdoc/>
public IEnumerable<string> QueryTags => _queryTags.Values;

internal OneOrMany<WhereExpressionInfo<T>> OneOrManyWhereExpressions => _whereExpressions;
internal OneOrMany<OrderExpressionInfo<T>> OneOrManyOrderExpressions => _orderExpressions;
internal OneOrMany<IncludeExpressionInfo> OneOrManyIncludeExpressions => _includeExpressions;
internal OneOrMany<string> OneOrManyIncludeStrings => _includeStrings;
internal OneOrMany<string> OneOrManyQueryTags => _queryTags;

/// <inheritdoc/>
Expand Down Expand Up @@ -179,14 +179,14 @@ void ISpecification<T>.CopyTo(Specification<T> otherSpec)
otherSpec._whereExpressions = _whereExpressions.Clone();
}

if (_includeExpressions is not null)
if (!_includeExpressions.IsEmpty)
{
otherSpec._includeExpressions = _includeExpressions.ToList();
otherSpec._includeExpressions = _includeExpressions.Clone();
}

if (_includeStrings is not null)
if (!_includeStrings.IsEmpty)
{
otherSpec._includeStrings = _includeStrings.ToList();
otherSpec._includeStrings = _includeStrings.Clone();
}

if (!_orderExpressions.IsEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,39 @@ public class IncludeEvaluatorTests(TestFactory factory) : IntegrationTest(factor
{
private static readonly IncludeEvaluator _evaluator = IncludeEvaluator.Instance;

[Fact]
public void QueriesMatch_GivenNoIncludeExpression()
{
var spec = new Specification<Store>();

var actual = _evaluator
.GetQuery(DbContext.Stores, spec)
.ToQueryString();

var expected = DbContext.Stores
.ToQueryString();

actual.Should().Be(expected);
}

[Fact]
public void QueriesMatch_GivenSingleIncludeExpression()
{
var spec = new Specification<Store>();
spec.Query
.Include(x => x.Products.Where(x => x.Id > 10));

var actual = _evaluator
.GetQuery(DbContext.Stores, spec)
.ToQueryString();

var expected = DbContext.Stores
.Include(x => x.Products.Where(x => x.Id > 10))
.ToQueryString();

actual.Should().Be(expected);
}

[Fact]
public void QueriesMatch_GivenIncludeExpressions()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,21 @@ public class IncludeStringEvaluatorTests(TestFactory factory) : IntegrationTest(
{
private static readonly IncludeStringEvaluator _evaluator = IncludeStringEvaluator.Instance;

[Fact]
public void QueriesMatch_GivenNoIncludeString()
{
var spec = new Specification<Store>();

var actual = _evaluator
.GetQuery(DbContext.Stores, spec)
.ToQueryString();

var expected = DbContext.Stores
.ToQueryString();

actual.Should().Be(expected);
}

[Fact]
public void QueriesMatch_GivenIncludeString()
{
Expand Down