diff --git a/.github/workflows/on-push-do-ci-build-postgis-pgvector.yml b/.github/workflows/on-push-do-ci-build-postgis-pgvector.yml new file mode 100644 index 0000000000..992ab8682c --- /dev/null +++ b/.github/workflows/on-push-do-ci-build-postgis-pgvector.yml @@ -0,0 +1,129 @@ +name: Build & Test - Extensions - PostGIS / PgVector + +# Dedicated workflow for the Marten.PostGIS and Marten.PgVector test projects. +# Each suite needs a Postgres image with the matching extension preinstalled — +# the default `postgres:15-alpine` service container used by the other CI +# workflows doesn't ship `postgis` or `vector`, so we run two jobs side-by-side +# with officially-maintained extension images. + +on: + push: + branches: + - master + paths-ignore: + - 'documentation/**' + - 'docs/**' + - 'azure-pipelines.yml' + pull_request: + branches: + - master + paths-ignore: + - 'documentation/**' + - 'docs/**' + - 'azure-pipelines.yml' + +env: + DOTNET_CLI_TELEMETRY_OPTOUT: 1 + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: 1 + pg_db: marten_testing + pg_user: postgres + CONFIGURATION: Release + FRAMEWORK: net9.0 + DISABLE_TEST_PARALLELIZATION: true + DEFAULT_SERIALIZER: "Newtonsoft" + NUKE_TELEMETRY_OPTOUT: true + +jobs: + postgis: + name: PostGIS extension tests + runs-on: ubuntu-latest + timeout-minutes: 20 + services: + postgres: + image: postgis/postgis:17-3.5 + ports: + - 5432:5432 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: ${{ env.pg_db }} + NAMEDATALEN: 150 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + --user postgres + steps: + - uses: actions/checkout@v6 + + - name: Install .NET + uses: actions/setup-dotnet@v5 + with: + dotnet-version: | + 9.0.x + 10.0.x + + - name: Optimize database for running tests faster + run: | + PG_CONTAINER_NAME=$(docker ps --filter expose=5432/tcp --format {{.Names}}) + docker exec $PG_CONTAINER_NAME bash -c "echo -e '\nfsync = off' >> /var/lib/postgresql/data/postgresql.conf" + docker exec $PG_CONTAINER_NAME bash -c "echo -e '\nfull_page_writes = off' >> /var/lib/postgresql/data/postgresql.conf" + docker exec $PG_CONTAINER_NAME bash -c "echo -e '\nsynchronous_commit = off' >> /var/lib/postgresql/data/postgresql.conf" + docker container restart $PG_CONTAINER_NAME + shell: bash + + - name: compile + run: ./build.sh compile + shell: bash + + - name: test-postgis + if: ${{ success() || failure() }} + run: ./build.sh TestPostGIS + shell: bash + + pgvector: + name: PgVector extension tests + runs-on: ubuntu-latest + timeout-minutes: 20 + services: + postgres: + image: pgvector/pgvector:pg17 + ports: + - 5432:5432 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: ${{ env.pg_db }} + NAMEDATALEN: 150 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + --user postgres + steps: + - uses: actions/checkout@v6 + + - name: Install .NET + uses: actions/setup-dotnet@v5 + with: + dotnet-version: | + 9.0.x + 10.0.x + + - name: Optimize database for running tests faster + run: | + PG_CONTAINER_NAME=$(docker ps --filter expose=5432/tcp --format {{.Names}}) + docker exec $PG_CONTAINER_NAME bash -c "echo -e '\nfsync = off' >> /var/lib/postgresql/data/postgresql.conf" + docker exec $PG_CONTAINER_NAME bash -c "echo -e '\nfull_page_writes = off' >> /var/lib/postgresql/data/postgresql.conf" + docker exec $PG_CONTAINER_NAME bash -c "echo -e '\nsynchronous_commit = off' >> /var/lib/postgresql/data/postgresql.conf" + docker container restart $PG_CONTAINER_NAME + shell: bash + + - name: compile + run: ./build.sh compile + shell: bash + + - name: test-pgvector + if: ${{ success() || failure() }} + run: ./build.sh TestPgVector + shell: bash diff --git a/.nuke/build.schema.json b/.nuke/build.schema.json index 391d9215f4..2b8f517709 100644 --- a/.nuke/build.schema.json +++ b/.nuke/build.schema.json @@ -57,6 +57,8 @@ "TestMultiTenancy", "TestNodaTime", "TestPatching", + "TestPgVector", + "TestPostGIS", "TestValueTypes" ] }, diff --git a/Directory.Packages.props b/Directory.Packages.props index 6bb7227f1e..5d98c605d8 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -43,12 +43,15 @@ + + + @@ -60,6 +63,7 @@ + diff --git a/build/build.cs b/build/build.cs index fbd39e4120..0d1a815f4c 100644 --- a/build/build.cs +++ b/build/build.cs @@ -42,7 +42,9 @@ class Build : NukeBuild Target TestExtensions => _ => _ .DependsOn(TestNodaTime) - .DependsOn(TestAspnetcore); + .DependsOn(TestAspnetcore) + .DependsOn(TestPostGIS) + .DependsOn(TestPgVector); Target Init => _ => _ .Executes(() => @@ -136,6 +138,30 @@ class Build : NukeBuild .SetFramework(Framework)); }); + Target TestPostGIS => _ => _ + .ProceedAfterFailure() + .Executes(() => + { + DotNetTest(c => c + .SetProjectFile("src/Marten.PostGIS.Tests") + .SetConfiguration(Configuration) + .EnableNoBuild() + .EnableNoRestore() + .SetFramework(Framework)); + }); + + Target TestPgVector => _ => _ + .ProceedAfterFailure() + .Executes(() => + { + DotNetTest(c => c + .SetProjectFile("src/Marten.PgVector.Tests") + .SetConfiguration(Configuration) + .EnableNoBuild() + .EnableNoRestore() + .SetFramework(Framework)); + }); + Target TestCore => _ => _ .ProceedAfterFailure() .Executes(() => diff --git a/docker-compose.yml b/docker-compose.yml index 54c904bab5..f100ae55bb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,14 @@ version: '3' services: postgresql: - image: ${POSTGRES_IMAGE:-ionx/postgres-plv8:12.8} + # Custom build layering PostGIS 3 + pgvector on top of the official + # multi-arch `postgres:17` image (see docker/postgres/Dockerfile). + # The Debian packages `postgresql-17-postgis-3` and + # `postgresql-17-pgvector` are available on both amd64 and arm64, so + # this works on Apple-silicon hosts. PLv8 was dropped — Marten core + # SQL no longer requires it. + build: + context: ./docker/postgres ports: - "5432:5432" environment: diff --git a/docker/postgres/Dockerfile b/docker/postgres/Dockerfile new file mode 100644 index 0000000000..ba15663543 --- /dev/null +++ b/docker/postgres/Dockerfile @@ -0,0 +1,16 @@ +# Custom PostgreSQL image for Marten local development. +# +# Layers PostGIS 3 (`postgresql-17-postgis-3`) and pgvector +# (`postgresql-17-pgvector`) on top of the official multi-arch `postgres:17` +# image. Both Debian packages are available for amd64 and arm64, so this +# builds and runs on Apple-silicon hosts without emulation. +# +# Used by `docker-compose.yml` to back the Marten test suite plus the +# Marten.PostGIS / Marten.PgVector extension test projects. +FROM postgres:17 + +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + postgresql-17-postgis-3 \ + postgresql-17-pgvector \ + && rm -rf /var/lib/apt/lists/* diff --git a/docs/.vitepress/config.mts b/docs/.vitepress/config.mts index bcae7a1701..705a84dd28 100644 --- a/docs/.vitepress/config.mts +++ b/docs/.vitepress/config.mts @@ -424,6 +424,14 @@ const config: UserConfig = { text: 'Read Replicas', link: '/postgres/read-replicas' }, + { + text: 'PostGIS spatial support', + link: '/postgres/postgis' + }, + { + text: 'pgvector support', + link: '/postgres/pgvector' + }, // { // text: 'Backup and restore', // link: '/postgres/backup-restore/' diff --git a/docs/cSpell.json b/docs/cSpell.json index f88e0bc6ec..9c39bd83ad 100644 --- a/docs/cSpell.json +++ b/docs/cSpell.json @@ -104,6 +104,8 @@ "ngram", "ngrams", "pgcrypto", + "pgvector", + "Ollama", "Jeremie", "Chassaing", "dedup" diff --git a/docs/postgres/pgvector.md b/docs/postgres/pgvector.md new file mode 100644 index 0000000000..9bfc26f427 --- /dev/null +++ b/docs/postgres/pgvector.md @@ -0,0 +1,181 @@ +# pgvector Support + +`Marten.PgVector` is an optional companion package that adds vector-similarity support to Marten on top of the [pgvector](https://github.com/pgvector/pgvector) PostgreSQL extension. It is published from the Marten repo under the MIT license and ships as the `Marten.PgVector` NuGet package. + +What it gives you: + +- a one-line `UsePgVector()` opt-in that registers the `vector` extension on every database Marten manages (including per-tenant databases) +- a `VectorSearchAsync` extension on `IQuerySession` that runs index-accelerated nearest-neighbor searches against a vector-typed property of a document +- a `VectorProjection` base class for event-sourced projections that maintain an embedding table alongside your stream, with content-hash skipping so unchanged content is not re-embedded +- an `IEmbeddingProvider` interface — `Marten.PgVector` is AI-model-agnostic; bring OpenAI, Ollama, a local model, or anything else + +## Installation + +```shell +dotnet add package Marten.PgVector +``` + +Your local PostgreSQL must ship the `vector` extension. The Dockerfile under `docker/postgres/Dockerfile` in this repo layers `postgresql-17-pgvector` (and `postgresql-17-postgis-3`) on the official multi-arch `postgres:17` image. In CI the [`pgvector/pgvector:pg17`](https://hub.docker.com/r/pgvector/pgvector) image is used. + +## Enabling pgvector on a store + +```csharp +using Marten; +using Marten.PgVector; + +var store = DocumentStore.For(opts => +{ + opts.Connection(connectionString); + + // 1. Adds CREATE EXTENSION IF NOT EXISTS vector to every database + // 2. Calls NpgsqlDataSourceBuilder.UseVector() so the Pgvector.Vector + // type round-trips through Npgsql + opts.UsePgVector(); + + opts.RegisterDocumentType(); +}); +``` + +`UsePgVector()` is multi-tenant aware. The single-server-per-tenant, master-table, and sharded tenancy strategies all create the extension in each tenant database via Marten's `ExtendedSchemaObjects`, which addresses the long-standing issue of extensions only being created on the default database ([#2515](https://github.com/JasperFx/marten/issues/2515)). + +## Storing vectors on a document + +Put a `float[]` (or `Pgvector.Vector`) on the document. The array round-trips into the JSONB document and is cast to `vector(N)` at query time. + +```csharp +public class ProductWithVector +{ + public Guid Id { get; set; } + public string Name { get; set; } = ""; + + // Stored as a float[] inside JSONB; cast to vector() at query time. + public float[]? Embedding { get; set; } + public string Category { get; set; } = ""; +} +``` + +## Vector similarity search + +`VectorSearchAsync` runs an ordered nearest-neighbor query against the chosen vector property. The three distance functions match the three pgvector index operator classes: + +| `DistanceFunction` | pgvector operator | Index ops class | Typical use | +| ------------------ | ----------------- | --------------------- | ------------------------- | +| `L2` | `<->` | `vector_l2_ops` | Euclidean distance | +| `Cosine` (default) | `<=>` | `vector_cosine_ops` | Text embeddings | +| `InnerProduct` | `<#>` | `vector_ip_ops` | Inner product (negative) | + +```csharp +using Pgvector; + +var queryVector = new Vector(new float[] { 1.0f, 0.0f, 0.0f }); + +await using var q = store.QuerySession(); + +var hits = await q.VectorSearchAsync( + x => x.Embedding, + queryVector, + limit: 10, + distance: DistanceFunction.L2); +``` + +In conjoined multi-tenancy stores (`AllDocumentsAreMultiTenanted` + a tenant-scoped session) the search adds an automatic `tenant_id` filter so a tenant only sees its own vectors. Database-per-tenant setups are isolated at the connection level and need no extra filtering. + +## Event-sourced vector projection + +`VectorProjection` is a base class for projections that maintain an embedding table alongside your stream. It handles the boilerplate of mapping events to text, hashing content, calling your `IEmbeddingProvider`, and writing the embeddings — skipping the embedding API call when content has not changed. + +```csharp +public record ProductCreated(Guid ProductId, string Name, string Description); +public record ProductUpdated(Guid ProductId, string Description); +public record ProductDeleted(Guid ProductId); + +public class ProductSearchProjection : VectorProjection +{ + public ProductSearchProjection(IEmbeddingProvider provider) + : base("product_search_vectors", provider) + { + } + + protected override void Configure(VectorProjectionMapping map) + { + map.Map( + e => $"{e.Name} {e.Description}", + e => e.ProductId); + + map.Map( + e => e.Description, + e => e.ProductId); + + map.Delete(); + } +} +``` + +Register it like any other projection, and also register the projection's storage table as a schema object so Marten creates it: + +```csharp +var projection = new ProductSearchProjection(myEmbeddingProvider); + +var store = DocumentStore.For(opts => +{ + opts.Connection(connectionString); + opts.UsePgVector(); + + opts.Projections.Add(projection, ProjectionLifecycle.Async); + opts.Storage.ExtendedSchemaObjects.Add( + projection.BuildTable(opts.Events.DatabaseSchemaName)); + + opts.Events.AddEventType(); + opts.Events.AddEventType(); + opts.Events.AddEventType(); +}); +``` + +The created table has the shape: + +| Column | Type | Notes | +| -------------- | ------------- | --------------------------------------------------------------------------- | +| `id` | `uuid` | Primary key — the projection's logical identity (defaults to `StreamId`) | +| `embedding` | `vector(N)` | `N` comes from `IEmbeddingProvider.Dimensions` | +| `content_text` | `text` | The source text that was embedded | +| `content_hash` | `text` | SHA-256 of `content_text` — used to skip re-embedding | +| `metadata` | `jsonb` | Reserved for caller-supplied metadata | +| `last_updated` | `timestamptz` | `now()` default, refreshed on upsert | + +### Querying the projection table + +`VectorProjectionSearchAsync` runs the canonical ordered-by-distance query against the projection table and returns the `Guid` id, distance, and the original content text: + +```csharp +var results = await q.VectorProjectionSearchAsync( + "product_search_vectors", + myEmbeddingProvider.Embed("red running shoes"), + limit: 10, + distance: DistanceFunction.Cosine); + +foreach (var r in results) +{ + Console.WriteLine($"{r.Id} distance={r.Distance} {r.ContentText}"); +} +``` + +## Bring-your-own embeddings + +`Marten.PgVector` does not ship a default embedding provider — implement `IEmbeddingProvider` against the model you want: + +```csharp +public interface IEmbeddingProvider +{ + int Dimensions { get; } + Task GenerateEmbeddingsAsync(string[] texts, CancellationToken ct = default); +} +``` + +`Dimensions` must match the `vector(N)` column the projection creates. Mixing dimensions across versions is a permanent migration — `pgvector` does not let you change the column width in place. + +## Notes & limitations + +- `VectorSearchAsync` runs raw SQL through the session's connection — it does not go through Marten's LINQ provider or compiled-query cache. Document instances are deserialized via the store's `ISerializer`. +- The vector value lives inside the JSONB document and is cast at query time (`(d.data->>'')::vector(N)`). For large tables, add an [HNSW or IVFFlat](https://github.com/pgvector/pgvector#indexing) index on that expression to keep similarity queries index-accelerated. +- Only simple member access expressions are supported in the vector property selector (`x => x.Embedding`), matching the Marten LINQ conventions. +- `VectorProjection` requires async execution — the synchronous `IProjection.Apply` overload throws. diff --git a/docs/postgres/postgis.md b/docs/postgres/postgis.md new file mode 100644 index 0000000000..e7c0e0c01d --- /dev/null +++ b/docs/postgres/postgis.md @@ -0,0 +1,134 @@ +# PostGIS Spatial Support + +`Marten.PostGIS` is an optional companion package that adds spatial-data support to Marten on top of the [PostGIS](https://postgis.net/) PostgreSQL extension. It is published from the Marten repo under the MIT license and ships as the `Marten.PostGIS` NuGet package. + +What it gives you: + +- a one-line `UsePostGIS()` opt-in that registers the `postgis` extension on every database Marten manages (including per-tenant databases) +- Newtonsoft.Json converters that round-trip [NetTopologySuite](https://github.com/NetTopologySuite/NetTopologySuite) geometry types (`Point`, `Polygon`, `LineString`, …) into the JSONB document +- four query helpers — `NearestToAsync`, `WithinDistanceAsync`, `ContainingAsync`, `IntersectingAsync` — that translate to the canonical PostGIS operators + +## Installation + +```shell +dotnet add package Marten.PostGIS +``` + +Your local PostgreSQL must ship the `postgis` extension. The Dockerfile under `docker/postgres/Dockerfile` in this repo layers `postgresql-17-postgis-3` (and `postgresql-17-pgvector`) on the official multi-arch `postgres:17` image. + +## Enabling PostGIS on a store + +```csharp +using Marten; +using Marten.PostGIS; + +var store = DocumentStore.For(opts => +{ + opts.Connection(connectionString); + + // 1. Adds CREATE EXTENSION IF NOT EXISTS postgis to every database + // 2. Calls NpgsqlDataSourceBuilder.UseNetTopologySuite() so NTS types + // round-trip through Npgsql + // 3. Swaps in a JsonNetSerializer with the NTS GeoJsonSerializer + // converters registered, so NTS geometries serialize as GeoJSON + // inside the document's JSONB column + opts.UsePostGIS(); + + opts.RegisterDocumentType(); +}); +``` + +`UsePostGIS()` is multi-tenant aware. Multi-database setups (single-server-per-tenant, master-table tenancy, sharded tenancy) all create the extension in each tenant database via Marten's `ExtendedSchemaObjects`. + +## Modelling a spatial document + +Put any NetTopologySuite geometry type on your document. The default factory `new GeometryFactory(new PrecisionModel(), 4326)` corresponds to [WGS 84](https://en.wikipedia.org/wiki/World_Geodetic_System) — the standard lat/lon coordinate system. + +```csharp +using NetTopologySuite.Geometries; + +public class StoreLocation +{ + public Guid Id { get; set; } + public string Name { get; set; } = ""; + public Point? Location { get; set; } +} + +public class ServiceArea +{ + public Guid Id { get; set; } + public string Name { get; set; } = ""; + public Polygon? Boundary { get; set; } +} +``` + +Insert and load just like any Marten document: + +```csharp +var wgs84 = new GeometryFactory(new PrecisionModel(), 4326); + +await using (var session = store.LightweightSession()) +{ + session.Store(new StoreLocation + { + Id = Guid.NewGuid(), + Name = "Downtown Store", + Location = wgs84.CreatePoint(new Coordinate(-122.33, 47.61)) + }); + await session.SaveChangesAsync(); +} +``` + +## Spatial queries + +The query helpers are extension methods on `IQuerySession`. They take a lambda picking the spatial property, an `NTS` geometry, and (for distance-flavoured queries) a `SpatialType`: + +| `SpatialType` | PostGIS cast | When to use | +| --------------------- | -------------- | ---------------------------------------------------------------------------------------------- | +| `Geography` (default) | `::geography` | Lat/lon on Earth — distances are in **metres**, accurate for global data | +| `Geometry` | `::geometry` | Cartesian (projected) plane — faster, distances are in the SRID's units (degrees for WGS 84) | + +### Nearest neighbor + +```csharp +await using var q = store.QuerySession(); + +var nearest = await q.NearestToAsync( + x => x.Location, + point: wgs84.CreatePoint(new Coordinate(-122.33, 47.61)), + limit: 5, + spatialType: SpatialType.Geometry); +``` + +Translates to `ORDER BY :: <-> $1 LIMIT $2` using the [`<->` KNN operator](https://postgis.net/docs/geometry_distance_knn.html), which is index-accelerated when a GiST index exists on the column. + +### Within a distance + +```csharp +var nearby = await q.WithinDistanceAsync( + x => x.Location, + point: downtownSeattle, + distanceMeters: 5000, + spatialType: SpatialType.Geography); +``` + +Translates to `ST_DWithin(::, $1, $2)` — the canonical index-accelerated distance filter. + +### Containing / intersecting + +```csharp +var coveringAreas = await q.ContainingAsync( + x => x.Boundary, downtownSeattle, SpatialType.Geometry); + +var overlappingAreas = await q.IntersectingAsync( + x => x.Boundary, marketBoundary, SpatialType.Geometry); +``` + +These map to [`ST_Contains`](https://postgis.net/docs/ST_Contains.html) and [`ST_Intersects`](https://postgis.net/docs/ST_Intersects.html) respectively. + +## Notes & limitations + +- The query helpers run raw SQL through the session's connection — they do not go through Marten's LINQ provider or compiled-query cache. Document instances are deserialized via the store's `ISerializer`. +- The spatial value lives inside the JSONB document and is cast to PostGIS types at query time (`ST_GeomFromGeoJSON(d.data->'')::`). For large tables, add a [functional GiST index](https://postgis.net/workshops/postgis-intro/indexing.html) on that expression to keep the spatial operators index-accelerated. +- The Newtonsoft `JsonNetSerializer` is registered for you by `UsePostGIS()`. If you have your own serializer configuration, call `UsePostGIS()` first and tweak the serializer afterwards. +- Only simple member access expressions are supported in the spatial property selector (`x => x.Location`), matching the Marten LINQ conventions. diff --git a/src/Marten.PgVector.Tests/ConjoinedTenancy/conjoined_vector_tests.cs b/src/Marten.PgVector.Tests/ConjoinedTenancy/conjoined_vector_tests.cs new file mode 100644 index 0000000000..4774614497 --- /dev/null +++ b/src/Marten.PgVector.Tests/ConjoinedTenancy/conjoined_vector_tests.cs @@ -0,0 +1,135 @@ +using Marten.Storage; +using Marten.PgVector; +using Marten.PgVector.Tests.SingleTenancy; +using Marten.Testing.Harness; +using Pgvector; +using Shouldly; +using Xunit; + +namespace Marten.PgVector.Tests.ConjoinedTenancy; + +[Collection("Marten.PgVector")] +public class conjoined_vector_tests : IAsyncLifetime +{ + private DocumentStore _store = null!; + + public async Task InitializeAsync() + { + _store = DocumentStore.For(opts => + { + opts.Connection(ConnectionSource.ConnectionString); + opts.DatabaseSchemaName = "pgvector_conjoined"; + opts.AutoCreateSchemaObjects = JasperFx.AutoCreate.All; + + opts.UsePgVector(); + + // Enable conjoined multi-tenancy + opts.Policies.AllDocumentsAreMultiTenanted(); + opts.Events.TenancyStyle = TenancyStyle.Conjoined; + + opts.RegisterDocumentType(); + }); + + await _store.Advanced.Clean.CompletelyRemoveAllAsync(); + await _store.Storage.ApplyAllConfiguredChangesToDatabaseAsync(); + } + + public Task DisposeAsync() + { + _store?.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task can_store_and_query_vectors_per_tenant() + { + // Store vectors for tenant A + await using (var session = _store.LightweightSession("tenant_a")) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), + Name = "Tenant A Widget", + Embedding = new float[] { 1, 0, 0 } + }); + await session.SaveChangesAsync(); + } + + // Store vectors for tenant B + await using (var session = _store.LightweightSession("tenant_b")) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), + Name = "Tenant B Widget", + Embedding = new float[] { 0, 1, 0 } + }); + await session.SaveChangesAsync(); + } + + // Query tenant A — should only see tenant A's data + await using (var q = _store.QuerySession("tenant_a")) + { + var results = await q.Query().ToListAsync(); + results.Count.ShouldBe(1); + results[0].Name.ShouldBe("Tenant A Widget"); + } + + // Query tenant B — should only see tenant B's data + await using (var q = _store.QuerySession("tenant_b")) + { + var results = await q.Query().ToListAsync(); + results.Count.ShouldBe(1); + results[0].Name.ShouldBe("Tenant B Widget"); + } + } + + [Fact] + public async Task vector_search_respects_tenant_isolation() + { + // Store different vectors per tenant + await using (var session = _store.LightweightSession("search_a")) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "A-Close", + Embedding = new float[] { 0.9f, 0.1f, 0 } + }); + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "A-Far", + Embedding = new float[] { 0, 0, 1 } + }); + await session.SaveChangesAsync(); + } + + await using (var session = _store.LightweightSession("search_b")) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "B-Close", + Embedding = new float[] { 0.95f, 0.05f, 0 } + }); + await session.SaveChangesAsync(); + } + + // Search tenant A for vectors near [1, 0, 0] + var queryVector = new Vector(new float[] { 1, 0, 0 }); + + await using var qa = _store.QuerySession("search_a"); + var resultsA = await qa.VectorSearchAsync( + x => x.Embedding, queryVector, limit: 10, distance: DistanceFunction.L2); + + // Should only find tenant A's documents + resultsA.Count.ShouldBe(2); + resultsA.ShouldAllBe(r => r.Name.StartsWith("A-")); + + // Search tenant B + await using var qb = _store.QuerySession("search_b"); + var resultsB = await qb.VectorSearchAsync( + x => x.Embedding, queryVector, limit: 10, distance: DistanceFunction.L2); + + resultsB.Count.ShouldBe(1); + resultsB[0].Name.ShouldBe("B-Close"); + } +} diff --git a/src/Marten.PgVector.Tests/Helpers/FakeEmbeddingProvider.cs b/src/Marten.PgVector.Tests/Helpers/FakeEmbeddingProvider.cs new file mode 100644 index 0000000000..8df9f12a1d --- /dev/null +++ b/src/Marten.PgVector.Tests/Helpers/FakeEmbeddingProvider.cs @@ -0,0 +1,43 @@ +using Marten.PgVector.Projection; +using Pgvector; + +namespace Marten.PgVector.Tests.Helpers; + +/// +/// Deterministic fake embedding provider for testing. +/// Generates a simple hash-based vector from the input text. +/// +public class FakeEmbeddingProvider : IEmbeddingProvider +{ + public int Dimensions { get; } + + public FakeEmbeddingProvider(int dimensions = 3) + { + Dimensions = dimensions; + } + + public Task GenerateEmbeddingsAsync(string[] texts, CancellationToken ct = default) + { + var results = new Vector[texts.Length]; + for (int i = 0; i < texts.Length; i++) + { + results[i] = GenerateVector(texts[i]); + } + return Task.FromResult(results); + } + + /// + /// Generate a deterministic vector from text — same text always produces the same vector. + /// + public Vector GenerateVector(string text) + { + var hash = text.GetHashCode(); + var values = new float[Dimensions]; + for (int i = 0; i < Dimensions; i++) + { + // Deterministic but varied per dimension + values[i] = (float)Math.Sin(hash + i * 7) * 0.5f + 0.5f; + } + return new Vector(values); + } +} diff --git a/src/Marten.PgVector.Tests/Marten.PgVector.Tests.csproj b/src/Marten.PgVector.Tests/Marten.PgVector.Tests.csproj new file mode 100644 index 0000000000..a25921b557 --- /dev/null +++ b/src/Marten.PgVector.Tests/Marten.PgVector.Tests.csproj @@ -0,0 +1,22 @@ + + + false + enable + enable + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/src/Marten.PgVector.Tests/MultiDatabase/database_per_tenant_vector_tests.cs b/src/Marten.PgVector.Tests/MultiDatabase/database_per_tenant_vector_tests.cs new file mode 100644 index 0000000000..269e9dc1ff --- /dev/null +++ b/src/Marten.PgVector.Tests/MultiDatabase/database_per_tenant_vector_tests.cs @@ -0,0 +1,218 @@ +using Marten.PgVector; +using Marten.PgVector.Tests.SingleTenancy; +using Marten.Storage; +using Marten.Testing.Harness; +using Npgsql; +using Pgvector; +using Shouldly; +using Weasel.Postgresql; +using Weasel.Postgresql.Migrations; +using Xunit; + +namespace Marten.PgVector.Tests.MultiDatabase; + +/// +/// Tests that verify pgvector works with database-per-tenant multi-tenancy. +/// Also serves as verification for issue #2515 — PostgreSQL extensions must be +/// created in EACH tenant database, not just the default. +/// +[Collection("Marten.PgVector")] +public class database_per_tenant_vector_tests : IAsyncLifetime +{ + private static readonly string[] TenantDatabases = { "pgvector_tenant1", "pgvector_tenant2", "pgvector_tenant3" }; + + private DocumentStore _store = null!; + private readonly Dictionary _tenantConnStrs = new(); + + public async Task InitializeAsync() + { + // Create the per-tenant databases if they don't exist and drop any stale schema + await using (var conn = new NpgsqlConnection(ConnectionSource.ConnectionString)) + { + await conn.OpenAsync(); + + foreach (var db in TenantDatabases) + { + _tenantConnStrs[db] = await CreateDatabaseIfNotExists(conn, db); + } + } + + _store = DocumentStore.For(opts => + { + opts.DatabaseSchemaName = "pgvector_mt"; + opts.AutoCreateSchemaObjects = JasperFx.AutoCreate.All; + + opts.UsePgVector(); + opts.RegisterDocumentType(); + + // Configure database-per-tenant using SingleServerMultiTenancy + opts.MultiTenantedDatabases(x => + { + foreach (var db in TenantDatabases) + { + x.AddSingleTenantDatabase(_tenantConnStrs[db], db); + } + }); + }); + + await _store.Storage.ApplyAllConfiguredChangesToDatabaseAsync(); + } + + public Task DisposeAsync() + { + _store?.Dispose(); + return Task.CompletedTask; + } + + private static async Task CreateDatabaseIfNotExists(NpgsqlConnection conn, string databaseName) + { + var builder = new NpgsqlConnectionStringBuilder(ConnectionSource.ConnectionString); + + var exists = await conn.DatabaseExists(databaseName); + if (!exists) + { + await new DatabaseSpecification().BuildDatabase(conn, databaseName); + } + + builder.Database = databaseName; + var connectionString = builder.ConnectionString; + + // Wipe any prior pgvector_mt schema so the test runs clean + await SchemaUtils.DropSchema(connectionString, "pgvector_mt"); + + return connectionString; + } + + /// + /// Issue #2515 verification: the pgvector extension must be created in each tenant database. + /// + [Fact] + public async Task vector_extension_is_created_in_each_tenant_database() + { + foreach (var db in TenantDatabases) + { + await using var conn = new NpgsqlConnection(_tenantConnStrs[db]); + await conn.OpenAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1 FROM pg_extension WHERE extname = 'vector'"; + var result = await cmd.ExecuteScalarAsync(); + + result.ShouldNotBeNull( + $"pgvector extension was NOT created in database '{db}'. " + + "This is the core issue from #2515 — extensions must be created in all tenant databases."); + } + } + + [Fact] + public async Task document_tables_exist_in_each_tenant_database() + { + foreach (var db in TenantDatabases) + { + await using var conn = new NpgsqlConnection(_tenantConnStrs[db]); + await conn.OpenAsync(); + + var tables = await conn.ExistingTablesAsync(); + tables.Any(t => t.Name.Contains("mt_doc_productwithvector")) + .ShouldBeTrue($"Document table not found in database '{db}'"); + } + } + + [Fact] + public async Task can_store_and_load_documents_per_tenant() + { + // Store in tenant1 + await using (var session = _store.LightweightSession(TenantDatabases[0])) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "T1 Product", + Embedding = new float[] { 1, 0, 0 } + }); + await session.SaveChangesAsync(); + } + + // Store in tenant2 + await using (var session = _store.LightweightSession(TenantDatabases[1])) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "T2 Product", + Embedding = new float[] { 0, 1, 0 } + }); + await session.SaveChangesAsync(); + } + + // Query tenant1 + await using (var q = _store.QuerySession(TenantDatabases[0])) + { + var results = await q.Query().ToListAsync(); + results.Count.ShouldBe(1); + results[0].Name.ShouldBe("T1 Product"); + } + + // Query tenant2 + await using (var q = _store.QuerySession(TenantDatabases[1])) + { + var results = await q.Query().ToListAsync(); + results.Count.ShouldBe(1); + results[0].Name.ShouldBe("T2 Product"); + } + + // Tenant3 should be empty + await using (var q = _store.QuerySession(TenantDatabases[2])) + { + var results = await q.Query().ToListAsync(); + results.Count.ShouldBe(0); + } + } + + [Fact] + public async Task vector_search_works_per_tenant_database() + { + // Store vectors in different tenant databases + await using (var session = _store.LightweightSession(TenantDatabases[0])) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "T1-Near", + Embedding = new float[] { 0.9f, 0.1f, 0 } + }); + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "T1-Far", + Embedding = new float[] { 0, 0, 1 } + }); + await session.SaveChangesAsync(); + } + + await using (var session = _store.LightweightSession(TenantDatabases[1])) + { + session.Store(new ProductWithVector + { + Id = Guid.NewGuid(), Name = "T2-Near", + Embedding = new float[] { 0.95f, 0.05f, 0 } + }); + await session.SaveChangesAsync(); + } + + var queryVector = new Vector(new float[] { 1, 0, 0 }); + + // Search tenant1 — should only find tenant1's documents + await using var q1 = _store.QuerySession(TenantDatabases[0]); + var results1 = await q1.VectorSearchAsync( + x => x.Embedding, queryVector, limit: 10, distance: DistanceFunction.L2); + + results1.Count.ShouldBe(2); + results1.ShouldAllBe(r => r.Name.StartsWith("T1-")); + results1[0].Name.ShouldBe("T1-Near"); // closest first + + // Search tenant2 — should only find tenant2's documents + await using var q2 = _store.QuerySession(TenantDatabases[1]); + var results2 = await q2.VectorSearchAsync( + x => x.Embedding, queryVector, limit: 10, distance: DistanceFunction.L2); + + results2.Count.ShouldBe(1); + results2[0].Name.ShouldBe("T2-Near"); + } +} diff --git a/src/Marten.PgVector.Tests/PgVectorCollection.cs b/src/Marten.PgVector.Tests/PgVectorCollection.cs new file mode 100644 index 0000000000..91468e182b --- /dev/null +++ b/src/Marten.PgVector.Tests/PgVectorCollection.cs @@ -0,0 +1,16 @@ +using Xunit; + +namespace Marten.PgVector.Tests; + +/// +/// Forces all `Marten.PgVector` test classes into a single xUnit collection so they run +/// sequentially. Each test class spins up its own DocumentStore that calls +/// `UsePgVector()`, which registers a `CREATE EXTENSION IF NOT EXISTS vector` schema +/// object. PostgreSQL's `CREATE EXTENSION IF NOT EXISTS` is not race-safe — concurrent +/// callers can both pass the existence check before either has inserted into +/// `pg_extension`, and the loser hits `23505 duplicate key` on +/// `pg_extension_name_index`. xUnit's default puts each test class in its own +/// auto-collection (parallel), so without this we'd race against ourselves. +/// +[CollectionDefinition("Marten.PgVector", DisableParallelization = true)] +public class PgVectorCollection; diff --git a/src/Marten.PgVector.Tests/SingleTenancy/vector_column_tests.cs b/src/Marten.PgVector.Tests/SingleTenancy/vector_column_tests.cs new file mode 100644 index 0000000000..1c1aa6a2c7 --- /dev/null +++ b/src/Marten.PgVector.Tests/SingleTenancy/vector_column_tests.cs @@ -0,0 +1,113 @@ +using Marten.PgVector; +using Marten.Testing.Harness; +using Pgvector; +using Shouldly; +using Xunit; + +namespace Marten.PgVector.Tests.SingleTenancy; + +[Collection("Marten.PgVector")] +public class vector_column_tests : IAsyncLifetime +{ + private DocumentStore _store = null!; + + public async Task InitializeAsync() + { + _store = DocumentStore.For(opts => + { + opts.Connection(ConnectionSource.ConnectionString); + opts.DatabaseSchemaName = "pgvector_tests"; + opts.AutoCreateSchemaObjects = JasperFx.AutoCreate.All; + + opts.UsePgVector(); + opts.RegisterDocumentType(); + }); + + await _store.Advanced.Clean.CompletelyRemoveAllAsync(); + await _store.Storage.ApplyAllConfiguredChangesToDatabaseAsync(); + } + + public Task DisposeAsync() + { + _store?.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task can_store_and_load_document_with_vector() + { + var product = new ProductWithVector + { + Id = Guid.NewGuid(), + Name = "Widget", + Embedding = new float[] { 1.0f, 2.0f, 3.0f } + }; + + await using (var session = _store.LightweightSession()) + { + session.Store(product); + await session.SaveChangesAsync(); + } + + await using (var query = _store.QuerySession()) + { + var loaded = await query.LoadAsync(product.Id); + loaded.ShouldNotBeNull(); + loaded.Name.ShouldBe("Widget"); + loaded.Embedding.ShouldNotBeNull(); + } + } + + [Fact] + public async Task can_search_by_vector_similarity() + { + var products = new[] + { + new ProductWithVector { Id = Guid.NewGuid(), Name = "A", Embedding = new float[] { 1, 0, 0 } }, + new ProductWithVector { Id = Guid.NewGuid(), Name = "B", Embedding = new float[] { 0, 1, 0 } }, + new ProductWithVector { Id = Guid.NewGuid(), Name = "C", Embedding = new float[] { 0, 0, 1 } }, + new ProductWithVector { Id = Guid.NewGuid(), Name = "Near A", Embedding = new float[] { 0.9f, 0.1f, 0 } }, + }; + + await using (var session = _store.LightweightSession()) + { + foreach (var p in products) session.Store(p); + await session.SaveChangesAsync(); + } + + // Search for vectors near [1, 0, 0] — should return "A" and "Near A" first + var queryVector = new Vector(new float[] { 1, 0, 0 }); + + await using var querySession = _store.QuerySession(); + var results = await querySession.VectorSearchAsync( + x => x.Embedding, queryVector, limit: 2, distance: DistanceFunction.L2); + + results.Count.ShouldBe(2); + results[0].Name.ShouldBe("A"); + results[1].Name.ShouldBe("Near A"); + } + + [Fact] + public async Task vector_extension_is_created() + { + await using var conn = _store.Storage.Database.CreateConnection(); + await conn.OpenAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1 FROM pg_extension WHERE extname = 'vector'"; + var result = await cmd.ExecuteScalarAsync(); + result.ShouldNotBeNull(); + } +} + +public class ProductWithVector +{ + public Guid Id { get; set; } + public string Name { get; set; } = ""; + + /// + /// Stored as a float array in JSONB, cast to vector() at query time. + /// + public float[]? Embedding { get; set; } + public string Category { get; set; } = ""; +} diff --git a/src/Marten.PgVector.Tests/SingleTenancy/vector_projection_tests.cs b/src/Marten.PgVector.Tests/SingleTenancy/vector_projection_tests.cs new file mode 100644 index 0000000000..6365b7e6a6 --- /dev/null +++ b/src/Marten.PgVector.Tests/SingleTenancy/vector_projection_tests.cs @@ -0,0 +1,270 @@ +using JasperFx.Events.Projections; +using Marten.PgVector; +using Marten.PgVector.Projection; +using Marten.PgVector.Tests.Helpers; +using Marten.Testing.Harness; +using Pgvector; +using Shouldly; +using Xunit; + +namespace Marten.PgVector.Tests.SingleTenancy; + +#region Test Events + +public record ProductCreated(Guid ProductId, string Name, string Description); +public record ProductUpdated(Guid ProductId, string Description); +public record ProductDeleted(Guid ProductId); + +#endregion + +#region Test Projection + +public class ProductSearchProjection : VectorProjection +{ + public ProductSearchProjection(IEmbeddingProvider provider) + : base("product_search_vectors", provider) + { + } + + protected override void Configure(VectorProjectionMapping map) + { + map.Map( + e => $"{e.Name} {e.Description}", + e => e.ProductId); + + map.Map( + e => e.Description, + e => e.ProductId); + + map.Delete(); + } +} + +#endregion + +[Collection("Marten.PgVector")] +public class vector_projection_tests : IAsyncLifetime +{ + private DocumentStore _store = null!; + private FakeEmbeddingProvider _embedder = null!; + + public async Task InitializeAsync() + { + _embedder = new FakeEmbeddingProvider(dimensions: 3); + + var projection = new ProductSearchProjection(_embedder); + + _store = DocumentStore.For(opts => + { + opts.Connection(ConnectionSource.ConnectionString); + opts.DatabaseSchemaName = "pgvector_proj_tests"; + opts.AutoCreateSchemaObjects = JasperFx.AutoCreate.All; + + opts.UsePgVector(); + + // Register the projection as Inline for simpler testing + // (In production, use Async lifecycle with the daemon) + opts.Projections.Add(projection, ProjectionLifecycle.Inline); + + // Register the projection's table as a schema object + opts.Storage.ExtendedSchemaObjects.Add(projection.BuildTable("pgvector_proj_tests")); + + opts.Events.AddEventType(); + opts.Events.AddEventType(); + opts.Events.AddEventType(); + }); + + await _store.Advanced.Clean.CompletelyRemoveAllAsync(); + await _store.Storage.ApplyAllConfiguredChangesToDatabaseAsync(); + } + + public Task DisposeAsync() + { + _store?.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task creates_embedding_on_event_append() + { + var productId = Guid.NewGuid(); + + await using var session = _store.LightweightSession(); + session.Events.StartStream(productId, + new ProductCreated(productId, "Widget", "A fantastic widget for all purposes")); + await session.SaveChangesAsync(); + + // Query the projection table directly + var results = await session.VectorProjectionSearchAsync( + "product_search_vectors", + _embedder.GenerateVector("Widget A fantastic widget for all purposes"), + limit: 10, + distance: DistanceFunction.L2); + + results.Count.ShouldBe(1); + results[0].Id.ShouldBe(productId); + results[0].ContentText.ShouldBe("Widget A fantastic widget for all purposes"); + } + + [Fact] + public async Task updates_embedding_when_content_changes() + { + var productId = Guid.NewGuid(); + + await using (var session = _store.LightweightSession()) + { + session.Events.StartStream(productId, + new ProductCreated(productId, "Widget", "Original description")); + await session.SaveChangesAsync(); + } + + // Update the product + await using (var session = _store.LightweightSession()) + { + session.Events.Append(productId, + new ProductUpdated(productId, "Updated description")); + await session.SaveChangesAsync(); + } + + // Search should find the updated content + await using var querySession = _store.QuerySession(); + var results = await querySession.VectorProjectionSearchAsync( + "product_search_vectors", + _embedder.GenerateVector("Updated description"), + limit: 10, + distance: DistanceFunction.L2); + + results.Count.ShouldBe(1); + results[0].ContentText.ShouldBe("Updated description"); + } + + [Fact] + public async Task skips_re_embedding_when_content_unchanged() + { + var productId = Guid.NewGuid(); + var callCount = 0; + var countingEmbedder = new CallCountingEmbeddingProvider(_embedder, () => callCount++); + + var projection = new ProductSearchProjection(countingEmbedder); + var store = DocumentStore.For(opts => + { + opts.Connection(ConnectionSource.ConnectionString); + opts.DatabaseSchemaName = "pgvector_hash_tests"; + opts.AutoCreateSchemaObjects = JasperFx.AutoCreate.All; + opts.UsePgVector(); + opts.Projections.Add(projection, ProjectionLifecycle.Inline); + opts.Storage.ExtendedSchemaObjects.Add(projection.BuildTable("pgvector_hash_tests")); + opts.Events.AddEventType(); + }); + + await store.Advanced.Clean.CompletelyRemoveAllAsync(); + await store.Storage.ApplyAllConfiguredChangesToDatabaseAsync(); + + // First append — should call embedder + await using (var session = store.LightweightSession()) + { + session.Events.StartStream(productId, + new ProductCreated(productId, "Widget", "Same content")); + await session.SaveChangesAsync(); + } + + var firstCallCount = callCount; + firstCallCount.ShouldBeGreaterThan(0); + + // Append same content again — embedder should NOT be called + await using (var session = store.LightweightSession()) + { + session.Events.Append(productId, + new ProductCreated(productId, "Widget", "Same content")); + await session.SaveChangesAsync(); + } + + callCount.ShouldBe(firstCallCount); // No additional calls + + store.Dispose(); + } + + [Fact] + public async Task deletes_embedding_on_delete_event() + { + var productId = Guid.NewGuid(); + + await using (var session = _store.LightweightSession()) + { + session.Events.StartStream(productId, + new ProductCreated(productId, "Widget", "To be deleted")); + await session.SaveChangesAsync(); + } + + // Delete + await using (var session = _store.LightweightSession()) + { + session.Events.Append(productId, + new ProductDeleted(productId)); + await session.SaveChangesAsync(); + } + + // Search should find nothing + await using var querySession = _store.QuerySession(); + var results = await querySession.VectorProjectionSearchAsync( + "product_search_vectors", + _embedder.GenerateVector("anything"), + limit: 10, + distance: DistanceFunction.L2); + + results.Count.ShouldBe(0); + } + + [Fact] + public async Task multiple_products_searchable_by_similarity() + { + var id1 = Guid.NewGuid(); + var id2 = Guid.NewGuid(); + var id3 = Guid.NewGuid(); + + await using var session = _store.LightweightSession(); + session.Events.StartStream(id1, + new ProductCreated(id1, "Red Shoes", "Bright red running shoes")); + session.Events.StartStream(id2, + new ProductCreated(id2, "Blue Shoes", "Navy blue casual shoes")); + session.Events.StartStream(id3, + new ProductCreated(id3, "Garden Hose", "50ft expandable garden hose")); + await session.SaveChangesAsync(); + + // Search — should return all 3 ordered by distance + await using var querySession = _store.QuerySession(); + var results = await querySession.VectorProjectionSearchAsync( + "product_search_vectors", + _embedder.GenerateVector("Red Shoes Bright red running shoes"), + limit: 10, + distance: DistanceFunction.L2); + + results.Count.ShouldBe(3); + // First result should be exact match + results[0].Id.ShouldBe(id1); + results[0].Distance.ShouldBe(0f); // Exact match = 0 distance + } +} + +/// +/// Wrapper that counts embedding API calls for testing content hash skipping. +/// +internal class CallCountingEmbeddingProvider : IEmbeddingProvider +{ + private readonly IEmbeddingProvider _inner; + private readonly Action _onCall; + + public CallCountingEmbeddingProvider(IEmbeddingProvider inner, Action onCall) + { + _inner = inner; + _onCall = onCall; + } + + public int Dimensions => _inner.Dimensions; + + public Task GenerateEmbeddingsAsync(string[] texts, CancellationToken ct = default) + { + _onCall(); + return _inner.GenerateEmbeddingsAsync(texts, ct); + } +} diff --git a/src/Marten.PgVector/DistanceFunction.cs b/src/Marten.PgVector/DistanceFunction.cs new file mode 100644 index 0000000000..e4b849806a --- /dev/null +++ b/src/Marten.PgVector/DistanceFunction.cs @@ -0,0 +1,45 @@ +namespace Marten.PgVector; + +/// +/// Distance functions supported by pgvector for similarity search. +/// +public enum DistanceFunction +{ + /// + /// Euclidean (L2) distance. Operator: <-> + /// Index ops class: vector_l2_ops + /// + L2, + + /// + /// Cosine distance. Operator: <=> + /// Index ops class: vector_cosine_ops + /// Most common for text embeddings. + /// + Cosine, + + /// + /// Inner product (negative). Operator: <#> + /// Index ops class: vector_ip_ops + /// + InnerProduct +} + +internal static class DistanceFunctionExtensions +{ + public static string Operator(this DistanceFunction f) => f switch + { + DistanceFunction.L2 => "<->", + DistanceFunction.Cosine => "<=>", + DistanceFunction.InnerProduct => "<#>", + _ => throw new ArgumentOutOfRangeException(nameof(f)) + }; + + public static string OpsClass(this DistanceFunction f) => f switch + { + DistanceFunction.L2 => "vector_l2_ops", + DistanceFunction.Cosine => "vector_cosine_ops", + DistanceFunction.InnerProduct => "vector_ip_ops", + _ => throw new ArgumentOutOfRangeException(nameof(f)) + }; +} diff --git a/src/Marten.PgVector/Marten.PgVector.csproj b/src/Marten.PgVector/Marten.PgVector.csproj new file mode 100644 index 0000000000..302f59d022 --- /dev/null +++ b/src/Marten.PgVector/Marten.PgVector.csproj @@ -0,0 +1,24 @@ + + + pgvector extension for Marten — enables the "vector" PostgreSQL extension, exposes a UsePgVector() store-options helper, and provides VectorSearchAsync against vector-typed JSONB fields plus an embedding-aware VectorProjection base class. + true + true + true + false + true + true + true + true + enable + enable + Marten.PgVector + + + + + + + + + + diff --git a/src/Marten.PgVector/PgVectorExtensions.cs b/src/Marten.PgVector/PgVectorExtensions.cs new file mode 100644 index 0000000000..d0d00288a7 --- /dev/null +++ b/src/Marten.PgVector/PgVectorExtensions.cs @@ -0,0 +1,123 @@ +using System.Linq.Expressions; +using System.Reflection; +using JasperFx.Core.Reflection; +using Marten.Internal.Sessions; +using Npgsql; +using NpgsqlTypes; +using Pgvector; +using Pgvector.Npgsql; +using Weasel.Postgresql; + +namespace Marten.PgVector; + +public static class PgVectorExtensions +{ + /// + /// Enable pgvector support for this Marten store. This registers the PostgreSQL + /// "vector" extension and configures Npgsql to handle vector types on all + /// data sources, including tenant databases. + /// + public static StoreOptions UsePgVector(this StoreOptions opts) + { + // Configure all NpgsqlDataSourceBuilders to support pgvector types + opts.ConfigureNpgsqlDataSourceBuilder(b => b.UseVector()); + + // Register the PostgreSQL "vector" extension for schema management. + // This ensures CREATE EXTENSION IF NOT EXISTS vector runs on every database. + opts.Storage.ExtendedSchemaObjects.Add(new Extension("vector")); + + return opts; + } + + /// + /// Search for documents by vector similarity using a dedicated vector column. + /// The vector data is stored as a float array in the JSONB document and queried + /// via a cast to the vector type. + /// + public static async Task> VectorSearchAsync( + this IQuerySession session, + Expression> vectorProperty, + Vector queryVector, + int limit = 10, + DistanceFunction distance = DistanceFunction.Cosine) where T : class + { + var store = (DocumentStore)session.DocumentStore; + var tableName = ((IReadOnlyStoreOptions)store.Options).Schema.For(); + + // Build a JSONB path to the vector property. + // Use the serializer to determine the correct JSON property name. + var member = GetMemberInfo(vectorProperty); + var jsonPath = member.Name; + + var op = distance.Operator(); + var dimensions = queryVector.ToArray().Length; + + // Build WHERE clause with optional tenant filtering for conjoined tenancy + var whereClause = $"d.data->>'{jsonPath}' IS NOT NULL"; + // Apply tenant_id filtering only for conjoined tenancy (shared table with tenant_id column). + // For database-per-tenant, isolation is handled by connecting to the tenant's database. + var tenantId = session.TenantId; + var isSingleDatabase = store.Options.Tenancy.Cardinality == JasperFx.Descriptors.DatabaseCardinality.Single; + var hasTenantFilter = isSingleDatabase + && !string.IsNullOrEmpty(tenantId) + && tenantId != JasperFx.StorageConstants.DefaultTenantId; + + if (hasTenantFilter) + { + whereClause += " AND d.tenant_id = $3"; + } + + // Pass the query vector as its text form ([f1,f2,…]) and cast to vector(N) + // server-side instead of binding a Pgvector.Vector parameter. UseVector() + // registers a Pgvector.Vector ↔ "vector" OID mapping on the NpgsqlDataSource, + // but the data source caches pg_type the first time it opens a connection — + // if the "vector" extension is created later (e.g. by Marten's schema + // migration on the same data source), the cache is stale and parameter + // resolution throws "Cannot resolve 'vector' to a fully qualified datatype + // name." Routing through text + an explicit cast makes this race-immune. + var sql = $"select d.data from {tableName} d " + + $"WHERE {whereClause} " + + $"ORDER BY (d.data->>'{jsonPath}')::vector({dimensions}) {op} $1::vector({dimensions}) LIMIT $2"; + + var results = new List(); + + var database = session.As().Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + cmd.Parameters.Add(new NpgsqlParameter { Value = queryVector.ToString(), NpgsqlDbType = NpgsqlTypes.NpgsqlDbType.Text }); + cmd.Parameters.Add(new NpgsqlParameter { Value = limit }); + if (hasTenantFilter) + { + cmd.Parameters.Add(new NpgsqlParameter { Value = tenantId }); + } + + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var serializer = store.Serializer; + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + var json = await reader.GetFieldValueAsync(0).ConfigureAwait(false); + var bytes = System.Text.Encoding.UTF8.GetBytes(json); + var doc = serializer.FromJson(new MemoryStream(bytes)); + if (doc != null) results.Add(doc); + } + + return results; + } + + private static MemberInfo GetMemberInfo(Expression> expression) + { + var body = expression.Body; + if (body is UnaryExpression { NodeType: ExpressionType.Convert } unary) + body = unary.Operand; + + return body switch + { + MemberExpression memberExpr => memberExpr.Member, + _ => throw new ArgumentException("Expression must be a simple property or field access") + }; + } +} diff --git a/src/Marten.PgVector/Projection/IEmbeddingProvider.cs b/src/Marten.PgVector/Projection/IEmbeddingProvider.cs new file mode 100644 index 0000000000..16650c77e9 --- /dev/null +++ b/src/Marten.PgVector/Projection/IEmbeddingProvider.cs @@ -0,0 +1,22 @@ +using Pgvector; + +namespace Marten.PgVector.Projection; + +/// +/// User-supplied embedding generator. Implement this with your +/// chosen model (OpenAI, Ollama, local model, etc.). +/// Marten.PgVector is AI-model-agnostic. +/// +public interface IEmbeddingProvider +{ + /// + /// The dimensionality of the vectors produced by this provider. + /// + int Dimensions { get; } + + /// + /// Generate embeddings for one or more text inputs. + /// The returned array must have the same length as the input array. + /// + Task GenerateEmbeddingsAsync(string[] texts, CancellationToken ct = default); +} diff --git a/src/Marten.PgVector/Projection/VectorProjection.cs b/src/Marten.PgVector/Projection/VectorProjection.cs new file mode 100644 index 0000000000..91afcd3b99 --- /dev/null +++ b/src/Marten.PgVector/Projection/VectorProjection.cs @@ -0,0 +1,286 @@ +using System.Security.Cryptography; +using System.Text; +using JasperFx.Events; +using JasperFx.Events.Projections; +using Marten.Events.Daemon; +using Marten.Events.Projections; +using Npgsql; +using Pgvector; +using Weasel.Postgresql; +using Weasel.Postgresql.Tables; + +namespace Marten.PgVector.Projection; + +/// +/// Base class for projections that maintain a vector embedding table. +/// Events are mapped to text content, which is embedded via an IEmbeddingProvider +/// and stored alongside content hashes to skip re-embedding unchanged content. +/// +/// Register with: opts.Projections.Add(new MyVectorProjection(provider), ProjectionLifecycle.Async); +/// Create the schema table via: opts.Storage.ExtendedSchemaObjects.Add(projection.BuildTable(schemaName)); +/// +public abstract class VectorProjection : IProjection +{ + private readonly IEmbeddingProvider _provider; + private readonly string _tableName; + private readonly List _mappings = new(); + private readonly List _deleteTypes = new(); + + protected VectorProjection(string tableName, IEmbeddingProvider provider) + { + _tableName = tableName; + _provider = provider; + + Configure(new VectorProjectionMapping(this)); + } + + /// + /// Override to define how events map to text content for embedding. + /// + protected abstract void Configure(VectorProjectionMapping map); + + /// + /// Build the Weasel Table object for this projection's embedding storage. + /// Register via opts.Storage.ExtendedSchemaObjects.Add(table). + /// + public Table BuildTable(string schemaName) + { + var table = new Table(new PostgresqlObjectName(schemaName, _tableName)); + table.AddColumn("id").AsPrimaryKey(); + table.AddColumn("embedding", $"vector({_provider.Dimensions})").NotNull(); + table.AddColumn("content_text"); + table.AddColumn("content_hash").NotNull(); + table.AddColumn("metadata", "jsonb"); + table.AddColumn("last_updated", "timestamptz").NotNull().DefaultValueByExpression("now()"); + return table; + } + + /// + /// The qualified table name for use in queries. + /// + public string QualifiedTableName(string schemaName) => $"{schemaName}.{_tableName}"; + + public string TableName => _tableName; + + #region IProjection + + public void Apply(IDocumentOperations operations, IReadOnlyList streams) + { + throw new NotSupportedException("VectorProjection requires async execution"); + } + + public Task ApplyAsync(IDocumentOperations operations, IReadOnlyList streams, + CancellationToken cancellation) + { + var allEvents = streams.SelectMany(s => s.Events).ToList(); + return ApplyEventsAsync(operations, allEvents, cancellation); + } + + // IJasperFxProjection + public Task ApplyAsync(IDocumentOperations operations, IReadOnlyList events, + CancellationToken cancellation) + { + return ApplyEventsAsync(operations, events, cancellation); + } + + private async Task ApplyEventsAsync(IDocumentOperations operations, IReadOnlyList allEvents, + CancellationToken cancellation) + { + if (allEvents.Count == 0) return; + + var store = (DocumentStore)operations.DocumentStore; + var schemaName = store.Options.Events.DatabaseSchemaName; + var qualifiedTable = QualifiedTableName(schemaName); + + // Collect content extractions and deletes + var extractions = new List<(Guid Id, string Content)>(); + var deletions = new List(); + + foreach (var @event in allEvents) + { + if (_deleteTypes.Contains(@event.EventType)) + { + deletions.Add(@event.StreamId); + continue; + } + + foreach (var mapping in _mappings) + { + if (mapping.EventType == @event.EventType) + { + var id = mapping.ExtractId(@event); + var content = mapping.ExtractContent(@event); + if (content != null) + { + extractions.Add((id, content)); + } + break; + } + } + } + + if (extractions.Count == 0 && deletions.Count == 0) return; + + var database = store.Storage.Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync(cancellation).ConfigureAwait(false); + + // Process deletions + foreach (var id in deletions) + { + var delCmd = conn.CreateCommand(); + delCmd.CommandText = $"DELETE FROM {qualifiedTable} WHERE id = $1"; + delCmd.Parameters.Add(new NpgsqlParameter { Value = id }); + await delCmd.ExecuteNonQueryAsync(cancellation).ConfigureAwait(false); + } + + if (extractions.Count == 0) return; + + // Fetch existing content hashes + var ids = extractions.Select(e => e.Id).Distinct().ToArray(); + var existingHashes = new Dictionary(); + + var hashCmd = conn.CreateCommand(); + hashCmd.CommandText = $"SELECT id, content_hash FROM {qualifiedTable} WHERE id = ANY($1)"; + hashCmd.Parameters.Add(new NpgsqlParameter { Value = ids }); + + await using (var reader = await hashCmd.ExecuteReaderAsync(cancellation).ConfigureAwait(false)) + { + while (await reader.ReadAsync(cancellation).ConfigureAwait(false)) + { + existingHashes[reader.GetGuid(0)] = reader.GetString(1); + } + } + + // Filter to only items needing new embeddings (content changed) + var needsEmbedding = new List<(Guid Id, string Content, string Hash)>(); + foreach (var (id, content) in extractions) + { + var hash = ComputeHash(content); + if (existingHashes.TryGetValue(id, out var existing) && existing == hash) + continue; + + needsEmbedding.Add((id, content, hash)); + } + + if (needsEmbedding.Count == 0) return; + + // Batch-generate embeddings + var texts = needsEmbedding.Select(e => e.Content).ToArray(); + var embeddings = await _provider.GenerateEmbeddingsAsync(texts, cancellation).ConfigureAwait(false); + + // Upsert rows + for (int i = 0; i < needsEmbedding.Count; i++) + { + var (id, content, hash) = needsEmbedding[i]; + var embedding = embeddings[i]; + + // See PgVectorExtensions.VectorSearchAsync — bind the embedding as + // its text form and cast to vector(N) server-side. The + // NpgsqlDataSource type-info cache is unreliable across schema + // migrations that create the "vector" extension. + var dimensions = _provider.Dimensions; + var upsertCmd = conn.CreateCommand(); + upsertCmd.CommandText = + $"INSERT INTO {qualifiedTable} (id, embedding, content_text, content_hash, last_updated) " + + $"VALUES ($1, $2::vector({dimensions}), $3, $4, now()) " + + $"ON CONFLICT (id) DO UPDATE SET embedding = $2::vector({dimensions}), content_text = $3, content_hash = $4, last_updated = now()"; + + upsertCmd.Parameters.Add(new NpgsqlParameter { Value = id }); + upsertCmd.Parameters.Add(new NpgsqlParameter { Value = embedding.ToString(), NpgsqlDbType = NpgsqlTypes.NpgsqlDbType.Text }); + upsertCmd.Parameters.Add(new NpgsqlParameter { Value = content }); + upsertCmd.Parameters.Add(new NpgsqlParameter { Value = hash }); + + await upsertCmd.ExecuteNonQueryAsync(cancellation).ConfigureAwait(false); + } + } + + #endregion + + #region Internals + + internal void AddMapping(IVectorEventMapping mapping) + { + _mappings.Add(mapping); + } + + internal void AddDeleteType(Type eventType) + { + _deleteTypes.Add(eventType); + } + + private static string ComputeHash(string content) + { + var bytes = SHA256.HashData(Encoding.UTF8.GetBytes(content)); + return Convert.ToHexString(bytes); + } + + #endregion +} + +/// +/// Fluent API for configuring event-to-content mappings in a VectorProjection. +/// +public class VectorProjectionMapping +{ + private readonly VectorProjection _projection; + + internal VectorProjectionMapping(VectorProjection projection) + { + _projection = projection; + } + + /// + /// Map an event type to text content for embedding. + /// + public VectorProjectionMapping Map( + Func contentSelector, + Func? idSelector = null) + { + _projection.AddMapping(new VectorEventMapping(contentSelector, idSelector)); + return this; + } + + /// + /// Register an event type that causes the embedding row to be deleted. + /// + public VectorProjectionMapping Delete() + { + _projection.AddDeleteType(typeof(TEvent)); + return this; + } +} + +internal interface IVectorEventMapping +{ + Type EventType { get; } + Guid ExtractId(IEvent @event); + string? ExtractContent(IEvent @event); +} + +internal class VectorEventMapping : IVectorEventMapping +{ + private readonly Func _contentSelector; + private readonly Func? _idSelector; + + public VectorEventMapping(Func contentSelector, Func? idSelector) + { + _contentSelector = contentSelector; + _idSelector = idSelector; + } + + public Type EventType => typeof(TEvent); + + public Guid ExtractId(IEvent @event) + { + if (_idSelector != null) + return _idSelector((TEvent)@event.Data); + return @event.StreamId; + } + + public string? ExtractContent(IEvent @event) + { + try { return _contentSelector((TEvent)@event.Data); } + catch { return null; } + } +} diff --git a/src/Marten.PgVector/Projection/VectorProjectionSearch.cs b/src/Marten.PgVector/Projection/VectorProjectionSearch.cs new file mode 100644 index 0000000000..77c6066e13 --- /dev/null +++ b/src/Marten.PgVector/Projection/VectorProjectionSearch.cs @@ -0,0 +1,70 @@ +using JasperFx.Core.Reflection; +using Marten.Internal.Sessions; +using Npgsql; +using NpgsqlTypes; +using Pgvector; + +namespace Marten.PgVector.Projection; + +/// +/// Result of a vector similarity search against a VectorProjection table. +/// +public class VectorSearchResult +{ + public Guid Id { get; set; } + public float Distance { get; set; } + public string? ContentText { get; set; } +} + +public static class VectorProjectionSearchExtensions +{ + /// + /// Search a VectorProjection's embedding table by similarity. + /// + public static async Task> VectorProjectionSearchAsync( + this IQuerySession session, + string projectionTableName, + Vector queryVector, + int limit = 10, + DistanceFunction distance = DistanceFunction.Cosine) + { + var store = (DocumentStore)session.DocumentStore; + var schemaName = store.Options.Events.DatabaseSchemaName; + var qualifiedTable = $"{schemaName}.{projectionTableName}"; + var op = distance.Operator(); + var dimensions = queryVector.ToArray().Length; + + // See PgVectorExtensions.VectorSearchAsync — bind the query vector as its + // text form and cast to vector(N) server-side to bypass the + // NpgsqlDataSource type-info cache (which can be stale when the "vector" + // extension is created at migration time on the same data source). + var sql = $"SELECT id, embedding {op} $1::vector({dimensions}) as distance, content_text " + + $"FROM {qualifiedTable} " + + $"ORDER BY embedding {op} $1::vector({dimensions}) LIMIT $2"; + + var results = new List(); + + var database = session.As().Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + cmd.Parameters.Add(new NpgsqlParameter { Value = queryVector.ToString(), NpgsqlDbType = NpgsqlDbType.Text }); + cmd.Parameters.Add(new NpgsqlParameter { Value = limit }); + + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + results.Add(new VectorSearchResult + { + Id = reader.GetGuid(0), + Distance = reader.GetFloat(1), + ContentText = reader.IsDBNull(2) ? null : reader.GetString(2) + }); + } + + return results; + } +} diff --git a/src/Marten.PgVector/VectorFieldRegistration.cs b/src/Marten.PgVector/VectorFieldRegistration.cs new file mode 100644 index 0000000000..74a0973b81 --- /dev/null +++ b/src/Marten.PgVector/VectorFieldRegistration.cs @@ -0,0 +1,67 @@ +using System.Linq.Expressions; +using System.Reflection; + +namespace Marten.PgVector; + +/// +/// Captures a vector field registration before it's applied to the DocumentMapping. +/// +internal class VectorFieldRegistration +{ + public Type DocumentType { get; } + public MemberInfo Member { get; } + public int Dimensions { get; } + public DistanceFunction Distance { get; } + public string ColumnName { get; } + + public VectorFieldRegistration(Type documentType, MemberInfo member, int dimensions, + DistanceFunction distance, string? columnName) + { + DocumentType = documentType; + Member = member; + Dimensions = dimensions; + Distance = distance; + ColumnName = columnName ?? member.Name.ToLowerInvariant(); + } +} + +/// +/// Configuration builder for pgvector options. +/// +public class PgVectorOptions +{ + internal List Registrations { get; } = new(); + + /// + /// Register a vector column on a document type. The property must be of type + /// Pgvector.Vector or float[]. + /// + public PgVectorOptions VectorOn( + Expression> memberExpression, + int dimensions, + DistanceFunction distance = DistanceFunction.Cosine, + string? columnName = null) + { + var member = GetMemberInfo(memberExpression); + Registrations.Add(new VectorFieldRegistration( + typeof(TDoc), member, dimensions, distance, columnName)); + return this; + } + + private static MemberInfo GetMemberInfo(Expression> expression) + { + var body = expression.Body; + + // Handle convert/unbox for value types + if (body is UnaryExpression { NodeType: ExpressionType.Convert } unary) + { + body = unary.Operand; + } + + return body switch + { + MemberExpression memberExpr => memberExpr.Member, + _ => throw new ArgumentException("Expression must be a simple property or field access") + }; + } +} diff --git a/src/Marten.PostGIS.Tests/Marten.PostGIS.Tests.csproj b/src/Marten.PostGIS.Tests/Marten.PostGIS.Tests.csproj new file mode 100644 index 0000000000..fd4b05221f --- /dev/null +++ b/src/Marten.PostGIS.Tests/Marten.PostGIS.Tests.csproj @@ -0,0 +1,22 @@ + + + false + enable + enable + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + diff --git a/src/Marten.PostGIS.Tests/spatial_query_tests.cs b/src/Marten.PostGIS.Tests/spatial_query_tests.cs new file mode 100644 index 0000000000..e0bc34fa87 --- /dev/null +++ b/src/Marten.PostGIS.Tests/spatial_query_tests.cs @@ -0,0 +1,227 @@ +using Marten.PostGIS; +using Marten.Testing.Harness; +using NetTopologySuite.Geometries; +using Shouldly; +using Xunit; + +namespace Marten.PostGIS.Tests; + +public class spatial_query_tests : IAsyncLifetime +{ + private DocumentStore _store = null!; + private static readonly GeometryFactory Wgs84 = new(new PrecisionModel(), 4326); + + public async Task InitializeAsync() + { + _store = DocumentStore.For(opts => + { + opts.Connection(ConnectionSource.ConnectionString); + opts.DatabaseSchemaName = "postgis_tests"; + opts.AutoCreateSchemaObjects = JasperFx.AutoCreate.All; + + opts.UsePostGIS(); + opts.RegisterDocumentType(); + opts.RegisterDocumentType(); + }); + + await _store.Advanced.Clean.CompletelyRemoveAllAsync(); + await _store.Storage.ApplyAllConfiguredChangesToDatabaseAsync(); + } + + public Task DisposeAsync() + { + _store?.Dispose(); + return Task.CompletedTask; + } + + [Fact] + public async Task postgis_extension_is_created() + { + await using var conn = _store.Storage.Database.CreateConnection(); + await conn.OpenAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1 FROM pg_extension WHERE extname = 'postgis'"; + var result = await cmd.ExecuteScalarAsync(); + result.ShouldNotBeNull(); + } + + [Fact] + public async Task can_store_and_load_document_with_point() + { + var store = new StoreLocation + { + Id = Guid.NewGuid(), + Name = "Downtown Store", + Location = Wgs84.CreatePoint(new Coordinate(-122.33, 47.61)) + }; + + await using (var session = _store.LightweightSession()) + { + session.Store(store); + await session.SaveChangesAsync(); + } + + await using (var q = _store.QuerySession()) + { + var loaded = await q.LoadAsync(store.Id); + loaded.ShouldNotBeNull(); + loaded.Name.ShouldBe("Downtown Store"); + // NTS Point is serialized as GeoJSON in the JSONB data column + } + } + + [Fact] + public async Task can_store_and_load_document_with_polygon() + { + var area = new ServiceArea + { + Id = Guid.NewGuid(), + Name = "Metro Area", + Boundary = Wgs84.CreatePolygon(new[] + { + new Coordinate(-122.5, 47.5), + new Coordinate(-122.0, 47.5), + new Coordinate(-122.0, 47.8), + new Coordinate(-122.5, 47.8), + new Coordinate(-122.5, 47.5) // close the ring + }) + }; + + await using (var session = _store.LightweightSession()) + { + session.Store(area); + await session.SaveChangesAsync(); + } + + await using (var q = _store.QuerySession()) + { + var loaded = await q.LoadAsync(area.Id); + loaded.ShouldNotBeNull(); + loaded.Name.ShouldBe("Metro Area"); + } + } + + [Fact] + public async Task nearest_to_returns_ordered_by_distance() + { + // Seattle area stores + var stores = new[] + { + new StoreLocation { Id = Guid.NewGuid(), Name = "Capitol Hill", + Location = Wgs84.CreatePoint(new Coordinate(-122.32, 47.63)) }, + new StoreLocation { Id = Guid.NewGuid(), Name = "Bellevue", + Location = Wgs84.CreatePoint(new Coordinate(-122.20, 47.61)) }, + new StoreLocation { Id = Guid.NewGuid(), Name = "Tacoma", + Location = Wgs84.CreatePoint(new Coordinate(-122.44, 47.25)) }, + }; + + await using (var session = _store.LightweightSession()) + { + foreach (var s in stores) session.Store(s); + await session.SaveChangesAsync(); + } + + // Search from downtown Seattle + var downtown = Wgs84.CreatePoint(new Coordinate(-122.33, 47.61)); + + await using var q = _store.QuerySession(); + var results = await q.NearestToAsync( + x => x.Location, downtown, limit: 3, spatialType: SpatialType.Geometry); + + results.Count.ShouldBe(3); + results[0].Name.ShouldBe("Capitol Hill"); // closest + results[2].Name.ShouldBe("Tacoma"); // farthest + } + + [Fact] + public async Task within_distance_filters_correctly() + { + var stores = new[] + { + new StoreLocation { Id = Guid.NewGuid(), Name = "Nearby", + Location = Wgs84.CreatePoint(new Coordinate(-122.33, 47.62)) }, + new StoreLocation { Id = Guid.NewGuid(), Name = "Far Away", + Location = Wgs84.CreatePoint(new Coordinate(-120.0, 45.0)) }, + }; + + await using (var session = _store.LightweightSession()) + { + foreach (var s in stores) session.Store(s); + await session.SaveChangesAsync(); + } + + // Search within ~50km of downtown Seattle using geometry (degrees) + var downtown = Wgs84.CreatePoint(new Coordinate(-122.33, 47.61)); + + await using var q = _store.QuerySession(); + var results = await q.WithinDistanceAsync( + x => x.Location, downtown, distanceMeters: 0.5, // ~50km in degrees + spatialType: SpatialType.Geometry); + + results.Count.ShouldBe(1); + results[0].Name.ShouldBe("Nearby"); + } + + [Fact] + public async Task containing_finds_polygons_that_contain_a_point() + { + var areas = new[] + { + new ServiceArea + { + Id = Guid.NewGuid(), Name = "Seattle Metro", + Boundary = Wgs84.CreatePolygon(new[] + { + new Coordinate(-122.5, 47.4), + new Coordinate(-122.0, 47.4), + new Coordinate(-122.0, 47.8), + new Coordinate(-122.5, 47.8), + new Coordinate(-122.5, 47.4) + }) + }, + new ServiceArea + { + Id = Guid.NewGuid(), Name = "Portland Metro", + Boundary = Wgs84.CreatePolygon(new[] + { + new Coordinate(-123.0, 45.3), + new Coordinate(-122.3, 45.3), + new Coordinate(-122.3, 45.7), + new Coordinate(-123.0, 45.7), + new Coordinate(-123.0, 45.3) + }) + } + }; + + await using (var session = _store.LightweightSession()) + { + foreach (var a in areas) session.Store(a); + await session.SaveChangesAsync(); + } + + // Find which areas contain downtown Seattle + var downtown = Wgs84.CreatePoint(new Coordinate(-122.33, 47.61)); + + await using var q = _store.QuerySession(); + var results = await q.ContainingAsync( + x => x.Boundary, downtown, spatialType: SpatialType.Geometry); + + results.Count.ShouldBe(1); + results[0].Name.ShouldBe("Seattle Metro"); + } +} + +public class StoreLocation +{ + public Guid Id { get; set; } + public string Name { get; set; } = ""; + public Point? Location { get; set; } +} + +public class ServiceArea +{ + public Guid Id { get; set; } + public string Name { get; set; } = ""; + public Polygon? Boundary { get; set; } +} diff --git a/src/Marten.PostGIS/Marten.PostGIS.csproj b/src/Marten.PostGIS/Marten.PostGIS.csproj new file mode 100644 index 0000000000..a8085df8e2 --- /dev/null +++ b/src/Marten.PostGIS/Marten.PostGIS.csproj @@ -0,0 +1,28 @@ + + + PostGIS spatial extension for Marten — enables the "postgis" PostgreSQL extension and exposes spatial query helpers (NearestTo, WithinDistance, Containing, Intersecting) backed by NetTopologySuite geometry types. + true + true + true + false + true + true + true + true + enable + enable + Marten.PostGIS + + + + + + + + + + + + + + diff --git a/src/Marten.PostGIS/PostGISExtensions.cs b/src/Marten.PostGIS/PostGISExtensions.cs new file mode 100644 index 0000000000..8d4de98271 --- /dev/null +++ b/src/Marten.PostGIS/PostGISExtensions.cs @@ -0,0 +1,244 @@ +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using JasperFx.Core.Reflection; +using Marten.Internal.Sessions; +using NetTopologySuite.Geometries; +using Npgsql; +using Npgsql.NetTopologySuite; +using NpgsqlTypes; +using Weasel.Postgresql; + +namespace Marten.PostGIS; + +public static class PostGISExtensions +{ + /// + /// Enable PostGIS support for this Marten store. This registers the PostgreSQL + /// "postgis" extension and configures Npgsql to handle NTS spatial types on all + /// data sources, including tenant databases. + /// + public static StoreOptions UsePostGIS(this StoreOptions opts) + { + opts.ConfigureNpgsqlDataSourceBuilder(b => b.UseNetTopologySuite()); + opts.Storage.ExtendedSchemaObjects.Add(new Extension("postgis")); + + // Register NTS GeoJSON converters with the Newtonsoft.Json serializer + // so that NTS types (Point, Polygon, etc.) are properly serialized/deserialized + // in Marten's JSONB document storage + var serializer = new Marten.Services.JsonNetSerializer(); + serializer.Configure(s => + { + var geoJsonSerializer = NetTopologySuite.IO.GeoJsonSerializer.Create( + s, new GeometryFactory(new PrecisionModel(), 4326)); + foreach (var converter in geoJsonSerializer.Converters) + { + s.Converters.Add(converter); + } + }); + opts.Serializer(serializer); + + return opts; + } + + /// + /// Find the nearest documents to a point, ordered by distance. + /// Uses PostGIS ST_Distance for ordering. + /// + public static async Task> NearestToAsync( + this IQuerySession session, + Expression> spatialProperty, + Point point, + int limit = 10, + SpatialType spatialType = SpatialType.Geography) where T : class + { + var store = (DocumentStore)session.DocumentStore; + var tableName = ((IReadOnlyStoreOptions)store.Options).Schema.For(); + var member = GetMemberInfo(spatialProperty); + var jsonPath = member.Name; + var pgType = spatialType == SpatialType.Geography ? "geography" : "geometry"; + + // Use the <-> KNN operator for index-accelerated nearest neighbor + var sql = $"SELECT d.data FROM {tableName} d " + + $"WHERE d.data->>'{jsonPath}' IS NOT NULL " + + $"ORDER BY ST_GeomFromGeoJSON(d.data->'{jsonPath}')::{ pgType} <-> $1 " + + $"LIMIT $2"; + + return await ExecuteSpatialQuery(session, store, sql, point, limit); + } + + /// + /// Find all documents whose spatial property is within a given distance of a point. + /// Uses PostGIS ST_DWithin for index-accelerated distance filtering. + /// + public static async Task> WithinDistanceAsync( + this IQuerySession session, + Expression> spatialProperty, + Point point, + double distanceMeters, + SpatialType spatialType = SpatialType.Geography) where T : class + { + var store = (DocumentStore)session.DocumentStore; + var tableName = ((IReadOnlyStoreOptions)store.Options).Schema.For(); + var member = GetMemberInfo(spatialProperty); + var jsonPath = member.Name; + var pgType = spatialType == SpatialType.Geography ? "geography" : "geometry"; + + var sql = $"SELECT d.data FROM {tableName} d " + + $"WHERE ST_DWithin(" + + $"ST_GeomFromGeoJSON(d.data->'{jsonPath}')::{pgType}, " + + $"$1, $2)"; + + var results = new List(); + var database = session.As().Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + cmd.Parameters.Add(new NpgsqlParameter { Value = point, NpgsqlDbType = NpgsqlDbType.Geometry }); + cmd.Parameters.Add(new NpgsqlParameter { Value = distanceMeters }); + + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var serializer = store.Serializer; + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + var json = await reader.GetFieldValueAsync(0).ConfigureAwait(false); + var doc = serializer.FromJson(new MemoryStream(Encoding.UTF8.GetBytes(json))); + if (doc != null) results.Add(doc); + } + + return results; + } + + /// + /// Find all documents whose spatial property contains the given geometry. + /// Uses PostGIS ST_Contains. + /// + public static async Task> ContainingAsync( + this IQuerySession session, + Expression> spatialProperty, + Geometry geometry, + SpatialType spatialType = SpatialType.Geography) where T : class + { + var store = (DocumentStore)session.DocumentStore; + var tableName = ((IReadOnlyStoreOptions)store.Options).Schema.For(); + var member = GetMemberInfo(spatialProperty); + var jsonPath = member.Name; + var pgType = spatialType == SpatialType.Geography ? "geography" : "geometry"; + + var sql = $"SELECT d.data FROM {tableName} d " + + $"WHERE ST_Contains(" + + $"ST_GeomFromGeoJSON(d.data->'{jsonPath}')::{pgType}, " + + $"$1)"; + + var results = new List(); + var database = session.As().Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + cmd.Parameters.Add(new NpgsqlParameter { Value = geometry, NpgsqlDbType = NpgsqlDbType.Geometry }); + + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var serializer = store.Serializer; + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + var json = await reader.GetFieldValueAsync(0).ConfigureAwait(false); + var doc = serializer.FromJson(new MemoryStream(Encoding.UTF8.GetBytes(json))); + if (doc != null) results.Add(doc); + } + + return results; + } + + /// + /// Find all documents whose spatial property intersects the given geometry. + /// Uses PostGIS ST_Intersects. + /// + public static async Task> IntersectingAsync( + this IQuerySession session, + Expression> spatialProperty, + Geometry geometry, + SpatialType spatialType = SpatialType.Geography) where T : class + { + var store = (DocumentStore)session.DocumentStore; + var tableName = ((IReadOnlyStoreOptions)store.Options).Schema.For(); + var member = GetMemberInfo(spatialProperty); + var jsonPath = member.Name; + var pgType = spatialType == SpatialType.Geography ? "geography" : "geometry"; + + var sql = $"SELECT d.data FROM {tableName} d " + + $"WHERE ST_Intersects(" + + $"ST_GeomFromGeoJSON(d.data->'{jsonPath}')::{pgType}, " + + $"$1)"; + + var results = new List(); + var database = session.As().Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + cmd.Parameters.Add(new NpgsqlParameter { Value = geometry, NpgsqlDbType = NpgsqlDbType.Geometry }); + + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var serializer = store.Serializer; + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + var json = await reader.GetFieldValueAsync(0).ConfigureAwait(false); + var doc = serializer.FromJson(new MemoryStream(Encoding.UTF8.GetBytes(json))); + if (doc != null) results.Add(doc); + } + + return results; + } + + #region Private helpers + + private static async Task> ExecuteSpatialQuery( + IQuerySession session, DocumentStore store, string sql, + Point point, int limit) where T : class + { + var results = new List(); + var database = session.As().Database; + await using var conn = database.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + cmd.Parameters.Add(new NpgsqlParameter { Value = point, NpgsqlDbType = NpgsqlDbType.Geometry }); + cmd.Parameters.Add(new NpgsqlParameter { Value = limit }); + + await using var reader = await cmd.ExecuteReaderAsync().ConfigureAwait(false); + var serializer = store.Serializer; + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + var json = await reader.GetFieldValueAsync(0).ConfigureAwait(false); + var doc = serializer.FromJson(new MemoryStream(Encoding.UTF8.GetBytes(json))); + if (doc != null) results.Add(doc); + } + + return results; + } + + private static MemberInfo GetMemberInfo(Expression> expression) + { + var body = expression.Body; + if (body is UnaryExpression { NodeType: ExpressionType.Convert } unary) + body = unary.Operand; + + return body switch + { + MemberExpression memberExpr => memberExpr.Member, + _ => throw new ArgumentException("Expression must be a simple property or field access") + }; + } + + #endregion +} diff --git a/src/Marten.PostGIS/SpatialType.cs b/src/Marten.PostGIS/SpatialType.cs new file mode 100644 index 0000000000..e30588bfa5 --- /dev/null +++ b/src/Marten.PostGIS/SpatialType.cs @@ -0,0 +1,20 @@ +namespace Marten.PostGIS; + +/// +/// PostGIS spatial column type. +/// +public enum SpatialType +{ + /// + /// Geodetic coordinate system (lat/lon on Earth's surface). + /// Distances in meters. Accurate for global data. + /// This is the default. + /// + Geography, + + /// + /// Cartesian coordinate system (projected plane). + /// Faster operations but distances are in the coordinate system's units. + /// + Geometry +} diff --git a/src/Marten.slnx b/src/Marten.slnx index 483a1cdfcc..b4e281e65d 100644 --- a/src/Marten.slnx +++ b/src/Marten.slnx @@ -20,6 +20,14 @@ + + + + + + + +