Skip to content

Commit

Permalink
Refactor ChatController for Asynchronous Plugin Registration (#864)
Browse files Browse the repository at this point in the history
This commit introduces asynchronous registration of plugins in the
ChatController. The changes improve the efficiency of the chat function
by allowing multiple plugins to be registered concurrently. This
enhancement is expected to improve response times and overall
performance of the chat functionality.

### Motivation and Context

1. This change is required to improve the efficiency and performance of
the chat functionality in the application.
2. The current synchronous registration of plugins in the ChatController
can lead to slower response times. This change solves this problem by
introducing asynchronous registration, allowing multiple plugins to be
registered concurrently.
3. This contributes to the scenario where the application needs to
handle multiple chat sessions simultaneously, each potentially requiring
different sets of plugins. With the asynchronous registration, the
application can now handle these scenarios more efficiently.
4. This change does not directly fix an open issue, but it is a
proactive measure to enhance the application's performance and
scalability.

### Description
This edit is just a simple separation of registration of plugins [ JIRA
- GRAPH - GITHUB - Custom plugins ] each one of them in a separate
**Asynchronous** function and registering the enabled plugins
simultaneously, as this process occurs each time a message is sent to
chatbot.

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [Contribution
Guidelines](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/chat-copilot/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
Ahmed-Adel3 authored Mar 24, 2024
1 parent 216f4f2 commit 55b1c9b
Showing 1 changed file with 86 additions and 56 deletions.
142 changes: 86 additions & 56 deletions webapi/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,87 +194,112 @@ private async Task RegisterFunctionsAsync(Kernel kernel, Dictionary<string, stri
{
// Register authenticated functions with the kernel only if the request includes an auth header for the plugin.

var tasks = new List<Task>();

// GitHub
if (authHeaders.TryGetValue("GITHUB", out string? GithubAuthHeader))
{
this._logger.LogInformation("Enabling GitHub plugin.");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GithubAuthHeader));
await kernel.ImportPluginFromOpenApiAsync(
pluginName: "GitHubPlugin",
filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/GitHubPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
});
tasks.Add(this.RegisterGithubPlugin(kernel, GithubAuthHeader));
}

// Jira
if (authHeaders.TryGetValue("JIRA", out string? JiraAuthHeader))
{
this._logger.LogInformation("Registering Jira plugin");
var authenticationProvider = new BasicAuthenticationProvider(() => { return Task.FromResult(JiraAuthHeader); });
var hasServerUrlOverride = variables.TryGetValue("jira-server-url", out object? serverUrlOverride);

await kernel.ImportPluginFromOpenApiAsync(
pluginName: "JiraPlugin",
filePath: Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", "OpenApi/JiraPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!.ToString()!) : null,
});
tasks.Add(this.RegisterJiraPlugin(kernel, JiraAuthHeader, variables));
}

// Microsoft Graph
if (authHeaders.TryGetValue("GRAPH", out string? GraphAuthHeader))
{
this._logger.LogInformation("Enabling Microsoft Graph plugin(s).");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GraphAuthHeader));
GraphServiceClient graphServiceClient = this.CreateGraphServiceClient(authenticationProvider.GraphClientAuthenticateRequestAsync);

kernel.ImportPluginFromObject(new TaskListPlugin(new MicrosoftToDoConnector(graphServiceClient)), "todo");
kernel.ImportPluginFromObject(new CalendarPlugin(new OutlookCalendarConnector(graphServiceClient)), "calendar");
kernel.ImportPluginFromObject(new EmailPlugin(new OutlookMailConnector(graphServiceClient)), "email");
tasks.Add(this.RegisterMicrosoftGraphPlugins(kernel, GraphAuthHeader));
}

if (variables.TryGetValue("customPlugins", out object? customPluginsString))
{
CustomPlugin[]? customPlugins = JsonSerializer.Deserialize<CustomPlugin[]>(customPluginsString!.ToString()!);
tasks.AddRange(this.RegisterCustomPlugins(kernel, customPluginsString, authHeaders));
}

if (customPlugins != null)
await Task.WhenAll(tasks);
}

private async Task RegisterGithubPlugin(Kernel kernel, string GithubAuthHeader)
{
this._logger.LogInformation("Enabling GitHub plugin.");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GithubAuthHeader));
await kernel.ImportPluginFromOpenApiAsync(
pluginName: "GitHubPlugin",
filePath: GetPluginFullPath("GitHubPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
foreach (CustomPlugin plugin in customPlugins)
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
});
}

private async Task RegisterJiraPlugin(Kernel kernel, string JiraAuthHeader, KernelArguments variables)
{
this._logger.LogInformation("Registering Jira plugin");
var authenticationProvider = new BasicAuthenticationProvider(() => { return Task.FromResult(JiraAuthHeader); });
var hasServerUrlOverride = variables.TryGetValue("jira-server-url", out object? serverUrlOverride);

await kernel.ImportPluginFromOpenApiAsync(
pluginName: "JiraPlugin",
filePath: GetPluginFullPath("OpenApi/JiraPlugin/openapi.json"),
new OpenApiFunctionExecutionParameters
{
AuthCallback = authenticationProvider.AuthenticateRequestAsync,
ServerUrlOverride = hasServerUrlOverride ? new Uri(serverUrlOverride!.ToString()!) : null,
}); ; ;
}

private Task RegisterMicrosoftGraphPlugins(Kernel kernel, string GraphAuthHeader)
{
this._logger.LogInformation("Enabling Microsoft Graph plugin(s).");
BearerAuthenticationProvider authenticationProvider = new(() => Task.FromResult(GraphAuthHeader));
GraphServiceClient graphServiceClient = this.CreateGraphServiceClient(authenticationProvider.GraphClientAuthenticateRequestAsync);

kernel.ImportPluginFromObject(new TaskListPlugin(new MicrosoftToDoConnector(graphServiceClient)), "todo");
kernel.ImportPluginFromObject(new CalendarPlugin(new OutlookCalendarConnector(graphServiceClient)), "calendar");
kernel.ImportPluginFromObject(new EmailPlugin(new OutlookMailConnector(graphServiceClient)), "email");
return Task.CompletedTask;
}

private IEnumerable<Task> RegisterCustomPlugins(Kernel kernel, object? customPluginsString, Dictionary<string, string> authHeaders)
{
CustomPlugin[]? customPlugins = JsonSerializer.Deserialize<CustomPlugin[]>(customPluginsString!.ToString()!);

if (customPlugins != null)
{
foreach (CustomPlugin plugin in customPlugins)
{
if (authHeaders.TryGetValue(plugin.AuthHeaderTag.ToUpperInvariant(), out string? PluginAuthValue))
{
if (authHeaders.TryGetValue(plugin.AuthHeaderTag.ToUpperInvariant(), out string? PluginAuthValue))
// Register the ChatGPT plugin with the kernel.
this._logger.LogInformation("Enabling {0} plugin.", plugin.NameForHuman);

// TODO: [Issue #44] Support other forms of auth. Currently, we only support user PAT or no auth.
var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase);
Task authCallback(HttpRequestMessage request, string _, OpenAIAuthenticationConfig __, CancellationToken ___ = default)
{
// Register the ChatGPT plugin with the kernel.
this._logger.LogInformation("Enabling {0} plugin.", plugin.NameForHuman);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", PluginAuthValue);

// TODO: [Issue #44] Support other forms of auth. Currently, we only support user PAT or no auth.
var requiresAuth = !plugin.AuthType.Equals("none", StringComparison.OrdinalIgnoreCase);
Task authCallback(HttpRequestMessage request, string _, OpenAIAuthenticationConfig __, CancellationToken ___ = default)
{
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", PluginAuthValue);

return Task.CompletedTask;
}

await kernel.ImportPluginFromOpenAIAsync(
$"{plugin.NameForModel}Plugin",
PluginUtils.GetPluginManifestUri(plugin.ManifestDomain),
new OpenAIFunctionExecutionParameters
{
HttpClient = this._httpClientFactory.CreateClient(),
IgnoreNonCompliantErrors = true,
AuthCallback = requiresAuth ? authCallback : null
});
return Task.CompletedTask;
}

yield return kernel.ImportPluginFromOpenAIAsync(
$"{plugin.NameForModel}Plugin",
PluginUtils.GetPluginManifestUri(plugin.ManifestDomain),
new OpenAIFunctionExecutionParameters
{
HttpClient = this._httpClientFactory.CreateClient(),
IgnoreNonCompliantErrors = true,
AuthCallback = requiresAuth ? authCallback : null
});
}
}
else
{
this._logger.LogDebug("Failed to deserialize custom plugin details: {0}", customPluginsString);
}
}
else
{
this._logger.LogDebug("Failed to deserialize custom plugin details: {0}", customPluginsString);
}
}

Expand Down Expand Up @@ -354,6 +379,11 @@ private static KernelArguments GetContextVariables(Ask ask, IAuthInfo authInfo,
return contextVariables;
}

private static string GetPluginFullPath(string pluginPath)
{
return Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location)!, "Plugins", pluginPath);
}

/// <summary>
/// Dispose of the object.
/// </summary>
Expand Down

0 comments on commit 55b1c9b

Please sign in to comment.