diff --git a/src/LinqTests/Bugs/Bug_4169_compiled_query_contains_startswith_endswith_escape.cs b/src/LinqTests/Bugs/Bug_4169_compiled_query_contains_startswith_endswith_escape.cs new file mode 100644 index 0000000000..c53267fc80 --- /dev/null +++ b/src/LinqTests/Bugs/Bug_4169_compiled_query_contains_startswith_endswith_escape.cs @@ -0,0 +1,93 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading.Tasks; +using Marten; +using Marten.Linq; +using Marten.Testing.Harness; +using Shouldly; + +namespace LinqTests.Bugs; + +public class Bug_4169_compiled_query_contains_startswith_endswith_escape : BugIntegrationContext +{ + public record WildcardDocument(Guid Id, string Name); + + // Each compiled query type must be unique to guarantee fresh code generation + + public class ContainsWithPercent : ICompiledListQuery + { + public string Name { get; set; } + + public ContainsWithPercent(string name) => Name = name; + + public Expression, IEnumerable>> QueryIs() + => q => q.Where(x => x.Name.Contains(Name)); + } + + public class StartsWithWithPercent : ICompiledListQuery + { + public string Name { get; set; } + + public StartsWithWithPercent(string name) => Name = name; + + public Expression, IEnumerable>> QueryIs() + => q => q.Where(x => x.Name.StartsWith(Name)); + } + + public class EndsWithWithPercent : ICompiledListQuery + { + public string Name { get; set; } + + public EndsWithWithPercent(string name) => Name = name; + + public Expression, IEnumerable>> QueryIs() + => q => q.Where(x => x.Name.EndsWith(Name)); + } + + [Fact] + public async Task compiled_contains_should_not_treat_percent_as_wildcard() + { + var match = new WildcardDocument(Guid.NewGuid(), "100% Complete"); + var noMatch = new WildcardDocument(Guid.NewGuid(), "100 Complete"); + theSession.Store(match); + theSession.Store(noMatch); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new ContainsWithPercent("100%"))).ToList(); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(match.Id); + } + + [Fact] + public async Task compiled_starts_with_should_not_treat_percent_as_wildcard() + { + var match = new WildcardDocument(Guid.NewGuid(), "100% of target"); + var noMatch = new WildcardDocument(Guid.NewGuid(), "100 of target"); + theSession.Store(match); + theSession.Store(noMatch); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new StartsWithWithPercent("100%"))).ToList(); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(match.Id); + } + + [Fact] + public async Task compiled_ends_with_should_not_treat_percent_as_wildcard() + { + var match = new WildcardDocument(Guid.NewGuid(), "score: 100%"); + var noMatch = new WildcardDocument(Guid.NewGuid(), "score: 100"); + theSession.Store(match); + theSession.Store(noMatch); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new EndsWithWithPercent("100%"))).ToList(); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(match.Id); + } +} diff --git a/src/LinqTests/Bugs/Bug_4169_compiled_query_ilike_escape.cs b/src/LinqTests/Bugs/Bug_4169_compiled_query_ilike_escape.cs new file mode 100644 index 0000000000..a297efa208 --- /dev/null +++ b/src/LinqTests/Bugs/Bug_4169_compiled_query_ilike_escape.cs @@ -0,0 +1,103 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading.Tasks; +using Marten; +using Marten.Linq; +using Marten.Testing.Harness; +using Shouldly; + +namespace LinqTests.Bugs; + +public class Bug_4169_compiled_query_ilike_escape : BugIntegrationContext +{ + public record IlikeEscapeDocument(Guid Id, string DisplayName); + + // Separate compiled query types per test to guarantee each triggers fresh code generation + // with the specific special character value + + public class FindByDisplayNameWithPercent : ICompiledListQuery + { + public string DisplayName { get; set; } + + public FindByDisplayNameWithPercent(string displayName) + { + DisplayName = displayName; + } + + public Expression, IEnumerable>> QueryIs() + { + return q => q.Where(x => x.DisplayName.Equals(DisplayName, StringComparison.InvariantCultureIgnoreCase)); + } + } + + public class FindByDisplayNameWithUnderscore : ICompiledListQuery + { + public string DisplayName { get; set; } + + public FindByDisplayNameWithUnderscore(string displayName) + { + DisplayName = displayName; + } + + public Expression, IEnumerable>> QueryIs() + { + return q => q.Where(x => x.DisplayName.Equals(DisplayName, StringComparison.InvariantCultureIgnoreCase)); + } + } + + public class FindByDisplayNameWithBackslash : ICompiledListQuery + { + public string DisplayName { get; set; } + + public FindByDisplayNameWithBackslash(string displayName) + { + DisplayName = displayName; + } + + public Expression, IEnumerable>> QueryIs() + { + return q => q.Where(x => x.DisplayName.Equals(DisplayName, StringComparison.InvariantCultureIgnoreCase)); + } + } + + [Fact] + public async Task compiled_query_with_percentage_in_equals_ignore_case() + { + var doc = new IlikeEscapeDocument(Guid.NewGuid(), "100% Complete"); + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new FindByDisplayNameWithPercent("100% Complete"))).ToList(); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(doc.Id); + } + + [Fact] + public async Task compiled_query_with_underscore_in_equals_ignore_case() + { + var doc = new IlikeEscapeDocument(Guid.NewGuid(), "hello_world"); + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new FindByDisplayNameWithUnderscore("hello_world"))).ToList(); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(doc.Id); + } + + [Fact] + public async Task compiled_query_with_backslash_in_equals_ignore_case() + { + var doc = new IlikeEscapeDocument(Guid.NewGuid(), @"path\to\file"); + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new FindByDisplayNameWithBackslash(@"path\to\file"))).ToList(); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(doc.Id); + } +} diff --git a/src/LinqTests/Bugs/Bug_4169_compiled_query_startswith_endswith_swap.cs b/src/LinqTests/Bugs/Bug_4169_compiled_query_startswith_endswith_swap.cs new file mode 100644 index 0000000000..a57d06fda3 --- /dev/null +++ b/src/LinqTests/Bugs/Bug_4169_compiled_query_startswith_endswith_swap.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading.Tasks; +using Marten; +using Marten.Linq; +using Marten.Testing.Harness; +using Shouldly; + +namespace LinqTests.Bugs; + +public class Bug_4169_compiled_query_startswith_endswith_swap : BugIntegrationContext +{ + public record SwapDocument(Guid Id, string Name); + + // Select a scalar value to force StatelessCompiledQuery (no document selector) + + public class StartsWithProjection : ICompiledListQuery + { + public string Prefix { get; set; } + + public StartsWithProjection(string prefix) => Prefix = prefix; + + public Expression, IEnumerable>> QueryIs() + => q => q.Where(x => x.Name.StartsWith(Prefix)).Select(x => x.Name); + } + + public class EndsWithProjection : ICompiledListQuery + { + public string Suffix { get; set; } + + public EndsWithProjection(string suffix) => Suffix = suffix; + + public Expression, IEnumerable>> QueryIs() + => q => q.Where(x => x.Name.EndsWith(Suffix)).Select(x => x.Name); + } + + [Fact] + public async Task compiled_starts_with_projection_should_match_prefix_not_suffix() + { + var match = new SwapDocument(Guid.NewGuid(), "hello world"); + var noMatch = new SwapDocument(Guid.NewGuid(), "world hello"); + theSession.Store(match); + theSession.Store(noMatch); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new StartsWithProjection("hello"))).ToList(); + + results.Count.ShouldBe(1); + results[0].ShouldBe("hello world"); + } + + [Fact] + public async Task compiled_ends_with_projection_should_match_suffix_not_prefix() + { + var match = new SwapDocument(Guid.NewGuid(), "world hello"); + var noMatch = new SwapDocument(Guid.NewGuid(), "hello world"); + theSession.Store(match); + theSession.Store(noMatch); + await theSession.SaveChangesAsync(); + + var results = (await theSession.QueryAsync(new EndsWithProjection("hello"))).ToList(); + + results.Count.ShouldBe(1); + results[0].ShouldBe("world hello"); + } +} diff --git a/src/Marten/Internal/CompiledQueries/ClonedCompiledQuery.cs b/src/Marten/Internal/CompiledQueries/ClonedCompiledQuery.cs index a30d28d6d8..dbc68b7855 100644 --- a/src/Marten/Internal/CompiledQueries/ClonedCompiledQuery.cs +++ b/src/Marten/Internal/CompiledQueries/ClonedCompiledQuery.cs @@ -42,16 +42,26 @@ public Task HandleAsync(DbDataReader reader, IMartenSession session, Cance protected string StartsWith(string value) { - return $"{value}%"; + return $"{EscapeLikeValue(value)}%"; } protected string ContainsString(string value) { - return $"%{value}%"; + return $"%{EscapeLikeValue(value)}%"; } protected string EndsWith(string value) { - return $"%{value}"; + return $"%{EscapeLikeValue(value)}"; + } + + protected string EqualsIgnoreCaseValue(string value) + { + return EscapeLikeValue(value); + } + + private static string EscapeLikeValue(string value) + { + return value.Replace("\\", "\\\\").Replace("%", "\\%").Replace("_", "\\_"); } } diff --git a/src/Marten/Internal/CompiledQueries/ComplexCompiledQuery.cs b/src/Marten/Internal/CompiledQueries/ComplexCompiledQuery.cs index 3df4a3ac38..d3d761842f 100644 --- a/src/Marten/Internal/CompiledQueries/ComplexCompiledQuery.cs +++ b/src/Marten/Internal/CompiledQueries/ComplexCompiledQuery.cs @@ -41,16 +41,26 @@ public Task HandleAsync(DbDataReader reader, IMartenSession session, Cance protected string StartsWith(string value) { - return $"%{value}"; + return $"{EscapeLikeValue(value)}%"; } protected string ContainsString(string value) { - return $"%{value}%"; + return $"%{EscapeLikeValue(value)}%"; } protected string EndsWith(string value) { - return $"{value}%"; + return $"%{EscapeLikeValue(value)}"; + } + + protected string EqualsIgnoreCaseValue(string value) + { + return EscapeLikeValue(value); + } + + private static string EscapeLikeValue(string value) + { + return value.Replace("\\", "\\\\").Replace("%", "\\%").Replace("_", "\\_"); } } diff --git a/src/Marten/Internal/CompiledQueries/StatelessCompiledQuery.cs b/src/Marten/Internal/CompiledQueries/StatelessCompiledQuery.cs index 308ea3377d..c65aa0ee51 100644 --- a/src/Marten/Internal/CompiledQueries/StatelessCompiledQuery.cs +++ b/src/Marten/Internal/CompiledQueries/StatelessCompiledQuery.cs @@ -37,16 +37,26 @@ public Task HandleAsync(DbDataReader reader, IMartenSession session, Cance protected string StartsWith(string value) { - return $"%{value}"; + return $"{EscapeLikeValue(value)}%"; } protected string ContainsString(string value) { - return $"%{value}%"; + return $"%{EscapeLikeValue(value)}%"; } protected string EndsWith(string value) { - return $"{value}%"; + return $"%{EscapeLikeValue(value)}"; + } + + protected string EqualsIgnoreCaseValue(string value) + { + return EscapeLikeValue(value); + } + + private static string EscapeLikeValue(string value) + { + return value.Replace("\\", "\\\\").Replace("%", "\\%").Replace("_", "\\_"); } } diff --git a/src/Marten/Linq/Parsing/Methods/Strings/StringEquals.cs b/src/Marten/Linq/Parsing/Methods/Strings/StringEquals.cs index f7584ebf4f..ec9fe8e6d4 100644 --- a/src/Marten/Linq/Parsing/Methods/Strings/StringEquals.cs +++ b/src/Marten/Linq/Parsing/Methods/Strings/StringEquals.cs @@ -1,8 +1,12 @@ #nullable enable using System; +using System.Reflection; +using JasperFx.CodeGeneration; using JasperFx.Core.Reflection; +using Marten.Internal.CompiledQueries; using Marten.Linq.Members; using Marten.Linq.SqlGeneration.Filters; +using NpgsqlTypes; using Weasel.Postgresql; using Weasel.Postgresql.SqlGeneration; @@ -26,22 +30,48 @@ protected override ISqlFragment buildFilter(bool caseInsensitive, IQueryableMemb } } -internal class StringEqualsIgnoreCaseFilter : ISqlFragment +internal class StringEqualsIgnoreCaseFilter : ISqlFragment, ICompiledQueryAwareFilter { public IQueryableMember Member { get; } public CommandParameter Value { get; } + private readonly string _rawValue; + private MemberInfo? _queryMember; public StringEqualsIgnoreCaseFilter(IQueryableMember member, CommandParameter value) { Member = member; - Value = new CommandParameter(StringComparisonParser.EscapeValue(value.Value?.ToString() ?? string.Empty)); + _rawValue = value.Value as string ?? string.Empty; + Value = new CommandParameter(StringComparisonParser.EscapeValue(_rawValue)); } public void Apply(ICommandBuilder builder) { builder.Append(Member.RawLocator); builder.Append(StringComparisonParser.CaseInSensitiveLike); - Value.Apply(builder); + builder.AppendParameter(StringComparisonParser.EscapeValue(_rawValue)); + ParameterName = builder.LastParameterName; } + public bool TryMatchValue(object value, MemberInfo member) + { + if (_rawValue.Equals(value)) + { + _queryMember = member; + return true; + } + + return false; + } + + public void GenerateCode(GeneratedMethod method, int parameterIndex, string parametersVariableName) + { + var maskedValue = $"EqualsIgnoreCaseValue(_query.{_queryMember!.Name})"; + + method.Frames.Code($@" +{parametersVariableName}[{parameterIndex}].NpgsqlDbType = {{0}}; +{parametersVariableName}[{parameterIndex}].Value = {maskedValue}; +", NpgsqlDbType.Varchar); + } + + public string? ParameterName { get; private set; } }