Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,30 @@
using Azure.Provisioning.Authorization;
using Azure.Provisioning.CognitiveServices;
using OpenAI.Chat;
using OpenAI.Embeddings;

namespace Azure.Provisioning.CloudMachine.OpenAI;

public class OpenAIFeature : CloudMachineFeature
{
public string Model { get; }
public string ModelVersion { get; }
private OpenAiModelDeployment _chatDeployment;
private OpenAiModelDeployment? _embeddingsDeployment;

public OpenAIFeature(string model, string modelVersion) { Model = model; ModelVersion = modelVersion; }
public OpenAIFeature(OpenAiModelDeployment chatDeployment, OpenAiModelDeployment? embeddingsDeployment = default)
{
if (chatDeployment == null)
{
throw new ArgumentNullException(nameof(chatDeployment));
}
_chatDeployment = chatDeployment;
_embeddingsDeployment = embeddingsDeployment;
}

public override void AddTo(CloudMachineInfrastructure cloudMachine)
{
CognitiveServicesAccount cognitiveServices = new("openai")
{
Name = cloudMachine.Id,
Name = $"{cloudMachine.Id}/{cloudMachine.Id}",
Kind = "OpenAI",
Sku = new CognitiveServicesSku { Name = "S0" },
Properties = new CognitiveServicesAccountProperties()
Expand All @@ -31,31 +40,52 @@ public override void AddTo(CloudMachineInfrastructure cloudMachine)
CustomSubDomainName = cloudMachine.Id
},
};
cloudMachine.AddResource(cognitiveServices);

cloudMachine.AddResource(cognitiveServices.CreateRoleAssignment(
CognitiveServicesBuiltInRole.CognitiveServicesOpenAIContributor,
RoleManagementPrincipalType.User,
cloudMachine.PrincipalIdParameter)
);

// TODO: if we every support more than one deployment, they need to be chained using DependsOn.
// The reason is that deployments need to be deployed/created serially.
CognitiveServicesAccountDeployment deployment = new("openai_deployment", "2023-05-01")
CognitiveServicesAccountDeployment chat = new("openai_deployment", "2023-05-01")
{
Parent = cognitiveServices,
Name = cloudMachine.Id,
Properties = new CognitiveServicesAccountDeploymentProperties()
{
Model = new CognitiveServicesAccountDeploymentModel() {
Name = this.Model,
Model = new CognitiveServicesAccountDeploymentModel()
{
Name = _chatDeployment.Model,
Format = "OpenAI",
Version = this.ModelVersion,
Version = _chatDeployment.ModelVersion
}
},
};
cloudMachine.AddResource(chat);

cloudMachine.AddResource(cognitiveServices);
cloudMachine.AddResource(deployment);
if (_embeddingsDeployment != null)
{
CognitiveServicesAccountDeployment embeddings = new("openai_deployment", "2023-05-01")
{
Parent = cognitiveServices,
Name = $"{cloudMachine.Id}/{cloudMachine.Id}-embedding",
Properties = new CognitiveServicesAccountDeploymentProperties()
{
Model = new CognitiveServicesAccountDeploymentModel()
{
Name = _embeddingsDeployment.Model,
Format = "OpenAI",
Version = _embeddingsDeployment.ModelVersion
}
},
};

// Ensure that additional deployments, are chained using DependsOn.
// The reason is that deployments need to be deployed/created serially.
embeddings.DependsOn.Add(chat);
cloudMachine.AddResource(embeddings);
}
}
}

Expand All @@ -72,6 +102,17 @@ public static ChatClient GetOpenAIChatClient(this ClientWorkspace workspace)
return chatClient;
}

public static EmbeddingClient GetOpenAIEmbeddingsClient(this ClientWorkspace workspace)
{
EmbeddingClient embeddingsClient = workspace.Subclients.Get(() =>
{
AzureOpenAIClient aoiaClient = workspace.Subclients.Get(() => CreateAzureOpenAIClient(workspace));
return workspace.CreateEmbeddingsClient(aoiaClient);
});

return embeddingsClient;
}

private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace workspace)
{
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(AzureOpenAIClient));
Expand All @@ -81,7 +122,7 @@ private static AzureOpenAIClient CreateAzureOpenAIClient(this ClientWorkspace wo
}
else
{
return new(connection.Endpoint, new ApiKeyCredential(connection.ApiKeyCredential!));
return new(connection.Endpoint, new ApiKeyCredential(connection.ApiKeyCredential!));
}
}

Expand All @@ -91,4 +132,11 @@ private static ChatClient CreateChatClient(this ClientWorkspace workspace, Azure
ChatClient chat = client.GetChatClient(connection.Id);
return chat;
}

private static EmbeddingClient CreateEmbeddingsClient(this ClientWorkspace workspace, AzureOpenAIClient client)
{
ClientConnectionOptions connection = workspace.GetConnectionOptions(typeof(EmbeddingClient));
EmbeddingClient embeddings = client.GetEmbeddingClient(connection.Id);
return embeddings;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

public class OpenAiModelDeployment
{
public OpenAiModelDeployment(string model, string modelVersion) { Model = model; ModelVersion = modelVersion; }
public string Model { get; }
public string ModelVersion { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public override ClientConnectionOptions GetConnectionOptions(Type clientType, st
return new ClientConnectionOptions(new($"https://{this.Id}.openai.azure.com"), Credential);
case "OpenAI.Chat.ChatClient":
return new ClientConnectionOptions(Id);
case "OpenAI.Embeddings.EmbeddingClient":
return new ClientConnectionOptions($"{Id}-embedding");
default:
throw new Exception($"unknown client {clientId}");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,50 @@ private BlobContainerClient GetContainer(string containerName)
return container;
}

public string UploadJson(object json, string? name = default)
public string UploadJson(object json, string? name = default, bool overwrite = false)
{
bool doOverwrite = false;
BlobContainerClient container = GetDefaultContainer();

if (name == default)
name = $"b{Guid.NewGuid()}";

container.UploadBlob(name, BinaryData.FromObjectAsJson(json));
try
{
container.UploadBlob(name, BinaryData.FromObjectAsJson(json));
}
catch (RequestFailedException e) when (overwrite && e.Status == 409 && e.ErrorCode == BlobErrorCode.BlobAlreadyExists)
{
doOverwrite = true;
}
if (doOverwrite)
{
container.GetBlobClient(name).Upload(BinaryData.FromObjectAsJson(json), overwrite: true);
}

return name;
}

public string UploadBlob(Stream fileStream, string? name = default, bool overwrite = false)
{
bool doOverwrite = false;
BlobContainerClient container = GetDefaultContainer();

if (name == default)
name = $"b{Guid.NewGuid()}";
try
{
container.UploadBlob(name, BinaryData.FromStream(fileStream));
}
catch (RequestFailedException e) when (overwrite && e.Status == 409 && e.ErrorCode == BlobErrorCode.BlobAlreadyExists)
{
doOverwrite = true;
}
if (doOverwrite)
{
fileStream.Position = 0;
container.GetBlobClient(name).Upload(fileStream, overwrite: true);
}

return name;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ public void Provisioning(string[] args)
if (CloudMachineInfrastructure.Configure(args, (cm) =>
{
cm.AddFeature(new KeyVaultFeature());
cm.AddFeature(new OpenAIFeature("gpt-35-turbo", "0125"));
cm.AddFeature(new OpenAIFeature(new OpenAiModelDeployment("gpt-35-turbo", "0125"), new OpenAiModelDeployment("text-embedding-ada-002", "2")));
}))
return;

CloudMachineWorkspace cm = new();
Console.WriteLine(cm.Id);
var embeddings = cm.GetOpenAIEmbeddingsClient();
}

[Ignore("no recordings yet")]
Expand Down Expand Up @@ -71,7 +72,7 @@ public void OpenAI(string[] args)
{
if (CloudMachineInfrastructure.Configure(args, (cm) =>
{
cm.AddFeature(new OpenAIFeature("gpt-35-turbo", "0125"));
cm.AddFeature(new OpenAIFeature(new OpenAiModelDeployment("gpt-35-turbo", "0125")));
}))
return;

Expand Down
Loading