From d5fe6704ade455565d7086231ed028e1bbb39f0e Mon Sep 17 00:00:00 2001 From: "Jeremy D. Miller" Date: Wed, 11 Mar 2026 06:08:18 -0500 Subject: [PATCH] Add GroupBy LINQ operator, CheckExistsAsync, AssignTagWhere, and EventsExistByTags - Implement GroupBy LINQ operator with key selectors, aggregates (Count, Sum, Min, Max, Avg), composite keys, and HAVING clauses - Add CheckExistsAsync API with strong-typed ID support and batch query integration - Add EventsExistByTags for checking event existence by tag conditions in DCB scenarios - Add AssignTagWhere operation for retroactively tagging events matching a LINQ WHERE clause - Include tests and documentation for all new features Co-Authored-By: Claude Opus 4.6 --- docs/.vitepress/config.mts | 1 + docs/documents/querying/check-exists.md | 35 ++ docs/documents/querying/linq/operators.md | 55 +- docs/events/dcb.md | 11 + .../Reading/check_document_exists.cs | 195 +++++++ .../Dcb/assign_tag_where_tests.cs | 207 ++++++++ .../dcb_tag_query_and_consistency_tests.cs | 86 +++ src/LinqTests/Operators/group_by_operator.cs | 225 ++++++++ .../Events/Dcb/EventsExistByTagsHandler.cs | 130 +++++ src/Marten/Events/EventStore.Dcb.cs | 8 + src/Marten/Events/EventStore.cs | 44 ++ src/Marten/Events/IEventStoreOperations.cs | 15 + .../Operations/AssignTagWhereOperation.cs | 65 +++ src/Marten/IQuerySession.cs | 50 ++ .../Sessions/QuerySession.CheckExists.cs | 86 +++ .../Linq/CollectionUsage.Compilation.cs | 157 ++++++ src/Marten/Linq/CollectionUsage.cs | 1 + src/Marten/Linq/GroupByData.cs | 15 + .../Linq/Parsing/GroupBySelectParser.cs | 488 ++++++++++++++++++ .../Linq/Parsing/Operators/GroupByOperator.cs | 27 + .../Linq/Parsing/Operators/OperatorLibrary.cs | 1 + .../QueryHandlers/CheckExistsByIdHandler.cs | 68 +++ .../Linq/SqlGeneration/SelectorStatement.cs | 26 + .../BatchQuerying/BatchedQuery.Events.cs | 8 + .../Services/BatchQuerying/BatchedQuery.cs | 52 ++ .../Services/BatchQuerying/IBatchedQuery.cs | 53 ++ .../check_exists_with_strong_typed_ids.cs | 181 +++++++ 27 files changed, 2287 insertions(+), 3 deletions(-) create mode 100644 docs/documents/querying/check-exists.md create mode 100644 src/DocumentDbTests/Reading/check_document_exists.cs create mode 100644 src/EventSourcingTests/Dcb/assign_tag_where_tests.cs create mode 100644 src/LinqTests/Operators/group_by_operator.cs create mode 100644 src/Marten/Events/Dcb/EventsExistByTagsHandler.cs create mode 100644 src/Marten/Events/Operations/AssignTagWhereOperation.cs create mode 100644 src/Marten/Internal/Sessions/QuerySession.CheckExists.cs create mode 100644 src/Marten/Linq/GroupByData.cs create mode 100644 src/Marten/Linq/Parsing/GroupBySelectParser.cs create mode 100644 src/Marten/Linq/Parsing/Operators/GroupByOperator.cs create mode 100644 src/Marten/Linq/QueryHandlers/CheckExistsByIdHandler.cs create mode 100644 src/ValueTypeTests/StrongTypedId/check_exists_with_strong_typed_ids.cs diff --git a/docs/.vitepress/config.mts b/docs/.vitepress/config.mts index 3c960926ea..930d1b3948 100644 --- a/docs/.vitepress/config.mts +++ b/docs/.vitepress/config.mts @@ -116,6 +116,7 @@ const config: UserConfig = { { text: 'Querying Documents', link: '/documents/querying/', collapsed: true, items: [ { text: 'Loading Documents by Id', link: '/documents/querying/byid' }, + { text: 'Checking Document Existence', link: '/documents/querying/check-exists' }, { text: 'Querying Documents with Linq', link: '/documents/querying/linq/' }, { text: 'Supported Linq Operators', link: '/documents/querying/linq/operators' }, { text: 'Querying within Child Collections', link: '/documents/querying/linq/child-collections' }, diff --git a/docs/documents/querying/check-exists.md b/docs/documents/querying/check-exists.md new file mode 100644 index 0000000000..7ffb0df2a9 --- /dev/null +++ b/docs/documents/querying/check-exists.md @@ -0,0 +1,35 @@ +# Checking Document Existence + +Sometimes you only need to know whether a document with a given id exists in the database, without actually loading and deserializing the full document. Marten provides the `CheckExistsAsync` API for this purpose, which issues a lightweight `SELECT EXISTS(...)` query against PostgreSQL. This avoids the overhead of JSON deserialization and object materialization, making it significantly more efficient than loading the document just to check if it's there. + +## Usage + +`CheckExistsAsync` is available on `IQuerySession` (and therefore also on `IDocumentSession`). It supports all identity types: `Guid`, `int`, `long`, `string`, and strongly-typed identifiers. + + + + +## Supported Identity Types + +| Id Type | Supported | +|---------|-----------| +| `Guid` | Yes | +| `int` | Yes | +| `long` | Yes | +| `string` | Yes | +| `object` | Yes (for dynamic id types) | +| Strong-typed ids (Vogen, record structs, etc.) | Yes (via `object` overload) | + +## Batched Queries + +`CheckExists` is also available as part of [batched queries](/documents/querying/batched-queries), allowing you to check existence of multiple documents in a single round-trip to the database: + + + + +## Behavior Notes + +- Returns `true` if the document exists, `false` otherwise. +- Respects soft-delete filters: if a document type uses soft deletes, a soft-deleted document will return `false`. +- Respects multi-tenancy: the check is scoped to the current session's tenant. +- Does **not** load the document into the identity map or trigger any deserialization. diff --git a/docs/documents/querying/linq/operators.md b/docs/documents/querying/linq/operators.md index afdbd325c7..6ed99125df 100644 --- a/docs/documents/querying/linq/operators.md +++ b/docs/documents/querying/linq/operators.md @@ -217,10 +217,59 @@ public void using_take_and_skip(IDocumentSession session) TODO -- link to the paging support -## Grouping Operators +## GroupBy() -Sorry, but Marten does not yet support `GroupBy()`. You can track [this GitHub issue](https://github.com/JasperFx/marten/issues/569) to follow -any future work on this Linq operator. +Marten supports the `GroupBy()` LINQ operator for grouping documents by one or more keys and computing aggregate values. GroupBy translates to SQL `GROUP BY` with aggregate functions like `COUNT`, `SUM`, `MIN`, `MAX`, and `AVG`. + +### Simple Key with Aggregates + + + + +### Composite Key + +You can group by multiple properties using an anonymous type: + +```csharp +var results = await session.Query() + .GroupBy(x => new { x.Color, x.String }) + .Select(g => new { Color = g.Key.Color, Text = g.Key.String, Count = g.Count() }) + .ToListAsync(); +``` + +### Where Before GroupBy + +Filter documents before grouping with a standard `Where()` clause: + +```csharp +var results = await session.Query() + .Where(x => x.Number > 20) + .GroupBy(x => x.Color) + .Select(g => new { Color = g.Key, Count = g.Count() }) + .ToListAsync(); +``` + +### HAVING (Where After GroupBy) + +Filter groups with a `Where()` clause after `GroupBy()` -- this translates to SQL `HAVING`: + +```csharp +var results = await session.Query() + .GroupBy(x => x.Color) + .Where(g => g.Count() > 1) + .Select(g => new { Color = g.Key, Count = g.Count() }) + .ToListAsync(); +``` + +### Supported Aggregates + +The following aggregate methods are supported within GroupBy projections: + +- `g.Count()` / `g.LongCount()` -- `COUNT(*)` +- `g.Sum(x => x.Property)` -- `SUM(property)` +- `g.Min(x => x.Property)` -- `MIN(property)` +- `g.Max(x => x.Property)` -- `MAX(property)` +- `g.Average(x => x.Property)` -- `AVG(property)` ## Distinct() diff --git a/docs/events/dcb.md b/docs/events/dcb.md index 126002580f..11f35b6ebb 100644 --- a/docs/events/dcb.md +++ b/docs/events/dcb.md @@ -224,6 +224,17 @@ catch (DcbConcurrencyException ex) The consistency check only detects events that match the **same tag query**. Events appended to unrelated tags or streams will not cause a violation. ::: +## Checking Event Existence + +If you only need to know whether any events matching a tag query exist -- without loading or deserializing them -- use `EventsExistAsync`. This is a lightweight `SELECT EXISTS(...)` query that avoids the overhead of fetching and materializing event data: + + + + +This is useful for guard clauses and validation logic in DCB workflows where you need to check preconditions before appending new events. + +`EventsExistAsync` is also available in batch queries via `batch.Events.EventsExist(query)`. + ## How It Works ### Storage diff --git a/src/DocumentDbTests/Reading/check_document_exists.cs b/src/DocumentDbTests/Reading/check_document_exists.cs new file mode 100644 index 0000000000..536d2ee42b --- /dev/null +++ b/src/DocumentDbTests/Reading/check_document_exists.cs @@ -0,0 +1,195 @@ +using System; +using System.Threading.Tasks; +using Marten; +using Marten.Testing.Documents; +using Marten.Testing.Harness; +using Shouldly; +using Xunit; + +namespace DocumentDbTests.Reading; + +public class check_document_exists: IntegrationContext +{ + public check_document_exists(DefaultStoreFixture fixture): base(fixture) + { + } + + [Fact] + public async Task check_exists_by_guid_id_hit() + { + var doc = new GuidDoc { Id = Guid.NewGuid() }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync(doc.Id); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_by_guid_id_miss() + { + var exists = await theSession.CheckExistsAsync(Guid.NewGuid()); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_by_int_id_hit() + { + var doc = new IntDoc { Id = 42 }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync(42); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_by_int_id_miss() + { + var exists = await theSession.CheckExistsAsync(999999); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_by_long_id_hit() + { + var doc = new LongDoc { Id = 200L }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync(200L); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_by_long_id_miss() + { + var exists = await theSession.CheckExistsAsync(999999L); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_by_string_id_hit() + { + var doc = new StringDoc { Id = "test-doc" }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync("test-doc"); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_by_string_id_miss() + { + var exists = await theSession.CheckExistsAsync("nonexistent"); + exists.ShouldBeFalse(); + } + + #region sample_check_exists_usage + + [Fact] + public async Task check_exists_by_object_id() + { + var doc = new GuidDoc { Id = Guid.NewGuid() }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + // Use the object overload for dynamic id types + var exists = await theSession.CheckExistsAsync((object)doc.Id); + exists.ShouldBeTrue(); + } + + #endregion +} + +public class check_document_exists_in_batch: IntegrationContext +{ + public check_document_exists_in_batch(DefaultStoreFixture fixture): base(fixture) + { + } + + #region sample_check_exists_batch_usage + + [Fact] + public async Task check_exists_in_batch_by_guid_id() + { + var doc = new GuidDoc { Id = Guid.NewGuid() }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists(doc.Id); + var existsMiss = batch.CheckExists(Guid.NewGuid()); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + #endregion + + [Fact] + public async Task check_exists_in_batch_by_int_id() + { + var doc = new IntDoc { Id = 77 }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists(77); + var existsMiss = batch.CheckExists(888888); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_by_long_id() + { + var doc = new LongDoc { Id = 300L }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists(300L); + var existsMiss = batch.CheckExists(999999L); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_by_string_id() + { + var doc = new StringDoc { Id = "batch-test" }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists("batch-test"); + var existsMiss = batch.CheckExists("nope"); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_by_object_id() + { + var doc = new GuidDoc { Id = Guid.NewGuid() }; + theSession.Store(doc); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists((object)doc.Id); + var existsMiss = batch.CheckExists((object)Guid.NewGuid()); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } +} diff --git a/src/EventSourcingTests/Dcb/assign_tag_where_tests.cs b/src/EventSourcingTests/Dcb/assign_tag_where_tests.cs new file mode 100644 index 0000000000..167ecb5ec6 --- /dev/null +++ b/src/EventSourcingTests/Dcb/assign_tag_where_tests.cs @@ -0,0 +1,207 @@ +#nullable enable +using System; +using System.Linq; +using System.Threading.Tasks; +using JasperFx.Events; +using JasperFx.Events.Tags; +using Marten; +using Marten.Events; +using Marten.Testing.Harness; +using Shouldly; +using Xunit; + +namespace EventSourcingTests.Dcb; + +public record RegionId(Guid Value); + +public record OrderPlaced(string OrderNumber, decimal Amount); +public record OrderShipped(string OrderNumber); +public record OrderCancelled(string OrderNumber, string Reason); + +[Collection("OneOffs")] +public class assign_tag_where_tests : OneOffConfigurationsContext, IAsyncLifetime +{ + private RegionId _eastRegion = null!; + private RegionId _westRegion = null!; + + public Task InitializeAsync() + { + _eastRegion = new RegionId(Guid.NewGuid()); + _westRegion = new RegionId(Guid.NewGuid()); + + StoreOptions(opts => + { + opts.Events.AddEventType(); + opts.Events.AddEventType(); + opts.Events.AddEventType(); + + opts.Events.RegisterTagType("region"); + }); + + return Task.CompletedTask; + } + + public Task DisposeAsync() => Task.CompletedTask; + + [Fact] + public async Task assign_tag_where_by_event_type_name() + { + // Append events WITHOUT tags + var stream1 = Guid.NewGuid(); + theSession.Events.Append(stream1, + new OrderPlaced("ORD-1", 100m), + new OrderShipped("ORD-1")); + await theSession.SaveChangesAsync(); + + // Now retroactively tag all OrderPlaced events with a region + await using var session2 = theStore.LightweightSession(); + var orderPlacedTypeName = theStore.Options.EventGraph.EventMappingFor().EventTypeName; + session2.Events.AssignTagWhere( + e => e.EventTypeName == orderPlacedTypeName, + _eastRegion); + await session2.SaveChangesAsync(); + + // Query by tag - should find only the OrderPlaced event + await using var session3 = theStore.LightweightSession(); + var query = new EventTagQuery().Or(_eastRegion); + var events = await session3.Events.QueryByTagsAsync(query); + + events.Count.ShouldBe(1); + events[0].Data.ShouldBeOfType().OrderNumber.ShouldBe("ORD-1"); + } + + [Fact] + public async Task assign_tag_where_by_stream_id() + { + var stream1 = Guid.NewGuid(); + var stream2 = Guid.NewGuid(); + + theSession.Events.Append(stream1, + new OrderPlaced("ORD-1", 100m), + new OrderShipped("ORD-1")); + theSession.Events.Append(stream2, + new OrderPlaced("ORD-2", 200m)); + await theSession.SaveChangesAsync(); + + // Tag all events in stream1 only + await using var session2 = theStore.LightweightSession(); + session2.Events.AssignTagWhere( + e => e.StreamId == stream1, + _eastRegion); + await session2.SaveChangesAsync(); + + // Query - should find only the 2 events from stream1 + await using var session3 = theStore.LightweightSession(); + var query = new EventTagQuery().Or(_eastRegion); + var events = await session3.Events.QueryByTagsAsync(query); + + events.Count.ShouldBe(2); + events.ShouldAllBe(e => e.StreamId == stream1); + } + + [Fact] + public async Task assign_tag_where_with_compound_predicate() + { + var stream1 = Guid.NewGuid(); + + theSession.Events.Append(stream1, + new OrderPlaced("ORD-1", 100m), + new OrderShipped("ORD-1"), + new OrderCancelled("ORD-1", "changed mind")); + await theSession.SaveChangesAsync(); + + // Tag events that are of type OrderPlaced or OrderCancelled + await using var session2 = theStore.LightweightSession(); + var placedType = theStore.Options.EventGraph.EventMappingFor().EventTypeName; + var cancelledType = theStore.Options.EventGraph.EventMappingFor().EventTypeName; + + session2.Events.AssignTagWhere( + e => e.EventTypeName == placedType || e.EventTypeName == cancelledType, + _eastRegion); + await session2.SaveChangesAsync(); + + // Query - should find 2 events (placed + cancelled, NOT shipped) + await using var session3 = theStore.LightweightSession(); + var query = new EventTagQuery().Or(_eastRegion); + var events = await session3.Events.QueryByTagsAsync(query); + + events.Count.ShouldBe(2); + events.Select(e => e.Data.GetType()).ShouldContain(typeof(OrderPlaced)); + events.Select(e => e.Data.GetType()).ShouldContain(typeof(OrderCancelled)); + events.Select(e => e.Data.GetType()).ShouldNotContain(typeof(OrderShipped)); + } + + [Fact] + public async Task assign_tag_where_is_idempotent() + { + var stream1 = Guid.NewGuid(); + theSession.Events.Append(stream1, new OrderPlaced("ORD-1", 100m)); + await theSession.SaveChangesAsync(); + + var placedType = theStore.Options.EventGraph.EventMappingFor().EventTypeName; + + // Assign the same tag twice - should not fail or duplicate + await using var session2 = theStore.LightweightSession(); + session2.Events.AssignTagWhere( + e => e.EventTypeName == placedType, _eastRegion); + await session2.SaveChangesAsync(); + + await using var session3 = theStore.LightweightSession(); + session3.Events.AssignTagWhere( + e => e.EventTypeName == placedType, _eastRegion); + await session3.SaveChangesAsync(); + + // Should still just find 1 event + await using var session4 = theStore.LightweightSession(); + var query = new EventTagQuery().Or(_eastRegion); + var events = await session4.Events.QueryByTagsAsync(query); + events.Count.ShouldBe(1); + } + + [Fact] + public async Task assign_tag_where_does_not_affect_unmatched_events() + { + var stream1 = Guid.NewGuid(); + var stream2 = Guid.NewGuid(); + + theSession.Events.Append(stream1, new OrderPlaced("ORD-1", 100m)); + theSession.Events.Append(stream2, new OrderPlaced("ORD-2", 200m)); + await theSession.SaveChangesAsync(); + + // Only tag events in stream1 + await using var session2 = theStore.LightweightSession(); + session2.Events.AssignTagWhere( + e => e.StreamId == stream1, _eastRegion); + await session2.SaveChangesAsync(); + + // Tag events in stream2 with different region + await using var session3 = theStore.LightweightSession(); + session3.Events.AssignTagWhere( + e => e.StreamId == stream2, _westRegion); + await session3.SaveChangesAsync(); + + // Verify east only has stream1 + await using var session4 = theStore.LightweightSession(); + var eastEvents = await session4.Events.QueryByTagsAsync( + new EventTagQuery().Or(_eastRegion)); + eastEvents.Count.ShouldBe(1); + eastEvents[0].StreamId.ShouldBe(stream1); + + // Verify west only has stream2 + var westEvents = await session4.Events.QueryByTagsAsync( + new EventTagQuery().Or(_westRegion)); + westEvents.Count.ShouldBe(1); + westEvents[0].StreamId.ShouldBe(stream2); + } + + [Fact] + public async Task assign_tag_where_throws_for_unregistered_tag_type() + { + var unregisteredTag = new StudentId(Guid.NewGuid()); + + Should.Throw(() => + { + theSession.Events.AssignTagWhere(e => e.Sequence > 0, unregisteredTag); + }); + } +} diff --git a/src/EventSourcingTests/Dcb/dcb_tag_query_and_consistency_tests.cs b/src/EventSourcingTests/Dcb/dcb_tag_query_and_consistency_tests.cs index 8c5646ef8d..42ae985394 100644 --- a/src/EventSourcingTests/Dcb/dcb_tag_query_and_consistency_tests.cs +++ b/src/EventSourcingTests/Dcb/dcb_tag_query_and_consistency_tests.cs @@ -517,6 +517,92 @@ await Should.ThrowAsync(async () => }); } + #region sample_marten_dcb_events_exist_async + [Fact] + public async Task events_exist_returns_true_when_matching_events_found() + { + var studentId = new StudentId(Guid.NewGuid()); + var courseId = new CourseId(Guid.NewGuid()); + var streamId = Guid.NewGuid(); + + var enrolled = theSession.Events.BuildEvent(new StudentEnrolled("Alice", "Math")); + enrolled.WithTag(studentId, courseId); + theSession.Events.Append(streamId, enrolled); + await theSession.SaveChangesAsync(); + + // Check existence -- lightweight, no event loading + var query = new EventTagQuery().Or(studentId); + var exists = await theSession.Events.EventsExistAsync(query); + exists.ShouldBeTrue(); + } + #endregion + + [Fact] + public async Task events_exist_returns_false_when_no_matching_events() + { + var studentId = new StudentId(Guid.NewGuid()); + + var query = new EventTagQuery().Or(studentId); + var exists = await theSession.Events.EventsExistAsync(query); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task events_exist_with_event_type_filter() + { + var studentId = new StudentId(Guid.NewGuid()); + var courseId = new CourseId(Guid.NewGuid()); + var streamId = Guid.NewGuid(); + + var enrolled = theSession.Events.BuildEvent(new StudentEnrolled("Alice", "Math")); + enrolled.WithTag(studentId, courseId); + theSession.Events.Append(streamId, enrolled); + await theSession.SaveChangesAsync(); + + // Should find StudentEnrolled + var query1 = new EventTagQuery().Or(studentId); + (await theSession.Events.EventsExistAsync(query1)).ShouldBeTrue(); + + // Should NOT find AssignmentSubmitted (none appended) + var query2 = new EventTagQuery().Or(studentId); + (await theSession.Events.EventsExistAsync(query2)).ShouldBeFalse(); + } + + [Fact] + public async Task events_exist_via_batch_query_positive() + { + var studentId = new StudentId(Guid.NewGuid()); + var courseId = new CourseId(Guid.NewGuid()); + var streamId = Guid.NewGuid(); + + var enrolled = theSession.Events.BuildEvent(new StudentEnrolled("Alice", "Math")); + enrolled.WithTag(studentId, courseId); + theSession.Events.Append(streamId, enrolled); + await theSession.SaveChangesAsync(); + + await using var session2 = theStore.LightweightSession(); + var batch = session2.CreateBatchQuery(); + var query = new EventTagQuery().Or(studentId); + var existsTask = batch.Events.EventsExist(query); + await batch.Execute(); + + (await existsTask).ShouldBeTrue(); + } + + [Fact] + public async Task events_exist_via_batch_query_negative() + { + var studentId = new StudentId(Guid.NewGuid()); + + await using var session2 = theStore.LightweightSession(); + var batch = session2.CreateBatchQuery(); + var query = new EventTagQuery().Or(studentId); + var existsTask = batch.Events.EventsExist(query); + await batch.Execute(); + + (await existsTask).ShouldBeFalse(); + } + [Fact] public async Task fetch_for_writing_by_tags_throws_on_empty_query() { diff --git a/src/LinqTests/Operators/group_by_operator.cs b/src/LinqTests/Operators/group_by_operator.cs new file mode 100644 index 0000000000..286f23e2b3 --- /dev/null +++ b/src/LinqTests/Operators/group_by_operator.cs @@ -0,0 +1,225 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Marten; +using Marten.Testing.Documents; +using Marten.Testing.Harness; +using Shouldly; + +namespace LinqTests.Operators; + +public class group_by_operator: OneOffConfigurationsContext +{ + private IDocumentStore _store; + private IDocumentSession _session; + + private async Task SetupTargetData() + { + _store = StoreOptions(opts => + { + opts.Schema.For(); + }); + + _session = _store.LightweightSession(); + _disposables.Add(_session); + + // Deterministic data for GroupBy tests + var targets = new[] + { + new Target { Color = Colors.Blue, Number = 10, String = "Alpha", Double = 1.5 }, + new Target { Color = Colors.Blue, Number = 20, String = "Alpha", Double = 2.5 }, + new Target { Color = Colors.Green, Number = 30, String = "Beta", Double = 3.5 }, + new Target { Color = Colors.Green, Number = 40, String = "Beta", Double = 4.5 }, + new Target { Color = Colors.Green, Number = 50, String = "Gamma", Double = 5.5 }, + new Target { Color = Colors.Red, Number = 60, String = "Gamma", Double = 6.5 }, + }; + + _session.Store(targets); + await _session.SaveChangesAsync(); + } + + #region sample_group_by_simple_key_with_count + + [Fact] + public async Task group_by_simple_key_with_count() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Select(g => new { Color = g.Key, Count = g.Count() }) + .ToListAsync(); + + results.Count.ShouldBe(3); + results.Single(x => x.Color == Colors.Blue).Count.ShouldBe(2); + results.Single(x => x.Color == Colors.Green).Count.ShouldBe(3); + results.Single(x => x.Color == Colors.Red).Count.ShouldBe(1); + } + + #endregion + + [Fact] + public async Task group_by_simple_key_with_sum() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Select(g => new { Color = g.Key, Total = g.Sum(x => x.Number) }) + .ToListAsync(); + + results.Count.ShouldBe(3); + results.Single(x => x.Color == Colors.Blue).Total.ShouldBe(30); + results.Single(x => x.Color == Colors.Green).Total.ShouldBe(120); + results.Single(x => x.Color == Colors.Red).Total.ShouldBe(60); + } + + [Fact] + public async Task group_by_simple_key_with_multiple_aggregates() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Select(g => new + { + Color = g.Key, + Count = g.Count(), + Total = g.Sum(x => x.Number), + Min = g.Min(x => x.Number), + Max = g.Max(x => x.Number) + }) + .ToListAsync(); + + results.Count.ShouldBe(3); + + var blue = results.Single(x => x.Color == Colors.Blue); + blue.Count.ShouldBe(2); + blue.Total.ShouldBe(30); + blue.Min.ShouldBe(10); + blue.Max.ShouldBe(20); + + var green = results.Single(x => x.Color == Colors.Green); + green.Count.ShouldBe(3); + green.Total.ShouldBe(120); + green.Min.ShouldBe(30); + green.Max.ShouldBe(50); + } + + [Fact] + public async Task group_by_string_key() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.String) + .Select(g => new { Key = g.Key, Count = g.Count() }) + .ToListAsync(); + + results.Count.ShouldBe(3); + results.Single(x => x.Key == "Alpha").Count.ShouldBe(2); + results.Single(x => x.Key == "Beta").Count.ShouldBe(2); + results.Single(x => x.Key == "Gamma").Count.ShouldBe(2); + } + + [Fact] + public async Task group_by_with_where_before_group() + { + await SetupTargetData(); + + var results = await _session.Query() + .Where(x => x.Number > 20) + .GroupBy(x => x.Color) + .Select(g => new { Color = g.Key, Count = g.Count() }) + .ToListAsync(); + + // Blue has 10, 20 -> both filtered out + // Green has 30, 40, 50 -> 3 pass + // Red has 60 -> 1 passes + results.Count.ShouldBe(2); + results.Single(x => x.Color == Colors.Green).Count.ShouldBe(3); + results.Single(x => x.Color == Colors.Red).Count.ShouldBe(1); + } + + [Fact] + public async Task group_by_with_having() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Where(g => g.Count() > 1) + .Select(g => new { Color = g.Key, Count = g.Count() }) + .ToListAsync(); + + // Blue=2, Green=3, Red=1 -> Red filtered by HAVING + results.Count.ShouldBe(2); + results.ShouldNotContain(x => x.Color == Colors.Red); + } + + [Fact] + public async Task group_by_composite_key() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => new { x.Color, x.String }) + .Select(g => new { Color = g.Key.Color, Text = g.Key.String, Count = g.Count() }) + .ToListAsync(); + + // Blue+Alpha=2, Green+Beta=2, Green+Gamma=1, Red+Gamma=1 + results.Count.ShouldBe(4); + results.Single(x => x.Color == Colors.Blue && x.Text == "Alpha").Count.ShouldBe(2); + results.Single(x => x.Color == Colors.Green && x.Text == "Beta").Count.ShouldBe(2); + results.Single(x => x.Color == Colors.Green && x.Text == "Gamma").Count.ShouldBe(1); + results.Single(x => x.Color == Colors.Red && x.Text == "Gamma").Count.ShouldBe(1); + } + + [Fact] + public async Task group_by_with_average() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Select(g => new { Color = g.Key, Avg = g.Average(x => x.Double) }) + .ToListAsync(); + + results.Count.ShouldBe(3); + results.Single(x => x.Color == Colors.Blue).Avg.ShouldBe(2.0, tolerance: 0.01); + results.Single(x => x.Color == Colors.Green).Avg.ShouldBe(4.5, tolerance: 0.01); + results.Single(x => x.Color == Colors.Red).Avg.ShouldBe(6.5, tolerance: 0.01); + } + + [Fact] + public async Task group_by_select_key_only() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Select(g => g.Key) + .ToListAsync(); + + results.Count.ShouldBe(3); + results.ShouldContain(Colors.Blue); + results.ShouldContain(Colors.Green); + results.ShouldContain(Colors.Red); + } + + [Fact] + public async Task group_by_with_long_count() + { + await SetupTargetData(); + + var results = await _session.Query() + .GroupBy(x => x.Color) + .Select(g => new { Color = g.Key, Count = g.LongCount() }) + .ToListAsync(); + + results.Count.ShouldBe(3); + results.Single(x => x.Color == Colors.Blue).Count.ShouldBe(2L); + results.Single(x => x.Color == Colors.Green).Count.ShouldBe(3L); + } +} diff --git a/src/Marten/Events/Dcb/EventsExistByTagsHandler.cs b/src/Marten/Events/Dcb/EventsExistByTagsHandler.cs new file mode 100644 index 0000000000..6d1f2cb922 --- /dev/null +++ b/src/Marten/Events/Dcb/EventsExistByTagsHandler.cs @@ -0,0 +1,130 @@ +#nullable enable +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using JasperFx.Events.Tags; +using Marten.Internal; +using Marten.Internal.Sessions; +using Marten.Linq.QueryHandlers; +using Weasel.Postgresql; + +namespace Marten.Events.Dcb; + +internal class EventsExistByTagsHandler: IQueryHandler +{ + private readonly DocumentStore _store; + private readonly EventTagQuery _query; + + public EventsExistByTagsHandler(DocumentStore store, EventTagQuery query) + { + _store = store; + _query = query; + } + + public void ConfigureCommand(ICommandBuilder builder, IMartenSession session) + { + var conditions = _query.Conditions; + if (conditions.Count == 0) + { + throw new ArgumentException("EventTagQuery must have at least one condition."); + } + + var distinctTagTypes = conditions.Select(c => c.TagType).Distinct().ToList(); + var schema = _store.Events.DatabaseSchemaName; + + builder.Append("select exists (select 1 from "); + + var first = true; + for (var i = 0; i < distinctTagTypes.Count; i++) + { + var tagType = distinctTagTypes[i]; + var registration = _store.Events.FindTagType(tagType) + ?? throw new InvalidOperationException( + $"Tag type '{tagType.Name}' is not registered. Call RegisterTagType<{tagType.Name}>() first."); + + var alias = $"t{i}"; + if (first) + { + builder.Append(schema); + builder.Append(".mt_event_tag_"); + builder.Append(registration.TableSuffix); + builder.Append(" "); + builder.Append(alias); + first = false; + } + else + { + builder.Append(" inner join "); + builder.Append(schema); + builder.Append(".mt_event_tag_"); + builder.Append(registration.TableSuffix); + builder.Append(" "); + builder.Append(alias); + builder.Append(" on t0.seq_id = "); + builder.Append(alias); + builder.Append(".seq_id"); + } + } + + // Join to mt_events only if we need event type filtering + var hasEventTypeFilter = conditions.Any(c => c.EventType != null); + if (hasEventTypeFilter) + { + builder.Append(" inner join "); + builder.Append(schema); + builder.Append(".mt_events e on t0.seq_id = e.seq_id"); + } + + builder.Append(" where ("); + for (var i = 0; i < conditions.Count; i++) + { + if (i > 0) + { + builder.Append(" or "); + } + + var condition = conditions[i]; + var tagIndex = distinctTagTypes.IndexOf(condition.TagType); + var alias = $"t{tagIndex}"; + + builder.Append("("); + builder.Append(alias); + builder.Append(".value = "); + + var registration = _store.Events.FindTagType(condition.TagType)!; + var value = registration.ExtractValue(condition.TagValue); + builder.AppendParameter(value); + + if (condition.EventType != null) + { + builder.Append(" and e.type = "); + var eventTypeName = _store.Events.EventMappingFor(condition.EventType).EventTypeName; + builder.AppendParameter(eventTypeName); + } + + builder.Append(")"); + } + + builder.Append(") limit 1)"); + } + + public bool Handle(DbDataReader reader, IMartenSession session) + { + return reader.Read() && reader.GetBoolean(0); + } + + public async Task HandleAsync(DbDataReader reader, IMartenSession session, CancellationToken token) + { + return await reader.ReadAsync(token).ConfigureAwait(false) && + await reader.GetFieldValueAsync(0, token).ConfigureAwait(false); + } + + public Task StreamJson(Stream stream, DbDataReader reader, CancellationToken token) + { + throw new NotSupportedException(); + } +} diff --git a/src/Marten/Events/EventStore.Dcb.cs b/src/Marten/Events/EventStore.Dcb.cs index fd8fb40c5e..103b44b682 100644 --- a/src/Marten/Events/EventStore.Dcb.cs +++ b/src/Marten/Events/EventStore.Dcb.cs @@ -16,6 +16,14 @@ namespace Marten.Events; internal partial class EventStore { + public async Task EventsExistAsync(EventTagQuery query, CancellationToken cancellation = default) + { + await _session.Database.EnsureStorageExistsAsync(typeof(IEvent), cancellation).ConfigureAwait(false); + + var handler = new EventsExistByTagsHandler(_store, query); + return await _session.ExecuteHandlerAsync(handler, cancellation).ConfigureAwait(false); + } + public async Task> QueryByTagsAsync(EventTagQuery query, CancellationToken cancellation = default) { diff --git a/src/Marten/Events/EventStore.cs b/src/Marten/Events/EventStore.cs index 89f59ebf85..13478b9caf 100644 --- a/src/Marten/Events/EventStore.cs +++ b/src/Marten/Events/EventStore.cs @@ -1,11 +1,17 @@ #nullable enable using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; using JasperFx.Events; +using Marten.Events.Operations; using Marten.Events.Protected; using Marten.Internal.Sessions; +using Marten.Linq.Parsing; using Marten.Storage; +using Weasel.Postgresql.SqlGeneration; namespace Marten.Events; @@ -32,5 +38,43 @@ public void OverwriteEvent(IEvent e) _session.QueueOperation(op); } + public void AssignTagWhere(Expression> expression, object tag) + { + if (expression == null) throw new ArgumentNullException(nameof(expression)); + if (tag == null) throw new ArgumentNullException(nameof(tag)); + + var tagType = tag.GetType(); + var registration = _store.Events.FindTagType(tagType) + ?? throw new InvalidOperationException( + $"Tag type '{tagType.Name}' is not registered. Call RegisterTagType<{tagType.Name}>() first."); + + var value = registration.ExtractValue(tag); + var schema = _store.Events.DatabaseSchemaName; + // Parse the expression into a SQL WHERE fragment using EventQueryMapping + var mapping = new EventQueryMapping(_store.Options); + var holder = new SimpleWhereFragmentHolder(); + var parser = new WhereClauseParser(_store.Options, mapping.QueryMembers, holder); + parser.Visit(expression.Body); + + ISqlFragment whereFragment = holder.Fragments.Count switch + { + 0 => throw new ArgumentException("Expression did not produce any WHERE clause."), + 1 => holder.Fragments[0], + _ => CompoundWhereFragment.And(holder.Fragments) + }; + + var op = new AssignTagWhereOperation(schema, registration, value, whereFragment); + _session.QueueOperation(op); + } + + private class SimpleWhereFragmentHolder: IWhereFragmentHolder + { + public List Fragments { get; } = new(); + + public void Register(ISqlFragment filter) + { + if (filter != null) Fragments.Add(filter); + } + } } diff --git a/src/Marten/Events/IEventStoreOperations.cs b/src/Marten/Events/IEventStoreOperations.cs index fc1f535c02..eb564feec0 100644 --- a/src/Marten/Events/IEventStoreOperations.cs +++ b/src/Marten/Events/IEventStoreOperations.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq.Expressions; using System.Threading; using System.Threading.Tasks; using JasperFx.Events; @@ -416,6 +417,20 @@ Task WriteExclusivelyToAggregate(string id, Func, Task> writi /// Guid CompletelyReplaceEvent(long sequence, T eventBody) where T : class; + /// + /// Retroactively assign a tag to all events matching the given LINQ predicate. + /// The tag must be of a registered tag type. The operation is queued and applied at SaveChangesAsync time. + /// + /// LINQ predicate against IEvent properties (e.g. EventTypeName, StreamId, Timestamp) + /// Tag value whose type must be registered via RegisterTagType + void AssignTagWhere(Expression> expression, object tag); + + /// + /// Check whether any events exist that match the given tag query, without loading the events. + /// This is a lightweight existence check useful for DCB guard clauses. + /// + Task EventsExistAsync(EventTagQuery query, CancellationToken cancellation = default); + /// /// Query events by their tags using the DCB pattern. /// Returns events matching any of the OR'd conditions in the query, ordered by seq_id. diff --git a/src/Marten/Events/Operations/AssignTagWhereOperation.cs b/src/Marten/Events/Operations/AssignTagWhereOperation.cs new file mode 100644 index 0000000000..aca9fd9b2a --- /dev/null +++ b/src/Marten/Events/Operations/AssignTagWhereOperation.cs @@ -0,0 +1,65 @@ +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using JasperFx.Events; +using JasperFx.Events.Tags; +using Marten.Internal; +using Marten.Internal.Operations; +using Weasel.Postgresql; +using Weasel.Postgresql.SqlGeneration; + +namespace Marten.Events.Operations; + +/// +/// Retroactively assigns a tag to all events matching a WHERE clause. +/// Generates: INSERT INTO schema.mt_event_tag_{suffix} (value, seq_id) +/// SELECT @value, d.seq_id FROM schema.mt_events as d WHERE {where} +/// ON CONFLICT DO NOTHING +/// +internal class AssignTagWhereOperation: IStorageOperation +{ + private readonly string _schemaName; + private readonly ITagTypeRegistration _registration; + private readonly object _value; + private readonly ISqlFragment _whereFragment; + + public AssignTagWhereOperation(string schemaName, ITagTypeRegistration registration, object value, + ISqlFragment whereFragment) + { + _schemaName = schemaName; + _registration = registration; + _value = value; + _whereFragment = whereFragment; + } + + public void ConfigureCommand(ICommandBuilder builder, IMartenSession session) + { + builder.Append("insert into "); + builder.Append(_schemaName); + builder.Append(".mt_event_tag_"); + builder.Append(_registration.TableSuffix); + builder.Append(" (value, seq_id) select "); + builder.AppendParameter(_value); + builder.Append(", d.seq_id from "); + builder.Append(_schemaName); + builder.Append(".mt_events as d where "); + _whereFragment.Apply(builder); + builder.Append(" on conflict do nothing"); + } + + public Type DocumentType => typeof(IEvent); + + public void Postprocess(DbDataReader reader, IList exceptions) + { + // No-op + } + + public Task PostprocessAsync(DbDataReader reader, IList exceptions, CancellationToken token) + { + return Task.CompletedTask; + } + + public OperationRole Role() => OperationRole.Events; +} diff --git a/src/Marten/IQuerySession.cs b/src/Marten/IQuerySession.cs index 7602c1b43f..0d13888204 100644 --- a/src/Marten/IQuerySession.cs +++ b/src/Marten/IQuerySession.cs @@ -72,6 +72,56 @@ public interface IQuerySession: IDisposable, IAsyncDisposable /// string TenantId { get; } + /// + /// Check if a document of type T with the given string id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + /// + Task CheckExistsAsync(string id, CancellationToken token = default) where T : notnull; + + /// + /// Check if a document of type T with the given int id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + /// + Task CheckExistsAsync(int id, CancellationToken token = default) where T : notnull; + + /// + /// Check if a document of type T with the given long id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + /// + Task CheckExistsAsync(long id, CancellationToken token = default) where T : notnull; + + /// + /// Check if a document of type T with the given Guid id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + /// + Task CheckExistsAsync(Guid id, CancellationToken token = default) where T : notnull; + + /// + /// Check if a document of type T with the given id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + /// + Task CheckExistsAsync(object id, CancellationToken token = default) where T : notnull; + /// /// Asynchronously find or load a single document of type T by a string id /// diff --git a/src/Marten/Internal/Sessions/QuerySession.CheckExists.cs b/src/Marten/Internal/Sessions/QuerySession.CheckExists.cs new file mode 100644 index 0000000000..34b9058a86 --- /dev/null +++ b/src/Marten/Internal/Sessions/QuerySession.CheckExists.cs @@ -0,0 +1,86 @@ +#nullable enable + +using System; +using System.Threading; +using System.Threading.Tasks; +using JasperFx.Core.Reflection; +using Marten.Exceptions; +using Marten.Internal.Storage; +using Marten.Linq.QueryHandlers; + +namespace Marten.Internal.Sessions; + +public partial class QuerySession +{ + public async Task CheckExistsAsync(string id, CancellationToken token = default) where T : notnull + { + assertNotDisposed(); + await Database.EnsureStorageExistsAsync(typeof(T), token).ConfigureAwait(false); + var storage = StorageFor(); + var handler = new CheckExistsByIdHandler(storage, id); + return await ExecuteHandlerAsync(handler, token).ConfigureAwait(false); + } + + public async Task CheckExistsAsync(int id, CancellationToken token = default) where T : notnull + { + assertNotDisposed(); + await Database.EnsureStorageExistsAsync(typeof(T), token).ConfigureAwait(false); + var storage = StorageFor(); + + if (storage is IDocumentStorage i) + { + var handler = new CheckExistsByIdHandler(i, id); + return await ExecuteHandlerAsync(handler, token).ConfigureAwait(false); + } + + if (storage is IDocumentStorage l) + { + var handler = new CheckExistsByIdHandler(l, id); + return await ExecuteHandlerAsync(handler, token).ConfigureAwait(false); + } + + throw new DocumentIdTypeMismatchException( + $"The identity type for document type {typeof(T).FullNameInCode()} is not numeric"); + } + + public async Task CheckExistsAsync(long id, CancellationToken token = default) where T : notnull + { + assertNotDisposed(); + await Database.EnsureStorageExistsAsync(typeof(T), token).ConfigureAwait(false); + var storage = StorageFor(); + var handler = new CheckExistsByIdHandler(storage, id); + return await ExecuteHandlerAsync(handler, token).ConfigureAwait(false); + } + + public async Task CheckExistsAsync(Guid id, CancellationToken token = default) where T : notnull + { + assertNotDisposed(); + await Database.EnsureStorageExistsAsync(typeof(T), token).ConfigureAwait(false); + var storage = StorageFor(); + var handler = new CheckExistsByIdHandler(storage, id); + return await ExecuteHandlerAsync(handler, token).ConfigureAwait(false); + } + + public async Task CheckExistsAsync(object id, CancellationToken token = default) where T : notnull + { + assertNotDisposed(); + await Database.EnsureStorageExistsAsync(typeof(T), token).ConfigureAwait(false); + var loader = typeof(ExistsChecker<>).CloseAndBuildAs(id.GetType()); + return await loader.CheckExistsAsync(id, this, token).ConfigureAwait(false); + } + + private interface IExistsChecker + { + Task CheckExistsAsync(object id, QuerySession session, CancellationToken token = default) where T : notnull; + } + + private class ExistsChecker: IExistsChecker where TId : notnull + { + public async Task CheckExistsAsync(object id, QuerySession session, CancellationToken token = default) where T : notnull + { + var storage = session.StorageFor(); + var handler = new CheckExistsByIdHandler(storage, (TId)id); + return await session.ExecuteHandlerAsync(handler, token).ConfigureAwait(false); + } + } +} diff --git a/src/Marten/Linq/CollectionUsage.Compilation.cs b/src/Marten/Linq/CollectionUsage.Compilation.cs index e17d470850..ae9df58c31 100644 --- a/src/Marten/Linq/CollectionUsage.Compilation.cs +++ b/src/Marten/Linq/CollectionUsage.Compilation.cs @@ -1,7 +1,9 @@ #nullable enable using System; +using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; using JasperFx.Core.Reflection; using Marten.Exceptions; using Marten.Internal; @@ -37,6 +39,12 @@ public Statement BuildTopStatement(IMartenSession session, IQueryableMemberColle statement.ParseWhereClause(WhereExpressions, session, collection, storage); + // If this is a GroupBy query, handle it separately + if (GroupByData != null) + { + return CompileGroupBy(session, statement, collection, statistics); + } + ParseIncludes(collection, session); if (Includes.Any()) { @@ -373,6 +381,155 @@ public Statement CompileGroupJoin(IMartenSession session, return joinStatement; } + public Statement CompileGroupBy(IMartenSession session, + SelectorStatement statement, IQueryableMemberCollection collection, QueryStatistics? statistics) + { + var groupBy = GroupByData!; + var groupingUsage = Inner; // The IGrouping usage with SelectExpression and WhereExpressions + + if (groupingUsage?.SelectExpression == null) + { + throw new BadLinqExpressionException( + "GroupBy must be followed by a Select() projection. Marten does not support returning IGrouping directly."); + } + + // Find the grouping parameter from the select expression + // The SelectExpression is the body of the lambda; we need the original lambda's parameter + // The grouping parameter was on the Select's lambda: g => new { ... } + // We need to find it from the expression tree + ParameterExpression groupingParam = FindGroupingParameter(groupingUsage.SelectExpression); + + var parser = new GroupBySelectParser( + _options.Serializer(), + collection, + groupBy.KeySelector, + groupingUsage.SelectExpression, + groupingParam); + + // Set GROUP BY columns + foreach (var col in parser.GroupByColumns) + { + statement.GroupByColumns.Add(col); + } + + // Set SELECT clause + if (parser.IsScalar) + { + var fragment = parser.ScalarFragment; + if (fragment is IQueryableMember member) + { + if (member.MemberType == typeof(string)) + { + statement.SelectClause = + new NewScalarStringSelectClause(member.RawLocator, statement.SelectClause.FromObject); + } + else if (member.MemberType.IsSimple() || member.MemberType == typeof(Guid) || + member.MemberType == typeof(decimal) || member.MemberType == typeof(DateTimeOffset)) + { + statement.SelectClause = + typeof(NewScalarSelectClause<>).CloseAndBuildAs(member, + statement.SelectClause.FromObject, + member.MemberType); + } + } + else if (fragment is LiteralSql literal) + { + // Aggregate scalar like count(*) + statement.SelectClause = + new NewScalarSelectClause(literal.Text, statement.SelectClause.FromObject); + } + } + else + { + var resultType = groupingUsage.SelectExpression.Type; + statement.SelectClause = + typeof(SelectDataSelectClause<>).CloseAndBuildAs( + statement.SelectClause.FromObject, + parser.NewObject, + resultType); + } + + // Process HAVING from the grouping usage's WhereExpressions + if (groupingUsage.WhereExpressions.Any()) + { + // Build key member dictionaries for the HAVING resolver + var keyMembers = new Dictionary(); + IQueryableMember simpleKeyMember = null; + bool isCompositeKey; + + var keyBody = groupBy.KeySelector.Body; + if (keyBody is NewExpression newExpr) + { + isCompositeKey = true; + var parameters = newExpr.Constructor!.GetParameters(); + for (var i = 0; i < parameters.Length; i++) + { + keyMembers[parameters[i].Name!] = collection.MemberFor(newExpr.Arguments[i]); + } + } + else + { + isCompositeKey = false; + simpleKeyMember = collection.MemberFor(keyBody); + } + + foreach (var whereExpr in groupingUsage.WhereExpressions) + { + var havingFragment = GroupBySelectParser.ResolveHavingFragment( + whereExpr, collection, groupBy.KeySelector, keyMembers, simpleKeyMember, isCompositeKey); + statement.HavingClauses.Add(havingFragment); + } + } + + // Transfer downstream operators from the grouping usage's Inner (if any) + // e.g., OrderBy, Take, Skip after Select + var downstream = groupingUsage.Inner; + if (downstream != null) + { + statement.Limit ??= downstream._limit; + statement.Offset ??= downstream._offset; + + if (downstream.SingleValueMode.HasValue) + { + SingleValueMode = downstream.SingleValueMode; + } + + if (downstream.IsAny) + { + IsAny = true; + } + } + + // Apply single value mode (Count, First, etc. after the GroupBy+Select) + ProcessSingleValueModeIfAny(statement, session, collection, statistics); + + return statement.Top(); + } + + private static ParameterExpression FindGroupingParameter(Expression expression) + { + var finder = new GroupingParameterFinder(); + finder.Visit(expression); + return finder.Parameter ?? throw new BadLinqExpressionException( + "Could not find the IGrouping parameter in the GroupBy Select expression"); + } + + private class GroupingParameterFinder: ExpressionVisitor + { + public ParameterExpression? Parameter { get; private set; } + + protected override Expression VisitParameter(ParameterExpression node) + { + if (Parameter == null && node.Type.IsGenericType && + node.Type.GetGenericTypeDefinition() == typeof(IGrouping<,>)) + { + Parameter = node; + } + + return base.VisitParameter(node); + } + } + public Statement CompileSelectMany(IMartenSession session, SelectorStatement parentStatement, ICollectionMember collectionMember, QueryStatistics? statistics) { diff --git a/src/Marten/Linq/CollectionUsage.cs b/src/Marten/Linq/CollectionUsage.cs index 58cacf2894..18ceab024d 100644 --- a/src/Marten/Linq/CollectionUsage.cs +++ b/src/Marten/Linq/CollectionUsage.cs @@ -30,6 +30,7 @@ public CollectionUsage(StoreOptions options, Type elementType) public Expression SelectMany { get; set; } = null!; public MethodCallExpression? SelectManyCallExpression { get; set; } public GroupJoinData? GroupJoinData { get; set; } + public GroupByData? GroupByData { get; set; } public void WriteLimit(int limit) diff --git a/src/Marten/Linq/GroupByData.cs b/src/Marten/Linq/GroupByData.cs new file mode 100644 index 0000000000..bdf3591cfa --- /dev/null +++ b/src/Marten/Linq/GroupByData.cs @@ -0,0 +1,15 @@ +#nullable enable +using System.Linq.Expressions; + +namespace Marten.Linq; + +/// +/// Holds parsed GroupBy expression components for LINQ GroupBy translation to SQL GROUP BY. +/// +public class GroupByData +{ + /// + /// The key selector lambda (e.g., x => x.Color) + /// + public LambdaExpression KeySelector { get; set; } = null!; +} diff --git a/src/Marten/Linq/Parsing/GroupBySelectParser.cs b/src/Marten/Linq/Parsing/GroupBySelectParser.cs new file mode 100644 index 0000000000..06cff2c465 --- /dev/null +++ b/src/Marten/Linq/Parsing/GroupBySelectParser.cs @@ -0,0 +1,488 @@ +#nullable disable +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using Marten.Exceptions; +using Marten.Linq.Members; +using Marten.Linq.SqlGeneration; +using Weasel.Postgresql; +using Weasel.Postgresql.SqlGeneration; + +namespace Marten.Linq.Parsing; + +/// +/// Parses the Select() projection on an IGrouping to build the SQL SELECT columns, +/// GROUP BY keys, and aggregate expressions (COUNT, SUM, MIN, MAX, AVG). +/// +internal class GroupBySelectParser: ExpressionVisitor +{ + private readonly ISerializer _serializer; + private readonly IQueryableMemberCollection _collection; + private readonly LambdaExpression _keySelector; + private readonly ParameterExpression _groupingParameter; + + // For composite keys: maps anonymous type member name to IQueryableMember + private readonly Dictionary _keyMembers = new(); + // For simple keys: the single key member + private IQueryableMember _simpleKeyMember; + private bool _isCompositeKey; + + private string _currentField; + private bool _hasStarted; + + public NewObject NewObject { get; private set; } + public List GroupByColumns { get; } = new(); + + // For scalar select (e.g., .Select(g => g.Key) or .Select(g => g.Count())) + public ISqlFragment ScalarFragment { get; private set; } + public bool IsScalar { get; private set; } + + public GroupBySelectParser( + ISerializer serializer, + IQueryableMemberCollection collection, + LambdaExpression keySelector, + Expression selectBody, + ParameterExpression groupingParameter) + { + _serializer = serializer; + _collection = collection; + _keySelector = keySelector; + _groupingParameter = groupingParameter; + + NewObject = new NewObject(serializer); + ParseKeySelector(); + Visit(selectBody); + } + + private void ParseKeySelector() + { + var body = _keySelector.Body; + + if (body is NewExpression newExpr) + { + // Composite key: x => new { x.Color, x.Number } + _isCompositeKey = true; + var parameters = newExpr.Constructor!.GetParameters(); + for (var i = 0; i < parameters.Length; i++) + { + var member = _collection.MemberFor(newExpr.Arguments[i]); + _keyMembers[parameters[i].Name!] = member; + GroupByColumns.Add(member.TypedLocator); + } + } + else if (body is MemberInitExpression memberInit) + { + // Composite key with member init: x => new KeyClass { Color = x.Color } + _isCompositeKey = true; + foreach (var binding in memberInit.Bindings.OfType()) + { + var member = _collection.MemberFor(binding.Expression); + _keyMembers[binding.Member.Name] = member; + GroupByColumns.Add(member.TypedLocator); + } + } + else + { + // Simple key: x => x.Color + _isCompositeKey = false; + _simpleKeyMember = _collection.MemberFor(body); + GroupByColumns.Add(_simpleKeyMember.TypedLocator); + } + } + + protected override Expression VisitNew(NewExpression node) + { + if (_hasStarted) + { + // Nested new expression - not supported for now + throw new BadLinqExpressionException( + "Marten does not support nested constructors in GroupBy projections"); + } + + _hasStarted = true; + + var parameters = node.Constructor!.GetParameters(); + for (var i = 0; i < parameters.Length; i++) + { + _currentField = parameters[i].Name; + Visit(node.Arguments[i]); + } + + return node; + } + + protected override Expression VisitMemberInit(MemberInitExpression node) + { + _hasStarted = true; + + // Visit constructor args first + var parameters = node.NewExpression.Constructor!.GetParameters(); + for (var i = 0; i < parameters.Length; i++) + { + _currentField = parameters[i].Name; + Visit(node.NewExpression.Arguments[i]); + } + + // Then visit member bindings + foreach (var binding in node.Bindings.OfType()) + { + _currentField = binding.Member.Name; + Visit(binding.Expression); + } + + return node; + } + + protected override Expression VisitMember(MemberExpression node) + { + // Check if this is g.Key + if (IsGroupingKeyAccess(node)) + { + if (_isCompositeKey) + { + // g.Key for composite key - this shouldn't happen directly in a well-formed projection + // But if it does, we can't represent the whole anonymous key as a single SQL expression + throw new BadLinqExpressionException( + "Cannot select the entire composite GroupBy key directly. Access individual key members like g.Key.Color instead."); + } + + if (_currentField != null) + { + NewObject.Members[_currentField] = _simpleKeyMember; + _currentField = null; + } + else + { + // Scalar select: .Select(g => g.Key) + IsScalar = true; + ScalarFragment = _simpleKeyMember; + } + + return node; + } + + // Check if this is g.Key.PropertyName (composite key member access) + if (node.Expression is MemberExpression innerMember && IsGroupingKeyAccess(innerMember)) + { + var memberName = node.Member.Name; + if (_keyMembers.TryGetValue(memberName, out var keyMember)) + { + if (_currentField != null) + { + NewObject.Members[_currentField] = keyMember; + _currentField = null; + } + else + { + IsScalar = true; + ScalarFragment = keyMember; + } + + return node; + } + + throw new BadLinqExpressionException( + $"Unknown composite key member '{memberName}' in GroupBy projection"); + } + + return base.VisitMember(node); + } + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + var aggregateSql = TryResolveAggregate(node); + if (aggregateSql != null) + { + if (_currentField != null) + { + NewObject.Members[_currentField] = new LiteralSql(aggregateSql); + _currentField = null; + } + else + { + IsScalar = true; + ScalarFragment = new LiteralSql(aggregateSql); + } + + return node; + } + + return base.VisitMethodCall(node); + } + + private string TryResolveAggregate(MethodCallExpression node) + { + var methodName = node.Method.Name; + + // Parameterless: g.Count(), g.LongCount() + if (methodName is "Count" or "LongCount") + { + if (node.Arguments.Count == 1 && IsGroupingParameter(node.Arguments[0])) + { + return "count(*)"; + } + + // With predicate: g.Count(x => x.Flag) + if (node.Arguments.Count == 2 && IsGroupingParameter(node.Arguments[0])) + { + var predicateSql = ResolvePredicate(node.Arguments[1]); + return $"count(*) filter (where {predicateSql})"; + } + } + + // Aggregate with selector: g.Sum(x => x.Number), g.Min(...), g.Max(...), g.Average(...) + if (methodName is "Sum" or "Min" or "Max" or "Average") + { + if (node.Arguments.Count == 2 && IsGroupingParameter(node.Arguments[0])) + { + var selectorLambda = ExtractLambda(node.Arguments[1]); + if (selectorLambda != null) + { + var member = _collection.MemberFor(selectorLambda.Body); + var sqlOp = methodName == "Average" ? "avg" : methodName.ToLowerInvariant(); + return $"{sqlOp}({member.TypedLocator})"; + } + } + } + + return null; + } + + /// + /// Resolves a HAVING clause predicate from the grouping's Where expression. + /// Returns the SQL for the predicate. + /// + public static ISqlFragment ResolveHavingFragment( + Expression expression, + IQueryableMemberCollection collection, + LambdaExpression keySelector, + Dictionary keyMembers, + IQueryableMember simpleKeyMember, + bool isCompositeKey) + { + var resolver = new HavingExpressionResolver(collection, keySelector, keyMembers, simpleKeyMember, isCompositeKey); + return resolver.Resolve(expression); + } + + private bool IsGroupingKeyAccess(MemberExpression node) + { + return node.Member.Name == "Key" + && node.Expression is ParameterExpression param + && param == _groupingParameter; + } + + private bool IsGroupingParameter(Expression node) + { + return node is ParameterExpression param && param == _groupingParameter; + } + + private static LambdaExpression ExtractLambda(Expression expr) + { + if (expr is UnaryExpression unary) + { + expr = unary.Operand; + } + + return expr as LambdaExpression; + } + + private string ResolvePredicate(Expression expr) + { + var lambda = ExtractLambda(expr); + if (lambda == null) + { + throw new BadLinqExpressionException("Expected a lambda predicate in GroupBy aggregate"); + } + + // Simple predicate support: x => x.Flag + var member = _collection.MemberFor(lambda.Body); + return $"{member.TypedLocator} = True"; + } +} + +/// +/// Translates Where() expressions on IGrouping to SQL HAVING clauses. +/// Supports aggregate comparisons like g.Count() > 5, g.Sum(x => x.Number) >= 100. +/// +internal class HavingExpressionResolver +{ + private readonly IQueryableMemberCollection _collection; + private readonly LambdaExpression _keySelector; + private readonly Dictionary _keyMembers; + private readonly IQueryableMember _simpleKeyMember; + private readonly bool _isCompositeKey; + + public HavingExpressionResolver( + IQueryableMemberCollection collection, + LambdaExpression keySelector, + Dictionary keyMembers, + IQueryableMember simpleKeyMember, + bool isCompositeKey) + { + _collection = collection; + _keySelector = keySelector; + _keyMembers = keyMembers; + _simpleKeyMember = simpleKeyMember; + _isCompositeKey = isCompositeKey; + } + + public ISqlFragment Resolve(Expression expression) + { + if (expression is BinaryExpression binary) + { + return ResolveBinary(binary); + } + + throw new BadLinqExpressionException( + "Marten only supports binary comparison expressions in GroupBy HAVING clauses"); + } + + private ISqlFragment ResolveBinary(BinaryExpression binary) + { + // Handle AND/OR + if (binary.NodeType == ExpressionType.AndAlso) + { + var left = Resolve(binary.Left); + var right = Resolve(binary.Right); + return new CompoundFragment("and", left, right); + } + + if (binary.NodeType == ExpressionType.OrElse) + { + var left = Resolve(binary.Left); + var right = Resolve(binary.Right); + return new CompoundFragment("or", left, right); + } + + var op = binary.NodeType switch + { + ExpressionType.Equal => "=", + ExpressionType.NotEqual => "!=", + ExpressionType.GreaterThan => ">", + ExpressionType.GreaterThanOrEqual => ">=", + ExpressionType.LessThan => "<", + ExpressionType.LessThanOrEqual => "<=", + _ => throw new BadLinqExpressionException( + $"Unsupported comparison operator '{binary.NodeType}' in GroupBy HAVING clause") + }; + + var leftSql = ResolveOperand(binary.Left); + var rightSql = ResolveOperand(binary.Right); + + return new HavingComparisonFragment(leftSql, op, rightSql); + } + + private string ResolveOperand(Expression expr) + { + // Aggregate call: g.Count(), g.Sum(x => x.Number) + if (expr is MethodCallExpression method) + { + return ResolveAggregateCall(method) + ?? throw new BadLinqExpressionException( + $"Unsupported method '{method.Method.Name}' in GroupBy HAVING clause"); + } + + // Constant + if (expr is ConstantExpression constant) + { + return constant.Value?.ToString() ?? "NULL"; + } + + // Key access: g.Key + if (expr is MemberExpression member && member.Member.Name == "Key") + { + if (_isCompositeKey) + { + throw new BadLinqExpressionException( + "Cannot use composite key directly in HAVING clause"); + } + + return _simpleKeyMember!.TypedLocator; + } + + // Try to evaluate as constant + if (expr.TryToParseConstant(out var c)) + { + return c.Value?.ToString() ?? "NULL"; + } + + throw new BadLinqExpressionException( + $"Unsupported expression type '{expr.NodeType}' in GroupBy HAVING clause"); + } + + private string ResolveAggregateCall(MethodCallExpression node) + { + var methodName = node.Method.Name; + + if (methodName is "Count" or "LongCount") + { + return "count(*)"; + } + + if (methodName is "Sum" or "Min" or "Max" or "Average" && node.Arguments.Count >= 2) + { + var lambda = ExtractLambda(node.Arguments[1]); + if (lambda != null) + { + var member = _collection.MemberFor(lambda.Body); + var sqlOp = methodName == "Average" ? "avg" : methodName.ToLowerInvariant(); + return $"{sqlOp}({member.TypedLocator})"; + } + } + + return null; + } + + private static LambdaExpression ExtractLambda(Expression expr) + { + if (expr is UnaryExpression unary) expr = unary.Operand; + return expr as LambdaExpression; + } +} + +internal class HavingComparisonFragment: ISqlFragment +{ + private readonly string _left; + private readonly string _op; + private readonly string _right; + + public HavingComparisonFragment(string left, string op, string right) + { + _left = left; + _op = op; + _right = right; + } + + public void Apply(ICommandBuilder builder) + { + builder.Append(_left); + builder.Append(" "); + builder.Append(_op); + builder.Append(" "); + builder.Append(_right); + } +} + +internal class CompoundFragment: ISqlFragment +{ + private readonly string _separator; + private readonly ISqlFragment _left; + private readonly ISqlFragment _right; + + public CompoundFragment(string separator, ISqlFragment left, ISqlFragment right) + { + _separator = separator; + _left = left; + _right = right; + } + + public void Apply(ICommandBuilder builder) + { + builder.Append("("); + _left.Apply(builder); + builder.Append($" {_separator} "); + _right.Apply(builder); + builder.Append(")"); + } +} diff --git a/src/Marten/Linq/Parsing/Operators/GroupByOperator.cs b/src/Marten/Linq/Parsing/Operators/GroupByOperator.cs new file mode 100644 index 0000000000..328cb48c8a --- /dev/null +++ b/src/Marten/Linq/Parsing/Operators/GroupByOperator.cs @@ -0,0 +1,27 @@ +#nullable enable +using System.Linq.Expressions; + +namespace Marten.Linq.Parsing.Operators; + +public class GroupByOperator: LinqOperator +{ + public GroupByOperator(): base("GroupBy") + { + } + + public override void Apply(ILinqQuery query, MethodCallExpression expression) + { + // GroupBy signature: source.GroupBy(keySelector) + // expression.Arguments[0] = source + // expression.Arguments[1] = key selector + + var usage = query.CollectionUsageFor(expression); + + var keyExpr = expression.Arguments[1].UnBox(); + + usage.GroupByData = new GroupByData + { + KeySelector = (LambdaExpression)keyExpr + }; + } +} diff --git a/src/Marten/Linq/Parsing/Operators/OperatorLibrary.cs b/src/Marten/Linq/Parsing/Operators/OperatorLibrary.cs index 40990298b0..2ea1a1013f 100644 --- a/src/Marten/Linq/Parsing/Operators/OperatorLibrary.cs +++ b/src/Marten/Linq/Parsing/Operators/OperatorLibrary.cs @@ -28,6 +28,7 @@ public OperatorLibrary() Add(); Add(); Add(); + Add(); Add(); Add(); Add(); // TODO -- is this necessary? diff --git a/src/Marten/Linq/QueryHandlers/CheckExistsByIdHandler.cs b/src/Marten/Linq/QueryHandlers/CheckExistsByIdHandler.cs new file mode 100644 index 0000000000..4ac055a93a --- /dev/null +++ b/src/Marten/Linq/QueryHandlers/CheckExistsByIdHandler.cs @@ -0,0 +1,68 @@ +#nullable enable +using System; +using System.Data.Common; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using JasperFx.Core.Reflection; +using Marten.Internal; +using Marten.Internal.Storage; +using Marten.Services; +using Weasel.Postgresql; + +namespace Marten.Linq.QueryHandlers; + +internal class CheckExistsByIdHandler: IQueryHandler where T : notnull where TId : notnull +{ + private readonly TId _id; + private readonly IDocumentStorage storage; + private static readonly Type[] _identityTypes = [typeof(int), typeof(long), typeof(string), typeof(Guid)]; + + public CheckExistsByIdHandler(IDocumentStorage documentStorage, TId id) + { + storage = documentStorage; + _id = id; + } + + public void ConfigureCommand(ICommandBuilder sql, IMartenSession session) + { + sql.Append("select exists(select 1 from "); + sql.Append(storage.FromObject); + sql.Append(" as d where id = "); + + if (_identityTypes.Contains(typeof(TId))) + { + sql.AppendParameter(_id); + } + else + { + var valueType = ValueTypeInfo.ForType(typeof(TId)); + typeof(Appender<,>).CloseAndBuildAs>(valueType, typeof(TId), valueType.SimpleType) + .Append(sql, _id); + } + + storage.AddTenancyFilter(sql, session.TenantId); + sql.Append(")"); + } + + public bool Handle(DbDataReader reader, IMartenSession session) + { + return reader.Read() && reader.GetBoolean(0); + } + + public async Task HandleAsync(DbDataReader reader, IMartenSession session, CancellationToken token) + { + if (await reader.ReadAsync(token).ConfigureAwait(false)) + { + return reader.GetBoolean(0); + } + + return false; + } + + public Task StreamJson(Stream stream, DbDataReader reader, CancellationToken token) + { + throw new NotSupportedException("StreamJson is not supported for CheckExistsByIdHandler"); + } +} diff --git a/src/Marten/Linq/SqlGeneration/SelectorStatement.cs b/src/Marten/Linq/SqlGeneration/SelectorStatement.cs index 7dd4b9c0d9..4379bb392f 100644 --- a/src/Marten/Linq/SqlGeneration/SelectorStatement.cs +++ b/src/Marten/Linq/SqlGeneration/SelectorStatement.cs @@ -1,5 +1,6 @@ #nullable enable using System; +using System.Collections.Generic; using System.Linq; using JasperFx.Core; using Marten.Internal; @@ -25,6 +26,9 @@ public class SelectorStatement: Statement, IWhereFragmentHolder public bool IsDistinct { get; set; } + public List GroupByColumns { get; } = new(); + public List HavingClauses { get; } = new(); + public void Register(ISqlFragment fragment) { Wheres.Add(fragment); @@ -47,6 +51,28 @@ protected override void configure(ICommandBuilder sql) } } + if (GroupByColumns.Count > 0) + { + sql.Append(" GROUP BY "); + sql.Append(GroupByColumns[0]); + for (var i = 1; i < GroupByColumns.Count; i++) + { + sql.Append(", "); + sql.Append(GroupByColumns[i]); + } + } + + if (HavingClauses.Count > 0) + { + sql.Append(" HAVING "); + HavingClauses[0].Apply(sql); + for (var i = 1; i < HavingClauses.Count; i++) + { + sql.Append(" and "); + HavingClauses[i].Apply(sql); + } + } + Ordering.Apply(sql); if (Offset.HasValue) diff --git a/src/Marten/Services/BatchQuerying/BatchedQuery.Events.cs b/src/Marten/Services/BatchQuerying/BatchedQuery.Events.cs index 63b8748da1..7b7ea7f3f7 100644 --- a/src/Marten/Services/BatchQuerying/BatchedQuery.Events.cs +++ b/src/Marten/Services/BatchQuerying/BatchedQuery.Events.cs @@ -183,6 +183,14 @@ public async Task> FetchForExclusiveWriting(string key) where return AddItem(handler); } + public Task EventsExist(EventTagQuery query) + { + _documentTypes.Add(typeof(IEvent)); + var store = (DocumentStore)Parent.DocumentStore; + var handler = new EventsExistByTagsHandler(store, query); + return AddItem(handler); + } + public Task> FetchForWritingByTags(EventTagQuery query) where T : class { Parent.AssertIsDocumentSession(); diff --git a/src/Marten/Services/BatchQuerying/BatchedQuery.cs b/src/Marten/Services/BatchQuerying/BatchedQuery.cs index 76b047cce1..7b191ba654 100644 --- a/src/Marten/Services/BatchQuerying/BatchedQuery.cs +++ b/src/Marten/Services/BatchQuerying/BatchedQuery.cs @@ -136,6 +136,58 @@ public Task QueryByPlan(IBatchQueryPlan plan) return plan.Fetch(this); } + public Task CheckExists(string id) where T : class + { + return checkExists(id); + } + + public Task CheckExists(int id) where T : class + { + return checkExists(id); + } + + public Task CheckExists(long id) where T : class + { + return checkExists(id); + } + + public Task CheckExists(Guid id) where T : class + { + return checkExists(id); + } + + public Task CheckExists(object id) where T : class + { + var checker = typeof(ExistsChecker<>).CloseAndBuildAs(id.GetType()); + return checker.CheckExists(id, this); + } + + private Task checkExists(TId id) where T : class where TId : notnull + { + _documentTypes.Add(typeof(T)); + var storage = Parent.StorageFor(); + if (storage is IDocumentStorage s) + { + var handler = new CheckExistsByIdHandler(s, id); + return AddItem(handler); + } + + throw new DocumentIdTypeMismatchException(storage, typeof(TId)); + } + + private interface IExistsChecker + { + Task CheckExists(object id, BatchedQuery parent) where T : class; + } + + private class ExistsChecker: IExistsChecker where TId : notnull + { + public Task CheckExists(object id, BatchedQuery parent) where T : class + { + return parent.checkExists((TId)id); + } + } + private Task load(TId id) where T : class where TId : notnull { _documentTypes.Add(typeof(T)); diff --git a/src/Marten/Services/BatchQuerying/IBatchedQuery.cs b/src/Marten/Services/BatchQuerying/IBatchedQuery.cs index 5a567b3609..b9bdbaa9bc 100644 --- a/src/Marten/Services/BatchQuerying/IBatchedQuery.cs +++ b/src/Marten/Services/BatchQuerying/IBatchedQuery.cs @@ -138,6 +138,14 @@ Task> FetchForExclusiveWriting(string key) /// Task FetchLatest(string id) where T : class; + /// + /// Check whether any events exist that match the given tag query, without loading the events. + /// This is a lightweight existence check useful for DCB guard clauses. + /// + /// + /// + Task EventsExist(EventTagQuery query); + /// /// Fetch events matching a tag query and aggregate them into type T with a DCB consistency boundary. /// At SaveChangesAsync time, will throw DcbConcurrencyException if new matching events were appended. @@ -157,6 +165,51 @@ public interface IBatchedQuery QuerySession Parent { get; } + /// + /// Check if a document of type T with the given string id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + Task CheckExists(string id) where T : class; + + /// + /// Check if a document of type T with the given int id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + Task CheckExists(int id) where T : class; + + /// + /// Check if a document of type T with the given long id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + Task CheckExists(long id) where T : class; + + /// + /// Check if a document of type T with the given Guid id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + Task CheckExists(Guid id) where T : class; + + /// + /// Check if a document of type T with the given id exists in the database + /// without loading or deserializing the document + /// + /// + /// + /// + Task CheckExists(object id) where T : class; + /// /// Load a single document of Type "T" by id /// diff --git a/src/ValueTypeTests/StrongTypedId/check_exists_with_strong_typed_ids.cs b/src/ValueTypeTests/StrongTypedId/check_exists_with_strong_typed_ids.cs new file mode 100644 index 0000000000..19a708cabf --- /dev/null +++ b/src/ValueTypeTests/StrongTypedId/check_exists_with_strong_typed_ids.cs @@ -0,0 +1,181 @@ +using System; +using System.Threading.Tasks; +using JasperFx.CodeGeneration; +using JasperFx.Core; +using Marten; +using Marten.Testing.Harness; +using Shouldly; + +namespace ValueTypeTests.StrongTypedId; + +public class check_exists_with_strong_typed_ids: IDisposable, IAsyncDisposable +{ + private readonly DocumentStore theStore; + private IDocumentSession theSession; + + public check_exists_with_strong_typed_ids() + { + theStore = DocumentStore.For(opts => + { + opts.Connection(ConnectionSource.ConnectionString); + opts.DatabaseSchemaName = "strong_typed_exists"; + + opts.ApplicationAssembly = GetType().Assembly; + opts.GeneratedCodeMode = TypeLoadMode.Auto; + opts.GeneratedCodeOutputPath = + AppContext.BaseDirectory.ParentDirectory().ParentDirectory().ParentDirectory().AppendPath("Internal", "Generated"); + }); + + theSession = theStore.LightweightSession(); + } + + public void Dispose() + { + theStore?.Dispose(); + theSession?.Dispose(); + } + + public async ValueTask DisposeAsync() + { + if (theStore != null) + { + await theStore.DisposeAsync(); + } + } + + [Fact] + public async Task check_exists_with_guid_strong_typed_id_hit() + { + var invoice = new Invoice2(); + theSession.Store(invoice); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync((object)invoice.Id!.Value); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_with_guid_strong_typed_id_miss() + { + var exists = await theSession.CheckExistsAsync((object)new Invoice2Id(Guid.NewGuid())); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_with_int_strong_typed_id_hit() + { + var order = new Order2 { Id = new Order2Id(42), Name = "Test" }; + theSession.Store(order); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync((object)order.Id!.Value); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_with_int_strong_typed_id_miss() + { + var exists = await theSession.CheckExistsAsync((object)new Order2Id(999999)); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_with_long_strong_typed_id_hit() + { + var issue = new Issue2 { Id = new Issue2Id(500L), Name = "Test" }; + theSession.Store(issue); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync((object)issue.Id!.Value); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_with_long_strong_typed_id_miss() + { + var exists = await theSession.CheckExistsAsync((object)new Issue2Id(999999L)); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_with_string_strong_typed_id_hit() + { + var team = new Team2 { Id = new Team2Id("team-exists-test"), Name = "Test" }; + theSession.Store(team); + await theSession.SaveChangesAsync(); + + var exists = await theSession.CheckExistsAsync((object)team.Id!.Value); + exists.ShouldBeTrue(); + } + + [Fact] + public async Task check_exists_with_string_strong_typed_id_miss() + { + var exists = await theSession.CheckExistsAsync((object)new Team2Id("nonexistent")); + exists.ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_with_guid_strong_typed_id() + { + var invoice = new Invoice2(); + theSession.Store(invoice); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists((object)invoice.Id!.Value); + var existsMiss = batch.CheckExists((object)new Invoice2Id(Guid.NewGuid())); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_with_int_strong_typed_id() + { + var order = new Order2 { Id = new Order2Id(88), Name = "Batch Test" }; + theSession.Store(order); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists((object)order.Id!.Value); + var existsMiss = batch.CheckExists((object)new Order2Id(777777)); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_with_long_strong_typed_id() + { + var issue = new Issue2 { Id = new Issue2Id(600L), Name = "Batch Test" }; + theSession.Store(issue); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists((object)issue.Id!.Value); + var existsMiss = batch.CheckExists((object)new Issue2Id(888888L)); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } + + [Fact] + public async Task check_exists_in_batch_with_string_strong_typed_id() + { + var team = new Team2 { Id = new Team2Id("batch-exists-test"), Name = "Batch Test" }; + theSession.Store(team); + await theSession.SaveChangesAsync(); + + var batch = theSession.CreateBatchQuery(); + var existsHit = batch.CheckExists((object)team.Id!.Value); + var existsMiss = batch.CheckExists((object)new Team2Id("not-there")); + await batch.Execute(); + + (await existsHit).ShouldBeTrue(); + (await existsMiss).ShouldBeFalse(); + } +}