|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | + |
| 4 | +using System.Diagnostics.CodeAnalysis; |
| 5 | +using System.Net; |
| 6 | +using System.Net.Http.Headers; |
| 7 | +using System.Text; |
| 8 | +using System.Text.Json; |
| 9 | +using System.Text.RegularExpressions; |
| 10 | + |
| 11 | +using Valleysoft.DockerCredsProvider; |
| 12 | + |
| 13 | +using Microsoft.NET.Build.Containers.Credentials; |
| 14 | +using System.Net.Sockets; |
| 15 | +using Microsoft.NET.Build.Containers.Resources; |
| 16 | + |
| 17 | +namespace Microsoft.NET.Build.Containers; |
| 18 | + |
| 19 | +/// <summary> |
| 20 | +/// A delegating handler that performs the Docker auth handshake as described <see href="https://docs.docker.com/registry/spec/auth/token/">in their docs</see> if a request isn't authenticated |
| 21 | +/// </summary> |
| 22 | +internal sealed partial class AuthHandshakeMessageHandler : DelegatingHandler |
| 23 | +{ |
| 24 | + private const int MaxRequestRetries = 5; // Arbitrary but seems to work ok for chunked uploads to ghcr.io |
| 25 | + |
| 26 | + private sealed record AuthInfo(Uri Realm, string Service, string? Scope); |
| 27 | + |
| 28 | + /// <summary> |
| 29 | + /// the www-authenticate header must have realm, service, and scope information, so this method parses it into that shape if present |
| 30 | + /// </summary> |
| 31 | + /// <param name="msg"></param> |
| 32 | + /// <param name="authInfo"></param> |
| 33 | + /// <returns></returns> |
| 34 | + private static bool TryParseAuthenticationInfo(HttpResponseMessage msg, [NotNullWhen(true)] out string? scheme, [NotNullWhen(true)] out AuthInfo? authInfo) |
| 35 | + { |
| 36 | + authInfo = null; |
| 37 | + scheme = null; |
| 38 | + |
| 39 | + var authenticateHeader = msg.Headers.WwwAuthenticate; |
| 40 | + if (!authenticateHeader.Any()) |
| 41 | + { |
| 42 | + return false; |
| 43 | + } |
| 44 | + |
| 45 | + AuthenticationHeaderValue header = authenticateHeader.First(); |
| 46 | + if (header is { Scheme: "Bearer" or "Basic", Parameter: string bearerArgs }) |
| 47 | + { |
| 48 | + scheme = header.Scheme; |
| 49 | + Dictionary<string, string> keyValues = new(); |
| 50 | + foreach (Match match in BearerParameterSplitter().Matches(bearerArgs)) |
| 51 | + { |
| 52 | + keyValues.Add(match.Groups["key"].Value, match.Groups["value"].Value); |
| 53 | + } |
| 54 | + |
| 55 | + if (keyValues.TryGetValue("realm", out string? realm) && keyValues.TryGetValue("service", out string? service)) |
| 56 | + { |
| 57 | + string? scope = null; |
| 58 | + keyValues.TryGetValue("scope", out scope); |
| 59 | + authInfo = new AuthInfo(new Uri(realm), service, scope); |
| 60 | + return true; |
| 61 | + } |
| 62 | + } |
| 63 | + |
| 64 | + return false; |
| 65 | + } |
| 66 | + |
| 67 | + public AuthHandshakeMessageHandler(HttpMessageHandler innerHandler) : base(innerHandler) { } |
| 68 | + |
| 69 | + /// <summary> |
| 70 | + /// Response to a request to get a token using some auth. |
| 71 | + /// </summary> |
| 72 | + /// <remarks> |
| 73 | + /// <see href="https://docs.docker.com/registry/spec/auth/token/#token-response-fields"/> |
| 74 | + /// </remarks> |
| 75 | + private sealed record TokenResponse(string? token, string? access_token, int? expires_in, DateTimeOffset? issued_at) |
| 76 | + { |
| 77 | + public string ResolvedToken => token ?? access_token ?? throw new ArgumentException(Resource.GetString(nameof(Strings.InvalidTokenResponse))); |
| 78 | + } |
| 79 | + |
| 80 | + /// <summary> |
| 81 | + /// Uses the authentication information from a 401 response to perform the authentication dance for a given registry. |
| 82 | + /// Credentials for the request are retrieved from the credential provider, then used to acquire a token. |
| 83 | + /// That token is cached for some duration on a per-host basis. |
| 84 | + /// </summary> |
| 85 | + /// <param name="uri"></param> |
| 86 | + /// <param name="service"></param> |
| 87 | + /// <param name="scope"></param> |
| 88 | + /// <param name="cancellationToken"></param> |
| 89 | + /// <returns></returns> |
| 90 | + private async Task<AuthenticationHeaderValue?> GetAuthenticationAsync(string registry, string scheme, Uri realm, string service, string? scope, CancellationToken cancellationToken) |
| 91 | + { |
| 92 | + // Allow overrides for auth via environment variables |
| 93 | + string? credU = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectUser); |
| 94 | + string? credP = Environment.GetEnvironmentVariable(ContainerHelpers.HostObjectPass); |
| 95 | + |
| 96 | + // fetch creds for the host |
| 97 | + DockerCredentials? privateRepoCreds; |
| 98 | + |
| 99 | + if (!string.IsNullOrEmpty(credU) && !string.IsNullOrEmpty(credP)) |
| 100 | + { |
| 101 | + privateRepoCreds = new DockerCredentials(credU, credP); |
| 102 | + } |
| 103 | + else |
| 104 | + { |
| 105 | + try |
| 106 | + { |
| 107 | + privateRepoCreds = await CredsProvider.GetCredentialsAsync(registry).ConfigureAwait(false); |
| 108 | + } |
| 109 | + catch (Exception e) |
| 110 | + { |
| 111 | + throw new CredentialRetrievalException(registry, e); |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + if (scheme is "Basic") |
| 116 | + { |
| 117 | + var basicAuth = new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}"))); |
| 118 | + return AuthHeaderCache.AddOrUpdate(realm, basicAuth); |
| 119 | + } |
| 120 | + else if (scheme is "Bearer") |
| 121 | + { |
| 122 | + // use those creds when calling the token provider |
| 123 | + var header = privateRepoCreds.Username == "<token>" |
| 124 | + ? new AuthenticationHeaderValue("Bearer", privateRepoCreds.Password) |
| 125 | + : new AuthenticationHeaderValue("Basic", Convert.ToBase64String(Encoding.ASCII.GetBytes($"{privateRepoCreds.Username}:{privateRepoCreds.Password}"))); |
| 126 | + var builder = new UriBuilder(realm); |
| 127 | + var queryDict = System.Web.HttpUtility.ParseQueryString(""); |
| 128 | + queryDict["service"] = service; |
| 129 | + if (scope is string s) |
| 130 | + { |
| 131 | + queryDict["scope"] = s; |
| 132 | + } |
| 133 | + builder.Query = queryDict.ToString(); |
| 134 | + var message = new HttpRequestMessage(HttpMethod.Get, builder.ToString()); |
| 135 | + message.Headers.Authorization = header; |
| 136 | + |
| 137 | + var tokenResponse = await base.SendAsync(message, cancellationToken).ConfigureAwait(false); |
| 138 | + tokenResponse.EnsureSuccessStatusCode(); |
| 139 | + |
| 140 | + TokenResponse? token = JsonSerializer.Deserialize<TokenResponse>(tokenResponse.Content.ReadAsStream(cancellationToken)); |
| 141 | + if (token is null) |
| 142 | + { |
| 143 | + throw new ArgumentException(Resource.GetString(nameof(Strings.CouldntDeserializeJsonToken))); |
| 144 | + } |
| 145 | + |
| 146 | + // save the retrieved token in the cache |
| 147 | + var bearerAuth = new AuthenticationHeaderValue("Bearer", token.ResolvedToken); |
| 148 | + return AuthHeaderCache.AddOrUpdate(realm, bearerAuth); |
| 149 | + } |
| 150 | + else |
| 151 | + { |
| 152 | + return null; |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) |
| 157 | + { |
| 158 | + if (request.RequestUri is null) |
| 159 | + { |
| 160 | + throw new ArgumentException(Resource.GetString(nameof(Strings.NoRequestUriSpecified)), nameof(request)); |
| 161 | + } |
| 162 | + |
| 163 | + // attempt to use cached token for the request if available |
| 164 | + if (AuthHeaderCache.TryGet(request.RequestUri, out AuthenticationHeaderValue? cachedAuthentication)) |
| 165 | + { |
| 166 | + request.Headers.Authorization = cachedAuthentication; |
| 167 | + } |
| 168 | + |
| 169 | + int retryCount = 0; |
| 170 | + |
| 171 | + while (retryCount < MaxRequestRetries) |
| 172 | + { |
| 173 | + try |
| 174 | + { |
| 175 | + var response = await base.SendAsync(request, cancellationToken).ConfigureAwait(false); |
| 176 | + if (response is { StatusCode: HttpStatusCode.OK }) |
| 177 | + { |
| 178 | + return response; |
| 179 | + } |
| 180 | + else if (response is { StatusCode: HttpStatusCode.Unauthorized } && TryParseAuthenticationInfo(response, out string? scheme, out AuthInfo? authInfo)) |
| 181 | + { |
| 182 | + if (await GetAuthenticationAsync(request.RequestUri.Host, scheme, authInfo.Realm, authInfo.Service, authInfo.Scope, cancellationToken).ConfigureAwait(false) is AuthenticationHeaderValue authentication) |
| 183 | + { |
| 184 | + request.Headers.Authorization = AuthHeaderCache.AddOrUpdate(request.RequestUri, authentication); |
| 185 | + return await base.SendAsync(request, cancellationToken).ConfigureAwait(false); |
| 186 | + } |
| 187 | + return response; |
| 188 | + } |
| 189 | + else |
| 190 | + { |
| 191 | + return response; |
| 192 | + } |
| 193 | + } |
| 194 | + catch (HttpRequestException e) when (e.InnerException is IOException ioe && ioe.InnerException is SocketException se) |
| 195 | + { |
| 196 | + retryCount += 1; |
| 197 | + |
| 198 | + // TODO: log in a way that is MSBuild-friendly |
| 199 | + Console.WriteLine($"Encountered a SocketException with message \"{se.Message}\". Pausing before retry."); |
| 200 | + |
| 201 | + await Task.Delay(TimeSpan.FromSeconds(1.0 * Math.Pow(2, retryCount)), cancellationToken).ConfigureAwait(false); |
| 202 | + |
| 203 | + // retry |
| 204 | + continue; |
| 205 | + } |
| 206 | + } |
| 207 | + |
| 208 | + throw new ApplicationException(Resource.GetString(nameof(Strings.TooManyRetries))); |
| 209 | + } |
| 210 | + |
| 211 | + [GeneratedRegex("(?<key>\\w+)=\"(?<value>[^\"]*)\"(?:,|$)")] |
| 212 | + private static partial Regex BearerParameterSplitter(); |
| 213 | +} |
0 commit comments