From a4916d23a95be6153ee49e4bc4930061985c7b56 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 29 Apr 2024 20:11:24 -0700 Subject: [PATCH 1/3] standalone port of sse updates --- .dotnet/src/Custom/Chat/ChatClient.cs | 56 ++++-- .../Chat/StreamingChatUpdate.Serialization.cs | 181 +++++++++++++++++ .../src/Custom/Chat/StreamingChatUpdate.cs | 152 --------------- .../src/Utility/ModelSerializationHelpers.cs | 110 +++++++++++ .dotnet/src/Utility/ServerSentEvent.cs | 68 +++++++ .dotnet/src/Utility/ServerSentEventField.cs | 64 ++++++ .../src/Utility/ServerSentEventFieldKind.cs | 11 ++ .dotnet/src/Utility/SseAsyncEnumerator.cs | 59 ------ .dotnet/src/Utility/SseLine.cs | 29 --- .dotnet/src/Utility/SseReader.cs | 184 +++++++++--------- .dotnet/src/Utility/StreamingClientResult.cs | 134 +++++++++++++ .dotnet/src/Utility/StreamingResult.cs | 95 --------- 12 files changed, 699 insertions(+), 444 deletions(-) create mode 100644 .dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs create mode 100644 .dotnet/src/Utility/ModelSerializationHelpers.cs create mode 100644 .dotnet/src/Utility/ServerSentEvent.cs create mode 100644 .dotnet/src/Utility/ServerSentEventField.cs create mode 100644 .dotnet/src/Utility/ServerSentEventFieldKind.cs delete mode 100644 .dotnet/src/Utility/SseAsyncEnumerator.cs delete mode 100644 .dotnet/src/Utility/SseLine.cs create mode 100644 .dotnet/src/Utility/StreamingClientResult.cs delete mode 100644 .dotnet/src/Utility/StreamingResult.cs diff --git a/.dotnet/src/Custom/Chat/ChatClient.cs b/.dotnet/src/Custom/Chat/ChatClient.cs index e8ae289a7..4d115ba80 100644 --- a/.dotnet/src/Custom/Chat/ChatClient.cs +++ b/.dotnet/src/Custom/Chat/ChatClient.cs @@ -200,19 +200,27 @@ public virtual StreamingClientResult CompleteChatStreaming( PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); requestMessage.BufferResponse = false; Shim.Pipeline.Send(requestMessage); - PipelineResponse response = requestMessage.ExtractResponse(); - - if (response.IsError) + if (requestMessage.Response.IsError) { - throw new ClientResultException(response); + throw new ClientResultException(requestMessage.Response); } - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( - responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + PipelineResponse response = null; + try + { + response = requestMessage.ExtractResponse(); + ClientResult genericResult = ClientResult.FromResponse(response); + StreamingClientResult streamingResult + = StreamingClientResult.Create( + response, + (sseJson) => StreamingChatUpdate.DeserializeStreamingChatUpdates(sseJson, options: default)); + response = null; + return streamingResult; + } + finally + { + response?.Dispose(); + } } /// @@ -238,19 +246,27 @@ public virtual async Task> CompleteCh PipelineMessage requestMessage = CreateCustomRequestMessage(messages, choiceCount, options); requestMessage.BufferResponse = false; await Shim.Pipeline.SendAsync(requestMessage).ConfigureAwait(false); - PipelineResponse response = requestMessage.ExtractResponse(); - - if (response.IsError) + if (requestMessage.Response.IsError) { - throw new ClientResultException(response); + throw new ClientResultException(requestMessage.Response); } - ClientResult genericResult = ClientResult.FromResponse(response); - return StreamingClientResult.CreateFromResponse( - genericResult, - (responseForEnumeration) => SseAsyncEnumerator.EnumerateFromSseStream( - responseForEnumeration.GetRawResponse().ContentStream, - e => StreamingChatUpdate.DeserializeStreamingChatUpdates(e))); + PipelineResponse response = null; + try + { + response = requestMessage.ExtractResponse(); + ClientResult genericResult = ClientResult.FromResponse(response); + StreamingClientResult streamingResult + = StreamingClientResult.Create( + response, + (sseJson) => StreamingChatUpdate.DeserializeStreamingChatUpdates(sseJson, options: default)); + response = null; + return streamingResult; + } + finally + { + response?.Dispose(); + } } private Internal.Models.CreateChatCompletionRequest CreateInternalRequest( diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs new file mode 100644 index 000000000..bc1a53ef8 --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs @@ -0,0 +1,181 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; + +namespace OpenAI.Chat; + +public partial class StreamingChatUpdate : IJsonModel> +{ + IEnumerable IJsonModel>.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdates, ref reader, options); + + IEnumerable IPersistableModel>.Create(BinaryData data, ModelReaderWriterOptions options) + => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdates, data, options); + + void IJsonModel>.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + => ModelSerializationHelpers.SerializeInstance, StreamingChatUpdate>(this, SerializeStreamingChatUpdates, writer, options); + + BinaryData IPersistableModel>.Write(ModelReaderWriterOptions options) + => ModelSerializationHelpers.SerializeInstance, StreamingChatUpdate>(this, options); + + string IPersistableModel>.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + + internal static IEnumerable DeserializeStreamingChatUpdates( + JsonElement sseDataJson, + ModelReaderWriterOptions options = default) + { + List results = []; + if (sseDataJson.ValueKind == JsonValueKind.Null) + { + return results; + } + string id = default; + DateTimeOffset created = default; + string systemFingerprint = null; + foreach (JsonProperty property in sseDataJson.EnumerateObject()) + { + if (property.NameEquals("id"u8)) + { + id = property.Value.GetString(); + continue; + } + if (property.NameEquals("created"u8)) + { + created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); + continue; + } + if (property.NameEquals("system_fingerprint")) + { + systemFingerprint = property.Value.GetString(); + continue; + } + if (property.NameEquals("choices"u8)) + { + foreach (JsonElement choiceElement in property.Value.EnumerateArray()) + { + ChatRole? role = null; + string contentUpdate = null; + string functionName = null; + string functionArgumentsUpdate = null; + int choiceIndex = 0; + ChatFinishReason? finishReason = null; + List toolCallUpdates = []; + ChatLogProbabilityCollection logProbabilities = new([]); + + foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) + { + if (choiceProperty.NameEquals("index"u8)) + { + choiceIndex = choiceProperty.Value.GetInt32(); + continue; + } + if (choiceProperty.NameEquals("finish_reason"u8)) + { + if (choiceProperty.Value.ValueKind == JsonValueKind.Null) + { + finishReason = null; + continue; + } + finishReason = choiceProperty.Value.GetString() switch + { + "stop" => ChatFinishReason.Stopped, + "length" => ChatFinishReason.Length, + "tool_calls" => ChatFinishReason.ToolCalls, + "function_call" => ChatFinishReason.FunctionCall, + "content_filter" => ChatFinishReason.ContentFilter, + _ => throw new ArgumentException(nameof(finishReason)), + }; + continue; + } + if (choiceProperty.NameEquals("delta"u8)) + { + foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) + { + if (deltaProperty.NameEquals("role"u8)) + { + role = deltaProperty.Value.GetString() switch + { + "system" => ChatRole.System, + "user" => ChatRole.User, + "assistant" => ChatRole.Assistant, + "tool" => ChatRole.Tool, + "function" => ChatRole.Function, + _ => throw new ArgumentException(nameof(role)), + }; + continue; + } + if (deltaProperty.NameEquals("content"u8)) + { + contentUpdate = deltaProperty.Value.GetString(); + continue; + } + if (deltaProperty.NameEquals("function_call"u8)) + { + foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) + { + if (functionProperty.NameEquals("name"u8)) + { + functionName = functionProperty.Value.GetString(); + continue; + } + if (functionProperty.NameEquals("arguments"u8)) + { + functionArgumentsUpdate = functionProperty.Value.GetString(); + } + } + } + if (deltaProperty.NameEquals("tool_calls")) + { + foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) + { + toolCallUpdates.Add( + StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); + } + } + } + } + if (choiceProperty.NameEquals("logprobs"u8)) + { + Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs + = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( + choiceProperty.Value); + logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); + } + } + // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate + // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. + if (toolCallUpdates.Count == 0) + { + toolCallUpdates.Add(null); + } + foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) + { + results.Add(new StreamingChatUpdate( + id, + created, + systemFingerprint, + choiceIndex, + role, + contentUpdate, + finishReason, + functionName, + functionArgumentsUpdate, + toolCallUpdate)); + } + } + continue; + } + } + if (results.Count == 0) + { + results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); + } + return results; + } + + internal static void SerializeStreamingChatUpdates(StreamingChatUpdate StreamingChatUpdate, Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + throw new NotSupportedException(nameof(StreamingChatUpdate)); + } +} diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs index d359ea956..c5959bea0 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdate.cs @@ -181,156 +181,4 @@ internal StreamingChatUpdate( ToolCallUpdate = toolCallUpdate; LogProbabilities = logProbabilities; } - - internal static List DeserializeStreamingChatUpdates(JsonElement element) - { - List results = []; - if (element.ValueKind == JsonValueKind.Null) - { - return results; - } - string id = default; - DateTimeOffset created = default; - string systemFingerprint = null; - foreach (JsonProperty property in element.EnumerateObject()) - { - if (property.NameEquals("id"u8)) - { - id = property.Value.GetString(); - continue; - } - if (property.NameEquals("created"u8)) - { - created = DateTimeOffset.FromUnixTimeSeconds(property.Value.GetInt64()); - continue; - } - if (property.NameEquals("system_fingerprint")) - { - systemFingerprint = property.Value.GetString(); - continue; - } - if (property.NameEquals("choices"u8)) - { - foreach (JsonElement choiceElement in property.Value.EnumerateArray()) - { - ChatRole? role = null; - string contentUpdate = null; - string functionName = null; - string functionArgumentsUpdate = null; - int choiceIndex = 0; - ChatFinishReason? finishReason = null; - List toolCallUpdates = []; - ChatLogProbabilityCollection logProbabilities = new([]); - - foreach (JsonProperty choiceProperty in choiceElement.EnumerateObject()) - { - if (choiceProperty.NameEquals("index"u8)) - { - choiceIndex = choiceProperty.Value.GetInt32(); - continue; - } - if (choiceProperty.NameEquals("finish_reason"u8)) - { - if (choiceProperty.Value.ValueKind == JsonValueKind.Null) - { - finishReason = null; - continue; - } - finishReason = choiceProperty.Value.GetString() switch - { - "stop" => ChatFinishReason.Stopped, - "length" => ChatFinishReason.Length, - "tool_calls" => ChatFinishReason.ToolCalls, - "function_call" => ChatFinishReason.FunctionCall, - "content_filter" => ChatFinishReason.ContentFilter, - _ => throw new ArgumentException(nameof(finishReason)), - }; - continue; - } - if (choiceProperty.NameEquals("delta"u8)) - { - foreach (JsonProperty deltaProperty in choiceProperty.Value.EnumerateObject()) - { - if (deltaProperty.NameEquals("role"u8)) - { - role = deltaProperty.Value.GetString() switch - { - "system" => ChatRole.System, - "user" => ChatRole.User, - "assistant" => ChatRole.Assistant, - "tool" => ChatRole.Tool, - "function" => ChatRole.Function, - _ => throw new ArgumentException(nameof(role)), - }; - continue; - } - if (deltaProperty.NameEquals("content"u8)) - { - contentUpdate = deltaProperty.Value.GetString(); - continue; - } - if (deltaProperty.NameEquals("function_call"u8)) - { - foreach (JsonProperty functionProperty in deltaProperty.Value.EnumerateObject()) - { - if (functionProperty.NameEquals("name"u8)) - { - functionName = functionProperty.Value.GetString(); - continue; - } - if (functionProperty.NameEquals("arguments"u8)) - { - functionArgumentsUpdate = functionProperty.Value.GetString(); - } - } - } - if (deltaProperty.NameEquals("tool_calls")) - { - foreach (JsonElement toolCallElement in deltaProperty.Value.EnumerateArray()) - { - toolCallUpdates.Add( - StreamingToolCallUpdate.DeserializeStreamingToolCallUpdate(toolCallElement)); - } - } - } - } - if (choiceProperty.NameEquals("logprobs"u8)) - { - Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs - = Internal.Models.CreateChatCompletionResponseChoiceLogprobs.DeserializeCreateChatCompletionResponseChoiceLogprobs( - choiceProperty.Value); - logProbabilities = ChatLogProbabilityCollection.FromInternalData(internalLogprobs); - } - } - // In the unlikely event that more than one tool call arrives on a single chunk, we'll generate - // separate updates just like for choices. Adding a "null" if empty lets us avoid a separate loop. - if (toolCallUpdates.Count == 0) - { - toolCallUpdates.Add(null); - } - foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) - { - results.Add(new StreamingChatUpdate( - id, - created, - systemFingerprint, - choiceIndex, - role, - contentUpdate, - finishReason, - functionName, - functionArgumentsUpdate, - toolCallUpdate, - logProbabilities)); - } - } - continue; - } - } - if (results.Count == 0) - { - results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); - } - return results; - } } diff --git a/.dotnet/src/Utility/ModelSerializationHelpers.cs b/.dotnet/src/Utility/ModelSerializationHelpers.cs new file mode 100644 index 000000000..08356afed --- /dev/null +++ b/.dotnet/src/Utility/ModelSerializationHelpers.cs @@ -0,0 +1,110 @@ +using System; +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace OpenAI; + +internal static partial class ModelSerializationHelpers +{ + internal static TOutput DeserializeNewInstance( + UInstanceInput existingInstance, + Func deserializationFunc, + ref Utf8JsonReader reader, + ModelReaderWriterOptions options) + where UInstanceInput : IJsonModel + { + options ??= new("W"); + var format = options.Format == "W" ? ((IJsonModel)existingInstance).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + + using JsonDocument document = JsonDocument.ParseValue(ref reader); + return deserializationFunc.Invoke(document.RootElement, options); + } + + internal static TOutput DeserializeNewInstance( + UInstanceInput existingInstance, + Func deserializationFunc, + BinaryData data, + ModelReaderWriterOptions options) + where UInstanceInput : IPersistableModel + { + options ??= new("W"); + var format = options.Format == "W" ? ((IPersistableModel)existingInstance).GetFormatFromOptions(options) : options.Format; + + switch (format) + { + case "J": + { + using JsonDocument document = JsonDocument.Parse(data); + return deserializationFunc.Invoke(document.RootElement, options)!; + } + default: + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + } + + internal static void SerializeInstance( + UInstanceInput instance, + Action serializationFunc, + Utf8JsonWriter writer, + ModelReaderWriterOptions options) + where UInstanceInput : IJsonModel + { + options ??= new ModelReaderWriterOptions("W"); + AssertSupportedJsonWriteFormat(instance, options); + serializationFunc.Invoke(instance, writer, options); + } + + internal static void SerializeInstance( + T instance, + Action serializationFunc, + Utf8JsonWriter writer, + ModelReaderWriterOptions options) + where T : IJsonModel + => SerializeInstance(instance, serializationFunc, writer, options); + + internal static BinaryData SerializeInstance( + UInstanceInput instance, + ModelReaderWriterOptions options) + where UInstanceInput : IPersistableModel + { + options ??= new("W"); + AssertSupportedPersistableWriteFormat(instance, options); + return ModelReaderWriter.Write(instance, options); + } + + internal static BinaryData SerializeInstance(T instance, ModelReaderWriterOptions options) + where T : IPersistableModel + => SerializeInstance(instance, options); + + internal static void AssertSupportedJsonWriteFormat(T instance, ModelReaderWriterOptions options) + where T : IJsonModel + => AssertSupportedJsonWriteFormat(instance, options); + + internal static void AssertSupportedJsonWriteFormat(UInstanceInput instance, ModelReaderWriterOptions options) + where UInstanceInput : IJsonModel + { + var format = options.Format == "W" ? ((IJsonModel)instance).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + } + + internal static void AssertSupportedPersistableWriteFormat(T instance, ModelReaderWriterOptions options) + where T : IPersistableModel + => AssertSupportedPersistableWriteFormat(instance, options); + + internal static void AssertSupportedPersistableWriteFormat(UInstanceInput instance, ModelReaderWriterOptions options) + where UInstanceInput : IPersistableModel + { + var format = options.Format == "W" ? ((IPersistableModel)instance).GetFormatFromOptions(options) : options.Format; + if (format != "J") + { + throw new FormatException($"The model {nameof(UInstanceInput)} does not support '{format}' format."); + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEvent.cs b/.dotnet/src/Utility/ServerSentEvent.cs new file mode 100644 index 000000000..ea91b1889 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEvent.cs @@ -0,0 +1,68 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct ServerSentEvent +{ + // Gets the value of the SSE "event type" buffer, used to distinguish between event kinds. + public ReadOnlyMemory EventName { get; } + // Gets the value of the SSE "data" buffer, which holds the payload of the server-sent event. + public ReadOnlyMemory Data { get; } + // Gets the value of the "last event ID" buffer, with which a user agent can reestablish a session. + public ReadOnlyMemory LastEventId { get; } + // If present, gets the defined "retry" value for the event, which represents the delay before reconnecting. + public TimeSpan? ReconnectionTime { get; } + + private readonly IReadOnlyList _fields; + private readonly string _multiLineData; + + internal ServerSentEvent(IReadOnlyList fields) + { + _fields = fields; + StringBuilder multiLineDataBuilder = null; + for (int i = 0; i < _fields.Count; i++) + { + ReadOnlyMemory fieldValue = _fields[i].Value; + switch (_fields[i].FieldType) + { + case ServerSentEventFieldKind.Event: + EventName = fieldValue; + break; + case ServerSentEventFieldKind.Data: + { + if (multiLineDataBuilder != null) + { + multiLineDataBuilder.Append(fieldValue); + } + else if (Data.IsEmpty) + { + Data = fieldValue; + } + else + { + multiLineDataBuilder ??= new(); + multiLineDataBuilder.Append(fieldValue); + Data = null; + } + break; + } + case ServerSentEventFieldKind.Id: + LastEventId = fieldValue; + break; + case ServerSentEventFieldKind.Retry: + ReconnectionTime = Int32.TryParse(fieldValue.ToString(), out int retry) ? TimeSpan.FromMilliseconds(retry) : null; + break; + default: + break; + } + if (multiLineDataBuilder != null) + { + _multiLineData = multiLineDataBuilder.ToString(); + Data = _multiLineData.AsMemory(); + } + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventField.cs b/.dotnet/src/Utility/ServerSentEventField.cs new file mode 100644 index 000000000..c6032c931 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventField.cs @@ -0,0 +1,64 @@ +using System; + +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal readonly struct ServerSentEventField +{ + public ServerSentEventFieldKind FieldType { get; } + + // TODO: we should not expose UTF16 publicly + public ReadOnlyMemory Value + { + get + { + if (_valueStartIndex >= _original.Length) + { + return ReadOnlyMemory.Empty; + } + else + { + return _original.AsMemory(_valueStartIndex); + } + } + } + + private readonly string _original; + private readonly int _valueStartIndex; + + internal ServerSentEventField(string line) + { + _original = line; + int colonIndex = _original.AsSpan().IndexOf(':'); + + ReadOnlyMemory fieldName = colonIndex < 0 ? _original.AsMemory(): _original.AsMemory(0, colonIndex); + FieldType = fieldName.Span switch + { + var x when x.SequenceEqual(s_eventFieldName.Span) => ServerSentEventFieldKind.Event, + var x when x.SequenceEqual(s_dataFieldName.Span) => ServerSentEventFieldKind.Data, + var x when x.SequenceEqual(s_lastEventIdFieldName.Span) => ServerSentEventFieldKind.Id, + var x when x.SequenceEqual(s_retryFieldName.Span) => ServerSentEventFieldKind.Retry, + _ => ServerSentEventFieldKind.Ignored, + }; + + if (colonIndex < 0) + { + _valueStartIndex = _original.Length; + } + else if (colonIndex + 1 < _original.Length && _original[colonIndex + 1] == ' ') + { + _valueStartIndex = colonIndex + 2; + } + else + { + _valueStartIndex = colonIndex + 1; + } + } + + public override string ToString() => _original; + + private static readonly ReadOnlyMemory s_eventFieldName = "event".AsMemory(); + private static readonly ReadOnlyMemory s_dataFieldName = "data".AsMemory(); + private static readonly ReadOnlyMemory s_lastEventIdFieldName = "id".AsMemory(); + private static readonly ReadOnlyMemory s_retryFieldName = "retry".AsMemory(); +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventFieldKind.cs b/.dotnet/src/Utility/ServerSentEventFieldKind.cs new file mode 100644 index 000000000..c3597b0ff --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventFieldKind.cs @@ -0,0 +1,11 @@ +namespace OpenAI; + +// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream +internal enum ServerSentEventFieldKind +{ + Event, + Data, + Id, + Retry, + Ignored +} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseAsyncEnumerator.cs b/.dotnet/src/Utility/SseAsyncEnumerator.cs deleted file mode 100644 index 743a1bedd..000000000 --- a/.dotnet/src/Utility/SseAsyncEnumerator.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Runtime.CompilerServices; -using System.Text.Json; -using System.Threading; - -namespace OpenAI; - -internal static class SseAsyncEnumerator -{ - internal static async IAsyncEnumerable EnumerateFromSseStream( - Stream stream, - Func> multiElementDeserializer, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - try - { - using SseReader sseReader = new(stream); - while (!cancellationToken.IsCancellationRequested) - { - SseLine? sseEvent = await sseReader.TryReadSingleFieldEventAsync().ConfigureAwait(false); - if (sseEvent is not null) - { - ReadOnlyMemory name = sseEvent.Value.FieldName; - if (!name.Span.SequenceEqual("data".AsSpan())) - { - throw new InvalidDataException(); - } - ReadOnlyMemory value = sseEvent.Value.FieldValue; - if (value.Span.SequenceEqual("[DONE]".AsSpan())) - { - break; - } - using JsonDocument sseMessageJson = JsonDocument.Parse(value); - IEnumerable newItems = multiElementDeserializer.Invoke(sseMessageJson.RootElement); - foreach (T item in newItems) - { - yield return item; - } - } - } - } - finally - { - // Always dispose the stream immediately once enumeration is complete for any reason - stream.Dispose(); - } - } - - internal static IAsyncEnumerable EnumerateFromSseStream( - Stream stream, - Func elementDeserializer, - CancellationToken cancellationToken = default) - => EnumerateFromSseStream( - stream, - (element) => new T[] { elementDeserializer.Invoke(element) }, - cancellationToken); -} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseLine.cs b/.dotnet/src/Utility/SseLine.cs deleted file mode 100644 index 4d82315f9..000000000 --- a/.dotnet/src/Utility/SseLine.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System; - -namespace OpenAI; - -// SSE specification: https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream -internal readonly struct SseLine -{ - private readonly string _original; - private readonly int _colonIndex; - private readonly int _valueIndex; - - public static SseLine Empty { get; } = new SseLine(string.Empty, 0, false); - - internal SseLine(string original, int colonIndex, bool hasSpaceAfterColon) - { - _original = original; - _colonIndex = colonIndex; - _valueIndex = colonIndex + (hasSpaceAfterColon ? 2 : 1); - } - - public bool IsEmpty => _original.Length == 0; - public bool IsComment => !IsEmpty && _original[0] == ':'; - - // TODO: we should not expose UTF16 publicly - public ReadOnlyMemory FieldName => _original.AsMemory(0, _colonIndex); - public ReadOnlyMemory FieldValue => _original.AsMemory(_valueIndex); - - public override string ToString() => _original; -} \ No newline at end of file diff --git a/.dotnet/src/Utility/SseReader.cs b/.dotnet/src/Utility/SseReader.cs index cf0301408..cab725caf 100644 --- a/.dotnet/src/Utility/SseReader.cs +++ b/.dotnet/src/Utility/SseReader.cs @@ -1,118 +1,124 @@ using System; -using System.ClientModel; -using System.ClientModel.Internal; +using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace OpenAI; internal sealed class SseReader : IDisposable +{ + private readonly Stream _stream; + private readonly StreamReader _reader; + private bool _disposedValue; + + public SseReader(Stream stream) { - private readonly Stream _stream; - private readonly StreamReader _reader; - private bool _disposedValue; + _stream = stream; + _reader = new StreamReader(stream); + } - public SseReader(Stream stream) - { - _stream = stream; - _reader = new StreamReader(stream); - } + /// + /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) + { + List fields = []; - public SseLine? TryReadSingleFieldEvent() + while (!cancellationToken.IsCancellationRequested) { - while (true) + string line = _reader.ReadLine(); + if (line == null) { - SseLine? line = TryReadLine(); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = TryReadLine(); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + // A null line indicates end of input + return null; } - } - - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadSingleFieldEventAsync() - { - while (true) + else if (line.Length == 0) { - SseLine? line = await TryReadLineAsync().ConfigureAwait(false); - if (line == null) - return null; - if (line.Value.IsEmpty) - throw new InvalidDataException("event expected."); - SseLine? empty = await TryReadLineAsync().ConfigureAwait(false); - if (empty != null && !empty.Value.IsEmpty) - throw new NotSupportedException("Multi-filed events not supported."); - if (!line.Value.IsComment) - return line; // skip comment lines + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); } } - public SseLine? TryReadLine() - { - string lineText = _reader.ReadLine(); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } + return null; + } - // TODO: we should support cancellation tokens, but StreamReader does not in NS2 - public async Task TryReadLineAsync() - { - string lineText = await _reader.ReadLineAsync().ConfigureAwait(false); - if (lineText == null) - return null; - if (lineText.Length == 0) - return SseLine.Empty; - if (TryParseLine(lineText, out SseLine line)) - return line; - return null; - } + /// + /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + { + List fields = []; - private static bool TryParseLine(string lineText, out SseLine line) + while (!cancellationToken.IsCancellationRequested) { - if (lineText.Length == 0) + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + if (line == null) + { + // A null line indicates end of input + return null; + } + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else { - line = default; - return false; + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); } + } - ReadOnlySpan lineSpan = lineText.AsSpan(); - int colonIndex = lineSpan.IndexOf(':'); - ReadOnlySpan fieldValue = lineSpan.Slice(colonIndex + 1); + return null; + } - bool hasSpace = false; - if (fieldValue.Length > 0 && fieldValue[0] == ' ') - hasSpace = true; - line = new SseLine(lineText, colonIndex, hasSpace); - return true; - } + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } - private void Dispose(bool disposing) + private void Dispose(bool disposing) + { + if (!_disposedValue) { - if (!_disposedValue) + if (disposing) { - if (disposing) - { - _reader.Dispose(); - _stream.Dispose(); - } - - _disposedValue = true; + _reader.Dispose(); + _stream.Dispose(); } + + _disposedValue = true; } - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - } \ No newline at end of file + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamingClientResult.cs b/.dotnet/src/Utility/StreamingClientResult.cs new file mode 100644 index 000000000..5e91fa233 --- /dev/null +++ b/.dotnet/src/Utility/StreamingClientResult.cs @@ -0,0 +1,134 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text.Json; +using System.Threading; + +namespace OpenAI; + +/// +/// Represents an operation response with streaming content that can be deserialized and enumerated while the response +/// is still being received. +/// +/// The data type representative of distinct, streamable items. +public class StreamingClientResult + : IDisposable + , IAsyncEnumerable +{ + private readonly PipelineResponse _baseResponse; + private readonly IAsyncEnumerable _asyncEnumerable; + private bool _disposedValue; + + /// + /// Gets the underlying that contains headers and other response-wide information. + /// + /// + /// The instance used in this . + /// + public PipelineResponse GetRawResponse() => _baseResponse; + + private StreamingClientResult() { } + + private StreamingClientResult( + PipelineResponse response, + Func> asyncEnumerableProcessor) + { + _baseResponse = response; + _asyncEnumerable = asyncEnumerableProcessor.Invoke(_baseResponse.ContentStream); + } + + public static StreamingClientResult Create( + PipelineResponse response, + Func> multiElementJsonDeserializerFunc, + CancellationToken cancellationToken = default) + { + return new(response, (stream) + => EnumerateFromJsonStream(stream, multiElementJsonDeserializerFunc, cancellationToken)); + } + + public static StreamingClientResult Create( + PipelineResponse response, + CancellationToken cancellationToken = default) + where U : IJsonModel + { + return new(response, (stream) => EnumerateFromJsonStream( + stream, + (sseChunkElement) => + { + BinaryData sseData = BinaryData.FromObjectAsJson(sseChunkElement.GetRawText()); + return [ModelReaderWriter.Read(sseData)]; + }, + cancellationToken)); + } + + private static async IAsyncEnumerable EnumerateFromJsonStream( + Stream contentStream, + Func> multiElementJsonDeserializerFunc, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + try + { + using SseReader sseReader = new(contentStream); + while (!cancellationToken.IsCancellationRequested) + { + ServerSentEvent? sseEvent = await sseReader + .TryGetNextEventAsync(cancellationToken) + .ConfigureAwait(false); + if (sseEvent is null) + { + break; + } + else + { + if (IsWellKnownDoneToken(sseEvent.Value.Data)) continue; + using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Value.Data); + IEnumerable sseDataItems = multiElementJsonDeserializerFunc.Invoke(sseDocument.RootElement); + foreach (U item in sseDataItems) + { + yield return item; + } + } + } + } + finally + { + // Always dispose the stream immediately once enumeration is complete for any reason + contentStream.Dispose(); + } + } + + /// + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _baseResponse?.Dispose(); + } + _disposedValue = true; + } + } + + private static bool IsWellKnownDoneToken(ReadOnlyMemory data) + { + ReadOnlyMemory[] wellKnownTokens = + [ + "[DONE]".AsMemory(), + ]; + return wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span)); + } + + IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) + => _asyncEnumerable.GetAsyncEnumerator(cancellationToken); +} diff --git a/.dotnet/src/Utility/StreamingResult.cs b/.dotnet/src/Utility/StreamingResult.cs deleted file mode 100644 index a1b6ff538..000000000 --- a/.dotnet/src/Utility/StreamingResult.cs +++ /dev/null @@ -1,95 +0,0 @@ -using System.ClientModel; -using System.ClientModel; -using System.ClientModel.Primitives; -using System.Threading; -using System.Collections.Generic; -using System; - -namespace OpenAI; - -/// -/// Represents an operation response with streaming content that can be deserialized and enumerated while the response -/// is still being received. -/// -/// The data type representative of distinct, streamable items. -public class StreamingClientResult - : IDisposable - , IAsyncEnumerable -{ - private ClientResult _rawResult { get; } - private IAsyncEnumerable _asyncEnumerableSource { get; } - private bool _disposedValue { get; set; } - - private StreamingClientResult() { } - - private StreamingClientResult( - ClientResult rawResult, - Func> asyncEnumerableProcessor) - { - _rawResult = rawResult; - _asyncEnumerableSource = asyncEnumerableProcessor.Invoke(rawResult); - } - - /// - /// Creates a new instance of using the provided underlying HTTP response. The - /// provided function will be used to resolve the response into an asynchronous enumeration of streamed response - /// items. - /// - /// The HTTP response. - /// - /// The function that will resolve the provided response into an IAsyncEnumerable. - /// - /// - /// A new instance of that will be capable of asynchronous enumeration of - /// items from the HTTP response. - /// - internal static StreamingClientResult CreateFromResponse( - ClientResult result, - Func> asyncEnumerableProcessor) - { - return new(result, asyncEnumerableProcessor); - } - - /// - /// Gets the underlying instance that this may enumerate - /// over. - /// - /// The instance attached to this . - public PipelineResponse GetRawResponse() => _rawResult.GetRawResponse(); - - /// - /// Gets the asynchronously enumerable collection of distinct, streamable items in the response. - /// - /// - /// The return value of this method may be used with the "await foreach" statement. - /// - /// As explicitly implements , callers may - /// enumerate a instance directly instead of calling this method. - /// - /// - /// - public IAsyncEnumerable EnumerateValues() => this; - - /// - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - /// - protected virtual void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _rawResult?.GetRawResponse()?.Dispose(); - } - _disposedValue = true; - } - } - - IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) - => _asyncEnumerableSource.GetAsyncEnumerator(cancellationToken); -} \ No newline at end of file From 2750a56c1b4b434b7ac3782cd6605e242133a85e Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Mon, 29 Apr 2024 20:20:29 -0700 Subject: [PATCH 2/3] comments --- .dotnet/src/Utility/StreamingClientResult.cs | 32 ++++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/.dotnet/src/Utility/StreamingClientResult.cs b/.dotnet/src/Utility/StreamingClientResult.cs index 5e91fa233..2217241f0 100644 --- a/.dotnet/src/Utility/StreamingClientResult.cs +++ b/.dotnet/src/Utility/StreamingClientResult.cs @@ -40,21 +40,47 @@ private StreamingClientResult( _asyncEnumerable = asyncEnumerableProcessor.Invoke(_baseResponse.ContentStream); } + /// + /// Creates a new instance of that will yield items of the specified type + /// as they become available via server-sent event JSON data on the available + /// . This overload supports deserializing multiple instaces of + /// per server-sent event using the provided multi-element deserialization delegate. + /// + /// The base for this result instance. + /// + /// The delegate that will be used to extract a collection of elements from each incoming JSON data payload. + /// + /// + /// The optional cancellation token used to control the enumeration. + /// + /// A new instance of . public static StreamingClientResult Create( PipelineResponse response, Func> multiElementJsonDeserializerFunc, CancellationToken cancellationToken = default) { return new(response, (stream) - => EnumerateFromJsonStream(stream, multiElementJsonDeserializerFunc, cancellationToken)); + => EnumerateFromSseJsonStream(stream, multiElementJsonDeserializerFunc, cancellationToken)); } + /// + /// Creates a new instance of that will yield items of the specified type + /// as they become available via server-sent event JSON data on the available + /// . This overload uses via the + /// interface and only supports single-item deserialization per server-sent event data + /// payload. + /// + /// The base for this result instance. + /// + /// The optional cancellation token used to control the enumeration. + /// + /// A new instance of . public static StreamingClientResult Create( PipelineResponse response, CancellationToken cancellationToken = default) where U : IJsonModel { - return new(response, (stream) => EnumerateFromJsonStream( + return new(response, (stream) => EnumerateFromSseJsonStream( stream, (sseChunkElement) => { @@ -64,7 +90,7 @@ public static StreamingClientResult Create( cancellationToken)); } - private static async IAsyncEnumerable EnumerateFromJsonStream( + private static async IAsyncEnumerable EnumerateFromSseJsonStream( Stream contentStream, Func> multiElementJsonDeserializerFunc, [EnumeratorCancellation] CancellationToken cancellationToken = default) From e225ab52c4b7f49eb3290b333a6bb6722cd25a67 Mon Sep 17 00:00:00 2001 From: Travis Wilson Date: Tue, 30 Apr 2024 12:40:20 -0700 Subject: [PATCH 3/3] partial pattern fixup --- .dotnet/src/Custom/Chat/ChatClient.cs | 38 +---- ...mingChatUpdateCollection.Serialization.cs} | 34 ++--- .../Chat/StreamingChatUpdateCollection.cs | 12 ++ .../Utility/AsyncServerSentEventEnumerator.cs | 64 ++++++++ .../AsyncServerSentEventJsonDataEnumerator.cs | 68 +++++++++ .dotnet/src/Utility/ServerSentEventReader.cs | 123 +++++++++++++++ .dotnet/src/Utility/StreamingClientResult.cs | 140 +++++------------- 7 files changed, 326 insertions(+), 153 deletions(-) rename .dotnet/src/Custom/Chat/{StreamingChatUpdate.Serialization.cs => StreamingChatUpdateCollection.Serialization.cs} (82%) create mode 100644 .dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs create mode 100644 .dotnet/src/Utility/AsyncServerSentEventEnumerator.cs create mode 100644 .dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs create mode 100644 .dotnet/src/Utility/ServerSentEventReader.cs diff --git a/.dotnet/src/Custom/Chat/ChatClient.cs b/.dotnet/src/Custom/Chat/ChatClient.cs index 4d115ba80..67cb06ed4 100644 --- a/.dotnet/src/Custom/Chat/ChatClient.cs +++ b/.dotnet/src/Custom/Chat/ChatClient.cs @@ -204,23 +204,8 @@ public virtual StreamingClientResult CompleteChatStreaming( { throw new ClientResultException(requestMessage.Response); } - - PipelineResponse response = null; - try - { - response = requestMessage.ExtractResponse(); - ClientResult genericResult = ClientResult.FromResponse(response); - StreamingClientResult streamingResult - = StreamingClientResult.Create( - response, - (sseJson) => StreamingChatUpdate.DeserializeStreamingChatUpdates(sseJson, options: default)); - response = null; - return streamingResult; - } - finally - { - response?.Dispose(); - } + return StreamingClientResult + .Create(requestMessage.Response); } /// @@ -250,23 +235,8 @@ public virtual async Task> CompleteCh { throw new ClientResultException(requestMessage.Response); } - - PipelineResponse response = null; - try - { - response = requestMessage.ExtractResponse(); - ClientResult genericResult = ClientResult.FromResponse(response); - StreamingClientResult streamingResult - = StreamingClientResult.Create( - response, - (sseJson) => StreamingChatUpdate.DeserializeStreamingChatUpdates(sseJson, options: default)); - response = null; - return streamingResult; - } - finally - { - response?.Dispose(); - } + return StreamingClientResult + .Create(requestMessage.Response); } private Internal.Models.CreateChatCompletionRequest CreateInternalRequest( diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs similarity index 82% rename from .dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs rename to .dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs index bc1a53ef8..224878b72 100644 --- a/.dotnet/src/Custom/Chat/StreamingChatUpdate.Serialization.cs +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.Serialization.cs @@ -5,30 +5,30 @@ namespace OpenAI.Chat; -public partial class StreamingChatUpdate : IJsonModel> +internal partial class StreamingChatUpdateCollection : IJsonModel { - IEnumerable IJsonModel>.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) - => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdates, ref reader, options); + StreamingChatUpdateCollection IJsonModel.Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdateCollection, ref reader, options); - IEnumerable IPersistableModel>.Create(BinaryData data, ModelReaderWriterOptions options) - => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdates, data, options); + StreamingChatUpdateCollection IPersistableModel.Create(BinaryData data, ModelReaderWriterOptions options) + => ModelSerializationHelpers.DeserializeNewInstance(this, DeserializeStreamingChatUpdateCollection, data, options); - void IJsonModel>.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) - => ModelSerializationHelpers.SerializeInstance, StreamingChatUpdate>(this, SerializeStreamingChatUpdates, writer, options); + void IJsonModel.Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + => ModelSerializationHelpers.SerializeInstance(this, SerializeStreamingChatUpdateCollections, writer, options); - BinaryData IPersistableModel>.Write(ModelReaderWriterOptions options) - => ModelSerializationHelpers.SerializeInstance, StreamingChatUpdate>(this, options); + BinaryData IPersistableModel.Write(ModelReaderWriterOptions options) + => ModelSerializationHelpers.SerializeInstance(this, options); - string IPersistableModel>.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; + string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; - internal static IEnumerable DeserializeStreamingChatUpdates( + internal static StreamingChatUpdateCollection DeserializeStreamingChatUpdateCollection( JsonElement sseDataJson, ModelReaderWriterOptions options = default) { List results = []; if (sseDataJson.ValueKind == JsonValueKind.Null) { - return results; + return new(results); } string id = default; DateTimeOffset created = default; @@ -151,7 +151,7 @@ Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs } foreach (StreamingToolCallUpdate toolCallUpdate in toolCallUpdates) { - results.Add(new StreamingChatUpdate( + results.Add(new( id, created, systemFingerprint, @@ -169,13 +169,13 @@ Internal.Models.CreateChatCompletionResponseChoiceLogprobs internalLogprobs } if (results.Count == 0) { - results.Add(new StreamingChatUpdate(id, created, systemFingerprint)); + results.Add(new(id, created, systemFingerprint)); } - return results; + return new(results); } - internal static void SerializeStreamingChatUpdates(StreamingChatUpdate StreamingChatUpdate, Utf8JsonWriter writer, ModelReaderWriterOptions options) + internal static void SerializeStreamingChatUpdateCollections(StreamingChatUpdateCollection StreamingChatUpdateCollection, Utf8JsonWriter writer, ModelReaderWriterOptions options) { - throw new NotSupportedException(nameof(StreamingChatUpdate)); + throw new NotSupportedException(nameof(StreamingChatUpdateCollection)); } } diff --git a/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs new file mode 100644 index 000000000..d0e2d79cb --- /dev/null +++ b/.dotnet/src/Custom/Chat/StreamingChatUpdateCollection.cs @@ -0,0 +1,12 @@ +namespace OpenAI.Chat; + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Text.Json; + +internal partial class StreamingChatUpdateCollection : ReadOnlyCollection +{ + internal StreamingChatUpdateCollection() : this([]) { } + internal StreamingChatUpdateCollection(IList list) : base(list) { } +} diff --git a/.dotnet/src/Utility/AsyncServerSentEventEnumerator.cs b/.dotnet/src/Utility/AsyncServerSentEventEnumerator.cs new file mode 100644 index 000000000..934773cf8 --- /dev/null +++ b/.dotnet/src/Utility/AsyncServerSentEventEnumerator.cs @@ -0,0 +1,64 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI; + +internal sealed class AsyncServerSentEventEnumerator : IAsyncEnumerator, IDisposable, IAsyncDisposable +{ + private static readonly ReadOnlyMemory _doneToken = "[DONE]".AsMemory(); + + private readonly ServerSentEventReader _reader; + private CancellationToken _cancellationToken; + private bool _disposedValue; + + public ServerSentEvent Current { get; private set; } + + public AsyncServerSentEventEnumerator(ServerSentEventReader reader, CancellationToken cancellationToken = default) + { + _reader = reader; + _cancellationToken = cancellationToken; + } + + public async ValueTask MoveNextAsync() + { + ServerSentEvent? nextEvent = await _reader.TryGetNextEventAsync(_cancellationToken).ConfigureAwait(false); + if (nextEvent.HasValue) + { + if (nextEvent.Value.Data.Span.SequenceEqual(_doneToken.Span)) + { + return false; + } + Current = nextEvent.Value; + return true; + } + return false; + } + + private void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _reader.Dispose(); + } + + _disposedValue = true; + } + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + public ValueTask DisposeAsync() + { + Dispose(); + return new ValueTask(); + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs b/.dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs new file mode 100644 index 000000000..cd915b882 --- /dev/null +++ b/.dotnet/src/Utility/AsyncServerSentEventJsonDataEnumerator.cs @@ -0,0 +1,68 @@ +using System; +using System.ClientModel.Primitives; +using System.Collections.Generic; +using System.Text.Json; +using System.Threading.Tasks; + +namespace OpenAI; + +internal class AsyncServerSentEventJsonDataEnumerator : AsyncServerSentEventJsonDataEnumerator + where T : IJsonModel +{ + public AsyncServerSentEventJsonDataEnumerator(AsyncServerSentEventEnumerator eventEnumerator) : base(eventEnumerator) + { } +} + +internal class AsyncServerSentEventJsonDataEnumerator : IAsyncEnumerator, IDisposable, IAsyncDisposable + where TJsonDataType : IJsonModel +{ + private AsyncServerSentEventEnumerator _eventEnumerator; + private IEnumerator _currentInstanceEnumerator; + + public TInstanceType Current { get; private set; } + + public AsyncServerSentEventJsonDataEnumerator(AsyncServerSentEventEnumerator eventEnumerator) + { + _eventEnumerator = eventEnumerator; + } + + public async ValueTask MoveNextAsync() + { + if (_currentInstanceEnumerator?.MoveNext() == true) + { + Current = _currentInstanceEnumerator.Current; + return true; + } + if (await _eventEnumerator.MoveNextAsync()) + { + using JsonDocument eventDocument = JsonDocument.Parse(_eventEnumerator.Current.Data); + BinaryData eventData = BinaryData.FromObjectAsJson(eventDocument.RootElement); + TJsonDataType jsonData = ModelReaderWriter.Read(eventData); + if (jsonData is TInstanceType singleInstanceData) + { + Current = singleInstanceData; + return true; + } + if (jsonData is IEnumerable instanceCollectionData) + { + _currentInstanceEnumerator = instanceCollectionData.GetEnumerator(); + if (_currentInstanceEnumerator.MoveNext() == true) + { + Current = _currentInstanceEnumerator.Current; + return true; + } + } + } + return false; + } + + public async ValueTask DisposeAsync() + { + await _eventEnumerator.DisposeAsync(); + } + + public void Dispose() + { + _eventEnumerator.Dispose(); + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/ServerSentEventReader.cs b/.dotnet/src/Utility/ServerSentEventReader.cs new file mode 100644 index 000000000..d47938254 --- /dev/null +++ b/.dotnet/src/Utility/ServerSentEventReader.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenAI; + +internal sealed class ServerSentEventReader : IDisposable +{ + private readonly Stream _stream; + private readonly StreamReader _reader; + private bool _disposedValue; + + public ServerSentEventReader(Stream stream) + { + _stream = stream; + _reader = new StreamReader(stream); + } + + /// + /// Synchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public ServerSentEvent? TryGetNextEvent(CancellationToken cancellationToken = default) + { + List fields = []; + + while (!cancellationToken.IsCancellationRequested) + { + string line = _reader.ReadLine(); + if (line == null) + { + // A null line indicates end of input + return null; + } + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + fields = []; + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); + } + } + + return null; + } + + /// + /// Asynchronously retrieves the next server-sent event from the underlying stream, blocking until a new event is + /// available and returning null once no further data is present on the stream. + /// + /// An optional cancellation token that can abort subsequent reads. + /// + /// The next in the stream, or null once no more data can be read from the stream. + /// + public async Task TryGetNextEventAsync(CancellationToken cancellationToken = default) + { + List fields = []; + + while (!cancellationToken.IsCancellationRequested) + { + string line = await _reader.ReadLineAsync().ConfigureAwait(false); + if (line == null) + { + // A null line indicates end of input + return null; + } + else if (line.Length == 0) + { + // An empty line should dispatch an event for pending accumulated fields + ServerSentEvent nextEvent = new(fields); + return nextEvent; + } + else if (line[0] == ':') + { + // A line beginning with a colon is a comment and should be ignored + continue; + } + else + { + // Otherwise, process the the field + value and accumulate it for the next dispatched event + fields.Add(new ServerSentEventField(line)); + } + } + + return null; + } + + public void Dispose() + { + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _reader.Dispose(); + _stream.Dispose(); + } + + _disposedValue = true; + } + } +} \ No newline at end of file diff --git a/.dotnet/src/Utility/StreamingClientResult.cs b/.dotnet/src/Utility/StreamingClientResult.cs index 2217241f0..57c43a01e 100644 --- a/.dotnet/src/Utility/StreamingClientResult.cs +++ b/.dotnet/src/Utility/StreamingClientResult.cs @@ -6,6 +6,7 @@ using System.Runtime.CompilerServices; using System.Text.Json; using System.Threading; +using System.Threading.Tasks; namespace OpenAI; @@ -14,12 +15,11 @@ namespace OpenAI; /// is still being received. /// /// The data type representative of distinct, streamable items. -public class StreamingClientResult - : IDisposable - , IAsyncEnumerable +public class StreamingClientResult : IAsyncEnumerable { - private readonly PipelineResponse _baseResponse; - private readonly IAsyncEnumerable _asyncEnumerable; + private readonly PipelineResponse _response; + private readonly Func> _asyncEnumeratorSourceDelegate; + private bool _disposedValue; /// @@ -28,39 +28,12 @@ public class StreamingClientResult /// /// The instance used in this . /// - public PipelineResponse GetRawResponse() => _baseResponse; - - private StreamingClientResult() { } - - private StreamingClientResult( - PipelineResponse response, - Func> asyncEnumerableProcessor) - { - _baseResponse = response; - _asyncEnumerable = asyncEnumerableProcessor.Invoke(_baseResponse.ContentStream); - } + public PipelineResponse GetRawResponse() => _response; - /// - /// Creates a new instance of that will yield items of the specified type - /// as they become available via server-sent event JSON data on the available - /// . This overload supports deserializing multiple instaces of - /// per server-sent event using the provided multi-element deserialization delegate. - /// - /// The base for this result instance. - /// - /// The delegate that will be used to extract a collection of elements from each incoming JSON data payload. - /// - /// - /// The optional cancellation token used to control the enumeration. - /// - /// A new instance of . - public static StreamingClientResult Create( - PipelineResponse response, - Func> multiElementJsonDeserializerFunc, - CancellationToken cancellationToken = default) + private StreamingClientResult(PipelineResponse response, Func> asyncEnumeratorSourceDelegate) { - return new(response, (stream) - => EnumerateFromSseJsonStream(stream, multiElementJsonDeserializerFunc, cancellationToken)); + _response = response; + _asyncEnumeratorSourceDelegate = asyncEnumeratorSourceDelegate; } /// @@ -75,86 +48,49 @@ public static StreamingClientResult Create( /// The optional cancellation token used to control the enumeration. /// /// A new instance of . - public static StreamingClientResult Create( - PipelineResponse response, - CancellationToken cancellationToken = default) - where U : IJsonModel + public static StreamingClientResult Create(PipelineResponse response, CancellationToken cancellationToken = default) + where U : IJsonModel { - return new(response, (stream) => EnumerateFromSseJsonStream( - stream, - (sseChunkElement) => - { - BinaryData sseData = BinaryData.FromObjectAsJson(sseChunkElement.GetRawText()); - return [ModelReaderWriter.Read(sseData)]; - }, - cancellationToken)); + return new(response, GetServerSentEventDeserializationEnumerator); } - private static async IAsyncEnumerable EnumerateFromSseJsonStream( - Stream contentStream, - Func> multiElementJsonDeserializerFunc, - [EnumeratorCancellation] CancellationToken cancellationToken = default) + public static StreamingClientResult Create(PipelineResponse response, CancellationToken cancellationToken = default) + where TJsonDataType : IJsonModel { - try - { - using SseReader sseReader = new(contentStream); - while (!cancellationToken.IsCancellationRequested) - { - ServerSentEvent? sseEvent = await sseReader - .TryGetNextEventAsync(cancellationToken) - .ConfigureAwait(false); - if (sseEvent is null) - { - break; - } - else - { - if (IsWellKnownDoneToken(sseEvent.Value.Data)) continue; - using JsonDocument sseDocument = JsonDocument.Parse(sseEvent.Value.Data); - IEnumerable sseDataItems = multiElementJsonDeserializerFunc.Invoke(sseDocument.RootElement); - foreach (U item in sseDataItems) - { - yield return item; - } - } - } - } - finally - { - // Always dispose the stream immediately once enumeration is complete for any reason - contentStream.Dispose(); - } + return new(response, GetServerSentEventDeserializationEnumerator); } - /// - public void Dispose() + private static IAsyncEnumerator GetServerSentEventDeserializationEnumerator(Stream stream, CancellationToken cancellationToken = default) + where U : IJsonModel { - Dispose(disposing: true); - GC.SuppressFinalize(this); + return GetServerSentEventDeserializationEnumerator(stream, cancellationToken); } - /// - protected virtual void Dispose(bool disposing) + private static IAsyncEnumerator GetServerSentEventDeserializationEnumerator( + Stream stream, + CancellationToken cancellationToken = default) + where TJsonDataType : IJsonModel { - if (!_disposedValue) + ServerSentEventReader sseReader = null; + AsyncServerSentEventEnumerator sseEnumerator = null; + try + { + sseReader = new(stream); + sseEnumerator = new(sseReader, cancellationToken); + AsyncServerSentEventJsonDataEnumerator instanceEnumerator = new(sseEnumerator); + sseEnumerator = null; + sseReader = null; + return instanceEnumerator; + } + finally { - if (disposing) - { - _baseResponse?.Dispose(); - } - _disposedValue = true; + sseEnumerator?.Dispose(); + sseReader?.Dispose(); } } - private static bool IsWellKnownDoneToken(ReadOnlyMemory data) + IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) { - ReadOnlyMemory[] wellKnownTokens = - [ - "[DONE]".AsMemory(), - ]; - return wellKnownTokens.Any(token => data.Span.SequenceEqual(token.Span)); + return _asyncEnumeratorSourceDelegate.Invoke(_response.ContentStream, cancellationToken); } - - IAsyncEnumerator IAsyncEnumerable.GetAsyncEnumerator(CancellationToken cancellationToken) - => _asyncEnumerable.GetAsyncEnumerator(cancellationToken); }