Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading.Tasks;
using Marten;
using Marten.Linq;
using Marten.Testing.Documents;
using Marten.Testing.Harness;
using Shouldly;

namespace LinqTests.Bugs;

public class compiled_query_problem_with_search_by_string_and_string_collections(DefaultStoreFixture fixture): IntegrationContext(fixture)
{
protected override async Task fixtureSetup()
{
await theStore.Advanced.ResetAllData();
}

public class IssuesByTitles: ICompiledListQuery<Issue>, IQueryPlanning
{
public required string[] Titles { get; set; }
public required string Status { get; set; }

public Expression<Func<IMartenQueryable<Issue>, IEnumerable<Issue>>> QueryIs()
{
return query => query.Where(x => x.Status == Status && x.Title.IsOneOf(Titles));
}
void IQueryPlanning.SetUniqueValuesForQueryPlanning()
{
Status = "status";
Titles = ["title"];
}
}

[Fact]
public async Task can_search_isOneOf_strings_with_compiled_queries_and_query_planning()
{
var issue1 = new Issue { Title = "Issue1", Status = "Open" };
var issue2 = new Issue { Title = "Issue2", Status = "Open"};
var issue3 = new Issue { Title = "Issue3", Status = "Open" };

theSession.Store(issue1, issue2, issue3);
await theSession.SaveChangesAsync();

await using var session = theStore.QuerySession();
var query = new IssuesByTitles { Titles = [issue1.Title, issue2.Title], Status = issue1.Status };
var issues = await session.QueryAsync(query);

issues.Count().ShouldBe(2);
}
}
76 changes: 76 additions & 0 deletions src/Marten/Internal/CompiledQueries/ArrayParameterFinder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Marten.Internal.CompiledQueries;

/// <summary>
/// Parameter finder for array-typed query members (string[], Guid[], int[], etc.)
/// used in compiled queries with operators like IsOneOf().
/// </summary>
internal class ArrayParameterFinder<TElement> : IParameterFinder
{
private readonly Func<int, TElement[]> _uniqueElementValues;

public ArrayParameterFinder(Func<int, TElement[]> uniqueElementValues)
{
_uniqueElementValues = uniqueElementValues;
}

public Type DotNetType => typeof(TElement[]);

public Queue<object> UniqueValueQueue(Type type)
{
// Each unique value is itself a TElement[] with unique content
var queue = new Queue<object>();
for (var i = 0; i < 20; i++)
{
queue.Enqueue(_uniqueElementValues(i + 1));
}
return queue;
}

public bool Matches(Type memberType)
{
return memberType == typeof(TElement[]);
}

public bool AreValuesUnique(object query, CompiledQueryPlan plan)
{
var members = plan.QueryMembers.OfType<IQueryMember<TElement[]>>().ToArray();

if (members.Length == 0)
{
return true;
}

// For arrays, check that each member has a distinct array (by reference or content)
return members.Select(x => x.GetValue(query))
.Distinct(new ArrayContentComparer<TElement>())
.Count() == members.Length;
}
}

internal class ArrayContentComparer<T> : IEqualityComparer<T[]?>
{
public bool Equals(T[]? x, T[]? y)
{
if (ReferenceEquals(x, y)) return true;
if (x == null || y == null) return false;
return x.SequenceEqual(y);
}

public int GetHashCode(T[]? obj)
{
if (obj == null) return 0;
unchecked
{
var hash = 17;
foreach (var item in obj)
{
hash = hash * 31 + (item?.GetHashCode() ?? 0);
}
return hash;
}
}
}
16 changes: 16 additions & 0 deletions src/Marten/Internal/CompiledQueries/CompiledQueryPlan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ private void sortMembers()
{
IncludeMembers.Add(member);
}
else if (memberType.IsArray && QueryCompiler.Finders.Any(x => x.Matches(memberType.GetElementType()!)))
{
// Arrays like string[], int[], Guid[] etc. whose element type has a registered
// parameter finder should be treated as query parameters, NOT as include members.
// This check must come before the IList<> check since arrays implement IList<T>.
if (member is PropertyInfo)
{
var queryMember = typeof(PropertyQueryMember<>).CloseAndBuildAs<IQueryMember>(member, memberType);
QueryMembers.Add(queryMember);
}
else if (member is FieldInfo)
{
var queryMember = typeof(FieldQueryMember<>).CloseAndBuildAs<IQueryMember>(member, memberType);
QueryMembers.Add(queryMember);
}
}
else if (memberType.Closes(typeof(IList<>)))
{
IncludeMembers.Add(member);
Expand Down
39 changes: 38 additions & 1 deletion src/Marten/Internal/CompiledQueries/ParameterUsage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,47 @@ private static string npgsqlDataTypeInCodeFor(NpgsqlParameter parameter)
private void generateSimpleCode(GeneratedMethod method, MemberInfo member, Type memberType,
string parametersVariableName)
{
method.Frames.Code($@"
// Array types like string[], Guid[], int[] need composite NpgsqlDbType (Array | ElementType)
// which can't be passed as a single enum value to the code generation template
if (memberType.IsArray)
{
var dbTypeCode = npgsqlArrayDbTypeCodeFor(memberType);
method.Frames.Code(
$"{parametersVariableName}[{Index}].NpgsqlDbType = {dbTypeCode};\n" +
$"{parametersVariableName}[{Index}].Value = _query.{member.Name};");
}
else
{
method.Frames.Code($@"
{parametersVariableName}[{Index}].NpgsqlDbType = {{0}};
{parametersVariableName}[{Index}].Value = _query.{member.Name};
", PostgresqlProvider.Instance.ToParameterType(memberType));
}
}

private static string npgsqlArrayDbTypeCodeFor(Type arrayType)
{
var elementType = arrayType.GetElementType()!;
var npgsqlTypeName = typeof(NpgsqlDbType).FullNameInCode();

if (elementType == typeof(string))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Varchar}";
if (elementType == typeof(Guid))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Uuid}";
if (elementType == typeof(int))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Integer}";
if (elementType == typeof(long))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Bigint}";
if (elementType == typeof(float))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Real}";
if (elementType == typeof(decimal))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Numeric}";
if (elementType == typeof(DateTime))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.Timestamp}";
if (elementType == typeof(DateTimeOffset))
return $"{npgsqlTypeName}.{NpgsqlDbType.Array} | {npgsqlTypeName}.{NpgsqlDbType.TimestampTz}";

throw new NotSupportedException($"Array type {arrayType.FullNameInCode()} is not supported for compiled query parameters");
}

private void generateEnumCode(GeneratedMethod method, StoreOptions storeOptions, MemberInfo member,
Expand Down
32 changes: 32 additions & 0 deletions src/Marten/Internal/CompiledQueries/QueryCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,32 @@ static QueryCompiler()

return values;
});

// Register array-type finders for compiled query members used with IsOneOf() etc.
forArrayType<string>(count =>
{
var values = new string[count];
for (var i = 0; i < count; i++) values[i] = $"_plan_{Guid.NewGuid():N}";
return values;
});
forArrayType<Guid>(count =>
{
var values = new Guid[count];
for (var i = 0; i < count; i++) values[i] = Guid.NewGuid();
return values;
});
forArrayType<int>(count =>
{
var values = new int[count];
for (var i = 0; i < count; i++) values[i] = -(500000 + i);
return values;
});
forArrayType<long>(count =>
{
var values = new long[count];
for (var i = 0; i < count; i++) values[i] = -(600000L + i);
return values;
});
}

private static void forType<T>(Func<int, T[]> uniqueValues)
Expand All @@ -119,6 +145,12 @@ private static void forType<T>(Func<int, T[]> uniqueValues)
Finders.Add(finder);
}

private static void forArrayType<T>(Func<int, T[]> uniqueElementValues)
{
var finder = new ArrayParameterFinder<T>(uniqueElementValues);
Finders.Add(finder);
}

public static CompiledQueryPlan BuildQueryPlan(QuerySession session, Type queryType, StoreOptions storeOptions)
{
var querySignature = queryType.FindInterfaceThatCloses(typeof(ICompiledQuery<,>));
Expand Down
21 changes: 20 additions & 1 deletion src/Marten/Internal/CompiledQueries/QueryMember.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,29 @@ public void TryWriteValue(UniqueValueSource valueSource, object query)

public MemberInfo Member { get; }

private static bool valuesAreEqual(object value, object? parameterValue)
{
if (value.Equals(parameterValue)) return true;

// For array types (string[], Guid[], int[], etc.), Equals() does reference comparison.
// We need structural comparison to match array parameter values from compiled query planning.
if (value is Array valueArray && parameterValue is Array paramArray)
{
if (valueArray.Length != paramArray.Length) return false;
for (var i = 0; i < valueArray.Length; i++)
{
if (!Equals(valueArray.GetValue(i), paramArray.GetValue(i))) return false;
}
return true;
}

return false;
}

private bool tryToFind(NpgsqlParameter parameter, ICompiledQueryAwareFilter[] filters,
object value, out ICompiledQueryAwareFilter? filterUsed)
{
if (filters.All(x => x.ParameterName != parameter.ParameterName) && value.Equals(parameter.Value))
if (filters.All(x => x.ParameterName != parameter.ParameterName) && valuesAreEqual(value, parameter.Value))
{
filterUsed = null;
return true;
Expand Down
Loading