Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk/ai/Azure.AI.Project/src/Custom/Agent/AgentClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ public virtual Task<Response<ThreadRun>> CreateRunAsync(AgentThread thread, Agen
/// <param name="cancellationToken"> The cancellation token to use. </param>
/// <exception cref="ArgumentNullException"> <paramref name="threadId"/> or <paramref name="assistantId"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="threadId"/> is an empty string, and was expected to be non-empty. </exception>
#pragma warning disable AZC0015 // Unexpected client method return type.
public virtual AsyncCollectionResult<StreamingUpdate> CreateRunStreamingAsync(string threadId, string assistantId, string overrideModelName = null, string overrideInstructions = null, string additionalInstructions = null, IEnumerable<ThreadMessage> additionalMessages = null, IEnumerable<ToolDefinition> overrideTools = null, float? temperature = null, float? topP = null, int? maxPromptTokens = null, int? maxCompletionTokens = null, TruncationObject truncationStrategy = null, BinaryData toolChoice = null, BinaryData responseFormat = null, IReadOnlyDictionary<string, string> metadata = null, CancellationToken cancellationToken = default)
#pragma warning restore AZC0015 // Unexpected client method return type.
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNull(assistantId, nameof(assistantId));
Expand Down Expand Up @@ -237,7 +239,9 @@ async Task<Response> sendRequestAsync() =>
/// <param name="cancellationToken"> The cancellation token to use. </param>
/// <exception cref="ArgumentNullException"> <paramref name="threadId"/> or <paramref name="assistantId"/> is null. </exception>
/// <exception cref="ArgumentException"> <paramref name="threadId"/> is an empty string, and was expected to be non-empty. </exception>
#pragma warning disable AZC0015 // Unexpected client method return type.
public virtual CollectionResult<StreamingUpdate> CreateRunStreaming(string threadId, string assistantId, string overrideModelName = null, string overrideInstructions = null, string additionalInstructions = null, IEnumerable<ThreadMessage> additionalMessages = null, IEnumerable<ToolDefinition> overrideTools = null, float? temperature = null, float? topP = null, int? maxPromptTokens = null, int? maxCompletionTokens = null, TruncationObject truncationStrategy = null, BinaryData toolChoice = null, BinaryData responseFormat = null, IReadOnlyDictionary<string, string> metadata = null, CancellationToken cancellationToken = default)
#pragma warning restore AZC0015 // Unexpected client method return type.
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNull(assistantId, nameof(assistantId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ namespace Azure.AI.Project;
/// </summary>
internal class AsyncStreamingUpdateCollection : AsyncCollectionResult<StreamingUpdate>
{
private readonly Func<Task<ClientResult>> _sendRequestAsync;
private readonly Func<Task<Response>> _sendRequestAsync;
private readonly CancellationToken _cancellationToken;

public AsyncStreamingUpdateCollection(Func<Task<ClientResult>> sendRequestAsync,
public AsyncStreamingUpdateCollection(Func<Task<Response>> sendRequestAsync,
CancellationToken cancellationToken)
{
Argument.AssertNotNull(sendRequestAsync, nameof(sendRequestAsync));
Expand All @@ -37,9 +37,12 @@ public AsyncStreamingUpdateCollection(Func<Task<ClientResult>> sendRequestAsync,

public async override IAsyncEnumerable<ClientResult> GetRawPagesAsync()
{
Response response = await _sendRequestAsync().ConfigureAwait(false);
PipelineResponse scmResponse = new ResponseAdapter(response);

// We don't currently support resuming a dropped connection from the
// last received event, so the response collection has a single element.
yield return await _sendRequestAsync().ConfigureAwait(false);
yield return ClientResult.FromResponse(scmResponse);
}

protected async override IAsyncEnumerable<StreamingUpdate> GetValuesFromPageAsync(ClientResult page)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ namespace Azure.AI.Project;
/// </summary>
internal class StreamingUpdateCollection : CollectionResult<StreamingUpdate>
{
private readonly Func<ClientResult> _sendRequest;
private readonly Func<Response> _sendRequest;
private readonly CancellationToken _cancellationToken;

public StreamingUpdateCollection(
Func<ClientResult> sendRequest,
Func<Response> sendRequest,
CancellationToken cancellationToken)
{
Argument.AssertNotNull(sendRequest, nameof(sendRequest));
Expand All @@ -37,9 +37,12 @@ public StreamingUpdateCollection(

public override IEnumerable<ClientResult> GetRawPages()
{
Response response = _sendRequest();
PipelineResponse scmResponse = new ResponseAdapter(response);

// We don't currently support resuming a dropped connection from the
// last received event, so the response collection has a single element.
yield return _sendRequest();
yield return ClientResult.FromResponse(scmResponse);
}
protected override IEnumerable<StreamingUpdate> GetValuesFromPage(ClientResult page)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public enum StreamingUpdateReason
/// Indicates that an update was generated as part of a <c>thread.created</c> event.
/// </summary>
/// <remarks> This reason is typically only associated with calls to
/// <see cref="AgentClient.CreateThreadAndRunStreaming(Agent, ThreadCreationOptions, RunCreationOptions)"/>,
/// <see cref="AgentClient.CreateThreadAndRun(string, AgentThreadCreationOptions, string, string, System.Collections.Generic.IEnumerable{ToolDefinition}, UpdateToolResourcesOptions, bool?, float?, float?, int?, int?, TruncationObject, System.BinaryData, System.BinaryData, System.Collections.Generic.IReadOnlyDictionary{string, string}, System.Threading.CancellationToken)"/>,
/// as other run-related methods operate on a thread that has previously been created.
/// </remarks>
ThreadCreated,
Expand Down
53 changes: 53 additions & 0 deletions sdk/ai/Azure.AI.Project/src/Custom/Utility/ResponseAdapter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.ClientModel.Primitives;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

#nullable enable

namespace Azure.AI.Project;

/// <summary>
/// Adapts an Azure.Core Response to an SCM PipelineResponse.
/// </summary>
internal class ResponseAdapter : PipelineResponse
{
private readonly Response _azureResponse;
private PipelineResponseHeaders? _headers;

public ResponseAdapter(Response azureResponse)
{
_azureResponse = azureResponse;
}

public override int Status => _azureResponse.Status;

public override string ReasonPhrase => _azureResponse.ReasonPhrase;

public override Stream? ContentStream
{
get => _azureResponse?.ContentStream;
set => _azureResponse.ContentStream = value;
}

public override BinaryData Content => _azureResponse.Content;

protected override PipelineResponseHeaders HeadersCore =>
_headers ??= new ResponseHeadersAdapter(_azureResponse.Headers);

public override BinaryData BufferContent(CancellationToken cancellationToken = default)
{
throw new NotSupportedException("Content buffering is not supported for SSE response streams.");
}

public override ValueTask<BinaryData> BufferContentAsync(CancellationToken cancellationToken = default)
{
throw new NotSupportedException("Content buffering is not supported for SSE response streams.");
}

public override void Dispose() => _azureResponse?.Dispose();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ClientModel.Primitives;
using System.Collections.Generic;
using Azure.Core;

#nullable enable

namespace Azure.AI.Project;

/// <summary>
/// Adapts an Azure.Core ResponseHeaders to an SCM PipelineResponseHeaders.
/// </summary>
internal class ResponseHeadersAdapter : PipelineResponseHeaders
{
private readonly ResponseHeaders _azureHeaders;

public ResponseHeadersAdapter(ResponseHeaders azureHeaders)
{
_azureHeaders = azureHeaders;
}

public override IEnumerator<KeyValuePair<string, string>> GetEnumerator()
{
foreach (HttpHeader header in _azureHeaders)
{
yield return new KeyValuePair<string, string>(header.Name, header.Value);
}
}

public override bool TryGetValue(string name, out string? value)
=> _azureHeaders.TryGetValue(name, out value);

public override bool TryGetValues(string name, out IEnumerable<string>? values)
=> _azureHeaders.TryGetValue(name, out values);
}