diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 79bcefd..27df08f 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -28,16 +28,16 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v4 - name: Initialize CodeQL - uses: github/codeql-action/init@4bdb89f48054571735e3792627da6195c57459e2 # v3 + uses: github/codeql-action/init@b20883b0cd1f46c72ae0ba6d1090936928f9fa30 # v3 with: languages: ${{ matrix.language }} - name: Setup .NET if: matrix.language == 'csharp' - uses: actions/setup-dotnet@67a3573c9a986a3f9c594539f4ab511d57bb3ce9 # v4 + uses: actions/setup-dotnet@baa11fbfe1d6520db94683bd5c7a3818018e4309 # v4 with: dotnet-version: "10.0.x" @@ -49,9 +49,9 @@ jobs: - name: Autobuild if: matrix.language != 'csharp' - uses: github/codeql-action/autobuild@4bdb89f48054571735e3792627da6195c57459e2 # v3 + uses: github/codeql-action/autobuild@b20883b0cd1f46c72ae0ba6d1090936928f9fa30 # v3 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@4bdb89f48054571735e3792627da6195c57459e2 # v3 + uses: github/codeql-action/analyze@b20883b0cd1f46c72ae0ba6d1090936928f9fa30 # v3 with: category: "/language:${{ matrix.language }}" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 6201d4b..7c35688 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -30,7 +30,7 @@ jobs: packages: write steps: - name: Checkout repository - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up QEMU uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3.7.0 @@ -40,7 +40,7 @@ jobs: - name: Log in to the Container registry if: github.event_name != 'pull_request' - uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 + uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3.7.0 with: registry: ${{ env.REGISTRY }} username: ${{ github.actor }} diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index d056523..d93b51e 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -34,7 +34,7 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false @@ -73,6 +73,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@cdefb33c0f6224e58673d9004f47f7cb3e328b89 # v4 + uses: github/codeql-action/upload-sarif@b20883b0cd1f46c72ae0ba6d1090936928f9fa30 # v4 with: sarif_file: results.sarif diff --git a/catalog.Dockerfile b/catalog.Dockerfile index c8be7d9..bd4f035 100644 --- a/catalog.Dockerfile +++ b/catalog.Dockerfile @@ -1,12 +1,12 @@ # Build UI first in a Node.js container -FROM --platform=$BUILDPLATFORM node:25-bookworm@sha256:839caad0185604c2e602024686408cdbcc37f1d2825e54ea3900f4dad3310a07 AS ui-build +FROM --platform=$BUILDPLATFORM node:25-bookworm@sha256:e6b32434aba48dcb8730d56de2df3d137de213f1f527a922a6bf7b2853a24e86 AS ui-build WORKDIR /ui COPY ui . RUN npm install RUN npm run build # create the build container -FROM --platform=$BUILDPLATFORM mcr.microsoft.com/dotnet/sdk:10.0@sha256:90f913c96383b4146ce45985fd97e723fa1b1b6359441c4b683240236052eb59 AS build +FROM --platform=$BUILDPLATFORM mcr.microsoft.com/dotnet/sdk:10.0@sha256:25d14b400b75fa4e89d5bd4487a92a604a4e409ab65becb91821e7dc4ac7f81f AS build ARG TARGETARCH LABEL stage=build WORKDIR /api @@ -16,7 +16,7 @@ COPY --from=ui-build /ui/dist/ ./wwwroot/ RUN dotnet publish -c Release -o out -a $TARGETARCH # create the runtime container -FROM mcr.microsoft.com/dotnet/aspnet:10.0@sha256:cc9c8da871c6e367a63122b858b10cfc464f5687bcfcf9d3761bcff1188cf257 +FROM mcr.microsoft.com/dotnet/aspnet:10.0@sha256:1aacc8154bc3071349907dae26849df301188be1a2e1f4560b903fb6275e481a ARG INSTALL_AZURE_CLI=false WORKDIR /app COPY --from=build /api/out . diff --git a/catalog/.github/agents/ask-catalog.agent.md b/catalog/.github/agents/ask-catalog.agent.md new file mode 100644 index 0000000..bd03159 --- /dev/null +++ b/catalog/.github/agents/ask-catalog.agent.md @@ -0,0 +1,14 @@ +--- +description: "Ask questions about experiments using the catalog MCP tools." +tools: ["read", "experiment-catalog/*"] +--- + +This agent uses the experiment catalog MCP server to analyze experiments. + +ALWAYS use this skill: [experiment-catalog](../skills/experiment-catalog/SKILL.md). + +## Tool Selection + +- When comparing a permutation (set) to the baseline, use `CompareExperiment` directly. Do not call `ListSetsForExperiment` first to validate the set name. +- Use `CompareByRef` only when the user asks about individual ground truth (ref) performance, such as which refs improved or regressed. +- Call each tool only when its output is needed. Avoid discovery or pre-check calls before comparison tools. diff --git a/catalog/.github/skills/experiment-catalog/SKILL.md b/catalog/.github/skills/experiment-catalog/SKILL.md new file mode 100644 index 0000000..4d2228b --- /dev/null +++ b/catalog/.github/skills/experiment-catalog/SKILL.md @@ -0,0 +1,85 @@ +# Experiment Catalog + +A comprehensive tool for cataloging, comparing, and analyzing experiment results. The Experiment Catalog enables teams to track evaluation runs across projects, compare metrics against baselines, and identify performance regressions or improvements in AI/ML experimentation workflows. + +## Overview + +The experiment catalog organizes experimental data in a hierarchical structure: + +| Level | Also Known As | Description | +| ---------- | ----------------- | ------------------------------------------------------------------------------ | +| Project | Sprint, Milestone | Fixed evaluation environment (baseline, ground truth, metrics) for experiments | +| Experiment | - | A hypothesis-driven test varying inference within a project | +| Set | Permutation | A configuration variant within an experiment | +| Result | - | All metric values for a single ground truth iteration | +| Ref | Ground Truth | Reference to the entity being evaluated, used for aggregation and comparison | + +## Key Concepts + +### Projects + +A project represents a fixed evaluation environment in which experiments are conducted. The project establishes: + +- Baseline measurements for comparison +- Ground truth data (often split into validation and test sets) +- Metric definitions and evaluation scripts +- Stable infrastructure configuration + +Projects align with milestones or sprints. During a project, the evaluation tooling and data remain constant while developers vary inference approaches through experiments. Each project iteration produces a new version of the solution that can be measured against the previous version. + +### Experiments + +An experiment tests a specific hypothesis by varying inference parameters, code, or configuration. Experiments contain multiple evaluation runs (sets) to compare different approaches. The goal is to prove or disprove the hypothesis by comparing results against baselines. + +### Baselines + +Baselines provide measurement points for comparison: + +| Baseline Type | Purpose | +| ------------------------- | ----------------------------------------------------------- | +| Project Baseline | Initial measurement before experimentation begins | +| Experiment Baseline | First run of an experiment before making changes | +| Final Experiment Baseline | Best configuration run on both validation and test sets | +| Final Project Baseline | End-of-project measurement to compare against project start | + +When working with non-deterministic inference or evaluation systems, run baselines multiple times (commonly 5 iterations) to establish reliable averages. + +### Sets and Refs + +- **Set**: A collection of results from a single evaluation run. Running 5 iterations of 12 ground truths constitutes one set. Additional iterations can be added to an existing set. +- **Ref**: The catalog term for a ground truth. Every ground truth is stored and queried as a "ref" throughout the catalog API, MCP tools, and data model. When a user asks about ground truth performance, improvements, or regressions, translate "ground truth" to "ref" in all catalog operations. Refs enable aggregation across iterations and comparison of individual ground truth performance across evaluation runs. + +### Iterations + +An iteration is a single execution of inference and evaluation for a ground truth. Because AI agents and LLM-based systems are non-deterministic, running multiple iterations is essential: + +- **Minimum recommendation**: At least 5 iterations per ground truth +- **Averaging**: Multiple iterations allow averaging results to account for variance in non-deterministic systems +- **Statistical analysis**: P-values and confidence intervals are calculated per ground truth, requiring multiple iterations to determine a reasonable range versus baseline + +A result captures all metric values for one ground truth iteration. When a set contains 5 iterations of 12 ground truths, it stores 60 individual results (5 × 12). + +## Experimentation Workflow + +The recommended workflow follows these phases: + +1. **Create a Project**: Establish the fixed evaluation environment +2. **Run a Project Baseline**: Measure initial state before experimentation +3. **Run Experiments**: + - Create an experiment with a hypothesis + - Run an experiment baseline (or accept the project baseline) + - Run permutations varying inference parameters + - Determine the best permutation + - Run a final experiment baseline on validation and test sets + - Write a summary documenting the experiment + - Review with your team + - Approve (merge) or reject +4. **Run a Final Project Baseline**: Measure end state after all experiments + +## Determining Best Permutation + +With many ground truths, differences between permutations are often minimal. Techniques for identifying the best approach: + +- **Look at Subsets**: Subsets like "multi-turn" examples may show 20-30% differences where overall metrics show only 1% variance +- **Prioritize Metrics**: Rank metrics by importance and evaluate based on highest-priority metrics first +- **Statistical Significance**: Use p-value calculations to determine when metric changes are meaningful diff --git a/catalog/.gitignore b/catalog/.gitignore index 5992428..2034b94 100644 --- a/catalog/.gitignore +++ b/catalog/.gitignore @@ -1,6 +1,7 @@ wwwroot/ cache/ *.env +.copilot-tracking/ .vscode/* !.vscode/settings.json diff --git a/catalog/.vscode/settings.json b/catalog/.vscode/settings.json new file mode 100644 index 0000000..caa768f --- /dev/null +++ b/catalog/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "microsoft-authentication.implementation": "msal-no-broker" +} \ No newline at end of file diff --git a/catalog/Program.cs b/catalog/Program.cs index d7ffda7..c8e07c7 100644 --- a/catalog/Program.cs +++ b/catalog/Program.cs @@ -13,6 +13,10 @@ using Microsoft.Extensions.Options; using Microsoft.IdentityModel.Tokens; using Microsoft.OpenApi; +using ModelContextProtocol; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.Protocol; using NetBricks; // load environment variables from .env file @@ -50,10 +54,19 @@ builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddHostedService(); builder.Services.AddHostedService(sp => sp.GetRequiredService()); +// add MCP server with analysis tools +builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithToolsFromAssembly() + .AddCallToolFilter(McpToolExceptionFilter.Create()); + // add controllers with swagger builder.Services.AddControllers().AddNewtonsoftJson(); builder.Services.AddEndpointsApiExplorer(); @@ -79,8 +92,14 @@ // add authentication with deferred configuration builder.Services.AddSingleton, JwtBearerConfigurator>(); -builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) - .AddJwtBearer(); +builder.Services.AddSingleton, McpAuthenticationConfigurator>(); +builder.Services.AddAuthentication(options => + { + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + }) + .AddJwtBearer() + .AddMcp(); builder.Services.AddAuthorization(); builder.Services.AddSingleton, AuthorizationConfigurator>(); @@ -90,7 +109,10 @@ options.AddPolicy("default-policy", corsBuilder => { - corsBuilder.WithOrigins("http://localhost:6020") + corsBuilder.WithOrigins( + "http://localhost:6020", + "http://localhost:6274" // MCP Inspector + ) .AllowAnyHeader() .AllowAnyMethod() .AllowCredentials(); @@ -115,8 +137,9 @@ app.UseAuthentication(); app.UseAuthorization(); -// map controllers +// map controllers and MCP app.MapControllers(); +app.MapMcp("/mcp"); // run app.Run(); \ No newline at end of file diff --git a/catalog/config/McpAuthenticationConfigurator.cs b/catalog/config/McpAuthenticationConfigurator.cs new file mode 100644 index 0000000..4856c65 --- /dev/null +++ b/catalog/config/McpAuthenticationConfigurator.cs @@ -0,0 +1,52 @@ +using System; +using System.Threading; +using Microsoft.Extensions.Options; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.Authentication; +using NetBricks; + +namespace Catalog; + +/// +/// Configures MCP authentication options using the application's OIDC settings. +/// +/// +/// When authentication is enabled, this sets up the +/// so that MCP clients can discover +/// the OAuth authorization server and complete the OAuth 2.0 flow. +/// +public class McpAuthenticationConfigurator(IConfigFactory configFactory) + : IConfigureNamedOptions +{ + /// + public void Configure(string? name, McpAuthenticationOptions options) + { + if (name != McpAuthenticationDefaults.AuthenticationScheme) + { + return; + } + Configure(options); + } + + /// + public void Configure(McpAuthenticationOptions options) + { + var config = configFactory.GetAsync(CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); + if (!config.IsAuthenticationEnabled) + { + return; + } + + // the API scope ensures Azure AD issues a JWT access token whose audience + // matches this API rather than an opaque Microsoft Graph token. + // OIDC_CLIENT_ID is required for MCP authentication to work with Azure AD. + options.ResourceMetadata = new ProtectedResourceMetadata + { + AuthorizationServers = { new Uri(config.OIDC_AUTHORITY!) } + }; + if (!string.IsNullOrEmpty(config.OIDC_CLIENT_ID)) + { + options.ResourceMetadata.ScopesSupported = [$"api://{config.OIDC_CLIENT_ID}/.default"]; + } + } +} diff --git a/catalog/controllers/AnalysisController.cs b/catalog/controllers/AnalysisController.cs index 79c853c..a188ed4 100644 --- a/catalog/controllers/AnalysisController.cs +++ b/catalog/controllers/AnalysisController.cs @@ -1,14 +1,6 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Linq.Expressions; -using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using Azure; using Microsoft.AspNetCore.Mvc; -using NetBricks; namespace Catalog; @@ -27,64 +19,11 @@ public IActionResult CalculateStatistics( [HttpPost("meaningful-tags")] public async Task MeaningfulTags( - [FromServices] IStorageService storageService, + [FromServices] AnalysisService analysisService, [FromBody] MeaningfulTagsRequest request, CancellationToken cancellationToken) { - var diffs = new List(); - - var experiment = await storageService.GetExperimentAsync(request.Project, request.Experiment, cancellationToken: cancellationToken); - - var baseline = request.CompareTo == MeaningfulTagsComparisonMode.Baseline - ? await storageService.GetProjectBaselineAsync(request.Project, cancellationToken) - : null; - - var listOfTags = await storageService.ListTagsAsync(request.Project, cancellationToken); - var includeTags = await storageService.GetTagsAsync(request.Project, listOfTags, cancellationToken); - var excludeTags = request.ExcludeTags is not null - ? await storageService.GetTagsAsync(request.Project, request.ExcludeTags, cancellationToken) - : null; - - var compareToDefault = 0.0M; - if (request.CompareTo == MeaningfulTagsComparisonMode.Average) - { - var results = experiment.Filter(null, excludeTags); - var experimentResult = experiment.AggregateSet(request.Set, results); - Metric? experimentMetric = null; - experimentResult?.Metrics?.TryGetValue(request.Metric, out experimentMetric); - compareToDefault = experimentMetric?.Value ?? 0.0M; - } - - foreach (var tag in includeTags) - { - var experimentResults = experiment.Filter([tag], excludeTags); - var experimentResult = experiment.AggregateSet(request.Set, experimentResults); - Metric? experimentTagMetric = null; - experimentResult?.Metrics?.TryGetValue(request.Metric, out experimentTagMetric); - - decimal? compareTo = compareToDefault; - if (baseline is not null) - { - var baselineResults = baseline.Filter([tag], excludeTags); - var baselineResult = baseline.AggregateSet(baseline.BaselineSet ?? baseline.LastSet, baselineResults); - Metric? baselineTagMetric = null; - baselineResult?.Metrics?.TryGetValue(request.Metric, out baselineTagMetric); - compareTo = baselineTagMetric?.Value; - } - - if (experimentTagMetric?.Value is not null && compareTo is not null) - { - var diff = (decimal)(experimentTagMetric.Value - compareTo); - diffs.Add(new TagDiff - { - Tag = tag.Name, - Diff = diff, - Impact = diff * (experimentTagMetric.Count ?? 0), - Count = experimentTagMetric.Count, - }); - } - } - - return Ok(new MeaningfulTagsResponse { Tags = diffs.OrderBy(x => x.Impact) }); + var response = await analysisService.GetMeaningfulTagsAsync(request, cancellationToken); + return Ok(response); } } diff --git a/catalog/controllers/ExperimentsController.cs b/catalog/controllers/ExperimentsController.cs index 5c40b10..8ee5daa 100644 --- a/catalog/controllers/ExperimentsController.cs +++ b/catalog/controllers/ExperimentsController.cs @@ -1,24 +1,15 @@ -using System; using System.Collections.Generic; using System.ComponentModel.DataAnnotations; -using System.Diagnostics; -using System.Linq; -using System.Net.NetworkInformation; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Mvc; -using Microsoft.Extensions.Logging; -using NetBricks; namespace Catalog; [ApiController] [Route("api/projects/{projectName}/experiments")] -public class ExperimentsController(ILogger logger) : ControllerBase +public class ExperimentsController : ControllerBase { - private readonly ILogger logger = logger; - [HttpGet] public async Task>> List( [FromServices] IStorageService storageService, @@ -50,6 +41,22 @@ public async Task> Get( return Ok(experiment); } + [HttpGet("{experimentName}/sets")] + public async Task>> ListSetsForExperiment( + [FromServices] ExperimentService experimentService, + [FromRoute, Required, ValidName, ValidProjectName] string projectName, + [FromRoute, Required, ValidName, ValidExperimentName] string experimentName, + CancellationToken cancellationToken) + { + if (string.IsNullOrEmpty(projectName) || string.IsNullOrEmpty(experimentName)) + { + return BadRequest("a project name and experiment name are required."); + } + + var sets = await experimentService.ListSetsForExperimentAsync(projectName, experimentName, cancellationToken); + return Ok(sets); + } + [HttpPost] public async Task Add( [FromServices] IStorageService storageService, @@ -104,22 +111,13 @@ public async Task SetBaselineForExperiment( return Ok(); } - private static async Task<(IList includeTags, IList excludeTags)> LoadTags( - IStorageService storageService, - string projectName, - string includeTagsStr, - string excludeTagsStr, - CancellationToken cancellationToken) - { - var includeTags = await storageService.GetTagsAsync(projectName, includeTagsStr.AsArray(() => [])!, cancellationToken); - var excludeTags = await storageService.GetTagsAsync(projectName, excludeTagsStr.AsArray(() => [])!, cancellationToken); - return (includeTags, excludeTags); - } - + /// + /// Compares an experiment's sets (permutations) against the baseline using aggregate metrics. + /// This is the default endpoint for comparing permutations to the baseline. + /// [HttpGet("{experimentName}/compare")] public async Task> Compare( - [FromServices] IConfigFactory configFactory, - [FromServices] IStorageService storageService, + [FromServices] ExperimentService experimentService, [FromRoute, Required, ValidName, ValidProjectName] string projectName, [FromRoute, Required, ValidName, ValidExperimentName] string experimentName, CancellationToken cancellationToken, @@ -132,108 +130,18 @@ public async Task> Compare( return BadRequest("a project name and experiment name are required."); } - // init - var watch = Stopwatch.StartNew(); - var comparison = new Comparison(); - var (includeTags, excludeTags) = await LoadTags(storageService, projectName, includeTagsStr, excludeTagsStr, cancellationToken); - comparison.MetricDefinitions = (await storageService.GetMetricsAsync(projectName, cancellationToken)) - .ToDictionary(x => x.Name); - logger.LogDebug("loaded tags and metric definitions in {ms} ms.", watch.ElapsedMilliseconds); - - // get the project baseline - try - { - watch.Restart(); - var baseline = await storageService.GetProjectBaselineAsync(projectName, cancellationToken); - var baselineSet = baseline.BaselineSet ?? baseline.LastSet; - var baselineFiltered = baseline.Filter(includeTags, excludeTags); - baseline.MetricDefinitions = comparison.MetricDefinitions; - comparison.ProjectBaseline = new ComparisonEntity - { - Project = projectName, - Experiment = baseline.Name, - Set = baselineSet, - Result = baseline.AggregateSet(baselineSet, baselineFiltered), - Count = baseline.Results?.Count(x => x.Set == baselineSet), // unfiltered count - }; - logger.LogDebug("loaded project baseline in {ms} ms.", watch.ElapsedMilliseconds); - } - catch (Exception e) - { - this.logger.LogWarning(e, "Failed to get baseline experiment for project {projectName}.", projectName); - } - - // get configuration - var config = await configFactory.GetAsync(cancellationToken); - - // get the experiment baseline - watch.Restart(); - var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); - var experimentBaselineSet = experiment.BaselineSet ?? experiment.FirstSet; - var experimentFiltered = experiment.Filter(includeTags, excludeTags); - experiment.MetricDefinitions = comparison.MetricDefinitions; - comparison.ExperimentBaseline = - string.Equals(experiment.Baseline, ":project", StringComparison.OrdinalIgnoreCase) - ? comparison.ProjectBaseline - : new ComparisonEntity - { - Project = projectName, - Experiment = experiment.Name, - Set = experimentBaselineSet, - Result = experiment.AggregateSet(experimentBaselineSet, experimentFiltered), - Count = experiment.Results?.Count(x => x.Set == experimentBaselineSet), // unfiltered count - }; - logger.LogDebug("loaded experiment baseline in {ms} ms.", watch.ElapsedMilliseconds); - - // get the sets - watch.Restart(); - comparison.Sets = experiment.AggregateAllSets(experimentFiltered) - .Select(x => - { - // find matching statistics - var statistics = experiment.Statistics?.LastOrDefault(y => - { - if (y.Set != x.Set) return false; - if (y.BaselineExperiment != comparison.ExperimentBaseline?.Experiment) return false; - if (y.BaselineSet != comparison.ExperimentBaseline?.Set) return false; - if (y.BaselineResultCount != comparison.ExperimentBaseline?.Count) return false; - if (y.SetResultCount != experiment.Results?.Count(z => z.Set == x.Set)) return false; // unfiltered count - if (y.NumSamples != config.CALC_PVALUES_USING_X_SAMPLES) return false; - if (y.ConfidenceLevel != config.CONFIDENCE_LEVEL) return false; - return true; - }); - - // fold statistics into result metrics - if (statistics?.Metrics is not null && x.Metrics is not null) - { - foreach (var (metricName, statisticsMetric) in statistics.Metrics) - { - if (x.Metrics.TryGetValue(metricName, out var resultMetric)) - { - resultMetric.PValue = statisticsMetric.PValue; - resultMetric.CILower = statisticsMetric.CILower; - resultMetric.CIUpper = statisticsMetric.CIUpper; - } - } - } - - return new ComparisonEntity - { - Project = projectName, - Experiment = experiment.Name, - Set = x.Set, - Result = x, - }; - }); - logger.LogDebug("aggregated sets in {ms} ms.", watch.ElapsedMilliseconds); - watch.Stop(); - + var comparison = await experimentService.CompareAsync(projectName, experimentName, includeTagsStr, excludeTagsStr, cancellationToken); return Ok(comparison); } + /// + /// Breaks down a comparison per ref (ground truth), showing which individual ground truths + /// improved or regressed. Only use when investigating individual ground truth performance. + /// For aggregate comparison of a permutation to the baseline, use the Compare endpoint instead. + /// [HttpGet("{experimentName}/sets/{setName}/compare-by-ref")] public async Task> CompareByRef( - [FromServices] IStorageService storageService, + [FromServices] ExperimentService experimentService, [FromRoute, Required, ValidName, ValidProjectName] string projectName, [FromRoute, Required, ValidName, ValidExperimentName] string experimentName, [FromRoute, Required, ValidName] string setName, @@ -246,85 +154,13 @@ public async Task> CompareByRef( return BadRequest("a project name, experiment name, and set name are required."); } - // init - var comparison = new ComparisonByRef(); - var (includeTags, excludeTags) = await LoadTags(storageService, projectName, includeTagsStr, excludeTagsStr, cancellationToken); - comparison.MetricDefinitions = (await storageService.GetMetricsAsync(projectName, cancellationToken)) - .ToDictionary(x => x.Name); - - // get the baseline - try - { - var baseline = await storageService.GetProjectBaselineAsync(projectName, cancellationToken); - var baselineFiltered = baseline.Filter(includeTags, excludeTags); - baseline.MetricDefinitions = comparison.MetricDefinitions; - comparison.ProjectBaseline = new ComparisonByRefEntity - { - Project = projectName, - Experiment = baseline.Name, - Set = baseline.BaselineSet ?? baseline.LastSet, - Results = baseline.AggregateSetByRef(baseline.BaselineSet ?? baseline.LastSet, baselineFiltered), - }; - } - catch (Exception e) - { - this.logger.LogWarning(e, "Failed to get baseline experiment for project {projectName}.", projectName); - } - - // get the experiment info - var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); - var experimentFiltered = experiment.Filter(includeTags, excludeTags); - experiment.MetricDefinitions = comparison.MetricDefinitions; - - // get the experiment baseline - if (string.Equals(experiment.Baseline, ":project", StringComparison.OrdinalIgnoreCase)) - { - comparison.ExperimentBaseline = comparison.ProjectBaseline; - } - else - { - comparison.ExperimentBaseline = new ComparisonByRefEntity - { - Project = projectName, - Experiment = experiment.Name, - Set = experiment.BaselineSet ?? experiment.FirstSet, - Results = experiment.AggregateSetByRef(experiment.BaselineSet ?? experiment.FirstSet, experimentFiltered), - }; - } - - // get the set experiment - comparison.ExperimentSet = new ComparisonByRefEntity - { - Project = projectName, - Experiment = experiment.Name, - Set = setName, - Results = experiment.AggregateSetByRef(setName, experimentFiltered), - }; - - // run policies - // if (comparison.ChosenResultsForChosenExperiment is not null - // && comparison.BaselineResultsForChosenExperiment is not null) - // { - // var policy = new PercentImprovement(); - // foreach (var (key, result) in comparison.ChosenResultsForChosenExperiment) - // { - // if (comparison.BaselineResultsForChosenExperiment.TryGetValue(key, out var baseline)) - // { - // policy.Evaluate(result, baseline, comparison.MetricDefinitions); - // } - // } - // this.logger.LogWarning("policy passed? {0}, {1}, {2}", policy.IsPassed, policy.NumResultsThatPassed, policy.NumResultsThatFailed); - // this.logger.LogWarning(policy.Requirement); - // this.logger.LogWarning(policy.Actual); - // } - + var comparison = await experimentService.CompareByRefAsync(projectName, experimentName, setName, includeTagsStr, excludeTagsStr, cancellationToken); return Ok(comparison); } [HttpGet("{experimentName}/sets/{setName}")] public async Task> GetNamedSet( - [FromServices] IConfigFactory configFactory, - [FromServices] IStorageService storageService, + [FromServices] ExperimentService experimentService, [FromRoute, Required, ValidName, ValidProjectName] string projectName, [FromRoute, Required, ValidName, ValidExperimentName] string experimentName, [FromRoute, Required, ValidName] string setName, @@ -332,31 +168,7 @@ public async Task> GetNamedSet( [FromQuery(Name = "include-tags")] string includeTagsStr = "", [FromQuery(Name = "exclude-tags")] string excludeTagsStr = "") { - // init - var metricDefinitions = (await storageService.GetMetricsAsync(projectName, cancellationToken)) - .ToDictionary(x => x.Name); - - // get the experiment and filter the results - var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); - var (includeTags, excludeTags) = await LoadTags(storageService, projectName, includeTagsStr, excludeTagsStr, cancellationToken); - var experimentFiltered = experiment.Filter(includeTags, excludeTags); - experiment.MetricDefinitions = metricDefinitions; - - // get the results - var results = experiment.AggregateSetByEachResult(setName, experimentFiltered) - ?? Enumerable.Empty(); - - // add the support docs - var config = await configFactory.GetAsync(cancellationToken); - if (!string.IsNullOrEmpty(config.PATH_TEMPLATE)) - { - foreach (var result in results) - { - if (!string.IsNullOrEmpty(result.InferenceUri)) result.InferenceUri = string.Format(config.PATH_TEMPLATE, result.InferenceUri); - if (!string.IsNullOrEmpty(result.EvaluationUri)) result.EvaluationUri = string.Format(config.PATH_TEMPLATE, result.EvaluationUri); - } - } - + var results = await experimentService.GetNamedSetAsync(projectName, experimentName, setName, includeTagsStr, excludeTagsStr, cancellationToken); return Ok(results); } diff --git a/catalog/HttpException.cs b/catalog/controllers/HttpException.cs similarity index 100% rename from catalog/HttpException.cs rename to catalog/controllers/HttpException.cs diff --git a/catalog/exp-catalog.csproj b/catalog/exp-catalog.csproj index a0b9cb1..465ecaa 100644 --- a/catalog/exp-catalog.csproj +++ b/catalog/exp-catalog.csproj @@ -8,20 +8,21 @@ - + - + + - - - - - - - + + + + + + + diff --git a/catalog/mcp.md b/catalog/mcp.md new file mode 100644 index 0000000..e1c95dd --- /dev/null +++ b/catalog/mcp.md @@ -0,0 +1,224 @@ +# MCP Services + +This document describes the MCP capabilities for experiment comparison and analysis, and how the integration was implemented. + +## Tests + +The following queries have been tested: + +- how many projects do I have? +- what projects do I have? +- list the projects I have +- what experiments do I have in amltest? +- list experiments under sprint-02 +- what permutations of the test_aml_run experiment exist? +- how good was the 20250805220419 permuation? +- how did the 20250807071317 permutation compare to the baseline? +- what were the top 5 ground truths that saw improvement in the recall? +- create me a new project called "sprint-02" +- create me an experiment under sprint-02 +- set the experiment known as "baseline" as the baseline for this project +- what tags are used in this project? +- what metrics are defined? +- what 3 tags would have the greatest impact on my recall metric? + +## Implementation Guide + +Follow these steps to add MCP (Model Context Protocol) support to an existing ASP.NET Core web API. This pattern exposes the same business logic through both REST endpoints and MCP tool calls. + +### 1. Add the NuGet package + +Add the `ModelContextProtocol.AspNetCore` package to the project file: + +```xml + +``` + +### 2. Extract business logic into service classes + +Move logic out of controllers into dedicated service classes so both controllers and MCP tools can share the same code. Two services were created in the `services/` folder: + +| Service | Purpose | +| ------------------- | ---------------------------------------------------------------- | +| `AnalysisService` | Tag impact analysis (meaningful tags) | +| `ExperimentService` | Comparison, per-ref comparison, named set retrieval, set listing | + +Controllers become thin wrappers that delegate to these services. MCP tool classes do the same. + +### 3. Create MCP tool classes + +Create classes in a `mcp/` folder. Each class groups related tools and follows this pattern: + +- Annotate the class with `[McpServerToolType]` +- Use constructor injection to receive the shared services +- Annotate each public method with `[McpServerTool(Name = "...")]` and `[Description("...")]` +- Annotate each parameter with `[Description("...")]` +- Return domain objects directly (the MCP SDK serializes them for you) + +Three tool classes were created: + +| Class | Tools | +| ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `ProjectsTools` | ListProjects, AddProject, ListTags, GetMetricDefinitions | +| `ExperimentsTools` | ListExperiments, GetExperiment, AddExperiment, ListSetsForExperiment, SetExperimentAsBaseline, SetBaselineForExperiment, CompareExperiment, CompareByRef, GetNamedSet | +| `AnalysisTools` | CalculateStatistics, MeaningfulTags | + +### 4. Validate tool parameters explicitly + +ASP.NET controllers validate parameters automatically through the model validation pipeline, which processes `DataAnnotations` attributes such as `[Required]`, `[ValidName]`, and `[ValidProjectName]`. The MCP SDK does not have an equivalent pipeline. Tool arguments arrive as raw `JsonElement` values and are deserialized into method parameters without running any `ValidationAttribute` logic. The SDK documentation confirms this: arguments "should be considered unvalidated and untrusted." + +Adding `[Required, ValidName, ValidProjectName]` to MCP tool parameters has no effect because nothing invokes those attributes at runtime. Custom validators like `ValidProjectNameAttribute` also depend on `ValidationContext.GetService()` to resolve `IStorageService`, which the SDK never provides. + +Instead, validate parameters explicitly at the start of each tool method using a shared helper class: + +```csharp +public static class McpValidationHelper +{ + public static void ValidateRequiredName(string? value, string parameterName) { ... } + public static void ValidateProjectName(string? value, IStorageService storageService) { ... } + public static void ValidateExperimentName(string? value, IStorageService storageService) { ... } + public static void ValidateOptionalNames(IEnumerable? values, string parameterName) { ... } +} +``` + +Each tool class injects `IStorageService` and exposes thin wrapper methods: + +```csharp +public class ExperimentsTools(IStorageService storageService, ExperimentService experimentService) +{ + private void ValidateProjectName(string? project) => + McpValidationHelper.ValidateProjectName(project, storageService); + private void ValidateExperimentName(string? experiment) => + McpValidationHelper.ValidateExperimentName(experiment, storageService); + + [McpServerTool(Name = "GetExperiment"), Description("...")] + public async Task GetExperiment(string project, string experiment, ...) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + return await storageService.GetExperimentAsync(project, experiment, false, cancellationToken); + } +} +``` + +Validation failures throw `HttpException(400, ...)`, which the exception filter (see next step) catches and returns as an MCP error result. + +### 5. Create an exception filter for MCP + +MCP tool calls do not pass through ASP.NET middleware, so `HttpExceptionMiddleware` does not catch exceptions thrown during tool execution. Create an `McpToolExceptionFilter` that mirrors the same error-handling behavior: + +- Catch `HttpWithResponseException`, `HttpException`, and generic `Exception` +- Return a `CallToolResult` with `IsError = true` and a text message +- Register the filter via `.AddCallToolFilter(McpToolExceptionFilter.Create())` + +### 6. Register services and MCP in Program.cs + +Add the following registrations: + +```csharp +// register the shared services +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); + +// register the MCP server +builder.Services + .AddMcpServer() + .WithHttpTransport() + .WithToolsFromAssembly() + .AddCallToolFilter(McpToolExceptionFilter.Create()); +``` + +`WithToolsFromAssembly()` discovers all classes annotated with `[McpServerToolType]` automatically. + +### 7. Map the MCP endpoint + +After `app.MapControllers()`, add: + +```csharp +app.MapMcp("/mcp"); +``` + +This exposes the MCP Streamable HTTP endpoint at `/mcp`. + +### 8. Update CORS for MCP Inspector + +If you use the MCP Inspector for testing, add its origin to the CORS policy: + +```csharp +corsBuilder.WithOrigins( + "http://localhost:6020", + "http://localhost:6274" // MCP Inspector +) +``` + +### 9. Handle enum serialization for MCP + +The MCP SDK uses `System.Text.Json` rather than `Newtonsoft.Json`. If any tool parameters use enums, add the `System.Text.Json` converter attribute alongside any existing Newtonsoft attributes: + +```csharp +[System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] +public enum MeaningfulTagsComparisonMode { Baseline, Zero, Average } +``` + +### 10. Authentication + +#### Local development (auth disabled) + +When running against localhost you can leave authentication off by omitting the `OIDC_AUTHORITY` environment variable. Without an authority the fallback policy is not set and the `/mcp` endpoint is open, so MCP clients connect without any token exchange. + +#### Deployed service (auth enabled) + +When the catalog is deployed with authentication enabled (`OIDC_AUTHORITY`, `OIDC_CLIENT_ID`, and optionally `OIDC_CLIENT_SECRET` are set), the MCP endpoint requires an OAuth 2.0 access token. The MCP SDK's authentication handler advertises the token requirements through `ProtectedResourceMetadata`, and compliant MCP clients (such as VS Code with GitHub Copilot) perform the OAuth flow automatically. + +##### Code changes + +Register the MCP authentication scheme alongside JWT Bearer in `Program.cs`: + +```csharp +builder.Services.AddSingleton, McpAuthenticationConfigurator>(); +builder.Services.AddAuthentication(options => + { + options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme; + options.DefaultChallengeScheme = McpAuthenticationDefaults.AuthenticationScheme; + }) + .AddJwtBearer() + .AddMcp(); +``` + +Create `McpAuthenticationConfigurator` to populate the protected resource metadata from the app's OIDC settings. It advertises the authorization server and an API scope of `api://{OIDC_CLIENT_ID}/.default` so that MCP clients know where to obtain a token and which scope to request. + +##### Azure AD app registration + +Configure the app registration that represents the catalog API: + +1. **Expose an API** — set the Application ID URI to `api://` and add a scope (for example `api:///all` with "Admins and users" consent). This scope is what MCP clients request when obtaining an access token. +2. **Authorized client applications** — add the VS Code client application ID and authorize it for the scope created above. This allows VS Code to acquire tokens for the API without a user consent prompt. +3. **Redirect URIs** — ensure the following are registered under the **Web** platform: + - `https://vscode.dev/redirect` (VS Code web) + - `http://localhost:33418` (VS Code desktop OAuth redirect) +4. **Mobile and desktop applications** — enable the MSAL redirect URI (`msal://auth`). +5. **Allow public client flows** — set to **Yes** so VS Code can authenticate as a public client without a client secret. + +##### VS Code settings + +Add the following to your VS Code `settings.json` (workspace or user level) so that the Microsoft authentication extension uses the MSAL flow without a broker, which is required for the MCP OAuth handshake: + +```json +{ + "microsoft-authentication.implementation": "msal-no-broker" +} +``` + +##### Summary + +| Scenario | OIDC_AUTHORITY | MCP auth behavior | +| ------------------ | -------------- | ------------------------------------------------------------------------------------ | +| Local development | Not set | MCP endpoint is open, no token needed | +| Deployed with auth | Set | MCP clients perform OAuth 2.0 automatically using the advertised scope and authority | + +### 11. Add Copilot agent and skill files (optional) + +To enable GitHub Copilot Chat to use the MCP tools via an agent: + +- Create `.github/agents/ask-catalog.agent.md` with the agent definition, tool references, and tool selection guidance +- Create `.github/skills/experiment-catalog/SKILL.md` with domain context (hierarchy, terminology, workflows) so the agent can reason about the data model diff --git a/catalog/mcp/AnalysisTools.cs b/catalog/mcp/AnalysisTools.cs new file mode 100644 index 0000000..2020e52 --- /dev/null +++ b/catalog/mcp/AnalysisTools.cs @@ -0,0 +1,83 @@ +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using ModelContextProtocol.Server; + +namespace Catalog; + +/// +/// MCP tools for experiment analysis operations. +/// +[McpServerToolType] +public class AnalysisTools(AnalysisService analysisService, CalculateStatisticsService calculateStatisticsService, IStorageService storageService) +{ + private void ValidateProjectName(string? project) => McpValidationHelper.ValidateProjectName(project, storageService); + private void ValidateExperimentName(string? experiment) => McpValidationHelper.ValidateExperimentName(experiment, storageService); + + /// + /// Enqueues a request to calculate statistics (p-values) for an experiment by comparing against the baseline. + /// + /// The project name. + /// The experiment name. + /// Token to cancel the operation. + /// A message indicating the request was enqueued. + [McpServerTool(Name = "CalculateStatistics"), Description("Enqueue a request to calculate statistics (p-values) for an experiment.")] + public string CalculateStatistics( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + + var request = new CalculateStatisticsRequest + { + Project = project, + Experiment = experiment + }; + + calculateStatisticsService.Enqueue(request); + return $"Statistics calculation enqueued for '{project}/{experiment}'"; + } + + /// + /// Analyzes which tags have the most meaningful impact on a specific metric. + /// + /// The project name. + /// The experiment name. + /// The result set to analyze. + /// The metric to analyze. + /// Optional tags to exclude from analysis. + /// Comparison mode: Baseline, Zero, or Average. + /// Token to cancel the operation. + /// A list of tags ordered by their impact on the metric. + [McpServerTool(Name = "MeaningfulTags"), Description("Analyze which tags have the most meaningful impact on a specific metric.")] + public async Task MeaningfulTags( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + [Description("The result set to analyze")] string set, + [Description("The metric to analyze")] string metric, + [Description("Optional tags to exclude from analysis")] IEnumerable? excludeTags = null, + [Description("Comparison mode: Baseline (compare to project baseline), Zero (compare to zero), or Average (compare to experiment average)")] MeaningfulTagsComparisonMode compareTo = MeaningfulTagsComparisonMode.Baseline, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + McpValidationHelper.ValidateRequiredName(set, "set"); + McpValidationHelper.ValidateRequiredName(metric, "metric"); + McpValidationHelper.ValidateOptionalNames(excludeTags, "excludeTags"); + + var request = new MeaningfulTagsRequest + { + Project = project, + Experiment = experiment, + Set = set, + Metric = metric, + ExcludeTags = excludeTags, + CompareTo = compareTo + }; + + return await analysisService.GetMeaningfulTagsAsync(request, cancellationToken); + } +} diff --git a/catalog/mcp/ExperimentsTools.cs b/catalog/mcp/ExperimentsTools.cs new file mode 100644 index 0000000..6f35eb3 --- /dev/null +++ b/catalog/mcp/ExperimentsTools.cs @@ -0,0 +1,217 @@ +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using ModelContextProtocol.Server; + +namespace Catalog; + +/// +/// MCP tools for experiment management operations. +/// +[McpServerToolType] +public class ExperimentsTools(IStorageService storageService, ExperimentService experimentService) +{ + private void ValidateProjectName(string? project) => McpValidationHelper.ValidateProjectName(project, storageService); + private void ValidateExperimentName(string? experiment) => McpValidationHelper.ValidateExperimentName(experiment, storageService); + + /// + /// Lists all experiments in a project. + /// + /// The project name. + /// Token to cancel the operation. + /// A list of experiments. + [McpServerTool(Name = "ListExperiments"), Description("List all experiments in a project.")] + public async Task> ListExperiments( + [Description("The project name")] string project, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + + return await storageService.GetExperimentsAsync(project, cancellationToken); + } + + /// + /// Gets a specific experiment by name. + /// + /// The project name. + /// The experiment name. + /// Token to cancel the operation. + /// The experiment details. + [McpServerTool(Name = "GetExperiment"), Description("Get a specific experiment by name.")] + public async Task GetExperiment( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + + return await storageService.GetExperimentAsync(project, experiment, false, cancellationToken); + } + + /// + /// Adds a new experiment to a project. + /// + /// The project name. + /// The experiment name. + /// The experiment hypothesis. + /// Token to cancel the operation. + /// A message indicating the experiment was added. + [McpServerTool(Name = "AddExperiment"), Description("Add a new experiment to a project.")] + public async Task AddExperiment( + [Description("The project name")] string project, + [Description("The experiment name")] string name, + [Description("The experiment hypothesis")] string hypothesis, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(name); + McpValidationHelper.ValidateRequiredName(hypothesis, "hypothesis"); + + var experiment = new Experiment { Name = name, Hypothesis = hypothesis }; + await storageService.AddExperimentAsync(project, experiment, cancellationToken); + return $"Experiment '{name}' added to project '{project}'."; + } + + /// + /// Lists the distinct set names (permutations) for an experiment. + /// Use this to discover available permutations, not to validate a set name before comparison. + /// + /// The project name. + /// The experiment name. + /// Token to cancel the operation. + /// A list of set names. + [McpServerTool(Name = "ListSetsForExperiment"), Description("List the distinct set names (permutations) for an experiment. Use only when the user wants to see which permutations exist. Do not call this to validate a set name before comparison; call CompareExperiment directly instead.")] + public async Task> ListSetsForExperiment( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + + return await experimentService.ListSetsForExperimentAsync(project, experiment, cancellationToken); + } + + /// + /// Sets an experiment as the project baseline. + /// + /// The project name. + /// The experiment name. + /// Token to cancel the operation. + /// A message indicating the experiment was set as baseline. + [McpServerTool(Name = "SetExperimentAsBaseline"), Description("Set an experiment as the project baseline.")] + public async Task SetExperimentAsBaseline( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + + await storageService.SetExperimentAsBaselineAsync(project, experiment, cancellationToken); + return $"Experiment '{experiment}' set as baseline for project '{project}'."; + } + + /// + /// Sets the baseline set for an experiment. + /// + /// The project name. + /// The experiment name. + /// The set name to use as baseline. + /// Token to cancel the operation. + /// A message indicating the baseline set was configured. + [McpServerTool(Name = "SetBaselineForExperiment"), Description("Set the baseline set for an experiment.")] + public async Task SetBaselineForExperiment( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + [Description("The set name to use as baseline")] string set, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + McpValidationHelper.ValidateRequiredName(set, "set"); + + await storageService.SetBaselineForExperiment(project, experiment, set, cancellationToken); + return $"Set '{set}' configured as baseline for experiment '{experiment}' in project '{project}'."; + } + + /// + /// Compares an experiment's sets (permutations) against the baseline using aggregate metrics. + /// This is the default tool for comparing permutations to the baseline. + /// + /// The project name. + /// The experiment name. + /// Optional comma-separated tag names to include. + /// Optional comma-separated tag names to exclude. + /// Token to cancel the operation. + /// The comparison result. + [McpServerTool(Name = "CompareExperiment"), Description("Compare an experiment's sets (permutations) against the baseline using aggregate metrics. This is the default tool for any question about how a permutation or set compared to the baseline. Returns aggregate metrics, project baseline, experiment baseline, and statistics.")] + public async Task CompareExperiment( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + [Description("Optional comma-separated tag names to include")] string includeTags = "", + [Description("Optional comma-separated tag names to exclude")] string excludeTags = "", + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + + return await experimentService.CompareAsync(project, experiment, includeTags, excludeTags, cancellationToken); + } + + /// + /// Breaks down a comparison per ref (ground truth), showing which individual ground truths improved or regressed. + /// Only use when the user specifically asks about individual ground truth performance. + /// For aggregate comparison of a permutation to the baseline, use instead. + /// + /// The project name. + /// The experiment name. + /// The set name to compare. + /// Optional comma-separated tag names to include. + /// Optional comma-separated tag names to exclude. + /// Token to cancel the operation. + /// The per-ref comparison result with baseline and set metrics for each ground truth. + [McpServerTool(Name = "CompareByRef"), Description("Break down a comparison per ref (ground truth) to identify which individual ground truths improved or regressed. Only use when the user asks about individual ground truth performance. For comparing a permutation to the baseline, use CompareExperiment instead.")] + public async Task CompareByRef( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + [Description("The set name to compare")] string set, + [Description("Optional comma-separated tag names to include")] string includeTags = "", + [Description("Optional comma-separated tag names to exclude")] string excludeTags = "", + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + McpValidationHelper.ValidateRequiredName(set, "set"); + + return await experimentService.CompareByRefAsync(project, experiment, set, includeTags, excludeTags, cancellationToken); + } + + /// + /// Gets per-result details for a named set in an experiment. + /// + /// The project name. + /// The experiment name. + /// The set name. + /// Optional comma-separated tag names to include. + /// Optional comma-separated tag names to exclude. + /// Token to cancel the operation. + /// The individual results for the named set. + [McpServerTool(Name = "GetNamedSet"), Description("Get per-result details for a named set in an experiment.")] + public async Task> GetNamedSet( + [Description("The project name")] string project, + [Description("The experiment name")] string experiment, + [Description("The set name")] string set, + [Description("Optional comma-separated tag names to include")] string includeTags = "", + [Description("Optional comma-separated tag names to exclude")] string excludeTags = "", + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + ValidateExperimentName(experiment); + McpValidationHelper.ValidateRequiredName(set, "set"); + + return await experimentService.GetNamedSetAsync(project, experiment, set, includeTags, excludeTags, cancellationToken); + } +} diff --git a/catalog/mcp/McpToolExceptionFilter.cs b/catalog/mcp/McpToolExceptionFilter.cs new file mode 100644 index 0000000..53014ef --- /dev/null +++ b/catalog/mcp/McpToolExceptionFilter.cs @@ -0,0 +1,63 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; + +namespace Catalog; + +/// +/// Marker class for logging from the MCP tool exception filter. +/// +public sealed class McpToolExceptionFilter +{ + private McpToolExceptionFilter() { } + + /// + /// Creates a filter that handles exceptions from MCP tool calls similar to HttpExceptionMiddleware. + /// + /// The filter function. + public static McpRequestFilter Create() + { + return next => async (context, cancellationToken) => + { + var logger = context.Services?.GetService>(); + var toolName = context.Params?.Name ?? "unknown"; + + try + { + return await next(context, cancellationToken); + } + catch (HttpWithResponseException ex) + { + logger?.LogWarning(ex, "MCP tool '{ToolName}' HTTP exception with response...", toolName); + return new CallToolResult + { + IsError = true, + Content = [new TextContentBlock { Text = ex.Message }] + }; + } + catch (HttpException ex) + { + logger?.LogWarning(ex, "MCP tool '{ToolName}' HTTP exception...", toolName); + return new CallToolResult + { + IsError = true, + Content = [new TextContentBlock { Text = ex.Message }] + }; + } + catch (Exception ex) + { + logger?.LogError(ex, "MCP tool '{ToolName}' internal exception...", toolName); + return new CallToolResult + { + IsError = true, + Content = [new TextContentBlock { Text = "There was an error processing the request." }] + }; + } + }; + } +} diff --git a/catalog/mcp/McpValidationHelper.cs b/catalog/mcp/McpValidationHelper.cs new file mode 100644 index 0000000..83a7d33 --- /dev/null +++ b/catalog/mcp/McpValidationHelper.cs @@ -0,0 +1,84 @@ +using System.Collections.Generic; + +namespace Catalog; + +/// +/// Provides validation methods for MCP tool parameters, matching the validation +/// performed by the API controllers via , +/// , and . +/// +public static class McpValidationHelper +{ + /// + /// Validates that a required string parameter is not null or empty and is a valid name. + /// + /// The value to validate. + /// The parameter name for error messages. + /// Thrown when validation fails. + public static void ValidateRequiredName(string? value, string parameterName) + { + if (string.IsNullOrWhiteSpace(value)) + { + throw new HttpException(400, $"The {parameterName} field is required."); + } + + if (!value.IsValidName()) + { + throw new HttpException(400, $"The {parameterName} field must contain only letters, digits, hyphens, underscores, periods, or colons (3-50 characters)."); + } + } + + /// + /// Validates a required project name using both name format and storage-specific rules. + /// + /// The project name to validate. + /// The storage service for project-specific validation. + /// Thrown when validation fails. + public static void ValidateProjectName(string? value, IStorageService storageService) + { + ValidateRequiredName(value, "project"); + + if (!storageService.TryValidProjectName(value, out string? errorMessage)) + { + throw new HttpException(400, errorMessage ?? "The project name is invalid."); + } + } + + /// + /// Validates a required experiment name using both name format and storage-specific rules. + /// + /// The experiment name to validate. + /// The storage service for experiment-specific validation. + /// Thrown when validation fails. + public static void ValidateExperimentName(string? value, IStorageService storageService) + { + ValidateRequiredName(value, "experiment"); + + if (!storageService.TryValidExperimentName(value, out string? errorMessage)) + { + throw new HttpException(400, errorMessage ?? "The experiment name is invalid."); + } + } + + /// + /// Validates an optional collection of names, ensuring each is a valid name if provided. + /// + /// The collection of names to validate. + /// The parameter name for error messages. + /// Thrown when any name in the collection is invalid. + public static void ValidateOptionalNames(IEnumerable? values, string parameterName) + { + if (values is null) + { + return; + } + + foreach (var value in values) + { + if (!value.IsValidName()) + { + throw new HttpException(400, $"The {parameterName} field contains an invalid name '{value}'. Names must contain only letters, digits, hyphens, underscores, periods, or colons (3-50 characters)."); + } + } + } +} diff --git a/catalog/mcp/ProjectsTools.cs b/catalog/mcp/ProjectsTools.cs new file mode 100644 index 0000000..5e4fd82 --- /dev/null +++ b/catalog/mcp/ProjectsTools.cs @@ -0,0 +1,78 @@ +using System.Collections.Generic; +using System.ComponentModel; +using System.Threading; +using System.Threading.Tasks; +using ModelContextProtocol.Server; + +namespace Catalog; + +/// +/// MCP tools for project management operations. +/// +[McpServerToolType] +public class ProjectsTools(IStorageService storageService) +{ + private void ValidateProjectName(string? project) => McpValidationHelper.ValidateProjectName(project, storageService); + + /// + /// Lists all projects. + /// + /// Token to cancel the operation. + /// A list of all projects. + [McpServerTool(Name = "ListProjects"), Description("List all projects.")] + public async Task> ListProjects( + CancellationToken cancellationToken = default) + { + return await storageService.GetProjectsAsync(cancellationToken); + } + + /// + /// Adds a new project. + /// + /// The project name. + /// Token to cancel the operation. + /// A message indicating the project was added. + [McpServerTool(Name = "AddProject"), Description("Add a new project.")] + public async Task AddProject( + [Description("The project name")] string name, + CancellationToken cancellationToken = default) + { + ValidateProjectName(name); + + var project = new Project { Name = name }; + await storageService.AddProjectAsync(project, cancellationToken); + return $"Project '{name}' added."; + } + + /// + /// Lists all tag names in a project. + /// + /// The project name. + /// Token to cancel the operation. + /// A list of tag names. + [McpServerTool(Name = "ListTags"), Description("List all tag names in a project.")] + public async Task> ListTags( + [Description("The project name")] string project, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + + return await storageService.ListTagsAsync(project, cancellationToken); + } + + /// + /// Gets the metric definitions for a project. + /// + /// The project name. + /// Token to cancel the operation. + /// A list of metric definitions. + [McpServerTool(Name = "GetMetricDefinitions"), Description("Get the metric definitions for a project.")] + public async Task> GetMetricDefinitions( + [Description("The project name")] string project, + CancellationToken cancellationToken = default) + { + ValidateProjectName(project); + + return await storageService.GetMetricsAsync(project, cancellationToken); + } +} diff --git a/catalog/models/MeaningfulTagsRequest.cs b/catalog/models/MeaningfulTagsRequest.cs index 0fdd3e4..d7cd8d3 100644 --- a/catalog/models/MeaningfulTagsRequest.cs +++ b/catalog/models/MeaningfulTagsRequest.cs @@ -5,6 +5,7 @@ namespace Catalog; +[System.Text.Json.Serialization.JsonConverter(typeof(JsonStringEnumConverter))] // for MCP support public enum MeaningfulTagsComparisonMode { Baseline, diff --git a/catalog/services/AnalysisService.cs b/catalog/services/AnalysisService.cs new file mode 100644 index 0000000..e8d11d5 --- /dev/null +++ b/catalog/services/AnalysisService.cs @@ -0,0 +1,82 @@ +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; + +namespace Catalog; + +/// +/// Provides analysis operations for experiments and metrics. +/// +public class AnalysisService(IStorageService storageService) +{ + /// + /// Analyzes which tags have the most meaningful impact on a specific metric. + /// + /// The meaningful tags request. + /// Token to cancel the operation. + /// A response containing tags ordered by their impact. + public async Task GetMeaningfulTagsAsync( + MeaningfulTagsRequest request, + CancellationToken cancellationToken = default) + { + var diffs = new List(); + + var experiment = await storageService.GetExperimentAsync( + request.Project, + request.Experiment, + cancellationToken: cancellationToken); + + var baseline = request.CompareTo == MeaningfulTagsComparisonMode.Baseline + ? await storageService.GetProjectBaselineAsync(request.Project, cancellationToken) + : null; + + var listOfTags = await storageService.ListTagsAsync(request.Project, cancellationToken); + var includeTags = await storageService.GetTagsAsync(request.Project, listOfTags, cancellationToken); + var excludeTags = request.ExcludeTags is not null + ? await storageService.GetTagsAsync(request.Project, request.ExcludeTags, cancellationToken) + : null; + + var compareToDefault = 0.0M; + if (request.CompareTo == MeaningfulTagsComparisonMode.Average) + { + var results = experiment.Filter(null, excludeTags); + var experimentResult = experiment.AggregateSet(request.Set, results); + Metric? experimentMetric = null; + experimentResult?.Metrics?.TryGetValue(request.Metric, out experimentMetric); + compareToDefault = experimentMetric?.Value ?? 0.0M; + } + + foreach (var tag in includeTags) + { + var experimentResults = experiment.Filter([tag], excludeTags); + var experimentResult = experiment.AggregateSet(request.Set, experimentResults); + Metric? experimentTagMetric = null; + experimentResult?.Metrics?.TryGetValue(request.Metric, out experimentTagMetric); + + decimal? compareTo = compareToDefault; + if (baseline is not null) + { + var baselineResults = baseline.Filter([tag], excludeTags); + var baselineResult = baseline.AggregateSet(baseline.BaselineSet ?? baseline.LastSet, baselineResults); + Metric? baselineTagMetric = null; + baselineResult?.Metrics?.TryGetValue(request.Metric, out baselineTagMetric); + compareTo = baselineTagMetric?.Value; + } + + if (experimentTagMetric?.Value is not null && compareTo is not null) + { + var diff = (decimal)(experimentTagMetric.Value - compareTo); + diffs.Add(new TagDiff + { + Tag = tag.Name, + Diff = diff, + Impact = diff * (experimentTagMetric.Count ?? 0), + Count = experimentTagMetric.Count, + }); + } + } + + return new MeaningfulTagsResponse { Tags = diffs.OrderBy(x => x.Impact) }; + } +} diff --git a/catalog/services/ExperimentService.cs b/catalog/services/ExperimentService.cs new file mode 100644 index 0000000..596bad2 --- /dev/null +++ b/catalog/services/ExperimentService.cs @@ -0,0 +1,279 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using NetBricks; + +namespace Catalog; + +/// +/// Provides comparison operations for experiments, including aggregate and per-ref comparisons. +/// +public class ExperimentService( + ILogger logger, + IStorageService storageService, + IConfigFactory configFactory) +{ + /// + /// Lists the distinct set names for an experiment. + /// + /// The project name. + /// The experiment name. + /// Token to cancel the operation. + /// The list of set names. + public async Task> ListSetsForExperimentAsync( + string projectName, + string experimentName, + CancellationToken cancellationToken = default) + { + var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); + return experiment.Sets.ToList(); + } + + /// + /// Loads include and exclude tags from comma-separated name strings. + /// + /// The project name. + /// Comma-separated include tag names. + /// Comma-separated exclude tag names. + /// Token to cancel the operation. + /// The resolved include and exclude tag lists. + public async Task<(IList IncludeTags, IList ExcludeTags)> LoadTagsAsync( + string projectName, + string includeTagsStr, + string excludeTagsStr, + CancellationToken cancellationToken = default) + { + var includeTags = await storageService.GetTagsAsync(projectName, includeTagsStr.AsArray(() => [])!, cancellationToken); + var excludeTags = await storageService.GetTagsAsync(projectName, excludeTagsStr.AsArray(() => [])!, cancellationToken); + return (includeTags, excludeTags); + } + + /// + /// Compares an experiment's sets against its baseline, including project baseline and statistics. + /// + /// The project name. + /// The experiment name. + /// Comma-separated include tag names. + /// Comma-separated exclude tag names. + /// Token to cancel the operation. + /// The comparison result. + public async Task CompareAsync( + string projectName, + string experimentName, + string includeTagsStr = "", + string excludeTagsStr = "", + CancellationToken cancellationToken = default) + { + var comparison = new Comparison(); + var (includeTags, excludeTags) = await LoadTagsAsync(projectName, includeTagsStr, excludeTagsStr, cancellationToken); + comparison.MetricDefinitions = (await storageService.GetMetricsAsync(projectName, cancellationToken)) + .ToDictionary(x => x.Name); + + // get the project baseline + try + { + var baseline = await storageService.GetProjectBaselineAsync(projectName, cancellationToken); + var baselineSet = baseline.BaselineSet ?? baseline.LastSet; + var baselineFiltered = baseline.Filter(includeTags, excludeTags); + baseline.MetricDefinitions = comparison.MetricDefinitions; + comparison.ProjectBaseline = new ComparisonEntity + { + Project = projectName, + Experiment = baseline.Name, + Set = baselineSet, + Result = baseline.AggregateSet(baselineSet, baselineFiltered), + Count = baseline.Results?.Count(x => x.Set == baselineSet), // unfiltered count + }; + } + catch (Exception e) + { + logger.LogWarning(e, "Failed to get baseline experiment for project {projectName}.", projectName); + } + + // get configuration + var config = await configFactory.GetAsync(cancellationToken); + + // get the experiment baseline + var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); + var experimentBaselineSet = experiment.BaselineSet ?? experiment.FirstSet; + var experimentFiltered = experiment.Filter(includeTags, excludeTags); + experiment.MetricDefinitions = comparison.MetricDefinitions; + comparison.ExperimentBaseline = + string.Equals(experiment.Baseline, ":project", StringComparison.OrdinalIgnoreCase) + ? comparison.ProjectBaseline + : new ComparisonEntity + { + Project = projectName, + Experiment = experiment.Name, + Set = experimentBaselineSet, + Result = experiment.AggregateSet(experimentBaselineSet, experimentFiltered), + Count = experiment.Results?.Count(x => x.Set == experimentBaselineSet), // unfiltered count + }; + + // get the sets + comparison.Sets = experiment.AggregateAllSets(experimentFiltered) + .Select(x => + { + // find matching statistics + var statistics = experiment.Statistics?.LastOrDefault(y => + { + if (y.Set != x.Set) return false; + if (y.BaselineExperiment != comparison.ExperimentBaseline?.Experiment) return false; + if (y.BaselineSet != comparison.ExperimentBaseline?.Set) return false; + if (y.BaselineResultCount != comparison.ExperimentBaseline?.Count) return false; + if (y.SetResultCount != experiment.Results?.Count(z => z.Set == x.Set)) return false; // unfiltered count + if (y.NumSamples != config.CALC_PVALUES_USING_X_SAMPLES) return false; + if (y.ConfidenceLevel != config.CONFIDENCE_LEVEL) return false; + return true; + }); + + // fold statistics into result metrics + if (statistics?.Metrics is not null && x.Metrics is not null) + { + foreach (var (metricName, statisticsMetric) in statistics.Metrics) + { + if (x.Metrics.TryGetValue(metricName, out var resultMetric)) + { + resultMetric.PValue = statisticsMetric.PValue; + resultMetric.CILower = statisticsMetric.CILower; + resultMetric.CIUpper = statisticsMetric.CIUpper; + } + } + } + + return new ComparisonEntity + { + Project = projectName, + Experiment = experiment.Name, + Set = x.Set, + Result = x, + }; + }); + + return comparison; + } + + /// + /// Compares an experiment set against its baseline on a per-ref basis. + /// + /// The project name. + /// The experiment name. + /// The set name to compare. + /// Comma-separated include tag names. + /// Comma-separated exclude tag names. + /// Token to cancel the operation. + /// The per-ref comparison result. + public async Task CompareByRefAsync( + string projectName, + string experimentName, + string setName, + string includeTagsStr = "", + string excludeTagsStr = "", + CancellationToken cancellationToken = default) + { + var comparison = new ComparisonByRef(); + var (includeTags, excludeTags) = await LoadTagsAsync(projectName, includeTagsStr, excludeTagsStr, cancellationToken); + comparison.MetricDefinitions = (await storageService.GetMetricsAsync(projectName, cancellationToken)) + .ToDictionary(x => x.Name); + + // get the project baseline + try + { + var baseline = await storageService.GetProjectBaselineAsync(projectName, cancellationToken); + var baselineFiltered = baseline.Filter(includeTags, excludeTags); + baseline.MetricDefinitions = comparison.MetricDefinitions; + comparison.ProjectBaseline = new ComparisonByRefEntity + { + Project = projectName, + Experiment = baseline.Name, + Set = baseline.BaselineSet ?? baseline.LastSet, + Results = baseline.AggregateSetByRef(baseline.BaselineSet ?? baseline.LastSet, baselineFiltered), + }; + } + catch (Exception e) + { + logger.LogWarning(e, "Failed to get baseline experiment for project {projectName}.", projectName); + } + + // get the experiment info + var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); + var experimentFiltered = experiment.Filter(includeTags, excludeTags); + experiment.MetricDefinitions = comparison.MetricDefinitions; + + // get the experiment baseline + if (string.Equals(experiment.Baseline, ":project", StringComparison.OrdinalIgnoreCase)) + { + comparison.ExperimentBaseline = comparison.ProjectBaseline; + } + else + { + comparison.ExperimentBaseline = new ComparisonByRefEntity + { + Project = projectName, + Experiment = experiment.Name, + Set = experiment.BaselineSet ?? experiment.FirstSet, + Results = experiment.AggregateSetByRef(experiment.BaselineSet ?? experiment.FirstSet, experimentFiltered), + }; + } + + // get the set experiment + comparison.ExperimentSet = new ComparisonByRefEntity + { + Project = projectName, + Experiment = experiment.Name, + Set = setName, + Results = experiment.AggregateSetByRef(setName, experimentFiltered), + }; + + return comparison; + } + + /// + /// Gets per-result details for a named set in an experiment, with optional support doc URI formatting. + /// + /// The project name. + /// The experiment name. + /// The set name. + /// Comma-separated include tag names. + /// Comma-separated exclude tag names. + /// Token to cancel the operation. + /// The individual results for the named set. + public async Task> GetNamedSetAsync( + string projectName, + string experimentName, + string setName, + string includeTagsStr = "", + string excludeTagsStr = "", + CancellationToken cancellationToken = default) + { + // init + var metricDefinitions = (await storageService.GetMetricsAsync(projectName, cancellationToken)) + .ToDictionary(x => x.Name); + + // get the experiment and filter the results + var experiment = await storageService.GetExperimentAsync(projectName, experimentName, cancellationToken: cancellationToken); + var (includeTags, excludeTags) = await LoadTagsAsync(projectName, includeTagsStr, excludeTagsStr, cancellationToken); + var experimentFiltered = experiment.Filter(includeTags, excludeTags); + experiment.MetricDefinitions = metricDefinitions; + + // get the results + var results = experiment.AggregateSetByEachResult(setName, experimentFiltered) + ?? Enumerable.Empty(); + + // add the support docs + var config = await configFactory.GetAsync(cancellationToken); + if (!string.IsNullOrEmpty(config.PATH_TEMPLATE)) + { + foreach (var result in results) + { + if (!string.IsNullOrEmpty(result.InferenceUri)) result.InferenceUri = string.Format(config.PATH_TEMPLATE, result.InferenceUri); + if (!string.IsNullOrEmpty(result.EvaluationUri)) result.EvaluationUri = string.Format(config.PATH_TEMPLATE, result.EvaluationUri); + } + } + + return results; + } +} diff --git a/evaluator/Ext.cs b/evaluator/Ext.cs index e3b274f..02a3c6b 100644 --- a/evaluator/Ext.cs +++ b/evaluator/Ext.cs @@ -139,7 +139,7 @@ public static int[] AsIntArray(this string? value, Func dflt) { if (string.IsNullOrEmpty(value)) return dflt(); var total = new List(); - foreach (var raw in value.AsArray(() => [])) + foreach (var raw in value.AsArray(() => [])!) { if (int.TryParse(raw, out var valid)) total.Add(valid); } diff --git a/evaluator/Program.cs b/evaluator/Program.cs index 3d23b3d..1b7ad61 100644 --- a/evaluator/Program.cs +++ b/evaluator/Program.cs @@ -1,119 +1,83 @@ using System; -using System.Threading; using System.Threading.Tasks; using dotenv.net; using Evaluator; using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using NetBricks; // Load environment variables from .env file var ENV_FILES = System.Environment.GetEnvironmentVariable("ENV_FILES").AsArray(() => [".env"]); -Console.WriteLine($"ENV_FILES = {string.Join(", ", ENV_FILES)}"); +Console.WriteLine($"ENV_FILES = {string.Join(", ", ENV_FILES!)}"); DotEnv.Load(new DotEnvOptions(envFilePaths: ENV_FILES, overwriteExistingVars: false)); // create the web application var builder = WebApplication.CreateBuilder(args); -// add config -var netConfig = new NetBricks.Config(); -await netConfig.Apply(); -var config = new Evaluator.Config(netConfig); -config.Validate(); -builder.Services.AddSingleton(config); -builder.Services.AddSingleton(netConfig); - -// add credentials if connection string is not provided -if (string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING)) -{ - builder.Services.AddDefaultAzureCredential(); -} +// add config using NetBricks +builder.Services.AddHttpClient(); +builder.Services.AddDefaultAzureCredential(); +builder.Services.AddConfig(); // add logging builder.Logging.ClearProviders(); builder.Services.AddSingleLineConsoleLogger(); -if (!string.IsNullOrEmpty(config.OPEN_TELEMETRY_CONNECTION_STRING)) +builder.Logging.AddFilter("Microsoft.AspNetCore.Mvc.ModelBinding", LogLevel.Warning); +builder.Logging.AddFilter("Microsoft.AspNetCore.Server.Kestrel.Connections", LogLevel.Warning); +builder.Logging.AddFilter("Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets", LogLevel.Warning); + +// configure OpenTelemetry logging early using IConfiguration (before full config is available) +// NOTE: It is unfortunate, but there appears to be no way to add OpenTelemetry in an async +// manner such that config could pull from App Config or Key Vault at startup. +var openTelemetryConnectionString = builder.Configuration["OPEN_TELEMETRY_CONNECTION_STRING"]; +if (!string.IsNullOrEmpty(openTelemetryConnectionString)) { - builder.Logging.AddOpenTelemetry(config.OPEN_TELEMETRY_CONNECTION_STRING); - builder.Services.AddOpenTelemetry(DiagnosticService.Source.Name, builder.Environment.ApplicationName, config.OPEN_TELEMETRY_CONNECTION_STRING); + builder.Logging.AddOpenTelemetry(openTelemetryConnectionString); + builder.Services.AddOpenTelemetry("evaluator", builder.Environment.ApplicationName, openTelemetryConnectionString); } -// add http client -builder.Services.AddHttpClient(); - // add API services -if (config.ROLES.Contains(Roles.API)) +builder.Services.AddHostedService(); +builder.Services.AddControllers().AddNewtonsoftJson(); +builder.Services.AddEndpointsApiExplorer(); +builder.Services.AddSwaggerGen().AddSwaggerGenNewtonsoftSupport(); +builder.Services.AddCors(options => { - Console.WriteLine("ADDING SERVICE: AzureStorageQueueWriter"); - builder.Services.AddHostedService(); - builder.Services.AddControllers().AddNewtonsoftJson(); - builder.Services.AddEndpointsApiExplorer(); - builder.Services.AddSwaggerGen().AddSwaggerGenNewtonsoftSupport(); - builder.Services.AddCors(options => - { - options.AddPolicy("default-policy", - builder => - { - builder.WithOrigins("http://localhost:6020") - .AllowAnyHeader() - .AllowAnyMethod(); - }); - }); - builder.WebHost.UseKestrel(options => + options.AddPolicy("default-policy", + builder => { - options.ListenAnyIP(config.PORT); + builder.WithOrigins("http://localhost:6020") + .AllowAnyHeader() + .AllowAnyMethod(); }); -} -else -{ - // NOTE: This does not expose a port - builder.WebHost.UseTestServer(); -} +}); -// add InferenceProxy services -if (config.ROLES.Contains(Roles.InferenceProxy)) -{ - Console.WriteLine("ADDING SERVICE: AzureStorageQueueReaderForInference"); - builder.Services.AddHostedService(); -} - -// add EvaluationProxy services -if (config.ROLES.Contains(Roles.EvaluationProxy)) -{ - Console.WriteLine("ADDING SERVICE: AzureStorageQueueReaderForEvaluation"); - builder.Services.AddHostedService(); -} +// configure Kestrel using IConfigFactory +builder.Services.AddSingleton, KestrelConfigurator>(); -// add maintenance service -if (config.MINUTES_BETWEEN_RESTORE_AFTER_BUSY > 0) -{ - Console.WriteLine("ADDING SERVICE: Maintenance"); - builder.Services.AddHostedService(); -} +// add InferenceProxy, EvaluationProxy, and Maintenance services +builder.Services.AddHostedService(); +builder.Services.AddHostedService(); +builder.Services.AddHostedService(); // build var app = builder.Build(); -// add API endpoints and routing -if (config.ROLES.Contains(Roles.API)) -{ - // use swagger - app.UseSwagger(); - app.UseSwaggerUI(); +// use swagger (API only) +app.UseSwagger(); +app.UseSwaggerUI(); - // use CORS - app.UseCors("default-policy"); +// use CORS (API only) +app.UseCors("default-policy"); - // add endpoints - app.UseRouting(); - app.UseMiddleware(); - app.MapControllers(); -} +// add endpoints (API only) +app.UseRouting(); +app.UseMiddleware(); +app.MapControllers(); // run await app.RunAsync(); diff --git a/evaluator/config/Config.cs b/evaluator/config/Config.cs index 1d92749..e373a08 100644 --- a/evaluator/config/Config.cs +++ b/evaluator/config/Config.cs @@ -1,246 +1,339 @@ using System; using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; using System.IO; using System.Linq; using NetBricks; namespace Evaluator; -public class Config : IConfig +[LogConfig("Configuration:")] +public class Config : IConfig, IValidatableObject { - private readonly NetBricks.IConfig config; + private readonly List invalidRoles = []; - public Config(NetBricks.IConfig config) - { - this.config = config; - this.PORT = this.config.Get("PORT").AsInt(() => 6030); + [SetValue("PORT")] + public int PORT { get; set; } = 6030; - this.ROLES = this.config.Get("ROLES", value => - { - List roles = []; - var list = value.AsArray(() => throw new Exception("ROLES must be an array of strings")); - foreach (var entry in list) - { - var role = entry.AsEnum(() => throw new Exception("each ROLE must be one of API, InferenceProxy, or EvaluationProxy.")); - roles.Add(role); - } - return roles; - }); - - this.OPEN_TELEMETRY_CONNECTION_STRING = this.config.GetSecret("OPEN_TELEMETRY_CONNECTION_STRING").Result; - this.AZURE_STORAGE_ACCOUNT_NAME = this.config.Get("AZURE_STORAGE_ACCOUNT_NAME"); - this.AZURE_STORAGE_CONNECTION_STRING = this.config.GetSecret("AZURE_STORAGE_CONNECTION_STRING").Result; - this.INFERENCE_CONTAINER = this.config.Get("INFERENCE_CONTAINER"); - this.EVALUATION_CONTAINER = this.config.Get("EVALUATION_CONTAINER"); - this.INBOUND_INFERENCE_QUEUES = this.config.Get("INBOUND_INFERENCE_QUEUES").AsArray(() => []); - this.INBOUND_EVALUATION_QUEUES = this.config.Get("INBOUND_EVALUATION_QUEUES").AsArray(() => []); - this.OUTBOUND_INFERENCE_QUEUE = this.config.Get("OUTBOUND_INFERENCE_QUEUE"); - this.INFERENCE_CONCURRENCY = this.config.Get("INFERENCE_CONCURRENCY, CONCURRENCY").AsInt(() => 1); - this.EVALUATION_CONCURRENCY = this.config.Get("EVALUATION_CONCURRENCY, CONCURRENCY").AsInt(() => 1); - this.MS_TO_PAUSE_WHEN_EMPTY = this.config.Get("MS_TO_PAUSE_WHEN_EMPTY").AsInt(() => 500); - this.DEQUEUE_FOR_X_SECONDS = this.config.Get("DEQUEUE_FOR_X_SECONDS").AsInt(() => 300); - this.MS_BETWEEN_DEQUEUE = this.config.Get("MS_BETWEEN_DEQUEUE").AsInt(() => 0); - this.MS_BETWEEN_DEQUEUE_CURRENT = this.MS_BETWEEN_DEQUEUE; - this.MAX_ATTEMPTS_TO_DEQUEUE = this.config.Get("MAX_ATTEMPTS_TO_DEQUEUE").AsInt(() => 5); - this.MS_TO_ADD_ON_BUSY = this.config.Get("MS_TO_ADD_ON_BUSY").AsInt(() => 0); - this.MINUTES_BETWEEN_RESTORE_AFTER_BUSY = this.config.Get("MINUTES_BETWEEN_RESTORE_AFTER_BUSY").AsInt(() => 0); - this.INFERENCE_URL = this.config.Get("INFERENCE_URL"); - this.EVALUATION_URL = this.config.Get("EVALUATION_URL"); - this.SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING = this.config.Get("SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING").AsInt(() => 300); - this.BACKOFF_ON_STATUS_CODES = this.config.Get("BACKOFF_ON_STATUS_CODES").AsIntArray(() => [429]); - this.DEADLETTER_ON_STATUS_CODES = this.config.Get("DEADLETTER_ON_STATUS_CODES").AsIntArray(() => [400, 401, 403, 404, 405]); - this.EXPERIMENT_CATALOG_BASE_URL = this.config.Get("EXPERIMENT_CATALOG_BASE_URL"); - - this.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE = this.config.Get("INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE, INBOUND_GROUNDTRUTH_TRANSFORM_FILE"); - this.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY = config.Get("INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY, INBOUND_GROUNDTRUTH_TRANSFORM_QUERY").AsString(() => - { - return string.IsNullOrEmpty(this.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE) - ? string.Empty - : File.ReadAllText(this.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE); - }); + [SetValue("ROLES")] + [LogConfig(mode: LogConfigMode.Never)] + public string[]? ROLES_RAW { get; set; } - this.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE = this.config.Get("INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE, INBOUND_GROUNDTRUTH_TRANSFORM_FILE"); - this.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY = config.Get("INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY, INBOUND_GROUNDTRUTH_TRANSFORM_QUERY").AsString(() => - { - return string.IsNullOrEmpty(this.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE) - ? string.Empty - : File.ReadAllText(this.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE); - }); + public List ROLES { get; set; } = []; - this.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE = this.config.Get("INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE, INBOUND_GROUNDTRUTH_TRANSFORM_FILE"); - this.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY = config.Get("INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY, INBOUND_GROUNDTRUTH_TRANSFORM_QUERY").AsString(() => - { - return string.IsNullOrEmpty(this.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE) - ? string.Empty - : File.ReadAllText(this.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE); - }); + [SetValue("OPEN_TELEMETRY_CONNECTION_STRING")] + [ResolveSecret] + [LogConfig(mode: LogConfigMode.Masked)] + public string? OPEN_TELEMETRY_CONNECTION_STRING { get; set; } - this.INBOUND_INFERENCE_TRANSFORM_FILE = this.config.Get("INBOUND_INFERENCE_TRANSFORM_FILE"); - this.INBOUND_INFERENCE_TRANSFORM_QUERY = config.Get("INBOUND_INFERENCE_TRANSFORM_QUERY").AsString(() => - { - return string.IsNullOrEmpty(this.INBOUND_INFERENCE_TRANSFORM_FILE) - ? string.Empty - : File.ReadAllText(this.INBOUND_INFERENCE_TRANSFORM_FILE); - }); + [SetValue("AZURE_STORAGE_ACCOUNT_NAME")] + public string? AZURE_STORAGE_ACCOUNT_NAME { get; set; } - this.INBOUND_EVALUATION_TRANSFORM_FILE = this.config.Get("INBOUND_EVALUATION_TRANSFORM_FILE"); - this.INBOUND_EVALUATION_TRANSFORM_QUERY = config.Get("INBOUND_EVALUATION_TRANSFORM_QUERY").AsString(() => - { - return string.IsNullOrEmpty(this.INBOUND_EVALUATION_TRANSFORM_FILE) - ? string.Empty - : File.ReadAllText(this.INBOUND_EVALUATION_TRANSFORM_FILE); - }); + [SetValue("AZURE_STORAGE_CONNECTION_STRING")] + [ResolveSecret] + [LogConfig(mode: LogConfigMode.Masked)] + public string? AZURE_STORAGE_CONNECTION_STRING { get; set; } - this.PROCESS_METRICS_IN_INFERENCE_RESPONSE = this.config.Get("PROCESS_METRICS_IN_INFERENCE_RESPONSE").AsBool(() => false); - this.PROCESS_METRICS_IN_EVALUATION_RESPONSE = this.config.Get("PROCESS_METRICS_IN_EVALUATION_RESPONSE").AsBool(() => true); - } + [SetValue("INFERENCE_CONTAINER")] + public string? INFERENCE_CONTAINER { get; set; } - public int PORT { get; } + [SetValue("EVALUATION_CONTAINER")] + public string? EVALUATION_CONTAINER { get; set; } - public List ROLES { get; } + [SetValue("INBOUND_INFERENCE_QUEUES")] + public string[] INBOUND_INFERENCE_QUEUES { get; set; } = []; - public string OPEN_TELEMETRY_CONNECTION_STRING { get; } + [SetValue("INBOUND_EVALUATION_QUEUES")] + public string[] INBOUND_EVALUATION_QUEUES { get; set; } = []; - public string AZURE_STORAGE_ACCOUNT_NAME { get; } + [SetValue("OUTBOUND_INFERENCE_QUEUE")] + public string? OUTBOUND_INFERENCE_QUEUE { get; set; } - public string AZURE_STORAGE_CONNECTION_STRING { get; } + [SetValue("INFERENCE_CONCURRENCY", "CONCURRENCY")] + [Range(1, 100)] + public int INFERENCE_CONCURRENCY { get; set; } = 1; - public string INFERENCE_CONTAINER { get; } + [SetValue("EVALUATION_CONCURRENCY", "CONCURRENCY")] + [Range(1, 100)] + public int EVALUATION_CONCURRENCY { get; set; } = 1; - public string EVALUATION_CONTAINER { get; } + [SetValue("MS_TO_PAUSE_WHEN_EMPTY")] + public int MS_TO_PAUSE_WHEN_EMPTY { get; set; } = 500; - public string[] INBOUND_INFERENCE_QUEUES { get; } + [SetValue("DEQUEUE_FOR_X_SECONDS")] + public int DEQUEUE_FOR_X_SECONDS { get; set; } = 300; - public string[] INBOUND_EVALUATION_QUEUES { get; } + [SetValue("MS_BETWEEN_DEQUEUE")] + public int MS_BETWEEN_DEQUEUE { get; set; } = 0; - public string OUTBOUND_INFERENCE_QUEUE { get; } + public int MS_BETWEEN_DEQUEUE_CURRENT { get; set; } - public int INFERENCE_CONCURRENCY { get; } + [SetValue("MAX_ATTEMPTS_TO_DEQUEUE")] + public int MAX_ATTEMPTS_TO_DEQUEUE { get; set; } = 5; - public int EVALUATION_CONCURRENCY { get; } + [SetValue("MS_TO_ADD_ON_BUSY")] + public int MS_TO_ADD_ON_BUSY { get; set; } = 0; - public int MS_TO_PAUSE_WHEN_EMPTY { get; } + [SetValue("MINUTES_BETWEEN_RESTORE_AFTER_BUSY")] + public int MINUTES_BETWEEN_RESTORE_AFTER_BUSY { get; set; } = 0; - public int DEQUEUE_FOR_X_SECONDS { get; } + [SetValue("INFERENCE_URL")] + public string? INFERENCE_URL { get; set; } - public int MS_BETWEEN_DEQUEUE { get; } + [SetValue("EVALUATION_URL")] + public string? EVALUATION_URL { get; set; } - public int MS_BETWEEN_DEQUEUE_CURRENT { get; set; } + [SetValue("SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING")] + public int SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING { get; set; } = 300; - public int MAX_ATTEMPTS_TO_DEQUEUE { get; } + [SetValue("BACKOFF_ON_STATUS_CODES")] + [LogConfig(mode: LogConfigMode.Never)] + public string[]? BACKOFF_ON_STATUS_CODES_RAW { get; set; } - public int MS_TO_ADD_ON_BUSY { get; } + public int[] BACKOFF_ON_STATUS_CODES { get; set; } = [429]; - public int MINUTES_BETWEEN_RESTORE_AFTER_BUSY { get; } + [SetValue("DEADLETTER_ON_STATUS_CODES")] + [LogConfig(mode: LogConfigMode.Never)] + public string[]? DEADLETTER_ON_STATUS_CODES_RAW { get; set; } - public string INFERENCE_URL { get; } + public int[] DEADLETTER_ON_STATUS_CODES { get; set; } = [400, 401, 403, 404, 405]; - public string EVALUATION_URL { get; } + [SetValue("EXPERIMENT_CATALOG_BASE_URL")] + public string? EXPERIMENT_CATALOG_BASE_URL { get; set; } - public int SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING { get; } + [SetValue("INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE", "INBOUND_GROUNDTRUTH_TRANSFORM_FILE")] + public string? INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE { get; set; } - public int[] BACKOFF_ON_STATUS_CODES { get; } + [SetValue("INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY", "INBOUND_GROUNDTRUTH_TRANSFORM_QUERY")] + public string? INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY { get; set; } - public int[] DEADLETTER_ON_STATUS_CODES { get; } + [SetValue("INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE", "INBOUND_GROUNDTRUTH_TRANSFORM_FILE")] + public string? INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE { get; set; } - public string EXPERIMENT_CATALOG_BASE_URL { get; } + [SetValue("INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY", "INBOUND_GROUNDTRUTH_TRANSFORM_QUERY")] + public string? INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY { get; set; } - public string INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE { get; } + [SetValue("INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE", "INBOUND_GROUNDTRUTH_TRANSFORM_FILE")] + public string? INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE { get; set; } - public string INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY { get; } + [SetValue("INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY", "INBOUND_GROUNDTRUTH_TRANSFORM_QUERY")] + public string? INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY { get; set; } - public string INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE { get; } + [SetValue("INBOUND_INFERENCE_TRANSFORM_FILE")] + public string? INBOUND_INFERENCE_TRANSFORM_FILE { get; set; } - public string INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY { get; } + [SetValue("INBOUND_INFERENCE_TRANSFORM_QUERY")] + public string? INBOUND_INFERENCE_TRANSFORM_QUERY { get; set; } - public string INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE { get; } + [SetValue("INBOUND_EVALUATION_TRANSFORM_FILE")] + public string? INBOUND_EVALUATION_TRANSFORM_FILE { get; set; } - public string INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY { get; } + [SetValue("INBOUND_EVALUATION_TRANSFORM_QUERY")] + public string? INBOUND_EVALUATION_TRANSFORM_QUERY { get; set; } - public string INBOUND_INFERENCE_TRANSFORM_FILE { get; } + [SetValue("PROCESS_METRICS_IN_INFERENCE_RESPONSE")] + public bool PROCESS_METRICS_IN_INFERENCE_RESPONSE { get; set; } = false; - public string INBOUND_INFERENCE_TRANSFORM_QUERY { get; } + [SetValue("PROCESS_METRICS_IN_EVALUATION_RESPONSE")] + public bool PROCESS_METRICS_IN_EVALUATION_RESPONSE { get; set; } = true; + + [SetValues] + public void ApplyDerivedValues() + { + ROLES = ParseRoles(ROLES_RAW); + BACKOFF_ON_STATUS_CODES = ParseIntArray(BACKOFF_ON_STATUS_CODES_RAW, [429]); + DEADLETTER_ON_STATUS_CODES = ParseIntArray(DEADLETTER_ON_STATUS_CODES_RAW, [400, 401, 403, 404, 405]); + + INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY = ResolveTransformQuery( + INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY, + INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE); + INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY = ResolveTransformQuery( + INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY, + INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE); + INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY = ResolveTransformQuery( + INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY, + INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE); + INBOUND_INFERENCE_TRANSFORM_QUERY = ResolveTransformQuery( + INBOUND_INFERENCE_TRANSFORM_QUERY, + INBOUND_INFERENCE_TRANSFORM_FILE); + INBOUND_EVALUATION_TRANSFORM_QUERY = ResolveTransformQuery( + INBOUND_EVALUATION_TRANSFORM_QUERY, + INBOUND_EVALUATION_TRANSFORM_FILE); + + MS_BETWEEN_DEQUEUE_CURRENT = MS_BETWEEN_DEQUEUE; + } - public string INBOUND_EVALUATION_TRANSFORM_FILE { get; } + public IEnumerable Validate(ValidationContext validationContext) + { + if (ROLES.Count == 0) + { + yield return new ValidationResult("ROLES must include at least one role.", new[] { nameof(ROLES) }); + } - public string INBOUND_EVALUATION_TRANSFORM_QUERY { get; } + if (invalidRoles.Count > 0) + { + yield return new ValidationResult( + $"ROLES contains invalid values: {string.Join(", ", invalidRoles)}.", + new[] { nameof(ROLES) }); + } - public bool PROCESS_METRICS_IN_INFERENCE_RESPONSE { get; } + if (string.IsNullOrEmpty(AZURE_STORAGE_ACCOUNT_NAME) && string.IsNullOrEmpty(AZURE_STORAGE_CONNECTION_STRING)) + { + yield return new ValidationResult( + "Either AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be set.", + new[] { nameof(AZURE_STORAGE_ACCOUNT_NAME), nameof(AZURE_STORAGE_CONNECTION_STRING) }); + } - public bool PROCESS_METRICS_IN_EVALUATION_RESPONSE { get; } + var hasInference = ROLES.Contains(Roles.InferenceProxy); + var hasEvaluation = ROLES.Contains(Roles.EvaluationProxy); - public void Validate() - { - // applies regardless of role - this.config.Require("PORT", this.PORT.ToString()); - this.config.Require("ROLES", this.ROLES.Select(r => r.ToString()).ToArray()); - this.config.Optional("OPEN_TELEMETRY_CONNECTION_STRING", OPEN_TELEMETRY_CONNECTION_STRING, hideValue: true); - - this.config.Optional("AZURE_STORAGE_ACCOUNT_NAME", this.AZURE_STORAGE_ACCOUNT_NAME); - this.config.Optional("AZURE_STORAGE_CONNECTION_STRING", this.AZURE_STORAGE_CONNECTION_STRING, hideValue: true); - if (string.IsNullOrEmpty(this.AZURE_STORAGE_ACCOUNT_NAME) && string.IsNullOrEmpty(this.AZURE_STORAGE_CONNECTION_STRING)) + if (hasInference) { - throw new Exception("Either AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be specified."); + if (string.IsNullOrEmpty(INFERENCE_CONTAINER)) + { + yield return new ValidationResult( + "INFERENCE_CONTAINER must be set when using the InferenceProxy role.", + new[] { nameof(INFERENCE_CONTAINER) }); + } + + if (string.IsNullOrEmpty(INFERENCE_URL)) + { + yield return new ValidationResult( + "INFERENCE_URL must be set when using the InferenceProxy role.", + new[] { nameof(INFERENCE_URL) }); + } + + if (INBOUND_INFERENCE_QUEUES.Length == 0) + { + yield return new ValidationResult( + "INBOUND_INFERENCE_QUEUES must be set when using the InferenceProxy role.", + new[] { nameof(INBOUND_INFERENCE_QUEUES) }); + } + + if (string.IsNullOrEmpty(OUTBOUND_INFERENCE_QUEUE)) + { + yield return new ValidationResult( + "OUTBOUND_INFERENCE_QUEUE must be set when using the InferenceProxy role.", + new[] { nameof(OUTBOUND_INFERENCE_QUEUE) }); + } } - // API-specific - if (this.ROLES.Contains(Roles.API)) + if (hasEvaluation) { - this.config.Optional("INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE", this.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE); - this.config.Optional("INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY", this.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY, hideValue: true); + if (string.IsNullOrEmpty(INFERENCE_CONTAINER)) + { + yield return new ValidationResult( + "INFERENCE_CONTAINER must be set when using the EvaluationProxy role.", + new[] { nameof(INFERENCE_CONTAINER) }); + } + + if (string.IsNullOrEmpty(EVALUATION_CONTAINER)) + { + yield return new ValidationResult( + "EVALUATION_CONTAINER must be set when using the EvaluationProxy role.", + new[] { nameof(EVALUATION_CONTAINER) }); + } + + if (string.IsNullOrEmpty(EVALUATION_URL)) + { + yield return new ValidationResult( + "EVALUATION_URL must be set when using the EvaluationProxy role.", + new[] { nameof(EVALUATION_URL) }); + } + + if (INBOUND_EVALUATION_QUEUES.Length == 0) + { + yield return new ValidationResult( + "INBOUND_EVALUATION_QUEUES must be set when using the EvaluationProxy role.", + new[] { nameof(INBOUND_EVALUATION_QUEUES) }); + } } - // InferenceProxy-specific - if (this.ROLES.Contains(Roles.InferenceProxy)) + if (hasInference || hasEvaluation) { - this.config.Require("INFERENCE_CONCURRENCY", this.INFERENCE_CONCURRENCY); - this.config.Require("INFERENCE_CONTAINER", this.INFERENCE_CONTAINER); - this.config.Require("INFERENCE_URL", this.INFERENCE_URL); - this.config.Require("INBOUND_INFERENCE_QUEUES", this.INBOUND_INFERENCE_QUEUES); - if (this.INBOUND_INFERENCE_QUEUES.Length == 0) + if (SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING <= 0) + { + yield return new ValidationResult( + "SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING must be greater than 0 for proxy roles.", + new[] { nameof(SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING) }); + } + + if (BACKOFF_ON_STATUS_CODES.Length == 0) + { + yield return new ValidationResult( + "BACKOFF_ON_STATUS_CODES must include at least one status code for proxy roles.", + new[] { nameof(BACKOFF_ON_STATUS_CODES) }); + } + + if (DEADLETTER_ON_STATUS_CODES.Length == 0) { - throw new Exception("When configured for the InferenceProxy role, INBOUND_INFERENCE_QUEUES must be specified."); + yield return new ValidationResult( + "DEADLETTER_ON_STATUS_CODES must include at least one status code for proxy roles.", + new[] { nameof(DEADLETTER_ON_STATUS_CODES) }); } - this.config.Require("OUTBOUND_INFERENCE_QUEUE", this.OUTBOUND_INFERENCE_QUEUE); - this.config.Optional("INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE", this.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE); - this.config.Optional("INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY", this.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY, hideValue: true); - this.config.Optional("INBOUND_INFERENCE_TRANSFORM_FILE", this.INBOUND_INFERENCE_TRANSFORM_FILE); - this.config.Optional("INBOUND_INFERENCE_TRANSFORM_QUERY", this.INBOUND_INFERENCE_TRANSFORM_QUERY, hideValue: true); - this.config.Optional("PROCESS_METRICS_IN_INFERENCE_RESPONSE", this.PROCESS_METRICS_IN_INFERENCE_RESPONSE.ToString()); + } + } + + private List ParseRoles(string[]? raw) + { + invalidRoles.Clear(); + List roles = []; + if (raw is null || raw.Length == 0) + { + return roles; } - // EvaluationProxy-specific - if (this.ROLES.Contains(Roles.EvaluationProxy)) + foreach (var entry in raw) { - this.config.Require("EVALUATION_CONCURRENCY", this.EVALUATION_CONCURRENCY); - this.config.Require("INFERENCE_CONTAINER", this.INFERENCE_CONTAINER); - this.config.Require("EVALUATION_CONTAINER", this.EVALUATION_CONTAINER); - this.config.Require("EVALUATION_URL", this.EVALUATION_URL); - this.config.Require("INBOUND_EVALUATION_QUEUES", this.INBOUND_EVALUATION_QUEUES); - if (this.INBOUND_EVALUATION_QUEUES.Length == 0) + if (Enum.TryParse(entry, true, out Roles role)) { - throw new Exception("When configured for the EvaluationProxy role, INBOUND_EVALUATION_QUEUES must be specified."); + roles.Add(role); + } + else + { + invalidRoles.Add(entry); } - this.config.Optional("INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE", this.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE); - this.config.Optional("INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY", this.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY, hideValue: true); - this.config.Optional("INBOUND_EVALUATION_TRANSFORM_FILE", this.INBOUND_EVALUATION_TRANSFORM_FILE); - this.config.Optional("INBOUND_EVALUATION_TRANSFORM_QUERY", this.INBOUND_EVALUATION_TRANSFORM_QUERY, hideValue: true); - this.config.Optional("PROCESS_METRICS_IN_EVALUATION_RESPONSE", this.PROCESS_METRICS_IN_EVALUATION_RESPONSE.ToString()); } - // any proxy - if (this.ROLES.Contains(Roles.InferenceProxy) || this.ROLES.Contains(Roles.EvaluationProxy)) + return roles; + } + + private static int[] ParseIntArray(string[]? raw, int[] defaults) + { + if (raw is null || raw.Length == 0) + { + return defaults; + } + + List values = []; + foreach (var entry in raw) + { + if (int.TryParse(entry, out var parsed)) + { + values.Add(parsed); + } + } + + return values.Count == 0 ? defaults : [.. values]; + } + + private static string? ResolveTransformQuery(string? query, string? filePath) + { + if (!string.IsNullOrEmpty(query)) + { + return query; + } + + if (string.IsNullOrEmpty(filePath)) { - this.config.Require("SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING", this.SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING); - this.config.Require("BACKOFF_ON_STATUS_CODES", this.BACKOFF_ON_STATUS_CODES.Select(c => c.ToString()).ToArray()); - this.config.Require("DEADLETTER_ON_STATUS_CODES", this.DEADLETTER_ON_STATUS_CODES.Select(c => c.ToString()).ToArray()); - this.config.Require("MAX_ATTEMPTS_TO_DEQUEUE", this.MAX_ATTEMPTS_TO_DEQUEUE.ToString()); - this.config.Require("MS_TO_PAUSE_WHEN_EMPTY", this.MS_TO_PAUSE_WHEN_EMPTY.ToString()); - this.config.Require("DEQUEUE_FOR_X_SECONDS", this.DEQUEUE_FOR_X_SECONDS.ToString()); - this.config.Require("MS_BETWEEN_DEQUEUE", this.MS_BETWEEN_DEQUEUE.ToString()); - this.config.Require("MS_TO_ADD_ON_BUSY", this.MS_TO_ADD_ON_BUSY.ToString()); - this.config.Require("MINUTES_BETWEEN_RESTORE_AFTER_BUSY", this.MINUTES_BETWEEN_RESTORE_AFTER_BUSY.ToString()); - this.config.Optional("EXPERIMENT_CATALOG_BASE_URL", this.EXPERIMENT_CATALOG_BASE_URL); + return query; } + + if (!File.Exists(filePath)) + { + throw new FileNotFoundException($"transform file not found: {filePath}", filePath); + } + + return File.ReadAllText(filePath); } } \ No newline at end of file diff --git a/evaluator/config/IConfig.cs b/evaluator/config/IConfig.cs index 74f7ec5..3674b97 100644 --- a/evaluator/config/IConfig.cs +++ b/evaluator/config/IConfig.cs @@ -4,43 +4,41 @@ namespace Evaluator; public interface IConfig { - int PORT { get; } - List ROLES { get; } - string OPEN_TELEMETRY_CONNECTION_STRING { get; } - string AZURE_STORAGE_ACCOUNT_NAME { get; } - string AZURE_STORAGE_CONNECTION_STRING { get; } - string INFERENCE_CONTAINER { get; } - string EVALUATION_CONTAINER { get; } - string[] INBOUND_INFERENCE_QUEUES { get; } - string[] INBOUND_EVALUATION_QUEUES { get; } - string OUTBOUND_INFERENCE_QUEUE { get; } - int INFERENCE_CONCURRENCY { get; } - int EVALUATION_CONCURRENCY { get; } - int MS_TO_PAUSE_WHEN_EMPTY { get; } - int DEQUEUE_FOR_X_SECONDS { get; } - int MS_BETWEEN_DEQUEUE { get; } + int PORT { get; set; } + List ROLES { get; set; } + string? OPEN_TELEMETRY_CONNECTION_STRING { get; set; } + string? AZURE_STORAGE_ACCOUNT_NAME { get; set; } + string? AZURE_STORAGE_CONNECTION_STRING { get; set; } + string? INFERENCE_CONTAINER { get; set; } + string? EVALUATION_CONTAINER { get; set; } + string[] INBOUND_INFERENCE_QUEUES { get; set; } + string[] INBOUND_EVALUATION_QUEUES { get; set; } + string? OUTBOUND_INFERENCE_QUEUE { get; set; } + int INFERENCE_CONCURRENCY { get; set; } + int EVALUATION_CONCURRENCY { get; set; } + int MS_TO_PAUSE_WHEN_EMPTY { get; set; } + int DEQUEUE_FOR_X_SECONDS { get; set; } + int MS_BETWEEN_DEQUEUE { get; set; } int MS_BETWEEN_DEQUEUE_CURRENT { get; set; } - int MAX_ATTEMPTS_TO_DEQUEUE { get; } - int MS_TO_ADD_ON_BUSY { get; } - int MINUTES_BETWEEN_RESTORE_AFTER_BUSY { get; } - string INFERENCE_URL { get; } - string EVALUATION_URL { get; } - int SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING { get; } - int[] BACKOFF_ON_STATUS_CODES { get; } - int[] DEADLETTER_ON_STATUS_CODES { get; } - string EXPERIMENT_CATALOG_BASE_URL { get; } - string INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE { get; } - string INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY { get; } - string INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE { get; } - string INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY { get; } - string INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE { get; } - string INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY { get; } - string INBOUND_INFERENCE_TRANSFORM_FILE { get; } - string INBOUND_INFERENCE_TRANSFORM_QUERY { get; } - string INBOUND_EVALUATION_TRANSFORM_FILE { get; } - string INBOUND_EVALUATION_TRANSFORM_QUERY { get; } - bool PROCESS_METRICS_IN_INFERENCE_RESPONSE { get; } - bool PROCESS_METRICS_IN_EVALUATION_RESPONSE { get; } - - void Validate(); + int MAX_ATTEMPTS_TO_DEQUEUE { get; set; } + int MS_TO_ADD_ON_BUSY { get; set; } + int MINUTES_BETWEEN_RESTORE_AFTER_BUSY { get; set; } + string? INFERENCE_URL { get; set; } + string? EVALUATION_URL { get; set; } + int SECONDS_BEFORE_TIMEOUT_FOR_PROCESSING { get; set; } + int[] BACKOFF_ON_STATUS_CODES { get; set; } + int[] DEADLETTER_ON_STATUS_CODES { get; set; } + string? EXPERIMENT_CATALOG_BASE_URL { get; set; } + string? INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_FILE { get; set; } + string? INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY { get; set; } + string? INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_FILE { get; set; } + string? INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY { get; set; } + string? INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_FILE { get; set; } + string? INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY { get; set; } + string? INBOUND_INFERENCE_TRANSFORM_FILE { get; set; } + string? INBOUND_INFERENCE_TRANSFORM_QUERY { get; set; } + string? INBOUND_EVALUATION_TRANSFORM_FILE { get; set; } + string? INBOUND_EVALUATION_TRANSFORM_QUERY { get; set; } + bool PROCESS_METRICS_IN_INFERENCE_RESPONSE { get; set; } + bool PROCESS_METRICS_IN_EVALUATION_RESPONSE { get; set; } } \ No newline at end of file diff --git a/evaluator/config/KestrelConfigurator.cs b/evaluator/config/KestrelConfigurator.cs new file mode 100644 index 0000000..fe2bbff --- /dev/null +++ b/evaluator/config/KestrelConfigurator.cs @@ -0,0 +1,23 @@ +using System.Net; +using System.Threading; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.Extensions.Options; +using NetBricks; + +namespace Evaluator; + +public class KestrelConfigurator(IConfigFactory configFactory) : IConfigureOptions +{ + public void Configure(KestrelServerOptions options) + { + var config = configFactory.GetAsync(CancellationToken.None).ConfigureAwait(false).GetAwaiter().GetResult(); + if (config.ROLES.Contains(Roles.API)) + { + options.ListenAnyIP(config.PORT); + } + else + { + options.Listen(IPAddress.None, 0); + } + } +} diff --git a/evaluator/evaluator.csproj b/evaluator/evaluator.csproj index 3258c07..7524a51 100644 --- a/evaluator/evaluator.csproj +++ b/evaluator/evaluator.csproj @@ -7,26 +7,26 @@ - - + + - - - - - - - - - - - - + + + + + + + + + + + + diff --git a/evaluator/evaluator.http b/evaluator/evaluator.http index 46701f4..e8d9177 100644 --- a/evaluator/evaluator.http +++ b/evaluator/evaluator.http @@ -1,5 +1,5 @@ ### enqueue an evaluation job -POST http://localhost:7000/api/evaluations HTTP/1.1 +POST http://localhost:6030/api/evaluations HTTP/1.1 Content-Type: application/json { @@ -36,5 +36,5 @@ Content-Type: application/json } ### get status of an evaluation job -GET http://localhost:7000/api/evaluations/status HTTP/1.1 +GET http://localhost:6030/api/evaluations/status HTTP/1.1 Accept: application/json diff --git a/evaluator/services/AzureStorageQueueReaderBase.cs b/evaluator/services/AzureStorageQueueReaderBase.cs index 12f2164..a9e74ce 100644 --- a/evaluator/services/AzureStorageQueueReaderBase.cs +++ b/evaluator/services/AzureStorageQueueReaderBase.cs @@ -15,34 +15,32 @@ using Azure.Storage.Queues.Models; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using NetBricks; using Newtonsoft.Json; namespace Evaluator; -public abstract class AzureStorageQueueReaderBase(IConfig config, +public abstract class AzureStorageQueueReaderBase( + IConfigFactory configFactory, IHttpClientFactory httpClientFactory, - ILogger logger, - DefaultAzureCredential? defaultAzureCredential = null) + DefaultAzureCredential defaultAzureCredential, + ILogger logger) : BackgroundService { - protected readonly IConfig config = config; - protected readonly DefaultAzureCredential? defaultAzureCredential = defaultAzureCredential; - protected readonly IHttpClientFactory httpClientFactory = httpClientFactory; - protected readonly ILogger logger = logger; - - protected BlobClient GetBlobClient(string containerName, string blobName) + protected async Task GetBlobClientAsync(string containerName, string blobName, CancellationToken cancellationToken) { - var blobUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"; - var blobClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) - ? new BlobServiceClient(new Uri(blobUrl), this.defaultAzureCredential) - : new BlobServiceClient(this.config.AZURE_STORAGE_CONNECTION_STRING); + var config = await configFactory.GetAsync(cancellationToken); + var blobUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"; + var blobClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) + ? new BlobServiceClient(new Uri(blobUrl), defaultAzureCredential) + : new BlobServiceClient(config.AZURE_STORAGE_CONNECTION_STRING); var containerClient = blobClient.GetBlobContainerClient(containerName); return containerClient.GetBlobClient(blobName); } public async Task GetQueueMessageCountAsync(QueueClient queueClient) { - this.logger.LogDebug("getting message count for queue {q}...", queueClient.Name); + logger.LogDebug("getting message count for queue {q}...", queueClient.Name); var properties = await queueClient.GetPropertiesAsync(); var x = properties.Value.ApproximateMessagesCount; return x; @@ -59,7 +57,7 @@ public async Task> GetAllQueueMessageCountsAsync(List> GetAllQueueMessageCountsAsync(List UploadBlobAsync(string containerName, string blobName, string content, CancellationToken cancellationToken) { - this.logger.LogDebug("attempting to upload {c}/{b}...", containerName, blobName); - var blobClient = this.GetBlobClient(containerName, blobName); + logger.LogDebug("attempting to upload {c}/{b}...", containerName, blobName); + var blobClient = await this.GetBlobClientAsync(containerName, blobName, cancellationToken); using var stream = new MemoryStream(Encoding.UTF8.GetBytes(content)); await blobClient.UploadAsync(stream, overwrite: true, cancellationToken); - this.logger.LogInformation("successfully uploaded {c}/{b}.", containerName, blobName); + logger.LogInformation("successfully uploaded {c}/{b}.", containerName, blobName); return blobClient.Uri.ToString(); } - protected string GetBlobUri(string containerName, string blobName) + protected async Task GetBlobUriAsync(string containerName, string blobName, CancellationToken cancellationToken) { - var blobClient = this.GetBlobClient(containerName, blobName); + var blobClient = await this.GetBlobClientAsync(containerName, blobName, cancellationToken); return blobClient.Uri.ToString(); } @@ -92,14 +90,16 @@ protected async Task RecordMetricsAsync( { return; } - if (string.IsNullOrEmpty(this.config.EXPERIMENT_CATALOG_BASE_URL)) + + var config = await configFactory.GetAsync(cancellationToken); + if (string.IsNullOrEmpty(config.EXPERIMENT_CATALOG_BASE_URL)) { - this.logger.LogWarning("there is no EXPERIMENT_CATALOG_BASE_URL provided, so no metrics will be logged."); + logger.LogWarning("there is no EXPERIMENT_CATALOG_BASE_URL provided, so no metrics will be logged."); return; } - this.logger.LogDebug("attempting to record {x} metrics...", metrics.Count); - using var httpClient = this.httpClientFactory.CreateClient(); + logger.LogDebug("attempting to record {x} metrics...", metrics.Count); + using var httpClient = httpClientFactory.CreateClient(); var result = new Result { Ref = pipelineRequest.Ref, @@ -111,7 +111,7 @@ protected async Task RecordMetricsAsync( }; var resultJson = JsonConvert.SerializeObject(result); var response = await httpClient.PostAsync( - $"{this.config.EXPERIMENT_CATALOG_BASE_URL}/api/projects/{pipelineRequest.Project}/experiments/{pipelineRequest.Experiment}/results", + $"{config.EXPERIMENT_CATALOG_BASE_URL}/api/projects/{pipelineRequest.Project}/experiments/{pipelineRequest.Experiment}/results", new StringContent(resultJson, Encoding.UTF8, "application/json"), cancellationToken); if (!response.IsSuccessStatusCode) @@ -119,7 +119,7 @@ protected async Task RecordMetricsAsync( var content = await response.Content.ReadAsStringAsync(cancellationToken); throw new Exception($"status code {response.StatusCode} when recording metrics: {content}"); } - this.logger.LogInformation("successfully recorded {x} metrics ({y}).", + logger.LogInformation("successfully recorded {x} metrics ({y}).", metrics.Count, string.Join(", ", metrics.Select(x => x.Key))); } @@ -144,7 +144,7 @@ protected async Task HandleResponseAsync( catch (JsonException) { // responseContent is not valid JSON, skip metrics extraction - this.logger.LogDebug("Response content is not valid JSON, skipping metrics extraction."); + logger.LogDebug("Response content is not valid JSON, skipping metrics extraction."); } } @@ -190,7 +190,7 @@ private static void ExtractMetricsFromJson(object? jsonObject, Dictionary SendForProcessingAsync( PipelineRequest pipelineRequest, - string url, + string? url, string content, QueueMessage queueMessage, string queueBody, @@ -198,11 +198,14 @@ private static void ExtractMetricsFromJson(object? jsonObject, Dictionary 0) { - this.config.MS_BETWEEN_DEQUEUE_CURRENT += ms; - this.logger.LogWarning( + config.MS_BETWEEN_DEQUEUE_CURRENT += ms; + logger.LogWarning( "received {code} from id {id}; delaying {ms} ms for a MS_BETWEEN_DEQUEUE of {total} ms.", response.StatusCode, callId, ms, - this.config.MS_BETWEEN_DEQUEUE_CURRENT); + config.MS_BETWEEN_DEQUEUE_CURRENT); } } - if (this.config.DEADLETTER_ON_STATUS_CODES.Contains((int)response.StatusCode)) + if (config.DEADLETTER_ON_STATUS_CODES.Contains((int)response.StatusCode)) { throw new DeadletterException($"status code {response.StatusCode} from id {callId} is considered fatal.", queueMessage, queueBody); } @@ -254,19 +257,21 @@ private static void ExtractMetricsFromJson(object? jsonObject, Dictionary configFactory, IHttpClientFactory httpClientFactory, - ILogger logger, - DefaultAzureCredential? defaultAzureCredential = null) - : AzureStorageQueueReaderBase(config, httpClientFactory, logger, defaultAzureCredential) + DefaultAzureCredential defaultAzureCredential, + ILogger logger) + : AzureStorageQueueReaderBase(configFactory, httpClientFactory, defaultAzureCredential, logger) { + private readonly IConfigFactory configFactory = configFactory; + private readonly DefaultAzureCredential defaultAzureCredential = defaultAzureCredential; + private readonly ILogger logger = logger; private readonly List inboundQueues = []; private readonly List inboundDeadletterQueues = []; - private readonly TaskRunner taskRunner = new(config.EVALUATION_CONCURRENCY); + private TaskRunner? taskRunner; private async Task ProcessRequestAsync( QueueClient inboundQueue, @@ -29,9 +34,14 @@ private async Task ProcessRequestAsync( var isConsideredToHaveProcessed = false; try { + // get config + var config = await this.configFactory.GetAsync(cancellationToken); + // check for a message this.logger.LogDebug("checking for a message in queue {q}...", inboundQueue.Name); - var message = await inboundQueue.ReceiveMessageAsync(TimeSpan.FromSeconds(this.config.DEQUEUE_FOR_X_SECONDS), cancellationToken); + var message = await inboundQueue.ReceiveMessageAsync( + TimeSpan.FromSeconds(config.DEQUEUE_FOR_X_SECONDS), + cancellationToken); var body = message?.Value?.Body?.ToString(); if (string.IsNullOrEmpty(body)) { @@ -39,9 +49,12 @@ private async Task ProcessRequestAsync( } // handle deadletter - if (message!.Value.DequeueCount > this.config.MAX_ATTEMPTS_TO_DEQUEUE) + if (message!.Value.DequeueCount > config.MAX_ATTEMPTS_TO_DEQUEUE) { - throw new DeadletterException($"message {message.Value.MessageId} has been dequeued {message.Value.DequeueCount} times", message.Value, body); + throw new DeadletterException( + $"message {message.Value.MessageId} has been dequeued {message.Value.DequeueCount} times", + message.Value, + body); } // deserialize the pipeline request @@ -53,10 +66,18 @@ private async Task ProcessRequestAsync( // it is considered to have processed once it starts doing something related to the actual request isConsideredToHaveProcessed = true; + // ensure required config values are present + var inferenceContainer = config.INFERENCE_CONTAINER + ?? throw new InvalidOperationException("INFERENCE_CONTAINER must be set for evaluation processing."); + var evaluationContainer = config.EVALUATION_CONTAINER + ?? throw new InvalidOperationException("EVALUATION_CONTAINER must be set for evaluation processing."); + var evaluationUrl = config.EVALUATION_URL + ?? throw new InvalidOperationException("EVALUATION_URL must be set for evaluation processing."); + // download and transform the inference file first - var inferenceBlobClient = this.GetBlobClient(this.config.INFERENCE_CONTAINER, $"{request.RunId}/{request.Id}.json"); + var inferenceBlobClient = await this.GetBlobClientAsync(inferenceContainer, $"{request.RunId}/{request.Id}.json", cancellationToken); var inferenceContent = await inferenceBlobClient.DownloadAndTransformAsync( - this.config.INBOUND_INFERENCE_TRANSFORM_QUERY, + config.INBOUND_INFERENCE_TRANSFORM_QUERY, this.logger, cancellationToken); @@ -65,7 +86,6 @@ private async Task ProcessRequestAsync( var inferenceJson = JsonConvert.DeserializeObject(inferenceContent); bool hasGroundTruthNode = inferenceJson?.ground_truth != null; bool hasInferenceNode = inferenceJson?.inference != null; - if (hasGroundTruthNode && hasInferenceNode) { payload = inferenceContent; @@ -74,9 +94,9 @@ private async Task ProcessRequestAsync( { // download and transform the ground truth file var groundTruthBlobRef = new BlobRef(request.GroundTruthUri); - var groundTruthBlobClient = this.GetBlobClient(groundTruthBlobRef.Container, groundTruthBlobRef.BlobName); + var groundTruthBlobClient = await this.GetBlobClientAsync(groundTruthBlobRef.Container, groundTruthBlobRef.BlobName, cancellationToken); var groundTruthContent = await groundTruthBlobClient.DownloadAndTransformAsync( - this.config.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY, + config.INBOUND_GROUNDTRUTH_FOR_EVALUATION_TRANSFORM_QUERY, this.logger, cancellationToken); @@ -91,7 +111,7 @@ private async Task ProcessRequestAsync( // call processing URL var (responseHeaders, responseContent) = await this.SendForProcessingAsync( request, - this.config.EVALUATION_URL, + evaluationUrl, payload, message.Value, body, @@ -99,13 +119,17 @@ private async Task ProcessRequestAsync( cancellationToken); // upload the result - var evaluationUri = await this.UploadBlobAsync(this.config.EVALUATION_CONTAINER, $"{request.RunId}/{request.Id}.json", responseContent, cancellationToken); + var evaluationUri = await this.UploadBlobAsync( + evaluationContainer, + $"{request.RunId}/{request.Id}.json", + responseContent, + cancellationToken); // get reference to the inferenceUri - var inferenceUri = this.GetBlobUri(this.config.INFERENCE_CONTAINER, $"{request.RunId}/{request.Id}.json"); + var inferenceUri = await this.GetBlobUriAsync(inferenceContainer, $"{request.RunId}/{request.Id}.json", cancellationToken); // handle the response headers (metrics, etc.) - if (this.config.PROCESS_METRICS_IN_EVALUATION_RESPONSE) + if (config.PROCESS_METRICS_IN_EVALUATION_RESPONSE) { await this.HandleResponseAsync(request, responseContent, inferenceUri, evaluationUri, cancellationToken); } @@ -137,7 +161,8 @@ private async Task GetMessagesFromInboundQueuesAsync(CancellationToken canc { var queue = this.inboundQueues[i]; var deadletter = this.inboundDeadletterQueues[i]; - await this.taskRunner.StartAsync(() => + var runner = this.taskRunner ?? throw new InvalidOperationException("task runner not initialized."); + await runner.StartAsync(() => this.ProcessRequestAsync(queue, deadletter, cancellationToken), onSuccess: async isConsideredToHaveProcessed => { @@ -181,19 +206,28 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) public override async Task StartAsync(CancellationToken cancellationToken) { - foreach (var queue in this.config.INBOUND_EVALUATION_QUEUES) + var config = await this.configFactory.GetAsync(cancellationToken); + if (!config.ROLES.Contains(Roles.EvaluationProxy)) + { + this.logger.LogInformation("EvaluationProxy role not configured; skipping AzureStorageQueueReaderForEvaluation."); + return; + } + + this.taskRunner = new TaskRunner(config.EVALUATION_CONCURRENCY); + + foreach (var queue in config.INBOUND_EVALUATION_QUEUES) { - var queueUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}"; - var queueClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var queueUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}"; + var queueClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new QueueClient(new Uri(queueUrl), this.defaultAzureCredential) - : new QueueClient(this.config.AZURE_STORAGE_CONNECTION_STRING, queue); + : new QueueClient(config.AZURE_STORAGE_CONNECTION_STRING, queue); await queueClient.ConnectAsync(this.logger, cancellationToken); this.inboundQueues.Add(queueClient); - var deadletterUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}-deadletter"; - var deadletterClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var deadletterUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}-deadletter"; + var deadletterClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new QueueClient(new Uri(deadletterUrl), this.defaultAzureCredential) - : new QueueClient(this.config.AZURE_STORAGE_CONNECTION_STRING, queue + "-deadletter"); + : new QueueClient(config.AZURE_STORAGE_CONNECTION_STRING, queue + "-deadletter"); await deadletterClient.ConnectAsync(this.logger, cancellationToken); this.inboundDeadletterQueues.Add(deadletterClient); } diff --git a/evaluator/services/AzureStorageQueueReaderForInference.cs b/evaluator/services/AzureStorageQueueReaderForInference.cs index c84827c..ad599c8 100644 --- a/evaluator/services/AzureStorageQueueReaderForInference.cs +++ b/evaluator/services/AzureStorageQueueReaderForInference.cs @@ -7,19 +7,24 @@ using Azure.Identity; using Azure.Storage.Queues; using Microsoft.Extensions.Logging; +using NetBricks; using Newtonsoft.Json; namespace Evaluator; -public class AzureStorageQueueReaderForInference(IConfig config, +public class AzureStorageQueueReaderForInference( + IConfigFactory configFactory, IHttpClientFactory httpClientFactory, - ILogger logger, - DefaultAzureCredential? defaultAzureCredential = null) - : AzureStorageQueueReaderBase(config, httpClientFactory, logger, defaultAzureCredential) + DefaultAzureCredential defaultAzureCredential, + ILogger logger) + : AzureStorageQueueReaderBase(configFactory, httpClientFactory, defaultAzureCredential, logger) { + private readonly IConfigFactory configFactory = configFactory; + private readonly DefaultAzureCredential defaultAzureCredential = defaultAzureCredential; + private readonly ILogger logger = logger; private readonly List inboundQueues = []; private readonly List inboundDeadletterQueues = []; - private readonly TaskRunner taskRunner = new(config.INFERENCE_CONCURRENCY); + private TaskRunner? taskRunner; private QueueClient? outboundQueue; private async Task ProcessRequestAsync( @@ -30,9 +35,14 @@ private async Task ProcessRequestAsync( var isConsideredToHaveProcessed = false; try { + // get config + var config = await this.configFactory.GetAsync(cancellationToken); + // check for a message this.logger.LogDebug("checking for a message in queue {q}...", inboundQueue.Name); - var message = await inboundQueue.ReceiveMessageAsync(TimeSpan.FromSeconds(this.config.DEQUEUE_FOR_X_SECONDS), cancellationToken); + var message = await inboundQueue.ReceiveMessageAsync( + TimeSpan.FromSeconds(config.DEQUEUE_FOR_X_SECONDS), + cancellationToken); var body = message?.Value?.Body?.ToString(); if (string.IsNullOrEmpty(body)) { @@ -40,9 +50,12 @@ private async Task ProcessRequestAsync( } // handle deadletter - if (message!.Value.DequeueCount > this.config.MAX_ATTEMPTS_TO_DEQUEUE) + if (message!.Value.DequeueCount > config.MAX_ATTEMPTS_TO_DEQUEUE) { - throw new DeadletterException($"message {message.Value.MessageId} has been dequeued {message.Value.DequeueCount} times.", message.Value, body); + throw new DeadletterException( + $"message {message.Value.MessageId} has been dequeued {message.Value.DequeueCount} times.", + message.Value, + body); } // deserialize the pipeline request @@ -54,18 +67,24 @@ private async Task ProcessRequestAsync( // it is considered to have processed once it starts doing something related to the actual request isConsideredToHaveProcessed = true; + // ensure required config values are present + var inferenceContainer = config.INFERENCE_CONTAINER + ?? throw new InvalidOperationException("INFERENCE_CONTAINER must be set for inference processing."); + var inferenceUrl = config.INFERENCE_URL + ?? throw new InvalidOperationException("INFERENCE_URL must be set for inference processing."); + // download and transform the ground truth file var groundTruthBlobRef = new BlobRef(request.GroundTruthUri); - var groundTruthBlobClient = this.GetBlobClient(groundTruthBlobRef.Container, groundTruthBlobRef.BlobName); + var groundTruthBlobClient = await this.GetBlobClientAsync(groundTruthBlobRef.Container, groundTruthBlobRef.BlobName, cancellationToken); var groundTruthContent = await groundTruthBlobClient.DownloadAndTransformAsync( - this.config.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY, + config.INBOUND_GROUNDTRUTH_FOR_INFERENCE_TRANSFORM_QUERY, this.logger, cancellationToken); // call processing URL var (responseHeaders, responseContent) = await this.SendForProcessingAsync( request, - this.config.INFERENCE_URL, + inferenceUrl, groundTruthContent, message.Value, body, @@ -73,10 +92,14 @@ private async Task ProcessRequestAsync( cancellationToken); // upload the result - var inferenceUri = await this.UploadBlobAsync(this.config.INFERENCE_CONTAINER, $"{request.RunId}/{request.Id}.json", responseContent, cancellationToken); + var inferenceUri = await this.UploadBlobAsync( + inferenceContainer, + $"{request.RunId}/{request.Id}.json", + responseContent, + cancellationToken); // handle the response headers (metrics, etc.) - if (this.config.PROCESS_METRICS_IN_INFERENCE_RESPONSE) + if (config.PROCESS_METRICS_IN_INFERENCE_RESPONSE) { await this.HandleResponseAsync(request, responseContent, inferenceUri, null, cancellationToken); } @@ -111,7 +134,8 @@ private async Task GetMessagesFromInboundQueuesAsync(CancellationToken canc { var queue = this.inboundQueues[i]; var deadletter = this.inboundDeadletterQueues[i]; - await this.taskRunner.StartAsync(() => + var runner = this.taskRunner ?? throw new InvalidOperationException("task runner not initialized."); + await runner.StartAsync(() => this.ProcessRequestAsync(queue, deadletter, cancellationToken), onSuccess: async isConsideredToHaveProcessed => { @@ -155,31 +179,40 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) public override async Task StartAsync(CancellationToken cancellationToken) { + var config = await this.configFactory.GetAsync(cancellationToken); + if (!config.ROLES.Contains(Roles.InferenceProxy)) + { + this.logger.LogInformation("InferenceProxy role not configured; skipping AzureStorageQueueReaderForInference."); + return; + } + + this.taskRunner = new TaskRunner(config.INFERENCE_CONCURRENCY); + // try and connect to all the inbound inference queues - foreach (var queue in this.config.INBOUND_INFERENCE_QUEUES) + foreach (var queue in config.INBOUND_INFERENCE_QUEUES) { - var queueUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}"; - var queueClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var queueUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}"; + var queueClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new QueueClient(new Uri(queueUrl), this.defaultAzureCredential) - : new QueueClient(this.config.AZURE_STORAGE_CONNECTION_STRING, queue); + : new QueueClient(config.AZURE_STORAGE_CONNECTION_STRING, queue); await queueClient.ConnectAsync(this.logger, cancellationToken); this.inboundQueues.Add(queueClient); - var deadletterUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}-deadletter"; - var deadletterClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var deadletterUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}-deadletter"; + var deadletterClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new QueueClient(new Uri(deadletterUrl), this.defaultAzureCredential) - : new QueueClient(this.config.AZURE_STORAGE_CONNECTION_STRING, queue + "-deadletter"); + : new QueueClient(config.AZURE_STORAGE_CONNECTION_STRING, queue + "-deadletter"); await deadletterClient.ConnectAsync(this.logger, cancellationToken); this.inboundDeadletterQueues.Add(deadletterClient); } // try and connect to the outbound inference queue - if (!string.IsNullOrEmpty(this.config.OUTBOUND_INFERENCE_QUEUE)) + if (!string.IsNullOrEmpty(config.OUTBOUND_INFERENCE_QUEUE)) { - var queueUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{this.config.OUTBOUND_INFERENCE_QUEUE}"; - var queueClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var queueUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{config.OUTBOUND_INFERENCE_QUEUE}"; + var queueClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new QueueClient(new Uri(queueUrl), this.defaultAzureCredential) - : new QueueClient(this.config.AZURE_STORAGE_CONNECTION_STRING, this.config.OUTBOUND_INFERENCE_QUEUE); + : new QueueClient(config.AZURE_STORAGE_CONNECTION_STRING, config.OUTBOUND_INFERENCE_QUEUE); await queueClient.ConnectAsync(this.logger, cancellationToken); this.outboundQueue = queueClient; } diff --git a/evaluator/services/AzureStorageQueueWriter.cs b/evaluator/services/AzureStorageQueueWriter.cs index 514490c..916fc4c 100644 --- a/evaluator/services/AzureStorageQueueWriter.cs +++ b/evaluator/services/AzureStorageQueueWriter.cs @@ -9,18 +9,19 @@ using Azure.Storage.Queues; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using NetBricks; using Newtonsoft.Json; namespace Evaluator; public class AzureStorageQueueWriter( - IConfig config, - ILogger logger, - DefaultAzureCredential? defaultAzureCredential = null) + IConfigFactory configFactory, + DefaultAzureCredential defaultAzureCredential, + ILogger logger) : BackgroundService { - private readonly IConfig config = config; - private readonly DefaultAzureCredential? defaultAzureCredential = defaultAzureCredential; + private readonly IConfigFactory configFactory = configFactory; + private readonly DefaultAzureCredential defaultAzureCredential = defaultAzureCredential; private readonly ILogger logger = logger; private readonly Channel enqueueRequests = Channel.CreateUnbounded(); @@ -45,10 +46,13 @@ private async Task EnqueueBlobAsync( { try { + // get configuration + var config = await this.configFactory.GetAsync(cancellationToken); + // load the blob file var blobClient = containerClient.GetBlobClient(blob.Name); string content = await blobClient.DownloadAndTransformAsync( - this.config.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY, + config.INBOUND_GROUNDTRUTH_FOR_API_TRANSFORM_QUERY, this.logger, cancellationToken); var groundTruthFile = JsonConvert.DeserializeObject(content) @@ -91,27 +95,35 @@ private async Task EnqueueBlobAsync( } } - private QueueClient GetQueueClient(string queue) + private async Task GetQueueClientAsync(string queue, CancellationToken cancellationToken) { - var queueUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}"; - var queueClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var config = await this.configFactory.GetAsync(cancellationToken); + var queueUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.queue.core.windows.net/{queue}"; + var queueClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new QueueClient(new Uri(queueUrl), this.defaultAzureCredential) - : new QueueClient(this.config.AZURE_STORAGE_CONNECTION_STRING, queue); + : new QueueClient(config.AZURE_STORAGE_CONNECTION_STRING, queue); return queueClient; } - private BlobContainerClient GetBlobContainerClient(string container) + private async Task GetBlobContainerClientAsync(string container, CancellationToken cancellationToken) { - var blobUrl = $"https://{this.config.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"; - var blobClient = string.IsNullOrEmpty(this.config.AZURE_STORAGE_CONNECTION_STRING) + var config = await this.configFactory.GetAsync(cancellationToken); + var blobUrl = $"https://{config.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"; + var blobClient = string.IsNullOrEmpty(config.AZURE_STORAGE_CONNECTION_STRING) ? new BlobServiceClient(new Uri(blobUrl), this.defaultAzureCredential) - : new BlobServiceClient(this.config.AZURE_STORAGE_CONNECTION_STRING); + : new BlobServiceClient(config.AZURE_STORAGE_CONNECTION_STRING); return blobClient.GetBlobContainerClient(container); } protected override async Task ExecuteAsync(CancellationToken stoppingToken) { + var config = await configFactory.GetAsync(stoppingToken); + if (!config.ROLES.Contains(Roles.API)) + { + return; + } this.logger.LogInformation("starting to listen for enqueue requests in AzureStorageQueueWriter..."); + while (!stoppingToken.IsCancellationRequested) { try @@ -119,7 +131,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) var enqueueRequest = await this.enqueueRequests.Reader.ReadAsync(stoppingToken); // try and connect to the output queue - var queueClient = this.GetQueueClient(enqueueRequest.Queue); + var queueClient = await this.GetQueueClientAsync(enqueueRequest.Queue, stoppingToken); await queueClient.ConnectAsync(this.logger, stoppingToken); // enqueue everything from each specified container @@ -127,7 +139,7 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) { // connect to the blob container var containerAndPath = containerPlusPath.Split('/', 2); - var containerClient = this.GetBlobContainerClient(containerAndPath[0]); + var containerClient = await this.GetBlobContainerClientAsync(containerAndPath[0], stoppingToken); // enqueue blobs from that container var prefix = containerAndPath.Length == 2 ? containerAndPath[1] : null; diff --git a/evaluator/services/Maintenance.cs b/evaluator/services/Maintenance.cs index 4f56ea6..8defc49 100644 --- a/evaluator/services/Maintenance.cs +++ b/evaluator/services/Maintenance.cs @@ -2,26 +2,26 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Hosting; +using NetBricks; namespace Evaluator; -public class Maintenance(IConfig config) : BackgroundService +public class Maintenance(IConfigFactory configFactory) : BackgroundService { - private readonly IConfig config = config; - protected override async Task ExecuteAsync(CancellationToken stoppingToken) { - if (this.config.MINUTES_BETWEEN_RESTORE_AFTER_BUSY == 0) + var config = await configFactory.GetAsync(stoppingToken); + if (config.MINUTES_BETWEEN_RESTORE_AFTER_BUSY == 0) { return; } while (!stoppingToken.IsCancellationRequested) { - await Task.Delay(TimeSpan.FromMinutes(this.config.MINUTES_BETWEEN_RESTORE_AFTER_BUSY), stoppingToken); - var proposed = this.config.MS_BETWEEN_DEQUEUE_CURRENT - this.config.MS_TO_ADD_ON_BUSY; - this.config.MS_BETWEEN_DEQUEUE_CURRENT = proposed < this.config.MS_BETWEEN_DEQUEUE - ? this.config.MS_BETWEEN_DEQUEUE + await Task.Delay(TimeSpan.FromMinutes(config.MINUTES_BETWEEN_RESTORE_AFTER_BUSY), stoppingToken); + var proposed = config.MS_BETWEEN_DEQUEUE_CURRENT - config.MS_TO_ADD_ON_BUSY; + config.MS_BETWEEN_DEQUEUE_CURRENT = proposed < config.MS_BETWEEN_DEQUEUE + ? config.MS_BETWEEN_DEQUEUE : proposed; } } diff --git a/ui/src/lib/ComparisonTable.svelte b/ui/src/lib/ComparisonTable.svelte index 3a295e2..517d5f1 100644 --- a/ui/src/lib/ComparisonTable.svelte +++ b/ui/src/lib/ComparisonTable.svelte @@ -115,7 +115,7 @@ annotations: [annotation], }), credentials: "include", - } + }, ); if (response.ok) { fetchComparison(); @@ -154,26 +154,20 @@ // fetch comparison const response = await fetch( `${prefix}/api/projects/${project.name}/experiments/${experiment.name}/compare?${tagFilters ?? ""}`, - { credentials: "include" } + { credentials: "include" }, ); comparison = await response.json(); // get a list of metrics const allKeys = [ - ...(comparison.project_baseline - ? Object.keys(comparison.project_baseline?.result?.metrics) - : []), - ...(comparison.experiment_baseline - ? Object.keys(comparison.experiment_baseline?.result?.metrics) - : []), - ...(comparison.sets - ? comparison.sets?.flatMap((experiment) => - Object.keys(experiment.result?.metrics) - ) - : []), + ...Object.keys(comparison.project_baseline?.result?.metrics ?? {}), + ...Object.keys(comparison.experiment_baseline?.result?.metrics ?? {}), + ...(comparison.sets ?? []).flatMap((experiment) => + Object.keys(experiment.result?.metrics ?? {}), + ), ]; metrics = [...new Set(allKeys)].sort((a, b) => - sortMetrics(comparison.metric_definitions, a, b) + sortMetrics(comparison.metric_definitions, a, b), ); // apply the set list