diff --git a/Directory.Build.props b/Directory.Build.props index 86b5ee23ec..fdccb70665 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,7 +1,7 @@ - 8.27.0 + 8.28.0 13.0 Jeremy D. Miller;Babu Annamalai;Jaedyn Tonee https://martendb.io/logo.png diff --git a/Directory.Packages.props b/Directory.Packages.props index f5470a21fc..4fa53f6917 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -13,14 +13,16 @@ - - + + + + diff --git a/src/Marten.SourceGeneration/DiscoveredMartenTypesEmitter.cs b/src/Marten.SourceGeneration/DiscoveredMartenTypesEmitter.cs new file mode 100644 index 0000000000..a502d960a0 --- /dev/null +++ b/src/Marten.SourceGeneration/DiscoveredMartenTypesEmitter.cs @@ -0,0 +1,64 @@ +using System.Collections.Generic; +using System.Text; + +namespace Marten.SourceGeneration; + +/// +/// Emits the DiscoveredMartenTypes static manifest class containing +/// compile-time discovered document, projection, and event types. +/// +internal static class DiscoveredMartenTypesEmitter +{ + public static string Emit( + List documentTypes, + List projectionTypes, + List eventTypes) + { + var sb = new StringBuilder(); + + sb.AppendLine("// "); + sb.AppendLine("// This file is generated by Marten.SourceGeneration."); + sb.AppendLine("// Do not edit manually."); + sb.AppendLine("#nullable enable"); + sb.AppendLine(); + sb.AppendLine("using System;"); + sb.AppendLine("using System.Collections.Generic;"); + sb.AppendLine(); + sb.AppendLine("namespace Marten.Generated;"); + sb.AppendLine(); + sb.AppendLine("/// "); + sb.AppendLine("/// Compile-time discovered Marten types. Used by StoreOptions to bypass"); + sb.AppendLine("/// runtime assembly scanning when source generation is opted into."); + sb.AppendLine("/// "); + sb.AppendLine("public static class DiscoveredMartenTypes"); + sb.AppendLine("{"); + + // Document types + EmitTypeArray(sb, "DocumentTypes", documentTypes); + sb.AppendLine(); + + // Projection types + EmitTypeArray(sb, "ProjectionTypes", projectionTypes); + sb.AppendLine(); + + // Event types + EmitTypeArray(sb, "EventTypes", eventTypes); + + sb.AppendLine("}"); + + return sb.ToString(); + } + + private static void EmitTypeArray(StringBuilder sb, string propertyName, List types) + { + sb.AppendLine($" public static IReadOnlyList {propertyName} => new Type[]"); + sb.AppendLine(" {"); + + foreach (var type in types) + { + sb.AppendLine($" typeof({type}),"); + } + + sb.AppendLine(" };"); + } +} diff --git a/src/Marten.SourceGeneration/Marten.SourceGeneration.csproj b/src/Marten.SourceGeneration/Marten.SourceGeneration.csproj new file mode 100644 index 0000000000..c970284b04 --- /dev/null +++ b/src/Marten.SourceGeneration/Marten.SourceGeneration.csproj @@ -0,0 +1,24 @@ + + + + Source Generator for Marten document and projection type discovery + netstandard2.0 + 12 + enable + false + $(NoWarn);CS8603;CS8602;CS8604 + true + true + + + + + + + + + + false + + + diff --git a/src/Marten.SourceGeneration/MartenTypeDiscoveryGenerator.cs b/src/Marten.SourceGeneration/MartenTypeDiscoveryGenerator.cs new file mode 100644 index 0000000000..248c48bfe4 --- /dev/null +++ b/src/Marten.SourceGeneration/MartenTypeDiscoveryGenerator.cs @@ -0,0 +1,306 @@ +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; + +namespace Marten.SourceGeneration; + +/// +/// Incremental source generator that discovers Marten document types, projection types, +/// and event types at compile time, emitting a static manifest class that can be used +/// to bypass runtime assembly scanning on startup. +/// +[Generator] +public class MartenTypeDiscoveryGenerator : IIncrementalGenerator +{ + // Attribute-based document discovery + private const string DocumentAliasAttributeFullName = "Marten.Schema.DocumentAliasAttribute"; + + // Projection base class names (unbound generic or non-generic) + private static readonly string[] ProjectionBaseTypeNames = new[] + { + "Marten.Events.Aggregation.SingleStreamProjection", + "Marten.Events.Projections.MultiStreamProjection", + "Marten.Events.Projections.EventProjection" + }; + + // Method names used in projections that take event parameters + private static readonly HashSet ApplyMethodNames = new HashSet + { + "Apply", "Create", "ShouldDelete" + }; + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + // Pipeline 1: Find classes with [DocumentAlias] attribute (document types) + var documentTypesByAttribute = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => IsClassWithAttributes(node), + transform: static (ctx, _) => GetDocumentAliasType(ctx)) + .Where(static info => info != null); + + // Pipeline 2: Find classes used with IDocumentSession.Store() or IQuerySession.Query() + var documentTypesByUsage = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => IsGenericInvocation(node), + transform: static (ctx, _) => GetSessionUsageType(ctx)) + .Where(static info => info != null); + + // Pipeline 3: Find projection types (classes inheriting from projection base classes) + var projectionTypes = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: static (node, _) => IsClassWithBaseList(node), + transform: static (ctx, _) => GetProjectionTypeInfo(ctx)) + .Where(static info => info != null); + + // Combine all pipelines with the compilation + var allData = context.CompilationProvider + .Combine(documentTypesByAttribute.Collect()) + .Combine(documentTypesByUsage.Collect()) + .Combine(projectionTypes.Collect()); + + context.RegisterSourceOutput(allData, static (spc, source) => + { + var compilation = source.Left.Left.Left; + var docsByAttribute = source.Left.Left.Right; + var docsByUsage = source.Left.Right; + var projections = source.Right; + + Execute(compilation, docsByAttribute, docsByUsage, projections, spc); + }); + } + + private static bool IsClassWithAttributes(SyntaxNode node) + { + return node is ClassDeclarationSyntax classDecl + && classDecl.AttributeLists.Count > 0; + } + + private static bool IsClassWithBaseList(SyntaxNode node) + { + return node is ClassDeclarationSyntax classDecl + && classDecl.BaseList != null + && classDecl.BaseList.Types.Count > 0; + } + + private static bool IsGenericInvocation(SyntaxNode node) + { + // Looking for method calls like session.Store() or session.Query() + return node is InvocationExpressionSyntax invocation + && invocation.Expression is MemberAccessExpressionSyntax memberAccess + && memberAccess.Name is GenericNameSyntax; + } + + private static DiscoveredTypeInfo GetDocumentAliasType(GeneratorSyntaxContext context) + { + var classDecl = (ClassDeclarationSyntax)context.Node; + var model = context.SemanticModel; + var classSymbol = model.GetDeclaredSymbol(classDecl) as INamedTypeSymbol; + + if (classSymbol == null) return null; + + foreach (var attr in classSymbol.GetAttributes()) + { + var attrName = attr.AttributeClass?.ToDisplayString(); + if (attrName == DocumentAliasAttributeFullName) + { + return new DiscoveredTypeInfo( + classSymbol.ToDisplayString(), + DiscoveredKind.Document); + } + } + + return null; + } + + private static DiscoveredTypeInfo GetSessionUsageType(GeneratorSyntaxContext context) + { + var invocation = (InvocationExpressionSyntax)context.Node; + var memberAccess = (MemberAccessExpressionSyntax)invocation.Expression; + var genericName = (GenericNameSyntax)memberAccess.Name; + + var methodName = genericName.Identifier.Text; + if (methodName != "Store" && methodName != "Query") return null; + + var model = context.SemanticModel; + var symbolInfo = model.GetSymbolInfo(invocation); + var methodSymbol = symbolInfo.Symbol as IMethodSymbol ?? symbolInfo.CandidateSymbols.OfType().FirstOrDefault(); + + if (methodSymbol == null) return null; + + // Check that the containing type is IDocumentSession or IQuerySession + var containingType = methodSymbol.ContainingType?.ToDisplayString(); + if (containingType == null) return null; + + bool isDocumentSession = containingType.StartsWith("Marten.IDocumentSession") + || containingType.StartsWith("Marten.IQuerySession"); + if (!isDocumentSession) return null; + + // Get the type argument + if (methodSymbol.TypeArguments.Length == 0) return null; + var typeArg = methodSymbol.TypeArguments[0]; + + // Skip open generic types + if (typeArg is ITypeParameterSymbol) return null; + + return new DiscoveredTypeInfo( + typeArg.ToDisplayString(), + DiscoveredKind.Document); + } + + private static ProjectionTypeInfo GetProjectionTypeInfo(GeneratorSyntaxContext context) + { + var classDecl = (ClassDeclarationSyntax)context.Node; + var model = context.SemanticModel; + var classSymbol = model.GetDeclaredSymbol(classDecl) as INamedTypeSymbol; + + if (classSymbol == null || classSymbol.IsAbstract) return null; + + // Walk the base type chain to find projection base classes + var currentBase = classSymbol.BaseType; + while (currentBase != null) + { + var baseFullName = currentBase.OriginalDefinition.ToDisplayString(); + + foreach (var projectionBase in ProjectionBaseTypeNames) + { + if (baseFullName.StartsWith(projectionBase)) + { + // Discovered a projection type. Now find event types from Apply/Create methods. + var eventTypes = DiscoverEventTypes(classSymbol); + return new ProjectionTypeInfo( + classSymbol.ToDisplayString(), + baseFullName, + eventTypes); + } + } + + currentBase = currentBase.BaseType; + } + + return null; + } + + private static List DiscoverEventTypes(INamedTypeSymbol projectionClass) + { + var eventTypes = new List(); + var seen = new HashSet(); + + // Walk up the type hierarchy to find Apply/Create methods + var current = projectionClass; + while (current != null) + { + foreach (var member in current.GetMembers()) + { + if (member is IMethodSymbol method && ApplyMethodNames.Contains(method.Name)) + { + // The first parameter (or second if first is the aggregate) is typically the event type + foreach (var param in method.Parameters) + { + var paramType = param.Type; + if (paramType is ITypeParameterSymbol) continue; + + var paramTypeName = paramType.ToDisplayString(); + + // Skip known Marten/infrastructure types + if (paramTypeName.StartsWith("Marten.") || + paramTypeName.StartsWith("JasperFx.") || + paramTypeName.StartsWith("System.") || + paramTypeName == "object") + continue; + + if (seen.Add(paramTypeName)) + { + eventTypes.Add(paramTypeName); + } + } + } + } + + current = current.BaseType; + } + + return eventTypes; + } + + private static void Execute( + Compilation compilation, + ImmutableArray docsByAttribute, + ImmutableArray docsByUsage, + ImmutableArray projections, + SourceProductionContext context) + { + // Collect all unique document types + var documentTypes = new HashSet(); + foreach (var doc in docsByAttribute) + { + if (doc != null) documentTypes.Add(doc.FullTypeName); + } + foreach (var doc in docsByUsage) + { + if (doc != null) documentTypes.Add(doc.FullTypeName); + } + + // Collect all unique projection types and event types + var projectionTypeNames = new HashSet(); + var eventTypeNames = new HashSet(); + + foreach (var proj in projections) + { + if (proj == null) continue; + projectionTypeNames.Add(proj.FullTypeName); + foreach (var eventType in proj.EventTypes) + { + eventTypeNames.Add(eventType); + } + } + + // Only emit if we found something + if (documentTypes.Count == 0 && projectionTypeNames.Count == 0 && eventTypeNames.Count == 0) + return; + + var source = DiscoveredMartenTypesEmitter.Emit( + documentTypes.OrderBy(x => x).ToList(), + projectionTypeNames.OrderBy(x => x).ToList(), + eventTypeNames.OrderBy(x => x).ToList()); + + context.AddSource("DiscoveredMartenTypes.g.cs", + SourceText.From(source, Encoding.UTF8)); + } +} + +internal sealed class DiscoveredTypeInfo +{ + public DiscoveredTypeInfo(string fullTypeName, DiscoveredKind kind) + { + FullTypeName = fullTypeName; + Kind = kind; + } + + public string FullTypeName { get; } + public DiscoveredKind Kind { get; } +} + +internal enum DiscoveredKind +{ + Document, + Projection, + Event +} + +internal sealed class ProjectionTypeInfo +{ + public ProjectionTypeInfo(string fullTypeName, string baseTypeName, List eventTypes) + { + FullTypeName = fullTypeName; + BaseTypeName = baseTypeName; + EventTypes = eventTypes; + } + + public string FullTypeName { get; } + public string BaseTypeName { get; } + public List EventTypes { get; } +} diff --git a/src/Marten/DocumentStore.cs b/src/Marten/DocumentStore.cs index 42284d1b6e..3ba0a934bf 100644 --- a/src/Marten/DocumentStore.cs +++ b/src/Marten/DocumentStore.cs @@ -65,7 +65,10 @@ public DocumentStore(StoreOptions options) StorageFeatures.PostProcessConfiguration(); Events.Initialize(this); Options.Projections.DiscoverGeneratedEvolvers(AppDomain.CurrentDomain.GetAssemblies()); - DiscoverNaturalKeyAggregates(AppDomain.CurrentDomain.GetAssemblies()); + // Note: Natural key aggregates are discovered lazily when FetchForWriting + // is called with a type that has [NaturalKey]. Assembly-level scanning was + // removed because it caused spurious InvalidProjectionException failures + // when test projects share compile references with incompatible stream identity types. Options.Projections.AssertValidity(Options); if (Options.LogFactory != null) diff --git a/src/Marten/Events/Aggregation/SingleStreamProjection.cs b/src/Marten/Events/Aggregation/SingleStreamProjection.cs index ba9e1d9ca6..c8d24d02b2 100644 --- a/src/Marten/Events/Aggregation/SingleStreamProjection.cs +++ b/src/Marten/Events/Aggregation/SingleStreamProjection.cs @@ -168,6 +168,10 @@ internal bool IsIdTypeValidForStream(Type idType, StoreOptions options, out Type protected IEnumerable validateDocumentIdentity(StoreOptions options, DocumentMapping mapping) { + // Skip ID type validation for aggregates that use natural keys — + // they intentionally have an ID type that differs from the stream identity + if (NaturalKeyDefinition != null) yield break; + var matches = IsIdTypeValidForStream(mapping.IdType, options, out var expectedType, out var valueTypeInfo); if (!matches) { diff --git a/src/Marten/Events/Daemon/Progress/ProjectionProgressStatement.cs b/src/Marten/Events/Daemon/Progress/ProjectionProgressStatement.cs index b3f21d0f7e..e97b1c7fe2 100644 --- a/src/Marten/Events/Daemon/Progress/ProjectionProgressStatement.cs +++ b/src/Marten/Events/Daemon/Progress/ProjectionProgressStatement.cs @@ -29,7 +29,7 @@ protected override void configure(ICommandBuilder builder) { if (_events.UseOptimizedProjectionRebuilds && _events.EnableExtendedProgressionTracking) { - builder.Append($"select name, last_seq_id, mode, rebuild_threshold, assigned_node, heartbeat, agent_status, pause_reason, running_on_node from {_events.DatabaseSchemaName}.mt_event_progression"); + builder.Append($"select name, last_seq_id, mode, rebuild_threshold, assigned_node, heartbeat, agent_status, pause_reason, running_on_node, warning_behind_threshold, critical_behind_threshold from {_events.DatabaseSchemaName}.mt_event_progression"); } else if (_events.UseOptimizedProjectionRebuilds) { @@ -37,7 +37,7 @@ protected override void configure(ICommandBuilder builder) } else if (_events.EnableExtendedProgressionTracking) { - builder.Append($"select name, last_seq_id, heartbeat, agent_status, pause_reason, running_on_node from {_events.DatabaseSchemaName}.mt_event_progression"); + builder.Append($"select name, last_seq_id, heartbeat, agent_status, pause_reason, running_on_node, warning_behind_threshold, critical_behind_threshold from {_events.DatabaseSchemaName}.mt_event_progression"); } else { diff --git a/src/Marten/Events/Daemon/Progress/ShardStateSelector.cs b/src/Marten/Events/Daemon/Progress/ShardStateSelector.cs index 94efd596bc..2cedf9106a 100644 --- a/src/Marten/Events/Daemon/Progress/ShardStateSelector.cs +++ b/src/Marten/Events/Daemon/Progress/ShardStateSelector.cs @@ -69,6 +69,18 @@ public async Task ResolveAsync(DbDataReader reader, CancellationToke state.RunningOnNode = await reader.GetFieldValueAsync(nextIndex, token).ConfigureAwait(false); } nextIndex++; + + if (!await reader.IsDBNullAsync(nextIndex, token).ConfigureAwait(false)) + { + state.WarningBehindThreshold = await reader.GetFieldValueAsync(nextIndex, token).ConfigureAwait(false); + } + nextIndex++; + + if (!await reader.IsDBNullAsync(nextIndex, token).ConfigureAwait(false)) + { + state.CriticalBehindThreshold = await reader.GetFieldValueAsync(nextIndex, token).ConfigureAwait(false); + } + nextIndex++; } return state; diff --git a/src/Marten/Events/Schema/EventProgressionTable.cs b/src/Marten/Events/Schema/EventProgressionTable.cs index e58c65974d..5b84fa00dd 100644 --- a/src/Marten/Events/Schema/EventProgressionTable.cs +++ b/src/Marten/Events/Schema/EventProgressionTable.cs @@ -31,6 +31,8 @@ public EventProgressionTable(EventGraph eventGraph): base(new PostgresqlObjectNa AddColumn("agent_status", "varchar(20)").AllowNulls(); AddColumn("pause_reason", "text").AllowNulls(); AddColumn("running_on_node", "integer").AllowNulls(); + AddColumn("warning_behind_threshold", "bigint").AllowNulls(); + AddColumn("critical_behind_threshold", "bigint").AllowNulls(); } PrimaryKeyName = "pk_mt_event_progression"; diff --git a/src/Marten/Events/Schema/NaturalKeyTable.cs b/src/Marten/Events/Schema/NaturalKeyTable.cs index d47a925827..3fa3c5a021 100644 --- a/src/Marten/Events/Schema/NaturalKeyTable.cs +++ b/src/Marten/Events/Schema/NaturalKeyTable.cs @@ -1,13 +1,16 @@ +using System.IO; +using System.Linq; using JasperFx.Events; using Marten.Events.Archiving; using Marten.Storage; using Marten.Storage.Metadata; +using Weasel.Core; using Weasel.Postgresql; using Weasel.Postgresql.Tables; namespace Marten.Events.Schema; -internal class NaturalKeyTable: Table +internal class NaturalKeyTable: Table, ISchemaObject { public NaturalKeyTable(EventGraph events, NaturalKeyDefinition naturalKey) : base(new PostgresqlObjectName(events.DatabaseSchemaName, @@ -47,7 +50,7 @@ public NaturalKeyTable(EventGraph events, NaturalKeyDefinition naturalKey) ColumnNames = new[] { streamCol, TenantIdColumn.Name, "is_archived" }, LinkedNames = new[] { "id", TenantIdColumn.Name, "is_archived" }, LinkedTable = new PostgresqlObjectName(events.DatabaseSchemaName, StreamsTable.TableName), - OnDelete = CascadeAction.Cascade + OnDelete = Weasel.Postgresql.CascadeAction.Cascade }); } else @@ -58,7 +61,7 @@ public NaturalKeyTable(EventGraph events, NaturalKeyDefinition naturalKey) ColumnNames = new[] { streamCol, "is_archived" }, LinkedNames = new[] { "id", "is_archived" }, LinkedTable = new PostgresqlObjectName(events.DatabaseSchemaName, StreamsTable.TableName), - OnDelete = CascadeAction.Cascade + OnDelete = Weasel.Postgresql.CascadeAction.Cascade }); } } @@ -70,7 +73,7 @@ public NaturalKeyTable(EventGraph events, NaturalKeyDefinition naturalKey) ColumnNames = new[] { streamCol, TenantIdColumn.Name }, LinkedNames = new[] { "id", TenantIdColumn.Name }, LinkedTable = new PostgresqlObjectName(events.DatabaseSchemaName, StreamsTable.TableName), - OnDelete = CascadeAction.Cascade + OnDelete = Weasel.Postgresql.CascadeAction.Cascade }); } else @@ -81,7 +84,7 @@ public NaturalKeyTable(EventGraph events, NaturalKeyDefinition naturalKey) ColumnNames = new[] { streamCol }, LinkedNames = new[] { "id" }, LinkedTable = new PostgresqlObjectName(events.DatabaseSchemaName, StreamsTable.TableName), - OnDelete = CascadeAction.Cascade + OnDelete = Weasel.Postgresql.CascadeAction.Cascade }); } @@ -92,4 +95,79 @@ public NaturalKeyTable(EventGraph events, NaturalKeyDefinition naturalKey) Columns = new[] { streamCol } }); } + + /// + /// Explicit ISchemaObject implementation to make FK creation idempotent. + /// When CREATE TABLE IF NOT EXISTS is a no-op (table already exists from a prior + /// test or concurrent migration), the base Table.WriteCreateStatement generates + /// ALTER TABLE ADD CONSTRAINT which fails with "constraint already exists". + /// This wraps each FK in a DO $$ IF NOT EXISTS guard. + /// + void ISchemaObject.WriteCreateStatement(Migrator migrator, TextWriter writer) + { + // Write the CREATE TABLE portion + if (migrator.TableCreation == CreationStyle.DropThenCreate) + { + writer.WriteLine("DROP TABLE IF EXISTS {0} CASCADE;", Identifier); + writer.WriteLine("CREATE TABLE {0} (", Identifier); + } + else + { + writer.WriteLine("CREATE TABLE IF NOT EXISTS {0} (", Identifier); + } + + var lines = Columns + .Select(column => column.ToDeclaration()) + .ToList(); + + if (PrimaryKeyColumns.Any()) + { + lines.Add($"CONSTRAINT {PrimaryKeyName} PRIMARY KEY ({string.Join(", ", PrimaryKeyColumns)})"); + } + + for (var i = 0; i < lines.Count - 1; i++) + { + writer.WriteLine(lines[i] + ","); + } + writer.WriteLine(lines.Last()); + + if (Partitioning != null) + { + Partitioning.WritePartitionBy(writer); + } + else + { + writer.WriteLine(");"); + } + + // Write FKs with idempotent guard to avoid "constraint already exists" + foreach (var foreignKey in ForeignKeys) + { + writer.WriteLine(); + writer.WriteLine("DO $$ BEGIN"); + writer.Write("IF NOT EXISTS (SELECT 1 FROM pg_constraint WHERE conname = '"); + writer.Write(foreignKey.Name); + writer.WriteLine("') THEN"); + writer.WriteLine(foreignKey.ToDDL(this)); + writer.WriteLine("END IF;"); + writer.WriteLine("END $$;"); + } + + // Write indexes with IF NOT EXISTS for idempotency + foreach (var index in Indexes) + { + writer.WriteLine(); + var indexDdl = index.ToDDL(this); + // Inject IF NOT EXISTS into CREATE INDEX statement + indexDdl = indexDdl.Replace("CREATE INDEX ", "CREATE INDEX IF NOT EXISTS "); + indexDdl = indexDdl.Replace("CREATE UNIQUE INDEX ", "CREATE UNIQUE INDEX IF NOT EXISTS "); + writer.WriteLine(indexDdl); + } + + if (Partitioning != null) + { + writer.WriteLine(); + Partitioning.WriteCreateStatement(writer, this); + } + } } diff --git a/src/Marten/Storage/MasterTableTenancy.cs b/src/Marten/Storage/MasterTableTenancy.cs index ba668b3e7f..50f7992b50 100644 --- a/src/Marten/Storage/MasterTableTenancy.cs +++ b/src/Marten/Storage/MasterTableTenancy.cs @@ -19,7 +19,7 @@ namespace Marten.Storage; -public class MasterTableTenancy: ITenancy, ITenancyWithMasterDatabase +public class MasterTableTenancy: ITenancy, ITenancyWithMasterDatabase, IDynamicTenantSource { private readonly MasterTableTenancyOptions _configuration; private readonly Lazy _dataSource; @@ -90,7 +90,7 @@ public async ValueTask> BuildDatabases() } await using var reader = await ((DbCommand)conn - .CreateCommand($"select tenant_id, connection_string from {_schemaName}.{TenantTable.TableName}")) + .CreateCommand($"select tenant_id, connection_string from {_schemaName}.{TenantTable.TableName} where {TenantTable.DisabledColumn} = false")) .ExecuteReaderAsync().ConfigureAwait(false); while (await reader.ReadAsync().ConfigureAwait(false)) @@ -246,6 +246,137 @@ await _dataSource.Value .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); } + #region IDynamicTenantSource + + async ValueTask ITenantedSource.FindAsync(string tenantId) + { + tenantId = _options.TenantIdStyle.MaybeCorrectTenantId(tenantId); + var connectionString = (string)await _dataSource.Value + .CreateCommand($"select connection_string from {_schemaName}.{TenantTable.TableName} where tenant_id = :id and {TenantTable.DisabledColumn} = false") + .With("id", tenantId) + .ExecuteScalarAsync(CancellationToken.None).ConfigureAwait(false); + + if (connectionString.IsEmpty()) + { + throw new UnknownTenantIdException(tenantId); + } + + return _configuration.CorrectConnectionString(connectionString); + } + + Task ITenantedSource.RefreshAsync() + { + // Reset the databases cache so next access re-reads from the master table + _databases = ImHashMap.Empty; + return Task.CompletedTask; + } + + IReadOnlyList ITenantedSource.AllActive() + { + return _databases.Enumerate().Select(x => + { + // Return the connection strings for active databases + return x.Value.Identifier; + }).Distinct().ToList(); + } + + IReadOnlyList> ITenantedSource.AllActiveByTenant() + { + return _databases.Enumerate().Select(pair => new Assignment(pair.Key, pair.Key)).ToList(); + } + + public async Task AddTenantAsync(string tenantId, string connectionValue) + { + tenantId = _options.TenantIdStyle.MaybeCorrectTenantId(tenantId); + await AddDatabaseRecordAsync(tenantId, connectionValue).ConfigureAwait(false); + + // Eagerly create and cache the database + connectionValue = _configuration.CorrectConnectionString(connectionValue); + var database = new MartenDatabase(_options, _options.NpgsqlDataSourceFactory.Create(connectionValue), tenantId); + database.TenantIds.Fill(tenantId); + _databases = _databases.AddOrUpdate(tenantId, database); + } + + public async Task DisableTenantAsync(string tenantId) + { + tenantId = _options.TenantIdStyle.MaybeCorrectTenantId(tenantId); + await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false); + + await _dataSource.Value + .CreateCommand($"update {_schemaName}.{TenantTable.TableName} set {TenantTable.DisabledColumn} = true where tenant_id = :id") + .With("id", tenantId) + .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); + + // Evict from cache and dispose + if (_databases.TryFind(tenantId, out var database)) + { + _databases = _databases.Remove(tenantId); +#pragma warning disable VSTHRD103 + database.Dispose(); +#pragma warning restore VSTHRD103 + } + } + + public async Task RemoveTenantAsync(string tenantId) + { + tenantId = _options.TenantIdStyle.MaybeCorrectTenantId(tenantId); + + // Evict from cache and dispose before deleting the record + if (_databases.TryFind(tenantId, out var database)) + { + _databases = _databases.Remove(tenantId); +#pragma warning disable VSTHRD103 + database.Dispose(); +#pragma warning restore VSTHRD103 + } + + await DeleteDatabaseRecordAsync(tenantId).ConfigureAwait(false); + } + + public async Task> AllDisabledAsync() + { + await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false); + + var list = new List(); + await using var conn = _dataSource.Value.CreateConnection(); + await conn.OpenAsync().ConfigureAwait(false); + + try + { + await using var reader = await ((DbCommand)conn + .CreateCommand($"select tenant_id from {_schemaName}.{TenantTable.TableName} where {TenantTable.DisabledColumn} = true")) + .ExecuteReaderAsync().ConfigureAwait(false); + + while (await reader.ReadAsync().ConfigureAwait(false)) + { + list.Add(await reader.GetFieldValueAsync(0).ConfigureAwait(false)); + } + + await reader.CloseAsync().ConfigureAwait(false); + } + finally + { + await conn.CloseAsync().ConfigureAwait(false); + } + + return list; + } + + public async Task EnableTenantAsync(string tenantId) + { + tenantId = _options.TenantIdStyle.MaybeCorrectTenantId(tenantId); + await maybeApplyChanges(_tenantDatabase.Value).ConfigureAwait(false); + + await _dataSource.Value + .CreateCommand($"update {_schemaName}.{TenantTable.TableName} set {TenantTable.DisabledColumn} = false where tenant_id = :id") + .With("id", tenantId) + .ExecuteNonQueryAsync(CancellationToken.None).ConfigureAwait(false); + + // The tenant will be lazily loaded into cache on next access + } + + #endregion + private async Task maybeApplyChanges(TenantLookupDatabase tenantDatabase) { if (!_hasAppliedChanges && (_configuration.AutoCreate ?? _options.AutoCreateSchemaObjects) != AutoCreate.None) @@ -288,7 +419,7 @@ private async Task seedDatabasesAsync(NpgsqlConnection conn) { tenantId = _options.TenantIdStyle.MaybeCorrectTenantId(tenantId); var connectionString = (string)await _dataSource.Value - .CreateCommand($"select connection_string from {_schemaName}.{TenantTable.TableName} where tenant_id = :id") + .CreateCommand($"select connection_string from {_schemaName}.{TenantTable.TableName} where tenant_id = :id and {TenantTable.DisabledColumn} = false") .With("id", tenantId) .ExecuteScalarAsync(CancellationToken.None).ConfigureAwait(false); @@ -356,7 +487,10 @@ public TenantTable(string schemaName): base(new DbObjectName(schemaName, TableNa { AddColumn(StorageConstants.TenantIdColumn).AsPrimaryKey(); AddColumn(StorageConstants.ConnectionStringColumn).NotNull(); + AddColumn(DisabledColumn).DefaultValueByExpression("false").NotNull(); } + + public const string DisabledColumn = "disabled"; } public async ValueTask DescribeDatabasesAsync(CancellationToken token) diff --git a/src/Marten/StoreOptions.SourceGeneration.cs b/src/Marten/StoreOptions.SourceGeneration.cs new file mode 100644 index 0000000000..f79fcc859c --- /dev/null +++ b/src/Marten/StoreOptions.SourceGeneration.cs @@ -0,0 +1,84 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Marten.Events.Projections; + +namespace Marten; + +public partial class StoreOptions +{ + /// + /// Attempt to use a source-generated type manifest from the given assembly + /// (or the ApplicationAssembly if not specified) to register document types, + /// projection types, and event types at startup without runtime assembly scanning. + /// This is an opt-in feature that requires the Marten.SourceGeneration analyzer + /// package to be referenced in the consuming project. + /// + /// + /// The assembly to search for the generated manifest. Defaults to ApplicationAssembly. + /// + /// True if a source-generated manifest was found and applied; false otherwise. + public bool TryUseSourceGeneratedDiscovery(Assembly? assembly = null) + { + assembly ??= ApplicationAssembly; + if (assembly == null) return false; + + var manifestType = assembly.GetType("Marten.Generated.DiscoveredMartenTypes"); + if (manifestType == null) return false; + + ApplySourceGeneratedManifest(manifestType); + return true; + } + + /// + /// Use a source-generated type manifest to register document types, + /// projection types, and event types. This is a convenience method + /// that directly accepts the manifest type from the generated code. + /// + /// The generated DiscoveredMartenTypes type. + public void UseSourceGeneratedDiscovery(Type manifestType) + { + ApplySourceGeneratedManifest(manifestType); + } + + private void ApplySourceGeneratedManifest(Type manifestType) + { + // Register document types + var documentTypesProp = manifestType.GetProperty("DocumentTypes", + BindingFlags.Public | BindingFlags.Static); + if (documentTypesProp != null) + { + var documentTypes = documentTypesProp.GetValue(null) as IReadOnlyList; + if (documentTypes != null) + { + foreach (var docType in documentTypes) + { + RegisterDocumentType(docType); + } + } + } + + // Register event types + var eventTypesProp = manifestType.GetProperty("EventTypes", + BindingFlags.Public | BindingFlags.Static); + if (eventTypesProp != null) + { + var eventTypes = eventTypesProp.GetValue(null) as IReadOnlyList; + if (eventTypes != null) + { + foreach (var eventType in eventTypes) + { + Events.AddEventType(eventType); + } + } + } + + // Note: Projection registration is NOT done automatically here because + // projections require a ProjectionLifecycle (Inline, Async, Live) which + // cannot be reliably inferred at compile time. Users should continue to + // register projections explicitly via StoreOptions.Projections.Add(). + // The manifest's ProjectionTypes property is available for tooling and + // diagnostics (e.g., drift detection). + } +}