From f006948378b77e4e4ba20d4183b59b0f0c3a3af3 Mon Sep 17 00:00:00 2001 From: Fati Iseni Date: Fri, 13 Jun 2025 10:59:08 +0200 Subject: [PATCH] Refactor the state of include expressions as OneOrMany. --- .../Evaluators/IncludeEvaluator.cs | 13 +++++++- .../Evaluators/IncludeStringEvaluator.cs | 9 +++++ src/Ardalis.Specification/Specification.cs | 24 +++++++------- .../Evaluators/IncludeEvaluatorTests.cs | 33 +++++++++++++++++++ .../Evaluators/IncludeStringEvaluatorTests.cs | 15 +++++++++ 5 files changed, 81 insertions(+), 13 deletions(-) diff --git a/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs b/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs index 87925228..15343cdd 100644 --- a/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs +++ b/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs @@ -1,5 +1,4 @@ using Microsoft.EntityFrameworkCore.Query; -using System.Collections; using System.Collections.Concurrent; using System.Diagnostics; using System.Reflection; @@ -41,6 +40,18 @@ private IncludeEvaluator() { } /// public IQueryable GetQuery(IQueryable query, ISpecification specification) where T : class { + if (specification is Specification spec) + { + if (spec.OneOrManyIncludeExpressions.IsEmpty) return query; + if (spec.OneOrManyIncludeExpressions.SingleOrDefault is { } includeExpression) + { + var lambdaExpr = includeExpression.LambdaExpression; + var key = new CacheKey(typeof(T), lambdaExpr.ReturnType, null); + var include = _cache.GetOrAdd(key, CreateIncludeDelegate); + return (IQueryable)include(query, lambdaExpr); + } + } + foreach (var includeExpression in specification.IncludeExpressions) { var lambdaExpr = includeExpression.LambdaExpression; diff --git a/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeStringEvaluator.cs b/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeStringEvaluator.cs index 72c57354..9e5273a6 100644 --- a/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeStringEvaluator.cs +++ b/src/Ardalis.Specification.EntityFrameworkCore/Evaluators/IncludeStringEvaluator.cs @@ -10,6 +10,15 @@ private IncludeStringEvaluator() { } /// public IQueryable GetQuery(IQueryable query, ISpecification specification) where T : class { + if (specification is Specification spec) + { + if (spec.OneOrManyIncludeStrings.IsEmpty) return query; + if (spec.OneOrManyIncludeStrings.SingleOrDefault is { } includeString) + { + return query.Include(includeString); + } + } + foreach (var includeString in specification.IncludeStrings) { query = query.Include(includeString); diff --git a/src/Ardalis.Specification/Specification.cs b/src/Ardalis.Specification/Specification.cs index abdde5da..a87cbb4e 100644 --- a/src/Ardalis.Specification/Specification.cs +++ b/src/Ardalis.Specification/Specification.cs @@ -27,8 +27,6 @@ public class Specification : Specification, ISpecification : ISpecification { private const int DEFAULT_CAPACITY_SEARCH = 2; - private const int DEFAULT_CAPACITY_INCLUDE = 2; - private const int DEFAULT_CAPACITY_INCLUDESTRING = 1; // It is utilized only during the building stage for the sub-chains. Once the state is built, we don't care about it anymore. // The initial value is not important since the value is always initialized by the root of the chain. @@ -43,8 +41,8 @@ public class Specification : ISpecification private OneOrMany> _whereExpressions = new(); private List>? _searchExpressions; private OneOrMany> _orderExpressions = new(); - private List? _includeExpressions; - private List? _includeStrings; + private OneOrMany _includeExpressions = new(); + private OneOrMany _includeStrings = new(); private Dictionary? _items; private OneOrMany _queryTags = new(); @@ -94,8 +92,8 @@ public class Specification : ISpecification // Specs are not intended to be thread-safe, so we don't need to worry about thread-safety here. internal void Add(WhereExpressionInfo whereExpression) => _whereExpressions.Add(whereExpression); internal void Add(OrderExpressionInfo orderExpression) => _orderExpressions.Add(orderExpression); - internal void Add(IncludeExpressionInfo includeExpression) => (_includeExpressions ??= new(DEFAULT_CAPACITY_INCLUDE)).Add(includeExpression); - internal void Add(string includeString) => (_includeStrings ??= new(DEFAULT_CAPACITY_INCLUDESTRING)).Add(includeString); + internal void Add(IncludeExpressionInfo includeExpression) => _includeExpressions.Add(includeExpression); + internal void Add(string includeString) => _includeStrings.Add(includeString); internal void Add(SearchExpressionInfo searchExpression) { if (_searchExpressions is null) @@ -132,16 +130,18 @@ internal void Add(SearchExpressionInfo searchExpression) public IEnumerable> OrderExpressions => _orderExpressions.Values; /// - public IEnumerable IncludeExpressions => _includeExpressions ?? Enumerable.Empty(); + public IEnumerable IncludeExpressions => _includeExpressions.Values; /// - public IEnumerable IncludeStrings => _includeStrings ?? Enumerable.Empty(); + public IEnumerable IncludeStrings => _includeStrings.Values; /// public IEnumerable QueryTags => _queryTags.Values; internal OneOrMany> OneOrManyWhereExpressions => _whereExpressions; internal OneOrMany> OneOrManyOrderExpressions => _orderExpressions; + internal OneOrMany OneOrManyIncludeExpressions => _includeExpressions; + internal OneOrMany OneOrManyIncludeStrings => _includeStrings; internal OneOrMany OneOrManyQueryTags => _queryTags; /// @@ -179,14 +179,14 @@ void ISpecification.CopyTo(Specification otherSpec) otherSpec._whereExpressions = _whereExpressions.Clone(); } - if (_includeExpressions is not null) + if (!_includeExpressions.IsEmpty) { - otherSpec._includeExpressions = _includeExpressions.ToList(); + otherSpec._includeExpressions = _includeExpressions.Clone(); } - if (_includeStrings is not null) + if (!_includeStrings.IsEmpty) { - otherSpec._includeStrings = _includeStrings.ToList(); + otherSpec._includeStrings = _includeStrings.Clone(); } if (!_orderExpressions.IsEmpty) diff --git a/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs b/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs index 57c5e90b..f82e3d1f 100644 --- a/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs +++ b/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs @@ -5,6 +5,39 @@ public class IncludeEvaluatorTests(TestFactory factory) : IntegrationTest(factor { private static readonly IncludeEvaluator _evaluator = IncludeEvaluator.Instance; + [Fact] + public void QueriesMatch_GivenNoIncludeExpression() + { + var spec = new Specification(); + + var actual = _evaluator + .GetQuery(DbContext.Stores, spec) + .ToQueryString(); + + var expected = DbContext.Stores + .ToQueryString(); + + actual.Should().Be(expected); + } + + [Fact] + public void QueriesMatch_GivenSingleIncludeExpression() + { + var spec = new Specification(); + spec.Query + .Include(x => x.Products.Where(x => x.Id > 10)); + + var actual = _evaluator + .GetQuery(DbContext.Stores, spec) + .ToQueryString(); + + var expected = DbContext.Stores + .Include(x => x.Products.Where(x => x.Id > 10)) + .ToQueryString(); + + actual.Should().Be(expected); + } + [Fact] public void QueriesMatch_GivenIncludeExpressions() { diff --git a/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeStringEvaluatorTests.cs b/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeStringEvaluatorTests.cs index b8b4592a..13fc8c12 100644 --- a/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeStringEvaluatorTests.cs +++ b/tests/Ardalis.Specification.EntityFrameworkCore.Tests/Evaluators/IncludeStringEvaluatorTests.cs @@ -5,6 +5,21 @@ public class IncludeStringEvaluatorTests(TestFactory factory) : IntegrationTest( { private static readonly IncludeStringEvaluator _evaluator = IncludeStringEvaluator.Instance; + [Fact] + public void QueriesMatch_GivenNoIncludeString() + { + var spec = new Specification(); + + var actual = _evaluator + .GetQuery(DbContext.Stores, spec) + .ToQueryString(); + + var expected = DbContext.Stores + .ToQueryString(); + + actual.Should().Be(expected); + } + [Fact] public void QueriesMatch_GivenIncludeString() {