diff --git a/src/LinqTests/Acceptance/query_with_inheritance.cs b/src/LinqTests/Acceptance/query_with_inheritance.cs index 0f9d846cee..2283eefbe7 100644 --- a/src/LinqTests/Acceptance/query_with_inheritance.cs +++ b/src/LinqTests/Acceptance/query_with_inheritance.cs @@ -1,8 +1,10 @@ using System; using System.Linq; +using System.Linq.Expressions; using System.Threading.Tasks; using JasperFx; using Marten; +using Marten.Linq; using Marten.Testing.Harness; using Shouldly; using Weasel.Core; @@ -26,14 +28,19 @@ public class Smurf: ISmurf public interface IPapaSmurf: ISmurf { + bool IsVillageLeader { get; set; } } public class PapaSmurf: Smurf, IPapaSmurf { + public bool IsVillageLeader { get; set; } + + public bool IsPapa { get; set; } = true; } public class PapySmurf: Smurf, IPapaSmurf { + public bool IsVillageLeader { get; set; } } public class BrainySmurf: PapaSmurf @@ -240,5 +247,36 @@ public async Task get_all_subclasses_of_an_interface() theSession.Query().Count().ShouldBe(3); } + [Fact] + public async Task search_on_property_of_subclass() + { + var smurf = new Smurf {Ability = "Follow the herd"}; + var papa = new PapaSmurf {Ability = "Lead", IsVillageLeader = true }; + var papy = new PapySmurf {Ability = "Lead"}; + var brainy = new BrainySmurf {Ability = "Invent"}; + theSession.Store(smurf, papa, brainy, papy); + + await theSession.SaveChangesAsync(); + + (await theSession.Query().WhereSub(x => x.IsVillageLeader).CountAsync()).ShouldBe(1); + } + + [Fact] + public async Task search_on_property_of_subclass_and_parent() + { + var smurf = new Smurf {Ability = "Follow the herd"}; + var papa = new PapaSmurf {Ability = "Lead" }; + var papy = new PapySmurf {Ability = "Lead"}; + var brainy = new BrainySmurf {Ability = "Invent"}; + theSession.Store(smurf, papa, brainy, papy); + + await theSession.SaveChangesAsync(); + + (await theSession.Query() + .WhereSub(x => x.IsPapa) + .Where(x => x.Ability == "Invent") + .CountAsync()).ShouldBe(1); + } + #endregion } diff --git a/src/Marten/Linq/IMartenQueryable.cs b/src/Marten/Linq/IMartenQueryable.cs index b5e1ae30ee..713ec5d93d 100644 --- a/src/Marten/Linq/IMartenQueryable.cs +++ b/src/Marten/Linq/IMartenQueryable.cs @@ -143,4 +143,6 @@ IMartenQueryableIncludeBuilder Include( /// IMartenQueryableIncludeBuilder Include( IDictionary> dictionary) where TInclude : notnull where TKey : notnull; + + IMartenQueryable WhereSub(Expression> predicate) where TSub : T; } diff --git a/src/Marten/Linq/MartenLinqQueryable.cs b/src/Marten/Linq/MartenLinqQueryable.cs index 4eb8f34579..68e3b10b3f 100644 --- a/src/Marten/Linq/MartenLinqQueryable.cs +++ b/src/Marten/Linq/MartenLinqQueryable.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using JasperFx.Core; +using JasperFx.Core.Reflection; using Marten.Internal.Sessions; using Marten.Linq.Includes; using Marten.Linq.Parsing; @@ -125,6 +126,11 @@ public IMartenQueryableIncludeBuilder Include return new MartenQueryableIncludeBuilder(this, dictionary); } + public IMartenQueryable WhereSub(Expression> predicate) where TSub : T => + (IMartenQueryable)this.Where( + Expression.Lambda>(predicate.Body, Expression.Parameter(typeof(T), "x")) + ); + public IEnumerator GetEnumerator() { return Provider.Execute>(Expression).GetEnumerator();