Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ChatController for Asynchronous Plugin Registration #864

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 86 additions & 56 deletions webapi/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,87 +194,112 @@ private async Task RegisterFunctionsAsync(Kernel kernel, Dictionary<string, stri
{
// Register authenticated functions with the kernel only if the request includes an auth header for the plugin.

var tasks = new List<Task>();

// GitHub
if (authHeaders.TryGetValue("GITHUB", out string? GithubAuthHeader))
{
this._logger.LogInformation("Enabling GitHub plugin.");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GithubAuthHeader));
await kernel.ImportPluginFromOpenApiAsync(
pluginName: "GitHubPlugin",
filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/GitHubPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
});
tasks.Add(this.RegisterGithubPlugin(kernel, GithubAuthHeader));
}

// Jira
if (authHeaders.TryGetValue("JIRA", out string? JiraAuthHeader))
{
this._logger.LogInformation("Registering Jira plugin");
var authenticationProvider = new BasicAuthenticationProvider(() => { return Task.FromResult(JiraAuthHeader); });
var hasServerUrlOverride = variables.TryGetValue("jira-server-url", out object? serverUrlOverride);

await kernel.ImportPluginFromOpenApiAsync(
pluginName: "JiraPlugin",
filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/JiraPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!.ToString()!) : null,
});
tasks.Add(this.RegisterJiraPlugin(kernel, JiraAuthHeader, variables));
}

// Microsoft Graph
if (authHeaders.TryGetValue("GRAPH", out string? GraphAuthHeader))
{
this._logger.LogInformation("Enabling Microsoft Graph plugin(s).");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GraphAuthHeader));
GraphServiceClient graphServiceClient = this.CreateGraphServiceClient(authenticationProvider.GraphClientAuthenticateRequestAsync);

kernel.ImportPluginFromObject(new TaskListPlugin(new MicrosoftToDoConnector(graphServiceClient)), "todo");
kernel.ImportPluginFromObject(new CalendarPlugin(new OutlookCalendarConnector(graphServiceClient)), "calendar");
kernel.ImportPluginFromObject(new EmailPlugin(new OutlookMailConnector(graphServiceClient)), "email");
tasks.Add(this.RegisterMicrosoftGraphPlugins(kernel, GraphAuthHeader));
}

if (variables.TryGetValue("customPlugins", out object? customPluginsString))
{
CustomPlugin[]? customPlugins = JsonSerializer.Deserialize<CustomPlugin[]>(customPluginsString!.ToString()!);
tasks.AddRange(this.RegisterCustomPlugins(kernel, customPluginsString, authHeaders));
}

if (customPlugins != null)
await Task.WhenAll(tasks);
}

private async Task RegisterGithubPlugin(Kernel kernel, string GithubAuthHeader)
{
this._logger.LogInformation("Enabling GitHub plugin.");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GithubAuthHeader));
await kernel.ImportPluginFromOpenApiAsync(
pluginName: "GitHubPlugin",
filePath: GetPluginFullPath("GitHubPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
foreach (CustomPlugin plugin in customPlugins)
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
});
}

private async Task RegisterJiraPlugin(Kernel kernel, string JiraAuthHeader, KernelArguments variables)
{
this._logger.LogInformation("Registering Jira plugin");
var authenticationProvider = new BasicAuthenticationProvider(() => { return Task.FromResult(JiraAuthHeader); });
var hasServerUrlOverride = variables.TryGetValue("jira-server-url", out object? serverUrlOverride);

await kernel.ImportPluginFromOpenApiAsync(
pluginName: "JiraPlugin",
filePath: GetPluginFullPath("OpenApi/JiraPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!.ToString()!) : null,
}); ; ;
}

private Task RegisterMicrosoftGraphPlugins(Kernel kernel, string GraphAuthHeader)
{
this._logger.LogInformation("Enabling Microsoft Graph plugin(s).");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GraphAuthHeader));
GraphServiceClient graphServiceClient = this.CreateGraphServiceClient(authenticationProvider.GraphClientAuthenticateRequestAsync);

kernel.ImportPluginFromObject(new TaskListPlugin(new MicrosoftToDoConnector(graphServiceClient)), "todo");
kernel.ImportPluginFromObject(new CalendarPlugin(new OutlookCalendarConnector(graphServiceClient)), "calendar");
kernel.ImportPluginFromObject(new EmailPlugin(new OutlookMailConnector(graphServiceClient)), "email");
return Task.CompletedTask;
}

private IEnumerable<Task> RegisterCustomPlugins(Kernel kernel, object? customPluginsString, Dictionary<string, string> authHeaders)
{
CustomPlugin[]? customPlugins = JsonSerializer.Deserialize<CustomPlugin[]>(customPluginsString!.ToString()!);

if (customPlugins != null)
{
foreach (CustomPlugin plugin in customPlugins)
{
if (authHeaders.TryGetValue(plugin.AuthHeaderTag.ToUpperInvariant(), out string? PluginAuthValue))
{
if (authHeaders.TryGetValue(plugin.AuthHeaderTag.ToUpperInvariant(), out string? PluginAuthValue))
// Register the ChatGPT plugin with the kernel.
this._logger.LogInformation("Enabling {0} plugin.", plugin.NameForHuman);

// TODO: [Issue #44] Support other forms of auth. Currently, we only support user PAT or no auth.
var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase);
Task authCallback(HttpRequestMessage request, string _, OpenAIAuthenticationConfig __, CancellationToken ___ = default)
{
// Register the ChatGPT plugin with the kernel.
this._logger.LogInformation("Enabling {0} plugin.", plugin.NameForHuman);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", PluginAuthValue);

// TODO: [Issue #44] Support other forms of auth. Currently, we only support user PAT or no auth.
var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase);
Task authCallback(HttpRequestMessage request, string _, OpenAIAuthenticationConfig __, CancellationToken ___ = default)
{
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", PluginAuthValue);

return Task.CompletedTask;
}

await kernel.ImportPluginFromOpenAIAsync(
$"{plugin.NameForModel}Plugin",
PluginUtils.GetPluginManifestUri(plugin.ManifestDomain),
new OpenAIFunctionExecutionParameters
{
HttpClient = this._httpClientFactory.CreateClient(),
IgnoreNonCompliantErrors = true,
AuthCallback = requiresAuth ? authCallback : null
});
return Task.CompletedTask;
}

yield return kernel.ImportPluginFromOpenAIAsync(
$"{plugin.NameForModel}Plugin",
PluginUtils.GetPluginManifestUri(plugin.ManifestDomain),
new OpenAIFunctionExecutionParameters
{
HttpClient = this._httpClientFactory.CreateClient(),
IgnoreNonCompliantErrors = true,
AuthCallback = requiresAuth ? authCallback : null
});
}
}
else
{
this._logger.LogDebug("Failed to deserialize custom plugin details: {0}", customPluginsString);
}
}
else
{
this._logger.LogDebug("Failed to deserialize custom plugin details: {0}", customPluginsString);
}
}

Expand Down Expand Up @@ -354,6 +379,11 @@ private static KernelArguments GetContextVariables(Ask ask, IAuthInfo authInfo,
return contextVariables;
}

private static string GetPluginFullPath(string pluginPath)
{
return Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", pluginPath);
}

/// <summary>
/// Dispose of the object.
/// </summary>
Expand Down
Loading