diff --git a/src/LinqTests/Acceptance/nullable_types.cs b/src/LinqTests/Acceptance/nullable_types.cs index bcd8bb5339..dc0965697f 100644 --- a/src/LinqTests/Acceptance/nullable_types.cs +++ b/src/LinqTests/Acceptance/nullable_types.cs @@ -1,6 +1,7 @@ using System; using System.Linq; using System.Threading.Tasks; +using Marten; using Marten.Testing.Documents; using Marten.Testing.Harness; using Shouldly; @@ -70,6 +71,50 @@ public async Task query_against_null_3() .ShouldBe(1); } + [Fact] + public async Task query_against_nullable_bool_not_true() + { + var target1 = new Target { NullableBoolean = null }; + theSession.Store(target1); + + var target2 = new Target { NullableBoolean = true }; + theSession.Store(target2); + + var target3 = new Target { NullableBoolean = false }; + theSession.Store(target3); + + await theSession.SaveChangesAsync(); + + theSession.Logger = new TestOutputMartenLogger(_output); + + var list = await theSession.Query().Where(x => x.NullableBoolean != true).ToListAsync(); + list.Count.ShouldBe(2); + list.Any(x => x.Id == target1.Id).ShouldBeTrue(); + list.Any(x => x.Id == target3.Id).ShouldBeTrue(); + } + + [Fact] + public async Task query_against_nullable_bool_not_false() + { + var target1 = new Target { NullableBoolean = null }; + theSession.Store(target1); + + var target2 = new Target { NullableBoolean = true }; + theSession.Store(target2); + + var target3 = new Target { NullableBoolean = false }; + theSession.Store(target3); + + await theSession.SaveChangesAsync(); + + theSession.Logger = new TestOutputMartenLogger(_output); + + var list = await theSession.Query().Where(x => x.NullableBoolean != false).ToListAsync(); + list.Count.ShouldBe(2); + list.Any(x => x.Id == target1.Id).ShouldBeTrue(); + list.Any(x => x.Id == target2.Id).ShouldBeTrue(); + } + [Fact] public async Task query_against_null_4() { diff --git a/src/Marten/Linq/Members/BooleanMember.cs b/src/Marten/Linq/Members/BooleanMember.cs index 339d230fcd..8253162f0d 100644 --- a/src/Marten/Linq/Members/BooleanMember.cs +++ b/src/Marten/Linq/Members/BooleanMember.cs @@ -1,5 +1,8 @@ #nullable enable +using System; +using System.Linq.Expressions; using System.Reflection; +using JasperFx.Core.Reflection; using Marten.Linq.SqlGeneration.Filters; using Weasel.Postgresql.SqlGeneration; @@ -7,14 +10,40 @@ namespace Marten.Linq.Members; internal class BooleanMember: QueryableMember, IComparableMember, IBooleanMember { + private readonly bool _isNullable; + public BooleanMember(IQueryableMember parent, Casing casing, MemberInfo member, string pgType): base(parent, casing, member) { TypedLocator = $"CAST({RawLocator} as {pgType})"; + + _isNullable = member.GetRawMemberType().IsNullable(); } public ISqlFragment BuildIsTrueFragment() { return new BooleanFieldIsTrue(this); } + + public override ISqlFragment CreateComparison(string op, ConstantExpression constant) + { + if (constant.Value == null) + { + return op == "=" ? new IsNullFilter(this) : new IsNotNullFilter(this); + } + + if (_isNullable && op == "!=") + { + if (constant.Value.Equals(true)) + { + return CompoundWhereFragment.Or(new IsNullFilter(this), base.CreateComparison(op, constant)); + } + else + { + return CompoundWhereFragment.Or(new IsNullFilter(this), base.CreateComparison(op, constant)); + } + } + + return base.CreateComparison(op, constant); + } }