diff --git a/src/Custom/Embeddings/OpenAIEmbedding.cs b/src/Custom/Embeddings/OpenAIEmbedding.cs index ac644c843..cbb581e29 100644 --- a/src/Custom/Embeddings/OpenAIEmbedding.cs +++ b/src/Custom/Embeddings/OpenAIEmbedding.cs @@ -118,32 +118,36 @@ private static ReadOnlyMemory ConvertToVectorOfFloats(BinaryData binaryDa // Decode base64 string to bytes. byte[] bytes = ArrayPool.Shared.Rent(Base64.GetMaxDecodedFromUtf8Length(base64.Length)); - OperationStatus status = Base64.DecodeFromUtf8(base64, bytes.AsSpan(), out int bytesConsumed, out int bytesWritten); - if (status != OperationStatus.Done || bytesWritten % sizeof(float) != 0) - { - ThrowInvalidData(); - } - - // Interpret bytes as floats - float[] vector = new float[bytesWritten / sizeof(float)]; - bytes.AsSpan(0, bytesWritten).CopyTo(MemoryMarshal.AsBytes(vector.AsSpan())); - if (!BitConverter.IsLittleEndian) - { - Span ints = MemoryMarshal.Cast(vector.AsSpan()); + try + { + OperationStatus status = Base64.DecodeFromUtf8(base64, bytes.AsSpan(), out int bytesConsumed, out int bytesWritten); + if (status != OperationStatus.Done || bytesWritten % sizeof(float) != 0) + { + ThrowInvalidData(); + } + + // Interpret bytes as floats + float[] vector = new float[bytesWritten / sizeof(float)]; + bytes.AsSpan(0, bytesWritten).CopyTo(MemoryMarshal.AsBytes(vector.AsSpan())); + if (!BitConverter.IsLittleEndian) + { + Span ints = MemoryMarshal.Cast(vector.AsSpan()); #if NET8_0_OR_GREATER - BinaryPrimitives.ReverseEndianness(ints, ints); + BinaryPrimitives.ReverseEndianness(ints, ints); #else for (int i = 0; i < ints.Length; i++) { ints[i] = BinaryPrimitives.ReverseEndianness(ints[i]); } #endif + } + return new ReadOnlyMemory(vector); + } + finally + { + ArrayPool.Shared.Return(bytes); } - - ArrayPool.Shared.Return(bytes); - return new ReadOnlyMemory(vector); - - static void ThrowInvalidData() => + } + static void ThrowInvalidData() => throw new FormatException("The input is not a valid Base64 string of encoded floats."); - } } diff --git a/src/Custom/RealtimeConversation/Internal/AsyncWebsocketMessageEnumerator.cs b/src/Custom/RealtimeConversation/Internal/AsyncWebsocketMessageEnumerator.cs index 3df40ef6f..906823ff0 100644 --- a/src/Custom/RealtimeConversation/Internal/AsyncWebsocketMessageEnumerator.cs +++ b/src/Custom/RealtimeConversation/Internal/AsyncWebsocketMessageEnumerator.cs @@ -26,6 +26,7 @@ public AsyncWebsocketMessageResultEnumerator(WebSocket webSocket, CancellationTo public ValueTask DisposeAsync() { + ArrayPool.Shared.Return(_receiveBuffer); _webSocket?.Dispose(); return new ValueTask(Task.CompletedTask); } @@ -50,4 +51,4 @@ public async ValueTask MoveNextAsync() Current = ClientResult.FromResponse(websocketPipelineResponse); return true; } -} \ No newline at end of file +} diff --git a/src/Custom/RealtimeConversation/RealtimeConversationSession.cs b/src/Custom/RealtimeConversation/RealtimeConversationSession.cs index 31a4ef50a..f27bc600c 100644 --- a/src/Custom/RealtimeConversation/RealtimeConversationSession.cs +++ b/src/Custom/RealtimeConversation/RealtimeConversationSession.cs @@ -55,9 +55,9 @@ public virtual async Task SendInputAudioAsync(Stream audio, CancellationToken ca } _isSendingAudioStream = true; } + byte[] buffer = ArrayPool.Shared.Rent(1024 * 16); try { - byte[] buffer = ArrayPool.Shared.Rent(1024 * 16); while (true) { int bytesRead = await audio.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); @@ -75,6 +75,7 @@ public virtual async Task SendInputAudioAsync(Stream audio, CancellationToken ca } finally { + ArrayPool.Shared.Return(buffer); using (await _audioSendSemaphore.AutoReleaseWaitAsync(cancellationToken).ConfigureAwait(false)) { _isSendingAudioStream = false; @@ -93,9 +94,9 @@ public virtual void SendInputAudio(Stream audio, CancellationToken cancellationT } _isSendingAudioStream = true; } + byte[] buffer = ArrayPool.Shared.Rent(1024 * 16); try { - byte[] buffer = ArrayPool.Shared.Rent(1024 * 16); while (true) { int bytesRead = audio.Read(buffer, 0, buffer.Length); @@ -113,6 +114,7 @@ public virtual void SendInputAudio(Stream audio, CancellationToken cancellationT } finally { + ArrayPool.Shared.Return(buffer); using (_audioSendSemaphore.AutoReleaseWait(cancellationToken)) { _isSendingAudioStream = false; @@ -349,4 +351,4 @@ public void Dispose() { WebSocket?.Dispose(); } -} \ No newline at end of file +} diff --git a/tests/RealtimeConversation/ConversationTests.cs b/tests/RealtimeConversation/ConversationTests.cs index 9c855d009..66141cf44 100644 --- a/tests/RealtimeConversation/ConversationTests.cs +++ b/tests/RealtimeConversation/ConversationTests.cs @@ -341,7 +341,6 @@ public async Task AudioWithToolsWorks(TestAudioSendType audioSendType) { byte[] allAudioBytes = await File.ReadAllBytesAsync(inputAudioFilePath, CancellationToken); const int audioSendBufferLength = 8 * 1024; - byte[] audioSendBuffer = ArrayPool.Shared.Rent(audioSendBufferLength); for (int readPos = 0; readPos < allAudioBytes.Length; readPos += audioSendBufferLength) { int nextSegmentLength = Math.Min(audioSendBufferLength, allAudioBytes.Length - readPos);