diff --git a/Orleans.sln b/Orleans.sln index fdb05321bf8..8263ecb9adc 100644 --- a/Orleans.sln +++ b/Orleans.sln @@ -258,6 +258,12 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ChaoticCluster.ServiceDefau EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TestSerializerExternalModels", "test\Misc\TestSerializerExternalModels\TestSerializerExternalModels.csproj", "{5D587DDE-036D-4694-A314-8DDF270AC031}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Orleans.Journaling", "src\Orleans.Journaling\Orleans.Journaling.csproj", "{20EFDCFC-F3FE-5509-5950-516E90DE1E05}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Orleans.Journaling.Tests", "test\Orleans.Journaling.Tests\Orleans.Journaling.Tests.csproj", "{4A4D30F4-6D61-7A80-8352-D76BD29582E0}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Orleans.Journaling.AzureStorage", "src\Azure\Orleans.Journaling.AzureStorage\Orleans.Journaling.AzureStorage.csproj", "{E613A10D-757D-44BA-97C1-3D06C22BDB2E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -672,6 +678,18 @@ Global {5D587DDE-036D-4694-A314-8DDF270AC031}.Debug|Any CPU.Build.0 = Debug|Any CPU {5D587DDE-036D-4694-A314-8DDF270AC031}.Release|Any CPU.ActiveCfg = Release|Any CPU {5D587DDE-036D-4694-A314-8DDF270AC031}.Release|Any CPU.Build.0 = Release|Any CPU + {20EFDCFC-F3FE-5509-5950-516E90DE1E05}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {20EFDCFC-F3FE-5509-5950-516E90DE1E05}.Debug|Any CPU.Build.0 = Debug|Any CPU + {20EFDCFC-F3FE-5509-5950-516E90DE1E05}.Release|Any CPU.ActiveCfg = Release|Any CPU + {20EFDCFC-F3FE-5509-5950-516E90DE1E05}.Release|Any CPU.Build.0 = Release|Any CPU + {4A4D30F4-6D61-7A80-8352-D76BD29582E0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {4A4D30F4-6D61-7A80-8352-D76BD29582E0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {4A4D30F4-6D61-7A80-8352-D76BD29582E0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {4A4D30F4-6D61-7A80-8352-D76BD29582E0}.Release|Any CPU.Build.0 = Release|Any CPU + {E613A10D-757D-44BA-97C1-3D06C22BDB2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E613A10D-757D-44BA-97C1-3D06C22BDB2E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E613A10D-757D-44BA-97C1-3D06C22BDB2E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E613A10D-757D-44BA-97C1-3D06C22BDB2E}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -796,6 +814,9 @@ Global {76A549FA-69F1-4967-82B6-161A8B52C86B} = {2579A7F6-EBE8-485A-BB20-A5D19DB5612B} {4004A79F-B6BB-4472-891B-AD1348AE3E93} = {2579A7F6-EBE8-485A-BB20-A5D19DB5612B} {5D587DDE-036D-4694-A314-8DDF270AC031} = {70BCC54E-1618-4742-A079-07588065E361} + {20EFDCFC-F3FE-5509-5950-516E90DE1E05} = {4CD3AA9E-D937-48CA-BB6C-158E12257D23} + {4A4D30F4-6D61-7A80-8352-D76BD29582E0} = {A6573187-FD0D-4DF7-91D1-03E07E470C0A} + {E613A10D-757D-44BA-97C1-3D06C22BDB2E} = {4C5D66BF-EE1C-4DD8-8551-D1B7F3768A34} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {7BFB3429-B5BB-4DB1-95B4-67D77A864952} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobLogStorage.cs b/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobLogStorage.cs new file mode 100644 index 00000000000..d834dccf308 --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobLogStorage.cs @@ -0,0 +1,178 @@ +using Azure; +using Azure.Storage.Blobs.Specialized; +using Azure.Storage.Blobs.Models; +using System.Runtime.CompilerServices; +using Azure.Storage.Sas; +using Orleans.Serialization.Buffers; +using Microsoft.Extensions.Logging; + +namespace Orleans.Journaling; + +internal sealed partial class AzureAppendBlobLogStorage : IStateMachineStorage +{ + private static readonly AppendBlobCreateOptions CreateOptions = new() { Conditions = new() { IfNoneMatch = ETag.All } }; + private readonly AppendBlobClient _client; + private readonly ILogger _logger; + private readonly LogExtentBuilder.ReadOnlyStream _stream; + private readonly AppendBlobAppendBlockOptions _appendOptions; + private bool _exists; + private int _numBlocks; + + public bool IsCompactionRequested => _numBlocks > 10; + + public AzureAppendBlobLogStorage(AppendBlobClient client, ILogger logger) + { + _client = client; + _logger = logger; + _stream = new(); + + // For the first request, if we have not performed a read yet, we want to guard against clobbering an existing blob. + _appendOptions = new AppendBlobAppendBlockOptions() { Conditions = new AppendBlobRequestConditions { IfNoneMatch = ETag.All } }; + } + + public async ValueTask AppendAsync(LogExtentBuilder value, CancellationToken cancellationToken) + { + if (!_exists) + { + var response = await _client.CreateAsync(CreateOptions, cancellationToken); + _appendOptions.Conditions.IfNoneMatch = default; + _appendOptions.Conditions.IfMatch = response.Value.ETag; + _exists = true; + } + + _stream.SetBuilder(value); + var result = await _client.AppendBlockAsync(_stream, _appendOptions, cancellationToken).ConfigureAwait(false); + LogAppend(_logger, value.Length, _client.BlobContainerName, _client.Name); + + _stream.Reset(); + _appendOptions.Conditions.IfNoneMatch = default; + _appendOptions.Conditions.IfMatch = result.Value.ETag; + _numBlocks = result.Value.BlobCommittedBlockCount; + } + + public async ValueTask DeleteAsync(CancellationToken cancellationToken) + { + var conditions = new BlobRequestConditions { IfMatch = _appendOptions.Conditions.IfMatch }; + await _client.DeleteAsync(conditions: conditions, cancellationToken: cancellationToken).ConfigureAwait(false); + + // Expect no blob to have been created when we append to it. + _appendOptions.Conditions.IfNoneMatch = ETag.All; + _appendOptions.Conditions.IfMatch = default; + _numBlocks = 0; + } + + public async IAsyncEnumerable ReadAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + Response result; + try + { + // If the blob was not newly created, then download the blob. + result = await _client.DownloadStreamingAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RequestFailedException exception) when (exception.Status is 404) + { + _exists = false; + yield break; + } + + // If the blob has a size of zero, check for a snapshot and restore the blob from the snapshot if one exists. + if (result.Value.Details.ContentLength == 0) + { + if (result.Value.Details.Metadata.TryGetValue("snapshot", out var snapshot) && snapshot is { Length: > 0 }) + { + result = await CopyFromSnapshotAsync(result.Value.Details.ETag, snapshot, cancellationToken).ConfigureAwait(false); + } + } + + _numBlocks = result.Value.Details.BlobCommittedBlockCount; + _appendOptions.Conditions.IfNoneMatch = default; + _appendOptions.Conditions.IfMatch = result.Value.Details.ETag; + _exists = true; + + // Read everything into a single log segment. We could change this to read in chunks, + // yielding when the stream does not return synchronously, if we wanted to support larger state machines. + var rawStream = result.Value.Content; + using var buffer = new ArcBufferWriter(); + while (true) + { + var mem = buffer.GetMemory(); + var bytesRead = await rawStream.ReadAsync(mem, cancellationToken); + if (bytesRead == 0) + { + if (buffer.Length > 0) + { + LogRead(_logger, buffer.Length, _client.BlobContainerName, _client.Name); + yield return new LogExtent(buffer.ConsumeSlice(buffer.Length)); + } + + yield break; + } + + buffer.AdvanceWriter(bytesRead); + } + } + + private async Task> CopyFromSnapshotAsync(ETag eTag, string snapshotDetail, CancellationToken cancellationToken) + { + // Read snapshot and append it to the blob. + var snapshot = _client.WithSnapshot(snapshotDetail); + var uri = snapshot.GenerateSasUri(permissions: BlobSasPermissions.Read, expiresOn: DateTimeOffset.UtcNow.AddHours(1)); + var copyResult = await _client.SyncCopyFromUriAsync( + uri, + new BlobCopyFromUriOptions { DestinationConditions = new BlobRequestConditions { IfNoneMatch = eTag } }, + cancellationToken).ConfigureAwait(false); + if (copyResult.Value.CopyStatus is not CopyStatus.Success) + { + throw new InvalidOperationException($"Copy did not complete successfully. Status: {copyResult.Value.CopyStatus}"); + } + + var result = await _client.DownloadStreamingAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + _exists = true; + return result; + } + + public async ValueTask ReplaceAsync(LogExtentBuilder value, CancellationToken cancellationToken) + { + // Create a snapshot of the blob for recovery purposes. + var blobSnapshot = await _client.CreateSnapshotAsync(conditions: _appendOptions.Conditions, cancellationToken: cancellationToken).ConfigureAwait(false); + + // Open the blob for writing, overwriting existing contents. + var createOptions = new AppendBlobCreateOptions() + { + Conditions = _appendOptions.Conditions, + Metadata = new Dictionary { ["snapshot"] = blobSnapshot.Value.Snapshot }, + }; + var createResult = await _client.CreateAsync(createOptions, cancellationToken).ConfigureAwait(false); + _appendOptions.Conditions.IfMatch = createResult.Value.ETag; + _appendOptions.Conditions.IfNoneMatch = default; + + // Write the state machine snapshot. + _stream.SetBuilder(value); + var result = await _client.AppendBlockAsync(_stream, _appendOptions, cancellationToken).ConfigureAwait(false); + LogReplace(_logger, _client.BlobContainerName, _client.Name, value.Length); + + _stream.Reset(); + _appendOptions.Conditions.IfNoneMatch = default; + _appendOptions.Conditions.IfMatch = result.Value.ETag; + _numBlocks = result.Value.BlobCommittedBlockCount; + + // Delete the blob snapshot. + await _client.WithSnapshot(blobSnapshot.Value.Snapshot).DeleteAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + } + + [LoggerMessage( + Level = LogLevel.Debug, + Message = "Appended {Length} bytes to blob \"{ContainerName}/{BlobName}\"")] + private static partial void LogAppend(ILogger logger, long length, string containerName, string blobName); + + [LoggerMessage( + Level = LogLevel.Debug, + Message = "Read {Length} bytes from blob \"{ContainerName}/{BlobName}\"")] + private static partial void LogRead(ILogger logger, long length, string containerName, string blobName); + + [LoggerMessage( + Level = LogLevel.Debug, + Message = "Replaced blob \"{ContainerName}/{BlobName}\", writing {Length} bytes")] + private static partial void LogReplace(ILogger logger, string containerName, string blobName, long length); + +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobStateMachineStorageOptions.cs b/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobStateMachineStorageOptions.cs new file mode 100644 index 00000000000..ba9f031acce --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobStateMachineStorageOptions.cs @@ -0,0 +1,120 @@ +using Azure; +using Azure.Storage.Blobs; +using Azure.Storage; +using Azure.Core; +using Orleans.Runtime; + +namespace Orleans.Journaling; + +/// +/// Options for configuring the Azure Append Blob state machine storage provider. +/// +public sealed class AzureAppendBlobStateMachineStorageOptions +{ + private BlobServiceClient? _blobServiceClient; + + /// + /// Container name where state machine state is stored. + /// + public string ContainerName { get; set; } = DEFAULT_CONTAINER_NAME; + public const string DEFAULT_CONTAINER_NAME = "state"; + + /// + /// Gets or sets the delegate used to generate the blob name for a given grain. + /// + public Func GetBlobName { get; set; } = DefaultGetBlobName; + + private static readonly Func DefaultGetBlobName = static (GrainId grainId) => $"{grainId}.bin"; + + /// + /// Options to be used when configuring the blob storage client, or to use the default options. + /// + public BlobClientOptions? ClientOptions { get; set; } + + /// + /// Gets or sets the client used to access the Azure Blob Service. + /// + public BlobServiceClient? BlobServiceClient + { + get => _blobServiceClient; + set + { + ArgumentNullException.ThrowIfNull(value); + _blobServiceClient = value; + CreateClient = ct => Task.FromResult(value); + } + } + + /// + /// The optional delegate used to create a instance. + /// + internal Func>? CreateClient { get; private set; } + + /// + /// Stage of silo lifecycle where storage should be initialized. Storage must be initialized prior to use. + /// + public int InitStage { get; set; } = DEFAULT_INIT_STAGE; + public const int DEFAULT_INIT_STAGE = ServiceLifecycleStage.ApplicationServices; + + /// + /// A function for building container factory instances. + /// + public Func BuildContainerFactory { get; set; } + = static (provider, options) => new DefaultBlobContainerFactory(options); + + /// + /// Configures the using a connection string. + /// + public void ConfigureBlobServiceClient(string connectionString) + { + ArgumentException.ThrowIfNullOrWhiteSpace(connectionString); + CreateClient = ct => Task.FromResult(new BlobServiceClient(connectionString, ClientOptions)); + } + + /// + /// Configures the using an authenticated service URI. + /// + public void ConfigureBlobServiceClient(Uri serviceUri) + { + ArgumentNullException.ThrowIfNull(serviceUri); + CreateClient = ct => Task.FromResult(new BlobServiceClient(serviceUri, ClientOptions)); + } + + /// + /// Configures the using the provided callback. + /// + public void ConfigureBlobServiceClient(Func> createClientCallback) + { + CreateClient = createClientCallback ?? throw new ArgumentNullException(nameof(createClientCallback)); + } + + /// + /// Configures the using an authenticated service URI and a . + /// + public void ConfigureBlobServiceClient(Uri serviceUri, TokenCredential tokenCredential) + { + ArgumentNullException.ThrowIfNull(serviceUri); + ArgumentNullException.ThrowIfNull(tokenCredential); + CreateClient = ct => Task.FromResult(new BlobServiceClient(serviceUri, tokenCredential, ClientOptions)); + } + + /// + /// Configures the using an authenticated service URI and a . + /// + public void ConfigureBlobServiceClient(Uri serviceUri, AzureSasCredential azureSasCredential) + { + ArgumentNullException.ThrowIfNull(serviceUri); + ArgumentNullException.ThrowIfNull(azureSasCredential); + CreateClient = ct => Task.FromResult(new BlobServiceClient(serviceUri, azureSasCredential, ClientOptions)); + } + + /// + /// Configures the using an authenticated service URI and a . + /// + public void ConfigureBlobServiceClient(Uri serviceUri, StorageSharedKeyCredential sharedKeyCredential) + { + ArgumentNullException.ThrowIfNull(serviceUri); + ArgumentNullException.ThrowIfNull(sharedKeyCredential); + CreateClient = ct => Task.FromResult(new BlobServiceClient(serviceUri, sharedKeyCredential, ClientOptions)); + } +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobStateMachineStorageProvider.cs b/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobStateMachineStorageProvider.cs new file mode 100644 index 00000000000..7d384a63844 --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/AzureAppendBlobStateMachineStorageProvider.cs @@ -0,0 +1,37 @@ +using Azure.Storage.Blobs.Specialized; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging; +using Orleans.Runtime; + +namespace Orleans.Journaling; + +internal sealed class AzureAppendBlobStateMachineStorageProvider( + IOptions options, + IServiceProvider serviceProvider, + ILogger logger) : IStateMachineStorageProvider, ILifecycleParticipant +{ + private readonly IBlobContainerFactory _containerFactory = options.Value.BuildContainerFactory(serviceProvider, options.Value); + private readonly AzureAppendBlobStateMachineStorageOptions _options = options.Value; + + private async Task Initialize(CancellationToken cancellationToken) + { + var client = await _options.CreateClient!(cancellationToken); + await _containerFactory.InitializeAsync(client, cancellationToken).ConfigureAwait(false); + } + + public IStateMachineStorage Create(IGrainContext grainContext) + { + var container = _containerFactory.GetBlobContainerClient(grainContext.GrainId); + var blobName = _options.GetBlobName(grainContext.GrainId); + var blobClient = container.GetAppendBlobClient(blobName); + return new AzureAppendBlobLogStorage(blobClient, logger); + } + + public void Participate(ISiloLifecycle observer) + { + observer.Subscribe( + nameof(AzureAppendBlobStateMachineStorageProvider), + ServiceLifecycleStage.RuntimeInitialize, + onStart: Initialize); + } +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/AzureBlobStorageGrainJournalingProviderBuilder.cs b/src/Azure/Orleans.Journaling.AzureStorage/AzureBlobStorageGrainJournalingProviderBuilder.cs new file mode 100644 index 00000000000..0254d312815 --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/AzureBlobStorageGrainJournalingProviderBuilder.cs @@ -0,0 +1,57 @@ +using Azure.Storage.Blobs; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Orleans; +using Orleans.Hosting; +using Orleans.Journaling; +using Orleans.Providers; + +[assembly: RegisterProvider("AzureBlobStorage", "GrainJournaling", "Silo", typeof(AzureBlobStorageGrainJournalingProviderBuilder))] +namespace Orleans.Hosting; + +internal sealed class AzureBlobStorageGrainJournalingProviderBuilder : IProviderBuilder +{ + public void Configure(ISiloBuilder builder, string name, IConfigurationSection configurationSection) + { + builder.AddAzureAppendBlobStateMachineStorage(); + var optionsBuilder = builder.Services.AddOptions(); + optionsBuilder.Configure((options, services) => + { + var containerName = configurationSection["ContainerName"]; + if (!string.IsNullOrEmpty(containerName)) + { + options.ContainerName = containerName; + } + + var serviceKey = configurationSection["ServiceKey"]; + if (!string.IsNullOrEmpty(serviceKey)) + { + // Get a client by name. + options.BlobServiceClient = services.GetRequiredKeyedService(serviceKey); + } + else + { + // Construct a connection multiplexer from a connection string. + var connectionName = configurationSection["ConnectionName"]; + var connectionString = configurationSection["ConnectionString"]; + if (!string.IsNullOrEmpty(connectionName) && string.IsNullOrEmpty(connectionString)) + { + var rootConfiguration = services.GetRequiredService(); + connectionString = rootConfiguration.GetConnectionString(connectionName); + } + + if (!string.IsNullOrEmpty(connectionString)) + { + if (Uri.TryCreate(connectionString, UriKind.Absolute, out var uri)) + { + options.BlobServiceClient = new(uri); + } + else + { + options.BlobServiceClient = new(connectionString); + } + } + } + }); + } +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/AzureBlobStorageHostingExtensions.cs b/src/Azure/Orleans.Journaling.AzureStorage/AzureBlobStorageHostingExtensions.cs new file mode 100644 index 00000000000..99786a83b5f --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/AzureBlobStorageHostingExtensions.cs @@ -0,0 +1,33 @@ +using Microsoft.Extensions.DependencyInjection; +using Orleans.Configuration.Internal; +using Orleans.Runtime; +using Orleans.Hosting; + +namespace Orleans.Journaling; + +public static class AzureBlobStorageHostingExtensions +{ + public static ISiloBuilder AddAzureAppendBlobStateMachineStorage(this ISiloBuilder builder) => builder.AddAzureAppendBlobStateMachineStorage(configure: null); + public static ISiloBuilder AddAzureAppendBlobStateMachineStorage(this ISiloBuilder builder, Action? configure) + { + builder.AddStateMachineStorage(); + + var services = builder.Services; + + var options = builder.Services.AddOptions(); + if (configure is not null) + { + options.Configure(configure); + } + + if (services.Any(service => service.ServiceType.Equals(typeof(AzureAppendBlobStateMachineStorageProvider)))) + { + return builder; + } + + builder.Services.AddSingleton(); + builder.Services.AddFromExisting(); + builder.Services.AddFromExisting, AzureAppendBlobStateMachineStorageProvider>(); + return builder; + } +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/DefaultBlobContainerFactory.cs b/src/Azure/Orleans.Journaling.AzureStorage/DefaultBlobContainerFactory.cs new file mode 100644 index 00000000000..f2484d1c03a --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/DefaultBlobContainerFactory.cs @@ -0,0 +1,26 @@ +using Azure.Storage.Blobs; +using Orleans.Runtime; + +namespace Orleans.Journaling; + +/// +/// A default blob container factory that uses the default container name. +/// +/// +/// Initializes a new instance of the class. +/// +/// The blob storage options +internal sealed class DefaultBlobContainerFactory(AzureAppendBlobStateMachineStorageOptions options) : IBlobContainerFactory +{ + private BlobContainerClient _defaultContainer = null!; + + /// + public BlobContainerClient GetBlobContainerClient(GrainId grainId) => _defaultContainer; + + /// + public async Task InitializeAsync(BlobServiceClient client, CancellationToken cancellationToken) + { + _defaultContainer = client.GetBlobContainerClient(options.ContainerName); + await _defaultContainer.CreateIfNotExistsAsync(cancellationToken: cancellationToken); + } +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/IBlobContainerFactory.cs b/src/Azure/Orleans.Journaling.AzureStorage/IBlobContainerFactory.cs new file mode 100644 index 00000000000..0e74463aadb --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/IBlobContainerFactory.cs @@ -0,0 +1,25 @@ +using Azure.Storage.Blobs; +using Orleans.Runtime; + +namespace Orleans.Journaling; + +/// +/// A factory for building container clients for blob storage using GrainId +/// +public interface IBlobContainerFactory +{ + /// + /// Gets the container which should be used for the specified grain. + /// + /// The grain id + /// A configured blob client + public BlobContainerClient GetBlobContainerClient(GrainId grainId); + + /// + /// Initialize any required dependencies using the provided client and options. + /// + /// The connected blob client + /// A token used to cancel the request. + /// A representing the asynchronous operation. + public Task InitializeAsync(BlobServiceClient client, CancellationToken cancellationToken); +} diff --git a/src/Azure/Orleans.Journaling.AzureStorage/Orleans.Journaling.AzureStorage.csproj b/src/Azure/Orleans.Journaling.AzureStorage/Orleans.Journaling.AzureStorage.csproj new file mode 100644 index 00000000000..de8e06acff0 --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/Orleans.Journaling.AzureStorage.csproj @@ -0,0 +1,23 @@ + + + + net8.0 + enable + enable + $(VersionSuffix).alpha.1 + alpha.1 + + + + + + + + + + + + + + + diff --git a/src/Azure/Orleans.Journaling.AzureStorage/Properties/AssemblyInfo.cs b/src/Azure/Orleans.Journaling.AzureStorage/Properties/AssemblyInfo.cs new file mode 100644 index 00000000000..9d00e67f920 --- /dev/null +++ b/src/Azure/Orleans.Journaling.AzureStorage/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Diagnostics.CodeAnalysis; + +[assembly: Experimental("ORLEANSEXP005")] diff --git a/src/Orleans.Core/Lifecycle/MigrationContext.cs b/src/Orleans.Core/Lifecycle/MigrationContext.cs index 8b918b49f3c..cadd6df2eb0 100644 --- a/src/Orleans.Core/Lifecycle/MigrationContext.cs +++ b/src/Orleans.Core/Lifecycle/MigrationContext.cs @@ -79,7 +79,11 @@ public bool TryAddValue(string key, T? value) public IEnumerable Keys => this; - public void Dispose() => _buffer.Reset(); + public void Dispose() + { + _buffer.Reset(); + _buffer = default; + } public bool TryGetBytes(string key, out ReadOnlySequence value) { diff --git a/src/Orleans.Journaling/DurableDictionary.cs b/src/Orleans.Journaling/DurableDictionary.cs new file mode 100644 index 00000000000..1883d1c70fd --- /dev/null +++ b/src/Orleans.Journaling/DurableDictionary.cs @@ -0,0 +1,305 @@ +using System.Buffers; +using System.Collections; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +public interface IDurableDictionary : IDictionary where K : notnull +{ +} + +[DebuggerTypeProxy(typeof(IDurableDictionaryDebugView<,>))] +[DebuggerDisplay("Count = {Count}")] +internal class DurableDictionary : IDurableDictionary, IDurableStateMachine where K : notnull +{ + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _keyCodec; + private readonly IFieldCodec _valueCodec; + private const byte VersionByte = 0; + private readonly Dictionary _items = []; + private IStateMachineLogWriter? _storage; + + protected DurableDictionary(IFieldCodec keyCodec, IFieldCodec valueCodec, SerializerSessionPool serializerSessionPool) + { + _keyCodec = keyCodec; + _valueCodec = valueCodec; + _serializerSessionPool = serializerSessionPool; + } + + public DurableDictionary([ServiceKey] string key, IStateMachineManager manager, IFieldCodec keyCodec, IFieldCodec valueCodec, SerializerSessionPool serializerSessionPool) : this(keyCodec, valueCodec, serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + manager.RegisterStateMachine(key, this); + } + + public V this[K key] + { + get => _items[key]; + + set + { + ApplySet(key, value); + AppendSet(key, value); + } + } + + public int Count => _items.Count; + + public ICollection Keys => _items.Keys; + + public ICollection Values => _items.Values; + + public bool IsReadOnly => ((ICollection>)_items).IsReadOnly; + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) + { + _items.Clear(); + _storage = storage; + } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != VersionByte) + { + throw new NotSupportedException($"This instance of {nameof(DurableDictionary)} supports version {(uint)VersionByte} and not version {(uint)version}."); + } + + var commandType = (CommandType)reader.ReadVarUInt32(); + switch (commandType) + { + case CommandType.Set: + ApplySet(ReadKey(ref reader), ReadValue(ref reader)); + break; + case CommandType.Remove: + ApplyRemove(ReadKey(ref reader)); + break; + case CommandType.Clear: + ApplyClear(); + break; + case CommandType.Snapshot: + ApplySnapshot(ref reader); + break; + default: + throw new NotSupportedException($"Command type {commandType} is not supported"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + K ReadKey(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _keyCodec.ReadValue(ref reader, field); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + V ReadValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _valueCodec.ReadValue(ref reader, field); + } + + void ApplySnapshot(ref Reader reader) + { + var count = (int)reader.ReadVarUInt32(); + _items.Clear(); + _items.EnsureCapacity(count); + for (var i = 0; i < count; i++) + { + var key = ReadKey(ref reader); + var value = ReadValue(ref reader); + ApplySet(key, value); + } + } + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) + { + // This state machine implementation appends log entries as the data structure is modified, so there is no need to perform separate writing here. + } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) + { + snapshotWriter.AppendEntry(static (self, bufferWriter) => + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Snapshot); + writer.WriteVarUInt32((uint)self._items.Count); + foreach (var (key, value) in self._items) + { + self._keyCodec.WriteField(ref writer, 0, typeof(K), key); + self._valueCodec.WriteField(ref writer, 0, typeof(V), value); + } + + writer.Commit(); + }, this); + } + + public void Clear() + { + ApplyClear(); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + using var session = state._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Clear); + writer.Commit(); + }, + this); + } + + public bool Contains(K key) => _items.ContainsKey(key); + + public bool Remove(K key) + { + if (ApplyRemove(key)) + { + AppendRemove(key); + return true; + } + + return false; + } + + private void AppendRemove(K key) + { + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, key) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Remove); + self._keyCodec.WriteField(ref writer, 0, typeof(K), key); + writer.Commit(); + }, (this, key)); + } + + IEnumerator IEnumerable.GetEnumerator() => _items.GetEnumerator(); + + private void AppendSet(K key, V value) + { + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, key, value) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Set); + self._keyCodec.WriteField(ref writer, 0, typeof(K), key); + self._valueCodec.WriteField(ref writer, 1, typeof(V), value); + writer.Commit(); + }, + (this, key, value)); + } + + protected virtual void OnSet(K key, V value) { } + + private void ApplySet(K key, V value) + { + _items[key] = value; + OnSet(key, value); + } + + private bool ApplyRemove(K key) => _items.Remove(key); + private void ApplyClear() => _items.Clear(); + + private IStateMachineLogWriter GetStorage() + { + Debug.Assert(_storage is not null); + return _storage; + } + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); + public void Add(K key, V value) + { + _items.Add(key, value); + OnSet(key, value); + AppendSet(key, value); + } + + public bool ContainsKey(K key) => _items.ContainsKey(key); + public bool TryGetValue(K key, [MaybeNullWhen(false)] out V value) => _items.TryGetValue(key, out value); + public void Add(KeyValuePair item) => Add(item.Key, item.Value); + public bool Contains(KeyValuePair item) => _items.Contains(item); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)_items).CopyTo(array, arrayIndex); + public bool Remove(KeyValuePair item) + { + if (((ICollection>)_items).Remove(item)) + { + AppendRemove(item.Key); + return true; + } + + return false; + } + + public IEnumerator> GetEnumerator() => ((IEnumerable>)_items).GetEnumerator(); + + private enum CommandType + { + Set = 0, + Remove = 1, + Clear = 2, + Snapshot = 3 + } +} + +[DebuggerDisplay("{Value}", Name = "[{Key}]")] +internal readonly struct DebugViewDictionaryItem +{ + public DebugViewDictionaryItem(TKey key, TValue value) + { + Key = key; + Value = value; + } + + public DebugViewDictionaryItem(KeyValuePair keyValue) + { + Key = keyValue.Key; + Value = keyValue.Value; + } + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public TKey Key { get; } + + [DebuggerBrowsable(DebuggerBrowsableState.Collapsed)] + public TValue Value { get; } +} + +internal sealed class IDurableDictionaryDebugView where TKey : notnull +{ + private readonly IDurableDictionary _dict; + + public IDurableDictionaryDebugView(IDurableDictionary dictionary) + { + ArgumentNullException.ThrowIfNull(dictionary); + _dict = dictionary; + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public DebugViewDictionaryItem[] Items + { + get + { + var keyValuePairs = new KeyValuePair[_dict.Count]; + _dict.CopyTo(keyValuePairs, 0); + var items = new DebugViewDictionaryItem[keyValuePairs.Length]; + for (int i = 0; i < items.Length; i++) + { + items[i] = new DebugViewDictionaryItem(keyValuePairs[i]); + } + return items; + } + } +} diff --git a/src/Orleans.Journaling/DurableGrain.cs b/src/Orleans.Journaling/DurableGrain.cs new file mode 100644 index 00000000000..80bc8e4641f --- /dev/null +++ b/src/Orleans.Journaling/DurableGrain.cs @@ -0,0 +1,35 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Orleans.Journaling; + +public abstract class DurableGrain : Grain, IGrainBase +{ + protected DurableGrain() + { + StateMachineManager = ServiceProvider.GetRequiredService(); + if (StateMachineManager is ILifecycleParticipant participant) + { + participant.Participate(((IGrainBase)this).GrainContext.ObservableLifecycle); + } + } + + protected IStateMachineManager StateMachineManager { get; } + + protected TStateMachine GetOrCreateStateMachine(string name) where TStateMachine : class, IDurableStateMachine + => GetOrCreateStateMachine(name, static sp => sp.GetRequiredService(), ServiceProvider); + + protected TStateMachine GetOrCreateStateMachine(string name, Func createStateMachine, TState state) where TStateMachine : class, IDurableStateMachine + { + if (StateMachineManager.TryGetStateMachine(name, out var stateMachine)) + { + return stateMachine as TStateMachine + ?? throw new InvalidOperationException($"A state machine named '{name}' already exists with an incompatible type {stateMachine.GetType()} versus {typeof(TStateMachine)}"); + } + + var result = createStateMachine(state); + StateMachineManager.RegisterStateMachine(name, result); + return result; + } + + protected ValueTask WriteStateAsync(CancellationToken cancellationToken = default) => StateMachineManager.WriteStateAsync(cancellationToken); +} diff --git a/src/Orleans.Journaling/DurableList.cs b/src/Orleans.Journaling/DurableList.cs new file mode 100644 index 00000000000..366ea2d45e8 --- /dev/null +++ b/src/Orleans.Journaling/DurableList.cs @@ -0,0 +1,306 @@ +using System.Buffers; +using System.Collections; +using System.Collections.ObjectModel; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +public interface IDurableList : IList +{ + void AddRange(IEnumerable collection); + ReadOnlyCollection AsReadOnly(); +} + +[DebuggerTypeProxy(typeof(IDurableCollectionDebugView<>))] +[DebuggerDisplay("Count = {Count}")] +internal sealed class DurableList : IDurableList, IDurableStateMachine +{ + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _codec; + private const byte VersionByte = 0; + private readonly List _items = []; + private IStateMachineLogWriter? _storage; + + public DurableList([ServiceKey] string key, IStateMachineManager manager, IFieldCodec codec, SerializerSessionPool serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + _codec = codec; + _serializerSessionPool = serializerSessionPool; + manager.RegisterStateMachine(key, this); + } + + public T this[int index] + { + get => _items[index]; + + set + { + if ((uint)index >= (uint)_items.Count) + { + ThrowIndexOutOfRange(); + } + + ApplySet(index, value); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, index, value) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Set); + writer.WriteVarUInt32((uint)index); + self._codec.WriteField(ref writer, 0, typeof(T), value!); + writer.Commit(); + }, + (this, index, value)); + } + } + + public int Count => _items.Count; + + bool ICollection.IsReadOnly => false; + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) + { + _items.Clear(); + _storage = storage; + } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != VersionByte) + { + throw new NotSupportedException($"This instance of {nameof(DurableList)} supports version {(uint)VersionByte} and not version {(uint)version}."); + } + + var commandType = (CommandType)reader.ReadVarUInt32(); + switch (commandType) + { + case CommandType.Add: + ApplyAdd(ReadValue(ref reader)); + break; + case CommandType.Set: + { + var index = (int)reader.ReadVarUInt32(); + var value = ReadValue(ref reader); + ApplySet(index, value); + } + break; + case CommandType.Insert: + { + var index = (int)reader.ReadVarUInt32(); + var value = ReadValue(ref reader); + ApplyInsert(index, value); + } + break; + case CommandType.Remove: + ApplyRemoveAt((int)reader.ReadVarUInt32()); + break; + case CommandType.Clear: + ApplyClear(); + break; + case CommandType.Snapshot: + ApplySnapshot(ref reader); + break; + default: + throw new NotSupportedException($"Command type {commandType} is not supported"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + T ReadValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _codec.ReadValue(ref reader, field); + } + + void ApplySnapshot(ref Reader reader) + { + var count = reader.ReadVarUInt32(); + ApplyClear(); + _items.EnsureCapacity((int)count); + for (var i = 0; i < count; i++) + { + ApplyAdd(ReadValue(ref reader)); + } + } + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) + { + // This state machine implementation appends log entries as the data structure is modified, so there is no need to perform separate writing here. + } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) + { + snapshotWriter.AppendEntry(static (self, bufferWriter) => + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Snapshot); + writer.WriteVarUInt32((uint)self._items.Count); + foreach (var item in self._items) + { + self._codec.WriteField(ref writer, 0, typeof(T), item); + } + + writer.Commit(); + }, this); + } + + public void Add(T item) + { + ApplyAdd(item); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, item) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Add); + self._codec.WriteField(ref writer, 0, typeof(T), item!); + writer.Commit(); + }, + (this, item)); + } + + public void Clear() + { + ApplyClear(); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + using var session = state._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Clear); + writer.Commit(); + }, + this); + } + + public bool Contains(T item) => _items.Contains(item); + public void CopyTo(T[] array, int arrayIndex) => _items.CopyTo(array, arrayIndex); + public IEnumerator GetEnumerator() => _items.GetEnumerator(); + public int IndexOf(T item) => _items.IndexOf(item); + public void Insert(int index, T item) + { + ApplyInsert(index, item); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, index, value) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Insert); + writer.WriteVarUInt32((uint)index); + self._codec.WriteField(ref writer, 0, typeof(T), value!); + writer.Commit(); + }, + (this, index, item)); + } + + public bool Remove(T item) + { + var index = _items.IndexOf(item); + if (index >= 0) + { + RemoveAt(index); + return true; + } + + return false; + } + + public void RemoveAt(int index) + { + ApplyRemoveAt(index); + + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, index) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Remove); + writer.WriteVarUInt32((uint)index); + writer.Commit(); + }, (this, index)); + } + + IEnumerator IEnumerable.GetEnumerator() => _items.GetEnumerator(); + + protected void ApplyAdd(T item) => _items.Add(item); + protected void ApplySet(int index, T item) => _items[index] = item; + protected void ApplyInsert(int index, T item) => _items.Insert(index, item); + protected void ApplyRemoveAt(int index) => _items.RemoveAt(index); + protected void ApplyClear() => _items.Clear(); + + [DoesNotReturn] + private static void ThrowIndexOutOfRange() => throw new ArgumentOutOfRangeException("index", "Index was out of range. Must be non-negative and less than the size of the collection"); + + private IStateMachineLogWriter GetStorage() + { + Debug.Assert(_storage is not null); + return _storage; + } + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); + public void AddRange(IEnumerable collection) + { + foreach (var element in collection) + { + Add(element); + } + } + + public ReadOnlyCollection AsReadOnly() => _items.AsReadOnly(); + + private enum CommandType + { + Add = 0, + Set = 1, + Insert = 2, + Remove = 3, + Clear = 4, + Snapshot = 5 + } +} + +internal sealed class IDurableCollectionDebugView +{ + private readonly ICollection _collection; + + public IDurableCollectionDebugView(ICollection collection) + { +#if NET + ArgumentNullException.ThrowIfNull(collection); +#else + if (collection is null) + { + throw new ArgumentNullException(nameof(collection)); + } +#endif + + _collection = collection; + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public T[] Items + { + get + { + T[] items = new T[_collection.Count]; + _collection.CopyTo(items, 0); + return items; + } + } +} \ No newline at end of file diff --git a/src/Orleans.Journaling/DurableNothing.cs b/src/Orleans.Journaling/DurableNothing.cs new file mode 100644 index 00000000000..b9336e93d71 --- /dev/null +++ b/src/Orleans.Journaling/DurableNothing.cs @@ -0,0 +1,33 @@ +using System.Buffers; +using Microsoft.Extensions.DependencyInjection; + +namespace Orleans.Journaling; + +/// +/// A durable object which does nothing, used for retiring other durable types. +/// +public interface IDurableNothing +{ +} + +/// +/// A durable object which does nothing, used for retiring other durable types. +/// +internal sealed class DurableNothing : IDurableNothing, IDurableStateMachine +{ + public DurableNothing([ServiceKey] string key, IStateMachineManager manager) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + manager.RegisterStateMachine(key, this); + } + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) { } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) { } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) { } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) { } + + public IDurableStateMachine DeepCopy() => this; +} diff --git a/src/Orleans.Journaling/DurableQueue.cs b/src/Orleans.Journaling/DurableQueue.cs new file mode 100644 index 00000000000..61808e65240 --- /dev/null +++ b/src/Orleans.Journaling/DurableQueue.cs @@ -0,0 +1,237 @@ +using System.Buffers; +using System.Collections; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +public interface IDurableQueue : IEnumerable, IReadOnlyCollection +{ + void Clear(); + bool Contains(T item); + void CopyTo(T[] array, int arrayIndex); + T Dequeue(); + void Enqueue(T item); + T Peek(); + bool TryDequeue([MaybeNullWhen(false)] out T item); + bool TryPeek([MaybeNullWhen(false)] out T item); +} + +[DebuggerTypeProxy(typeof(DurableQueueDebugView<>))] +[DebuggerDisplay("Count = {Count}")] +internal sealed class DurableQueue : IDurableQueue, IDurableStateMachine +{ + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _codec; + private const byte VersionByte = 0; + private readonly Queue _items = new(); + private IStateMachineLogWriter? _storage; + + public DurableQueue([ServiceKey] string key, IStateMachineManager manager, IFieldCodec codec, SerializerSessionPool serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + _codec = codec; + _serializerSessionPool = serializerSessionPool; + manager.RegisterStateMachine(key, this); + } + + public int Count => _items.Count; + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) + { + _items.Clear(); + _storage = storage; + } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != VersionByte) + { + throw new NotSupportedException($"This instance of {nameof(DurableQueue)} supports version {(uint)VersionByte} and not version {(uint)version}."); + } + + var commandType = (CommandType)reader.ReadVarUInt32(); + switch (commandType) + { + case CommandType.Enqueue: + ApplyEnqueue(ReadValue(ref reader)); + break; + case CommandType.Dequeue: + _ = ApplyDequeue(); + break; + case CommandType.Clear: + ApplyClear(); + break; + case CommandType.Snapshot: + ApplySnapshot(ref reader); + break; + default: + throw new NotSupportedException($"Command type {commandType} is not supported"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + T ReadValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _codec.ReadValue(ref reader, field); + } + + void ApplySnapshot(ref Reader reader) + { + var count = (int)reader.ReadVarUInt32(); + ApplyClear(); + _items.EnsureCapacity(count); + for (var i = 0; i < count; i++) + { + ApplyEnqueue(ReadValue(ref reader)); + } + } + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) + { + // This state machine implementation appends log entries as the data structure is modified, so there is no need to perform separate writing here. + } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) + { + snapshotWriter.AppendEntry(static (self, bufferWriter) => + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Snapshot); + writer.WriteVarUInt32((uint)self._items.Count); + foreach (var item in self._items) + { + self._codec.WriteField(ref writer, 0, typeof(T), item); + } + + writer.Commit(); + }, this); + } + + public void Clear() + { + ApplyClear(); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + using var session = state._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Clear); + writer.Commit(); + }, + this); + } + + public T Peek() => _items.Peek(); + public bool TryPeek([MaybeNullWhen(false)] out T item) => _items.TryPeek(out item); + public bool Contains(T item) => _items.Contains(item); + public void CopyTo(T[] array, int arrayIndex) => _items.CopyTo(array, arrayIndex); + public IEnumerator GetEnumerator() => _items.GetEnumerator(); + public void Enqueue(T item) + { + ApplyEnqueue(item); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, value) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Enqueue); + self._codec.WriteField(ref writer, 0, typeof(T), value!); + writer.Commit(); + }, + (this, item)); + } + + public T Dequeue() + { + var result = ApplyDequeue(); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var self = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Dequeue); + writer.Commit(); + }, this); + return result; + } + + public bool TryDequeue([MaybeNullWhen(false)] out T item) + { + if (ApplyTryDequeue(out item)) + { + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var self = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Dequeue); + writer.Commit(); + }, this); + return true; + } + + return false; + } + + IEnumerator IEnumerable.GetEnumerator() => _items.GetEnumerator(); + + protected void ApplyEnqueue(T item) => _items.Enqueue(item); + protected T ApplyDequeue() => _items.Dequeue(); + protected bool ApplyTryDequeue([MaybeNullWhen(false)] out T value) => _items.TryDequeue(out value); + protected void ApplyClear() => _items.Clear(); + + [DoesNotReturn] + private static void ThrowIndexOutOfRange() => throw new ArgumentOutOfRangeException("index", "Index was out of range. Must be non-negative and less than the size of the collection"); + + private IStateMachineLogWriter GetStorage() + { + Debug.Assert(_storage is not null); + return _storage; + } + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); + + private enum CommandType + { + Enqueue = 0, + Dequeue = 1, + Clear = 2, + Snapshot = 3, + } +} + +internal sealed class DurableQueueDebugView +{ + private readonly DurableQueue _queue; + + public DurableQueueDebugView(DurableQueue queue) + { + ArgumentNullException.ThrowIfNull(queue); + + _queue = queue; + } + + [DebuggerBrowsable(DebuggerBrowsableState.RootHidden)] + public T[] Items + { + get + { + return _queue.ToArray(); + } + } +} diff --git a/src/Orleans.Journaling/DurableSet.cs b/src/Orleans.Journaling/DurableSet.cs new file mode 100644 index 00000000000..371c293a708 --- /dev/null +++ b/src/Orleans.Journaling/DurableSet.cs @@ -0,0 +1,252 @@ +using System.Buffers; +using System.Collections; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +public interface IDurableSet : ISet, IReadOnlyCollection, IReadOnlySet +{ + new int Count { get; } + new bool Contains(T item); + new bool Add(T item); + new bool IsProperSubsetOf(IEnumerable other); + new bool IsProperSupersetOf(IEnumerable other); + new bool IsSubsetOf(IEnumerable other); + new bool IsSupersetOf(IEnumerable other); + new bool Overlaps(IEnumerable other); + new bool SetEquals(IEnumerable other); +} + +[DebuggerTypeProxy(typeof(IDurableCollectionDebugView<>))] +[DebuggerDisplay("Count = {Count}")] +internal sealed class DurableSet : IDurableSet, IDurableStateMachine +{ + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _codec; + private const byte VersionByte = 0; + private readonly HashSet _items = []; + private IStateMachineLogWriter? _storage; + + public DurableSet([ServiceKey] string key, IStateMachineManager manager, IFieldCodec codec, SerializerSessionPool serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + _codec = codec; + _serializerSessionPool = serializerSessionPool; + manager.RegisterStateMachine(key, this); + } + + public int Count => _items.Count; + public bool IsReadOnly => false; + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) + { + _items.Clear(); + _storage = storage; + } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != VersionByte) + { + throw new NotSupportedException($"This instance of {nameof(DurableSet)} supports version {(uint)VersionByte} and not version {(uint)version}."); + } + + var commandType = (CommandType)reader.ReadVarUInt32(); + switch (commandType) + { + case CommandType.Add: + ApplyAdd(ReadValue(ref reader)); + break; + case CommandType.Remove: + ApplyRemove(ReadValue(ref reader)); + break; + case CommandType.Clear: + ApplyClear(); + break; + case CommandType.Snapshot: + ApplySnapshot(ref reader); + break; + default: + throw new NotSupportedException($"Command type {commandType} is not supported"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + T ReadValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _codec.ReadValue(ref reader, field); + } + + void ApplySnapshot(ref Reader reader) + { + var count = (int)reader.ReadVarUInt32(); + ApplyClear(); + _items.EnsureCapacity(count); + for (var i = 0; i < count; i++) + { + ApplyAdd(ReadValue(ref reader)); + } + } + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) + { + // This state machine implementation appends log entries as the data structure is modified, so there is no need to perform separate writing here. + } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) + { + snapshotWriter.AppendEntry(WriteSnapshotToBufferWriter, this); + } + + private static void WriteSnapshotToBufferWriter(DurableSet self, IBufferWriter bufferWriter) + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Snapshot); + writer.WriteVarUInt32((uint)self._items.Count); + foreach (var item in self._items) + { + self._codec.WriteField(ref writer, 0, typeof(T), item); + } + + writer.Commit(); + } + + public void Clear() + { + ApplyClear(); + GetStorage().AppendEntry(static (state, bufferWriter) => + { + using var session = state._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Clear); + writer.Commit(); + }, + this); + } + + public bool Contains(T item) => _items.Contains(item); + public void CopyTo(T[] array, int arrayIndex) => _items.CopyTo(array, arrayIndex); + public IEnumerator GetEnumerator() => _items.GetEnumerator(); + public bool Add(T item) + { + if (ApplyAdd(item)) + { + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, item) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Add); + self._codec.WriteField(ref writer, 0, typeof(T), item!); + writer.Commit(); + }, + (this, item)); + return true; + } + + return false; + } + + public bool Remove(T item) + { + if (ApplyRemove(item)) + { + GetStorage().AppendEntry(static (state, bufferWriter) => + { + var (self, item) = state; + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.Remove); + self._codec.WriteField(ref writer, 0, typeof(T), item!); + writer.Commit(); + }, + (this, item)); + return true; + } + + return false; + } + + IEnumerator IEnumerable.GetEnumerator() => _items.GetEnumerator(); + + protected bool ApplyAdd(T item) => _items.Add(item); + protected bool ApplyRemove(T item) => _items.Remove(item); + protected void ApplyClear() => _items.Clear(); + + [DoesNotReturn] + private static void ThrowIndexOutOfRange() => throw new ArgumentOutOfRangeException("index", "Index was out of range. Must be non-negative and less than the size of the collection"); + + private IStateMachineLogWriter GetStorage() + { + Debug.Assert(_storage is not null); + return _storage; + } + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); + public void ExceptWith(IEnumerable other) + { + foreach (var item in other) + { + Remove(item); + } + } + + public bool IsProperSubsetOf(IEnumerable other) => _items.IsProperSubsetOf(other); + public bool IsProperSupersetOf(IEnumerable other) => _items.IsProperSupersetOf(other); + public bool IsSubsetOf(IEnumerable other) => _items.IsSubsetOf(other); + public bool IsSupersetOf(IEnumerable other) => _items.IsSupersetOf(other); + public bool Overlaps(IEnumerable other) => _items.Overlaps(other); + public bool SetEquals(IEnumerable other) => _items.SetEquals(other); + void ICollection.Add(T item) => Add(item); + + public void IntersectWith(IEnumerable other) + { + var initialCount = Count; + _items.IntersectWith(other); + if (Count != initialCount) + { + GetStorage().AppendEntry(WriteSnapshotToBufferWriter, this); + } + } + + public void SymmetricExceptWith(IEnumerable other) + { + var initialCount = Count; + _items.SymmetricExceptWith(other); + if (Count != initialCount) + { + GetStorage().AppendEntry(WriteSnapshotToBufferWriter, this); + } + } + + public void UnionWith(IEnumerable other) + { + foreach (var item in other) + { + Add(item); + } + } + + private enum CommandType + { + Add = 0, + Remove = 1, + Clear = 2, + Snapshot = 3, + } +} diff --git a/src/Orleans.Journaling/DurableState.cs b/src/Orleans.Journaling/DurableState.cs new file mode 100644 index 00000000000..79e759ae1a0 --- /dev/null +++ b/src/Orleans.Journaling/DurableState.cs @@ -0,0 +1,126 @@ +using System.Buffers; +using System.Diagnostics; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Core; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +[DebuggerDisplay("{Value}")] +internal sealed class DurableState : IPersistentState, IDurableStateMachine +{ + private const byte VersionByte = 0; + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _codec; + private readonly IStateMachineManager _manager; + private T? _value; + private ulong _version; + + public DurableState([ServiceKey] string key, IStateMachineManager manager, IFieldCodec codec, SerializerSessionPool serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + _codec = codec; + _serializerSessionPool = serializerSessionPool; + manager.RegisterStateMachine(key, this); + _manager = manager; + } + + public Action? OnPersisted { get; set; } + T IStorage.State + { + get => _value ??= Activator.CreateInstance(); + set => _value = value; + } + + string IStorage.Etag => $"{_version}"; + bool IStorage.RecordExists => _version > 0; + + private void OnValuePersisted() + { + ++_version; + OnPersisted?.Invoke(); + } + + void IDurableStateMachine.OnRecoveryCompleted() => OnValuePersisted(); + void IDurableStateMachine.OnWriteCompleted() => OnValuePersisted(); + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) => _value = default; + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != VersionByte) + { + throw new NotSupportedException($"This instance of {nameof(DurableState)} supports version {(uint)VersionByte} and not version {(uint)version}."); + } + + var commandType = (CommandType)reader.ReadVarUInt32(); + switch (commandType) + { + case CommandType.ClearValue: + ClearValue(ref reader); + break; + case CommandType.SetValue: + SetValue(ref reader); + break; + default: + throw new NotSupportedException($"Command type {commandType} is not supported"); + } + + void SetValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + _value = _codec.ReadValue(ref reader, field); + _version = reader.ReadVarUInt64(); + } + + void ClearValue(ref Reader reader) + { + _value = default; + _version = 0; + } + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) => WriteState(logWriter); + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) => WriteState(snapshotWriter); + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); + + private void WriteState(StateMachineStorageWriter writer) + { + writer.AppendEntry(static (self, bufferWriter) => + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.SetValue); + self._codec.WriteField(ref writer, 0, typeof(T), self._value!); + writer.WriteVarUInt64(self._version); + writer.Commit(); + }, this); + } + + Task IStorage.ClearStateAsync() => ((IStorage)this).ClearStateAsync(CancellationToken.None); + async Task IStorage.ClearStateAsync(CancellationToken cancellationToken) + { + _value = default; + _version = 0; + await _manager.WriteStateAsync(cancellationToken); + } + + Task IStorage.WriteStateAsync() => ((IStorage)this).WriteStateAsync(CancellationToken.None); + async Task IStorage.WriteStateAsync(CancellationToken cancellationToken) => await _manager.WriteStateAsync(cancellationToken); + Task IStorage.ReadStateAsync() => ((IStorage)this).ReadStateAsync(CancellationToken.None); + Task IStorage.ReadStateAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + private enum CommandType + { + SetValue, + ClearValue, + } +} diff --git a/src/Orleans.Journaling/DurableTaskCompletionSource.cs b/src/Orleans.Journaling/DurableTaskCompletionSource.cs new file mode 100644 index 00000000000..77ff8f942eb --- /dev/null +++ b/src/Orleans.Journaling/DurableTaskCompletionSource.cs @@ -0,0 +1,228 @@ +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +public interface IDurableTaskCompletionSource +{ + Task Task { get; } + DurableTaskCompletionSourceState State { get; } + + bool TrySetCanceled(); + bool TrySetException(Exception exception); + bool TrySetResult(T value); +} + +[DebuggerDisplay("Status = {Status}")] +internal sealed class DurableTaskCompletionSource : IDurableTaskCompletionSource, IDurableStateMachine +{ + private const byte SupportedVersion = 0; + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _codec; + private readonly IFieldCodec _exceptionCodec; + private readonly DeepCopier _copier; + private readonly DeepCopier _exceptionCopier; + + private TaskCompletionSource _completion = new(TaskCreationOptions.RunContinuationsAsynchronously); + private IStateMachineLogWriter? _storage; + private DurableTaskCompletionSourceStatus _status; + private T? _value; + private Exception? _exception; + + public DurableTaskCompletionSource( + [ServiceKey] string key, + IStateMachineManager manager, + IFieldCodec codec, + DeepCopier copier, + IFieldCodec exceptionCodec, + DeepCopier exceptionCopier, + SerializerSessionPool serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + _codec = codec; + _copier = copier; + _exceptionCodec = exceptionCodec; + _exceptionCopier = exceptionCopier; + _serializerSessionPool = serializerSessionPool; + manager.RegisterStateMachine(key, this); + } + + public bool TrySetResult(T value) + { + if (_status is not DurableTaskCompletionSourceStatus.Pending) + { + return false; + } + + _status = DurableTaskCompletionSourceStatus.Completed; + _value = _copier.Copy(value); + return true; + } + + public bool TrySetException(Exception exception) + { + if (_status is not DurableTaskCompletionSourceStatus.Pending) + { + return false; + } + + _status = DurableTaskCompletionSourceStatus.Faulted; + _exception = _exceptionCopier.Copy(exception); + return true; + } + + public bool TrySetCanceled() + { + if (_status is not DurableTaskCompletionSourceStatus.Pending) + { + return false; + } + + _status = DurableTaskCompletionSourceStatus.Canceled; + return true; + } + + public Task Task => _completion.Task; + + public DurableTaskCompletionSourceState State => _status switch + { + DurableTaskCompletionSourceStatus.Pending => new DurableTaskCompletionSourceState { Status = DurableTaskCompletionSourceStatus.Pending }, + DurableTaskCompletionSourceStatus.Completed => new DurableTaskCompletionSourceState { Status = DurableTaskCompletionSourceStatus.Completed, Value = _value }, + DurableTaskCompletionSourceStatus.Faulted => new DurableTaskCompletionSourceState { Status = DurableTaskCompletionSourceStatus.Faulted, Exception = _exception }, + DurableTaskCompletionSourceStatus.Canceled => new DurableTaskCompletionSourceState { Status = DurableTaskCompletionSourceStatus.Canceled }, + _ => throw new InvalidOperationException($"Unexpected status, \"{_status}\""), + }; + + private void OnValuePersisted() + { + switch (_status) + { + case DurableTaskCompletionSourceStatus.Completed: + _completion.TrySetResult(_value!); + break; + case DurableTaskCompletionSourceStatus.Faulted: + _completion.TrySetException(_exception!); + break; + case DurableTaskCompletionSourceStatus.Canceled: + _completion.TrySetCanceled(); + break; + default: + break; + } + } + + void IDurableStateMachine.OnRecoveryCompleted() => OnValuePersisted(); + void IDurableStateMachine.OnWriteCompleted() => OnValuePersisted(); + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) + { + // Reset the task completion source if necessary. + if (_completion.Task.IsCompleted) + { + _completion = new(TaskCreationOptions.RunContinuationsAsynchronously); + } + + _storage = storage; + } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != SupportedVersion) + { + throw new NotSupportedException($"This instance of {nameof(DurableTaskCompletionSource)} supports version {(uint)SupportedVersion} and not version {(uint)version}."); + } + + _status = (DurableTaskCompletionSourceStatus)reader.ReadVarUInt32(); + switch (_status) + { + case DurableTaskCompletionSourceStatus.Completed: + _value = ReadValue(ref reader); + break; + case DurableTaskCompletionSourceStatus.Faulted: + _exception = ReadException(ref reader); + break; + default: + break; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + T ReadValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _codec.ReadValue(ref reader, field); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + Exception ReadException(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _exceptionCodec.ReadValue(ref reader, field); + } + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) + { + if (_status is not DurableTaskCompletionSourceStatus.Pending) + { + WriteState(logWriter); + } + } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) => WriteState(snapshotWriter); + + private void WriteState(StateMachineStorageWriter writer) + { + writer.AppendEntry(static (self, bufferWriter) => + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(DurableTaskCompletionSource.SupportedVersion); + var status = self._status; + writer.WriteByte((byte)status); + if (status is DurableTaskCompletionSourceStatus.Completed) + { + self._codec.WriteField(ref writer, 0, typeof(T), self._value!); + } + else if (status is DurableTaskCompletionSourceStatus.Faulted) + { + self._exceptionCodec.WriteField(ref writer, 0, typeof(Exception), self._exception!); + } + + writer.Commit(); + }, this); + } + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); +} + +[GenerateSerializer] +public enum DurableTaskCompletionSourceStatus : byte +{ + Pending = 0, + Completed, + Faulted, + Canceled +} + +[GenerateSerializer, Immutable] +public readonly struct DurableTaskCompletionSourceState +{ + [Id(0)] + public DurableTaskCompletionSourceStatus Status { get; init; } + + [Id(1)] + public T? Value { get; init; } + + [Id(2)] + public Exception? Exception { get; init; } +} + diff --git a/src/Orleans.Journaling/DurableValue.cs b/src/Orleans.Journaling/DurableValue.cs new file mode 100644 index 00000000000..17cb2d19765 --- /dev/null +++ b/src/Orleans.Journaling/DurableValue.cs @@ -0,0 +1,129 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.DependencyInjection; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +public interface IDurableValue +{ + T? Value { get; set; } +} + +[DebuggerDisplay("{Value}")] +internal sealed class DurableValue : IDurableValue, IDurableStateMachine +{ + private const byte VersionByte = 0; + private readonly SerializerSessionPool _serializerSessionPool; + private readonly IFieldCodec _codec; + private IStateMachineLogWriter? _storage; + private T? _value; + private bool _isDirty; + + public DurableValue([ServiceKey] string key, IStateMachineManager manager, IFieldCodec codec, SerializerSessionPool serializerSessionPool) + { + ArgumentNullException.ThrowIfNullOrEmpty(key); + _codec = codec; + _serializerSessionPool = serializerSessionPool; + manager.RegisterStateMachine(key, this); + } + + public T? Value + { + get => _value; + set + { + _value = value; + OnModified(); + } + } + + public Action? OnPersisted { get; set; } + + private void OnValuePersisted() => OnPersisted?.Invoke(); + + public void OnModified() => _isDirty = true; + + void IDurableStateMachine.OnRecoveryCompleted() => OnValuePersisted(); + void IDurableStateMachine.OnWriteCompleted() => OnValuePersisted(); + + void IDurableStateMachine.Reset(IStateMachineLogWriter storage) + { + _value = default; + _storage = storage; + } + + void IDurableStateMachine.Apply(ReadOnlySequence logEntry) + { + using var session = _serializerSessionPool.GetSession(); + var reader = Reader.Create(logEntry, session); + var version = reader.ReadByte(); + if (version != VersionByte) + { + throw new NotSupportedException($"This instance of {nameof(DurableValue)} supports version {(uint)VersionByte} and not version {(uint)version}."); + } + + var commandType = (CommandType)reader.ReadVarUInt32(); + switch (commandType) + { + case CommandType.SetValue: + SetValue(ref reader); + break; + default: + throw new NotSupportedException($"Command type {commandType} is not supported"); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + T ReadValue(ref Reader reader) + { + var field = reader.ReadFieldHeader(); + return _codec.ReadValue(ref reader, field); + } + + void SetValue(ref Reader reader) => _value = ReadValue(ref reader); + } + + void IDurableStateMachine.AppendEntries(StateMachineStorageWriter logWriter) + { + if (_isDirty) + { + WriteState(logWriter); + _isDirty = false; + } + } + + void IDurableStateMachine.AppendSnapshot(StateMachineStorageWriter snapshotWriter) => WriteState(snapshotWriter); + + public IDurableStateMachine DeepCopy() => throw new NotImplementedException(); + + private void WriteState(StateMachineStorageWriter writer) + { + writer.AppendEntry(static (self, bufferWriter) => + { + using var session = self._serializerSessionPool.GetSession(); + var writer = Writer.Create(bufferWriter, session); + writer.WriteByte(VersionByte); + writer.WriteVarUInt32((uint)CommandType.SetValue); + self._codec.WriteField(ref writer, 0, typeof(T), self._value!); + writer.Commit(); + }, this); + } + + [DoesNotReturn] + private static void ThrowIndexOutOfRange() => throw new ArgumentOutOfRangeException("index", "Index was out of range. Must be non-negative and less than the size of the collection"); + + private IStateMachineLogWriter GetStorage() + { + Debug.Assert(_storage is not null); + return _storage; + } + + private enum CommandType + { + SetValue, + } +} diff --git a/src/Orleans.Journaling/HostingExtensions.cs b/src/Orleans.Journaling/HostingExtensions.cs new file mode 100644 index 00000000000..4eff8b3e93d --- /dev/null +++ b/src/Orleans.Journaling/HostingExtensions.cs @@ -0,0 +1,21 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Orleans.Journaling; +public static class HostingExtensions +{ + public static ISiloBuilder AddStateMachineStorage(this ISiloBuilder builder) + { + builder.Services.TryAddScoped(sp => sp.GetRequiredService().Create(sp.GetRequiredService())); + builder.Services.TryAddScoped(); + builder.Services.TryAddKeyedScoped(typeof(IDurableDictionary<,>), KeyedService.AnyKey, typeof(DurableDictionary<,>)); + builder.Services.TryAddKeyedScoped(typeof(IDurableList<>), KeyedService.AnyKey, typeof(DurableList<>)); + builder.Services.TryAddKeyedScoped(typeof(IDurableQueue<>), KeyedService.AnyKey, typeof(DurableQueue<>)); + builder.Services.TryAddKeyedScoped(typeof(IDurableSet<>), KeyedService.AnyKey, typeof(DurableSet<>)); + builder.Services.TryAddKeyedScoped(typeof(IDurableValue<>), KeyedService.AnyKey, typeof(DurableValue<>)); + builder.Services.TryAddKeyedScoped(typeof(IPersistentState<>), KeyedService.AnyKey, typeof(DurableState<>)); + builder.Services.TryAddKeyedScoped(typeof(IDurableTaskCompletionSource<>), KeyedService.AnyKey, typeof(DurableTaskCompletionSource<>)); + builder.Services.TryAddKeyedScoped(typeof(IDurableNothing), KeyedService.AnyKey, typeof(DurableNothing)); + return builder; + } +} diff --git a/src/Orleans.Journaling/IDurableStateMachine.cs b/src/Orleans.Journaling/IDurableStateMachine.cs new file mode 100644 index 00000000000..8c9c81d91ac --- /dev/null +++ b/src/Orleans.Journaling/IDurableStateMachine.cs @@ -0,0 +1,57 @@ +using System.Buffers; + +namespace Orleans.Journaling; + +/// +/// Interface for a state machine which can be persisted to durable storage. +/// +public interface IDurableStateMachine +{ + /// + /// Resets the state machine. + /// + /// + /// If the state machine has any volatile state, it must be cleared by this method. + /// This method can be called at any point in the state machine's lifetime, including during recovery. + /// + void Reset(IStateMachineLogWriter storage); + + /// + /// Called during recovery to apply the provided log entry or snapshot. + /// + /// The log entry or snapshot. + void Apply(ReadOnlySequence entry); + + /// + /// Notifies the state machine that all prior log entries and snapshots have been applied. + /// + /// + /// The state machine should not expect any additional calls after this method is called, + /// unless is called to reset the state machine to its initial state. + /// This method will be called before any or calls. + /// + void OnRecoveryCompleted() { } + + /// + /// Writes pending state changes to the log. + /// + /// The log writer. + void AppendEntries(StateMachineStorageWriter writer); + + /// + /// Writes a snapshot of the state machine to the provided writer. + /// + /// The log writer. + void AppendSnapshot(StateMachineStorageWriter writer); + + /// + /// Notifies the state machine that all prior log entries and snapshots which it has written have been written to stable storage. + /// + void OnWriteCompleted() { } + + /// + /// Creates and returns a deep copy of this instance. All replicas must be independent such that changes to one do not affect any other. + /// + /// A replica of this instance. + IDurableStateMachine DeepCopy(); +} diff --git a/src/Orleans.Journaling/IStateMachineLogWriter.cs b/src/Orleans.Journaling/IStateMachineLogWriter.cs new file mode 100644 index 00000000000..0fed26d934a --- /dev/null +++ b/src/Orleans.Journaling/IStateMachineLogWriter.cs @@ -0,0 +1,25 @@ +using System.Buffers; + +namespace Orleans.Journaling; + +/// +/// Provides functionality for writing out-of-band log entries to the log for the state machine which holds this instance. +/// +public interface IStateMachineLogWriter +{ + /// + /// Appends an entry to the log for the state machine which holds this instance. + /// + /// The state, passed to the delegate. + /// The delegate invoked to append a log entry. + /// The state passed to . + void AppendEntry(Action> action, TState state); + + /// + /// Appends an entry to the log for the state machine which holds this instance. + /// + /// The state, passed to the delegate. + /// The delegate invoked to append a log entry. + /// The state passed to . + void AppendEntries(Action action, TState state); +} diff --git a/src/Orleans.Journaling/IStateMachineManager.cs b/src/Orleans.Journaling/IStateMachineManager.cs new file mode 100644 index 00000000000..a10afeeb7f6 --- /dev/null +++ b/src/Orleans.Journaling/IStateMachineManager.cs @@ -0,0 +1,44 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Orleans.Journaling; + +/// +/// Manages the state machines for a given grain. +/// +public interface IStateMachineManager +{ + /// + /// Initializes the state machine manager. + /// + /// The cancellation token. + /// A which represents the operation. + ValueTask InitializeAsync(CancellationToken cancellationToken); + + /// + /// Registers a state machine with the manager. + /// + /// The state machine's stable identifier. + /// The state machine instance to register. + void RegisterStateMachine(string name, IDurableStateMachine stateMachine); + + /// + /// Registers a state machine with the manager. + /// + /// The state machine's stable identifier. + /// The state machine instance to register. + bool TryGetStateMachine(string name, [NotNullWhen(true)] out IDurableStateMachine? stateMachine); + + /// + /// Prepares and persists an update to the log. + /// + /// The cancellation token. + /// A which represents the operation. + ValueTask WriteStateAsync(CancellationToken cancellationToken); + + /// + /// Resets this instance, removing any persistent state. + /// + /// The cancellation token. + /// A which represents the operation. + ValueTask DeleteStateAsync(CancellationToken cancellationToken); +} diff --git a/src/Orleans.Journaling/IStateMachineStorage.cs b/src/Orleans.Journaling/IStateMachineStorage.cs new file mode 100644 index 00000000000..32dca010f39 --- /dev/null +++ b/src/Orleans.Journaling/IStateMachineStorage.cs @@ -0,0 +1,42 @@ +namespace Orleans.Journaling; + +/// +/// Provides storage for state machines. +/// +public interface IStateMachineStorage +{ + /// + /// Returns an ordered collection of all log segments belonging to this instance. + /// + /// The cancellation token. + /// An ordered collection of all log segments belonging to this instance. + IAsyncEnumerable ReadAsync(CancellationToken cancellationToken); + + /// + /// Replaces the log with the provided value atomically. + /// + /// The value to write. + /// The cancellation token. + /// A representing the operation. + ValueTask ReplaceAsync(LogExtentBuilder value, CancellationToken cancellationToken); + + /// + /// Appends the provided segment to the log atomically. + /// + /// The segment to append. + /// The cancellation token. + /// A representing the operation. + ValueTask AppendAsync(LogExtentBuilder value, CancellationToken cancellationToken); + + /// + /// Deletes the state machine's log atomically. + /// + /// The cancellation token. + /// A representing the operation. + ValueTask DeleteAsync(CancellationToken cancellationToken); + + /// + /// Gets a value indicating whether the state machine has requested a snapshot. + /// + bool IsCompactionRequested { get; } +} diff --git a/src/Orleans.Journaling/IStateMachineStorageProvider.cs b/src/Orleans.Journaling/IStateMachineStorageProvider.cs new file mode 100644 index 00000000000..21c2ac4bb8e --- /dev/null +++ b/src/Orleans.Journaling/IStateMachineStorageProvider.cs @@ -0,0 +1,6 @@ +namespace Orleans.Journaling; + +public interface IStateMachineStorageProvider +{ + IStateMachineStorage Create(IGrainContext grainContext); +} diff --git a/src/Orleans.Journaling/LogExtent.cs b/src/Orleans.Journaling/LogExtent.cs new file mode 100644 index 00000000000..c3c7bc01e4b --- /dev/null +++ b/src/Orleans.Journaling/LogExtent.cs @@ -0,0 +1,100 @@ +using System.Buffers; +using System.Collections; +using Orleans.Serialization.Buffers; +using System.Diagnostics; + +namespace Orleans.Journaling; + +/// +/// Represents a log segment which has been sealed and is no longer mutable. +/// +public sealed class LogExtent(ArcBuffer buffer) : IDisposable +{ + private ArcBuffer _buffer = buffer; + + public LogExtent() : this(new()) + { + } + + public bool IsEmpty => _buffer.Length == 0; + + internal EntryEnumerator Entries => EntryEnumerator.Create(this); + + public void Dispose() => _buffer.Dispose(); + + public readonly record struct Entry(StateMachineId StreamId, ReadOnlySequence Payload); + + internal struct EntryEnumerator : IEnumerable, IEnumerator, IDisposable + { + private LogExtent _logExtent; + private ReadOnlySequence _current; + private int _length; + + private EntryEnumerator(LogExtent logExtent) + { + _logExtent = logExtent; + _current = logExtent._buffer.AsReadOnlySequence(); + _length = -2; + } + + public readonly EntryEnumerator GetEnumerator() => this; + + public static EntryEnumerator Create(LogExtent logSegment) => new(logSegment); + + public bool MoveNext() + { + if (_length == -1) + { + ThrowEnumerationNotStartedOrEnded(); + } + + if (_length >= 0) + { + // Advance the cursor. + _current = _current.Slice(_length); + } + + if (_current.Length == 0) + { + _length = -1; + return false; + } + + var reader = Reader.Create(_current, null); + _length = (int)reader.ReadVarUInt32(); + _current = _current.Slice(reader.Position); + return true; + } + + public readonly Entry Current + { + get + { + if (_length < 0) + { + ThrowEnumerationNotStartedOrEnded(); + } + + var slice = _current.Slice(0, _length); + var reader = Reader.Create(slice, null); + var id = reader.ReadVarUInt32(); + return new(new(id), slice.Slice(reader.Position)); + } + } + + private readonly void ThrowEnumerationNotStartedOrEnded() + { + Debug.Assert(_length is (-1) or (-2)); + throw new InvalidOperationException(_length == -2 ? "Enumeration has not started." : "Enumeration has completed."); + } + + readonly object? IEnumerator.Current => Current; + + public void Reset() => this = new(_logExtent); + + public void Dispose() => _length = -1; + + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/src/Orleans.Journaling/LogExtentBuilder.ReadOnlyStream.cs b/src/Orleans.Journaling/LogExtentBuilder.ReadOnlyStream.cs new file mode 100644 index 00000000000..de189d4a708 --- /dev/null +++ b/src/Orleans.Journaling/LogExtentBuilder.ReadOnlyStream.cs @@ -0,0 +1,107 @@ +using System.Diagnostics; +using System.Numerics; + +namespace Orleans.Journaling; + +public sealed partial class LogExtentBuilder +{ + public sealed class ReadOnlyStream : Stream + { + private LogExtentBuilder? _builder; + private int _length; + private int _position; + + public ReadOnlyStream() { } + + public override bool CanRead => true; + public override bool CanSeek => true; + public override bool CanWrite => false; + public override long Length => _length; + public override long Position { get => _position; set => SetPosition((int)value); } + + public override int Read(byte[] buffer, int offset, int count) => Read(buffer.AsSpan(offset, count)); + + public override int Read(Span buffer) => throw new NotImplementedException(); + + public override long Seek(long offset, SeekOrigin origin) + { + switch (origin) + { + case SeekOrigin.Begin: + SetPosition((int)offset); + break; + case SeekOrigin.Current: + SetPosition(_position + (int)offset); + break; + case SeekOrigin.End: + SetPosition(_length - (int)offset); + break; + default: + throw new ArgumentOutOfRangeException(nameof(origin)); + } + + return Position; + } + + private void SetPosition(int value) + { + if (value > _length || value < 0) throw new ArgumentOutOfRangeException(nameof(value)); + _position = value; + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) => new(Read(buffer.Span)); + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => Task.FromResult(Read(buffer, offset, count)); + + public override void CopyTo(Stream destination, int bufferSize) + { + ValidateCopyToArguments(destination, bufferSize); + _builder!.CopyToAsync(destination, bufferSize, default).AsTask().GetAwaiter().GetResult(); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + ValidateCopyToArguments(destination, bufferSize); + + if (_position != 0) throw new NotImplementedException("Position must be zero for this copy operation"); + + return _builder!.CopyToAsync(destination, bufferSize, cancellationToken).AsTask(); + } + + public override void Flush() => throw GetReadOnlyException(); + public override void WriteByte(byte value) => throw GetReadOnlyException(); + public override void SetLength(long value) => throw GetReadOnlyException(); + public override void Write(byte[] buffer, int offset, int count) => throw GetReadOnlyException(); + public override void Write(ReadOnlySpan buffer) => throw GetReadOnlyException(); + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => throw GetReadOnlyException(); + + public void SetBuilder(LogExtentBuilder builder) + { + _builder = builder; + _position = 0; + _length = ComputeLength(); + } + + public void Reset() + { + _builder = default; + _position = 0; + _length = 0; + } + + private int ComputeLength() + { + Debug.Assert(_builder!._entryLengths is not null); + var length = _builder!._buffer.Length; + foreach (var entry in _builder!._entryLengths) + { + length += GetVarIntWidth(entry); + } + + return length; + } + + private static int GetVarIntWidth(uint value) => 1 + (int)((uint)BitOperations.Log2(value) / 7); + + private static NotSupportedException GetReadOnlyException() => new("This stream is read-only"); + } +} diff --git a/src/Orleans.Journaling/LogExtentBuilder.cs b/src/Orleans.Journaling/LogExtentBuilder.cs new file mode 100644 index 00000000000..6a4728f5f8b --- /dev/null +++ b/src/Orleans.Journaling/LogExtentBuilder.cs @@ -0,0 +1,160 @@ +using System.Buffers; +using System.Diagnostics; +using Orleans.Serialization.Buffers; +using Orleans.Serialization.Buffers.Adaptors; + +namespace Orleans.Journaling; + +/// +/// A mutable builder for creating log segments. +/// +public sealed partial class LogExtentBuilder(ArcBufferWriter buffer) : IDisposable, IBufferWriter +{ + private readonly List _entryLengths = []; + private readonly byte[] _scratch = new byte[8]; + private readonly ArcBufferWriter _buffer = buffer; + + public LogExtentBuilder() : this(new()) + { + } + + public long Length => _buffer.Length; + + public byte[] ToArray() + { + using var memoryStream = new PooledBufferStream(); + CopyTo(memoryStream, 4096); + return memoryStream.ToArray(); + } + + public StateMachineStorageWriter CreateLogWriter(StateMachineId id) => new(id, this); + + public bool IsEmpty => _buffer.Length == 0; + + internal void AppendEntry(StateMachineId id, byte[] value) => AppendEntry(id, (ReadOnlySpan)value); + internal void AppendEntry(StateMachineId id, Span value) => AppendEntry(id, (ReadOnlySpan)value); + internal void AppendEntry(StateMachineId id, Memory value) => AppendEntry(id, value.Span); + internal void AppendEntry(StateMachineId id, ReadOnlyMemory value) => AppendEntry(id, value.Span); + internal void AppendEntry(StateMachineId id, ArraySegment value) => AppendEntry(id, value.AsSpan()); + internal void AppendEntry(StateMachineId id, ReadOnlySpan value) + { + var startOffset = _buffer.Length; + var writer = Writer.Create(this, session: null); + writer.WriteVarUInt64(id.Value); + writer.Commit(); + + _buffer.Write(value); + + var endOffset = _buffer.Length; + _entryLengths.Add((uint)(endOffset - startOffset)); + } + + internal void AppendEntry(StateMachineId id, ReadOnlySequence value) + { + var startOffset = _buffer.Length; + + var writer = Writer.Create(this, session: null); + writer.WriteVarUInt64(id.Value); + writer.Commit(); + + _buffer.Write(value); + + var endOffset = _buffer.Length; + _entryLengths.Add((uint)(endOffset - startOffset)); + } + + internal void AppendEntry(StateMachineId id, Action> valueWriter, T value) + { + var startOffset = _buffer.Length; + + var writer = Writer.Create(this, session: null); + writer.WriteVarUInt64(id.Value); + writer.Commit(); + valueWriter(value, this); + + var endOffset = _buffer.Length; + _entryLengths.Add((uint)(endOffset - startOffset)); + } + + public void Reset() + { + _buffer.Reset(); + _entryLengths.Clear(); + } + + public void Dispose() => Reset(); + + // Implemented on this class to prevent the need to repeatedly box & unbox _buffer. + void IBufferWriter.Advance(int count) => _buffer.AdvanceWriter(count); + Memory IBufferWriter.GetMemory(int sizeHint) => _buffer.GetMemory(sizeHint); + Span IBufferWriter.GetSpan(int sizeHint) => _buffer.GetSpan(sizeHint); + + public async ValueTask CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + using var buffer = _buffer.PeekSlice(_buffer.Length); + var segments = buffer.MemorySegments; + var currentSegment = ReadOnlyMemory.Empty; + foreach (var entryLength in _entryLengths) + { + await destination.WriteAsync(GetLengthBytes(_scratch, entryLength), cancellationToken); + + var remainingEntryLength = entryLength; + while (remainingEntryLength > 0) + { + // Move to the next memory segment if necessary. + if (currentSegment.Length == 0) + { + var hasNext = segments.MoveNext(); + Debug.Assert(hasNext); + currentSegment = segments.Current; + continue; + } + + var copyLen = Math.Min((uint)bufferSize, Math.Min(remainingEntryLength, (uint)currentSegment.Length)); + + await destination.WriteAsync(currentSegment[..(int)copyLen], cancellationToken); + + remainingEntryLength -= copyLen; + currentSegment = currentSegment[(int)copyLen..]; + } + } + } + + public void CopyTo(Stream destination, int bufferSize) + { + using var buffer = _buffer.PeekSlice(_buffer.Length); + var segments = buffer.MemorySegments; + var currentSegment = ReadOnlyMemory.Empty; + foreach (var entryLength in _entryLengths) + { + destination.Write(GetLengthBytes(_scratch, entryLength).Span); + + var remainingEntryLength = entryLength; + while (remainingEntryLength > 0) + { + // Move to the next memory segment if necessary. + if (currentSegment.Length == 0) + { + var hasNext = segments.MoveNext(); + Debug.Assert(hasNext); + currentSegment = segments.Current; + continue; + } + + var copyLen = Math.Min((uint)bufferSize, Math.Min(remainingEntryLength, (uint)currentSegment.Length)); + + destination.Write(currentSegment[..(int)copyLen].Span); + + remainingEntryLength -= copyLen; + currentSegment = currentSegment[(int)copyLen..]; + } + } + } + + private static ReadOnlyMemory GetLengthBytes(byte[] scratch, uint length) + { + var writer = Writer.Create(scratch, null); + writer.WriteVarUInt32(length); + return new ReadOnlyMemory(scratch, 0, writer.Position); + } +} diff --git a/src/Orleans.Journaling/Orleans.Journaling.csproj b/src/Orleans.Journaling/Orleans.Journaling.csproj new file mode 100644 index 00000000000..6f584544d18 --- /dev/null +++ b/src/Orleans.Journaling/Orleans.Journaling.csproj @@ -0,0 +1,30 @@ + + + + Microsoft.Orleans.Journaling + Microsoft Orleans Journaling + Extensible persistence for grains based on replicated state machines. + $(PackageTags) Persistence State Machines + true + $(DefaultTargetFrameworks) + enable + enable + $(VersionSuffix).alpha.1 + alpha.1 + + + + + + + + + + + + + + + + + diff --git a/src/Orleans.Journaling/Properties/AssemblyInfo.cs b/src/Orleans.Journaling/Properties/AssemblyInfo.cs new file mode 100644 index 00000000000..9d00e67f920 --- /dev/null +++ b/src/Orleans.Journaling/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Diagnostics.CodeAnalysis; + +[assembly: Experimental("ORLEANSEXP005")] diff --git a/src/Orleans.Journaling/StateMachineId.cs b/src/Orleans.Journaling/StateMachineId.cs new file mode 100644 index 00000000000..2f7674e7f41 --- /dev/null +++ b/src/Orleans.Journaling/StateMachineId.cs @@ -0,0 +1,7 @@ +namespace Orleans.Journaling; + +/// +/// Identifies a state machine. +/// +/// The underlying identity value. +public readonly record struct StateMachineId(ulong Value); diff --git a/src/Orleans.Journaling/StateMachineManager.cs b/src/Orleans.Journaling/StateMachineManager.cs new file mode 100644 index 00000000000..f45527075c1 --- /dev/null +++ b/src/Orleans.Journaling/StateMachineManager.cs @@ -0,0 +1,452 @@ +using System.Buffers; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.Logging; +using Orleans.Runtime.Internal; +using Orleans.Serialization.Codecs; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling; + +internal sealed partial class StateMachineManager : IStateMachineManager, ILifecycleParticipant, ILifecycleObserver, IDisposable +{ + private const int MinApplicationStateMachineId = 8; + private static readonly StringCodec StringCodec = new(); + private static readonly UInt64Codec UInt64Codec = new(); + private readonly object _lock = new(); + private readonly Dictionary _stateMachines = new(StringComparer.Ordinal); + private readonly Dictionary _stateMachinesMap = []; + private readonly IStateMachineStorage _storage; + private readonly ILogger _logger; + private readonly SingleWaiterAutoResetEvent _workSignal = new() { RunContinuationsAsynchronously = true }; + private readonly Queue _workQueue = new(); + private readonly CancellationTokenSource _shutdownCancellation = new(); + private readonly StateMachineManagerState _stateMachineIds; + private readonly Task _workLoop; + private ManagerState _state; + private Task? _pendingWrite; + private ulong _nextStateMachineId = MinApplicationStateMachineId; + private LogExtentBuilder? _currentLogSegment; + + public StateMachineManager( + IStateMachineStorage storage, + ILogger logger, + SerializerSessionPool serializerSessionPool) + { + _storage = storage; + _logger = logger; + + // The list of known state machines is itself stored as a durable state machine with the implicit id 0. + // This allows us to recover the list of state machines ids without having to store it separately. + _stateMachineIds = new StateMachineManagerState(this, StringCodec, UInt64Codec, serializerSessionPool); + _stateMachinesMap[0] = _stateMachineIds; + + _workLoop = Start(); + } + + public void RegisterStateMachine(string name, IDurableStateMachine stateMachine) + { + _shutdownCancellation.Token.ThrowIfCancellationRequested(); + ArgumentNullException.ThrowIfNullOrEmpty(name); + + lock (_lock) + { + _stateMachines.Add(name, stateMachine); + _workQueue.Enqueue(new WorkItem(WorkItemType.RegisterStateMachine, completion: null) + { + Context = name + }); + } + + _workSignal.Signal(); + } + + public async ValueTask InitializeAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + _shutdownCancellation.Token.ThrowIfCancellationRequested(); + + Task task; + lock (_lock) + { + var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + task = completion.Task; + _workQueue.Enqueue(new WorkItem(WorkItemType.Initialize, completion)); + } + + _workSignal.Signal(); + await task; + } + + private Task Start() + { + using var suppressExecutionContext = new ExecutionContextSuppressor(); + return WorkLoop(); + } + + private async Task WorkLoop() + { + var cancellationToken = _shutdownCancellation.Token; + using var cancellationRegistration = cancellationToken.Register(state => ((StateMachineManager)state!)._workSignal.Signal(), this); + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext | ConfigureAwaitOptions.ForceYielding); + var needsRecovery = true; + while (true) + { + try + { + await _workSignal.WaitAsync().ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + + while (true) + { + if (needsRecovery) + { + await RecoverAsync(cancellationToken).ConfigureAwait(false); + needsRecovery = false; + } + + WorkItem workItem; + lock (_lock) + { + if (!_workQueue.TryDequeue(out workItem)) + { + // Wait for the queue to be signaled again. + break; + } + } + + try + { + // Note that the implementation of each command is inlined to avoid allocating unnecessary async state machines. + // We are ok sacrificing some code organization for performance in the inner loop. + if (workItem.Type is WorkItemType.AppendLog or WorkItemType.WriteSnapshot) + { + // TODO: decide whether it's best to snapshot or append. Eg, by summing the size of the most recent snapshots and the current log length. + // If the current log length is greater than the snapshot size, then take a snapshot instead of appending more log entries. + var isSnapshot = workItem.Type is WorkItemType.WriteSnapshot; + LogExtentBuilder? logSegment; + lock (_lock) + { + if (isSnapshot && _currentLogSegment is { } existingSegment) + { + // If there are pending writes, reset them since they will be captured by the snapshot instead. + // If we did not do this, the log would begin with some writes which would be followed by a snapshot which also included those writes. + existingSegment.Reset(); + } + else + { + _currentLogSegment ??= new(); + } + + // The map of state machine ids is itself stored as a durable state machine with the id 0. + // This must be stored first, since it includes the identities of all other state machines, which are needed when replaying the log. + AppendUpdatesOrSnapshotStateMachine(_currentLogSegment, isSnapshot, 0, _stateMachineIds); + + foreach (var (id, stateMachine) in _stateMachinesMap) + { + if (id is 0 || stateMachine is null) + { + // Skip state machines which have been removed. + continue; + } + + AppendUpdatesOrSnapshotStateMachine(_currentLogSegment, isSnapshot, id, stateMachine); + } + + if (_currentLogSegment.IsEmpty) + { + logSegment = null; + } + else + { + logSegment = _currentLogSegment; + _currentLogSegment = null; + } + } + + if (logSegment is not null) + { + if (isSnapshot) + { + await _storage.ReplaceAsync(logSegment, cancellationToken).ConfigureAwait(false); + } + else + { + await _storage.AppendAsync(logSegment, cancellationToken).ConfigureAwait(false); + } + + // Notify all state machines that the operation completed. + lock (_lock) + { + foreach (var stateMachine in _stateMachines.Values) + { + stateMachine.OnWriteCompleted(); + } + } + } + } + else if (workItem.Type is WorkItemType.DeleteState) + { + // Clear storage. + await _storage.DeleteAsync(cancellationToken).ConfigureAwait(false); + + lock (_lock) + { + // Reset the state machine id collection. + _stateMachineIds.ResetVolatileState(); + + // Allocate new state machine ids for each state machine. + // Doing so will trigger a reset, since _stateMachineIds will call OnSetStateMachineId, which resets the state machine in question. + _nextStateMachineId = 1; + foreach (var (name, stateMachine) in _stateMachines) + { + var id = _nextStateMachineId++; + _stateMachineIds[name] = id; + } + } + } + else if (workItem.Type is WorkItemType.Initialize) + { + lock (_lock) + { + _state = ManagerState.Ready; + } + } + else if (workItem.Type is WorkItemType.RegisterStateMachine) + { + lock (_lock) + { + if (_state is not ManagerState.Unknown) + { + throw new NotSupportedException("Registering a state machine after activation is not supported."); + } + + var name = (string)workItem.Context!; + if (!_stateMachineIds.ContainsKey(name)) + { + // Doing so will trigger a reset, since _stateMachineIds will call OnSetStateMachineId, which resets the state machine in question. + _stateMachineIds[name] = _nextStateMachineId++; + } + } + } + else + { + Debug.Fail($"The command {workItem.Type} is unsupported"); + } + + workItem.CompletionSource?.SetResult(); + } + catch (Exception exception) + { + workItem.CompletionSource?.SetException(exception); + needsRecovery = true; + } + } + } + catch (Exception exception) + { + needsRecovery = true; + if (cancellationToken.IsCancellationRequested) + { + return; + } + + LogErrorProcessingWorkItems(_logger, exception); + } + } + } + + private static void AppendUpdatesOrSnapshotStateMachine(LogExtentBuilder logSegment, bool isSnapshot, ulong id, IDurableStateMachine stateMachine) + { + var writer = logSegment.CreateLogWriter(new(id)); + if (isSnapshot) + { + stateMachine.AppendSnapshot(writer); + } + else + { + stateMachine.AppendEntries(writer); + } + } + + public async ValueTask DeleteStateAsync(CancellationToken cancellationToken) + { + Task task; + lock (_lock) + { + var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + task = completion.Task; + _workQueue.Enqueue(new WorkItem(WorkItemType.DeleteState, completion)); + } + + _workSignal.Signal(); + await task; + } + + private async Task RecoverAsync(CancellationToken cancellationToken) + { + _stateMachineIds.ResetVolatileState(); + await foreach (var segment in _storage.ReadAsync(cancellationToken)) + { + cancellationToken.ThrowIfCancellationRequested(); + try + { + foreach (var entry in segment.Entries) + { + var stateMachine = _stateMachinesMap[entry.StreamId.Value]; + stateMachine.Apply(entry.Payload); + } + } + finally + { + segment.Dispose(); + } + } + + foreach (var stateMachine in _stateMachines.Values) + { + stateMachine.OnRecoveryCompleted(); + } + } + + public async ValueTask WriteStateAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + Task? pendingWrite; + var didEnqueue = false; + lock (_lock) + { + // If the pending write is faulted, recovery will need to be performed. + // For now, await it so that we can propagate the exception consistently. + if (_pendingWrite is not { IsFaulted: true }) + { + var completion = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + _pendingWrite = completion.Task; + var workItemType = _storage.IsCompactionRequested switch + { + true => WorkItemType.WriteSnapshot, + false => WorkItemType.AppendLog, + }; + + _workQueue.Enqueue(new WorkItem(workItemType, completion)); + didEnqueue = true; + } + + pendingWrite = _pendingWrite; + } + + if (didEnqueue) + { + _workSignal.Signal(); + } + + if (pendingWrite is { } task) + { + await task.WaitAsync(cancellationToken); + } + } + + private void OnSetStateMachineId(string name, ulong id) + { + lock (_lock) + { + if (id >= _nextStateMachineId) + { + _nextStateMachineId = id + 1; + } + + if (_stateMachines.TryGetValue(name, out var stateMachine)) + { + _stateMachinesMap[id] = stateMachine; + stateMachine.Reset(new StateMachineLogWriter(this, new(id))); + } + else + { + throw new InvalidOperationException($"State machine \"{name}\" (id: {id}) has not been registered on this state machine manager."); + } + } + } + + public bool TryGetStateMachine(string name, [NotNullWhen(true)] out IDurableStateMachine? stateMachine) => _stateMachines.TryGetValue(name, out stateMachine); + + void ILifecycleParticipant.Participate(IGrainLifecycle observer) => observer.Subscribe(GrainLifecycleStage.SetupState, this); + Task ILifecycleObserver.OnStart(CancellationToken cancellationToken) => InitializeAsync(cancellationToken).AsTask(); + async Task ILifecycleObserver.OnStop(CancellationToken cancellationToken) + { + _shutdownCancellation.Cancel(); + _workSignal.Signal(); + await _workLoop.WaitAsync(cancellationToken).ConfigureAwait(ConfigureAwaitOptions.ContinueOnCapturedContext | ConfigureAwaitOptions.SuppressThrowing); + } + + void IDisposable.Dispose() + { + _shutdownCancellation.Dispose(); + } + + private sealed class StateMachineLogWriter(StateMachineManager manager, StateMachineId streamId) : IStateMachineLogWriter + { + private readonly StateMachineManager _manager = manager; + private readonly StateMachineId _id = streamId; + + public void AppendEntry(Action> action, TState state) + { + lock (_manager._lock) + { + var segment = _manager._currentLogSegment ??= new(); + var logWriter = segment.CreateLogWriter(_id); + logWriter.AppendEntry(action, state); + } + } + + public void AppendEntries(Action action, TState state) + { + lock (_manager._lock) + { + var segment = _manager._currentLogSegment ??= new(); + var logWriter = segment.CreateLogWriter(_id); + action(state, logWriter); + } + } + } + + private readonly struct WorkItem(StateMachineManager.WorkItemType type, TaskCompletionSource? completion) + { + public WorkItemType Type { get; } = type; + public TaskCompletionSource? CompletionSource { get; } = completion; + public object? Context { get; init; } + } + + private enum WorkItemType + { + Initialize, + AppendLog, + WriteSnapshot, + DeleteState, + RegisterStateMachine + } + + private enum ManagerState + { + Unknown, + Ready, + } + + private sealed class StateMachineManagerState( + StateMachineManager manager, + IFieldCodec keyCodec, + IFieldCodec valueCodec, + SerializerSessionPool serializerSessionPool) : DurableDictionary(keyCodec, valueCodec, serializerSessionPool) + { + private readonly StateMachineManager _manager = manager; + + public void ResetVolatileState() => ((IDurableStateMachine)this).Reset(new StateMachineLogWriter(_manager, new(0))); + + protected override void OnSet(string key, ulong value) => _manager.OnSetStateMachineId(key, value); + } + + [LoggerMessage( + Level = LogLevel.Error, + Message = "Error processing work items.")] + private static partial void LogErrorProcessingWorkItems(ILogger logger, Exception exception); +} diff --git a/src/Orleans.Journaling/StateMachineStorageWriter.cs b/src/Orleans.Journaling/StateMachineStorageWriter.cs new file mode 100644 index 00000000000..2abc44f933f --- /dev/null +++ b/src/Orleans.Journaling/StateMachineStorageWriter.cs @@ -0,0 +1,24 @@ +using System.Buffers; + +namespace Orleans.Journaling; + +public readonly struct StateMachineStorageWriter +{ + private readonly StateMachineId _id; + private readonly LogExtentBuilder _segment; + + internal StateMachineStorageWriter(StateMachineId id, LogExtentBuilder segment) + { + _id = id; + _segment = segment; + } + + public void AppendEntry(byte[] value) => _segment.AppendEntry(_id, value); + public void AppendEntry(Span value) => _segment.AppendEntry(_id, value); + public void AppendEntry(Memory value) => _segment.AppendEntry(_id, value); + public void AppendEntry(ReadOnlyMemory value) => _segment.AppendEntry(_id, value); + public void AppendEntry(ArraySegment value) => _segment.AppendEntry(_id, value); + public void AppendEntry(ReadOnlySpan value) => _segment.AppendEntry(_id, value); + public void AppendEntry(ReadOnlySequence value) => _segment.AppendEntry(_id, value); + public void AppendEntry(Action> valueWriter, T value) => _segment.AppendEntry(_id, valueWriter, value); +} diff --git a/src/Orleans.Journaling/VolatileStateMachineStorage.cs b/src/Orleans.Journaling/VolatileStateMachineStorage.cs new file mode 100644 index 00000000000..c04264e3b20 --- /dev/null +++ b/src/Orleans.Journaling/VolatileStateMachineStorage.cs @@ -0,0 +1,59 @@ +using Orleans.Serialization.Buffers; +using System.Collections.Concurrent; +using System.Runtime.CompilerServices; + +namespace Orleans.Journaling; + +public sealed class VolatileStateMachineStorageProvider : IStateMachineStorageProvider +{ + private readonly ConcurrentDictionary _storage = new(); + public IStateMachineStorage Create(IGrainContext grainContext) => _storage.GetOrAdd(grainContext.GrainId, _ => new VolatileStateMachineStorage()); +} + +/// +/// An in-memory, volatile implementation of for non-durable use cases, such as development and testing. +/// +public sealed class VolatileStateMachineStorage : IStateMachineStorage +{ + private readonly List _segments = []; + + public bool IsCompactionRequested => _segments.Count > 10; + + /// + public async IAsyncEnumerable ReadAsync([EnumeratorCancellation] CancellationToken cancellationToken) + { + await Task.CompletedTask; + using var buffer = new ArcBufferWriter(); + foreach (var segment in _segments) + { + cancellationToken.ThrowIfCancellationRequested(); + buffer.Write(segment); + yield return new LogExtent(buffer.ConsumeSlice(segment.Length)); + } + } + + /// + public ValueTask AppendAsync(LogExtentBuilder segment, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + _segments.Add(segment.ToArray()); + return default; + } + + /// + public ValueTask ReplaceAsync(LogExtentBuilder snapshot, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + _segments.Clear(); + _segments.Add(snapshot.ToArray()); + return default; + } + + public ValueTask DeleteAsync(CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + _segments.Clear(); + return default; + } +} + diff --git a/src/Orleans.Runtime/Properties/AssemblyInfo.cs b/src/Orleans.Runtime/Properties/AssemblyInfo.cs index 4c2864c6a28..abf4e11da60 100644 --- a/src/Orleans.Runtime/Properties/AssemblyInfo.cs +++ b/src/Orleans.Runtime/Properties/AssemblyInfo.cs @@ -2,6 +2,7 @@ [assembly: InternalsVisibleTo("Orleans.Streaming")] [assembly: InternalsVisibleTo("Orleans.Reminders")] +[assembly: InternalsVisibleTo("Orleans.Journaling")] [assembly: InternalsVisibleTo("Orleans.TestingHost")] [assembly: InternalsVisibleTo("AWSUtils.Tests")] diff --git a/src/Orleans.Serialization.TestKit/FieldCodecTester.cs b/src/Orleans.Serialization.TestKit/FieldCodecTester.cs index 0a9ec599ea1..c0ef28914d5 100644 --- a/src/Orleans.Serialization.TestKit/FieldCodecTester.cs +++ b/src/Orleans.Serialization.TestKit/FieldCodecTester.cs @@ -14,6 +14,7 @@ using Xunit; using Orleans.Serialization.Serializers; using Xunit.Abstractions; +using Orleans.Serialization.GeneratedCodeHelpers; namespace Orleans.Serialization.TestKit { @@ -886,6 +887,15 @@ private void CanBeSkipped(TValue original) Assert.Equal(writerSession.ReferencedObjects.CurrentReferenceId, readerSession.ReferencedObjects.CurrentReferenceId); } + { + using var readerSession = _sessionPool.GetSession(); + var reader = Reader.Create(readResult.Buffer, readerSession); + var readField = reader.ReadFieldHeader(); + reader.ConsumeUnknownField(readField); + Assert.Equal(expectedLength, reader.Position); + Assert.Equal(writerSession.ReferencedObjects.CurrentReferenceId, readerSession.ReferencedObjects.CurrentReferenceId); + } + pipe.Reader.AdvanceTo(readResult.Buffer.End); pipe.Reader.Complete(); } diff --git a/src/Orleans.Serialization/Buffers/ArcBufferWriter.cs b/src/Orleans.Serialization/Buffers/ArcBufferWriter.cs new file mode 100644 index 00000000000..3276ef51f11 --- /dev/null +++ b/src/Orleans.Serialization/Buffers/ArcBufferWriter.cs @@ -0,0 +1,1471 @@ +#nullable enable +using System; +using System.Buffers; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Collections.Generic; +using System.Collections; + +#if NET6_0_OR_GREATER +using System.Numerics; +#else +using Orleans.Serialization.Utilities; +#endif + +namespace Orleans.Serialization.Buffers; + +/// +/// A implementation implemented using pooled arrays which is specialized for creating instances. +/// +[StructLayout(LayoutKind.Auto)] +[Immutable] +public sealed class ArcBufferWriter : IBufferWriter, IDisposable +{ + // The first page. This is the page which consumers will consume from. + // This may be equal to the current page, or it may be a previous page. + private ArcBufferPage _readPage; + + // The current page. This is the page which will be written to when the next write occurs. + private ArcBufferPage _writePage; + + // The current page. This is the last page which was allocated. In the linked list formed by the pages, _first <= _current <= _tail. + private ArcBufferPage _tail; + + // The offset into the first page which has been consumed already. When this reaches the end of the page, the page can be unpinned. + private int _readIndex; + + // The total length of the buffer. + private int _totalLength; + + // Indicates whether the writer has been disposed. + private bool _disposed; + + /// + /// Gets the minimum page size. + /// + public const int MinimumPageSize = ArcBufferPagePool.MinimumPageSize; + + /// + /// Initializes a new instance of the struct. + /// + public ArcBufferWriter() + { + _readPage = _writePage = _tail = ArcBufferPagePool.Shared.Rent(); + Debug.Assert(_readPage.ReferenceCount == 0); + _readPage.Pin(_readPage.Version); + } + + /// + /// Gets the number of unconsumed bytes. + /// + public int Length + { + get + { + ThrowIfDisposed(); + return _totalLength - _readIndex; + } + } + + /// + /// Adds additional buffers to the destination list until the list has reached its capacity. + /// + /// The destination to add buffers to. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void ReplenishBuffers(List> buffers) + { + ThrowIfDisposed(); + + // Skip half-full pages in an attempt to minimize the number of buffers added to the destination + // at the expense of under-utilized memory. This could be tweaked up to increase page utilization. + const int MinimumUsablePageSize = MinimumPageSize / 2; + + while (buffers.Count < buffers.Capacity) + { + // Only use the current page if it is greater than the minimum "usable" page size. + if (_tail.WriteCapacity > MinimumUsablePageSize) + { + buffers.Add(_tail.WritableArraySegment); + } + + // Allocate a new page. + AllocatePage(0); + } + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + void IBufferWriter.Advance(int count) => AdvanceWriter(count); + + /// + /// Advances the writer by the specified number of bytes. + /// + /// The numbers of bytes to advance the writer by. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void AdvanceWriter(int count) + { + ThrowIfDisposed(); + +#if NET5_0_OR_GREATER + ArgumentOutOfRangeException.ThrowIfLessThan(count, 0); +#else + if (count < 0) throw new ArgumentOutOfRangeException(nameof(count), "Length must be greater than or equal to 0."); +#endif + _totalLength += count; + while (true) + { + var amount = Math.Min(_writePage.WriteCapacity, count); + _writePage.Advance(amount); + count -= amount; + + if (count == 0) + { + break; + } + + var next = _writePage.Next; + Debug.Assert(next is not null); + _writePage = next; + } + } + + /// + /// Resets this instance, returning all memory. + /// + public void Reset() + { + ThrowIfDisposed(); + + UnpinAll(); + _totalLength = _readIndex = 0; + _readPage = _writePage = _tail = ArcBufferPagePool.Shared.Rent(); + Debug.Assert(_readPage.ReferenceCount == 0); + _readPage.Pin(_readPage.Version); + } + + /// + public void Dispose() + { + if (_disposed) return; + + UnpinAll(); + _totalLength = _readIndex = 0; + _readPage = _writePage = _tail = null!; + _disposed = true; + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Memory GetMemory(int sizeHint = 0) + { + ThrowIfDisposed(); + + if (sizeHint >= _writePage.WriteCapacity) + { + return GetMemorySlow(sizeHint); + } + + return _writePage.WritableMemory; + } + + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span GetSpan(int sizeHint = 0) + { + ThrowIfDisposed(); + + if (sizeHint >= _writePage.WriteCapacity) + { + return GetSpanSlow(sizeHint); + } + + return _writePage.WritableSpan; + } + + /// + /// Attempts to read the provided number of bytes from the buffer. + /// + /// The destination, which may be used to hold the requested data if the data needs to be copied. + /// A span of either zero length, if the data is unavailable, or at least the requested length if the data is available. + public ReadOnlySpan Peek(scoped in Span destination) + { + ThrowIfDisposed(); + + // Single span. + var firstSpan = _readPage.AsSpan(_readIndex, _readPage.Length - _readIndex); + if (firstSpan.Length >= destination.Length) + { + return firstSpan; + } + + // Multiple spans. Create a slice without pinning it, since we would be immediately unpinning it. + Peek(destination); + return destination; + } + + /// Copies the contents of this writer to a span. + /// This method does not advance the read cursor. + public int Peek(Span output) + { + ThrowIfDisposed(); + + var bytesCopied = 0; + var current = _readPage; + var offset = _readIndex; + while (output.Length > 0 && current != null) + { + var segment = current.AsSpan(offset, current.Length - offset); + var copyLength = Math.Min(segment.Length, output.Length); + bytesCopied += copyLength; + var slice = segment[..copyLength]; + slice.CopyTo(output); + output = output[slice.Length..]; + current = current.Next; + offset = 0; + } + + return bytesCopied; + } + + /// + /// Writes the provided sequence to this buffer. + /// + /// The data to write. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Write(ReadOnlySequence input) + { + ThrowIfDisposed(); + + foreach (var segment in input) + { + Write(segment.Span); + } + } + + /// + /// Writes the provided value to this buffer. + /// + /// The data to write. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Write(ReadOnlySpan value) + { + ThrowIfDisposed(); + + var destination = GetSpan(); + + // Fast path, try copying to the available memory directly + if (value.Length <= destination.Length) + { + value.CopyTo(destination); + AdvanceWriter(value.Length); + } + else + { + WriteMultiSegment(value, destination); + } + } + + private void WriteMultiSegment(in ReadOnlySpan source, Span destination) + { + var input = source; + while (true) + { + var writeSize = Math.Min(destination.Length, input.Length); + input[..writeSize].CopyTo(destination); + AdvanceWriter(writeSize); + input = input[writeSize..]; + if (input.Length > 0) + { + destination = GetSpan(); + + continue; + } + + return; + } + } + + /// + /// Unpins all pages. + /// + private void UnpinAll() + { + var current = _readPage; + while (current != null) + { + var previous = current; + current = previous.Next; + previous.Unpin(previous.Version); + } + } + + /// + /// Returns a slice of the provided length without marking the data referred to it as consumed. + /// + /// The number of bytes to consume. + /// A slice of unconsumed data. + public ArcBuffer PeekSlice(int count) + { + ThrowIfDisposed(); + +#if NET6_0_OR_GREATER + ArgumentOutOfRangeException.ThrowIfLessThan(count, 0); + ArgumentOutOfRangeException.ThrowIfGreaterThan(count, Length); +#else + if (count < 0) throw new ArgumentOutOfRangeException(nameof(count), "Length must be greater than or equal to 0."); + if (count > Length) throw new ArgumentOutOfRangeException(nameof(count), "Length must be less than or equal to the unconsumed length of the buffer."); +#endif + Debug.Assert(count >= 0); + Debug.Assert(count <= Length); + + var result = new ArcBuffer(_readPage, token: _readPage.Version, offset: _readIndex, count); + result.Pin(); + return result; + } + + /// + /// Consumes a slice of the provided length. + /// + /// The number of bytes to consume. + /// A buffer representing the consumed data. + public ArcBuffer ConsumeSlice(int count) + { + ThrowIfDisposed(); + + var result = PeekSlice(count); + + // Advance the cursor so that subsequent slice calls will return the next slice. + AdvanceReader(count); + + return result; + } + + /// + /// Advances the reader by the specified number of bytes. + /// + /// The number of bytes to advance the reader. + public void AdvanceReader(int count) + { + ThrowIfDisposed(); + + Debug.Assert(count >= 0); + Debug.Assert(count <= Length); + + _readIndex += count; + + // If this call would consume the entire first page and the page is not the current write page, unpin it. + while (_readIndex >= _readPage.Length && _writePage != _readPage) + { + // Advance the consumed length. + var current = _readPage; + _readIndex -= current.Length; + _totalLength -= current.Length; + + // Advance to the next page + Debug.Assert(current.Next is not null); + _readPage = current.Next!; + + // Unpin the page. + current.Unpin(current.Version); + } + + Debug.Assert(_readPage is not null); + Debug.Assert(_readIndex <= _readPage.Length); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private Span GetSpanSlow(int sizeHint) => AdvanceWritePage(sizeHint).Array; + + [MethodImpl(MethodImplOptions.NoInlining)] + private Memory GetMemorySlow(int sizeHint) => AdvanceWritePage(sizeHint).AsMemory(0); + + private ArcBufferPage AllocatePage(int sizeHint) + { + Debug.Assert(_tail.Next is null); + + var newBuffer = ArcBufferPagePool.Shared.Rent(sizeHint); + Debug.Assert(newBuffer.ReferenceCount == 0); + newBuffer.Pin(newBuffer.Version); + _tail.SetNext(newBuffer, _tail.Version); + return _tail = newBuffer; + } + + private ArcBufferPage AdvanceWritePage(int sizeHint) + { + var next = _writePage.Next; + if (next is null) + { + next = AllocatePage(sizeHint); + } + + _writePage = next; + return next; + } + + private void ThrowIfDisposed() + { + if (_disposed) + throw new ObjectDisposedException(nameof(ArcBufferWriter)); + } +} + +internal sealed class ArcBufferPagePool +{ + public static ArcBufferPagePool Shared { get; } = new(); + public const int MinimumPageSize = 16 * 1024; + private readonly ConcurrentQueue _pages = new(); + private readonly ConcurrentQueue _largePages = new(); + + private ArcBufferPagePool() { } + + public ArcBufferPage Rent(int size = -1) + { + ArcBufferPage? block; + if (size <= MinimumPageSize) + { + if (!_pages.TryDequeue(out block)) + { + block = new ArcBufferPage(size); + } + } + else if (_largePages.TryDequeue(out block)) + { + block.ResizeLargeSegment(size); + return block; + } + + return block ?? new ArcBufferPage(size); + } + + internal void Return(ArcBufferPage block) + { + Debug.Assert(block.IsValid); + if (block.IsMinimumSize) + { + _pages.Enqueue(block); + } + else + { + _largePages.Enqueue(block); + } + } +} + +/// +/// A page of data. +/// +public sealed class ArcBufferPage +{ + // The current version of the page. Each time the page is return to the pool, the version is incremented. + // This helps to ensure that the page is not consumed after it has been returned to the pool. + // This is a guard against certain programming bugs. + private int _version; + + // The current reference count. This is used to ensure that a page is not returned to the pool while it is still in use. + private int _refCount; + + internal ArcBufferPage() + { + Array = []; + } + + internal ArcBufferPage(int length) + { +#if !NET6_0_OR_GREATER + Array = null!; +#endif + InitializeArray(length); + } + + public void ResizeLargeSegment(int length) + { + Debug.Assert(length > ArcBufferPagePool.MinimumPageSize); + InitializeArray(length); + } + +#if NET6_0_OR_GREATER + [MemberNotNull(nameof(Array))] +#endif + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void InitializeArray(int length) + { + if (length <= ArcBufferPagePool.MinimumPageSize) + { + Debug.Assert(Array is null); +#if NET6_0_OR_GREATER + var array = GC.AllocateUninitializedArray(ArcBufferPagePool.MinimumPageSize, pinned: true); +#else + var array = new byte[ArcBufferPagePool.MinimumPageSize]; +#endif + Array = array; + } + else + { + // Round up to a power of two. + length = (int)BitOperations.RoundUpToPowerOf2((uint)length); + + if (Array is not null) + { + // The segment has an appropriate size already. + if (Array.Length == length) + { + return; + } + + // The segment is being resized. + ArrayPool.Shared.Return(Array); + } + + Array = ArrayPool.Shared.Rent(length); + } + } + + /// + /// Gets the array underpinning the page. + /// + public byte[] Array { get; private set; } + + /// + /// Gets the number of bytes which have been written to the page. + /// + public int Length { get; private set; } + + /// + /// A containing the readable bytes from this page. + /// + public ReadOnlySpan ReadableSpan => Array.AsSpan(0, Length); + + /// + /// A containing the readable bytes from this page. + /// + public ReadOnlyMemory ReadableMemory => AsMemory(0, Length); + + /// + /// An containing the readable bytes from this page. + /// + public ArraySegment ReadableArraySegment => new(Array, 0, Length); + + /// + /// An containing the writable bytes from this page. + /// + public ArraySegment WritableArraySegment => new(Array, Length, Array.Length - Length); + + /// + /// Gets the next node. + /// + public ArcBufferPage? Next { get; protected set; } + + /// + /// Gets the current page version. + /// + public int Version => _version; + + /// + /// Gets a value indicating whether this page is valid. + /// + public bool IsValid => Array is { Length: > 0 }; + + /// + /// Gets a value indicating whether this page is equal to the minimum page size. + /// + public bool IsMinimumSize => Array.Length == ArcBufferPagePool.MinimumPageSize; + + /// + /// Gets the number of bytes in the page which are available for writing. + /// + public int WriteCapacity => Array.Length - Length; + + /// + /// Gets the writable memory in the page. + /// + public Memory WritableMemory => AsMemory(Length); + + /// + /// Gets a span representing the writable memory in the page. + /// + public Span WritableSpan => AsSpan(Length); + + /// + /// Gets the current page reference count. + /// + internal int ReferenceCount => _refCount; + + /// + /// Creates a new memory region over the portion of the target page beginning at a specified position. + /// + /// The offset into the array to return memory from. + /// The memory region. + public Memory AsMemory(int offset) + { +#if NET6_0_OR_GREATER + if (IsMinimumSize) + { + return MemoryMarshal.CreateFromPinnedArray(Array, offset, Array.Length - offset); + } +#endif + + return Array.AsMemory(offset); + } + + /// + /// Creates a new memory region over the portion of the target page beginning at a specified position with a specified length. + /// + /// The offset into the array that the memory region starts from. + /// The length of the memory region. + /// The memory region. + public Memory AsMemory(int offset, int length) + { +#if NET6_0_OR_GREATER + if (IsMinimumSize) + { + return MemoryMarshal.CreateFromPinnedArray(Array, offset, length); + } +#endif + + return Array.AsMemory(offset, length); + } + + /// + /// Returns an array segment pointing to the underlying array, starting from the provided offset, and having the provided length. + /// + /// The offset into the array that the array segment starts from. + /// The length of the array segment. + /// The array segment. + public ArraySegment AsArraySegment(int offset, int length) => new(Array, offset, length); + + /// + /// Gets a span pointing to the underlying array, starting from the provided offset. + /// + /// The offset. + /// The span. + public Span AsSpan(int offset) => Array.AsSpan(offset); + + /// + /// Gets a span pointing to the underlying array, starting from the provided offset. + /// + /// The offset. + /// The length. + /// The span. + public Span AsSpan(int offset, int length) => Array.AsSpan(offset, length); + + /// + /// Increases the number of bytes written to the page by the provided amount. + /// + /// The number of bytes to increase the length of this page by. + public void Advance(int bytes) + { + Debug.Assert(bytes >= 0, "Advance called with negative bytes"); + Length += bytes; + Debug.Assert(Length <= Array.Length); + } + + /// + /// Sets the next page in the sequence. + /// + /// The next page in the sequence. + /// The token, which must match the page's for this operation to be allowed. + public void SetNext(ArcBufferPage next, int token) + { + Debug.Assert(Next is null); + CheckValidity(token); + Debug.Assert(next is not null, "SetNext called with null next page"); + Debug.Assert(next != this, "SetNext called with self as next page"); + Next = next; + } + + /// + /// Pins this page to prevent it from being returned to the page pool. + /// + /// The token, which must match the page's for this operation to be allowed. + public void Pin(int token) + { + if (token != _version) + { + ThrowInvalidVersion(); + } + + Interlocked.Increment(ref _refCount); + } + + /// + /// Unpins this page, allowing it to be returned to the page pool. + /// + /// The token, which must match the page's for this operation to be allowed. + public void Unpin(int token) + { + CheckValidity(token); + if (Interlocked.Decrement(ref _refCount) == 0) + { + Return(); + } + } + + private void Return() + { + Debug.Assert(_refCount == 0); + Length = 0; + Next = default; + Interlocked.Increment(ref _version); + ArcBufferPagePool.Shared.Return(this); + } + + /// + /// Throws if the provided does not match the page's . + /// + /// The token, which must match the page's . + public void CheckValidity(int token) + { + if (token != _version) + { + ThrowInvalidVersion(); + } + + if (_refCount <= 0) + { + ThrowAccessViolation(); + } + } + + [DoesNotReturn] + private static void ThrowInvalidVersion() => throw new InvalidOperationException("An invalid token was provided when attempting to perform an operation on this page."); + + [DoesNotReturn] + private static void ThrowAccessViolation() => throw new InvalidOperationException("An attempt was made to access a page with an invalid reference count."); +} + +/// +/// Provides reader access to an . +/// +/// The writer. +public readonly struct ArcBufferReader(ArcBufferWriter writer) +{ + /// + /// Gets the number of unconsumed bytes. + /// + public int Length + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => writer.Length; + } + + /// + /// Attempts to read the provided number of bytes from the buffer. + /// + /// The destination, which may be used to hold the requested data if the data needs to be copied. + /// A span of either zero length, if the data is unavailable, or the requested length if the data is available. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ReadOnlySpan Peek(scoped in Span destination) => writer.Peek(in destination); + + /// + /// Returns a slice of the provided length without marking the data referred to it as consumed. + /// + /// The number of bytes to consume. + /// A slice of unconsumed data. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ArcBuffer PeekSlice(int count) => writer.PeekSlice(count); + + /// + /// Consumes a slice of the provided length. + /// + /// The number of bytes to consume. + /// A buffer representing the consumed data. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ArcBuffer ConsumeSlice(int count) => writer.ConsumeSlice(count); + + /// + /// Consumes the amount of data present in the span. + /// + /// + public void Consume(Span output) + { + var count = writer.Peek(output); + if (count != output.Length) + { + throw new InvalidOperationException("Attempted to consume more data than is available."); + } + + writer.AdvanceReader(count); + } + + public void Skip(int count) + { + writer.AdvanceReader(count); + } +} + +/// +/// Represents a slice of a . +/// +/// +/// Initializes a new instance of the type. +/// +/// The first page in the sequence. +/// The token of the first page in the sequence. +/// The offset into the buffer at which this slice begins. +/// The length of this slice. +public struct ArcBuffer(ArcBufferPage first, int token, int offset, int length) : IDisposable +{ + /// + /// Gets the token of the first page pointed to by this slice. + /// + private int _firstPageToken = token; + + /// + /// Gets the first page. + /// + public readonly ArcBufferPage First = first; + + /// + /// Gets the offset into the first page at which this slice begins. + /// + public readonly int Offset = offset; + + /// + /// Gets the length of this sequence. + /// + public readonly int Length = length; + + /// Copies the contents of this writer to a span. + public readonly int CopyTo(Span output) + { + CheckValidity(); + if (output.Length < Length) + { + throw new ArgumentException("Destination span is not large enough to hold the buffer contents.", nameof(output)); + } + + var copied = 0; + foreach (var span in this) + { + var slice = span[..Math.Min(span.Length, output.Length)]; + slice.CopyTo(output); + output = output[slice.Length..]; + copied += slice.Length; + } + + return copied; + } + + /// Copies the contents of this writer to a pooled buffer. + public readonly void CopyTo(ArcBufferWriter output) + { + CheckValidity(); + foreach (var span in this) + { + output.Write(span); + } + } + + /// Copies the contents of this writer to a buffer writer. + public readonly void CopyTo(ref TBufferWriter output) where TBufferWriter : IBufferWriter + { + CheckValidity(); + foreach (var span in this) + { + Write(ref output, span); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void Write(ref TBufferWriter writer, ReadOnlySpan value) where TBufferWriter : IBufferWriter + { + var destination = writer.GetSpan(); + + // Fast path, try copying to the available memory directly + if (value.Length <= destination.Length) + { + value.CopyTo(destination); + writer.Advance(value.Length); + } + else + { + WriteMultiSegment(ref writer, value, destination); + } + } + + private static void WriteMultiSegment(ref TBufferWriter writer, in ReadOnlySpan source, Span destination) where TBufferWriter : IBufferWriter + { + var input = source; + while (true) + { + var writeSize = Math.Min(destination.Length, input.Length); + input[..writeSize].CopyTo(destination); + writer.Advance(writeSize); + input = input[writeSize..]; + if (input.Length > 0) + { + destination = writer.GetSpan(); + + if (destination.IsEmpty) + { + ThrowInsufficientSpaceException(); + } + + continue; + } + + return; + } + } + + [DoesNotReturn] + private static void ThrowInsufficientSpaceException() => throw new InvalidOperationException("Insufficient capacity in provided buffer"); + + /// + /// Returns a new which must not be accessed after disposing this instance. + /// + public readonly ReadOnlySequence AsReadOnlySequence() + { + var runningIndex = 0L; + ReadOnlySequenceSegment? first = null; + ReadOnlySequenceSegment? previous = null; + var endIndex = 0; + foreach (var memory in MemorySegments) + { + var current = new ReadOnlySequenceSegment(memory, runningIndex); + first ??= current; + endIndex = memory.Length; + runningIndex += endIndex; + previous?.SetNext(current); + previous = current; + } + + if (first is null) + { + return ReadOnlySequence.Empty; + } + + Debug.Assert(first is not null); + Debug.Assert(previous is not null); + if (previous == first) + { + return new ReadOnlySequence(first.Memory); + } + + return new ReadOnlySequence(first, 0, previous, endIndex); + } + + /// + /// Returns the data which has been written as an array. + /// + /// The data which has been written. + public readonly byte[] ToArray() + { + CheckValidity(); + var result = new byte[Length]; + CopyTo(result); + return result; + } + + /// + /// Throws if the buffer it no longer valid. + /// + private readonly void CheckValidity() => First.CheckValidity(_firstPageToken); + + public readonly ArcBuffer Slice(int offset) => Slice(offset, Length - offset); + + public readonly ArcBuffer Slice(int offset, int length) + { + var result = UnsafeSlice(offset, length); + result.Pin(); + return result; + } + + public readonly ArcBuffer UnsafeSlice(int offset, int length) + { +#if NET6_0_OR_GREATER + ArgumentOutOfRangeException.ThrowIfLessThan(length, 0); + ArgumentOutOfRangeException.ThrowIfLessThan(offset, 0); + ArgumentOutOfRangeException.ThrowIfGreaterThan(length + offset, Length); +#else + if (length < 0) throw new ArgumentOutOfRangeException(nameof(length), "Length must be greater than or equal to 0."); + if (offset < 0) throw new ArgumentOutOfRangeException(nameof(offset), "Offset must be greater than or equal to 0."); + if (length + offset > Length) throw new ArgumentOutOfRangeException($"{nameof(length)} + {nameof(offset)}", "Length plus offset must be less than or equal to the length of the buffer."); +#endif + + CheckValidity(); + Debug.Assert(offset >= 0); + Debug.Assert(length >= 0); + Debug.Assert(offset + length <= Length); + ArcBuffer result; + + // Navigate to the offset page & calculate the offset into the page. + if (Offset + offset < First.Length || length == 0) + { + // The slice starts within this page. + result = new ArcBuffer(First, token: _firstPageToken, Offset + offset, length); + } + else + { + // The slice starts within a subsequent page. + // Account for the first page, then navigate to the page which the offset falls in. + offset -= First.Length - Offset; + var page = First.Next; + Debug.Assert(page is not null); + + while (offset >= page.Length) + { + offset -= page.Length; + page = page.Next; + Debug.Assert(page is not null); + } + + result = new ArcBuffer(page, token: page.Version, offset, length); + } + + return result; + } + + /// + /// Pins this slice, preventing the referenced pages from being returned to the pool. + /// + public readonly void Pin() + { + CheckValidity(); + var pageEnumerator = Pages.GetEnumerator(); + if (pageEnumerator.MoveNext()) + { + var page = pageEnumerator.Current!; + page.Pin(_firstPageToken); + } + + while (pageEnumerator.MoveNext()) + { + var page = pageEnumerator.Current!; + page.Pin(page.Version); + } + } + + /// + /// Unpins this slice, allowing the referenced pages to be returned to the pool. + /// + public void Unpin() + { + var pageEnumerator = Pages.GetEnumerator(); + if (pageEnumerator.MoveNext()) + { + var page = pageEnumerator.Current!; + page.Unpin(_firstPageToken); + } + + while (pageEnumerator.MoveNext()) + { + var page = pageEnumerator.Current!; + page.Unpin(page.Version); + } + + _firstPageToken = -1; + } + + /// + public void Dispose() + { + if (_firstPageToken == -1) + { + // Already disposed. + return; + } + + Unpin(); + } + + /// + /// Returns an enumerator which can be used to enumerate the span segments referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly SpanEnumerator GetEnumerator() => new(this); + + /// + /// Returns an enumerator which can be used to enumerate the pages referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + internal readonly PageEnumerator Pages => new(this); + + /// + /// Returns an enumerator which can be used to enumerate the pages referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + internal readonly PageSegmentEnumerator PageSegments => new(this); + + /// + /// Returns an enumerator which can be used to enumerate the span segments referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly SpanEnumerator SpanSegments => new(this); + + /// + /// Returns an enumerator which can be used to enumerate the memory segments referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly MemoryEnumerator MemorySegments => new(this); + + /// + /// Returns an enumerator which can be used to enumerate the array segments referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly ArraySegmentEnumerator ArraySegments => new(this); + + /// + /// Defines a region of data within a page. + /// + public readonly struct PageSegment(ArcBufferPage page, int offset, int length) + { + /// + /// Gets the page which this segment refers to. + /// + public readonly ArcBufferPage Page = page; + + /// + /// Gets the offset into the page at which this region begins. + /// + public readonly int Offset = offset; + + /// + /// Gets the length of this region. + /// + public readonly int Length = length; + + /// + /// Gets a representation of this region. + /// + public readonly ReadOnlySpan Span => Page.AsSpan(Offset, Length); + + /// + /// Gets a representation of this region. + /// + public readonly ReadOnlyMemory Memory => Page.AsMemory(Offset, Length); + + /// + /// Gets an representation of this region. + /// + public readonly ArraySegment ArraySegment => Page.AsArraySegment(Offset, Length); + } + + /// + /// Enumerates over page segments in a . + /// + /// + /// Initializes a new instance of the type. + /// + /// The buffer to enumerate. + internal struct PageSegmentEnumerator(ArcBuffer slice) : IEnumerable, IEnumerator + { + internal readonly ArcBuffer Slice = slice; + private int _position; + private ArcBufferPage? _page = slice.Length > 0 ? slice.First : null; + + internal readonly ArcBufferPage First => Slice.First; + internal readonly int Offset => Slice.Offset; + internal readonly int Length => Slice.Length; + + /// + /// Gets this instance as an enumerator. + /// + /// This instance. + public readonly PageSegmentEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public PageSegment Current { get; private set; } + + /// + readonly object? IEnumerator.Current => Current; + + /// + /// Gets a value indicating whether enumeration has completed. + /// + public readonly bool IsCompleted => _page is null || _position == Length; + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + Debug.Assert(_position <= Length, "Enumerator position exceeds slice length"); + if (_page is null || _position == Length) + { + Current = default; + Debug.Assert(_position == Length, "Enumerator ended before reaching full length"); + return false; + } + + if (_page == First) + { + Debug.Assert(_position == 0); + Slice.CheckValidity(); + var offset = Offset; + var length = Math.Min(Length, _page.Length - offset); + Debug.Assert(length >= 0, "Calculated negative length for first segment"); + _position += length; + Current = new PageSegment(_page, offset, length); + _page = _page.Next; + return true; + } + + { + var length = Math.Min(Length - _position, _page.Length); + Debug.Assert(length >= 0, "Calculated negative length for subsequent segment"); + _position += length; + Current = new PageSegment(_page, 0, length); + _page = _page.Next; + return true; + } + } + + /// + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + void IEnumerator.Reset() + { + _position = 0; + _page = Slice.Length > 0 ? Slice.First : null; + Current = default; + } + + /// + readonly void IDisposable.Dispose() { } + } + + /// + /// Enumerates over pages in a . + /// + /// + /// Initializes a new instance of the type. + /// + /// The slice to enumerate. + internal struct PageEnumerator(ArcBuffer slice) : IEnumerable, IEnumerator + { + private PageSegmentEnumerator _enumerator = slice.PageSegments; + + /// + /// Gets this instance as an enumerator. + /// + /// This instance. + public readonly PageEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public ArcBufferPage? Current { get; private set; } + + /// + readonly object? IEnumerator.Current => Current; + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (_enumerator.MoveNext()) + { + Current = _enumerator.Current.Page; + return true; + } + + Current = default; + return false; + } + + /// + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + void IEnumerator.Reset() + { + _enumerator = _enumerator.Slice.PageSegments; + Current = default; + } + + /// + readonly void IDisposable.Dispose() { } + } + + /// + /// Enumerates over spans of bytes in a . + /// + /// + /// Initializes a new instance of the type. + /// + /// The slice to enumerate. + public ref struct SpanEnumerator(ArcBuffer slice) + { + private PageSegmentEnumerator _enumerator = slice.PageSegments; + + /// + /// Gets this instance as an enumerator. + /// + /// This instance. + public readonly SpanEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public ReadOnlySpan Current { get; private set; } + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (_enumerator.MoveNext()) + { + Current = _enumerator.Current.Span; + return true; + } + + Current = default; + return false; + } + } + + /// + /// Enumerates over sequences of bytes in a . + /// + /// + /// Initializes a new instance of the type. + /// + /// The slice to enumerate. + public struct MemoryEnumerator(ArcBuffer slice) : IEnumerable>, IEnumerator> + { + private PageSegmentEnumerator _enumerator = slice.PageSegments; + + /// + /// Gets this instance as an enumerator. + /// + /// This instance. + public readonly MemoryEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public ReadOnlyMemory Current { get; private set; } + + /// + readonly object? IEnumerator.Current => Current; + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (_enumerator.MoveNext()) + { + Current = _enumerator.Current.Memory; + return true; + } + + Current = default; + return false; + } + + /// + readonly IEnumerator> IEnumerable>.GetEnumerator() => GetEnumerator(); + + /// + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + void IEnumerator.Reset() + { + _enumerator = _enumerator.Slice.PageSegments; + Current = default; + } + + /// + readonly void IDisposable.Dispose() { } + } + + /// + /// Enumerates over array segments in a . + /// + /// + /// Initializes a new instance of the type. + /// + /// The slice to enumerate. + public struct ArraySegmentEnumerator(ArcBuffer slice) : IEnumerable>, IEnumerator> + { + private PageSegmentEnumerator _enumerator = slice.PageSegments; + + /// + /// Gets this instance as an enumerator. + /// + /// This instance. + public readonly ArraySegmentEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public ArraySegment Current { get; private set; } + + /// + readonly object? IEnumerator.Current => Current; + + /// + /// Gets a value indicating whether enumeration has completed. + /// + public readonly bool IsCompleted => _enumerator.IsCompleted; + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (_enumerator.MoveNext()) + { + Current = _enumerator.Current.ArraySegment; + return true; + } + + Current = default; + return false; + } + + /// + readonly IEnumerator> IEnumerable>.GetEnumerator() => GetEnumerator(); + + /// + readonly IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + void IEnumerator.Reset() + { + _enumerator = _enumerator.Slice.PageSegments; + Current = default; + } + + /// + readonly void IDisposable.Dispose() { } + } + + private sealed class ReadOnlySequenceSegment : ReadOnlySequenceSegment + { + public ReadOnlySequenceSegment(ReadOnlyMemory memory, long runningIndex) + { + Memory = memory; + RunningIndex = runningIndex; + } + + public void SetNext(ReadOnlySequenceSegment next) + { + Debug.Assert(Next is null); + Next = next; + } + } +} diff --git a/src/Orleans.Serialization/Buffers/PooledBuffer.cs b/src/Orleans.Serialization/Buffers/PooledBuffer.cs index 2dba8414373..c2f4ead8e31 100644 --- a/src/Orleans.Serialization/Buffers/PooledBuffer.cs +++ b/src/Orleans.Serialization/Buffers/PooledBuffer.cs @@ -274,7 +274,7 @@ public ReadOnlySequence AsReadOnlySequence() /// Returns an enumerator which can be used to enumerate the data referenced by this instance. /// /// An enumerator for the data contained in this instance. - public readonly BufferSlice.SpanEnumerator GetEnumerator() => new(Slice()); + public readonly MemoryEnumerator MemorySegments => new(this); /// /// Writes the provided sequence to this buffer. @@ -367,6 +367,191 @@ private void Commit() CurrentPosition = 0; } + /// + /// Enumerates over sequences of bytes in a . + /// + public struct MemoryEnumerator + { + private static readonly SequenceSegment InitialSegmentSentinel = new(); + private static readonly SequenceSegment FinalSegmentSentinel = new(); + private readonly PooledBuffer _buffer; + private int _position; + private SequenceSegment _segment; + + /// + /// Initializes a new instance of the type. + /// + /// The buffer to enumerate. + public MemoryEnumerator(PooledBuffer buffer) + { + _buffer = buffer; + _segment = InitialSegmentSentinel; + CurrentMemory = ReadOnlyMemory.Empty; + } + + /// + /// Returns an enumerator which can be used to enumerate the data referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly MemoryEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public readonly ReadOnlyMemory Current => CurrentMemory; + public ReadOnlyMemory CurrentMemory; + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (ReferenceEquals(_segment, InitialSegmentSentinel)) + { + _segment = _buffer.First; + } + + var endPosition = _buffer.Length; + while (_segment != null && _segment != FinalSegmentSentinel) + { + var segment = _segment.CommittedMemory; + + // Find the starting segment and the offset to copy from. + var segmentLength = segment.Length; + if (segmentLength == 0) + { + CurrentMemory = ReadOnlyMemory.Empty; + _segment = FinalSegmentSentinel; + return false; + } + + CurrentMemory = segment[..segmentLength]; + _position += segmentLength; + _segment = _segment.Next as SequenceSegment; + return true; + } + + // Account for the uncommitted data at the end of the buffer. + // The write head is only linked to the previous buffers when Commit() is called and it is set to null afterwards, + // meaning that if the write head is not null, the other buffers are not linked to it and it therefore has not been enumerated. + if (_segment != FinalSegmentSentinel && _buffer.CurrentPosition > 0 && _buffer.WriteHead is { } head) + { + var finalLength = _buffer.CurrentPosition; + if (finalLength == 0) + { + CurrentMemory = ReadOnlyMemory.Empty; + _segment = FinalSegmentSentinel; + return false; + } + + CurrentMemory = head.Array.AsMemory(0, finalLength); + _position += finalLength; + Debug.Assert(_position == _buffer.Length); + _segment = FinalSegmentSentinel; + return true; + } + + return false; + } + } + + /// + /// Enumerates over sequences of bytes in a . + /// + public ref struct SpanEnumerator + { + private static readonly SequenceSegment InitialSegmentSentinel = new(); + private static readonly SequenceSegment FinalSegmentSentinel = new(); + private readonly +#if NET8_0_OR_GREATER + ref readonly +#endif + PooledBuffer _buffer; + private int _position; + private SequenceSegment _segment; + + /// + /// Initializes a new instance of the type. + /// + /// The buffer to enumerate. + public SpanEnumerator(ref PooledBuffer buffer) + { + _buffer = +#if NET8_0_OR_GREATER + ref +#endif + buffer; + _segment = InitialSegmentSentinel; + Current = ReadOnlySpan.Empty; + } + + /// + /// Returns an enumerator which can be used to enumerate the data referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly SpanEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public ReadOnlySpan Current { get; private set; } + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (ReferenceEquals(_segment, InitialSegmentSentinel)) + { + _segment = _buffer.First; + } + + var endPosition = _buffer.Length; + while (_segment != null && _segment != FinalSegmentSentinel) + { + var segment = _segment.CommittedMemory; + + // Find the starting segment and the offset to copy from. + var segmentLength = Math.Min(segment.Length, endPosition - _position); + if (segmentLength == 0) + { + Current = ReadOnlySpan.Empty; + _segment = FinalSegmentSentinel; + return false; + } + + Current = segment.Span[..segmentLength]; + _position += segmentLength; + _segment = _segment.Next as SequenceSegment; + return true; + } + + // Account for the uncommitted data at the end of the buffer. + // The write head is only linked to the previous buffers when Commit() is called and it is set to null afterwards, + // meaning that if the write head is not null, the other buffers are not linked to it and it therefore has not been enumerated. + if (_segment != FinalSegmentSentinel && _buffer.CurrentPosition > 0 && _buffer.WriteHead is { } head && _position < endPosition) + { + var finalLength = Math.Min(_buffer.CurrentPosition, endPosition - _position); + if (finalLength == 0) + { + Current = ReadOnlySpan.Empty; + _segment = FinalSegmentSentinel; + return false; + } + + Current = head.Array.AsSpan(0, finalLength); + _position += finalLength; + Debug.Assert(_position == endPosition); + _segment = FinalSegmentSentinel; + return true; + } + + return false; + } + } + /// /// Represents a slice of a . /// @@ -471,6 +656,12 @@ public readonly byte[] ToArray() /// An enumerator for the data contained in this instance. public readonly SpanEnumerator GetEnumerator() => new(this); + /// + /// Returns an enumerator which can be used to enumerate the data referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly MemoryEnumerator MemorySegments => new(this); + /// /// Enumerates over spans of bytes in a . /// @@ -497,6 +688,12 @@ public SpanEnumerator(BufferSlice slice) internal readonly int Offset => _slice._offset; internal readonly int Length => _slice._length; + /// + /// Returns an enumerator which can be used to enumerate the data referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly SpanEnumerator GetEnumerator() => this; + /// /// Gets the element in the collection at the current position of the enumerator. /// @@ -578,6 +775,120 @@ public bool MoveNext() return false; } } + + /// + /// Enumerates over sequences of bytes in a . + /// + public struct MemoryEnumerator + { + private static readonly SequenceSegment InitialSegmentSentinel = new(); + private static readonly SequenceSegment FinalSegmentSentinel = new(); + private readonly BufferSlice _slice; + private int _position; + private SequenceSegment _segment; + + /// + /// Initializes a new instance of the type. + /// + /// The slice to enumerate. + public MemoryEnumerator(BufferSlice slice) + { + _slice = slice; + _segment = InitialSegmentSentinel; + Current = ReadOnlyMemory.Empty; + } + + internal readonly PooledBuffer Buffer => _slice._buffer; + internal readonly int Offset => _slice._offset; + internal readonly int Length => _slice._length; + + /// + /// Returns an enumerator which can be used to enumerate the data referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public readonly MemoryEnumerator GetEnumerator() => this; + + /// + /// Gets the element in the collection at the current position of the enumerator. + /// + public ReadOnlyMemory Current { get; private set; } + + /// + /// Advances the enumerator to the next element of the collection. + /// + /// if the enumerator was successfully advanced to the next element; if the enumerator has passed the end of the collection. + public bool MoveNext() + { + if (ReferenceEquals(_segment, InitialSegmentSentinel)) + { + _segment = Buffer.First; + } + + var endPosition = Offset + Length; + while (_segment != null && _segment != FinalSegmentSentinel) + { + var segment = _segment.CommittedMemory; + + // Find the starting segment and the offset to copy from. + int segmentOffset; + if (_position < Offset) + { + if (_position + segment.Length <= Offset) + { + // Start is in a subsequent segment + _position += segment.Length; + _segment = _segment.Next as SequenceSegment; + continue; + } + else + { + // Start is in this segment + segmentOffset = Offset; + } + } + else + { + segmentOffset = 0; + } + + var segmentLength = Math.Min(segment.Length - segmentOffset, endPosition - (_position + segmentOffset)); + if (segmentLength == 0) + { + Current = ReadOnlyMemory.Empty; + _segment = FinalSegmentSentinel; + return false; + } + + Current = segment.Slice(segmentOffset, segmentLength); + _position += segmentOffset + segmentLength; + _segment = _segment.Next as SequenceSegment; + return true; + } + + // Account for the uncommitted data at the end of the buffer. + // The write head is only linked to the previous buffers when Commit() is called and it is set to null afterwards, + // meaning that if the write head is not null, the other buffers are not linked to it and it therefore has not been enumerated. + if (_segment != FinalSegmentSentinel && Buffer.CurrentPosition > 0 && Buffer.WriteHead is { } head && _position < endPosition) + { + var finalOffset = Math.Max(Offset - _position, 0); + var finalLength = Math.Min(Buffer.CurrentPosition, endPosition - (_position + finalOffset)); + if (finalLength == 0) + { + Current = ReadOnlyMemory.Empty; + _segment = FinalSegmentSentinel; + return false; + } + + Current = head.Array.AsMemory(finalOffset, finalLength); + _position += finalOffset + finalLength; + Debug.Assert(_position == endPosition); + _segment = FinalSegmentSentinel; + return true; + } + + return false; + } + } } private sealed class SequenceSegmentPool @@ -726,3 +1037,15 @@ public void Return() } } } + +/// +/// Extensions for . +/// +public static class PooledBufferExtensions +{ + /// + /// Returns an enumerator which can be used to enumerate the data referenced by this instance. + /// + /// An enumerator for the data contained in this instance. + public static PooledBuffer.SpanEnumerator GetEnumerator(this ref PooledBuffer buffer) => new(ref buffer); +} diff --git a/src/Orleans.Serialization/Cloning/IDeepCopier.cs b/src/Orleans.Serialization/Cloning/IDeepCopier.cs index 73c3655f62e..18f5fa32b36 100644 --- a/src/Orleans.Serialization/Cloning/IDeepCopier.cs +++ b/src/Orleans.Serialization/Cloning/IDeepCopier.cs @@ -1,3 +1,4 @@ +#nullable enable using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -31,7 +32,7 @@ public interface IDeepCopierProvider /// /// The type supported by the copier. /// A deep copier capable of copying instances of type , or if an appropriate copier was not found. - IDeepCopier TryGetDeepCopier(); + IDeepCopier? TryGetDeepCopier(); /// /// Gets a deep copier capable of copying instances of type . @@ -49,7 +50,7 @@ public interface IDeepCopierProvider /// The type supported by the returned copier. /// /// A deep copier capable of copying instances of type , or if an appropriate copier was not found. - IDeepCopier TryGetDeepCopier(Type type); + IDeepCopier? TryGetDeepCopier(Type type); /// /// Gets a base type copier capable of copying instances of type . @@ -69,7 +70,7 @@ public interface IDeepCopier /// /// Creates a deep copy of the provided untyped input. The type must still match the copier instance! /// - object DeepCopy(object input, CopyContext context); + object? DeepCopy(object? input, CopyContext context); } /// @@ -88,7 +89,7 @@ internal sealed class ShallowCopier : IOptionalDeepCopier public static readonly ShallowCopier Instance = new(); public bool IsShallowCopyable() => true; - public object DeepCopy(object input, CopyContext _) => input; + public object? DeepCopy(object? input, CopyContext _) => input; } /// @@ -102,7 +103,7 @@ public class ShallowCopier : IOptionalDeepCopier, IDeepCopier public T DeepCopy(T input, CopyContext _) => input; /// Returns the input value. - public object DeepCopy(object input, CopyContext _) => input; + public object? DeepCopy(object? input, CopyContext _) => input; } /// @@ -120,7 +121,7 @@ public interface IDeepCopier : IDeepCopier /// A copy of . T DeepCopy(T input, CopyContext context); - object IDeepCopier.DeepCopy(object input, CopyContext context) => DeepCopy((T)input, context); + object? IDeepCopier.DeepCopy(object? input, CopyContext context) => input is null ? null : DeepCopy((T)input, context); } /// @@ -190,22 +191,16 @@ public interface ISpecializableCopier /// Provides context for a copy operation. /// /// - public sealed class CopyContext : IDisposable + /// + /// Initializes a new instance of the class. + /// + /// The codec provider. + /// The action to call when this context is disposed. + public sealed class CopyContext(CodecProvider codecProvider, Action onDisposed) : IDisposable { private readonly Dictionary _copies = new(ReferenceEqualsComparer.Default); - private readonly CodecProvider _copierProvider; - private readonly Action _onDisposed; - - /// - /// Initializes a new instance of the class. - /// - /// The codec provider. - /// The action to call when this context is disposed. - public CopyContext(CodecProvider codecProvider, Action onDisposed) - { - _copierProvider = codecProvider; - _onDisposed = onDisposed; - } + private readonly CodecProvider _copierProvider = codecProvider; + private readonly Action _onDisposed = onDisposed; /// /// Returns the previously recorded copy of the provided object, if it exists. @@ -214,7 +209,7 @@ public CopyContext(CodecProvider codecProvider, Action onDisposed) /// The original object. /// The previously recorded copy of . /// if a copy of has been recorded, otherwise. - public bool TryGetCopy(object original, [NotNullWhen(true)] out T result) where T : class + public bool TryGetCopy(object? original, out T? result) where T : class { if (original is null) { @@ -253,15 +248,15 @@ public void RecordCopy(object original, object copy) /// The value type. /// The value. /// A copy of the provided value. - public T DeepCopy(T value) + public T? DeepCopy(T? value) { if (!typeof(T).IsValueType) { if (value is null) return default; } - var copier = _copierProvider.GetDeepCopier(value.GetType()); - return (T)copier.DeepCopy(value, this); + var copier = _copierProvider.GetDeepCopier(value!.GetType()); + return (T?)copier.DeepCopy(value, this); } /// @@ -357,15 +352,13 @@ private static bool IsShallowCopyableInternal(Type type) /// /// Converts an untyped copier into a strongly-typed copier. /// - internal sealed class UntypedCopierWrapper : IDeepCopier + internal sealed class UntypedCopierWrapper(IDeepCopier copier) : IDeepCopier { - private readonly IDeepCopier _copier; - - public UntypedCopierWrapper(IDeepCopier copier) => _copier = copier; + private readonly IDeepCopier _copier = copier; - public T DeepCopy(T original, CopyContext context) => (T)_copier.DeepCopy(original, context); + public T DeepCopy(T original, CopyContext context) => (T)_copier.DeepCopy(original, context)!; - public object DeepCopy(object original, CopyContext context) => _copier.DeepCopy(original, context); + public object? DeepCopy(object? original, CopyContext context) => _copier.DeepCopy(original, context); } /// @@ -397,16 +390,10 @@ public CopyContextPool(CodecProvider codecProvider) /// The context. private void Return(CopyContext context) => _pool.Return(context); - private readonly struct PoolPolicy : IPooledObjectPolicy + private readonly struct PoolPolicy(CodecProvider codecProvider, Action onDisposed) : IPooledObjectPolicy { - private readonly CodecProvider _codecProvider; - private readonly Action _onDisposed; - - public PoolPolicy(CodecProvider codecProvider, Action onDisposed) - { - _codecProvider = codecProvider; - _onDisposed = onDisposed; - } + private readonly CodecProvider _codecProvider = codecProvider; + private readonly Action _onDisposed = onDisposed; public CopyContext Create() => new(_codecProvider, _onDisposed); diff --git a/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs b/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs index cfc5963266b..7d0d58495b2 100644 --- a/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs +++ b/src/Orleans.Serialization/Codecs/ByteArrayCodec.cs @@ -393,7 +393,7 @@ public void WriteField(ref Writer writer, uint fie ReferenceCodec.MarkValueField(writer.Session); writer.WriteFieldHeader(fieldIdDelta, expectedType, typeof(PooledBuffer), WireType.LengthPrefixed); writer.WriteVarUInt32((uint)value.Length); - foreach (var segment in value) + foreach (var segment in value.GetEnumerator()) { writer.Write(segment); } diff --git a/src/Orleans.Serialization/Codecs/DictionaryCodec.cs b/src/Orleans.Serialization/Codecs/DictionaryCodec.cs index 8ae7a8707c7..f519d02c6ca 100644 --- a/src/Orleans.Serialization/Codecs/DictionaryCodec.cs +++ b/src/Orleans.Serialization/Codecs/DictionaryCodec.cs @@ -1,8 +1,7 @@ -#nullable enable +#nullable enable using System; using System.Buffers; using System.Collections.Generic; -using System.Linq; using System.Reflection; using Orleans.Serialization.Buffers; using Orleans.Serialization.Cloning; @@ -180,12 +179,12 @@ public Dictionary DeepCopy(Dictionary input, CopyCon { if (context.TryGetCopy>(input, out var result)) { - return result; + return result!; } if (input.GetType() != typeof(Dictionary)) { - return context.DeepCopy(input); + return context.DeepCopy(input)!; } result = new Dictionary(input.Count, input.Comparer); @@ -344,4 +343,4 @@ void IBaseCodec>.Deserialize(ref Reader private static void ThrowLengthFieldMissing() => throw new RequiredFieldMissingException("Serialized dictionary is missing its length field."); } -} \ No newline at end of file +} diff --git a/src/Orleans.Serialization/Orleans.Serialization.csproj b/src/Orleans.Serialization/Orleans.Serialization.csproj index 0324b54783f..30142a62ee4 100644 --- a/src/Orleans.Serialization/Orleans.Serialization.csproj +++ b/src/Orleans.Serialization/Orleans.Serialization.csproj @@ -1,4 +1,4 @@ - + Microsoft.Orleans.Serialization @@ -28,4 +28,8 @@ + + + + diff --git a/src/Orleans.Serialization/Serializers/SurrogateCodec.cs b/src/Orleans.Serialization/Serializers/SurrogateCodec.cs index 329f3616f57..9392e8feeb7 100644 --- a/src/Orleans.Serialization/Serializers/SurrogateCodec.cs +++ b/src/Orleans.Serialization/Serializers/SurrogateCodec.cs @@ -1,4 +1,4 @@ -#nullable enable +#nullable enable using System; using System.Buffers; using System.Diagnostics.CodeAnalysis; @@ -50,7 +50,7 @@ public TField DeepCopy(TField input, CopyContext context) { if (context.TryGetCopy(input, out var result)) { - return result; + return result!; } var surrogate = _converter.ConvertToSurrogate(in input); @@ -148,4 +148,4 @@ public void DeepCopy(TField input, TField output, CopyContext context) [DoesNotReturn] private void ThrowNoPopulatorException() => throw new NotSupportedException($"Surrogate type {typeof(TConverter)} does not implement {typeof(IPopulator)} and therefore cannot be used in an inheritance hierarchy."); -} \ No newline at end of file +} diff --git a/test/Orleans.Journaling.Tests/DurableDictionaryTests.cs b/test/Orleans.Journaling.Tests/DurableDictionaryTests.cs new file mode 100644 index 00000000000..a3c69682865 --- /dev/null +++ b/test/Orleans.Journaling.Tests/DurableDictionaryTests.cs @@ -0,0 +1,202 @@ +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("BVT")] +public class DurableDictionaryTests : StateMachineTestBase +{ + [Fact] + public async Task DurableDictionary_BasicOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var keyCodec = CodecProvider.GetCodec(); + var valueCodec = CodecProvider.GetCodec(); + var dictionary = new DurableDictionary("testDict", sut.Manager, keyCodec, valueCodec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Add items + dictionary.Add("one", 1); + dictionary.Add("two", 2); + dictionary.Add("three", 3); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(3, dictionary.Count); + Assert.Equal(1, dictionary["one"]); + Assert.Equal(2, dictionary["two"]); + Assert.Equal(3, dictionary["three"]); + + // Act - Update item + dictionary["two"] = 22; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(22, dictionary["two"]); + + // Act - Remove item + var removed = dictionary.Remove("three"); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.True(removed); + Assert.Equal(2, dictionary.Count); + Assert.False(dictionary.ContainsKey("three")); + } + + [Fact] + public async Task DurableDictionary_Persistence_Test() + { + // Arrange + var sut = CreateTestSystem(); + var keyCodec = CodecProvider.GetCodec(); + var valueCodec = CodecProvider.GetCodec(); + var dictionary1 = new DurableDictionary("testDict", sut.Manager, keyCodec, valueCodec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Add items and persist + dictionary1.Add("one", 1); + dictionary1.Add("two", 2); + dictionary1.Add("three", 3); + await sut.Manager.WriteStateAsync(CancellationToken.None); + + // Create a new manager with the same storage + var sut2 = CreateTestSystem(storage: sut.Storage); + var dictionary2 = new DurableDictionary("testDict", sut2.Manager, keyCodec, valueCodec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Dictionary should be recovered + Assert.Equal(3, dictionary2.Count); + Assert.Equal(1, dictionary2["one"]); + Assert.Equal(2, dictionary2["two"]); + Assert.Equal(3, dictionary2["three"]); + } + + [Fact] + public async Task DurableDictionary_ComplexKeys_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var keyCodec = CodecProvider.GetCodec(); + var valueCodec = CodecProvider.GetCodec(); + var dictionary = new DurableDictionary("complexDict", manager, keyCodec, valueCodec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act + var key1 = new TestKey { Id = 1, Name = "Key1" }; + var key2 = new TestKey { Id = 2, Name = "Key2" }; + + dictionary.Add(key1, "Value1"); + dictionary.Add(key2, "Value2"); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(2, dictionary.Count); + Assert.Equal("Value1", dictionary[key1]); + Assert.Equal("Value2", dictionary[key2]); + } + + [Fact] + public async Task DurableDictionary_ComplexValues_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var keyCodec = CodecProvider.GetCodec(); + var valueCodec = CodecProvider.GetCodec(); + var dictionary = new DurableDictionary("peopleDict", manager, keyCodec, valueCodec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act + var person1 = new TestPerson { Id = 1, Name = "John", Age = 30 }; + var person2 = new TestPerson { Id = 2, Name = "Jane", Age = 25 }; + + dictionary.Add("person1", person1); + dictionary.Add("person2", person2); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(2, dictionary.Count); + Assert.Equal("John", dictionary["person1"].Name); + Assert.Equal(25, dictionary["person2"].Age); + + // Act - Update + dictionary["person1"].Age = 31; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(31, dictionary["person1"].Age); + } + + [Fact] + public async Task DurableDictionary_Clear_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var keyCodec = CodecProvider.GetCodec(); + var valueCodec = CodecProvider.GetCodec(); + var dictionary = new DurableDictionary("clearDict", manager, keyCodec, valueCodec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add items + dictionary.Add("one", 1); + dictionary.Add("two", 2); + dictionary.Add("three", 3); + await manager.WriteStateAsync(CancellationToken.None); + + // Act - Clear + dictionary.Clear(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Empty(dictionary); + } + + [Fact] + public async Task DurableDictionary_Enumeration_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var keyCodec = CodecProvider.GetCodec(); + var valueCodec = CodecProvider.GetCodec(); + var dictionary = new DurableDictionary("enumDict", manager, keyCodec, valueCodec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add items + var expectedPairs = new Dictionary + { + { "one", 1 }, + { "two", 2 }, + { "three", 3 } + }; + + foreach (var pair in expectedPairs) + { + dictionary.Add(pair.Key, pair.Value); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Act & Assert - Test enumeration + var actualPairs = dictionary.ToDictionary(kv => kv.Key, kv => kv.Value); + Assert.Equal(expectedPairs, actualPairs); + + // Test Keys and Values collections + Assert.Equal(expectedPairs.Keys, dictionary.Keys); + Assert.Equal(expectedPairs.Values, dictionary.Values); + } +} + +[GenerateSerializer] +public record class TestKey +{ + [Id(0)] + public int Id { get; set; } + [Id(1)] + public string? Name { get; set; } +} diff --git a/test/Orleans.Journaling.Tests/DurableGrainTests.cs b/test/Orleans.Journaling.Tests/DurableGrainTests.cs new file mode 100644 index 00000000000..dee02dc67ad --- /dev/null +++ b/test/Orleans.Journaling.Tests/DurableGrainTests.cs @@ -0,0 +1,327 @@ +using Orleans.Core.Internal; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("BVT")] +public class DurableGrainTests(IntegrationTestFixture fixture) : IClassFixture +{ + private IGrainFactory Client => fixture.Client; + + [Fact] + public async Task DurableGrain_State_Persistence_Test() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Set state properties and persist + await grain.SetTestValues("Test Name", 42); + + // Assert + Assert.Equal("Test Name", await grain.GetName()); + Assert.Equal(42, await grain.GetCounter()); + + // Force deactivation and get a new reference + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - State should be recovered + Assert.Equal("Test Name", await grain.GetName()); + Assert.Equal(42, await grain.GetCounter()); + } + + [Fact] + public async Task DurableGrain_Update_State_Test() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Set state and persist + await grain.SetTestValues("Initial Name", 10); + + // Update state and persist again + await grain.SetTestValues("Updated Name", 20); + + // Assert + Assert.Equal("Updated Name", await grain.GetName()); + Assert.Equal(20, await grain.GetCounter()); + + // Force deactivation and get a new reference + await grain.Cast().DeactivateOnIdle(); + + // Assert - Updated state should be recovered + Assert.Equal("Updated Name", await grain.GetName()); + Assert.Equal(20, await grain.GetCounter()); + } + + [Fact] + public async Task DurableGrain_Complex_Types_Test() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Set complex state and persist + var person = new TestPerson { Id = 1, Name = "John Doe", Age = 30 }; + var items = new List { "Item1", "Item2", "Item3" }; + await grain.SetTestValues(person, items); + + // Assert + var retrievedPerson = await grain.GetPerson(); + var retrievedItems = await grain.GetItems(); + + Assert.Equal("John Doe", retrievedPerson.Name); + Assert.Equal(3, retrievedItems.Count); + + // Force deactivation and get a new reference + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - Complex state should be recovered + retrievedPerson = await grain.GetPerson(); + retrievedItems = await grain.GetItems(); + + Assert.NotNull(retrievedPerson); + Assert.Equal(1, retrievedPerson.Id); + Assert.Equal("John Doe", retrievedPerson.Name); + Assert.Equal(30, retrievedPerson.Age); + + Assert.Equal(3, retrievedItems.Count); + Assert.Equal("Item1", retrievedItems[0]); + Assert.Equal("Item2", retrievedItems[1]); + Assert.Equal("Item3", retrievedItems[2]); + } + + [Fact] + public async Task DurableGrain_Multiple_Collections_Test() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Populate collections and persist + await grain.AddToDictionary("key1", 1); + await grain.AddToDictionary("key2", 2); + await grain.AddToList("item1"); + await grain.AddToList("item2"); + await grain.AddToQueue(100); + await grain.AddToQueue(200); + await grain.AddToSet("set1"); + await grain.AddToSet("set2"); + + // Assert + Assert.Equal(2, await grain.GetDictionaryCount()); + Assert.Equal(2, await grain.GetListCount()); + Assert.Equal(2, await grain.GetQueueCount()); + Assert.Equal(2, await grain.GetSetCount()); + + // Force deactivation and get a new reference + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - All collections should be recovered + Assert.Equal(2, await grain.GetDictionaryCount()); + Assert.Equal(1, await grain.GetDictionaryValue("key1")); + Assert.Equal(2, await grain.GetDictionaryValue("key2")); + + Assert.Equal(2, await grain.GetListCount()); + Assert.Equal("item1", await grain.GetListItem(0)); + Assert.Equal("item2", await grain.GetListItem(1)); + + Assert.Equal(2, await grain.GetQueueCount()); + Assert.Equal(100, await grain.PeekQueueItem()); + + Assert.Equal(2, await grain.GetSetCount()); + Assert.True(await grain.ContainsSetItem("set1")); + Assert.True(await grain.ContainsSetItem("set2")); + } + + [Fact] + public async Task DurableGrain_State_Modifications_Test() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Populate initial state and persist + await grain.AddToDictionary("key1", 1); + await grain.AddToList("item1"); + await grain.AddToQueue(100); + await grain.AddToSet("set1"); + + // Modify state and persist again + await grain.AddToDictionary("key2", 2); + await grain.AddToDictionary("key1", 10); // Update via interface method + await grain.AddToList("item2"); + await grain.AddToQueue(200); + await grain.AddToSet("set2"); + + // Assert + Assert.Equal(2, await grain.GetDictionaryCount()); + Assert.Equal(10, await grain.GetDictionaryValue("key1")); + Assert.Equal(2, await grain.GetListCount()); + Assert.Equal(2, await grain.GetQueueCount()); + Assert.Equal(2, await grain.GetSetCount()); + + // Force deactivation and get a new reference + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - Modified state should be recovered + Assert.Equal(2, await grain.GetDictionaryCount()); + Assert.Equal(10, await grain.GetDictionaryValue("key1")); + Assert.Equal(2, await grain.GetDictionaryValue("key2")); + + Assert.Equal(2, await grain.GetListCount()); + Assert.Equal("item1", await grain.GetListItem(0)); + Assert.Equal("item2", await grain.GetListItem(1)); + + // Further modify the state + await grain.RemoveFromDictionary("key1"); + await grain.RemoveListItemAt(0); + await grain.DequeueItem(); + await grain.RemoveFromSet("set1"); + + // Assert the modifications + Assert.Equal(1, await grain.GetDictionaryCount()); + Assert.Equal(1, await grain.GetListCount()); + Assert.Equal(1, await grain.GetQueueCount()); + Assert.Equal(1, await grain.GetSetCount()); + } + + [Fact] + public async Task Grain_State_Should_Persist_Between_Activations() + { + // Arrange - Get a reference to a grain + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Set the grain state + await grain.SetValues("Test Name", 42); + var initialState = await grain.GetValues(); + + // Deactivate the grain forcefully + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Get the values from the grain (which will be reactivated) + var newState = await grain.GetValues(); + + // Assert + Assert.Equal(initialState.Name, newState.Name); + Assert.Equal(initialState.Counter, newState.Counter); + } + + [Fact] + public async Task Grain_Should_Handle_Multiple_Collections() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Add items to collections + await grain.AddToDictionary("key1", 1); + await grain.AddToDictionary("key2", 2); + + await grain.AddToList("item1"); + await grain.AddToList("item2"); + + await grain.AddToQueue(100); + await grain.AddToQueue(200); + + await grain.AddToSet("set1"); + await grain.AddToSet("set2"); + await grain.AddToSet("set1"); // Duplicate, should be ignored + + // Assert - Check counts + Assert.Equal(2, await grain.GetDictionaryCount()); + Assert.Equal(2, await grain.GetListCount()); + Assert.Equal(2, await grain.GetQueueCount()); + Assert.Equal(2, await grain.GetSetCount()); + + // Deactivate the grain forcefully + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - Check values after reactivation + Assert.Equal(1, await grain.GetDictionaryValue("key1")); + Assert.Equal(2, await grain.GetDictionaryValue("key2")); + Assert.Equal("item1", await grain.GetListItem(0)); + Assert.Equal("item2", await grain.GetListItem(1)); + Assert.Equal(100, await grain.PeekQueueItem()); + Assert.True(await grain.ContainsSetItem("set1")); + Assert.True(await grain.ContainsSetItem("set2")); + + // Act - Modify collections + await grain.RemoveFromDictionary("key1"); + await grain.RemoveListItemAt(0); + await grain.DequeueItem(); + await grain.RemoveFromSet("set1"); + + // Assert - Check counts after modifications + Assert.Equal(1, await grain.GetDictionaryCount()); + Assert.Equal(1, await grain.GetListCount()); + Assert.Equal(1, await grain.GetQueueCount()); + Assert.Equal(1, await grain.GetSetCount()); + + // Deactivate the grain again + idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - Check values after second reactivation + Assert.Equal(1, await grain.GetDictionaryCount()); + Assert.Equal(1, await grain.GetListCount()); + Assert.Equal(1, await grain.GetQueueCount()); + Assert.Equal(1, await grain.GetSetCount()); + Assert.Equal(2, await grain.GetDictionaryValue("key2")); + Assert.Equal("item2", await grain.GetListItem(0)); + Assert.Equal(200, await grain.PeekQueueItem()); + Assert.True(await grain.ContainsSetItem("set2")); + } + + [Fact] + public async Task Grain_Should_Handle_Large_State() + { + // Arrange + var grain = Client.GetGrain(Guid.NewGuid()); + + // Act - Add many items + const int itemCount = 1000; + for (int i = 0; i < itemCount; i++) + { + await grain.AddToDictionary($"key{i}", i); + if (i < 100) // Add fewer items to other collections to keep test runtime reasonable + { + await grain.AddToList($"item{i}"); + await grain.AddToQueue(i); + await grain.AddToSet($"set{i}"); + } + } + + // Assert - Check counts + Assert.Equal(itemCount, await grain.GetDictionaryCount()); + Assert.Equal(100, await grain.GetListCount()); + Assert.Equal(100, await grain.GetQueueCount()); + Assert.Equal(100, await grain.GetSetCount()); + + // Deactivate the grain forcefully + var idBefore = await grain.GetActivationId(); + await grain.Cast().DeactivateOnIdle(); + Assert.NotEqual(idBefore, await grain.GetActivationId()); + + // Assert - Check random values after reactivation + for (int i = 0; i < 10; i++) + { + var randomIndex = new Random().Next(0, itemCount - 1); + Assert.Equal(randomIndex, await grain.GetDictionaryValue($"key{randomIndex}")); + + if (randomIndex < 100) + { + Assert.Equal($"item{randomIndex}", await grain.GetListItem(randomIndex)); + Assert.True(await grain.ContainsSetItem($"set{randomIndex}")); + } + } + } +} diff --git a/test/Orleans.Journaling.Tests/DurableQueueTests.cs b/test/Orleans.Journaling.Tests/DurableQueueTests.cs new file mode 100644 index 00000000000..08c4f7e1433 --- /dev/null +++ b/test/Orleans.Journaling.Tests/DurableQueueTests.cs @@ -0,0 +1,301 @@ +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("BVT")] +public class DurableQueueTests : StateMachineTestBase +{ + [Fact] + public async Task DurableQueue_BasicOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("testQueue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Enqueue items + queue.Enqueue("one"); + queue.Enqueue("two"); + queue.Enqueue("three"); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(3, queue.Count); + + // Act - Peek + var peeked = queue.Peek(); + + // Assert - Peek doesn't remove the item + Assert.Equal("one", peeked); + Assert.Equal(3, queue.Count); + + // Act - Dequeue + var dequeued1 = queue.Dequeue(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("one", dequeued1); + Assert.Equal(2, queue.Count); + + // Act - Dequeue again + var dequeued2 = queue.Dequeue(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("two", dequeued2); + Assert.Single(queue); + + // Act - Dequeue last item + var dequeued3 = queue.Dequeue(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("three", dequeued3); + Assert.Empty(queue); + } + + [Fact] + public async Task DurableQueue_Persistence_Test() + { + // Arrange + var sut = CreateTestSystem(); + var codec = CodecProvider.GetCodec(); + var queue1 = new DurableQueue("testQueue", sut.Manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Enqueue items and persist + queue1.Enqueue("one"); + queue1.Enqueue("two"); + queue1.Enqueue("three"); + await sut.Manager.WriteStateAsync(CancellationToken.None); + + // Create a new manager with the same storage + var sut2 = CreateTestSystem(storage: sut.Storage); + var queue2 = new DurableQueue("testQueue", sut2.Manager, codec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Queue should be recovered + Assert.Equal(3, queue2.Count); + Assert.Equal("one", queue2.Peek()); + + // Act - Dequeue from recovered queue + var dequeued = queue2.Dequeue(); + await sut2.Manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("one", dequeued); + Assert.Equal(2, queue2.Count); + } + + [Fact] + public async Task DurableQueue_ComplexValues_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("personQueue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act + var person1 = new TestPerson { Id = 1, Name = "John", Age = 30 }; + var person2 = new TestPerson { Id = 2, Name = "Jane", Age = 25 }; + + queue.Enqueue(person1); + queue.Enqueue(person2); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(2, queue.Count); + var peeked = queue.Peek(); + Assert.Equal("John", peeked.Name); + + // Act - Dequeue + var dequeued = queue.Dequeue(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Single(queue); + Assert.Equal("John", dequeued.Name); + Assert.Equal(30, dequeued.Age); + } + + [Fact] + public async Task DurableQueue_Clear_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("clearQueue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add items + queue.Enqueue("one"); + queue.Enqueue("two"); + queue.Enqueue("three"); + await manager.WriteStateAsync(CancellationToken.None); + + // Act - Clear + queue.Clear(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Empty(queue); + Assert.Empty(queue); + } + + [Fact] + public async Task DurableQueue_EmptyQueueOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("emptyQueue", manager, codec, SessionPool); + await manager.WriteStateAsync(CancellationToken.None); + await sut.Lifecycle.OnStart(); + + // Assert + Assert.Empty(queue); + + // Act & Assert - Peek and Dequeue on empty queue should throw + Assert.Throws(() => queue.Peek()); + Assert.Throws(() => queue.Dequeue()); + } + + [Fact] + public async Task DurableQueue_Enumeration_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("enumQueue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add items + var expectedItems = new List { "one", "two", "three" }; + + foreach (var item in expectedItems) + { + queue.Enqueue(item); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Act + var actualItems = queue.ToList(); + + // Assert - Items should be in same order as enqueued + Assert.Equal(expectedItems, actualItems); + } + + [Fact] + public async Task DurableQueue_LargeNumberOfOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("largeQueue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Enqueue many items + const int itemCount = 1000; + for (int i = 0; i < itemCount; i++) + { + queue.Enqueue(i); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(itemCount, queue.Count); + Assert.Equal(0, queue.Peek()); + + // Create a new manager with the same storage to test recovery + var sut2 = CreateTestSystem(storage: sut.Storage); + var queue2 = new DurableQueue("largeQueue", sut2.Manager, codec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Large queue is correctly recovered + Assert.Equal(itemCount, queue2.Count); + + // Act - Dequeue all items and verify order + for (int i = 0; i < itemCount; i++) + { + var item = queue2.Dequeue(); + Assert.Equal(i, item); + } + + await sut2.Manager.WriteStateAsync(CancellationToken.None); + Assert.Empty(queue2); + } + + [Fact] + public async Task DurableQueue_Concurrent_EnqueueDequeue_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var queue = new DurableQueue("concurrentQueue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Simulate a queue with concurrent operations + const int batchSize = 100; + + // First batch: add 100 items + for (int i = 0; i < batchSize; i++) + { + queue.Enqueue(i); + } + await manager.WriteStateAsync(CancellationToken.None); + + // Remove 50 items + for (int i = 0; i < batchSize / 2; i++) + { + queue.Dequeue(); + } + await manager.WriteStateAsync(CancellationToken.None); + + // Add another 100 items + for (int i = batchSize; i < batchSize * 2; i++) + { + queue.Enqueue(i); + } + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(batchSize + batchSize / 2, queue.Count); // Should have 150 items + + // Create a new manager with the same storage to test recovery + var sut2 = CreateTestSystem(storage: sut.Storage); + var queue2 = new DurableQueue("concurrentQueue", sut2.Manager, codec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Queue should be recovered with correct state and ordering + Assert.Equal(batchSize + batchSize / 2, queue2.Count); + + // First values should be the second half of first batch + for (int i = batchSize / 2; i < batchSize; i++) + { + var item = queue2.Dequeue(); + Assert.Equal(i, item); + } + + // Then we should get the second batch + for (int i = batchSize; i < batchSize * 2; i++) + { + var item = queue2.Dequeue(); + Assert.Equal(i, item); + } + + await sut2.Manager.WriteStateAsync(CancellationToken.None); + Assert.Empty(queue2); + } +} diff --git a/test/Orleans.Journaling.Tests/DurableSetTests.cs b/test/Orleans.Journaling.Tests/DurableSetTests.cs new file mode 100644 index 00000000000..f74bc885bed --- /dev/null +++ b/test/Orleans.Journaling.Tests/DurableSetTests.cs @@ -0,0 +1,280 @@ +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("BVT")] +public class DurableSetTests : StateMachineTestBase +{ + [Fact] + public async Task DurableSet_BasicOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set = new DurableSet("testSet", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Add items + bool added1 = set.Add("one"); + bool added2 = set.Add("two"); + bool added3 = set.Add("three"); + bool duplicateAdded = set.Add("one"); // Adding duplicate + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.True(added1); + Assert.True(added2); + Assert.True(added3); + Assert.False(duplicateAdded); // Should not add duplicates + Assert.Equal(3, set.Count); + Assert.Contains("one", set); + Assert.Contains("two", set); + Assert.Contains("three", set); + + // Act - Remove item + bool removed = set.Remove("two"); + bool removedNonExisting = set.Remove("four"); // Remove non-existing + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.True(removed); + Assert.False(removedNonExisting); + Assert.Equal(2, set.Count); + Assert.Contains("one", set); + Assert.DoesNotContain("two", set); + Assert.Contains("three", set); + } + + [Fact] + public async Task DurableSet_Persistence_Test() + { + // First manager and set + var sut = CreateTestSystem(); + var codec = CodecProvider.GetCodec(); + var set1 = new DurableSet("testSet", sut.Manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Add items and persist + set1.Add("one"); + set1.Add("two"); + set1.Add("three"); + await sut.Manager.WriteStateAsync(CancellationToken.None); + + // Create a new manager with the same storage + var sut2 = CreateTestSystem(storage: sut.Storage); + var set2 = new DurableSet("testSet", sut2.Manager, codec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Set should be recovered + Assert.Equal(3, set2.Count); + Assert.Contains("one", set2); + Assert.Contains("two", set2); + Assert.Contains("three", set2); + } + + [Fact] + public async Task DurableSet_ComplexValues_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set = new DurableSet("personSet", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act + var person1 = new TestPerson { Id = 1, Name = "John", Age = 30 }; + var person2 = new TestPerson { Id = 2, Name = "Jane", Age = 25 }; + var person3 = new TestPerson { Id = 1, Name = "John", Age = 30 }; // Same as person1 + + set.Add(person1); + set.Add(person2); + bool duplicateAdded = set.Add(person3); // Should not add duplicate when overriding Equals + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(2, set.Count); + Assert.Contains(person1, set); + Assert.Contains(person2, set); + } + + [Fact] + public async Task DurableSet_Clear_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set = new DurableSet("clearSet", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add items + set.Add("one"); + set.Add("two"); + set.Add("three"); + await manager.WriteStateAsync(CancellationToken.None); + + // Act - Clear + set.Clear(); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Empty(set); + Assert.Empty(set); + } + + [Fact] + public async Task DurableSet_Enumeration_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set = new DurableSet("enumSet", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add items + var expectedItems = new HashSet { "one", "two", "three" }; + + foreach (var item in expectedItems) + { + set.Add(item); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Act + var actualItems = set.ToHashSet(); + + // Assert + Assert.Equal(expectedItems, actualItems); + } + + [Fact] + public async Task DurableSet_LargeNumberOfItems_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set = new DurableSet("largeSet", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Add many items + const int itemCount = 1000; + for (int i = 0; i < itemCount; i++) + { + set.Add(i); + } + + // Add some duplicates which should be ignored + for (int i = 0; i < 100; i++) + { + set.Add(i); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(itemCount, set.Count); + + // Create a new manager with the same storage to test recovery + var sut2 = CreateTestSystem(storage: sut.Storage); + var set2 = new DurableSet("largeSet", sut2.Manager, codec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Large set is correctly recovered + Assert.Equal(itemCount, set2.Count); + for (int i = 0; i < itemCount; i++) + { + Assert.Contains(i, set2); + } + } + + [Fact] + public async Task DurableSet_SetOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set1 = new DurableSet("set1", manager, codec, SessionPool); + var set2 = new DurableSet("set2", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Populate set1 with even numbers from 0 to 10 + for (int i = 0; i <= 10; i += 2) + { + set1.Add(i); + } + + // Populate set2 with numbers from 5 to 15 + for (int i = 5; i <= 15; i++) + { + set2.Add(i); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Act & Assert - Set operations + var set1HashSet = set1.ToHashSet(); + var set2HashSet = set2.ToHashSet(); + + // Intersection + var intersection = new HashSet(set1HashSet); + intersection.IntersectWith(set2HashSet); + Assert.Equal(new HashSet { 6, 8, 10 }, intersection); + + // Union + var union = new HashSet(set1HashSet); + union.UnionWith(set2HashSet); + Assert.Equal(new HashSet { 0, 2, 4, 6, 8, 10, 5, 7, 9, 11, 12, 13, 14, 15 }, union); + + // Difference (set1 - set2) + var difference = new HashSet(set1HashSet); + difference.ExceptWith(set2HashSet); + Assert.Equal(new HashSet { 0, 2, 4 }, difference); + } + + [Fact] + public async Task DurableSet_ExceptWith_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var set = new DurableSet("exceptSet", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add numbers from 0 to 9 + for (int i = 0; i < 10; i++) + { + set.Add(i); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Act - Remove even numbers + var evens = new List(); + for (int i = 0; i < 10; i += 2) + { + evens.Add(i); + } + + foreach (var even in evens) + { + set.Remove(even); + } + + await manager.WriteStateAsync(CancellationToken.None); + + // Assert - Should only contain odd numbers + Assert.Equal(5, set.Count); + for (int i = 1; i < 10; i += 2) + { + Assert.Contains(i, set); + } + } +} diff --git a/test/Orleans.Journaling.Tests/DurableValueTests.cs b/test/Orleans.Journaling.Tests/DurableValueTests.cs new file mode 100644 index 00000000000..c065800dec8 --- /dev/null +++ b/test/Orleans.Journaling.Tests/DurableValueTests.cs @@ -0,0 +1,116 @@ +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("BVT")] +public class DurableValueTests : StateMachineTestBase +{ + [Fact] + public async Task DurableValue_BasicOperations_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var durableValue = new DurableValue("testValue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Set initial value + durableValue.Value = "Hello World"; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("Hello World", durableValue.Value); + + // Act - Update value + durableValue.Value = "Updated Value"; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("Updated Value", durableValue.Value); + } + + [Fact] + public async Task DurableValue_Persistence_Test() + { + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var durableValue = new DurableValue("counter", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Modify and persist + durableValue.Value = 42; + await sut.Manager.WriteStateAsync(CancellationToken.None); + + // Create a new manager with the same storage + var sut2 = CreateTestSystem(storage: sut.Storage); + var durableValue2 = new DurableValue("counter", sut2.Manager, codec, SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Value should be recovered + Assert.Equal(42, durableValue2.Value); + } + + [Fact] + public async Task DurableValue_NullValue_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var durableValue = new DurableValue("nullableValue", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Set to null + durableValue.Value = null; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Null(durableValue.Value); + + // Act - Update to non-null + durableValue.Value = "Not null anymore"; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal("Not null anymore", durableValue.Value); + + // Act - Update back to null + durableValue.Value = null; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Null(durableValue.Value); + } + + [Fact] + public async Task DurableValue_ComplexType_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + var durableValue = new DurableValue("person", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Act + var person = new TestPerson { Id = 1, Name = "John Doe", Age = 30 }; + durableValue.Value = person; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.NotNull(durableValue.Value); + Assert.Equal(1, durableValue.Value.Id); + Assert.Equal("John Doe", durableValue.Value.Name); + Assert.Equal(30, durableValue.Value.Age); + + // Act - Update property + durableValue.Value.Age = 31; + await manager.WriteStateAsync(CancellationToken.None); + + // Assert + Assert.Equal(31, durableValue.Value.Age); + } +} diff --git a/test/Orleans.Journaling.Tests/Grains.cs b/test/Orleans.Journaling.Tests/Grains.cs new file mode 100644 index 00000000000..b088c00d652 --- /dev/null +++ b/test/Orleans.Journaling.Tests/Grains.cs @@ -0,0 +1,61 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Orleans.Journaling.Tests; + +[GenerateSerializer] +public sealed record TestDurableGrainState(string Name, int Counter); + +public class TestDurableGrain( + [FromKeyedServices("state")] IPersistentState state) : DurableGrain, ITestDurableGrain +{ + private readonly Guid _activationId = Guid.NewGuid(); + public Task GetName() => Task.FromResult(state.State.Name); + public Task GetCounter() => Task.FromResult(state.State.Counter); + + public async Task SetTestValues(string name, int counter) + { + state.State = new(name, counter); + await WriteStateAsync(); + } + + public Task GetActivationId() => Task.FromResult(_activationId); +} + +public class TestDurableGrainWithComplexState( + [FromKeyedServices("person")] IDurableValue person, + [FromKeyedServices("list")] IDurableList list) : DurableGrain, ITestDurableGrainWithComplexState +{ + private readonly Guid _activationId = Guid.NewGuid(); + private readonly IDurableValue _person = person; + private readonly IDurableList _list = list; + + public Task GetPerson() => Task.FromResult(_person.Value ?? new TestPerson()); + public Task> GetItems() => Task.FromResult>(_list.AsReadOnly()); + + public async Task SetTestValues(TestPerson person, List items) + { + _person.Value = person; + _list.Clear(); + _list.AddRange(items); + await WriteStateAsync(); + } + + public Task GetActivationId() => Task.FromResult(_activationId); +} + +public interface ITestDurableGrain : IGrainWithGuidKey +{ + Task GetActivationId(); + Task SetTestValues(string name, int counter); + Task GetName(); + Task GetCounter(); +} + +public interface ITestDurableGrainWithComplexState : IGrainWithGuidKey +{ + Task GetActivationId(); + Task SetTestValues(TestPerson person, List items); + Task GetPerson(); + Task> GetItems(); +} + diff --git a/test/Orleans.Journaling.Tests/ITestDurableGrainInterface.cs b/test/Orleans.Journaling.Tests/ITestDurableGrainInterface.cs new file mode 100644 index 00000000000..83e4cc4ea8d --- /dev/null +++ b/test/Orleans.Journaling.Tests/ITestDurableGrainInterface.cs @@ -0,0 +1,11 @@ +namespace Orleans.Journaling.Tests; + +/// +/// Interface for the test durable grain +/// +public interface ITestDurableGrainInterface : IGrainWithGuidKey +{ + Task GetActivationId(); + Task SetValues(string name, int counter); + Task<(string Name, int Counter)> GetValues(); +} \ No newline at end of file diff --git a/test/Orleans.Journaling.Tests/ITestMultiCollectionGrain.cs b/test/Orleans.Journaling.Tests/ITestMultiCollectionGrain.cs new file mode 100644 index 00000000000..79e3c95a9be --- /dev/null +++ b/test/Orleans.Journaling.Tests/ITestMultiCollectionGrain.cs @@ -0,0 +1,33 @@ +namespace Orleans.Journaling.Tests; + +/// +/// Interface for the test multi-collection grain +/// +public interface ITestMultiCollectionGrain : IGrainWithGuidKey +{ + Task GetActivationId(); + + // Dictionary operations + Task AddToDictionary(string key, int value); + Task RemoveFromDictionary(string key); + Task GetDictionaryValue(string key); + Task GetDictionaryCount(); + + // List operations + Task AddToList(string item); + Task RemoveListItemAt(int index); + Task GetListItem(int index); + Task GetListCount(); + + // Queue operations + Task AddToQueue(int item); + Task DequeueItem(); + Task PeekQueueItem(); + Task GetQueueCount(); + + // Set operations + Task AddToSet(string item); + Task RemoveFromSet(string item); + Task ContainsSetItem(string item); + Task GetSetCount(); +} \ No newline at end of file diff --git a/test/Orleans.Journaling.Tests/IntegrationTestFixture.cs b/test/Orleans.Journaling.Tests/IntegrationTestFixture.cs new file mode 100644 index 00000000000..41b00d1eda7 --- /dev/null +++ b/test/Orleans.Journaling.Tests/IntegrationTestFixture.cs @@ -0,0 +1,45 @@ +using System.Collections.Concurrent; +using Microsoft.Extensions.DependencyInjection; +using Orleans.TestingHost; +using Xunit; + +namespace Orleans.Journaling.Tests; + +/// +/// Base class for journaling tests with common setup using InProcessTestCluster +/// +public class IntegrationTestFixture : IAsyncLifetime +{ + public InProcessTestCluster Cluster { get; } + public IClusterClient Client => Cluster.Client; + + public IntegrationTestFixture() + { + var builder = new InProcessTestClusterBuilder(); + var storageProvider = new VolatileStateMachineStorageProvider(); + builder.ConfigureSilo((options, siloBuilder) => + { + siloBuilder.AddStateMachineStorage(); + siloBuilder.Services.AddSingleton(storageProvider); + }); + ConfigureTestCluster(builder); + Cluster = builder.Build(); + } + + protected virtual void ConfigureTestCluster(InProcessTestClusterBuilder builder) + { + } + + public virtual async Task InitializeAsync() + { + await Cluster.DeployAsync(); + } + + public virtual async Task DisposeAsync() + { + if (Cluster != null) + { + await Cluster.DisposeAsync(); + } + } +} diff --git a/test/Orleans.Journaling.Tests/JournalingAzureStorageTestConfiguration.cs b/test/Orleans.Journaling.Tests/JournalingAzureStorageTestConfiguration.cs new file mode 100644 index 00000000000..2f6abdd8904 --- /dev/null +++ b/test/Orleans.Journaling.Tests/JournalingAzureStorageTestConfiguration.cs @@ -0,0 +1,33 @@ +using TestExtensions; +using Xunit; + +namespace Orleans.Journaling.Tests; + +internal static class JournalingAzureStorageTestConfiguration +{ + public static void CheckPreconditionsOrThrow() + { + if (TestDefaultConfiguration.UseAadAuthentication) + { + Skip.If(string.IsNullOrEmpty(TestDefaultConfiguration.DataBlobUri.ToString()), "DataBlobUri is not set. Skipping test."); + } + else + { + Skip.If(string.IsNullOrEmpty(TestDefaultConfiguration.DataConnectionString), "DataConnectionString is not set. Skipping test."); + } + } + + public static AzureAppendBlobStateMachineStorageOptions ConfigureTestDefaults(this AzureAppendBlobStateMachineStorageOptions options) + { + if (TestDefaultConfiguration.UseAadAuthentication) + { + options.BlobServiceClient = new(TestDefaultConfiguration.DataBlobUri, TestDefaultConfiguration.TokenCredential); + } + else + { + options.BlobServiceClient = new(TestDefaultConfiguration.DataConnectionString); + } + + return options; + } +} diff --git a/test/Orleans.Journaling.Tests/LogSegmentTests.cs b/test/Orleans.Journaling.Tests/LogSegmentTests.cs new file mode 100644 index 00000000000..2482e65f32d --- /dev/null +++ b/test/Orleans.Journaling.Tests/LogSegmentTests.cs @@ -0,0 +1,391 @@ +using Azure.Storage.Blobs; +using Azure.Storage.Blobs.Specialized; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Orleans.Configuration.Internal; +using Orleans.Serialization; +using Orleans.Serialization.Serializers; +using Orleans.Serialization.Session; +using TestExtensions; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("AzureStorage"), TestCategory("Functional")] +public sealed class AzureStorageLogSegmentTests : LogSegmentTests +{ + public AzureStorageLogSegmentTests() + { + JournalingAzureStorageTestConfiguration.CheckPreconditionsOrThrow(); + } + + protected override void ConfigureServices(IServiceCollection services) + { + services.Configure(options => JournalingAzureStorageTestConfiguration.ConfigureTestDefaults(options)); + services.AddSingleton(); + services.AddFromExisting(); + services.AddFromExisting, AzureAppendBlobStateMachineStorageProvider>(); + } +} + +public sealed class InMemoryLogSegmentTests : LogSegmentTests +{ + protected override void ConfigureServices(IServiceCollection services) + { + services.AddSingleton(); + } +} + +/// +/// Base class for testing implementations. +/// Derived classes must implement to register the specific storage provider. +/// This class provides a suite of common tests for validating the behavior of +/// against different storage backends. +/// +public abstract class LogSegmentTests : IAsyncLifetime +{ + private IServiceProvider _serviceProvider = null!; + private SiloLifecycleSubject? _siloLifecycle; + private IStateMachineStorageProvider _storageProvider = null!; + + public virtual async Task InitializeAsync() + { + var services = new ServiceCollection(); + services.AddSerializer(); + services.AddLogging(); + ConfigureServices(services); + _serviceProvider = services.BuildServiceProvider(); + _siloLifecycle = new SiloLifecycleSubject(_serviceProvider.GetRequiredService>()); + _storageProvider = _serviceProvider.GetRequiredService(); + var participants = _serviceProvider.GetServices>(); + foreach (var participant in participants) + { + participant.Participate(_siloLifecycle); + } + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60)); + await _siloLifecycle.OnStart(cts.Token); + } + + public async Task DisposeAsync() + { + if (_siloLifecycle is not null) + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60)); + await _siloLifecycle.OnStop(cts.Token); + } + } + + protected abstract void ConfigureServices(IServiceCollection services); + + private (StateMachineManager Manager, DurableList List, IStateMachineStorage Storage) CreateTestComponents(string listName, GrainId grainId) + { + var sessionPool = _serviceProvider.GetRequiredService(); + var codecProvider = _serviceProvider.GetRequiredService(); + var grainContext = new TestGrainContext(grainId); // Use provided GrainId + var storage = _storageProvider.Create(grainContext); + var manager = new StateMachineManager(storage, _serviceProvider.GetRequiredService>(), sessionPool); + var list = new DurableList(listName, manager, codecProvider.GetCodec(), sessionPool); + return (manager, list, storage); + } + + /// + /// Tests basic Add, Update (by index), and RemoveAt operations. + /// + [SkippableFact] + public async Task DurableList_BasicOperations_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var grainId = GrainId.Create("test-grain", $"basicList-{Guid.NewGuid()}"); // Unique ID for this test run + var (manager, list, _) = CreateTestComponents("basicList", grainId); + await manager.InitializeAsync(cts.Token); + + list.Add("one"); + list.Add("two"); + list.Add("three"); + await manager.WriteStateAsync(cts.Token); + + Assert.Equal(3, list.Count); + Assert.Equal("one", list[0]); + Assert.Equal("two", list[1]); + Assert.Equal("three", list[2]); + + list[1] = "updated"; + await manager.WriteStateAsync(cts.Token); + + Assert.Equal("updated", list[1]); + + list.RemoveAt(0); + await manager.WriteStateAsync(cts.Token); + + Assert.Equal(2, list.Count); + Assert.Equal("updated", list[0]); + Assert.Equal("three", list[1]); + } + + /// + /// Tests that list state is correctly persisted and can be recovered by a new instance. + /// + [SkippableFact] + public async Task DurableList_Persistence_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var listName = "persistenceList"; + var grainId = GrainId.Create("test-grain", $"{listName}-{Guid.NewGuid()}"); // Consistent GrainId for recovery + var (manager1, list1, storage) = CreateTestComponents(listName, grainId); + await manager1.InitializeAsync(cts.Token); + + list1.Add("one"); + list1.Add("two"); + list1.Add("three"); + await manager1.WriteStateAsync(cts.Token); + + var sessionPool = _serviceProvider.GetRequiredService(); + var codecProvider = _serviceProvider.GetRequiredService(); + var manager2 = new StateMachineManager(storage, _serviceProvider.GetRequiredService>(), sessionPool); + var list2 = new DurableList(listName, manager2, codecProvider.GetCodec(), sessionPool); + await manager2.InitializeAsync(cts.Token); + + Assert.Equal(3, list2.Count); + Assert.Equal("one", list2[0]); + Assert.Equal("two", list2[1]); + Assert.Equal("three", list2[2]); + } + + /// + /// Tests storing and retrieving complex objects, including updates to mutable properties. + /// + [SkippableFact] + public async Task DurableList_ComplexValues_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var listName = "personList"; + var grainId = GrainId.Create("test-grain", $"{listName}-{Guid.NewGuid()}"); // Consistent GrainId for recovery + var (manager, list, _) = CreateTestComponents(listName, grainId); + await manager.InitializeAsync(cts.Token); + + var person1 = new TestPerson { Id = 1, Name = "John", Age = 30 }; + var person2 = new TestPerson { Id = 2, Name = "Jane", Age = 25 }; + + list.Add(person1); + list.Add(person2); + await manager.WriteStateAsync(cts.Token); + + Assert.Equal(2, list.Count); + Assert.Equal("John", list[0].Name); + Assert.Equal(25, list[1].Age); + + list[0] = list[0] with { Age = 31 }; + await manager.WriteStateAsync(cts.Token); + + // Re-read to confirm persistence of the change + var (manager2, list2, _) = CreateTestComponents(listName, grainId); // Use same GrainId to reload + await manager2.InitializeAsync(cts.Token); + Assert.Equal(31, list2[0].Age); + } + + /// + /// Tests the Clear operation and its persistence. + /// + [SkippableFact] + public async Task DurableList_Clear_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var listName = "clearList"; + var grainId = GrainId.Create("test-grain", $"{listName}-{Guid.NewGuid()}"); // Consistent GrainId for recovery + var (manager, list, _) = CreateTestComponents(listName, grainId); + await manager.InitializeAsync(cts.Token); + + list.Add("one"); + list.Add("two"); + list.Add("three"); + await manager.WriteStateAsync(cts.Token); + + list.Clear(); + await manager.WriteStateAsync(cts.Token); + + Assert.Empty(list); + + // Verify persistence of Clear + var (manager2, list2, _) = CreateTestComponents(listName, grainId); // Use same GrainId to reload + await manager2.InitializeAsync(cts.Token); + Assert.Empty(list2); + } + + /// + /// Tests the Contains method and Remove (by value) operation. + /// + [SkippableFact] + public async Task DurableList_Contains_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var grainId = GrainId.Create("test-grain", $"containsList-{Guid.NewGuid()}"); // Unique ID for this test run + var (manager, list, _) = CreateTestComponents("containsList", grainId); + await manager.InitializeAsync(cts.Token); + + list.Add("one"); + list.Add("two"); + list.Add("three"); + await manager.WriteStateAsync(cts.Token); + + Assert.Contains("two", list); + Assert.DoesNotContain("four", list); + + list.Remove("two"); + await manager.WriteStateAsync(cts.Token); + + Assert.DoesNotContain("two", list); + } + + /// + /// Tests Insert and Remove (by value) operations. + /// + [SkippableFact] + public async Task DurableList_InsertAndRemove_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var grainId = GrainId.Create("test-grain", $"insertList-{Guid.NewGuid()}"); // Unique ID for this test run + var (manager, list, _) = CreateTestComponents("insertList", grainId); + await manager.InitializeAsync(cts.Token); + + list.Add("one"); + list.Add("three"); + await manager.WriteStateAsync(cts.Token); + + list.Insert(1, "two"); + await manager.WriteStateAsync(cts.Token); + + Assert.Equal(3, list.Count); + Assert.Equal("one", list[0]); + Assert.Equal("two", list[1]); + Assert.Equal("three", list[2]); + + bool removed = list.Remove("two"); + await manager.WriteStateAsync(cts.Token); + + Assert.True(removed); + Assert.Equal(2, list.Count); + Assert.Equal("one", list[0]); + Assert.Equal("three", list[1]); + } + + /// + /// Tests list enumeration using ToList() (which relies on GetEnumerator). + /// + [SkippableFact] + public async Task DurableList_Enumeration_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + var grainId = GrainId.Create("test-grain", $"enumList-{Guid.NewGuid()}"); // Unique ID for this test run + var (manager, list, _) = CreateTestComponents("enumList", grainId); + await manager.InitializeAsync(cts.Token); + + var expectedItems = new List { "one", "two", "three" }; + + foreach (var item in expectedItems) + { + list.Add(item); + } + + await manager.WriteStateAsync(cts.Token); + + var actualItems = list.ToList(); + + Assert.Equal(expectedItems, actualItems); + } + + /// + /// Tests behavior with a larger number of operations (add, update) and multiple writes, + /// potentially triggering snapshotting behavior in the storage provider. Also tests recovery. + /// + [SkippableFact] + public async Task DurableList_LargeNumberOfOperations_And_Snapshot_Test() + { + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60)); // Increased timeout + var listName = "largeList"; + var grainId = GrainId.Create("test-grain", $"{listName}-{Guid.NewGuid()}"); // Consistent GrainId for recovery + var (manager, list, storage) = CreateTestComponents(listName, grainId); + await manager.InitializeAsync(cts.Token); + + const int itemCount = 100; // Reduced for faster testing, increase if needed + for (int i = 0; i < itemCount; i++) + { + list.Add(i); + } + + // Write multiple times to potentially trigger snapshotting + for (int j = 0; j < 5; ++j) + { + await manager.WriteStateAsync(cts.Token); + } + + Assert.Equal(itemCount, list.Count); + + for (int i = 0; i < itemCount; i += 2) + { + list[i] = list[i] * 2; + } + + // Write multiple times again + for (int j = 0; j < 5; ++j) + { + await manager.WriteStateAsync(cts.Token); + } + + for (int i = 0; i < itemCount; i++) + { + if (i % 2 == 0) Assert.Equal(i * 2, list[i]); + else Assert.Equal(i, list[i]); + } + + // Test recovery (potentially from snapshot) + var sessionPool = _serviceProvider.GetRequiredService(); + var codecProvider = _serviceProvider.GetRequiredService(); + var manager2 = new StateMachineManager(storage, _serviceProvider.GetRequiredService>(), sessionPool); // Reuses the storage object linked via grainId + var list2 = new DurableList(listName, manager2, codecProvider.GetCodec(), sessionPool); + await manager2.InitializeAsync(cts.Token); + + Assert.Equal(itemCount, list2.Count); + for (int i = 0; i < itemCount; i++) + { + if (i % 2 == 0) Assert.Equal(i * 2, list2[i]); + else Assert.Equal(i, list2[i]); + } + } + + // Keep TestGrainContext and add TestPerson record needed for one of the tests + [GenerateSerializer, Immutable] + internal sealed record TestPerson + { + [Id(0)] + public int Id { get; init; } + [Id(1)] + public string Name { get; init; } = ""; + [Id(2)] + public int Age { get; init; } + } + + internal sealed class TestGrainContext(GrainId grainId) : IGrainContext + { + public GrainReference GrainReference => throw new NotImplementedException(); + public GrainId GrainId => grainId; + public object? GrainInstance => throw new NotImplementedException(); + public ActivationId ActivationId => throw new NotImplementedException(); + public GrainAddress Address => throw new NotImplementedException(); + public IServiceProvider ActivationServices => throw new NotImplementedException(); + public IGrainLifecycle ObservableLifecycle => throw new NotImplementedException(); + public IWorkItemScheduler Scheduler => throw new NotImplementedException(); + public Task Deactivated => throw new NotImplementedException(); + + public void Activate(Dictionary? requestContext, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public void Deactivate(DeactivationReason deactivationReason, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public bool Equals(IGrainContext? other) => throw new NotImplementedException(); + public TComponent? GetComponent() where TComponent : class => throw new NotImplementedException(); + public TTarget? GetTarget() where TTarget : class => throw new NotImplementedException(); + public void Migrate(Dictionary? requestContext, CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public void ReceiveMessage(object message) => throw new NotImplementedException(); + public void Rehydrate(IRehydrationContext context) => throw new NotImplementedException(); + public void SetComponent(TComponent? value) where TComponent : class => throw new NotImplementedException(); + } +} diff --git a/test/Orleans.Journaling.Tests/Orleans.Journaling.Tests.csproj b/test/Orleans.Journaling.Tests/Orleans.Journaling.Tests.csproj new file mode 100644 index 00000000000..db5f347c593 --- /dev/null +++ b/test/Orleans.Journaling.Tests/Orleans.Journaling.Tests.csproj @@ -0,0 +1,30 @@ + + + + enable + enable + true + $(TestTargetFrameworks) + + + + $(NoWarn);ORLEANSEXP005 + + + + + + + + + + + + + + + + + + + diff --git a/test/Orleans.Journaling.Tests/StateMachineManagerTests.cs b/test/Orleans.Journaling.Tests/StateMachineManagerTests.cs new file mode 100644 index 00000000000..7d5e47b2a21 --- /dev/null +++ b/test/Orleans.Journaling.Tests/StateMachineManagerTests.cs @@ -0,0 +1,222 @@ +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Orleans.Journaling.Tests; + +[TestCategory("BVT")] +public class StateMachineManagerTests : StateMachineTestBase +{ + [Fact] + public async Task StateMachineManager_RegisterStateMachine_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var codec = CodecProvider.GetCodec(); + + // Act - Register state machines + var dictionary = new DurableDictionary("dict1", manager, CodecProvider.GetCodec(), codec, SessionPool); + var list = new DurableList("list1", manager, CodecProvider.GetCodec(), SessionPool); + var queue = new DurableQueue("queue1", manager, codec, SessionPool); + await sut.Lifecycle.OnStart(); + + // Add some data + dictionary.Add("key1", 1); + list.Add("item1"); + queue.Enqueue(42); + + // Write state + await manager.WriteStateAsync(CancellationToken.None); + + // Assert - Data is correctly stored + Assert.Equal(1, dictionary["key1"]); + Assert.Equal("item1", list[0]); + Assert.Equal(42, queue.Peek()); + } + + [Fact] + public async Task StateMachineManager_StateRecovery_Test() + { + // Arrange + var sut = CreateTestSystem(); + + // Create and populate state machines + var dictionary = new DurableDictionary("dict1", sut.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + var list = new DurableList("list1", sut.Manager, CodecProvider.GetCodec(), SessionPool); + await sut.Lifecycle.OnStart(); + + dictionary.Add("key1", 1); + dictionary.Add("key2", 2); + list.Add("item1"); + list.Add("item2"); + + await sut.Manager.WriteStateAsync(CancellationToken.None); + + // Act - Create new manager with same storage + var sut2 = CreateTestSystem(storage: sut.Storage); + var recoveredDict = new DurableDictionary("dict1", sut2.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + var recoveredList = new DurableList("list1", sut2.Manager, CodecProvider.GetCodec(), SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - State should be recovered + Assert.Equal(2, recoveredDict.Count); + Assert.Equal(1, recoveredDict["key1"]); + Assert.Equal(2, recoveredDict["key2"]); + + Assert.Equal(2, recoveredList.Count); + Assert.Equal("item1", recoveredList[0]); + Assert.Equal("item2", recoveredList[1]); + } + + [Fact] + public async Task StateMachineManager_MultipleWriteStates_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var dictionary = new DurableDictionary("dict1", sut.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Multiple operations with WriteState in between + dictionary.Add("key1", 1); + await manager.WriteStateAsync(CancellationToken.None); + + dictionary.Add("key2", 2); + await manager.WriteStateAsync(CancellationToken.None); + + dictionary["key1"] = 10; + await manager.WriteStateAsync(CancellationToken.None); + + dictionary.Remove("key2"); + await manager.WriteStateAsync(CancellationToken.None); + + // Assert - Final state is correct + Assert.Single(dictionary); + Assert.Equal(10, dictionary["key1"]); + Assert.False(dictionary.ContainsKey("key2")); + + // Create new manager to verify recovery + var sut2 = CreateTestSystem(storage: sut.Storage); + var recoveredDict = new DurableDictionary("dict1", sut2.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - Recovery should have final state + Assert.Single(recoveredDict); + Assert.Equal(10, recoveredDict["key1"]); + Assert.False(recoveredDict.ContainsKey("key2")); + } + + [Fact] + public async Task StateMachineManager_MultipleStateMachines_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + + // Create multiple state machines with different types + var intDict = new DurableDictionary("intDict", manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + var stringList = new DurableList("stringList", manager, CodecProvider.GetCodec(), SessionPool); + var personValue = new DurableValue("personValue", manager, CodecProvider.GetCodec(), SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Populate all state machines + intDict.Add(1, "one"); + intDict.Add(2, "two"); + + stringList.Add("item1"); + stringList.Add("item2"); + + personValue.Value = new TestPerson { Id = 100, Name = "Test Person", Age = 30 }; + + await manager.WriteStateAsync(CancellationToken.None); + + // Assert - All should have correct values + Assert.Equal(2, intDict.Count); + Assert.Equal("one", intDict[1]); + + Assert.Equal(2, stringList.Count); + Assert.Equal("item1", stringList[0]); + + Assert.NotNull(personValue.Value); + Assert.Equal(100, personValue.Value.Id); + Assert.Equal("Test Person", personValue.Value.Name); + + // Create new manager to verify recovery of multiple state machines + var sut2 = CreateTestSystem(storage: sut.Storage); + var recoveredIntDict = new DurableDictionary("intDict", sut2.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + var recoveredStringList = new DurableList("stringList", sut2.Manager, CodecProvider.GetCodec(), SessionPool); + var recoveredPersonValue = new DurableValue("personValue", sut2.Manager, CodecProvider.GetCodec(), SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - All should be recovered with correct values + Assert.Equal(2, recoveredIntDict.Count); + Assert.Equal("one", recoveredIntDict[1]); + + Assert.Equal(2, recoveredStringList.Count); + Assert.Equal("item1", recoveredStringList[0]); + + Assert.NotNull(recoveredPersonValue.Value); + Assert.Equal(100, recoveredPersonValue.Value.Id); + Assert.Equal("Test Person", recoveredPersonValue.Value.Name); + } + + [Fact] + public async Task StateMachineManager_Concurrency_Test() + { + // Arrange + var sut = CreateTestSystem(); + var manager = sut.Manager; + var dict1 = new DurableDictionary("dict1", manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + var dict2 = new DurableDictionary("dict2", manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Simulate concurrent operations on different state machines + dict1.Add("key1", 1); + dict2.Add("key1", 100); + + dict1.Add("key2", 2); + dict2.Add("key2", 200); + + await manager.WriteStateAsync(CancellationToken.None); + + // Assert - Both state machines should have their correct values + Assert.Equal(2, dict1.Count); + Assert.Equal(2, dict2.Count); + + Assert.Equal(1, dict1["key1"]); + Assert.Equal(100, dict2["key1"]); + + Assert.Equal(2, dict1["key2"]); + Assert.Equal(200, dict2["key2"]); + } + + [Fact] + public async Task StateMachineManager_LargeStateRecovery_Test() + { + // Arrange + var sut = CreateTestSystem(); + var largeDict = new DurableDictionary("largeDict", sut.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + await sut.Lifecycle.OnStart(); + + // Act - Add many items + const int itemCount = 1000; + for (int i = 0; i < itemCount; i++) + { + largeDict.Add(i, $"Value {i}"); + } + + await sut.Manager.WriteStateAsync(CancellationToken.None); + + // Create new manager for recovery + var sut2 = CreateTestSystem(storage: sut.Storage); + var recoveredDict = new DurableDictionary("largeDict", sut2.Manager, CodecProvider.GetCodec(), CodecProvider.GetCodec(), SessionPool); + await sut2.Lifecycle.OnStart(); + + // Assert - All items should be recovered + Assert.Equal(itemCount, recoveredDict.Count); + for (int i = 0; i < itemCount; i++) + { + Assert.Equal($"Value {i}", recoveredDict[i]); + } + } +} diff --git a/test/Orleans.Journaling.Tests/StateMachineTestBase.cs b/test/Orleans.Journaling.Tests/StateMachineTestBase.cs new file mode 100644 index 00000000000..78adb0435ac --- /dev/null +++ b/test/Orleans.Journaling.Tests/StateMachineTestBase.cs @@ -0,0 +1,66 @@ +using System.Collections.Immutable; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Orleans.Serialization; +using Orleans.Serialization.Serializers; +using Orleans.Serialization.Session; + +namespace Orleans.Journaling.Tests; + +/// +/// Base class for journaling tests with common setup +/// +public abstract class StateMachineTestBase +{ + protected readonly ServiceProvider ServiceProvider; + protected readonly SerializerSessionPool SessionPool; + protected readonly ICodecProvider CodecProvider; + protected readonly ILoggerFactory LoggerFactory; + + protected StateMachineTestBase() + { + var services = new ServiceCollection(); + services.AddSerializer(); + services.AddLogging(builder => builder.AddConsole()); + + ServiceProvider = services.BuildServiceProvider(); + SessionPool = ServiceProvider.GetRequiredService(); + CodecProvider = ServiceProvider.GetRequiredService(); + LoggerFactory = ServiceProvider.GetRequiredService(); + } + + /// + /// Creates an in-memory storage for testing + /// + protected virtual IStateMachineStorage CreateStorage() + { + return new VolatileStateMachineStorage(); + } + + /// + /// Creates a state machine manager with in-memory storage + /// + internal (IStateMachineManager Manager, IStateMachineStorage Storage, ILifecycleSubject Lifecycle) CreateTestSystem(IStateMachineStorage? storage = null) + { + storage ??= CreateStorage(); + var logger = LoggerFactory.CreateLogger(); + var manager = new StateMachineManager(storage, logger, SessionPool); + var lifecycle = new GrainLifecycle(LoggerFactory.CreateLogger()); + (manager as ILifecycleParticipant)?.Participate(lifecycle); + return (manager, storage, lifecycle); + } + + private class GrainLifecycle(ILogger logger) : LifecycleSubject(logger), IGrainLifecycle + { + private static readonly ImmutableDictionary StageNames = GetStageNames(typeof(GrainLifecycleStage)); + + public void AddMigrationParticipant(IGrainMigrationParticipant participant) { } + public void RemoveMigrationParticipant(IGrainMigrationParticipant participant) { } + + protected override string GetStageName(int stage) + { + if (StageNames.TryGetValue(stage, out var result)) return result; + return base.GetStageName(stage); + } + } +} diff --git a/test/Orleans.Journaling.Tests/TestDurableGrain.cs b/test/Orleans.Journaling.Tests/TestDurableGrain.cs new file mode 100644 index 00000000000..24ce4352790 --- /dev/null +++ b/test/Orleans.Journaling.Tests/TestDurableGrain.cs @@ -0,0 +1,23 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Orleans.Journaling.Tests; + +public class DurableValueTestGrain( + [FromKeyedServices("name")] IDurableValue name, + [FromKeyedServices("counter")] IDurableValue counter) : DurableGrain, ITestDurableGrainInterface +{ + private readonly Guid _activationId = Guid.NewGuid(); + private readonly IDurableValue _name = name; + private readonly IDurableValue _counter = counter; + + public Task SetValues(string name, int counter) + { + _name.Value = name; + _counter.Value = counter; + return WriteStateAsync().AsTask(); + } + + public Task<(string Name, int Counter)> GetValues() => Task.FromResult((_name.Value!, _counter.Value!)); + + public Task GetActivationId() => Task.FromResult(_activationId); +} \ No newline at end of file diff --git a/test/Orleans.Journaling.Tests/TestMultiCollectionGrain.cs b/test/Orleans.Journaling.Tests/TestMultiCollectionGrain.cs new file mode 100644 index 00000000000..cc68385b7c1 --- /dev/null +++ b/test/Orleans.Journaling.Tests/TestMultiCollectionGrain.cs @@ -0,0 +1,106 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace Orleans.Journaling.Tests; + +public class TestMultiCollectionGrain( + [FromKeyedServices("dictionary")] IDurableDictionary dictionary, + [FromKeyedServices("list")] IDurableList list, + [FromKeyedServices("queue")] IDurableQueue queue, + [FromKeyedServices("set")] IDurableSet set) : DurableGrain, ITestMultiCollectionGrain +{ + private readonly Guid _activationId = Guid.NewGuid(); + + // Dictionary operations + public async Task AddToDictionary(string key, int value) + { + dictionary[key] = value; + await WriteStateAsync(); + } + + public async Task RemoveFromDictionary(string key) + { + dictionary.Remove(key); + await WriteStateAsync(); + } + + public async Task GetDictionaryValue(string key) + { + return await Task.FromResult(dictionary[key]); + } + + public async Task GetDictionaryCount() + { + return await Task.FromResult(dictionary.Count); + } + + // List operations + public async Task AddToList(string item) + { + list.Add(item); + await WriteStateAsync(); + } + + public async Task RemoveListItemAt(int index) + { + list.RemoveAt(index); + await WriteStateAsync(); + } + + public async Task GetListItem(int index) + { + return await Task.FromResult(list[index]); + } + + public async Task GetListCount() + { + return await Task.FromResult(list.Count); + } + + // Queue operations + public async Task AddToQueue(int item) + { + queue.Enqueue(item); + await WriteStateAsync(); + } + + public async Task DequeueItem() + { + var item = queue.Dequeue(); + return await Task.FromResult(item); + } + + public async Task PeekQueueItem() + { + return await Task.FromResult(queue.Peek()); + } + + public async Task GetQueueCount() + { + return await Task.FromResult(queue.Count); + } + + // Set operations + public async Task AddToSet(string item) + { + set.Add(item); + await WriteStateAsync(); + } + + public async Task RemoveFromSet(string item) + { + set.Remove(item); + await WriteStateAsync(); + } + + public async Task ContainsSetItem(string item) + { + return await Task.FromResult(set.Contains(item)); + } + + public async Task GetSetCount() + { + return await Task.FromResult(set.Count); + } + + public Task GetActivationId() => Task.FromResult(_activationId); +} diff --git a/test/Orleans.Journaling.Tests/TestPerson.cs b/test/Orleans.Journaling.Tests/TestPerson.cs new file mode 100644 index 00000000000..e0588bf272d --- /dev/null +++ b/test/Orleans.Journaling.Tests/TestPerson.cs @@ -0,0 +1,15 @@ +namespace Orleans.Journaling.Tests; + +/// +/// Test class used for complex object serialization testing +/// +[GenerateSerializer] +public record class TestPerson +{ + [Id(0)] + public int Id { get; set; } + [Id(1)] + public string? Name { get; set; } + [Id(2)] + public int Age { get; set; } +} diff --git a/test/Orleans.Serialization.UnitTests/ArcBufferWriterTests.cs b/test/Orleans.Serialization.UnitTests/ArcBufferWriterTests.cs new file mode 100644 index 00000000000..1b04fbb255e --- /dev/null +++ b/test/Orleans.Serialization.UnitTests/ArcBufferWriterTests.cs @@ -0,0 +1,915 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using Orleans.Serialization.Buffers; +using Xunit; + +namespace Orleans.Serialization.UnitTests; + +[Trait("Category", "BVT")] +public class ArcBufferWriterTests +{ + private const int PageSize = ArcBufferWriter.MinimumPageSize; +#if NET6_0_OR_GREATER + private readonly Random _random = Random.Shared; +#else + private readonly Random _random = new Random(); +#endif + + /// + /// Verifies that writing data larger than a single page results in correct multi-page buffer management and correct data retrieval. + /// + [Fact] + public void MultiPageBuffer_CorrectlyHandlesLargeWritesAndRetrieval() + { + using var bufferWriter = new ArcBufferWriter(); + var randomData = new byte[PageSize * 3]; + _random.NextBytes(randomData); + int[] writeSizes = [1, 52, 125, 4096]; + var i = 0; + while (bufferWriter.Length < randomData.Length) + { + var writeSize = Math.Min(randomData.Length - bufferWriter.Length, writeSizes[i++ % writeSizes.Length]); + bufferWriter.Write(randomData); + } + + { + using var wholeBuffer = bufferWriter.PeekSlice(randomData.Length); + Assert.Equal(3, wholeBuffer.Pages.Count()); + Assert.Equal(3, wholeBuffer.PageSegments.Count()); + Assert.Equal(3, wholeBuffer.MemorySegments.Count()); + Assert.Equal(3, wholeBuffer.ArraySegments.Count()); + Assert.Equal(randomData, wholeBuffer.AsReadOnlySequence().ToArray()); + + { + using var newWriter = new ArcBufferWriter(); + newWriter.Write(wholeBuffer.AsReadOnlySequence()); + + Span headerBytes = stackalloc byte[8]; + var result = newWriter.Peek(in headerBytes); + Assert.True(result.Length >= headerBytes.Length); + Assert.Equal(randomData[0..headerBytes.Length], result[..headerBytes.Length].ToArray()); + var copiedData = new byte[newWriter.Length]; + newWriter.Peek(copiedData); + newWriter.AdvanceReader(copiedData.Length); + Assert.Equal(0, newWriter.Length); + Assert.Equal(randomData, copiedData); + } + + var spanCount = 0; + foreach (var span in wholeBuffer.SpanSegments) + { + Assert.Equal(PageSize, span.Length); + var spanArray = span.ToArray(); + Assert.Equal(spanArray, wholeBuffer.ArraySegments.Skip(spanCount).Take(1).Single().ToArray()); + Assert.Equal(spanArray, wholeBuffer.MemorySegments.Skip(spanCount).Take(1).Single().ToArray()); + Assert.Equal(spanArray, wholeBuffer.PageSegments.Skip(spanCount).Take(1).Single().Span.ToArray()); + Assert.Equal(spanArray, wholeBuffer.PageSegments.Skip(spanCount).Take(1).Single().Memory.ToArray()); + Assert.Equal(spanArray, wholeBuffer.PageSegments.Skip(spanCount).Take(1).Single().ArraySegment.ToArray()); + Assert.Equal(spanArray, wholeBuffer.AsReadOnlySequence().Slice(spanCount * PageSize, PageSize).ToArray()); + ++spanCount; + } + + Assert.Equal(3, spanCount); + } + + Assert.Equal(randomData.Length, bufferWriter.Length); + + { + using var peeked = bufferWriter.PeekSlice(3000); + using var slice = bufferWriter.ConsumeSlice(3000); + var sliceArray = slice.ToArray(); + Assert.Equal(randomData.AsSpan(0, 3000).ToArray(), sliceArray); + Assert.Equal(sliceArray, peeked.ToArray()); + Assert.Equal(sliceArray, peeked.AsReadOnlySequence().ToArray()); + + Assert.Equal(randomData.Length - sliceArray.Length, bufferWriter.Length); + } + + { + using var peeked = bufferWriter.PeekSlice(3000); + using var slice = bufferWriter.ConsumeSlice(3000); + var sliceArray = slice.ToArray(); + Assert.Equal(randomData.AsSpan(3000, 3000).ToArray(), sliceArray); + Assert.Equal(sliceArray, peeked.ToArray()); + Assert.Equal(sliceArray, slice.AsReadOnlySequence().ToArray()); + + Assert.Equal(randomData.Length - sliceArray.Length * 2, bufferWriter.Length); + } + + Assert.Equal(randomData.Length - 6000, bufferWriter.Length); + } + + /// + /// Verifies that page reference counts and versions are managed correctly as slices are consumed and disposed. + /// + [Fact] + public void PageBufferManagement_TracksReferenceCountsAndVersions() + { + var bufferWriter = new ArcBufferWriter(); + var randomData = new byte[PageSize * 12]; + _random.NextBytes(randomData); + bufferWriter.Write(randomData); + + var peeked = bufferWriter.PeekSlice(randomData.Length); + var pages = peeked.Pages.ToList(); + peeked.Dispose(); + + var expected = pages.Select((p, i) => (p.Version, p.ReferenceCount)).ToList(); + CheckPages(pages, expected); + + var slice = bufferWriter.ConsumeSlice(PageSize - 1); + slice.Dispose(); + + CheckPages(pages, expected); + + slice = bufferWriter.ConsumeSlice(1); + CheckPages(pages, expected); + slice.Dispose(); + + expected[0] = (expected[0].Version + 1, 0); + CheckPages(pages, expected); + + slice = bufferWriter.ConsumeSlice(PageSize); + CheckPages(pages, expected); + slice.Dispose(); + + expected[1] = (expected[1].Version + 1, 0); + CheckPages(pages, expected); + + slice = bufferWriter.ConsumeSlice(PageSize + 1); + expected[3] = (expected[3].Version, expected[3].ReferenceCount + 1); + CheckPages(pages, expected); + slice.Dispose(); + + expected[2] = (expected[2].Version + 1, 0); + expected[3] = (expected[3].Version, expected[3].ReferenceCount - 1); + CheckPages(pages, expected); + + Assert.Equal(randomData.Length - 1 - PageSize * 3, bufferWriter.Length); + + bufferWriter.Dispose(); + expected = expected.Take(3).Concat(expected.Skip(3).Select(e => (e.Version + 1, 0))).ToList(); + CheckPages(pages, expected); + + Assert.Throws(() => bufferWriter.Length); + + static void CheckPages(List pages, List<(int Version, int ReferenceCount)> expectedValues) + { + var index = 0; + foreach (var page in pages) + { + var expected = expectedValues[index]; + CheckPage(page, expected.Version, expected.ReferenceCount); + ++index; + } + } + + static void CheckPage(ArcBufferPage page, int expectedVersion, int expectedRefCount) + { + Assert.Equal(expectedVersion, page.Version); + Assert.Equal(expectedRefCount, page.ReferenceCount); + } + } + + /// + /// Verifies that ReplenishBuffers provides correct buffer segments for socket-like reads and that all pages are eventually freed. + /// + [Fact] + public void ReplenishBuffers_ProvidesSegmentsAndFreesPages() + { + var bufferWriter = new ArcBufferWriter(); + var randomData = new byte[PageSize * 16]; + _random.NextBytes(randomData); + bufferWriter.Write([0]); + var pages = new List(); + var firstSlice = bufferWriter.ConsumeSlice(1); + var firstPage = firstSlice.Pages.First(); + firstSlice.Dispose(); + + var buffers = new List>(capacity: 16); + var consumed = new List(); + int[] socketReadSizes = [256, 4096, 76, 12805, 4096, 26, 8094, 12345, 1, 0, 12345]; + int[] messageReadSizes = [8, 1020, 8, 902, 8, 1203, 8, 8045, 0, 12034, 8, 1101, 8, 4096]; + var messageReadIndex = 0; + + ReadOnlySpan socket = randomData; + foreach (var readSize in socketReadSizes) + { + bufferWriter.ReplenishBuffers(buffers); + + // Simulate reading from a socket. + Read(ref socket, readSize, buffers); + MaintainBufferList(buffers, readSize); + bufferWriter.AdvanceWriter(readSize); + + // Add the newly allocated pages to the list for test assertion purposes. + using (var peeked = bufferWriter.PeekSlice(bufferWriter.Length)) + { + pages.AddRange(peeked.Pages.Where(p => !pages.Contains(p))); + } + + // Simulate consuming the socket data. + while (bufferWriter.Length > messageReadSizes[messageReadIndex % messageReadSizes.Length]) + { + consumed.Add(bufferWriter.ConsumeSlice(messageReadSizes[messageReadIndex++ % messageReadSizes.Length])); + } + } + + consumed.Add(bufferWriter.ConsumeSlice(bufferWriter.Length)); + + var totalReadSize = socketReadSizes.Sum(); + Assert.Equal(totalReadSize, consumed.Sum(c => c.Length)); + var consumedData = new byte[totalReadSize]; + var consumerSpan = consumedData.AsSpan(); + foreach (var buffer in consumed) + { + buffer.CopyTo(consumerSpan); + consumerSpan = consumerSpan[buffer.Length..]; + } + + Assert.Equal(randomData[..totalReadSize], consumedData); + foreach (var buffer in consumed) + { + buffer.Dispose(); + } + + bufferWriter.Dispose(); + + // Check that all pages were freed. + foreach (var page in pages) + { + Assert.Equal(0, page.ReferenceCount); + } + + static void MaintainBufferList(List> buffers, int readSize) + { + while (readSize > 0) + { + if (buffers[0].Count <= readSize) + { + // Consume the buffer completely. + readSize -= buffers[0].Count; + buffers.RemoveAt(0); + } + else + { + // Consume the buffer partially. + buffers[0] = new(buffers[0].Array, buffers[0].Offset + readSize, buffers[0].Count - readSize); + break; + } + } + } + + static void Read(ref ReadOnlySpan socket, int readSize, List> buffers) + { + var payload = socket[..readSize]; + socket = socket[readSize..]; + var bufferIndex = 0; + while (!payload.IsEmpty) + { + var output = buffers[bufferIndex]; + var amount = Math.Min(output.Count, payload.Length); + payload[..amount].CopyTo(output); + payload = payload[amount..]; + ++bufferIndex; + } + } + } + + /// + /// Verifies that writing a buffer of a given size results in the correct reported length. + /// + [Fact] + public void WriteBuffer_UpdatesLengthCorrectly() + { + using var buffer = new ArcBufferWriter(); + var data = new byte[1024]; + _random.NextBytes(data); + buffer.Write(data); + + // Assert + Assert.Equal(data.Length, buffer.Length); + } + + /// + /// Verifies that peeking at a slice returns the correct data without consuming it. + /// + [Fact] + public void PeekSlice_ReturnsCorrectDataWithoutConsuming() + { + using var buffer = new ArcBufferWriter(); + var data = new byte[1024]; + _random.NextBytes(data); + buffer.Write(data); + + using var peeked = buffer.PeekSlice(512); + // Assert + Assert.Equal(data.AsSpan(0, 512).ToArray(), peeked.ToArray()); + } + + /// + /// Verifies that consuming a slice returns the correct data and updates the buffer length. + /// + [Fact] + public void ConsumeSlice_ReturnsCorrectDataAndUpdatesLength() + { + using var buffer = new ArcBufferWriter(); + var data = new byte[1024]; + _random.NextBytes(data); + buffer.Write(data); + + using var slice = buffer.ConsumeSlice(512); + using var subSlice = slice.Slice(256, 256); + // Assert + Assert.Equal(data.AsSpan(0, 512).ToArray(), slice.ToArray()); + Assert.Equal(data.AsSpan(256, 256).ToArray(), subSlice.ToArray()); + Assert.Equal(data.Length - slice.Length, buffer.Length); + } + + /// + /// Verifies that using a slice after it has been unpinned throws an exception. + /// + [Fact] + public void UseAfterFree_ThrowsException() + { + using var buffer = new ArcBufferWriter(); + var data = new byte[1024]; + _random.NextBytes(data); + buffer.Write(data); + + var slice = buffer.ConsumeSlice(512); + slice.Unpin(); + + // Assert + Assert.Throws(() => slice.ToArray()); + } + + /// + /// Verifies that double unpinning a slice throws, and that buffer can be reset and disposed safely. + /// + [Fact] + public void DoubleFree_ThrowsAndBufferCanBeResetAndDisposed() + { + var buffer = new ArcBufferWriter(); + var data = new byte[1024]; + _random.NextBytes(data); + buffer.Write(data); + + var slice = buffer.ConsumeSlice(512); + slice.Unpin(); + + // Assert + Assert.Throws(() => slice.Unpin()); + + Assert.Equal(512, buffer.Length); + buffer.Reset(); + Assert.Equal(0, buffer.Length); + + buffer.Dispose(); + } + + /// + /// Verifies that a new buffer is empty. + /// + [Fact] + public void NewBuffer_IsEmpty() + { + using var buffer = new ArcBufferWriter(); + // Assert + Assert.Equal(0, buffer.Length); + } + + /// + /// Verifies that writing an empty buffer does not change the buffer length. + /// + [Fact] + public void WriteEmptyBuffer_DoesNotChangeLength() + { + using var buffer = new ArcBufferWriter(); + var data = Array.Empty(); + _random.NextBytes(data); + buffer.Write(data); + + // Assert + Assert.Equal(0, buffer.Length); + } + + /// + /// Verifies that peeking at an empty buffer returns empty segments and throws when peeking past end. + /// + [Fact] + public void PeekEmptyBuffer_ReturnsEmptyAndThrowsOnOverflow() + { + using var buffer = new ArcBufferWriter(); + using var peeked = buffer.PeekSlice(0); + using var subSlice = peeked.Slice(0, 0); + Assert.Empty(peeked.Pages); + Assert.Empty(peeked.PageSegments); + Assert.Empty(peeked.ArraySegments); + Assert.Empty(peeked.MemorySegments); + + Assert.Empty(subSlice.Pages); + Assert.Empty(subSlice.PageSegments); + Assert.Empty(subSlice.ArraySegments); + Assert.Empty(subSlice.MemorySegments); + + // Assert + Assert.Equal(0, peeked.Length); + Assert.Throws(() => buffer.PeekSlice(1)); + } + + /// + /// Verifies that consuming an empty buffer returns empty segments and throws when consuming past end. + /// + [Fact] + public void ConsumeEmptyBuffer_ReturnsEmptyAndThrowsOnOverflow() + { + using var buffer = new ArcBufferWriter(); + using var slice = buffer.ConsumeSlice(0); + using var subSlice = slice.Slice(0, 0); + Assert.Empty(slice.Pages); + Assert.Empty(slice.PageSegments); + Assert.Empty(slice.ArraySegments); + Assert.Empty(slice.MemorySegments); + Assert.Equal(0, slice.AsReadOnlySequence().Length); + + Assert.Empty(subSlice.Pages); + Assert.Empty(subSlice.PageSegments); + Assert.Empty(subSlice.ArraySegments); + Assert.Empty(subSlice.MemorySegments); + Assert.Equal(0, subSlice.AsReadOnlySequence().Length); + + // Assert + Assert.Equal(0, slice.Length); + Assert.Equal(0, buffer.Length); + Assert.Throws(() => buffer.PeekSlice(1)); + Assert.Throws(() => buffer.ConsumeSlice(1)); + } + + /// + /// Verifies that disposing a slice after consuming a full page increments the page version. + /// + [Fact] + public void DisposeSliceAfterFullPageConsumption_IncrementsPageVersion() + { + using var bufferWriter = new ArcBufferWriter(); + var data = new byte[ArcBufferPagePool.MinimumPageSize + 1]; + _random.NextBytes(data); + bufferWriter.Write(data); + + // Consuming the slice will cause the writer to release (unpin) those pages. + // Since we write more than one page (MinimumPageSize), we should have at least two pages. + // The write head will sit on the second page, leaving the first free to be consumed. + var slice = bufferWriter.ConsumeSlice(ArcBufferPagePool.MinimumPageSize); + var pages = new List(slice.Pages); + + var initialVersions = pages.Select(p => p.Version).ToList(); + slice.Dispose(); + + // Assert + foreach (var page in pages.Zip(initialVersions)) + { + // Check that the versions have been incremented. + Assert.True(page.First.Version > page.Second); + } + } + + /// + /// Verifies that after writing and then advancing the read head, the page version is incremented as expected. + /// + [Fact] + public void PageVersionIncrementAfterWriteAndReadHeadAdvance() + { + using var bufferWriter = new ArcBufferWriter(); + var data = new byte[ArcBufferPagePool.MinimumPageSize]; + _random.NextBytes(data); + bufferWriter.Write(data); + + // Since we write exactly one page (MinimumPageSize), we should have exactly one page. + // The write head will sit on the first page, preventing it from being unpinned. + var slice = bufferWriter.ConsumeSlice(ArcBufferPagePool.MinimumPageSize); + var pages = new List(slice.Pages); + + var initialVersions = pages.Select(p => p.Version).ToList(); + slice.Dispose(); + + // Assert + foreach (var page in pages.Zip(initialVersions)) + { + // Check that the versions have NOT been incremented. + Assert.False(page.First.Version > page.Second); + } + + // Write one more byte, moving the write head to the second page. + bufferWriter.Write([0]); + + // Advance the read head to trigger unpinning and version increment. + bufferWriter.AdvanceReader(1); + + // Assert + foreach (var page in pages.Zip(initialVersions)) + { + // Check that the versions have NOT been incremented. + Assert.True(page.First.Version > page.Second); + } + } + + /// + /// Verifies that all operations throw ObjectDisposedException after the buffer is disposed. + /// + [Fact] + public void DisposedBuffer_ThrowsOnAllOperations() + { + var buffer = new ArcBufferWriter(); + buffer.Dispose(); + Assert.Throws(() => buffer.GetMemory(1)); + Assert.Throws(() => buffer.GetSpan(1)); + Assert.Throws(() => buffer.Write(new byte[1])); + Assert.Throws(() => buffer.PeekSlice(0)); + Assert.Throws(() => buffer.ConsumeSlice(0)); + Assert.Throws(() => buffer.AdvanceWriter(1)); + Assert.Throws(() => buffer.AdvanceReader(0)); + Assert.Throws(() => buffer.Reset()); + Assert.Throws(() => buffer.ReplenishBuffers(new List>(1))); + } + + /// + /// Verifies that double-disposing an ArcBuffer slice is safe and does not throw. + /// + [Fact] + public void DoubleDisposeArcBuffer_IsSafe() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[100]); + var slice = buffer.PeekSlice(10); + slice.Dispose(); + // Should not throw + slice.Dispose(); + } + + /// + /// Verifies that resetting a disposed buffer throws ObjectDisposedException. + /// + [Fact] + public void ResetAfterDispose_Throws() + { + var buffer = new ArcBufferWriter(); + buffer.Dispose(); + Assert.Throws(() => buffer.Reset()); + } + + /// + /// Verifies that advancing the writer by a negative value throws ArgumentOutOfRangeException. + /// + [Fact] + public void AdvanceWriterNegative_Throws() + { + using var buffer = new ArcBufferWriter(); + Assert.Throws(() => buffer.AdvanceWriter(-1)); + } + + /// + /// Verifies that advancing the reader by a negative or too-large value throws ArgumentOutOfRangeException. + /// + [Fact] + public void AdvanceReaderNegativeOrTooLarge_Throws() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[10]); + Assert.Throws(() => buffer.PeekSlice(11)); + Assert.Throws(() => buffer.ConsumeSlice(11)); + } + + /// + /// Verifies that calling Reset() after writing data spanning several pages returns all pages to the pool and empties the buffer. + /// + [Fact] + public void ResetReleasesAllPages_EmptiesBuffer() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[ArcBufferPagePool.MinimumPageSize * 3]); + buffer.Reset(); + Assert.Equal(0, buffer.Length); + } + + /// + /// Verifies that calling Dispose() multiple times on ArcBufferWriter is safe. + /// + [Fact] + public void DisposeMultipleTimes_IsSafe() + { + var buffer = new ArcBufferWriter(); + buffer.Dispose(); + buffer.Dispose(); + } + + /// + /// Verifies that writing or getting memory/span after Dispose() throws ObjectDisposedException. + /// + [Fact] + public void WriteAfterDispose_Throws() + { + var buffer = new ArcBufferWriter(); + buffer.Dispose(); + Assert.Throws(() => buffer.Write(new byte[1])); + Assert.Throws(() => buffer.GetMemory(1)); + Assert.Throws(() => buffer.GetSpan(1)); + } + + /// + /// Verifies that pinning and unpinning a page multiple times only returns it to the pool when the reference count reaches zero. + /// + [Fact] + public void PinUnpinReferenceCounting_WorksCorrectly() + { + var page = new ArcBufferPage(ArcBufferPagePool.MinimumPageSize); + int token = page.Version; + page.Pin(token); + page.Pin(token); + Assert.Equal(2, page.ReferenceCount); + page.Unpin(token); + Assert.Equal(1, page.ReferenceCount); + page.Unpin(token); + Assert.Equal(0, page.ReferenceCount); + } + + /// + /// Verifies that unpinning a page with an incorrect version token throws InvalidOperationException. + /// + [Fact] + public void UnpinWithInvalidToken_Throws() + { + var page = new ArcBufferPage(ArcBufferPagePool.MinimumPageSize); + int token = page.Version; + page.Pin(token); + Assert.Throws(() => page.Unpin(token + 1)); + } + + /// + /// Verifies that CheckValidity throws if the reference count is zero or negative. + /// + [Fact] + public void CheckValidityWithInvalidRefCount_Throws() + { + var page = new ArcBufferPage(ArcBufferPagePool.MinimumPageSize); + int token = page.Version; + Assert.Throws(() => page.CheckValidity(token)); + } + + /// + /// Verifies that disposing a slice does not affect the original buffer. + /// + [Fact] + public void SliceDispose_DoesNotAffectOriginalBuffer() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[100]); + var slice = buffer.PeekSlice(50); + slice.Dispose(); + Assert.Equal(100, buffer.Length); + } + + /// + /// Verifies that UnsafeSlice does not increment the reference count. + /// + [Fact] + public void UnsafeSlice_DoesNotPinPages() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[100]); + var slice = buffer.PeekSlice(100); + var page = slice.First; + int before = page.ReferenceCount; + var unsafeSlice = slice.UnsafeSlice(10, 10); + Assert.Equal(before, unsafeSlice.First.ReferenceCount); + } + + /// + /// Verifies that copying to a span that is too small throws. + /// + [Fact] + public void CopyToWithInsufficientDestination_Throws() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[100]); + var slice = buffer.PeekSlice(100); + var dest = new byte[50]; + Assert.Throws(() => slice.CopyTo(dest.AsSpan())); + } + + /// + /// Verifies that consuming more bytes than available throws. + /// + [Fact] + public void ConsumeMoreThanAvailable_Throws() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[10]); + Assert.Throws(() => buffer.ConsumeSlice(20)); + } + + /// + /// Verifies that Skip() advances the read head. + /// + [Fact] + public void SkipAdvancesReadHead_WorksCorrectly() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[100]); + var reader = new ArcBufferReader(buffer); + reader.Skip(50); + Assert.Equal(50, reader.Length); + } + + /// + /// Verifies that large pages are reused by the pool. + /// + [Fact] + public void LargePageReuse_Works() + { + var pool = ArcBufferPagePool.Shared; + var page1 = pool.Rent(ArcBufferPagePool.MinimumPageSize * 4); + int version1 = page1.Version; + page1.Pin(version1); // Pin the page + page1.Unpin(version1); // Return to pool + var page2 = pool.Rent(ArcBufferPagePool.MinimumPageSize * 4); + Assert.True(page2.Version > version1 || page2 != page1); + } + + /// + /// Verifies that minimum size pages are reused by the pool. + /// + [Fact] + public void MinimumPageReuse_Works() + { + var pool = ArcBufferPagePool.Shared; + var page1 = pool.Rent(); + int version1 = page1.Version; + page1.Pin(version1); // Pin the page + page1.Unpin(version1); // Return to pool + var page2 = pool.Rent(); + Assert.True(page2.Version > version1 || page2 != page1); + } + + /// + /// Verifies boundary values for slicing, peeking, and consuming. + /// + [Fact] + public void BoundaryValue_SlicePeekConsume() + { + using var buffer = new ArcBufferWriter(); + var data = new byte[PageSize * 2]; + _random.NextBytes(data); + buffer.Write(data); + + // Slice at start + using (var s = buffer.PeekSlice(0)) + { + Assert.Equal(0, s.Length); + } + using (var s = buffer.PeekSlice(1)) + { + Assert.Equal(data[0], s.ToArray()[0]); + } + using (var s = buffer.PeekSlice(data.Length)) + { + Assert.Equal(data, s.ToArray()); + } + + // Slice at page boundary + using (var s = buffer.PeekSlice(PageSize)) + { + Assert.Equal(data.Take(PageSize).ToArray(), s.ToArray()); + } + using (var s = buffer.PeekSlice(PageSize + 1)) + { + Assert.Equal(data.Take(PageSize + 1).ToArray(), s.ToArray()); + } + + // Consume at boundaries + using (var s = buffer.ConsumeSlice(0)) + { + Assert.Equal(0, s.Length); + } + using (var s = buffer.ConsumeSlice(1)) + { + Assert.Equal(data[0], s.ToArray()[0]); + } + using (var s = buffer.ConsumeSlice(PageSize - 1)) + { + Assert.Equal(data.Skip(1).Take(PageSize - 1).ToArray(), s.ToArray()); + } + using (var s = buffer.ConsumeSlice(PageSize)) + { + Assert.Equal(data.Skip(PageSize).Take(PageSize).ToArray(), s.ToArray()); + } + } + + /// + /// Verifies that double-free and use-after-free are guarded. + /// + [Fact] + public void DoubleFree_And_UseAfterFree_Guards() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[100]); + var slice = buffer.PeekSlice(50); + slice.Dispose(); + // Double dispose is safe + slice.Dispose(); + // Unpin after dispose throws + Assert.Throws(() => slice.Unpin()); + // Use after dispose throws + Assert.Throws(() => slice.ToArray()); + } + + /// + /// Verifies that memory is not leaked (reference count returns to zero) after all slices are disposed. + /// + [Fact] + public void NoMemoryLeak_ReferenceCountReturnsToZero() + { + var buffer = new ArcBufferWriter(); + buffer.Write(new byte[PageSize * 2]); + var slices = new List(); + for (int i = 0; i < 10; i++) + { + slices.Add(buffer.PeekSlice(PageSize)); + } + var pages = slices[0].Pages.ToList(); + foreach (var s in slices) + { + s.Dispose(); + } + foreach (var p in pages) + { + Assert.Equal(1, p.ReferenceCount); // Only the buffer's own pin remains + } + buffer.Dispose(); + foreach (var p in pages) + { + Assert.Equal(0, p.ReferenceCount); + } + } + + /// + /// Verifies that slicing and peeking with zero-length and full-length works for empty and full buffers. + /// + [Fact] + public void EmptyAndFullBuffer_SlicePeek() + { + using var buffer = new ArcBufferWriter(); + using (var s = buffer.PeekSlice(0)) + { + Assert.Equal(0, s.Length); + } + buffer.Write(new byte[PageSize]); + using (var s = buffer.PeekSlice(PageSize)) + { + Assert.Equal(PageSize, s.Length); + } + using (var s = buffer.ConsumeSlice(PageSize)) + { + Assert.Equal(PageSize, s.Length); + } + Assert.Equal(0, buffer.Length); + } + + /// + /// Verifies that slicing at the very end of the buffer returns an empty slice. + /// + [Fact] + public void SliceAtEnd_ReturnsEmpty() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[10]); + buffer.ConsumeSlice(10).Dispose(); + using (var s = buffer.PeekSlice(0)) + { + Assert.Equal(0, s.Length); + } + Assert.Throws(() => buffer.PeekSlice(1)); + } + + /// + /// Verifies that pin/unpin on different slices to the same page does not leak memory. + /// + [Fact] + public void MultipleSlices_SamePage_NoLeak() + { + using var buffer = new ArcBufferWriter(); + buffer.Write(new byte[PageSize]); + var s1 = buffer.PeekSlice(PageSize / 2); + var s2 = buffer.PeekSlice(PageSize / 2); + var page = s1.First; + Assert.True(page.ReferenceCount >= 2); + s1.Dispose(); + Assert.True(page.ReferenceCount >= 1); + s2.Dispose(); + Assert.Equal(1, page.ReferenceCount); // Only buffer's own pin remains + buffer.Dispose(); + Assert.Equal(0, page.ReferenceCount); + } +}