diff --git a/.dotnet/src/Custom/Chat/ChatClient.cs b/.dotnet/src/Custom/Chat/ChatClient.cs index e8ae289a7..67cb06ed4 100644 --- a/.dotnet/src/Custom/Chat/ChatClient.cs +++ b/.dotnet/src/Custom/Chat/ChatClient.cs @@ -200,19 +200,12 @@ public virtual StreamingClientResult CompleteChatStreaming( PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); requestMessage.BufferResponse = false; Shim.Pipeline.Send(requestMessage); - PipelineResponse response = requestMessage.ExtractResponse(); - - if (response.IsError) + if (requestMessage.Response.IsError) { - throw new ClientResultException(response); + throw new ClientResultException(requestMessage.Response); } - - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( - responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + return StreamingClientResult + .Create(requestMessage.Response); } /// @@ -238,19 +231,12 @@ public virtual async Task> CompleteCh PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); requestMessage.BufferResponse = false; await Shim.Pipeline.SendAsync(requestMessage).ConfigureAwait(false); - PipelineResponse response = requestMessage.ExtractResponse(); - - if (response.IsError) + if (requestMessage.Response.IsError) { - throw new ClientResultException(response); + throw new ClientResultException(requestMessage.Response); } - - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( - responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + return StreamingClientResult + .Create(requestMessage.Response); } private Internal.Models.CreateChatCompletionRequest CreateInternalRequest( diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs index d359ea956..c5959bea0 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs @@ -181,156 +181,4 @@ internal StreamingChatUpdate( ToolCallUpdate = toolCallUpdate; LogProbabilities = logProbabilities; } - - internal static List DeserializeStreamingChatUpdates(JsonElement element) - { - List results = []; - if (element.ValueKind == JsonValueKind.Null) - { - return results; - } - string id = default; - DateTimeOffset created = default; - string systemFingerprint = null; - foreach (JsonProperty property in element.EnumerateObject()) - { - if (property.NameEquals("id"u8)) - { - id = property.Value.GetString(); - continue; - } - if (property.NameEquals("created"u8)) - { - created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - continue; - } - if (property.NameEquals("system_fingerprint")) - { - systemFingerprint = property.Value.GetString(); - continue; - } - if (property.NameEquals("choices"u8)) - { - foreach (JsonElement choiceElement in property.Value.EnumerateArray()) - { - ChatRole? role = null; - string contentUpdate = null; - string functionName = null; - string functionArgumentsUpdate = null; - int choiceIndex = 0; - ChatFinishReason? finishReason = null; - List toolCallUpdates = []; - ChatLogProbabilityCollection logProbabilities = new([]); - - foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) - { - if (choiceProperty.NameEquals("index"u8)) - { - choiceIndex = choiceProperty.Value.GetInt32(); - continue; - } - if (choiceProperty.NameEquals("finish_reason"u8)) - { - if (choiceProperty.Value.ValueKind == JsonValueKind.Null) - { - finishReason = null; - continue; - } - finishReason = choiceProperty.Value.GetString() switch - { - "stop" => ChatFinishReason.Stopped, - "length" => ChatFinishReason.Length, - "tool_calls" => ChatFinishReason.ToolCalls, - "function_call" => ChatFinishReason.FunctionCall, - "content_filter" => ChatFinishReason.ContentFilter, - _ => throw new ArgumentException(nameof(finishReason)), - }; - continue; - } - if (choiceProperty.NameEquals("delta"u8)) - { - foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) - { - if (deltaProperty.NameEquals("role"u8)) - { - role = deltaProperty.Value.GetString() switch - { - "system" => ChatRole.System, - "user" => ChatRole.User, - "assistant" => ChatRole.Assistant, - "tool" => ChatRole.Tool, - "function" => ChatRole.Function, - _ => throw new ArgumentException(nameof(role)), - }; - continue; - } - if (deltaProperty.NameEquals("content"u8)) - { - contentUpdate = deltaProperty.Value.GetString(); - continue; - } - if (deltaProperty.NameEquals("function_call"u8)) - { - foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) - { - if (functionProperty.NameEquals("name"u8)) - { - functionName = functionProperty.Value.GetString(); - continue; - } - if (functionProperty.NameEquals("arguments"u8)) - { - functionArgumentsUpdate = functionProperty.Value.GetString(); - } - } - } - if (deltaProperty.NameEquals("tool_calls")) - { - foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) - { - toolCallUpdates.Add( - StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); - } - } - } - } - if (choiceProperty.NameEquals("logprobs"u8)) - { - Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs - = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( - choiceProperty.Value); - logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); - } - } - // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate - // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. - if (toolCallUpdates.Count == 0) - { - toolCallUpdates.Add(null); - } - foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) - { - results.Add(new StreamingChatUpdate( - id, - created, - systemFingerprint, - choiceIndex, - role, - contentUpdate, - finishReason, - functionName, - functionArgumentsUpdate, - toolCallUpdate, - logProbabilities)); - } - } - continue; - } - } - if (results.Count == 0) - { - results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); - } - return results; - } } diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs new file mode 100644 index 000000000..224878b72 --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs @@ -0,0 +1,181 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace OpenAI.Chat; + +internal partial class StreamingChatUpdateCollection : IJsonModel +{ + StreamingChatUpdateCollection IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdateCollection, ref reader, options); + + StreamingChatUpdateCollection IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdateCollection, data, options); + + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + => ModelSerializationHelpers.SerializeInstance(this, SerializeStreamingChatUpdateCollections, writer, options); + + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + => ModelSerializationHelpers.SerializeInstance(this, options); + + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + internal static StreamingChatUpdateCollection DeserializeStreamingChatUpdateCollection( + JsonElement sseDataJson, + ModelReaderWriterOptions options = default) + { + List results = []; + if (sseDataJson.ValueKind == JsonValueKind.Null) + { + return new(results); + } + string id = default; + DateTimeOffset created = default; + string systemFingerprint = null; + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("created"u8)) + { + created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + if (property.NameEquals("system_fingerprint")) + { + systemFingerprint = property.Value.GetString(); + continue; + } + if (property.NameEquals("choices"u8)) + { + foreach (JsonElement choiceElement in property.Value.EnumerateArray()) + { + ChatRole? role = null; + string contentUpdate = null; + string functionName = null; + string functionArgumentsUpdate = null; + int choiceIndex = 0; + ChatFinishReason? finishReason = null; + List toolCallUpdates = []; + ChatLogProbabilityCollection logProbabilities = new([]); + + foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) + { + if (choiceProperty.NameEquals("index"u8)) + { + choiceIndex = choiceProperty.Value.GetInt32(); + continue; + } + if (choiceProperty.NameEquals("finish_reason"u8)) + { + if (choiceProperty.Value.ValueKind == JsonValueKind.Null) + { + finishReason = null; + continue; + } + finishReason = choiceProperty.Value.GetString() switch + { + "stop" => ChatFinishReason.Stopped, + "length" => ChatFinishReason.Length, + "tool_calls" => ChatFinishReason.ToolCalls, + "function_call" => ChatFinishReason.FunctionCall, + "content_filter" => ChatFinishReason.ContentFilter, + _ => throw new ArgumentException(nameof(finishReason)), + }; + continue; + } + if (choiceProperty.NameEquals("delta"u8)) + { + foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) + { + if (deltaProperty.NameEquals("role"u8)) + { + role = deltaProperty.Value.GetString() switch + { + "system" => ChatRole.System, + "user" => ChatRole.User, + "assistant" => ChatRole.Assistant, + "tool" => ChatRole.Tool, + "function" => ChatRole.Function, + _ => throw new ArgumentException(nameof(role)), + }; + continue; + } + if (deltaProperty.NameEquals("content"u8)) + { + contentUpdate = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("function_call"u8)) + { + foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) + { + if (functionProperty.NameEquals("name"u8)) + { + functionName = functionProperty.Value.GetString(); + continue; + } + if (functionProperty.NameEquals("arguments"u8)) + { + functionArgumentsUpdate = functionProperty.Value.GetString(); + } + } + } + if (deltaProperty.NameEquals("tool_calls")) + { + foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) + { + toolCallUpdates.Add( + StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); + } + } + } + } + if (choiceProperty.NameEquals("logprobs"u8)) + { + Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs + = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( + choiceProperty.Value); + logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); + } + } + // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate + // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. + if (toolCallUpdates.Count == 0) + { + toolCallUpdates.Add(null); + } + foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) + { + results.Add(new( + id, + created, + systemFingerprint, + choiceIndex, + role, + contentUpdate, + finishReason, + functionName, + functionArgumentsUpdate, + toolCallUpdate)); + } + } + continue; + } + } + if (results.Count == 0) + { + results.Add(new(id, created, systemFingerprint)); + } + return new(results); + } + + internal static void SerializeStreamingChatUpdateCollections(StreamingChatUpdateCollection StreamingChatUpdateCollection, Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + throw new NotSupportedException(nameof(StreamingChatUpdateCollection)); + } +} diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs new file mode 100644 index 000000000..d0e2d79cb --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs @@ -0,0 +1,12 @@ +namespace OpenAI.Chat; + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text.Json; + +internal partial class StreamingChatUpdateCollection : ReadOnlyCollection +{ + internal StreamingChatUpdateCollection() : this([]) { } + internal StreamingChatUpdateCollection(IList list) : base(list) { } +} diff --git a/.dotnet/src/Utility/AsyncServerSentEventEnumerator.cs b/.dotnet/src/Utility/AsyncServerSentEventEnumerator.cs new file mode 100644 index 000000000..934773cf8 --- /dev/null +++ b/.dotnet/src/Utility/AsyncServerSentEventEnumerator.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI; + +internal sealed class AsyncServerSentEventEnumerator : IAsyncEnumerator, IDisposable, IAsyncDisposable +{ + private static readonly ReadOnlyMemory _doneToken = "[DONE]".AsMemory(); + + private readonly ServerSentEventReader _reader; + private CancellationToken _cancellationToken; + private bool _disposedValue; + + public ServerSentEvent Current { get; private set; } + + public AsyncServerSentEventEnumerator(ServerSentEventReader reader, CancellationToken cancellationToken = default) + { + _reader = reader; + _cancellationToken = cancellationToken; + } + + public async ValueTask MoveNextAsync() + { + ServerSentEvent? nextEvent = await _reader.TryGetNextEventAsync(_cancellationToken).ConfigureAwait(false); + if (nextEvent.HasValue) + { + if (nextEvent.Value.Data.Span.SequenceEqual(_doneToken.Span)) + { + return false; + } + Current = nextEvent.Value; + return true; + } + return false; + } + + private void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _reader.Dispose(); + } + + _disposedValue = true; + } + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + public ValueTask DisposeAsync() + { + Dispose(); + return new ValueTask(); + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs b/.dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs new file mode 100644 index 000000000..cd915b882 --- /dev/null +++ b/.dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs @@ -0,0 +1,68 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; + +namespace OpenAI; + +internal class AsyncServerSentEventJsonDataEnumerator : AsyncServerSentEventJsonDataEnumerator + where T : IJsonModel +{ + public AsyncServerSentEventJsonDataEnumerator(AsyncServerSentEventEnumerator eventEnumerator) : base(eventEnumerator) + { } +} + +internal class AsyncServerSentEventJsonDataEnumerator : IAsyncEnumerator, IDisposable, IAsyncDisposable + where TJsonDataType : IJsonModel +{ + private AsyncServerSentEventEnumerator _eventEnumerator; + private IEnumerator _currentInstanceEnumerator; + + public TInstanceType Current { get; private set; } + + public AsyncServerSentEventJsonDataEnumerator(AsyncServerSentEventEnumerator eventEnumerator) + { + _eventEnumerator = eventEnumerator; + } + + public async ValueTask MoveNextAsync() + { + if (_currentInstanceEnumerator?.MoveNext() == true) + { + Current = _currentInstanceEnumerator.Current; + return true; + } + if (await _eventEnumerator.MoveNextAsync()) + { + using JsonDocument eventDocument = JsonDocument.Parse(_eventEnumerator.Current.Data); + BinaryData eventData = BinaryData.FromObjectAsJson(eventDocument.RootElement); + TJsonDataType jsonData = ModelReaderWriter.Read(eventData); + if (jsonData is TInstanceType singleInstanceData) + { + Current = singleInstanceData; + return true; + } + if (jsonData is IEnumerable instanceCollectionData) + { + _currentInstanceEnumerator = instanceCollectionData.GetEnumerator(); + if (_currentInstanceEnumerator.MoveNext() == true) + { + Current = _currentInstanceEnumerator.Current; + return true; + } + } + } + return false; + } + + public async ValueTask DisposeAsync() + { + await _eventEnumerator.DisposeAsync(); + } + + public void Dispose() + { + _eventEnumerator.Dispose(); + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ModelSerializationHelpers.cs b/.dotnet/src/Utility/ModelSerializationHelpers.cs new file mode 100644 index 000000000..08356afed --- /dev/null +++ b/.dotnet/src/Utility/ModelSerializationHelpers.cs @@ -0,0 +1,110 @@ +using System; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace OpenAI; + +internal static partial class ModelSerializationHelpers +{ + internal static TOutput DeserializeNewInstance( + UInstanceInput existingInstance, + Func deserializationFunc, + ref Utf8JsonReader reader, + ModelReaderWriterOptions options) + where UInstanceInput : IJsonModel + { + options ??= new("W"); + var format = options.Format == "W" ? ((IJsonModel)existingInstance).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return deserializationFunc.Invoke(document.RootElement, options); + } + + internal static TOutput DeserializeNewInstance( + UInstanceInput existingInstance, + Func deserializationFunc, + BinaryData data, + ModelReaderWriterOptions options) + where UInstanceInput : IPersistableModel + { + options ??= new("W"); + var format = options.Format == "W" ? ((IPersistableModel)existingInstance).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return deserializationFunc.Invoke(document.RootElement, options)!; + } + default: + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + } + + internal static void SerializeInstance( + UInstanceInput instance, + Action serializationFunc, + Utf8JsonWriter writer, + ModelReaderWriterOptions options) + where UInstanceInput : IJsonModel + { + options ??= new ModelReaderWriterOptions("W"); + AssertSupportedJsonWriteFormat(instance, options); + serializationFunc.Invoke(instance, writer, options); + } + + internal static void SerializeInstance( + T instance, + Action serializationFunc, + Utf8JsonWriter writer, + ModelReaderWriterOptions options) + where T : IJsonModel + => SerializeInstance(instance, serializationFunc, writer, options); + + internal static BinaryData SerializeInstance( + UInstanceInput instance, + ModelReaderWriterOptions options) + where UInstanceInput : IPersistableModel + { + options ??= new("W"); + AssertSupportedPersistableWriteFormat(instance, options); + return ModelReaderWriter.Write(instance, options); + } + + internal static BinaryData SerializeInstance(T instance, ModelReaderWriterOptions options) + where T : IPersistableModel + => SerializeInstance(instance, options); + + internal static void AssertSupportedJsonWriteFormat(T instance, ModelReaderWriterOptions options) + where T : IJsonModel + => AssertSupportedJsonWriteFormat(instance, options); + + internal static void AssertSupportedJsonWriteFormat(UInstanceInput instance, ModelReaderWriterOptions options) + where UInstanceInput : IJsonModel + { + var format = options.Format == "W" ? ((IJsonModel)instance).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + } + + internal static void AssertSupportedPersistableWriteFormat(T instance, ModelReaderWriterOptions options) + where T : IPersistableModel + => AssertSupportedPersistableWriteFormat(instance, options); + + internal static void AssertSupportedPersistableWriteFormat(UInstanceInput instance, ModelReaderWriterOptions options) + where UInstanceInput : IPersistableModel + { + var format = options.Format == "W" ? ((IPersistableModel)instance).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEvent.cs b/.dotnet/src/Utility/ServerSentEvent.cs new file mode 100644 index 000000000..ea91b1889 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEvent.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct ServerSentEvent +{ + // Gets the value of the SSE "event type" buffer, used to distinguish between event kinds. + public ReadOnlyMemory EventName { get; } + // Gets the value of the SSE "data" buffer, which holds the payload of the server-sent event. + public ReadOnlyMemory Data { get; } + // Gets the value of the "last event ID" buffer, with which a user agent can reestablish a session. + public ReadOnlyMemory LastEventId { get; } + // If present, gets the defined "retry" value for the event, which represents the delay before reconnecting. + public TimeSpan? ReconnectionTime { get; } + + private readonly IReadOnlyList _fields; + private readonly string _multiLineData; + + internal ServerSentEvent(IReadOnlyList fields) + { + _fields = fields; + StringBuilder multiLineDataBuilder = null; + for (int i = 0; i < _fields.Count; i++) + { + ReadOnlyMemory fieldValue = _fields[i].Value; + switch (_fields[i].FieldType) + { + case ServerSentEventFieldKind.Event: + EventName = fieldValue; + break; + case ServerSentEventFieldKind.Data: + { + if (multiLineDataBuilder != null) + { + multiLineDataBuilder.Append(fieldValue); + } + else if (Data.IsEmpty) + { + Data = fieldValue; + } + else + { + multiLineDataBuilder ??= new(); + multiLineDataBuilder.Append(fieldValue); + Data = null; + } + break; + } + case ServerSentEventFieldKind.Id: + LastEventId = fieldValue; + break; + case ServerSentEventFieldKind.Retry: + ReconnectionTime = Int32.TryParse(fieldValue.ToString(), out int retry) ? TimeSpan.FromMilliseconds(retry) : null; + break; + default: + break; + } + if (multiLineDataBuilder != null) + { + _multiLineData = multiLineDataBuilder.ToString(); + Data = _multiLineData.AsMemory(); + } + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventField.cs b/.dotnet/src/Utility/ServerSentEventField.cs new file mode 100644 index 000000000..c6032c931 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventField.cs @@ -0,0 +1,64 @@ +using System; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct ServerSentEventField +{ + public ServerSentEventFieldKind FieldType { get; } + + // TODO: we should not expose UTF16 publicly + public ReadOnlyMemory Value + { + get + { + if (_valueStartIndex >= _original.Length) + { + return ReadOnlyMemory.Empty; + } + else + { + return _original.AsMemory(_valueStartIndex); + } + } + } + + private readonly string _original; + private readonly int _valueStartIndex; + + internal ServerSentEventField(string line) + { + _original = line; + int colonIndex = _original.AsSpan().IndexOf(':'); + + ReadOnlyMemory fieldName = colonIndex < 0 ? _original.AsMemory(): _original.AsMemory(0, colonIndex); + FieldType = fieldName.Span switch + { + var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, + var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, + var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, + var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, + _ => ServerSentEventFieldKind.Ignored, + }; + + if (colonIndex < 0) + { + _valueStartIndex = _original.Length; + } + else if (colonIndex + 1 < _original.Length && _original[colonIndex + 1] == ' ') + { + _valueStartIndex = colonIndex + 2; + } + else + { + _valueStartIndex = colonIndex + 1; + } + } + + public override string ToString() => _original; + + private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); + private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); + private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); + private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventFieldKind.cs b/.dotnet/src/Utility/ServerSentEventFieldKind.cs new file mode 100644 index 000000000..c3597b0ff --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventFieldKind.cs @@ -0,0 +1,11 @@ +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal enum ServerSentEventFieldKind +{ + Event, + Data, + Id, + Retry, + Ignored +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventReader.cs b/.dotnet/src/Utility/ServerSentEventReader.cs new file mode 100644 index 000000000..d47938254 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventReader.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI; + +internal sealed class ServerSentEventReader : IDisposable +{ + private readonly Stream _stream; + private readonly StreamReader _reader; + private bool _disposedValue; + + public ServerSentEventReader(Stream stream) + { + _stream = stream; + _reader = new StreamReader(stream); + } + + /// + /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) + { + List fields = []; + + while (!cancellationToken.IsCancellationRequested) + { + string line = _reader.ReadLine(); + if (line == null) + { + // A null line indicates end of input + return null; + } + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); + } + } + + return null; + } + + /// + /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + { + List fields = []; + + while (!cancellationToken.IsCancellationRequested) + { + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + if (line == null) + { + // A null line indicates end of input + return null; + } + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); + } + } + + return null; + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _reader.Dispose(); + _stream.Dispose(); + } + + _disposedValue = true; + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs deleted file mode 100644 index 743a1bedd..000000000 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Threading; - -namespace OpenAI; - -internal static class SseAsyncEnumerator -{ - internal static async IAsyncEnumerable EnumerateFromSseStream( - Stream stream, - Func> multiElementDeserializer, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - try - { - using SseReader sseReader = new(stream); - while (!cancellationToken.IsCancellationRequested) - { - SseLine? sseEvent = await sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent is not null) - { - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - { - throw new InvalidDataException(); - } - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - break; - } - using JsonDocument sseMessageJson = JsonDocument.Parse(value); - IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); - foreach (T item in newItems) - { - yield return item; - } - } - } - } - finally - { - // Always dispose the stream immediately once enumeration is complete for any reason - stream.Dispose(); - } - } - - internal static IAsyncEnumerable EnumerateFromSseStream( - Stream stream, - Func elementDeserializer, - CancellationToken cancellationToken = default) - => EnumerateFromSseStream( - stream, - (element) => new T[] { elementDeserializer.Invoke(element) }, - cancellationToken); -} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseLine.cs b/.dotnet/src/Utility/SseLine.cs deleted file mode 100644 index 4d82315f9..000000000 --- a/.dotnet/src/Utility/SseLine.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; - -namespace OpenAI; - -// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal readonly struct SseLine -{ - private readonly string _original; - private readonly int _colonIndex; - private readonly int _valueIndex; - - public static SseLine Empty { get; } = new SseLine(string.Empty, 0, false); - - internal SseLine(string original, int colonIndex, bool hasSpaceAfterColon) - { - _original = original; - _colonIndex = colonIndex; - _valueIndex = colonIndex + (hasSpaceAfterColon ? 2 : 1); - } - - public bool IsEmpty => _original.Length == 0; - public bool IsComment => !IsEmpty && _original[0] == ':'; - - // TODO: we should not expose UTF16 publicly - public ReadOnlyMemory FieldName => _original.AsMemory(0, _colonIndex); - public ReadOnlyMemory FieldValue => _original.AsMemory(_valueIndex); - - public override string ToString() => _original; -} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index cf0301408..cab725caf 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -1,118 +1,124 @@ using System; -using System.ClientModel; -using System.ClientModel.Internal; +using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace OpenAI; internal sealed class SseReader : IDisposable +{ + private readonly Stream _stream; + private readonly StreamReader _reader; + private bool _disposedValue; + + public SseReader(Stream stream) { - private readonly Stream _stream; - private readonly StreamReader _reader; - private bool _disposedValue; + _stream = stream; + _reader = new StreamReader(stream); + } - public SseReader(Stream stream) - { - _stream = stream; - _reader = new StreamReader(stream); - } + /// + /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) + { + List fields = []; - public SseLine? TryReadSingleFieldEvent() + while (!cancellationToken.IsCancellationRequested) { - while (true) + string line = _reader.ReadLine(); + if (line == null) { - SseLine? line = TryReadLine(); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = TryReadLine(); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + // A null line indicates end of input + return null; } - } - - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadSingleFieldEventAsync() - { - while (true) + else if (line.Length == 0) { - SseLine? line = await TryReadLineAsync().ConfigureAwait(false); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = await TryReadLineAsync().ConfigureAwait(false); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); } } - public SseLine? TryReadLine() - { - string lineText = _reader.ReadLine(); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } + return null; + } - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadLineAsync() - { - string lineText = await _reader.ReadLineAsync().ConfigureAwait(false); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } + /// + /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + { + List fields = []; - private static bool TryParseLine(string lineText, out SseLine line) + while (!cancellationToken.IsCancellationRequested) { - if (lineText.Length == 0) + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + if (line == null) + { + // A null line indicates end of input + return null; + } + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else { - line = default; - return false; + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); } + } - ReadOnlySpan lineSpan = lineText.AsSpan(); - int colonIndex = lineSpan.IndexOf(':'); - ReadOnlySpan fieldValue = lineSpan.Slice(colonIndex + 1); + return null; + } - bool hasSpace = false; - if (fieldValue.Length > 0 && fieldValue[0] == ' ') - hasSpace = true; - line = new SseLine(lineText, colonIndex, hasSpace); - return true; - } + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } - private void Dispose(bool disposing) + private void Dispose(bool disposing) + { + if (!_disposedValue) { - if (!_disposedValue) + if (disposing) { - if (disposing) - { - _reader.Dispose(); - _stream.Dispose(); - } - - _disposedValue = true; + _reader.Dispose(); + _stream.Dispose(); } + + _disposedValue = true; } - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - } \ No newline at end of file + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamingClientResult.cs b/.dotnet/src/Utility/StreamingClientResult.cs new file mode 100644 index 000000000..57c43a01e --- /dev/null +++ b/.dotnet/src/Utility/StreamingClientResult.cs @@ -0,0 +1,96 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI; + +/// +/// Represents an operation response with streaming content that can be deserialized and enumerated while the response +/// is still being received. +/// +/// The data type representative of distinct, streamable items. +public class StreamingClientResult : IAsyncEnumerable +{ + private readonly PipelineResponse _response; + private readonly Func> _asyncEnumeratorSourceDelegate; + + private bool _disposedValue; + + /// + /// Gets the underlying that contains headers and other response-wide information. + /// + /// + /// The instance used in this . + /// + public PipelineResponse GetRawResponse() => _response; + + private StreamingClientResult(PipelineResponse response, Func> asyncEnumeratorSourceDelegate) + { + _response = response; + _asyncEnumeratorSourceDelegate = asyncEnumeratorSourceDelegate; + } + + /// + /// Creates a new instance of that will yield items of the specified type + /// as they become available via server-sent event JSON data on the available + /// . This overload uses via the + /// interface and only supports single-item deserialization per server-sent event data + /// payload. + /// + /// The base for this result instance. + /// + /// The optional cancellation token used to control the enumeration. + /// + /// A new instance of . + public static StreamingClientResult Create(PipelineResponse response, CancellationToken cancellationToken = default) + where U : IJsonModel + { + return new(response, GetServerSentEventDeserializationEnumerator); + } + + public static StreamingClientResult Create(PipelineResponse response, CancellationToken cancellationToken = default) + where TJsonDataType : IJsonModel + { + return new(response, GetServerSentEventDeserializationEnumerator); + } + + private static IAsyncEnumerator GetServerSentEventDeserializationEnumerator(Stream stream, CancellationToken cancellationToken = default) + where U : IJsonModel + { + return GetServerSentEventDeserializationEnumerator(stream, cancellationToken); + } + + private static IAsyncEnumerator GetServerSentEventDeserializationEnumerator( + Stream stream, + CancellationToken cancellationToken = default) + where TJsonDataType : IJsonModel + { + ServerSentEventReader sseReader = null; + AsyncServerSentEventEnumerator sseEnumerator = null; + try + { + sseReader = new(stream); + sseEnumerator = new(sseReader, cancellationToken); + AsyncServerSentEventJsonDataEnumerator instanceEnumerator = new(sseEnumerator); + sseEnumerator = null; + sseReader = null; + return instanceEnumerator; + } + finally + { + sseEnumerator?.Dispose(); + sseReader?.Dispose(); + } + } + + IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) + { + return _asyncEnumeratorSourceDelegate.Invoke(_response.ContentStream, cancellationToken); + } +} diff --git a/.dotnet/src/Utility/StreamingResult.cs b/.dotnet/src/Utility/StreamingResult.cs deleted file mode 100644 index a1b6ff538..000000000 --- a/.dotnet/src/Utility/StreamingResult.cs +++ /dev/null @@ -1,95 +0,0 @@ -using System.ClientModel; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Threading; -using System.Collections.Generic; -using System; - -namespace OpenAI; - -/// -/// Represents an operation response with streaming content that can be deserialized and enumerated while the response -/// is still being received. -/// -/// The data type representative of distinct, streamable items. -public class StreamingClientResult - : IDisposable - , IAsyncEnumerable -{ - private ClientResult _rawResult { get; } - private IAsyncEnumerable _asyncEnumerableSource { get; } - private bool _disposedValue { get; set; } - - private StreamingClientResult() { } - - private StreamingClientResult( - ClientResult rawResult, - Func> asyncEnumerableProcessor) - { - _rawResult = rawResult; - _asyncEnumerableSource = asyncEnumerableProcessor.Invoke(rawResult); - } - - /// - /// Creates a new instance of using the provided underlying HTTP response. The - /// provided function will be used to resolve the response into an asynchronous enumeration of streamed response - /// items. - /// - /// The HTTP response. - /// - /// The function that will resolve the provided response into an IAsyncEnumerable. - /// - /// - /// A new instance of that will be capable of asynchronous enumeration of - /// items from the HTTP response. - /// - internal static StreamingClientResult CreateFromResponse( - ClientResult result, - Func> asyncEnumerableProcessor) - { - return new(result, asyncEnumerableProcessor); - } - - /// - /// Gets the underlying instance that this may enumerate - /// over. - /// - /// The instance attached to this . - public PipelineResponse GetRawResponse() => _rawResult.GetRawResponse(); - - /// - /// Gets the asynchronously enumerable collection of distinct, streamable items in the response. - /// - /// - /// The return value of this method may be used with the "await foreach" statement. - /// - /// As explicitly implements , callers may - /// enumerate a instance directly instead of calling this method. - /// - /// - /// - public IAsyncEnumerable EnumerateValues() => this; - - /// - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - /// - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _rawResult?.GetRawResponse()?.Dispose(); - } - _disposedValue = true; - } - } - - IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) - => _asyncEnumerableSource.GetAsyncEnumerator(cancellationToken); -} \ No newline at end of file