diff --git a/hashicorp-vault-cagateway/Client/HashicorpVaultClient.cs b/hashicorp-vault-cagateway/Client/HashicorpVaultClient.cs index 3ab133d..790443c 100644 --- a/hashicorp-vault-cagateway/Client/HashicorpVaultClient.cs +++ b/hashicorp-vault-cagateway/Client/HashicorpVaultClient.cs @@ -11,9 +11,11 @@ using Microsoft.Extensions.Logging; using Org.BouncyCastle.Asn1.X509; using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Text.Json; +using System.Threading; using System.Threading.Tasks; namespace Keyfactor.Extensions.CAPlugin.HashicorpVault @@ -25,15 +27,16 @@ namespace Keyfactor.Extensions.CAPlugin.HashicorpVault /// public class HashicorpVaultClient { - private VaultHttp _vaultHttp { get; set; } - private static readonly ILogger logger = LogHandler.GetClassLogger(); + private HashicorpVaultCAConfig _caConfig { get; set; } + private HashicorpVaultCATemplateConfig _templateConfig { get; set; } + private readonly ILogger logger; public HashicorpVaultClient(HashicorpVaultCAConfig caConfig, HashicorpVaultCATemplateConfig templateConfig = null) { + logger = LogHandler.GetClassLogger(); logger.MethodEntry(); - - SetClientValuesFromConfigs(caConfig, templateConfig); - + _caConfig = caConfig; + _templateConfig = templateConfig; logger.MethodExit(); } @@ -48,8 +51,9 @@ public async Task SignCSR(string csr, string subject, Dictionary SignCSR(string csr, string subject, Dictionary SignCSR(string csr, string subject, Dictionary SignCSR(string csr, string subject, Dictionary 0) { logger.LogTrace($"the response contained warnings: {string.Join(", ", response.Warnings)}"); } - - logger.LogTrace($"serialized SignResponse: {JsonSerializer.Serialize(response.Data)}"); - + + logger.LogTrace($"serialized SignResponse: {JsonSerializer.Serialize(response.Data)}"); + return response.Data; } catch (Exception ex) @@ -131,6 +136,7 @@ public async Task GetCertificate(string certSerial) try { + var _vaultHttp = ConfigureNewVaultClient(); var response = await _vaultHttp.GetAsync($"cert/{certSerial}"); logger.LogTrace($"successfully received a response for certificate with serial number: {certSerial}"); return response; @@ -151,7 +157,8 @@ public async Task RevokeCertificate(string serial) logger.MethodEntry(); logger.LogTrace($"making request to revoke cert with serial: {serial}"); try - { + { + var _vaultHttp = ConfigureNewVaultClient(); var response = await _vaultHttp.PostAsync("revoke", new RevokeRequest(serial)); logger.LogTrace($"successfully revoked cert with serial {serial}, revocation time: {response.RevocationTime}"); return response; @@ -170,11 +177,12 @@ public async Task PingServer() logger.LogTrace($"performing a system health check request to Vault"); try { + var _vaultHttp = ConfigureNewVaultClient(); var res = await _vaultHttp.HealthCheckAsync(); logger.LogTrace($"-- Vault health check response --"); logger.LogTrace($"Vault version : {res.VaultVersion}"); logger.LogTrace($"sealed? : {res.Sealed}"); - logger.LogTrace($"initialized? : {res.Initialized}"); + logger.LogTrace($"initialized? : {res.Initialized}"); return true; } catch (Exception ex) @@ -192,18 +200,30 @@ public async Task PingServer() /// Retreives all serial numbers for issued certificates /// /// a list of the certificate serial number strings - public async Task> GetAllCertSerialNumbers() + public async Task> GetAllCertSerialNumbers(CancellationToken token) { + var certSerials = new List(); + if (token.IsCancellationRequested) + { + logger.LogWarning($"cancelation was requested; stopping task"); + return certSerials; + } logger.MethodEntry(); - var keys = new List(); try { + var _vaultHttp = ConfigureNewVaultClient(); var res = await _vaultHttp.GetAsync>("certs/?list=true"); - return res.Data.Entries; + var serials = res.Data?.Entries; + if (serials == null || serials.Count == 0) + { + return certSerials; + } + logger.LogTrace($"got {res.Data?.Entries?.Count} serial numbers from {_caConfig.Host} namespace: {_caConfig.Namespace}, mount-point: {_caConfig.MountPoint}"); + return serials; } catch (Exception ex) { - logger.LogError($"there was an error retreiving the certificate keys: {ex.Message}"); + logger.LogError($"there was an error retreiving the certificate keys from {_caConfig.Host} using namespace: {_templateConfig.Namespace ?? _caConfig.Namespace} and mount-point: {_caConfig.MountPoint}. Error: {ex.Message}"); throw; } finally { logger.MethodExit(); } @@ -215,6 +235,7 @@ private async Task> GetRevokedSerialNumbers() var keys = new List(); try { + var _vaultHttp = ConfigureNewVaultClient(); var res = await _vaultHttp.GetAsync("certs/revoked"); keys = res.Entries; } @@ -233,6 +254,7 @@ public async Task> GetRoleNamesAsync() var roleNames = new List(); try { + var _vaultHttp = ConfigureNewVaultClient(); logger.LogTrace("getting the role names as a wrapped keyed-list response.."); var response = await _vaultHttp.GetAsync>("roles/?list=true"); logger.LogTrace($"received {response.Data?.Entries?.Count} role names (or product IDs)"); @@ -243,44 +265,47 @@ public async Task> GetRoleNamesAsync() logger.LogError($"There was a problem retreiving the PKI role names: {LogHandler.FlattenException(ex)}"); throw; } - finally { logger.MethodExit(); } + finally { logger.MethodExit(); } } - private void SetClientValuesFromConfigs(HashicorpVaultCAConfig caConfig, HashicorpVaultCATemplateConfig templateConfig) + private VaultHttp ConfigureNewVaultClient() { logger.MethodEntry(); + try + { + var hostUrl = _caConfig.Host; // host url and authentication details come from the CA config + logger.LogTrace($"set value for Host url: {hostUrl}"); - var hostUrl = caConfig.Host; // host url and authentication details come from the CA config - var token = caConfig.Token; - var nameSpace = string.IsNullOrEmpty(templateConfig?.Namespace) ? caConfig.Namespace : templateConfig.Namespace; // Namespace comes from templateconfig if available, otherwise defaults to caConfig; can be null - var mountPoint = string.IsNullOrEmpty(templateConfig?.MountPoint) ? caConfig.MountPoint : templateConfig.MountPoint; // Mountpoint comes from templateconfig if available, otherwise defaults to caConfig; if null, uses "pki" (Vault Default) - mountPoint = mountPoint ?? "pki"; // using the vault default PKI secrets engine mount point if not present in config - - logger.LogTrace($"set value for Host url: {hostUrl}"); - logger.LogTrace($"set value for authentication token: {token ?? "(not defined)"}"); - logger.LogTrace($"set value for Namespace: {nameSpace ?? "(not defined)"}"); - logger.LogTrace($"set value for Mountpoint: {mountPoint}"); + var token = _caConfig.Token; + logger.LogTrace($"set value for authentication token: {token ?? "(not defined)"}"); - // _certAuthInfo = caConfig?.ClientCertificate; - // logger.LogTrace($"set value for Certificate authentication; thumbprint: {_certAuthInfo?.Thumbprint ?? "(missing) - using token authentication"}"); + var nameSpace = string.IsNullOrEmpty(_templateConfig?.Namespace) ? _caConfig.Namespace : _templateConfig.Namespace; // Namespace comes from templateconfig if available, otherwise defaults to caConfig; can be null + logger.LogTrace($"set value for Namespace: {nameSpace ?? "(not defined)"}"); - //if (_token == null && _certAuthInfo == null) - //{ - // throw new MissingFieldException("Either an authentication token or certificate to use for authentication into Vault must be provided."); - //} + var mountPoint = string.IsNullOrEmpty(_templateConfig?.MountPoint) ? _caConfig.MountPoint : _templateConfig.MountPoint; // Mountpoint comes from templateconfig if available, otherwise defaults to caConfig; if null, uses "pki" (Vault Default) + mountPoint = mountPoint ?? "pki"; // using the vault default PKI secrets engine mount point if not present in config + logger.LogTrace($"set value for Mountpoint: {mountPoint}"); - _vaultHttp = new VaultHttp(hostUrl, mountPoint, token, nameSpace); - logger.MethodExit(); - } + // _certAuthInfo = caConfig?.ClientCertificate; + // logger.LogTrace($"set value for Certificate authentication; thumbprint: {_certAuthInfo?.Thumbprint ?? "(missing) - using token authentication"}"); - private static string ConvertSerialToTrackingId(string serialNumber) - { - // vault returns certificate serial formatted thusly: 17:67:16:b0:b9:45:58:c0:3a:29:e3:cb:d6:98:33:7a:a6:3b:66:c1 - // we cannot use the ':' character as part of our internal tracking id, but Vault requests can work with either ':' or '-' - // so we convert from colon-separated pairs to hyphen separated pairs. + //if (_token == null && _certAuthInfo == null) + //{ + // throw new MissingFieldException("Either an authentication token or certificate to use for authentication into Vault must be provided."); + //} - return serialNumber.Replace(":", "-"); + return new VaultHttp(hostUrl, mountPoint, token, nameSpace); + } + catch (Exception ex) + { + logger.LogError($"error when creating new vault client: {LogHandler.FlattenException(ex)}"); + throw; + } + finally + { + logger.MethodExit(); + } } } } \ No newline at end of file diff --git a/hashicorp-vault-cagateway/Client/VaultHttp.cs b/hashicorp-vault-cagateway/Client/VaultHttp.cs index c1c5f99..6396874 100644 --- a/hashicorp-vault-cagateway/Client/VaultHttp.cs +++ b/hashicorp-vault-cagateway/Client/VaultHttp.cs @@ -23,16 +23,36 @@ namespace Keyfactor.Extensions.CAPlugin.HashicorpVault.Client /// public class VaultHttp { + private string _vaultApiPath { get; set; } private string _mountPoint { get; set; } // not all requests are the the secrets engine, so can't append this permanently to the base path - private RestClient _restClient { get; set; } + private string _authToken { get; set; } + private string _nameSpace { get; set; } private JsonSerializerOptions _serializerOptions { get; set; } + private readonly ILogger logger; - private static readonly ILogger logger = LogHandler.GetClassLogger(); + private RestClient _restClient; + protected RestClient restClient + { + get + { + if (_restClient != null) { return _restClient; } + var restClientOptions = new RestClientOptions(_vaultApiPath) { ThrowOnAnyError = true }; + _restClient = new RestClient(restClientOptions, configureSerialization: s => s.UseSystemTextJson(_serializerOptions)); + _restClient.AddDefaultHeader("X-Vault-Request", "true"); + _restClient.AddDefaultHeader("X-Vault-Token", _authToken); + if (_nameSpace != null) _restClient.AddDefaultHeader("X-Vault-Namespace", _nameSpace); + logger.LogTrace($"configured a new instance of our Vault restsharp client with the configured values:"); + logger.LogTrace($"vault api path: {_vaultApiPath}"); + logger.LogTrace($"mount point: {_mountPoint}"); + logger.LogTrace($"namespace: {_nameSpace}"); + return _restClient; + } + } public VaultHttp(string host, string mountPoint, string authToken, string nameSpace = null) { + logger = LogHandler.GetClassLogger(); logger.MethodEntry(); - _serializerOptions = new() { DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingDefault, @@ -40,21 +60,16 @@ public VaultHttp(string host, string mountPoint, string authToken, string nameSp PropertyNameCaseInsensitive = true, PreferredObjectCreationHandling = JsonObjectCreationHandling.Replace, }; + _vaultApiPath = $"{host.TrimEnd('/')}/v1"; // api path ends in /v1 + _mountPoint = mountPoint.TrimStart('/').TrimEnd('/'); + _authToken = authToken; - var restClientOptions = new RestClientOptions($"{host.TrimEnd('/')}/v1") { ThrowOnAnyError = true }; - _restClient = new RestClient(restClientOptions, configureSerialization: s => s.UseSystemTextJson(_serializerOptions)); - - _mountPoint = mountPoint.TrimStart('/').TrimEnd('/'); // remove leading and trailing slashes - _restClient.AddDefaultHeader("X-Vault-Request", "true"); - _restClient.AddDefaultHeader("X-Vault-Token", authToken); - - if (nameSpace != null) _restClient.AddDefaultHeader("X-Vault-Namespace", nameSpace); - - logger.LogTrace($"configured an instance of our restsharp client with the provided values:"); - logger.LogTrace($"host url: {host}"); - logger.LogTrace($"mount point: {_mountPoint}"); - logger.LogTrace($"namespace: {nameSpace}"); + if (!string.IsNullOrEmpty(nameSpace)) + { + _nameSpace = nameSpace; + } + logger.LogTrace($"configured our httpclient wrapper with the following values:\nbase path: {_vaultApiPath}\nnamespace: {(!string.IsNullOrEmpty(_nameSpace) ? _nameSpace : "(no namespace configured)")}\nmount point: {_mountPoint}\nauth token: {_authToken.Substring(0, 8)}..."); logger.MethodExit(); } @@ -70,25 +85,26 @@ public async Task GetAsync(string path, Dictionary paramet { logger.MethodEntry(); logger.LogTrace($"preparing to send GET request to {path} with parameters {JsonSerializer.Serialize(parameters)}"); - logger.LogTrace($"will attempt to deserialize the response into a {typeof(T)}"); + logger.LogTrace($"will attempt to deserialize the response into a {typeof(T).Name}"); try { var request = new RestRequest($"{_mountPoint}/{path}", Method.Get); if (parameters != null) { request.AddJsonBody(parameters); } - var response = await _restClient.ExecuteGetAsync(request); - + var response = await restClient.ExecuteGetAsync(request); + logger.LogTrace($"response headers: {response.Headers}\nresponse content: {response.Content}\nresponse data: {response.Data}"); response.ThrowIfError(); return response.Data; } catch (Exception ex) { - logger.LogError($"there was an error making the request: {LogHandler.FlattenException(ex)}"); + logger.LogError($"there was an error making the GET request: {LogHandler.FlattenException(ex)}"); throw; } finally { + DisposeRestClient(); logger.MethodExit(); } } @@ -99,7 +115,7 @@ public async Task PostAsync(string path, dynamic parameters = default) var resourcePath = $"{_mountPoint}/{path}"; - logger.LogTrace($"preparing to send POST request to {_restClient.Options.BaseUrl}{resourcePath}"); + logger.LogTrace($"preparing to send POST request to {restClient.Options.BaseUrl}{resourcePath}"); logger.LogTrace($"will attempt to deserialize the response into a {typeof(T)}"); try @@ -112,14 +128,14 @@ public async Task PostAsync(string path, dynamic parameters = default) request.AddJsonBody(serializedParams); } - logger.LogTrace($"full url for the request: {_restClient.Options.BaseUrl}/{request.Resource}"); + logger.LogTrace($"full url for the request: {restClient.Options.BaseUrl}/{request.Resource}"); - var response = await _restClient.ExecutePostAsync(request); + var response = await restClient.ExecutePostAsync(request); logger.LogTrace($"request completed. response returned:"); logger.LogTrace($"response.StatusCode: {response!.StatusCode}"); logger.LogTrace($"response.contentType: {response!.ContentType}"); - + if (response.ErrorMessage != null) logger.LogTrace($"response.ErrorMessage: {response!.ErrorMessage}"); ErrorResponse errorResponse = null; @@ -139,11 +155,12 @@ public async Task PostAsync(string path, dynamic parameters = default) } catch (Exception ex) { - logger.LogError($"there was an error making the request: {LogHandler.FlattenException(ex)}"); + logger.LogError($"there was an error making the POST request: {LogHandler.FlattenException(ex)}"); throw; } finally { + DisposeRestClient(); logger.MethodExit(); } } @@ -154,36 +171,23 @@ public async Task HealthCheckAsync() try { - return await _restClient.GetAsync("sys/seal-status"); + return await restClient.GetAsync("sys/seal-status"); } catch (Exception ex) { - logger.LogError($"there was an error making the request: {LogHandler.FlattenException(ex)}"); + logger.LogError($"there was an error making the health-check request: {LogHandler.FlattenException(ex)}"); throw; } - finally { logger.MethodExit(); } - } - - /// - /// gets the capabilities for the current token in the given namespace. - /// using this method to verify connectivity - /// - /// - public async Task> GetCapabilitiesForThisTokenAndNamespace() - { - logger.MethodEntry(); - try - { - var response = await _restClient.GetAsync("sys/capabilities/self"); - response!.ThrowIfError(); - return response.Content?.data?.capabilities as List; - } - catch (Exception ex) + finally { - logger.LogError($"request to get capabilities for token failed: {LogHandler.FlattenException(ex)}"); - throw; + DisposeRestClient(); + logger.MethodExit(); } - finally { logger.MethodExit(); } + } + + private void DisposeRestClient() { + _restClient?.Dispose(); + _restClient = null; } } } diff --git a/hashicorp-vault-cagateway/HashicorpVaultCAConnector.cs b/hashicorp-vault-cagateway/HashicorpVaultCAConnector.cs index e020e9e..358ee4f 100644 --- a/hashicorp-vault-cagateway/HashicorpVaultCAConnector.cs +++ b/hashicorp-vault-cagateway/HashicorpVaultCAConnector.cs @@ -18,14 +18,16 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using System.Runtime.CompilerServices; namespace Keyfactor.Extensions.CAPlugin.HashicorpVault { public class HashicorpVaultCAConnector : IAnyCAPlugin { private readonly ILogger logger; - private HashicorpVaultCAConfig _caConfig { get; set; } - private HashicorpVaultClient _client { get; set; } + private HashicorpVaultCAConfig _caConfig; + + //private HashicorpVaultClient _client { get; set; } private ICertificateDataReader _certificateDataReader; private JsonSerializerOptions _serializerOptions; @@ -46,20 +48,20 @@ public HashicorpVaultCAConnector() /// Initialize the /// /// The config provider contains information required to connect to the CA. + [MethodImpl(MethodImplOptions.Synchronized)] public void Initialize(IAnyCAPluginConfigProvider configProvider, ICertificateDataReader certificateDataReader) { logger.MethodEntry(LogLevel.Trace); + _certificateDataReader = certificateDataReader; string rawConfig = JsonSerializer.Serialize(configProvider.CAConnectionData); logger.LogTrace($"serialized config: {rawConfig}"); _caConfig = JsonSerializer.Deserialize(rawConfig); logger.MethodExit(LogLevel.Trace); - _client = new HashicorpVaultClient(_caConfig); } /// /// Enrolls for a certificate through the API. /// - /// Reads certificate data from the database. /// The certificate request CSR in PEM format. /// The subject of the certificate request. /// Any SANs added to the request. @@ -70,6 +72,7 @@ public void Initialize(IAnyCAPluginConfigProvider configProvider, ICertificateDa public async Task Enroll(string csr, string subject, Dictionary san, EnrollmentProductInfo productInfo, RequestFormat requestFormat, EnrollmentType enrollmentType) { logger.MethodEntry(LogLevel.Trace); + logger.LogInformation($"Begin {enrollmentType} enrollment for {subject}"); string statusMessage; SignResponse signResponse; @@ -84,7 +87,7 @@ public async Task Enroll(string csr, string subject, Dictionar // create the client logger.LogTrace("instantiating the client.."); - _client = new HashicorpVaultClient(_caConfig, templateConfig); + var hashiClient = new HashicorpVaultClient(_caConfig, templateConfig); logger.LogDebug("Parse subject for Common Name"); string commonName = ParseSubject(subject, "CN="); @@ -94,7 +97,7 @@ public async Task Enroll(string csr, string subject, Dictionar logger.LogTrace($"using vault role name {vaultRole}"); - signResponse = await _client.SignCSR(csr, subject, san, vaultRole); + signResponse = await hashiClient.SignCSR(csr, subject, san, vaultRole); // trace logs logger.LogTrace($"back to calling method"); @@ -142,10 +145,13 @@ public async Task Enroll(string csr, string subject, Dictionar public async Task GetSingleRecord(string caRequestID) { logger.MethodEntry(); + + var hashiClient = new HashicorpVaultClient(_caConfig); + logger.LogTrace($"preparing to send request to retrieve certificate with id {caRequestID}"); try { - var cert = await _client.GetCertificate(caRequestID); + var cert = await hashiClient.GetCertificate(caRequestID); logger.LogTrace($"got a response from the request.."); logger.LogTrace($"revocation time: {cert.RevocationTime}"); @@ -176,10 +182,11 @@ public async Task GetSingleRecord(string caRequestID) public async Task Ping() { logger.MethodEntry(); + var hashiClient = new HashicorpVaultClient(_caConfig); logger.LogTrace("Attempting ping of Vault endpoint"); try { - var result = await _client.PingServer(); + var result = await hashiClient.PingServer(); } catch (Exception ex) { @@ -198,11 +205,12 @@ public async Task Ping() /// The status of the request as an int representing EndEntityStatus public async Task Revoke(string caRequestID, string hexSerialNumber, uint revocationReason) { + var hashiClient = new HashicorpVaultClient(_caConfig); logger.MethodEntry(); logger.LogTrace($"Sending request to revoke certificate with id: {caRequestID}"); try { - var response = await _client.RevokeCertificate(caRequestID); + var response = await hashiClient.RevokeCertificate(caRequestID); logger.LogTrace($"returning 'REVOKED' EndEntityStatus ({(int)EndEntityStatus.REVOKED})"); return (int)EndEntityStatus.REVOKED; } @@ -219,107 +227,162 @@ public async Task Revoke(string caRequestID, string hexSerialNumber, uint r /// /// Provides information about the gateway's certificate database. /// Buffer into which certificates are places from the CA. - /// Information about the last CA sync. - /// The cancellation token. + /// The cancellation token. public async Task Synchronize(BlockingCollection blockingBuffer, DateTime? lastSync, bool fullSync, CancellationToken cancelToken) { // !! Any certificates issued outside of this CA Gateway will not necessarily be associated with the role name / (product ID) that was used to generate it // !! since that value is not retreivable after the initial generation. - logger.MethodEntry(); + logger.LogTrace("Beginning Synchronization Task.."); var certSerials = new List(); var count = 0; + var changedCount = 0; - try - { - logger.LogTrace("getting all certificate serial numbers from vault"); - certSerials = await _client.GetAllCertSerialNumbers(); - } - catch (Exception ex) - { - logger.LogError($"failed to retreive serial numbers: {LogHandler.FlattenException(ex)}"); - throw; - } + HashicorpVaultClient hashiClient; - logger.LogTrace($"got {certSerials.Count()} serial numbers. Begin checking status for each..."); - - foreach (var certSerial in certSerials) - { - CertResponse certFromVault = null; - var dbStatus = -1; + try + { // wrapper for single thread block - // first, retreive the details from Vault try { - logger.LogTrace($"Calling GetCertificate on our client, passing serial number: {certSerial}"); - certFromVault = await _client.GetCertificate(certSerial); + + hashiClient = new HashicorpVaultClient(_caConfig); + logger.LogTrace($"getting all certificate serial numbers from vault from {_caConfig.Host} using namespace {_caConfig.Namespace} and mount-point {_caConfig.MountPoint}"); + certSerials = await hashiClient.GetAllCertSerialNumbers(cancelToken); } catch (Exception ex) { - logger.LogError($"Failed to retreive details for certificate with serial number {certSerial} from Vault. Errors: {LogHandler.FlattenException(ex)}"); + logger.LogError($"failed to retreive serial numbers: {LogHandler.FlattenException(ex)}"); + blockingBuffer.CompleteAdding(); throw; } - logger.LogTrace($"converting {certSerial} to database trackingId"); - var trackingId = certSerial.Replace(":", "-"); // we store with '-'; hashi stores with ':' + logger.LogTrace($"got {certSerials?.Count() ?? 0} serial numbers. Begin checking status for each..."); - // then, check for an existing local entry - try - { - logger.LogTrace($"attempting to retreive status of cert with tracking id {trackingId} from the database"); - dbStatus = await _certificateDataReader.GetStatusByRequestID(trackingId); - } - catch - { - logger.LogTrace($"tracking id {trackingId} was not found in the database. it will be added."); + if (certSerials == null || certSerials.Count == 0) + { // exit if no certs were found in vault + blockingBuffer.CompleteAdding(); + logger.LogTrace($"no certificates found at path {_caConfig.Host} using namespace {_caConfig.Namespace} and mount point {_caConfig.MountPoint}"); + logger.MethodExit(); + return; } - if (dbStatus == -1 || fullSync) // it's missing and needs added, or a full sync is requested + foreach (var certSerial in certSerials) { - logger.LogTrace($"adding cert with serial {trackingId} to the database. fullsync is {fullSync}, and the certificate {(dbStatus == -1 ? "does not yet exist" : "already exists")} in the database."); - - var newCert = new AnyCAPluginCertificate + if (cancelToken.IsCancellationRequested) { - CARequestID = trackingId, - Certificate = certFromVault.Certificate, - Status = certFromVault.RevocationTime != null ? (int)EndEntityStatus.REVOKED : (int)EndEntityStatus.GENERATED, - RevocationDate = certFromVault.RevocationTime, - }; + break; + } + CertResponse certFromVault = null; + var dbStatus = -1; + // first, retreive the details from Vault try { - logger.LogTrace($"writing the result."); - blockingBuffer.Add(newCert); - logger.LogTrace($"successfully added certificate to the database."); + logger.LogTrace($"Calling GetCertificate on our client, passing serial number: {certSerial}"); + certFromVault = hashiClient.GetCertificate(certSerial).Result; } catch (Exception ex) { - logger.LogError($"Failed to add the cert to the database: {LogHandler.FlattenException(ex)}"); + logger.LogError($"Failed to retreive details for certificate with serial number {certSerial} from Vault. Errors: {LogHandler.FlattenException(ex)}"); + blockingBuffer.CompleteAdding(); throw; } - } - else // the cert exists in the database; just update the status if necessary - { - var revoked = certFromVault.RevocationTime != null; - var vaultStatus = revoked ? (int)EndEntityStatus.REVOKED : (int)EndEntityStatus.GENERATED; - if (vaultStatus != dbStatus) // if there is a mismatch, we need to update + logger.LogTrace($"converting {certSerial} to database trackingId"); + + var trackingId = certSerial.Replace(":", "-"); // we store with '-'; hashi stores with ':' + + // then, check for an existing local entry + try + { + logger.LogTrace($"attempting to retreive status of cert with tracking id {trackingId} from the database"); + dbStatus = await _certificateDataReader.GetStatusByRequestID(trackingId); + } + catch (Exception ex) // an exception is thrown if cert doesn't exist { + if (!ex.Message.Contains("No matching")) + { // if the exception message doesn't contain this; something else happened. + logger.LogTrace($"exception when retrieving cert from database: {ex.Message}"); + blockingBuffer.CompleteAdding(); + throw; + } + logger.LogTrace($"tracking id {trackingId} was not found in the database. it will be added."); + } + + if (dbStatus == -1 || fullSync) // it's missing and needs added, or a full sync is requested + { + logger.LogTrace($"adding cert with serial {trackingId} to the database. fullsync is {fullSync}, and the certificate {(dbStatus == -1 ? "does not yet exist" : "already exists")} in the database."); + changedCount++; var newCert = new AnyCAPluginCertificate { CARequestID = trackingId, Certificate = certFromVault.Certificate, - Status = vaultStatus, - RevocationDate = certFromVault.RevocationTime - // ProductID is not available via the API after the initial issuance. we do not want to overwrite + Status = certFromVault.RevocationTime != null ? (int)EndEntityStatus.REVOKED : (int)EndEntityStatus.GENERATED, + RevocationDate = certFromVault.RevocationTime, }; + + try + { + logger.LogTrace($"writing the result."); + if (blockingBuffer.TryAdd(newCert, 50, cancelToken)) + { + logger.LogTrace($"successfully added certificate to the database."); + } + else + { + logger.LogTrace($"adding to queue for writing was blocked."); + } + } + catch (Exception ex) + { + logger.LogError($"Failed to add the cert to the database: {LogHandler.FlattenException(ex)}"); + logger.LogTrace($"closing buffer and aborting sync"); + blockingBuffer.CompleteAdding(); + throw; + } } + else // the cert exists in the database; just update the status if necessary + { + logger.LogTrace($"certificate with id {trackingId} was found in the database, comparing status."); + var revoked = certFromVault.RevocationTime != null; + var vaultStatus = revoked ? (int)EndEntityStatus.REVOKED : (int)EndEntityStatus.GENERATED; + if (vaultStatus != dbStatus) // if there is a mismatch, we need to update + { + changedCount++; + logger.LogTrace($"status in vault is {vaultStatus}, status in db is {dbStatus}; updating db."); + var newCert = new AnyCAPluginCertificate + { + CARequestID = trackingId, + Certificate = certFromVault.Certificate, + Status = vaultStatus, + RevocationDate = certFromVault.RevocationTime + // ProductID is not available via the API after the initial issuance. we do not want to overwrite + }; + if (blockingBuffer.TryAdd(newCert, 50, cancelToken)) + { + logger.LogTrace($"successfully updated certificate {trackingId} in the database."); + } + else + { + logger.LogTrace($"adding to queue for writing was blocked."); + } + } + else + { + logger.LogTrace($"The status is unchanged and we are doing an incremental scan; no need to update db."); + } + } + count++; } - count++; + blockingBuffer.CompleteAdding(); + logger.LogTrace($"Completed sync of {count} certificates. It was {(fullSync ? "a full" : "an incremental")} sync and {changedCount} records were updated."); } - logger.LogTrace($"Completed sync of {count} certificates"); - logger.MethodExit(); + finally + { + logger.MethodExit(); + } } /// @@ -352,18 +415,22 @@ public async Task ValidateCAConnectionInfo(Dictionary connection // make sure an authentication mechanism is defined (either certificate or token) var token = connectionInfo[Constants.CAConfig.TOKEN] as string; - var cert = connectionInfo[Constants.CAConfig.CLIENTCERT] as string; - if (string.IsNullOrEmpty(token) && string.IsNullOrEmpty(cert)) - { - errors.Add("Either an authentication token or client certificate must be defined for authentication into Vault."); - } - if (!string.IsNullOrEmpty(token) && !string.IsNullOrEmpty(cert)) - { - logger.LogWarning("Both an authentication token and client certificate are defined. Using the token for authentication."); - } + /// REMOVING CERT VALIDATION UNTIL CLIENT CERT AUTH IS IMPLEMENTED + + //var cert = connectionInfo[Constants.CAConfig.CLIENTCERT] as string; + + //if (string.IsNullOrEmpty(token) && string.IsNullOrEmpty(cert)) + //{ + // errors.Add("Either an authentication token or client certificate must be defined for authentication into Vault."); + //} + //if (!string.IsNullOrEmpty(token) && !string.IsNullOrEmpty(cert)) + //{ + // logger.LogWarning("Both an authentication token and client certificate are defined. Using the token for authentication."); + //} // if any errors, throw + if (errors.Any()) { var allErrors = string.Join("\n", errors); @@ -397,13 +464,13 @@ public async Task ValidateCAConnectionInfo(Dictionary connection // create an instance of our client with those values - _client = new HashicorpVaultClient(config); + var hashiClient = new HashicorpVaultClient(config); // attempt an authenticated request to retreive role names try { logger.LogTrace("making an authenticated request to the Vault server to verify credentials (listing role names).."); - var roleNames = await _client.GetRoleNamesAsync(); + var roleNames = await hashiClient.GetRoleNamesAsync(); logger.LogTrace($"successfule request: received a response containing {roleNames.Count} role names"); } catch (Exception ex) @@ -439,10 +506,15 @@ public Task ValidateProductInfo(EnrollmentProductInfo productInfo, Dictionary ProductIdIsValid(string productID, HashicorpVaultCAConfig config) + { + var hashiClient = new HashicorpVaultClient(config); + + // attempt an authenticated request to retreive role names + try + { + logger.LogTrace("making an authenticated request to the Vault server to verify credentials (listing role names).."); + var roleNames = await hashiClient.GetRoleNamesAsync(); + logger.LogTrace($"successfule request: received a response containing {roleNames.Count} role names"); + return roleNames.Any(rn => rn == productID); + } + catch (Exception ex) + { + logger.LogError($"Authenticated request failed. {ex.Message}"); + throw; + } + finally { logger.MethodExit(); } + } + /// /// Gets annotations for the CA connector properties. /// @@ -545,11 +637,11 @@ public Dictionary GetTemplateParameterAnnotations() public List GetProductIds() { logger.MethodEntry(); - // Initialize should have been called in order to populate the caConfig and create the client. + var hashiClient = new HashicorpVaultClient(_caConfig); try { logger.LogTrace("requesting role names from vault.."); - var roleNames = _client.GetRoleNamesAsync().Result; + var roleNames = hashiClient.GetRoleNamesAsync().Result; logger.LogTrace($"got {roleNames.Count} role names from vault:"); foreach (var name in roleNames) {