Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 10 additions & 24 deletions .dotnet/OpenAI.sln
Original file line number Diff line number Diff line change
@@ -1,37 +1,19 @@
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.29709.97
# Visual Studio Version 17
VisualStudioVersion = 17.9.34902.65
MinimumVisualStudioVersion = 10.0.40219.1
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenAI", "src\OpenAI.csproj", "{28FF4005-4467-4E36-92E7-DEA27DEB1519}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI", "src\OpenAI.csproj", "{28FF4005-4467-4E36-92E7-DEA27DEB1519}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OpenAI.Tests", "tests\OpenAI.Tests.csproj", "{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.Tests", "tests\OpenAI.Tests.csproj", "{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "System.ClientModel", "..\..\azure-sdk-for-net\sdk\core\System.ClientModel\src\System.ClientModel.csproj", "{2DAD1811-2A5E-4C60-80D1-B94533FD1B74}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{B0C276D1-2930-4887-B29A-D1A33E7009A2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{B0C276D1-2930-4887-B29A-D1A33E7009A2}.Debug|Any CPU.Build.0 = Debug|Any CPU
{B0C276D1-2930-4887-B29A-D1A33E7009A2}.Release|Any CPU.ActiveCfg = Release|Any CPU
{B0C276D1-2930-4887-B29A-D1A33E7009A2}.Release|Any CPU.Build.0 = Release|Any CPU
{8E9A77AC-792A-4432-8320-ACFD46730401}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{8E9A77AC-792A-4432-8320-ACFD46730401}.Debug|Any CPU.Build.0 = Debug|Any CPU
{8E9A77AC-792A-4432-8320-ACFD46730401}.Release|Any CPU.ActiveCfg = Release|Any CPU
{8E9A77AC-792A-4432-8320-ACFD46730401}.Release|Any CPU.Build.0 = Release|Any CPU
{A4241C1F-A53D-474C-9E4E-075054407E74}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A4241C1F-A53D-474C-9E4E-075054407E74}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A4241C1F-A53D-474C-9E4E-075054407E74}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A4241C1F-A53D-474C-9E4E-075054407E74}.Release|Any CPU.Build.0 = Release|Any CPU
{FA8BD3F1-8616-47B6-974C-7576CDF4717E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{FA8BD3F1-8616-47B6-974C-7576CDF4717E}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FA8BD3F1-8616-47B6-974C-7576CDF4717E}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FA8BD3F1-8616-47B6-974C-7576CDF4717E}.Release|Any CPU.Build.0 = Release|Any CPU
{85677AD3-C214-42FA-AE6E-49B956CAC8DC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{85677AD3-C214-42FA-AE6E-49B956CAC8DC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{85677AD3-C214-42FA-AE6E-49B956CAC8DC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{85677AD3-C214-42FA-AE6E-49B956CAC8DC}.Release|Any CPU.Build.0 = Release|Any CPU
{28FF4005-4467-4E36-92E7-DEA27DEB1519}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{28FF4005-4467-4E36-92E7-DEA27DEB1519}.Debug|Any CPU.Build.0 = Debug|Any CPU
{28FF4005-4467-4E36-92E7-DEA27DEB1519}.Release|Any CPU.ActiveCfg = Release|Any CPU
Expand All @@ -40,6 +22,10 @@ Global
{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}.Release|Any CPU.Build.0 = Release|Any CPU
{2DAD1811-2A5E-4C60-80D1-B94533FD1B74}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{2DAD1811-2A5E-4C60-80D1-B94533FD1B74}.Debug|Any CPU.Build.0 = Debug|Any CPU
{2DAD1811-2A5E-4C60-80D1-B94533FD1B74}.Release|Any CPU.ActiveCfg = Release|Any CPU
{2DAD1811-2A5E-4C60-80D1-B94533FD1B74}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
8 changes: 4 additions & 4 deletions .dotnet/src/Custom/Assistants/AssistantClient.Convenience.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public virtual ClientResult<bool> DeleteMessage(ThreadMessage message)
/// <param name="assistant"> The assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
/// <returns> A new <see cref="ThreadRun"/> instance. </returns>
public virtual Task<ClientResult<ThreadRun>> CreateRunAsync(AssistantThread thread, Assistant assistant, RunCreationOptions options = null)
public virtual Task<StatusBasedOperation<RunStatus, ThreadRun>> CreateRunAsync(AssistantThread thread, Assistant assistant, RunCreationOptions options = null)
=> CreateRunAsync(thread?.Id, assistant?.Id, options);

/// <summary>
Expand All @@ -226,7 +226,7 @@ public virtual Task<ClientResult<ThreadRun>> CreateRunAsync(AssistantThread thre
/// <param name="assistant"> The assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
/// <returns> A new <see cref="ThreadRun"/> instance. </returns>
public virtual ClientResult<ThreadRun> CreateRun(AssistantThread thread, Assistant assistant, RunCreationOptions options = null)
public virtual StatusBasedOperation<RunStatus, ThreadRun> CreateRun(AssistantThread thread, Assistant assistant, RunCreationOptions options = null)
=> CreateRun(thread?.Id, assistant?.Id, options);

/// <summary>
Expand Down Expand Up @@ -346,15 +346,15 @@ public virtual PageableCollection<ThreadRun> GetRuns(
/// </summary>
/// <param name="run"> The run to get a refreshed instance of. </param>
/// <returns> A new <see cref="ThreadRun"/> instance with updated information. </returns>
public virtual Task<ClientResult<ThreadRun>> GetRunAsync(ThreadRun run)
internal virtual Task<ClientResult<ThreadRun>> GetRunAsync(ThreadRun run)
=> GetRunAsync(run?.ThreadId, run?.Id);

/// <summary>
/// Gets a refreshed instance of an existing <see cref="ThreadRun"/>.
/// </summary>
/// <param name="run"> The run to get a refreshed instance of. </param>
/// <returns> A new <see cref="ThreadRun"/> instance with updated information. </returns>
public virtual ClientResult<ThreadRun> GetRun(ThreadRun run)
internal virtual ClientResult<ThreadRun> GetRun(ThreadRun run)
=> GetRun(run?.ThreadId, run?.Id);

/// <summary>
Expand Down
16 changes: 10 additions & 6 deletions .dotnet/src/Custom/Assistants/AssistantClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ public virtual ClientResult<bool> DeleteMessage(string threadId, string messageI
/// <param name="assistantId"> The ID of the assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
/// <returns> A new <see cref="ThreadRun"/> instance. </returns>
public virtual async Task<ClientResult<ThreadRun>> CreateRunAsync(string threadId, string assistantId, RunCreationOptions options = null)
public virtual async Task<StatusBasedOperation<RunStatus, ThreadRun>> CreateRunAsync(string threadId, string assistantId, RunCreationOptions options = null)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));
Expand All @@ -461,7 +461,9 @@ public virtual async Task<ClientResult<ThreadRun>> CreateRunAsync(string threadI

ClientResult protocolResult = await CreateRunAsync(threadId, options.ToBinaryContent(), null)
.ConfigureAwait(false);
return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse);
ClientResult<ThreadRun> createResult = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse);

return new AssistantRunOperation(createResult, GetRun, GetRunAsync);
}

/// <summary>
Expand All @@ -472,7 +474,7 @@ public virtual async Task<ClientResult<ThreadRun>> CreateRunAsync(string threadI
/// <param name="assistantId"> The ID of the assistant that should be used when evaluating the thread. </param>
/// <param name="options"> Additional options for the run. </param>
/// <returns> A new <see cref="ThreadRun"/> instance. </returns>
public virtual ClientResult<ThreadRun> CreateRun(string threadId, string assistantId, RunCreationOptions options = null)
public virtual StatusBasedOperation<RunStatus, ThreadRun> CreateRun(string threadId, string assistantId, RunCreationOptions options = null)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(assistantId, nameof(assistantId));
Expand All @@ -481,7 +483,9 @@ public virtual ClientResult<ThreadRun> CreateRun(string threadId, string assista
options.Stream = null;

ClientResult protocolResult = CreateRun(threadId, options.ToBinaryContent(), null);
return CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse);
ClientResult<ThreadRun> createResult = CreateResultFromProtocol(protocolResult, ThreadRun.FromResponse);

return new AssistantRunOperation(createResult, GetRun, GetRunAsync);
}

/// <summary>
Expand Down Expand Up @@ -662,7 +666,7 @@ public virtual PageableCollection<ThreadRun> GetRuns(
/// <param name="threadId"> The ID of the thread to retrieve the run from. </param>
/// <param name="runId"> The ID of the run to retrieve. </param>
/// <returns> The existing <see cref="ThreadRun"/> instance. </returns>
public virtual async Task<ClientResult<ThreadRun>> GetRunAsync(string threadId, string runId)
internal virtual async Task<ClientResult<ThreadRun>> GetRunAsync(string threadId, string runId)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(runId, nameof(runId));
Expand All @@ -677,7 +681,7 @@ public virtual async Task<ClientResult<ThreadRun>> GetRunAsync(string threadId,
/// <param name="threadId"> The ID of the thread to retrieve the run from. </param>
/// <param name="runId"> The ID of the run to retrieve. </param>
/// <returns> The existing <see cref="ThreadRun"/> instance. </returns>
public virtual ClientResult<ThreadRun> GetRun(string threadId, string runId)
internal virtual ClientResult<ThreadRun> GetRun(string threadId, string runId)
{
Argument.AssertNotNullOrEmpty(threadId, nameof(threadId));
Argument.AssertNotNullOrEmpty(runId, nameof(runId));
Expand Down
210 changes: 210 additions & 0 deletions .dotnet/src/Custom/Assistants/AssistantRunOperation.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading;
using System.Threading.Tasks;

#nullable enable

namespace OpenAI.Assistants;

// TODO: add hooks for cancel run?

internal class AssistantRunOperation : StatusBasedOperation<RunStatus, ThreadRun>
{
private static readonly TimeSpan DefaultPollingInterval = TimeSpan.FromSeconds(2);

private readonly string _threadId;
private readonly string _runId;

private readonly Func<string, string, ClientResult<ThreadRun>> _getRun;
private readonly Func<string, string, Task<ClientResult<ThreadRun>>> _getRunAsync;

private ClientResult<ThreadRun> _lastSeenResult;

private bool _statusChanged;
private bool _paused;

public AssistantRunOperation(ClientResult<ThreadRun> createResult,
Func<string, string, ClientResult<ThreadRun>> getRun,
Func<string, string, Task<ClientResult<ThreadRun>>> getRunAsync) :
base(GetIdFromResult(createResult), createResult.Value.Status, GetResponseFromResult(createResult))
{
_lastSeenResult = createResult;
Value = _lastSeenResult.Value;

_threadId = createResult.Value.ThreadId;
_runId = createResult.Value.Id;

_getRun = getRun;
_getRunAsync = getRunAsync;
}

private static string GetIdFromResult(ClientResult<ThreadRun> result)
{
// TODO: validate this is reversible -- i.e. it will parse
return $"{result.Value.ThreadId};{result.Value.Id}";
}

private static PipelineResponse GetResponseFromResult(ClientResult<ThreadRun> result)
=> result.GetRawResponse();

public override ClientResult UpdateStatus()
{
if (HasCompleted)
{
return _lastSeenResult;
}

ClientResult<ThreadRun> result = _getRun(_threadId, _runId);

// Compute delta between result and _lastSeenResult
if (_lastSeenResult.Value.Status != result.Value.Status)
{
Status = result.Value.Status;
_statusChanged = true;
}

_lastSeenResult = result;

Value = result.Value;

if (result.Value.Status.IsTerminal)
{
HasCompleted = true;
}

SetRawResponse(result.GetRawResponse());

return result;
}

public override async ValueTask<ClientResult> UpdateStatusAsync()
{
if (HasCompleted)
{
return _lastSeenResult;
}

_lastSeenResult = await _getRunAsync(_threadId, _runId).ConfigureAwait(false);

Value = _lastSeenResult.Value;

if (_lastSeenResult.Value.Status.IsTerminal)
{
HasCompleted = true;
}

SetRawResponse(_lastSeenResult.GetRawResponse());

return _lastSeenResult;
}

public override ClientResult<ThreadRun> WaitForCompletion(TimeSpan? pollingInterval = default, CancellationToken cancellationToken = default)
{
pollingInterval ??= DefaultPollingInterval;

while (true)
{
cancellationToken.ThrowIfCancellationRequested();

if (!_paused)
{
UpdateStatus();

if (HasCompleted)
{
return _lastSeenResult;
}
}

// TODO: note pollling interval logic may change for e.g. exponential
// backoff if the operation is paused.
cancellationToken.WaitHandle.WaitOne(pollingInterval.Value);
}
}

public override async ValueTask<ClientResult<ThreadRun>> WaitForCompletionAsync(TimeSpan? pollingInterval = default, CancellationToken cancellationToken = default)
{
pollingInterval ??= DefaultPollingInterval;

while (true)
{
cancellationToken.ThrowIfCancellationRequested();

if (!_paused)
{
await UpdateStatusAsync().ConfigureAwait(false);

if (HasCompleted)
{
return _lastSeenResult;
}
}

await Task.Delay(pollingInterval.Value, cancellationToken).ConfigureAwait(false);
}
}

public override ClientResult WaitForCompletionResult(CancellationToken cancellationToken = default)
=> WaitForCompletion(DefaultPollingInterval, cancellationToken);

public override ClientResult WaitForCompletionResult(TimeSpan pollingInterval, CancellationToken cancellationToken = default)
=> WaitForCompletion(pollingInterval, cancellationToken);

public override async ValueTask<ClientResult> WaitForCompletionResultAsync(CancellationToken cancellationToken = default)
=> await WaitForCompletionAsync(DefaultPollingInterval, cancellationToken).ConfigureAwait(false);


public override async ValueTask<ClientResult> WaitForCompletionResultAsync(TimeSpan pollingInterval, CancellationToken cancellationToken = default)
=> await WaitForCompletionAsync(pollingInterval, cancellationToken).ConfigureAwait(false);

public override async ValueTask<ClientResult<(RunStatus Status, ThreadRun? Value)>> WaitForStatusUpdateAsync(TimeSpan? pollingInterval = null, CancellationToken cancellationToken = default)
{
pollingInterval ??= DefaultPollingInterval;
while (true)
{
cancellationToken.ThrowIfCancellationRequested();

if (!_paused)
{
ClientResult result = await UpdateStatusAsync().ConfigureAwait(false);

if (_statusChanged)
{
(RunStatus Status, ThreadRun? Value) tuple = new(_lastSeenResult.Value.Status, _lastSeenResult.Value);
return FromValue(tuple, result.GetRawResponse());
}
}

await Task.Delay(pollingInterval.Value, cancellationToken).ConfigureAwait(false);
}
}

public override ClientResult<(RunStatus Status, ThreadRun? Value)> WaitForStatusUpdate(TimeSpan? pollingInterval = null, CancellationToken cancellationToken = default)
{
pollingInterval ??= DefaultPollingInterval;

while (true)
{
cancellationToken.ThrowIfCancellationRequested();

if (!_paused)
{
ClientResult result = UpdateStatus();

if (_statusChanged)
{
(RunStatus Status, ThreadRun? Value) tuple = new(_lastSeenResult.Value.Status, _lastSeenResult.Value);
return FromValue(tuple, result.GetRawResponse());
}
}

cancellationToken.WaitHandle.WaitOne(pollingInterval.Value);
}
}

public override void Pause() => _paused = true;

public override void Resume() => _paused = false;
}
14 changes: 14 additions & 0 deletions .dotnet/src/Custom/Common/LroHelpers.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Threading;
using System.Threading.Tasks;

#nullable enable

namespace OpenAI;

internal class LroHelpers
{

}
2 changes: 1 addition & 1 deletion .dotnet/src/OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<SignAssembly>False</SignAssembly>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.ClientModel" Version="1.1.0-beta.4" />
<ProjectReference Include="..\..\..\azure-sdk-for-net\sdk\core\System.ClientModel\src\System.ClientModel.csproj" />
<PackageReference Include="System.Text.Json" Version="8.0.2" />
</ItemGroup>
</Project>
Loading