diff --git a/Microsoft.Azure.Cosmos/src/Linq/SubtreeEvaluator.cs b/Microsoft.Azure.Cosmos/src/Linq/SubtreeEvaluator.cs index 4d67fbd67d..7b6c8e57c9 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/SubtreeEvaluator.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/SubtreeEvaluator.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos.Linq { using System; using System.Collections.Generic; + using System.Collections.ObjectModel; using System.Linq.Expressions; using System.Reflection; @@ -40,7 +41,21 @@ public override Expression Visit(Expression expression) protected override Expression VisitMemberInit(MemberInitExpression node) { - return node; + // Rebuild the MemberInit manually and intentionally do NOT visit node.NewExpression. + // The Nominator nominates a parameterless `new T()` as a candidate (CanBeEvaluated + // returns true for any non-Parameter / non-Lambda expression). Routing it through + // our overridden Visit would fold it into a ConstantExpression of the constructed + // CLR instance. Expression.MemberInit requires a NewExpression as its first argument, + // not a ConstantExpression, so that path would throw InvalidOperationException at + // runtime. We only need to recurse into the bindings to fold closure-captured + // variables (and other independent sub-trees) in initializers — see issue #1664. + ReadOnlyCollection newBindings = Visit(node.Bindings, this.VisitMemberBinding); + if (newBindings == node.Bindings) + { + return node; + } + + return Expression.MemberInit(node.NewExpression, newBindings); } private Expression EvaluateMemberAccess(Expression expression) diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Linq/ConstantEvaluatorTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Linq/ConstantEvaluatorTests.cs new file mode 100644 index 0000000000..255ea1ddb5 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Linq/ConstantEvaluatorTests.cs @@ -0,0 +1,206 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Linq +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Linq.Expressions; + using Microsoft.VisualStudio.TestTools.UnitTesting; + + [TestClass] + public class ConstantEvaluatorTests + { + [TestMethod] + public void ClosuresInsideMemberInitExpressionAreFolded() + { + int captured = 1; + Expression> expression = x => new TestClass { Property = x + captured }; + + Expression folded = ConstantEvaluator.PartialEval(expression.Body); + + MemberInitExpression memberInit = AssertMemberInit(folded, typeof(TestClass)); + MemberAssignment assignment = AssertSingleMemberAssignment(memberInit, nameof(TestClass.Property)); + BinaryExpression binary = AssertBinary(assignment.Expression, ExpressionType.Add); + AssertParameter(binary.Left, "x"); + AssertConstant(binary.Right, 1); + } + + [TestMethod] + public void ClosuresInsideAnonymousObjectAreFolded() + { + int captured = 1; + Expression> expression = x => new { Property = x + captured }; + + Expression folded = ConstantEvaluator.PartialEval(expression.Body); + + NewExpression newExpression = AssertNew(folded); + Assert.AreEqual(1, newExpression.Arguments.Count); + BinaryExpression binary = AssertBinary(newExpression.Arguments[0], ExpressionType.Add); + AssertParameter(binary.Left, "x"); + AssertConstant(binary.Right, 1); + } + + // Regression test for https://github.com/Azure/azure-cosmos-dotnet-v3/issues/1664. + // Mirrors the original bug report shape: a dictionary indexer keyed by a closure-captured + // variable, used inside a class member initializer. Before the fix this folded to + // `{"k": "Test"}["k"]` instead of `"Test"`, producing invalid SQL at the Cosmos backend. + // The parameter reference `q.StringProperty` anchors the MemberInit so the Nominator + // cannot collapse the entire expression to a single constant — only the closure-only + // dictionary indexer sub-tree should fold. + [TestMethod] + public void ClosuresUsedAsDictionaryIndexerInsideMemberInitAreFolded() + { + Dictionary map = new Dictionary { ["k"] = "Test" }; + string capturedKey = "k"; + Expression> expression = + q => new TestClass { StringProperty = q.StringProperty + map[capturedKey] }; + + Expression folded = ConstantEvaluator.PartialEval(expression.Body); + + MemberInitExpression memberInit = AssertMemberInit(folded, typeof(TestClass)); + MemberAssignment assignment = AssertSingleMemberAssignment(memberInit, nameof(TestClass.StringProperty)); + BinaryExpression binary = AssertBinary(assignment.Expression, ExpressionType.Add); + // Left side stays as `q.StringProperty` (parameter-bound member access). + MemberExpression leftMember = binary.Left as MemberExpression; + Assert.IsNotNull(leftMember, $"Expected MemberExpression on left of Add but got {binary.Left?.NodeType.ToString() ?? ""}."); + AssertParameter(leftMember.Expression, "q"); + Assert.AreEqual(nameof(TestClass.StringProperty), leftMember.Member.Name); + // Right side must be the folded literal — not a `Dictionary[indexer]` expression. + AssertConstant(binary.Right, "Test"); + } + + [TestMethod] + public void ClosuresInsideNestedMemberInitAreFolded() + { + int captured = 7; + Expression> expression = x => new OuterTestClass + { + Inner = new TestClass { Property = x + captured } + }; + + Expression folded = ConstantEvaluator.PartialEval(expression.Body); + + MemberInitExpression outerInit = AssertMemberInit(folded, typeof(OuterTestClass)); + MemberAssignment innerAssignment = AssertSingleMemberAssignment(outerInit, nameof(OuterTestClass.Inner)); + MemberInitExpression innerInit = AssertMemberInit(innerAssignment.Expression, typeof(TestClass)); + MemberAssignment propertyAssignment = AssertSingleMemberAssignment(innerInit, nameof(TestClass.Property)); + BinaryExpression binary = AssertBinary(propertyAssignment.Expression, ExpressionType.Add); + AssertParameter(binary.Left, "x"); + AssertConstant(binary.Right, 7); + } + + [TestMethod] + public void ClosuresInsideMemberMemberBindingAreFolded() + { + int captured = 9; + Expression> expression = x => new OuterTestClass + { + Inner = { Property = x + captured } + }; + + Expression folded = ConstantEvaluator.PartialEval(expression.Body); + + MemberInitExpression outerInit = AssertMemberInit(folded, typeof(OuterTestClass)); + Assert.AreEqual(1, outerInit.Bindings.Count); + MemberMemberBinding nested = outerInit.Bindings[0] as MemberMemberBinding; + Assert.IsNotNull(nested, "Expected a MemberMemberBinding for nested initializer syntax."); + Assert.AreEqual(nameof(OuterTestClass.Inner), nested.Member.Name); + Assert.AreEqual(1, nested.Bindings.Count); + MemberAssignment propertyAssignment = nested.Bindings[0] as MemberAssignment; + Assert.IsNotNull(propertyAssignment, "Expected MemberAssignment inside MemberMemberBinding."); + Assert.AreEqual(nameof(TestClass.Property), propertyAssignment.Member.Name); + BinaryExpression binary = AssertBinary(propertyAssignment.Expression, ExpressionType.Add); + AssertParameter(binary.Left, "x"); + AssertConstant(binary.Right, 9); + } + + [TestMethod] + public void ClosuresInsideMemberListBindingAreFolded() + { + int captured = 11; + Expression> expression = x => new OuterWithListTestClass + { + Items = { x + captured } + }; + + Expression folded = ConstantEvaluator.PartialEval(expression.Body); + + MemberInitExpression outerInit = AssertMemberInit(folded, typeof(OuterWithListTestClass)); + Assert.AreEqual(1, outerInit.Bindings.Count); + MemberListBinding listBinding = outerInit.Bindings[0] as MemberListBinding; + Assert.IsNotNull(listBinding, "Expected a MemberListBinding for collection initializer syntax."); + Assert.AreEqual(nameof(OuterWithListTestClass.Items), listBinding.Member.Name); + Assert.AreEqual(1, listBinding.Initializers.Count); + Assert.AreEqual(1, listBinding.Initializers[0].Arguments.Count); + BinaryExpression binary = AssertBinary(listBinding.Initializers[0].Arguments[0], ExpressionType.Add); + AssertParameter(binary.Left, "x"); + AssertConstant(binary.Right, 11); + } + + private static MemberInitExpression AssertMemberInit(Expression expression, Type expectedType) + { + MemberInitExpression memberInit = expression as MemberInitExpression; + Assert.IsNotNull(memberInit, $"Expected MemberInitExpression but got {expression?.NodeType.ToString() ?? ""}."); + Assert.AreEqual(expectedType, memberInit.Type); + return memberInit; + } + + private static NewExpression AssertNew(Expression expression) + { + NewExpression newExpression = expression as NewExpression; + Assert.IsNotNull(newExpression, $"Expected NewExpression but got {expression?.NodeType.ToString() ?? ""}."); + return newExpression; + } + + private static MemberAssignment AssertSingleMemberAssignment(MemberInitExpression memberInit, string memberName) + { + Assert.AreEqual(1, memberInit.Bindings.Count, $"Expected a single binding for member '{memberName}'."); + MemberAssignment assignment = memberInit.Bindings[0] as MemberAssignment; + Assert.IsNotNull(assignment, $"Expected MemberAssignment for member '{memberName}' but got {memberInit.Bindings[0].BindingType}."); + Assert.AreEqual(memberName, assignment.Member.Name); + return assignment; + } + + private static BinaryExpression AssertBinary(Expression expression, ExpressionType nodeType) + { + BinaryExpression binary = expression as BinaryExpression; + Assert.IsNotNull(binary, $"Expected BinaryExpression but got {expression?.NodeType.ToString() ?? ""}."); + Assert.AreEqual(nodeType, binary.NodeType); + return binary; + } + + private static void AssertParameter(Expression expression, string parameterName) + { + ParameterExpression parameter = expression as ParameterExpression; + Assert.IsNotNull(parameter, $"Expected ParameterExpression '{parameterName}' but got {expression?.NodeType.ToString() ?? ""}."); + Assert.AreEqual(parameterName, parameter.Name); + } + + private static void AssertConstant(Expression expression, T expectedValue) + { + ConstantExpression constant = expression as ConstantExpression; + Assert.IsNotNull(constant, $"Expected ConstantExpression with value '{expectedValue}' but got {expression?.NodeType.ToString() ?? ""}."); + Assert.AreEqual(expectedValue, constant.Value); + } + + private class TestClass + { + public int Property { get; set; } + + public string StringProperty { get; set; } + } + + private class OuterTestClass + { + public TestClass Inner { get; set; } = new TestClass(); + } + + private class OuterWithListTestClass + { + public List Items { get; } = new List(); + } + } +} diff --git a/changelog.md b/changelog.md index 9761217a21..a008b800a7 100644 --- a/changelog.md +++ b/changelog.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Bugs Fixed +- [5298](https://github.com/Azure/azure-cosmos-dotnet-v3/pull/5298) LINQ: Fixes constant folding for closure-captured variables inside MemberInitExpression (resolves #1664). Previously, the recursion that partially evaluates expressions terminated whenever it encountered a `MemberInitExpression` node, so captured variables inside object initializers were not folded, producing invalid translated SQL. + #### Other Changes ### [3.61.0-preview.0](https://www.nuget.org/packages/Microsoft.Azure.Cosmos/3.61.0-preview.0) - 2026-5-18