diff --git a/src/LinqTests/Bugs/Bug_3009_select_before_where.cs b/src/LinqTests/Bugs/Bug_3009_select_before_where.cs new file mode 100644 index 0000000000..d6b90c118d --- /dev/null +++ b/src/LinqTests/Bugs/Bug_3009_select_before_where.cs @@ -0,0 +1,247 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Marten; +using Marten.Testing.Documents; +using Marten.Testing.Harness; +using Shouldly; + +namespace LinqTests.Bugs; + +public class Bug_3009_select_before_where: BugIntegrationContext +{ + [Fact] + public async Task select_before_where_with_different_type() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "low" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "mid" } }; + var doc3 = new DocWithInner { Id = Guid.NewGuid(), Name = "three", Inner = new InnerDoc { Value = 90, Text = "high" } }; + + theSession.Store(doc1, doc2, doc3); + await theSession.SaveChangesAsync(); + + // Select().Where() - the problematic ordering from GH-3009 + var results = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x.Value > 40) + .ToListAsync(); + + results.Count.ShouldBe(2); + results.ShouldContain(x => x.Value == 50); + results.ShouldContain(x => x.Value == 90); + } + + [Fact] + public async Task select_before_where_matches_where_before_select_different_type() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "low" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "mid" } }; + var doc3 = new DocWithInner { Id = Guid.NewGuid(), Name = "three", Inner = new InnerDoc { Value = 90, Text = "high" } }; + + theSession.Store(doc1, doc2, doc3); + await theSession.SaveChangesAsync(); + + // Normal order: Where().Select() + var expected = await theSession.Query() + .Where(x => x.Inner.Value > 40) + .Select(x => x.Inner) + .ToListAsync(); + + // Reversed order: Select().Where() + var actual = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x.Value > 40) + .ToListAsync(); + + actual.Count.ShouldBe(expected.Count); + actual.Select(x => x.Value).OrderBy(x => x) + .ShouldBe(expected.Select(x => x.Value).OrderBy(x => x)); + } + + [Fact] + public async Task select_before_where_with_same_type() + { + // Target.Inner is also of type Target, so this tests same-type Select hoisting + var targets = Target.GenerateRandomData(50).ToArray(); + await theStore.BulkInsertAsync(targets); + + // Only targets where Inner is not null + var targetsWithInner = targets.Where(x => x.Inner != null).ToArray(); + + // Normal order + var expected = await theSession.Query() + .Where(x => x.Inner != null && x.Inner.Number > 0) + .Select(x => x.Inner) + .ToListAsync(); + + // Reversed order + var actual = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x != null && x.Number > 0) + .ToListAsync(); + + actual.Count.ShouldBe(expected.Count); + } + + [Fact] + public async Task select_before_multiple_where_clauses() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "alpha" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "beta" } }; + var doc3 = new DocWithInner { Id = Guid.NewGuid(), Name = "three", Inner = new InnerDoc { Value = 90, Text = "gamma" } }; + + theSession.Store(doc1, doc2, doc3); + await theSession.SaveChangesAsync(); + + // Multiple Where clauses after Select + var results = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x.Value > 5) + .Where(x => x.Value < 80) + .ToListAsync(); + + results.Count.ShouldBe(2); + results.ShouldContain(x => x.Value == 10); + results.ShouldContain(x => x.Value == 50); + } + + [Fact] + public async Task select_before_where_with_string_comparison() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "alpha" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "beta" } }; + var doc3 = new DocWithInner { Id = Guid.NewGuid(), Name = "three", Inner = new InnerDoc { Value = 90, Text = "gamma" } }; + + theSession.Store(doc1, doc2, doc3); + await theSession.SaveChangesAsync(); + + var results = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x.Text == "beta") + .ToListAsync(); + + results.Count.ShouldBe(1); + results[0].Value.ShouldBe(50); + } + + [Fact] + public async Task select_before_where_with_first_or_default() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "alpha" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "beta" } }; + + theSession.Store(doc1, doc2); + await theSession.SaveChangesAsync(); + + var result = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x.Text == "beta") + .FirstOrDefaultAsync(); + + result.ShouldNotBeNull(); + result.Value.ShouldBe(50); + } + + [Fact] + public async Task select_before_where_with_count() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "alpha" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "beta" } }; + var doc3 = new DocWithInner { Id = Guid.NewGuid(), Name = "three", Inner = new InnerDoc { Value = 90, Text = "gamma" } }; + + theSession.Store(doc1, doc2, doc3); + await theSession.SaveChangesAsync(); + + var count = await theSession.Query() + .Select(x => x.Inner) + .Where(x => x.Value > 40) + .CountAsync(); + + count.ShouldBe(2); + } + + [Fact] + public async Task where_before_and_after_select() + { + var doc1 = new DocWithInner { Id = Guid.NewGuid(), Name = "one", Inner = new InnerDoc { Value = 10, Text = "alpha" } }; + var doc2 = new DocWithInner { Id = Guid.NewGuid(), Name = "two", Inner = new InnerDoc { Value = 50, Text = "beta" } }; + var doc3 = new DocWithInner { Id = Guid.NewGuid(), Name = "three", Inner = new InnerDoc { Value = 90, Text = "gamma" } }; + + theSession.Store(doc1, doc2, doc3); + await theSession.SaveChangesAsync(); + + // Pre-Select Where on document + post-Select Where on projected type + var results = await theSession.Query() + .Where(x => x.Name != "one") + .Select(x => x.Inner) + .Where(x => x.Value < 80) + .ToListAsync(); + + results.Count.ShouldBe(1); + results[0].Value.ShouldBe(50); + } + + [Fact] + public async Task select_deep_member_before_where() + { + var doc1 = new DocWithNested + { + Id = Guid.NewGuid(), + Level1 = new Level1Doc + { + Level2 = new Level2Doc { Score = 100, Label = "a" } + } + }; + var doc2 = new DocWithNested + { + Id = Guid.NewGuid(), + Level1 = new Level1Doc + { + Level2 = new Level2Doc { Score = 200, Label = "b" } + } + }; + + theSession.Store(doc1, doc2); + await theSession.SaveChangesAsync(); + + // Select a deeply nested member, then filter + var results = await theSession.Query() + .Select(x => x.Level1.Level2) + .Where(x => x.Score > 150) + .ToListAsync(); + + results.Count.ShouldBe(1); + results[0].Label.ShouldBe("b"); + } +} + +public class DocWithInner +{ + public Guid Id { get; set; } + public string Name { get; set; } + public InnerDoc Inner { get; set; } +} + +public class InnerDoc +{ + public int Value { get; set; } + public string Text { get; set; } +} + +public class DocWithNested +{ + public Guid Id { get; set; } + public Level1Doc Level1 { get; set; } +} + +public class Level1Doc +{ + public Level2Doc Level2 { get; set; } +} + +public class Level2Doc +{ + public int Score { get; set; } + public string Label { get; set; } +} diff --git a/src/Marten/Linq/Parsing/Operators/SelectOperator.cs b/src/Marten/Linq/Parsing/Operators/SelectOperator.cs index 0c4235c301..a73c9ecc63 100644 --- a/src/Marten/Linq/Parsing/Operators/SelectOperator.cs +++ b/src/Marten/Linq/Parsing/Operators/SelectOperator.cs @@ -1,4 +1,5 @@ #nullable enable +using System; using System.Linq; using System.Linq.Expressions; @@ -12,8 +13,14 @@ public SelectOperator(): base("Select") public override void Apply(ILinqQuery query, MethodCallExpression expression) { - var usage = query.CollectionUsageFor(expression); + // Capture the current usage before CollectionUsageFor potentially creates a new one. + // Due to outermost-to-innermost expression tree traversal, any WhereExpressions + // already on the current usage were added by operators that come AFTER Select in + // user code (e.g., .Select(...).Where(...)) and need to be hoisted. + var previousUsage = query.CurrentUsage; + var select = expression.Arguments.Last(); + LambdaExpression? selectLambda = null; if (select is UnaryExpression e) { select = e.Operand; @@ -21,9 +28,72 @@ public override void Apply(ILinqQuery query, MethodCallExpression expression) if (select is LambdaExpression l) { + selectLambda = l; select = l.Body; } + var usage = query.CollectionUsageFor(expression); + + // Expression hoisting for .Select().Where() chains (GH-3009). + // When the select body is a simple member access (e.g., x => x.Inner), we can + // rewrite post-Select Where expressions to reference the original document type + // by prepending the select member path. + if (selectLambda != null && select is MemberExpression memberSelect && previousUsage != null + && previousUsage.WhereExpressions.Any()) + { + if (previousUsage == usage) + { + // Same-type case: Select projects to the same type (e.g., Target.Inner is Target). + // Force a new CollectionUsage so post-Select wheres are separated from pre-Select ones. + usage = query.StartNewCollectionUsageFor(expression); + } + + // Rewrite and move Where expressions from the post-Select usage to the document usage + HoistWhereExpressions(previousUsage, usage, previousUsage.ElementType, memberSelect); + } + usage.SelectExpression = select; } + + internal static void HoistWhereExpressions( + CollectionUsage source, CollectionUsage target, + Type projectedType, MemberExpression selectBody) + { + var rewriter = new PostSelectExpressionRewriter(projectedType, selectBody); + + foreach (var where in source.WhereExpressions) + { + target.WhereExpressions.Add(rewriter.Visit(where)); + } + + source.WhereExpressions.Clear(); + } + + /// + /// Rewrites expressions from a post-Select Where clause by replacing ParameterExpression + /// references of the projected type with the Select body (a MemberExpression). + /// For example, transforms y.Value > 5 (where y is the projected type) into + /// x.Inner.Value > 5 (where x.Inner is the Select body). + /// + private class PostSelectExpressionRewriter: ExpressionVisitor + { + private readonly Type _projectedType; + private readonly MemberExpression _selectBody; + + public PostSelectExpressionRewriter(Type projectedType, MemberExpression selectBody) + { + _projectedType = projectedType; + _selectBody = selectBody; + } + + protected override Expression VisitParameter(ParameterExpression node) + { + if (node.Type == _projectedType) + { + return _selectBody; + } + + return base.VisitParameter(node); + } + } }