diff --git a/doc/DistributedTracing.md b/doc/DistributedTracing.md index 40c787180b..ad4eb4ff09 100644 --- a/doc/DistributedTracing.md +++ b/doc/DistributedTracing.md @@ -158,8 +158,7 @@ And construct the new tracing factory in the service constructor: Azure::Core::Context const& context = Azure::Core::Context{}) { // Create a new context and span for this request. - auto contextAndSpan = m_tracingFactory.CreateSpan( - "ServiceMethod", Azure::Core::Tracing::_internal::SpanKind::Internal, context); + auto contextAndSpan = m_tracingFactory.CreateSpan("ServiceMethod", context); // contextAndSpan.Context is the new context for the operation. // contextAndSpan.Span is the new span for the operation. @@ -179,7 +178,6 @@ And construct the new tracing factory in the service constructor: { // Register that the exception has happened and that the span is now in error. contextAndSpan.Span.AddEvent(ex); - contextAndSpan.Span.SetStatus(Azure::Core::Tracing::_internal::SpanStatus::Error); throw; } diff --git a/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_administration_client.hpp b/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_administration_client.hpp index bafac89517..2cd86ca987 100644 --- a/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_administration_client.hpp +++ b/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_administration_client.hpp @@ -6,6 +6,7 @@ #include "azure/attestation/attestation_client_models.hpp" #include "azure/attestation/attestation_client_options.hpp" #include +#include #include #include @@ -68,10 +69,7 @@ namespace Azure { namespace Security { namespace Attestation { * @param attestationClient An existing attestation client. */ AttestationAdministrationClient(AttestationAdministrationClient const& attestationClient) - : m_endpoint(attestationClient.m_endpoint), m_apiVersion(attestationClient.m_apiVersion), - m_pipeline(attestationClient.m_pipeline), - m_tokenValidationOptions(attestationClient.m_tokenValidationOptions), - m_attestationSigners(attestationClient.m_attestationSigners){}; + = default; /** * @brief Destructor. @@ -258,6 +256,7 @@ namespace Azure { namespace Security { namespace Attestation { std::shared_ptr m_credentials; std::shared_ptr m_pipeline; AttestationTokenValidationOptions m_tokenValidationOptions; + Azure::Core::Tracing::_internal::TracingContextFactory m_tracingFactory; std::vector m_attestationSigners; diff --git a/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_client.hpp b/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_client.hpp index 2d55b76768..69cbb039b9 100644 --- a/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_client.hpp +++ b/sdk/attestation/azure-security-attestation/inc/azure/attestation/attestation_client.hpp @@ -6,6 +6,7 @@ #include "azure/attestation/attestation_client_models.hpp" #include "azure/attestation/attestation_client_options.hpp" #include +#include #include #include @@ -162,11 +163,7 @@ namespace Azure { namespace Security { namespace Attestation { * * @param attestationClient An existing attestation client. */ - AttestationClient(AttestationClient const& attestationClient) - : m_endpoint(attestationClient.m_endpoint), m_apiVersion(attestationClient.m_apiVersion), - m_pipeline(attestationClient.m_pipeline), - m_tokenValidationOptions(attestationClient.m_tokenValidationOptions), - m_attestationSigners(attestationClient.m_attestationSigners){}; + AttestationClient(AttestationClient const& attestationClient) = default; std::string const Endpoint() const { return m_endpoint.GetAbsoluteUrl(); } @@ -255,6 +252,7 @@ namespace Azure { namespace Security { namespace Attestation { std::shared_ptr m_pipeline; AttestationTokenValidationOptions m_tokenValidationOptions; std::vector m_attestationSigners; + Azure::Core::Tracing::_internal::TracingContextFactory m_tracingFactory; /** @brief Construct a new Attestation Client object * diff --git a/sdk/attestation/azure-security-attestation/src/attestation_administration_client.cpp b/sdk/attestation/azure-security-attestation/src/attestation_administration_client.cpp index 9ad1c63e36..1728294d73 100644 --- a/sdk/attestation/azure-security-attestation/src/attestation_administration_client.cpp +++ b/sdk/attestation/azure-security-attestation/src/attestation_administration_client.cpp @@ -21,6 +21,7 @@ using namespace Azure::Security::Attestation; using namespace Azure::Security::Attestation::Models; using namespace Azure::Security::Attestation::_detail; using namespace Azure::Security::Attestation::Models::_detail; +using namespace Azure::Core::Tracing::_internal; using namespace Azure::Core::Http; using namespace Azure::Core::Http::Policies; using namespace Azure::Core::Http::Policies::_internal; @@ -42,7 +43,8 @@ AttestationAdministrationClient::AttestationAdministrationClient( std::shared_ptr credential, AttestationAdministrationClientOptions const& options) : m_endpoint(endpoint), m_apiVersion(options.Version.ToString()), - m_tokenValidationOptions(options.TokenValidationOptions) + m_tokenValidationOptions(options.TokenValidationOptions), + m_tracingFactory(options, "security.attestation", PackageVersion::ToString()) { std::vector> perRetrypolicies; if (credential) @@ -58,11 +60,7 @@ AttestationAdministrationClient::AttestationAdministrationClient( std::vector> perCallpolicies; m_pipeline = std::make_shared( - options, - "Attestation", - PackageVersion::ToString(), - std::move(perRetrypolicies), - std::move(perCallpolicies)); + options, std::move(perRetrypolicies), std::move(perCallpolicies)); } AttestationAdministrationClient AttestationAdministrationClient::Create( @@ -86,53 +84,64 @@ AttestationAdministrationClient::GetAttestationPolicy( GetPolicyOptions const& options, Azure::Core::Context const& context) const { - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, - m_apiVersion, - HttpMethod::Get, - {"policies/" + attestationType.ToString()}, - nullptr); - - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + auto tracingContext(m_tracingFactory.CreateTracingContext("GetAttestationPolicy", context)); + try + { - // Deserialize the Service response token and return the JSON web token returned by the - // service. - std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, + m_apiVersion, + HttpMethod::Get, + {"policies/" + attestationType.ToString()}, + nullptr); - // Parse the JWT returned by the attestation service. - const auto resultToken - = AttestationTokenInternal( - responseToken); + // Send the request to the service. + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - // Validate the token returned by the service. Use the cached attestation signers in the - // validation. - resultToken.ValidateToken( - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions, - m_attestationSigners); + // Deserialize the Service response token and return the JSON web token returned by the + // service. + std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + + // Parse the JWT returned by the attestation service. + const auto resultToken + = AttestationTokenInternal( + responseToken); + + // Validate the token returned by the service. Use the cached attestation signers in the + // validation. + resultToken.ValidateToken( + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions, + m_attestationSigners); + + // Extract the underlying policy token from the response. + std::string policyTokenValue + = *static_cast>(resultToken) + .Body.PolicyToken; + + // TPM policies are empty by default, at least in our test instances, so handle the empty policy + // token case. + const auto policyTokenI + = AttestationTokenInternal( + policyTokenValue); + AttestationToken policyToken(policyTokenI); + std::string returnPolicy; + if (policyToken.Body.AttestationPolicy) + { + std::vector policyUtf8 = *policyToken.Body.AttestationPolicy; + returnPolicy = std::string(policyUtf8.begin(), policyUtf8.end()); + } - // Extract the underlying policy token from the response. - std::string policyTokenValue - = *static_cast>(resultToken).Body.PolicyToken; - - // TPM policies are empty by default, at least in our test instances, so handle the empty policy - // token case. - const auto policyTokenI - = AttestationTokenInternal( - policyTokenValue); - AttestationToken policyToken(policyTokenI); - std::string returnPolicy; - if (policyToken.Body.AttestationPolicy) + // Construct a token whose body is the policy, but whose token is the response from the + // service. + const auto returnedToken = AttestationTokenInternal(responseToken, &returnPolicy); + return Response>(returnedToken, std::move(response)); + } + catch (std::runtime_error const& ex) { - std::vector policyUtf8 = *policyToken.Body.AttestationPolicy; - returnPolicy = std::string(policyUtf8.begin(), policyUtf8.end()); + tracingContext.Span.AddEvent(ex); + throw; } - - // Construct a token whose body is the policy, but whose token is the response from the - // service. - const auto returnedToken = AttestationTokenInternal(responseToken, &returnPolicy); - return Response>(returnedToken, std::move(response)); } Models::AttestationToken AttestationAdministrationClient::CreateAttestationPolicyToken( @@ -163,62 +172,72 @@ AttestationAdministrationClient::SetAttestationPolicy( SetPolicyOptions const& options, Azure::Core::Context const& context) const { - // Calculate a signed (or unsigned) attestation policy token to send to the service. - Models::AttestationToken const tokenToSend( - CreateAttestationPolicyToken(newAttestationPolicy, options.SigningKey)); + auto tracingContext(m_tracingFactory.CreateTracingContext("SetAttestationPolicy", context)); + try + { + // Calculate a signed (or unsigned) attestation policy token to send to the service. + Models::AttestationToken const tokenToSend( + CreateAttestationPolicyToken(newAttestationPolicy, options.SigningKey)); - Azure::Core::IO::MemoryBodyStream stream( - reinterpret_cast(tokenToSend.RawToken.data()), tokenToSend.RawToken.size()); + Azure::Core::IO::MemoryBodyStream stream( + reinterpret_cast(tokenToSend.RawToken.data()), tokenToSend.RawToken.size()); - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, - m_apiVersion, - HttpMethod::Put, - {"policies/" + attestationType.ToString()}, - &stream); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, + m_apiVersion, + HttpMethod::Put, + {"policies/" + attestationType.ToString()}, + &stream); - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + // Send the request to the service. + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - // Deserialize the Service response token and return the JSON web token returned by the - // service. - std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + // Deserialize the Service response token and return the JSON web token returned by the + // service. + std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); - // Parse the JWT returned by the attestation service. - auto resultToken - = AttestationTokenInternal( - responseToken); + // Parse the JWT returned by the attestation service. + auto resultToken + = AttestationTokenInternal( + responseToken); - // Validate the token returned by the service. Use the cached attestation signers in the - // validation. - resultToken.ValidateToken( - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions, - m_attestationSigners); + // Validate the token returned by the service. Use the cached attestation signers in the + // validation. + resultToken.ValidateToken( + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions, + m_attestationSigners); - // Extract the underlying policy token from the response. - auto internalResult - = static_cast>(resultToken).Body; + // Extract the underlying policy token from the response. + auto internalResult + = static_cast>(resultToken).Body; - Models::PolicyResult returnedResult; - if (internalResult.PolicyResolution) - { - returnedResult.PolicyResolution = Models::PolicyModification(*internalResult.PolicyResolution); - } - if (internalResult.PolicySigner) - { - returnedResult.PolicySigner = AttestationSignerInternal(*internalResult.PolicySigner); + Models::PolicyResult returnedResult; + if (internalResult.PolicyResolution) + { + returnedResult.PolicyResolution + = Models::PolicyModification(*internalResult.PolicyResolution); + } + if (internalResult.PolicySigner) + { + returnedResult.PolicySigner = AttestationSignerInternal(*internalResult.PolicySigner); + } + if (internalResult.PolicyTokenHash) + { + returnedResult.PolicyTokenHash = Base64Url::Base64UrlDecode(*internalResult.PolicyTokenHash); + } + + // Construct a token whose body is the policy result, but whose token is the response from + // the service. + auto returnedToken + = AttestationTokenInternal(responseToken, &returnedResult); + return Response>(returnedToken, std::move(response)); } - if (internalResult.PolicyTokenHash) + catch (std::runtime_error const& ex) { - returnedResult.PolicyTokenHash = Base64Url::Base64UrlDecode(*internalResult.PolicyTokenHash); + tracingContext.Span.AddEvent(ex); + throw; } - - // Construct a token whose body is the policy result, but whose token is the response from the - // service. - auto returnedToken - = AttestationTokenInternal(responseToken, &returnedResult); - return Response>(returnedToken, std::move(response)); } Azure::Response> @@ -227,67 +246,78 @@ AttestationAdministrationClient::ResetAttestationPolicy( SetPolicyOptions const& options, Azure::Core::Context const& context) const { - // Calculate a signed (or unsigned) attestation policy token to send to the service. - Models::AttestationToken tokenToSend( - CreateAttestationPolicyToken(Azure::Nullable(), options.SigningKey)); + auto tracingContext(m_tracingFactory.CreateTracingContext("ResetAttestationPolicy", context)); + try + { + // Calculate a signed (or unsigned) attestation policy token to send to the service. + Models::AttestationToken tokenToSend( + CreateAttestationPolicyToken(Azure::Nullable(), options.SigningKey)); - Azure::Core::IO::MemoryBodyStream stream( - reinterpret_cast(tokenToSend.RawToken.data()), tokenToSend.RawToken.size()); + Azure::Core::IO::MemoryBodyStream stream( + reinterpret_cast(tokenToSend.RawToken.data()), tokenToSend.RawToken.size()); - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, - m_apiVersion, - HttpMethod::Post, - {"policies/" + attestationType.ToString() + ":reset"}, - &stream); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, + m_apiVersion, + HttpMethod::Post, + {"policies/" + attestationType.ToString() + ":reset"}, + &stream); - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + // Send the request to the service. + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - // Deserialize the Service response token and return the JSON web token returned by the - // service. - std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + // Deserialize the Service response token and return the JSON web token returned by the + // service. + std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); - // Parse the JWT returned by the attestation service. - auto resultToken - = AttestationTokenInternal( - responseToken); + // Parse the JWT returned by the attestation service. + auto resultToken + = AttestationTokenInternal( + responseToken); - // Validate the token returned by the service. Use the cached attestation signers in the - // validation. - resultToken.ValidateToken( - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions, - m_attestationSigners); + // Validate the token returned by the service. Use the cached attestation signers in the + // validation. + resultToken.ValidateToken( + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions, + m_attestationSigners); - // Extract the underlying policy token from the response. - auto internalResult - = static_cast>(resultToken).Body; + // Extract the underlying policy token from the response. + auto internalResult + = static_cast>(resultToken).Body; - Models::PolicyResult returnedResult; - if (internalResult.PolicyResolution) + Models::PolicyResult returnedResult; + if (internalResult.PolicyResolution) + { + returnedResult.PolicyResolution + = Models::PolicyModification(*internalResult.PolicyResolution); + } + // Note that the attestation service currently never returns these values on Reset, even + // though they are meaningful. Commenting them out to improve code coverage numbers. At + // some point the attestation service may start returning these values, at which point + // they can be un-commented out. + // if (internalResult.PolicySigner) + // { + // returnedResult.PolicySigner = + // AttestationSignerInternal(*internalResult.PolicySigner); + // } + // if (internalResult.PolicyTokenHash) + // { + // returnedResult.PolicyTokenHash = + // Base64Url::Base64UrlDecode(*internalResult.PolicyTokenHash); + // } + + // Construct a token whose body is the policy result, but whose token is the response from + // the service. + auto returnedToken + = AttestationTokenInternal(responseToken, &returnedResult); + return Response>(returnedToken, std::move(response)); + } + catch (std::runtime_error const& ex) { - returnedResult.PolicyResolution = Models::PolicyModification(*internalResult.PolicyResolution); + tracingContext.Span.AddEvent(ex); + throw; } - // Note that the attestation service currently never returns these values on Reset, even though - // they are meaningful. Commenting them out to improve code coverage numbers. At some point the - // attestation service may start returning these values, at which point they can be un-commented - // out. - // if (internalResult.PolicySigner) - // { - // returnedResult.PolicySigner = AttestationSignerInternal(*internalResult.PolicySigner); - // } - // if (internalResult.PolicyTokenHash) - // { - // returnedResult.PolicyTokenHash = - // Base64Url::Base64UrlDecode(*internalResult.PolicyTokenHash); - // } - - // Construct a token whose body is the policy result, but whose token is the response from the - // service. - auto returnedToken - = AttestationTokenInternal(responseToken, &returnedResult); - return Response>(returnedToken, std::move(response)); } Azure::Response> @@ -295,44 +325,54 @@ AttestationAdministrationClient::GetIsolatedModeCertificates( GetIsolatedModeCertificatesOptions const& options, Azure::Core::Context const& context) const { - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, m_apiVersion, HttpMethod::Get, {"certificates"}, nullptr); - - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + auto tracingContext( + m_tracingFactory.CreateTracingContext("GetIsolatedModeCertificates", context)); + try + { + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, m_apiVersion, HttpMethod::Get, {"certificates"}, nullptr); - // Deserialize the Service response token and return the JSON web token returned by the - // service. - std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + // Send the request to the service. + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - // Parse the JWT returned by the attestation service. - auto resultToken = AttestationTokenInternal< - Models::_detail::GetIsolatedModeCertificatesResult, - IsolatedModeCertificateGetResultSerializer>(responseToken); + // Deserialize the Service response token and return the JSON web token returned by the + // service. + std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + + // Parse the JWT returned by the attestation service. + auto resultToken = AttestationTokenInternal< + Models::_detail::GetIsolatedModeCertificatesResult, + IsolatedModeCertificateGetResultSerializer>(responseToken); + + // Validate the token returned by the service. Use the cached attestation signers in the + // validation. + resultToken.ValidateToken( + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions, + m_attestationSigners); + + Models::_detail::JsonWebKeySet jwks( + *static_cast>( + resultToken) + .Body.PolicyCertificates); + Models::IsolatedModeCertificateListResult returnedResult; + for (const auto& certificate : jwks.Keys) + { + returnedResult.Certificates.push_back(AttestationSignerInternal(certificate)); + } - // Validate the token returned by the service. Use the cached attestation signers in the - // validation. - resultToken.ValidateToken( - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions, - m_attestationSigners); - - Models::_detail::JsonWebKeySet jwks( - *static_cast>( - resultToken) - .Body.PolicyCertificates); - Models::IsolatedModeCertificateListResult returnedResult; - for (const auto& certificate : jwks.Keys) + // Construct a token whose body is the get policy certificates result, but whose token + // is the response from the service. + auto returnedToken = AttestationTokenInternal( + responseToken, &returnedResult); + return Response>( + returnedToken, std::move(response)); + } + catch (std::runtime_error const& ex) { - returnedResult.Certificates.push_back(AttestationSignerInternal(certificate)); + tracingContext.Span.AddEvent(ex); + throw; } - - // Construct a token whose body is the get policy certificates result, but whose token is the - // response from the service. - auto returnedToken = AttestationTokenInternal( - responseToken, &returnedResult); - return Response>( - returnedToken, std::move(response)); } std::string AttestationAdministrationClient::CreateIsolatedModeModificationToken( @@ -397,8 +437,8 @@ AttestationAdministrationClient::ProcessIsolatedModeModificationResult( returnValue.CertificateThumbprint = (*internalResult.CertificateThumbprint); } - // Construct a token whose body is the policy result, but whose token is the response from the - // service. + // Construct a token whose body is the policy result, but whose token is the response + // from the service. auto const returnedToken = AttestationTokenInternal( responseToken, &returnValue); @@ -412,23 +452,32 @@ AttestationAdministrationClient::AddIsolatedModeCertificate( AddIsolatedModeCertificateOptions const& options, Azure::Core::Context const& context) const { - auto const policyCertToken( - CreateIsolatedModeModificationToken(pemEncodedX509CertificateToAdd, existingSigningKey)); - Azure::Core::IO::MemoryBodyStream stream( - reinterpret_cast(policyCertToken.data()), policyCertToken.size()); - - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, m_apiVersion, HttpMethod::Post, {"certificates:add"}, &stream); - - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - AttestationToken returnValue( - ProcessIsolatedModeModificationResult( - response, - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions)); - return Response>( - returnValue, std::move(response)); + auto tracingContext(m_tracingFactory.CreateTracingContext("AddIsolatedModeCertificate", context)); + try + { + auto const policyCertToken( + CreateIsolatedModeModificationToken(pemEncodedX509CertificateToAdd, existingSigningKey)); + Azure::Core::IO::MemoryBodyStream stream( + reinterpret_cast(policyCertToken.data()), policyCertToken.size()); + + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, m_apiVersion, HttpMethod::Post, {"certificates:add"}, &stream); + + // Send the request to the service. + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + AttestationToken returnValue( + ProcessIsolatedModeModificationResult( + response, + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions)); + return Response>( + returnValue, std::move(response)); + } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } Azure::Response> @@ -438,31 +487,41 @@ AttestationAdministrationClient::RemoveIsolatedModeCertificate( RemoveIsolatedModeCertificateOptions const& options, Azure::Core::Context const& context) const { - // Calculate a signed (or unsigned) attestation policy token to send to the service. - // Embed the encoded policy in the StoredAttestationPolicy. - auto const policyCertToken( - CreateIsolatedModeModificationToken(pemEncodedX509CertificateToRemove, existingSigningKey)); - - Azure::Core::IO::MemoryBodyStream stream( - reinterpret_cast(policyCertToken.data()), policyCertToken.size()); - - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, m_apiVersion, HttpMethod::Post, {"certificates:remove"}, &stream); - - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - AttestationToken returnValue( - ProcessIsolatedModeModificationResult( - response, - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions)); - return Response>( - returnValue, std::move(response)); + auto tracingContext( + m_tracingFactory.CreateTracingContext("RemoveIsolatedModeCertificate", context)); + try + { + // Calculate a signed (or unsigned) attestation policy token to send to the service. + // Embed the encoded policy in the StoredAttestationPolicy. + auto const policyCertToken( + CreateIsolatedModeModificationToken(pemEncodedX509CertificateToRemove, existingSigningKey)); + + Azure::Core::IO::MemoryBodyStream stream( + reinterpret_cast(policyCertToken.data()), policyCertToken.size()); + + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, m_apiVersion, HttpMethod::Post, {"certificates:remove"}, &stream); + + // Send the request to the service. + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + AttestationToken returnValue( + ProcessIsolatedModeModificationResult( + response, + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions)); + return Response>( + returnValue, std::move(response)); + } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } /** - * @brief Retrieves the information needed to validate the response returned from the attestation - * service. + * @brief Retrieves the information needed to validate the response returned from the + * attestation service. * * @details Validating the response returned by the attestation service requires a set of * possible signers for the attestation token. @@ -472,26 +531,35 @@ AttestationAdministrationClient::RemoveIsolatedModeCertificate( void AttestationAdministrationClient::RetrieveResponseValidationCollateral( Azure::Core::Context const& context) { - std::unique_lock stateLock(SharedStateLock); - - if (m_attestationSigners.empty()) + auto tracingContext(m_tracingFactory.CreateTracingContext("Create", context)); + try { - stateLock.unlock(); - auto request - = AttestationCommonRequest::CreateRequest(m_endpoint, HttpMethod::Get, {"certs"}, nullptr); - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - auto jsonWebKeySet(JsonWebKeySetSerializer::Deserialize(response)); - TokenValidationCertificateResult returnValue; - std::vector newValue; - for (const auto& jwk : jsonWebKeySet.Keys) - { - AttestationSignerInternal internalSigner(jwk); - newValue.push_back(internalSigner); - } - stateLock.lock(); + std::unique_lock stateLock(SharedStateLock); + if (m_attestationSigners.empty()) { - m_attestationSigners = newValue; + stateLock.unlock(); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, HttpMethod::Get, {"certs"}, nullptr); + auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); + auto jsonWebKeySet(JsonWebKeySetSerializer::Deserialize(response)); + TokenValidationCertificateResult returnValue; + std::vector newValue; + for (const auto& jwk : jsonWebKeySet.Keys) + { + AttestationSignerInternal internalSigner(jwk); + newValue.push_back(internalSigner); + } + stateLock.lock(); + if (m_attestationSigners.empty()) + { + m_attestationSigners = newValue; + } } } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } diff --git a/sdk/attestation/azure-security-attestation/src/attestation_client.cpp b/sdk/attestation/azure-security-attestation/src/attestation_client.cpp index d39ea699dd..ee970ed8c3 100644 --- a/sdk/attestation/azure-security-attestation/src/attestation_client.cpp +++ b/sdk/attestation/azure-security-attestation/src/attestation_client.cpp @@ -21,6 +21,7 @@ using namespace Azure::Security::Attestation; using namespace Azure::Security::Attestation::Models; using namespace Azure::Security::Attestation::_detail; using namespace Azure::Security::Attestation::Models::_detail; +using namespace Azure::Core::Tracing::_internal; using namespace Azure::Core::Http; using namespace Azure::Core::Http::Policies; using namespace Azure::Core::Http::Policies::_internal; @@ -31,7 +32,8 @@ AttestationClient::AttestationClient( std::shared_ptr credential, AttestationClientOptions options) : m_endpoint(endpoint), m_credentials(credential), - m_tokenValidationOptions(options.TokenValidationOptions) + m_tokenValidationOptions(options.TokenValidationOptions), + m_tracingFactory(options, "security.attestation", PackageVersion::ToString()) { std::vector> perRetrypolicies; if (credential) @@ -47,39 +49,58 @@ AttestationClient::AttestationClient( std::vector> perCallpolicies; m_pipeline = std::make_shared( - options, - "Attestation", - PackageVersion::ToString(), - std::move(perRetrypolicies), - std::move(perCallpolicies)); + options, std::move(perRetrypolicies), std::move(perCallpolicies)); } Azure::Response AttestationClient::GetOpenIdMetadata( Azure::Core::Context const& context) const { - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, HttpMethod::Get, {".well-known/openid-configuration"}, nullptr); + auto tracingContext(m_tracingFactory.CreateTracingContext("GetOpenIdMetadata", context)); + try + { + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, HttpMethod::Get, {".well-known/openid-configuration"}, nullptr); + + auto response + = AttestationCommonRequest::SendRequest(*m_pipeline, request, tracingContext.Context); + auto openIdMetadata(OpenIdMetadataSerializer::Deserialize(response)); - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - auto openIdMetadata(OpenIdMetadataSerializer::Deserialize(response)); - return Response(std::move(openIdMetadata), std::move(response)); + return Response(std::move(openIdMetadata), std::move(response)); + } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } Azure::Response AttestationClient::GetTokenValidationCertificates( Azure::Core::Context const& context) const { - auto request - = AttestationCommonRequest::CreateRequest(m_endpoint, HttpMethod::Get, {"certs"}, nullptr); + auto tracingContext( + m_tracingFactory.CreateTracingContext("GetTokenValidationCertificates", context)); + try + { - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - auto jsonWebKeySet(JsonWebKeySetSerializer::Deserialize(response)); - TokenValidationCertificateResult returnValue; - for (const auto& jwk : jsonWebKeySet.Keys) + auto request + = AttestationCommonRequest::CreateRequest(m_endpoint, HttpMethod::Get, {"certs"}, nullptr); + + auto response + = AttestationCommonRequest::SendRequest(*m_pipeline, request, tracingContext.Context); + auto jsonWebKeySet(JsonWebKeySetSerializer::Deserialize(response)); + TokenValidationCertificateResult returnValue; + for (const auto& jwk : jsonWebKeySet.Keys) + { + AttestationSignerInternal internalSigner(jwk); + returnValue.Signers.push_back(internalSigner); + } + return Response(returnValue, std::move(response)); + } + catch (std::runtime_error const& ex) { - AttestationSignerInternal internalSigner(jwk); - returnValue.Signers.push_back(internalSigner); + tracingContext.Span.AddEvent(ex); + throw; } - return Response(returnValue, std::move(response)); } Azure::Response> AttestationClient::AttestSgxEnclave( @@ -87,41 +108,53 @@ Azure::Response> AttestationClient::AttestSg AttestSgxEnclaveOptions options, Azure::Core::Context const& context) const { - AttestSgxEnclaveRequest attestRequest{ - sgxQuote, - options.InitTimeData, - options.RunTimeData, - options.DraftPolicyForAttestation, - options.Nonce}; - - const std::string serializedRequest(AttestSgxEnclaveRequestSerializer::Serialize(attestRequest)); - - const auto encodedVector - = std::vector(serializedRequest.begin(), serializedRequest.end()); - Azure::Core::IO::MemoryBodyStream stream(encodedVector); - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, m_apiVersion, HttpMethod::Post, {"attest/SgxEnclave"}, &stream); - - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - - // Deserialize the Service response token and return the JSON web token returned by the service. - std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); - - // Parse the JWT returned by the attestation service. - auto const token - = AttestationTokenInternal(responseToken); - - // Validate the token returned by the service. Use the cached attestation signers in the - // validation. - token.ValidateToken( - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions, - m_attestationSigners); - - // And return the attestation result to the caller. - auto returnedToken = AttestationToken(token); - return Response>(returnedToken, std::move(response)); + auto tracingContext(m_tracingFactory.CreateTracingContext("AttestSgxEnclave", context)); + try + { + + AttestSgxEnclaveRequest attestRequest{ + sgxQuote, + options.InitTimeData, + options.RunTimeData, + options.DraftPolicyForAttestation, + options.Nonce}; + + const std::string serializedRequest( + AttestSgxEnclaveRequestSerializer::Serialize(attestRequest)); + + const auto encodedVector + = std::vector(serializedRequest.begin(), serializedRequest.end()); + Azure::Core::IO::MemoryBodyStream stream(encodedVector); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, m_apiVersion, HttpMethod::Post, {"attest/SgxEnclave"}, &stream); + + // Send the request to the service. + auto response + = AttestationCommonRequest::SendRequest(*m_pipeline, request, tracingContext.Context); + + // Deserialize the Service response token and return the JSON web token returned by the service. + std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + + // Parse the JWT returned by the attestation service. + auto const token + = AttestationTokenInternal(responseToken); + + // Validate the token returned by the service. Use the cached attestation signers in the + // validation. + token.ValidateToken( + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions, + m_attestationSigners); + + // And return the attestation result to the caller. + auto returnedToken = AttestationToken(token); + return Response>(returnedToken, std::move(response)); + } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } Azure::Response> AttestationClient::AttestOpenEnclave( @@ -129,46 +162,66 @@ Azure::Response> AttestationClient::AttestOp AttestOpenEnclaveOptions options, Azure::Core::Context const& context) const { - AttestOpenEnclaveRequest attestRequest{ - openEnclaveReport, - options.InitTimeData, - options.RunTimeData, - options.DraftPolicyForAttestation, - options.Nonce}; - std::string serializedRequest(AttestOpenEnclaveRequestSerializer::Serialize(attestRequest)); - - auto encodedVector = std::vector(serializedRequest.begin(), serializedRequest.end()); - Azure::Core::IO::MemoryBodyStream stream(encodedVector); - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, m_apiVersion, HttpMethod::Post, {"attest/OpenEnclave"}, &stream); - - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); - auto token - = AttestationTokenInternal(responseToken); - token.ValidateToken( - options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride - : this->m_tokenValidationOptions, - m_attestationSigners); - - return Response>(token, std::move(response)); + auto tracingContext(m_tracingFactory.CreateTracingContext("AttestOpenEnclave", context)); + try + { + AttestOpenEnclaveRequest attestRequest{ + openEnclaveReport, + options.InitTimeData, + options.RunTimeData, + options.DraftPolicyForAttestation, + options.Nonce}; + std::string serializedRequest(AttestOpenEnclaveRequestSerializer::Serialize(attestRequest)); + + auto encodedVector = std::vector(serializedRequest.begin(), serializedRequest.end()); + Azure::Core::IO::MemoryBodyStream stream(encodedVector); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, m_apiVersion, HttpMethod::Post, {"attest/OpenEnclave"}, &stream); + + auto response + = AttestationCommonRequest::SendRequest(*m_pipeline, request, tracingContext.Context); + std::string responseToken = AttestationServiceTokenResponseSerializer::Deserialize(response); + auto token + = AttestationTokenInternal(responseToken); + token.ValidateToken( + options.TokenValidationOptionsOverride ? *options.TokenValidationOptionsOverride + : this->m_tokenValidationOptions, + m_attestationSigners); + + return Response>(token, std::move(response)); + } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } Azure::Response AttestationClient::AttestTpm( AttestTpmOptions const& attestTpmOptions, Azure::Core::Context const& context) const { - std::string jsonToSend = TpmDataSerializer::Serialize(attestTpmOptions.Payload); - auto encodedVector = std::vector(jsonToSend.begin(), jsonToSend.end()); - Azure::Core::IO::MemoryBodyStream stream(encodedVector); + auto tracingContext(m_tracingFactory.CreateTracingContext("AttestTpm", context)); + try + { + std::string jsonToSend = TpmDataSerializer::Serialize(attestTpmOptions.Payload); + auto encodedVector = std::vector(jsonToSend.begin(), jsonToSend.end()); + Azure::Core::IO::MemoryBodyStream stream(encodedVector); - auto request = AttestationCommonRequest::CreateRequest( - m_endpoint, m_apiVersion, HttpMethod::Post, {"attest/Tpm"}, &stream); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, m_apiVersion, HttpMethod::Post, {"attest/Tpm"}, &stream); - // Send the request to the service. - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - std::string returnedBody(TpmDataSerializer::Deserialize(response)); - return Response(TpmAttestationResult{returnedBody}, std::move(response)); + // Send the request to the service. + auto response + = AttestationCommonRequest::SendRequest(*m_pipeline, request, tracingContext.Context); + std::string returnedBody(TpmDataSerializer::Deserialize(response)); + return Response(TpmAttestationResult{returnedBody}, std::move(response)); + } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } namespace { @@ -186,28 +239,39 @@ std::shared_timed_mutex SharedStateLock; */ void AttestationClient::RetrieveResponseValidationCollateral(Azure::Core::Context const& context) { - std::unique_lock stateLock(SharedStateLock); - - if (m_attestationSigners.empty()) + auto tracingContext(m_tracingFactory.CreateTracingContext("Create", context)); + try { - stateLock.unlock(); - auto request - = AttestationCommonRequest::CreateRequest(m_endpoint, HttpMethod::Get, {"certs"}, nullptr); - auto response = AttestationCommonRequest::SendRequest(*m_pipeline, request, context); - auto jsonWebKeySet(JsonWebKeySetSerializer::Deserialize(response)); - TokenValidationCertificateResult returnValue; - std::vector newValue; - for (const auto& jwk : jsonWebKeySet.Keys) - { - AttestationSignerInternal internalSigner(jwk); - newValue.push_back(internalSigner); - } - stateLock.lock(); + std::unique_lock stateLock(SharedStateLock); + if (m_attestationSigners.empty()) { - m_attestationSigners = newValue; + stateLock.unlock(); + auto request = AttestationCommonRequest::CreateRequest( + m_endpoint, HttpMethod::Get, {"certs"}, nullptr); + auto response + = AttestationCommonRequest::SendRequest(*m_pipeline, request, tracingContext.Context); + auto jsonWebKeySet(JsonWebKeySetSerializer::Deserialize(response)); + TokenValidationCertificateResult returnValue; + std::vector newValue; + for (const auto& jwk : jsonWebKeySet.Keys) + { + AttestationSignerInternal internalSigner(jwk); + newValue.push_back(internalSigner); + } + stateLock.lock(); + if (m_attestationSigners.empty()) + { + m_attestationSigners = newValue; + } + tracingContext.Span.SetStatus(SpanStatus::Ok); } } + catch (std::runtime_error const& ex) + { + tracingContext.Span.AddEvent(ex); + throw; + } } /** @brief Construct a new Attestation Client object diff --git a/sdk/attestation/azure-security-attestation/vcpkg.json b/sdk/attestation/azure-security-attestation/vcpkg.json index bbe326b509..d0882d9519 100644 --- a/sdk/attestation/azure-security-attestation/vcpkg.json +++ b/sdk/attestation/azure-security-attestation/vcpkg.json @@ -1,6 +1,6 @@ { "name": "azure-security-attestation-cpp", - "version": "1.0.0-beta.1", + "version": "1.0.0-beta.3", "dependencies": [ { "name": "azure-core-cpp" diff --git a/sdk/attestation/azure-security-attestation/vcpkg/vcpkg.json b/sdk/attestation/azure-security-attestation/vcpkg/vcpkg.json index 4c97583075..a92bf0339b 100644 --- a/sdk/attestation/azure-security-attestation/vcpkg/vcpkg.json +++ b/sdk/attestation/azure-security-attestation/vcpkg/vcpkg.json @@ -14,7 +14,7 @@ { "name": "azure-core-cpp", "default-features": false, - "version>=": "1.5.0" + "version>=": "1.7.0-beta.1" }, { "name": "vcpkg-cmake", diff --git a/sdk/core/azure-core-tracing-opentelemetry/test/ut/service_support_test.cpp b/sdk/core/azure-core-tracing-opentelemetry/test/ut/service_support_test.cpp index 630296ccb8..b3c173821c 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/test/ut/service_support_test.cpp +++ b/sdk/core/azure-core-tracing-opentelemetry/test/ut/service_support_test.cpp @@ -590,8 +590,6 @@ class ServiceClient { : m_tracingFactory(clientOptions, "Azure.Core.OpenTelemetry.Test.Service", "1.0.0.beta-2") { std::vector> policies; - policies.emplace_back(std::make_unique( - "Azure.Core.OpenTelemetry.Test.Service", "1.0.0.beta-2", clientOptions.Telemetry)); policies.emplace_back(std::make_unique()); policies.emplace_back(std::make_unique(RetryOptions{})); diff --git a/sdk/core/azure-core/CMakeLists.txt b/sdk/core/azure-core/CMakeLists.txt index 90857472f0..931dace33a 100644 --- a/sdk/core/azure-core/CMakeLists.txt +++ b/sdk/core/azure-core/CMakeLists.txt @@ -81,6 +81,7 @@ set( inc/azure/core/internal/environment.hpp inc/azure/core/internal/extendable_enumeration.hpp inc/azure/core/internal/http/pipeline.hpp + inc/azure/core/internal/http/user_agent.hpp inc/azure/core/internal/io/null_body_stream.hpp inc/azure/core/internal/json/json.hpp inc/azure/core/internal/json/json_optional.hpp @@ -128,6 +129,7 @@ set( src/http/telemetry_policy.cpp src/http/transport_policy.cpp src/http/url.cpp + src/http/user_agent.cpp src/io/body_stream.cpp src/io/random_access_file_body_stream.cpp src/logger.cpp @@ -138,7 +140,7 @@ set( src/strings.cpp src/tracing/tracing.cpp src/uuid.cpp -) + ) add_library(azure-core ${AZURE_CORE_HEADER} ${AZURE_CORE_SOURCE}) diff --git a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp index 2262537a25..525c2c8648 100644 --- a/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/policies/policy.hpp @@ -14,8 +14,8 @@ #include "azure/core/dll_import_export.hpp" #include "azure/core/http/http.hpp" #include "azure/core/http/transport.hpp" +#include "azure/core/internal/http/user_agent.hpp" #include "azure/core/internal/input_sanitizer.hpp" -#include "azure/core/tracing/tracing.hpp" #include "azure/core/uuid.hpp" #include @@ -428,16 +428,17 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { * @details Applies an HTTP header with a component name and version to each HTTP request, * includes Azure SDK version information, and operating system information. * @remark See https://azure.github.io/azure-sdk/general_azurecore.html#telemetry-policy. + * + * @remark Note that for clients which are using distributed tracing, this functionality is + * merged into the RequestActivityPolicy policy. + * + * Eventually, when all service have converted to using distributed tracing, this policy can be + * deprecated. */ class TelemetryPolicy final : public HttpPolicy { private: std::string const m_telemetryId; - static std::string BuildTelemetryId( - std::string const& componentName, - std::string const& componentVersion, - std::string const& applicationId); - public: /** * @brief Construct HTTP telemetry policy. @@ -450,7 +451,10 @@ namespace Azure { namespace Core { namespace Http { namespace Policies { std::string const& componentName, std::string const& componentVersion, TelemetryOptions options = TelemetryOptions()) - : m_telemetryId(BuildTelemetryId(componentName, componentVersion, options.ApplicationId)) + : m_telemetryId(Azure::Core::Http::_detail::UserAgentGenerator::GenerateUserAgent( + componentName, + componentVersion, + options.ApplicationId)) { } diff --git a/sdk/core/azure-core/inc/azure/core/internal/http/pipeline.hpp b/sdk/core/azure-core/inc/azure/core/internal/http/pipeline.hpp index f2ee822752..7e2e9c8319 100644 --- a/sdk/core/azure-core/inc/azure/core/internal/http/pipeline.hpp +++ b/sdk/core/azure-core/inc/azure/core/internal/http/pipeline.hpp @@ -31,33 +31,7 @@ namespace Azure { namespace Core { namespace Http { namespace _internal { * @remark See #policy.hpp */ class HttpPipeline final { - protected: - std::vector> m_policies; - - public: - /** - * @brief Construct HTTP pipeline with the sequence of HTTP policies provided. - * - * @param policies A sequence of #Azure::Core::Http::Policies::HttpPolicy - * representing a stack, first element corresponding to the top of the stack. - * - * @throw `std::invalid_argument` when policies is empty. - */ - explicit HttpPipeline( - const std::vector>& policies) - { - if (policies.size() == 0) - { - throw std::invalid_argument("policies cannot be empty"); - } - - m_policies.reserve(policies.size()); - for (auto& policy : policies) - { - m_policies.emplace_back(policy->Clone()); - } - } - + private: /** * @brief Construct a new HTTP Pipeline object from clientOptions. * @@ -72,25 +46,26 @@ namespace Azure { namespace Core { namespace Http { namespace _internal { */ explicit HttpPipeline( Azure::Core::_internal::ClientOptions const& clientOptions, - std::string const& telemetryServiceName, - std::string const& telemetryServiceVersion, std::vector>&& perRetryPolicies, - std::vector>&& perCallPolicies) + std::vector>&& perCallPolicies, + bool includeTelemetryPolicy, + std::string const& telemetryServiceName = {}, + std::string const& telemetryServiceVersion = {}) { Azure::Core::_internal::InputSanitizer inputSanitizer( clientOptions.Log.AllowedHttpQueryParameters, clientOptions.Log.AllowedHttpHeaders); auto const& perCallClientPolicies = clientOptions.PerOperationPolicies; auto const& perRetryClientPolicies = clientOptions.PerRetryPolicies; - // Adding 6 for: - // - TelemetryPolicy + // Adding 5/6 for: + // - TelemetryPolicy (if required) // - RequestIdPolicy // - RetryPolicy // - LogPolicy // - RequestActivityPolicy // - TransportPolicy auto pipelineSize = perCallClientPolicies.size() + perRetryClientPolicies.size() - + perRetryPolicies.size() + perCallPolicies.size() + 6; + + perRetryPolicies.size() + perCallPolicies.size() + 5 + (includeTelemetryPolicy ? 1 : 0); m_policies.reserve(pipelineSize); @@ -105,9 +80,12 @@ namespace Azure { namespace Core { namespace Http { namespace _internal { std::make_unique()); // Telemetry - m_policies.emplace_back( - std::make_unique( - telemetryServiceName, telemetryServiceVersion, clientOptions.Telemetry)); + if (includeTelemetryPolicy) + { + m_policies.emplace_back( + std::make_unique( + telemetryServiceName, telemetryServiceVersion, clientOptions.Telemetry)); + } // client-options per call policies. for (auto& policy : perCallClientPolicies) @@ -145,6 +123,88 @@ namespace Azure { namespace Core { namespace Http { namespace _internal { clientOptions.Transport)); } + protected: + std::vector> m_policies; + + public: + /** + * @brief Construct HTTP pipeline with the sequence of HTTP policies provided. + * + * @param policies A sequence of #Azure::Core::Http::Policies::HttpPolicy + * representing a stack, first element corresponding to the top of the stack. + * + * @throw `std::invalid_argument` when policies is empty. + */ + explicit HttpPipeline( + const std::vector>& policies) + { + if (policies.size() == 0) + { + throw std::invalid_argument("policies cannot be empty"); + } + + m_policies.reserve(policies.size()); + for (auto& policy : policies) + { + m_policies.emplace_back(policy->Clone()); + } + } + + /** + * @brief Construct a new HTTP Pipeline object from clientOptions. + * + * @remark The client options includes per retry and per call policies which are merged with the + * service-specific per retry policies. + * + * @param clientOptions The SDK client options. + * @param telemetryServiceName The name of the service for sending telemetry. + * @param telemetryServiceVersion The version of the service for sending telemetry. + * @param perRetryPolicies The service-specific per retry policies. + * @param perCallPolicies The service-specific per call policies. + */ + explicit HttpPipeline( + Azure::Core::_internal::ClientOptions const& clientOptions, + std::string const& telemetryServiceName, + std::string const& telemetryServiceVersion, + std::vector>&& perRetryPolicies, + std::vector>&& perCallPolicies) + : HttpPipeline( + clientOptions, + std::move(perRetryPolicies), + std::move(perCallPolicies), + true, + telemetryServiceName, + telemetryServiceVersion) + { + } + + /** + * @brief Construct a new HTTP Pipeline object from clientOptions. + * + * @remark The client options includes per retry and per call policies which are merged with the + * service-specific per retry policies. + * + * @remark This specialization of the HttpPipeline constructor constructs an HTTP pipeline + * *without* a telemetry policy. It is intended for use by service clients which have converted + * to use distributed tracing - the distributed tracing policy adds the User-Agent header to the + * request. + * + * @param clientOptions The SDK client options. + * @param perRetryPolicies The service-specific per retry policies. + * @param perCallPolicies The service-specific per call policies. + */ + explicit HttpPipeline( + Azure::Core::_internal::ClientOptions const& clientOptions, + std::vector>&& perRetryPolicies, + std::vector>&& perCallPolicies) + : HttpPipeline( + clientOptions, + std::move(perRetryPolicies), + std::move(perCallPolicies), + false) + { + } + /** * @brief Construct HTTP pipeline with the sequence of HTTP policies provided. * diff --git a/sdk/core/azure-core/inc/azure/core/internal/http/user_agent.hpp b/sdk/core/azure-core/inc/azure/core/internal/http/user_agent.hpp new file mode 100644 index 0000000000..ae43c51485 --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/internal/http/user_agent.hpp @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief HTTP pipeline is a stack of HTTP policies. + * @remark See #policy.hpp + */ + +#pragma once + +#include + +namespace Azure { namespace Core { namespace Http { namespace _detail { + class UserAgentGenerator { + public: + static std::string GenerateUserAgent( + std::string const& componentName, + std::string const& componentVersion, + std::string const& applicationId); + }; +}}}} // namespace Azure::Core::Http::_detail diff --git a/sdk/core/azure-core/inc/azure/core/internal/tracing/service_tracing.hpp b/sdk/core/azure-core/inc/azure/core/internal/tracing/service_tracing.hpp index b023a850cf..7c7c42d627 100644 --- a/sdk/core/azure-core/inc/azure/core/internal/tracing/service_tracing.hpp +++ b/sdk/core/azure-core/inc/azure/core/internal/tracing/service_tracing.hpp @@ -3,6 +3,7 @@ #include "azure/core/context.hpp" #include "azure/core/internal/client_options.hpp" +#include "azure/core/internal/http/user_agent.hpp" #include "azure/core/tracing/tracing.hpp" #pragma once @@ -128,7 +129,8 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { } /** - * @brief Records an exception occurring in the span. + * @brief Records an exception occurring in the span. Also marks the status of the span as + * SpanStatus::Error * * @param exception Exception which has occurred. */ @@ -137,6 +139,7 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { if (m_span) { m_span->AddEvent(exception); + m_span->SetStatus(SpanStatus::Error); } } @@ -165,6 +168,7 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { private: std::string m_serviceName; std::string m_serviceVersion; + std::string m_userAgent; std::shared_ptr m_serviceTracer; /** @brief The key used to retrieve the span and tracer associated with a context object. @@ -184,6 +188,10 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { std::string serviceName, std::string serviceVersion) : m_serviceName(serviceName), m_serviceVersion(serviceVersion), + m_userAgent(Azure::Core::Http::_detail::UserAgentGenerator::GenerateUserAgent( + serviceName, + serviceVersion, + options.Telemetry.ApplicationId)), m_serviceTracer( options.Telemetry.TracingProvider ? options.Telemetry.TracingProvider->CreateTracer(serviceName, serviceVersion) @@ -243,6 +251,14 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { std::unique_ptr CreateAttributeSet() const; + /** @brief Retrieves the User-Agent header value for this tracing context factory. + */ + std::string const& GetUserAgent() const { return m_userAgent; } + + /** @brief Returns true if this TracingContextFactory is connected to a service tracer. + */ + bool HasTracer() const { return static_cast(m_serviceTracer); } + static std::unique_ptr CreateFromContext( Azure::Core::Context const& context); }; diff --git a/sdk/core/azure-core/src/http/log_policy.cpp b/sdk/core/azure-core/src/http/log_policy.cpp index 2fe4cf5884..ae2db52fe2 100644 --- a/sdk/core/azure-core/src/http/log_policy.cpp +++ b/sdk/core/azure-core/src/http/log_policy.cpp @@ -89,7 +89,7 @@ Azure::Core::CaseInsensitiveSet const "traceparent", "tracestate", "Transfer-Encoding", - "User-Agent" + "User-Agent", "x-ms-client-request-id", "x-ms-request-id", "x-ms-return-client-request-id", diff --git a/sdk/core/azure-core/src/http/request_activity_policy.cpp b/sdk/core/azure-core/src/http/request_activity_policy.cpp index f52a54861d..2ca84be7c7 100644 --- a/sdk/core/azure-core/src/http/request_activity_policy.cpp +++ b/sdk/core/azure-core/src/http/request_activity_policy.cpp @@ -23,21 +23,38 @@ std::unique_ptr RequestActivityPolicy::Send( NextHttpPolicy nextPolicy, Context const& context) const { + Azure::Nullable userAgent; // Find a tracing factory from our context. Note that the factory value is owned by the // context chain so we can manage a raw pointer to the factory. auto tracingFactory = TracingContextFactory::CreateFromContext(context); if (tracingFactory) { + // Determine the value of the "User-Agent" header. + // + // If nobody has previously set a user agent header, then set the user agent header + // based on the value calculated by the tracing factory. + userAgent = request.GetHeader("User-Agent"); + if (!userAgent.HasValue()) + { + userAgent = tracingFactory->GetUserAgent(); + request.SetHeader("User-Agent", userAgent.Value()); + } + } + + // If our tracing factory has a tracer attached to it, register the request with the tracer. + if (tracingFactory && tracingFactory->HasTracer()) + { + // Create a tracing span over the HTTP request. - std::stringstream ss; - ss << "HTTP " << request.GetMethod().ToString(); + std::string spanName("HTTP "); + spanName.append(request.GetMethod().ToString()); CreateSpanOptions createOptions; createOptions.Kind = SpanKind::Client; createOptions.Attributes = tracingFactory->CreateAttributeSet(); - // Note that the AttributeSet takes a *reference* to the values passed into the AttributeSet. - // This means that all the values passed into the AttributeSet MUST be stabilized across the - // lifetime of the AttributeSet. + // Note that the AttributeSet takes a *reference* to the values passed into the + // AttributeSet. This means that all the values passed into the AttributeSet MUST be + // stabilized across the lifetime of the AttributeSet. // Note that request.GetMethod() returns an HttpMethod object, which is always a static // object, and thus its lifetime is constant. That is not the case for the other values @@ -55,14 +72,11 @@ std::unique_ptr RequestActivityPolicy::Send( TracingAttributes::RequestId.ToString(), requestId.Value()); } - const auto userAgent = request.GetHeader("User-Agent"); - if (userAgent.HasValue()) - { - createOptions.Attributes->AddAttribute( - TracingAttributes::HttpUserAgent.ToString(), userAgent.Value()); - } + // We retrieved the value of the user-agent header above. + createOptions.Attributes->AddAttribute( + TracingAttributes::HttpUserAgent.ToString(), userAgent.Value()); - auto contextAndSpan = tracingFactory->CreateTracingContext(ss.str(), createOptions, context); + auto contextAndSpan = tracingFactory->CreateTracingContext(spanName, createOptions, context); auto scope = std::move(contextAndSpan.Span); // Propagate information from the scope to the HTTP headers. diff --git a/sdk/core/azure-core/src/http/telemetry_policy.cpp b/sdk/core/azure-core/src/http/telemetry_policy.cpp index 0a012d5efd..6e17ef576d 100644 --- a/sdk/core/azure-core/src/http/telemetry_policy.cpp +++ b/sdk/core/azure-core/src/http/telemetry_policy.cpp @@ -2,162 +2,13 @@ // SPDX-License-Identifier: MIT #include "azure/core/http/policies/policy.hpp" -#include "azure/core/platform.hpp" - -#include -#include - -#if defined(AZ_PLATFORM_WINDOWS) -#if !defined(WIN32_LEAN_AND_MEAN) -#define WIN32_LEAN_AND_MEAN -#endif -#if !defined(NOMINMAX) -#define NOMINMAX -#endif - -#include - -#if !defined(WINAPI_PARTITION_DESKTOP) \ - || WINAPI_PARTITION_DESKTOP // See azure/core/platform.hpp for explanation. - -namespace Azure { namespace Core { namespace _internal { - - /** - * @brief HkeyHolder ensures native handle resource is released. - * - */ - class HkeyHolder final { - private: - HKEY m_value = nullptr; - - public: - explicit HkeyHolder() noexcept : m_value(nullptr) {} - - ~HkeyHolder() noexcept - { - if (m_value != nullptr) - { - ::RegCloseKey(m_value); - } - } - - void operator=(HKEY p) noexcept - { - if (p != nullptr) - { - m_value = p; - } - } - - operator HKEY() noexcept { return m_value; } - - operator HKEY*() noexcept { return &m_value; } - - HKEY* operator&() noexcept { return &m_value; } - }; - -}}} // namespace Azure::Core::_internal - -#endif - -#elif defined(AZ_PLATFORM_POSIX) -#include -#endif - -namespace { -std::string GetOSVersion() -{ - std::ostringstream osVersionInfo; - -#if defined(AZ_PLATFORM_WINDOWS) -#if !defined(WINAPI_PARTITION_DESKTOP) \ - || WINAPI_PARTITION_DESKTOP // See azure/core/platform.hpp for explanation. - { - Azure::Core::_internal::HkeyHolder regKey; - if (RegOpenKeyExA( - HKEY_LOCAL_MACHINE, - "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion", - 0, - KEY_READ, - ®Key) - == ERROR_SUCCESS) - { - auto first = true; - static constexpr char const* regValues[]{ - "ProductName", "CurrentVersion", "CurrentBuildNumber", "BuildLabEx"}; - for (auto regValue : regValues) - { - char valueBuf[200] = {}; - DWORD valueBufSize = sizeof(valueBuf); - - if (RegQueryValueExA(regKey, regValue, NULL, NULL, (LPBYTE)valueBuf, &valueBufSize) - == ERROR_SUCCESS) - { - if (valueBufSize > 0) - { - osVersionInfo << (first ? "" : " ") - << std::string(valueBuf, valueBuf + (valueBufSize - 1)); - first = false; - } - } - } - } - } -#else - { - osVersionInfo << "UWP"; - } -#endif -#elif defined(AZ_PLATFORM_POSIX) - { - utsname sysInfo{}; - if (uname(&sysInfo) == 0) - { - osVersionInfo << sysInfo.sysname << " " << sysInfo.release << " " << sysInfo.machine << " " - << sysInfo.version; - } - } -#endif - - return osVersionInfo.str(); -} - -std::string TrimString(std::string s) -{ - auto const isSpace = [](int c) { return !std::isspace(c); }; - - s.erase(s.begin(), std::find_if(s.begin(), s.end(), isSpace)); - s.erase(std::find_if(s.rbegin(), s.rend(), isSpace).base(), s.end()); - - return s; -} -} // namespace using Azure::Core::Context; using namespace Azure::Core::Http; using namespace Azure::Core::Http::Policies; using namespace Azure::Core::Http::Policies::_internal; -std::string TelemetryPolicy::BuildTelemetryId( - std::string const& componentName, - std::string const& componentVersion, - std::string const& applicationId) -{ - // Spec: https://azure.github.io/azure-sdk/general_azurecore.html#telemetry-policy - std::ostringstream telemetryId; - - if (!applicationId.empty()) - { - telemetryId << TrimString(applicationId).substr(0, 24) << " "; - } - - static std::string const osVer = GetOSVersion(); - telemetryId << "azsdk-cpp-" << componentName << "/" << componentVersion << " (" << osVer << ")"; - - return telemetryId.str(); -} - -std::unique_ptr TelemetryPolicy::Send( +std::unique_ptr Azure::Core::Http::Policies::_internal::TelemetryPolicy::Send( Request& request, NextHttpPolicy nextPolicy, Context const& context) const diff --git a/sdk/core/azure-core/src/http/user_agent.cpp b/sdk/core/azure-core/src/http/user_agent.cpp new file mode 100644 index 0000000000..2aaaee93be --- /dev/null +++ b/sdk/core/azure-core/src/http/user_agent.cpp @@ -0,0 +1,166 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief HTTP pipeline is a stack of HTTP policies. + * @remark See #policy.hpp + */ + +#include + +#include "azure/core/context.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/internal/tracing/service_tracing.hpp" +#include "azure/core/platform.hpp" +#include "azure/core/tracing/tracing.hpp" +#include +#include + +#if defined(AZ_PLATFORM_WINDOWS) +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#if !defined(NOMINMAX) +#define NOMINMAX +#endif + +#include + +#if !defined(WINAPI_PARTITION_DESKTOP) \ + || WINAPI_PARTITION_DESKTOP // See azure/core/platform.hpp for explanation. + +namespace { + +/** + * @brief HkeyHolder ensures native handle resource is released. + * + */ +class HkeyHolder final { +private: + HKEY m_value = nullptr; + +public: + explicit HkeyHolder() noexcept : m_value(nullptr) {} + + ~HkeyHolder() noexcept + { + if (m_value != nullptr) + { + ::RegCloseKey(m_value); + } + } + + void operator=(HKEY p) noexcept + { + if (p != nullptr) + { + m_value = p; + } + } + + operator HKEY() noexcept { return m_value; } + + operator HKEY*() noexcept { return &m_value; } + + HKEY* operator&() noexcept { return &m_value; } +}; + +} // namespace + +#endif + +#elif defined(AZ_PLATFORM_POSIX) +#include +#endif + +namespace { +std::string GetOSVersion() +{ + std::ostringstream osVersionInfo; + +#if defined(AZ_PLATFORM_WINDOWS) +#if !defined(WINAPI_PARTITION_DESKTOP) \ + || WINAPI_PARTITION_DESKTOP // See azure/core/platform.hpp for explanation. + { + HkeyHolder regKey; + if (RegOpenKeyExA( + HKEY_LOCAL_MACHINE, + "SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion", + 0, + KEY_READ, + ®Key) + == ERROR_SUCCESS) + { + auto first = true; + static constexpr char const* regValues[]{ + "ProductName", "CurrentVersion", "CurrentBuildNumber", "BuildLabEx"}; + for (auto regValue : regValues) + { + char valueBuf[200] = {}; + DWORD valueBufSize = sizeof(valueBuf); + + if (RegQueryValueExA(regKey, regValue, NULL, NULL, (LPBYTE)valueBuf, &valueBufSize) + == ERROR_SUCCESS) + { + if (valueBufSize > 0) + { + osVersionInfo << (first ? "" : " ") + << std::string(valueBuf, valueBuf + (valueBufSize - 1)); + first = false; + } + } + } + } + } +#else + { + osVersionInfo << "UWP"; + } +#endif +#elif defined(AZ_PLATFORM_POSIX) + { + utsname sysInfo{}; + if (uname(&sysInfo) == 0) + { + osVersionInfo << sysInfo.sysname << " " << sysInfo.release << " " << sysInfo.machine << " " + << sysInfo.version; + } + } +#endif + + return osVersionInfo.str(); +} + +std::string TrimString(std::string s) +{ + auto const isSpace = [](int c) { return !std::isspace(c); }; + + s.erase(s.begin(), std::find_if(s.begin(), s.end(), isSpace)); + s.erase(std::find_if(s.rbegin(), s.rend(), isSpace).base(), s.end()); + + return s; +} +} // namespace + +namespace Azure { namespace Core { namespace Http { namespace _detail { + + std::string UserAgentGenerator::GenerateUserAgent( + std::string const& componentName, + std::string const& componentVersion, + std::string const& applicationId) + { + // Spec: https://azure.github.io/azure-sdk/general_azurecore.html#telemetry-policy + std::ostringstream telemetryId; + + if (!applicationId.empty()) + { + telemetryId << TrimString(applicationId).substr(0, 24) << " "; + } + + static std::string const osVer = GetOSVersion(); + telemetryId << "azsdk-cpp-" << componentName << "/" << componentVersion << " (" << osVer << ")"; + + return telemetryId.str(); + } +}}}} // namespace Azure::Core::Http::_detail diff --git a/sdk/core/azure-core/src/tracing/tracing.cpp b/sdk/core/azure-core/src/tracing/tracing.cpp index de183e822c..23e4eb0d20 100644 --- a/sdk/core/azure-core/src/tracing/tracing.cpp +++ b/sdk/core/azure-core/src/tracing/tracing.cpp @@ -1,5 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + #include "azure/core/tracing/tracing.hpp" +#include "azure/core/context.hpp" +#include "azure/core/http/policies/policy.hpp" #include "azure/core/internal/tracing/service_tracing.hpp" +#include +#include namespace Azure { namespace Core { namespace Tracing { namespace _internal { @@ -21,23 +28,21 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { const TracingAttributes TracingAttributes::RequestId("requestId"); const TracingAttributes TracingAttributes::HttpStatusCode("http.status_code"); + using Azure::Core::Context; + TracingContextFactory::TracingContext TracingContextFactory::CreateTracingContext( std::string const& methodName, Azure::Core::Context const& context) const { - if (m_serviceTracer) - { - Azure::Core::Context contextToUse = context; - CreateSpanOptions createOptions; + Azure::Core::Context contextToUse = context; + CreateSpanOptions createOptions; - createOptions.Kind = SpanKind::Internal; - createOptions.Attributes = m_serviceTracer->CreateAttributeSet(); - return CreateTracingContext(methodName, createOptions, context); - } - else + createOptions.Kind = SpanKind::Internal; + if (HasTracer()) { - return TracingContext{context, ServiceSpan{}}; + createOptions.Attributes = m_serviceTracer->CreateAttributeSet(); } + return CreateTracingContext(methodName, createOptions, context); } TracingContextFactory::TracingContext TracingContextFactory::CreateTracingContext( @@ -45,17 +50,20 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { Azure::Core::Tracing::_internal::CreateSpanOptions& createOptions, Azure::Core::Context const& context) const { - if (m_serviceTracer) - { - Azure::Core::Context contextToUse = context; + Azure::Core::Context contextToUse = context; - // Ensure that the factory is available in the context chain. - TracingContextFactory const* tracingFactoryFromContext; - if (!context.TryGetValue(TracingFactoryContextKey, tracingFactoryFromContext)) - { - contextToUse = context.WithValue(TracingFactoryContextKey, this); - } + // Ensure that the factory is available in the context chain. + // Note that we do this even if we don't have distributed tracing enabled, that's because + // the tracing context factory is also responsible for creating the User-Agent HTTP header, so + // it needs to be available for all requests. + TracingContextFactory const* tracingFactoryFromContext; + if (!context.TryGetValue(TracingFactoryContextKey, tracingFactoryFromContext)) + { + contextToUse = context.WithValue(TracingFactoryContextKey, this); + } + if (HasTracer()) + { std::shared_ptr traceContext; // Find a span in the context hierarchy. if (contextToUse.TryGetValue(ContextSpanKey, traceContext)) @@ -83,7 +91,7 @@ namespace Azure { namespace Core { namespace Tracing { namespace _internal { } else { - return TracingContext{context, ServiceSpan{}}; + return TracingContext{contextToUse, ServiceSpan{}}; } } diff --git a/sdk/core/azure-core/test/ut/log_policy_test.cpp b/sdk/core/azure-core/test/ut/log_policy_test.cpp index 32ddb6f43a..5fdb38bc09 100644 --- a/sdk/core/azure-core/test/ut/log_policy_test.cpp +++ b/sdk/core/azure-core/test/ut/log_policy_test.cpp @@ -13,7 +13,10 @@ using Azure::Core::Http::Policies::LogOptions; // cspell:ignore qparam namespace { -void SendRequest(LogOptions const& logOptions, std::string const& portAndPath = "") +void SendRequest( + LogOptions const& logOptions, + bool addDefaultAllowedHeaders = false, + std::string const& portAndPath = "") { using namespace Azure::Core; using namespace Azure::Core::IO; @@ -61,8 +64,41 @@ void SendRequest(LogOptions const& logOptions, std::string const& portAndPath = request.SetHeader("hEaDeR1", "HvAlUe1"); request.SetHeader("HeAdEr2", "hVaLuE2"); - request.SetHeader("x-ms-request-id", "6c536700-4c36-4e22-9161-76e7b3bf8269"); + // Add in all the default allowed HTTP headers to the request. We'll make sure they're not + // redacted on the way out. + if (addDefaultAllowedHeaders) + { + + request.SetHeader("Accept", "Accept"); + request.SetHeader("Cache-Control", "Cache-Control"); + request.SetHeader("Connection", "Connection"); + request.SetHeader("Content-Length", "Content-Length"); + request.SetHeader("Content-Type", "Content-Type"); + request.SetHeader("Date", "Date"); + request.SetHeader("ETag", "ETag"); + request.SetHeader("Expires", "Expires"); + request.SetHeader("If-Match", "If-Match"); + request.SetHeader("If-Modified-Since", "If-Modified-Since"); + request.SetHeader("If-None-Match", "If-None-Match"); + request.SetHeader("If-Unmodified-Since", "If-Unmodified-Since"); + request.SetHeader("Last-Modified", "Last-Modified"); + request.SetHeader("Pragma", "Pragma"); + request.SetHeader("Request-Id", "Request-Id"); + request.SetHeader("Retry-After", "Retry-After"); + request.SetHeader("Server", "Server"); + request.SetHeader("traceparent", "traceparent"); + request.SetHeader("tracestate", "tracestate"); + request.SetHeader("Transfer-Encoding", "Transfer-Encoding"); + request.SetHeader("User-Agent", "User-Agent"); + request.SetHeader("x-ms-client-request-id", "x-ms-client-request-id"); + request.SetHeader("x-ms-request-id", "x-ms-request-id"); + request.SetHeader("x-ms-return-client-request-id", "x-ms-return-client-request-id"); + } + else + { + request.SetHeader("x-ms-request-id", "6c536700-4c36-4e22-9161-76e7b3bf8269"); + } { std::vector> policies; @@ -156,7 +192,7 @@ TEST(LogPolicy, Default) TEST(LogPolicy, PortAndPath) { TestLogger const Log; - SendRequest(LogOptions(), ":8080/path"); + SendRequest(LogOptions(), false, ":8080/path"); EXPECT_EQ(Log.Entries.size(), 2); @@ -189,7 +225,7 @@ TEST(LogPolicy, Headers) { auto logOptions = LogOptions(); logOptions.AllowedHttpHeaders.insert({"HeAder1", "heaDer3"}); - SendRequest(logOptions); + SendRequest(logOptions, false); } EXPECT_EQ(Log.Entries.size(), 2); @@ -216,6 +252,74 @@ TEST(LogPolicy, Headers) EXPECT_TRUE(EndsWith(entry2.Message, "ms) : 200 OKAY")); } +TEST(LogPolicy, DefaultHeaders) +{ + TestLogger const Log; + + { + auto logOptions = LogOptions(); + logOptions.AllowedHttpHeaders.insert({"HeAder1", "heaDer3"}); + SendRequest(logOptions, true); + } + + EXPECT_EQ(Log.Entries.size(), 2); + + auto const entry1 = Log.Entries.at(0); + auto const entry2 = Log.Entries.at(1); + + EXPECT_EQ(entry1.Level, Logger::Level::Informational); + EXPECT_EQ(entry2.Level, Logger::Level::Informational); + + EXPECT_EQ( + entry1.Message, + "HTTP Request : GET https://www.microsoft.com" + "?Qparam2=REDACTED" + "&qParam3=REDACTED" + "&qparam%204=REDACTED" + "&qparam%25204=REDACTED" + "&qparam1=REDACTED" + "\naccept : Accept" + "\ncache-control : Cache-Control" + "\nconnection : Connection" + "\ncontent-length : Content-Length" + "\ncontent-type : Content-Type" + "\ndate : Date" + "\netag : ETag" + "\nexpires : Expires" + "\nheader1 : HvAlUe1" + "\nheader2 : REDACTED" + "\nif-match : If-Match" + "\nif-modified-since : If-Modified-Since" + "\nif-none-match : If-None-Match" + "\nif-unmodified-since : If-Unmodified-Since" + "\nlast-modified : Last-Modified" + "\npragma : Pragma" + "\nrequest-id : Request-Id" + "\nretry-after : Retry-After" + "\nserver : Server" + "\ntraceparent : traceparent" + "\ntracestate : tracestate" + "\ntransfer-encoding : Transfer-Encoding" + "\nuser-agent : User-Agent" + "\nx-ms-client-request-id : x-ms-client-request-id" + "\nx-ms-request-id : x-ms-request-id" + "\nx-ms-return-client-request-id : x-ms-return-client-request-id"); + + EXPECT_TRUE(StartsWith(entry2.Message, "HTTP Response (")); + EXPECT_TRUE(EndsWith(entry2.Message, "ms) : 200 OKAY")); + + // Ensure that the entire list of allowed headers is in the list of headers. + // This ensures that if a new header is added to the default allow list, we have a test case + // covering it. + for (auto const& allowedHeader : + Azure::Core::Http::Policies::_detail::g_defaultAllowedHttpHeaders) + { + // NOTE: If this fails, it means that we need to update the SendRequest function + // to add support for the missing allowed header. + EXPECT_NE(entry1.Message.find(allowedHeader), std::string::npos); + } +} + TEST(LogPolicy, QueryParams) { TestLogger const Log; diff --git a/sdk/core/azure-core/test/ut/request_activity_policy_test.cpp b/sdk/core/azure-core/test/ut/request_activity_policy_test.cpp index 927b3777ac..1144eea97b 100644 --- a/sdk/core/azure-core/test/ut/request_activity_policy_test.cpp +++ b/sdk/core/azure-core/test/ut/request_activity_policy_test.cpp @@ -174,7 +174,8 @@ TEST(RequestActivityPolicy, Basic) // Final policy - equivalent to HTTP policy. policies.emplace_back(std::make_unique()); - Azure::Core::Http::_internal::HttpPipeline(policies).Send(request, callContext); + auto response + = Azure::Core::Http::_internal::HttpPipeline(policies).Send(request, callContext); } EXPECT_EQ(1ul, testTracer->GetTracers().size()); @@ -192,22 +193,24 @@ TEST(RequestActivityPolicy, Basic) Azure::Core::_internal::ClientOptions clientOptions; clientOptions.Telemetry.TracingProvider = testTracer; Azure::Core::Tracing::_internal::TracingContextFactory serviceTrace( - clientOptions, "my-service-cpp", "1.0b2"); + clientOptions, "my-service-cpp", "1.0.0.beta-2"); auto contextAndSpan = serviceTrace.CreateTracingContext("My API", {}); Azure::Core::Context callContext = std::move(contextAndSpan.Context); Request request(HttpMethod::Get, Url("https://www.microsoft.com")); + Azure::Nullable userAgent; { std::vector> policies; // Add the request ID policy - this adds the x-ms-request-id attribute to the pipeline. policies.emplace_back(std::make_unique()); - policies.emplace_back( - std::make_unique("my-service-cpp", "1.0b2", clientOptions.Telemetry)); policies.emplace_back(std::make_unique(RetryOptions{})); policies.emplace_back( std::make_unique(Azure::Core::_internal::InputSanitizer{})); // Final policy - equivalent to HTTP policy. - policies.emplace_back(std::make_unique()); + policies.emplace_back(std::make_unique([&](Request& request) { + userAgent = request.GetHeader("user-agent"); // Return success. + return std::make_unique(1, 1, HttpStatusCode::Ok, "Something"); + })); Azure::Core::Http::_internal::HttpPipeline(policies).Send(request, callContext); } @@ -218,6 +221,8 @@ TEST(RequestActivityPolicy, Basic) EXPECT_EQ("My API", tracer->GetSpans()[0]->GetName()); EXPECT_EQ("HTTP GET", tracer->GetSpans()[1]->GetName()); EXPECT_EQ("GET", tracer->GetSpans()[1]->GetAttributes().at("http.method")); + std::string expectedUserAgentPrefix{"azsdk-cpp-my-service-cpp/1.0.0.beta-2 ("}; + EXPECT_EQ(expectedUserAgentPrefix, userAgent.Value().substr(0, expectedUserAgentPrefix.size())); } } diff --git a/sdk/core/azure-core/test/ut/service_tracing_test.cpp b/sdk/core/azure-core/test/ut/service_tracing_test.cpp index 1ac5e8eea3..204d366d79 100644 --- a/sdk/core/azure-core/test/ut/service_tracing_test.cpp +++ b/sdk/core/azure-core/test/ut/service_tracing_test.cpp @@ -32,6 +32,82 @@ TEST(TracingContextFactory, ServiceTraceEnums) std::string tracingAttributeName = TracingAttributes::AzNamespace.ToString(); } +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include +#include +#include + +using namespace Azure::Core; +using namespace Azure::Core::Http; +using namespace Azure::Core::Http::_internal; +using namespace Azure::Core::Http::Policies; +using namespace Azure::Core::Http::Policies::_internal; + +namespace { + +class NoOpPolicy final : public HttpPolicy { +private: + std::unique_ptr Send( + Request& request, + NextHttpPolicy nextPolicy, + Context const& context) const override + { + (void)context; + (void)request; + (void)nextPolicy; + + return std::unique_ptr(); + } + + std::unique_ptr Clone() const override { return std::make_unique(*this); } +}; + +} // namespace + +TEST(TracingContextFactory, UserAgentTests) +{ + struct + { + const std::string serviceName; + const std::string serviceVersion; + const std::string applicationId; + const std::string expectedPrefix; + } UserAgentTests[] + = {{"storage-blob", "11.0.0", "", "azsdk-cpp-storage-blob/11.0.0 ("}, + {"storage-blob", + "11.0.0", + "AzCopy/10.0.4-Preview", + "AzCopy/10.0.4-Preview azsdk-cpp-storage-blob/11.0.0 ("}, + {"storage-blob", + "11.0.0", + "AzCopy / 10.0.4-Preview ", + "AzCopy / 10.0.4-Preview azsdk-cpp-storage-blob/11.0.0 ("}, + {"storage-blob", + "11.0.0", + " 01234567890123456789abcde ", + "01234567890123456789abcd azsdk-cpp-storage-blob/11.0.0 ("}}; + + constexpr auto UserAgentEnd = ')'; + constexpr auto OSInfoMinLength = 10; + + for (auto const& test : UserAgentTests) + { + Azure::Core::_internal::ClientOptions clientOptions; + clientOptions.Telemetry.ApplicationId = test.applicationId; + Azure::Core::Tracing::_internal::TracingContextFactory traceFactory( + clientOptions, test.serviceName, test.serviceVersion); + std::string userAgent = traceFactory.GetUserAgent(); + + EXPECT_FALSE(userAgent.empty()); + EXPECT_LT( + test.expectedPrefix.size() + OSInfoMinLength + sizeof(UserAgentEnd), userAgent.size()); + EXPECT_EQ(test.expectedPrefix, userAgent.substr(0, test.expectedPrefix.size())); + EXPECT_EQ(UserAgentEnd, userAgent[userAgent.size() - 1]); + } +} + TEST(TracingContextFactory, SimpleServiceSpanTests) { { diff --git a/sdk/template/azure-template/inc/azure/template/template_client.hpp b/sdk/template/azure-template/inc/azure/template/template_client.hpp index e9821dc1b1..6ee21c0f31 100644 --- a/sdk/template/azure-template/inc/azure/template/template_client.hpp +++ b/sdk/template/azure-template/inc/azure/template/template_client.hpp @@ -15,9 +15,11 @@ namespace Azure { namespace Template { class TemplateClient final { public: - TemplateClient(TemplateClientOptions options = TemplateClientOptions()); - std::string ClientVersion() const; - int GetValue(int key) const; + TemplateClient(TemplateClientOptions const& options = TemplateClientOptions{}); + int GetValue(int key, Azure::Core::Context const& context = Azure::Core::Context{}) const; + + private: + Azure::Core::Tracing::_internal::TracingContextFactory m_tracingFactory; }; }} // namespace Azure::Template diff --git a/sdk/template/azure-template/src/template_client.cpp b/sdk/template/azure-template/src/template_client.cpp index 794349b07c..cb1e74e4cf 100644 --- a/sdk/template/azure-template/src/template_client.cpp +++ b/sdk/template/azure-template/src/template_client.cpp @@ -10,16 +10,57 @@ using namespace Azure::Template; using namespace Azure::Template::_detail; -std::string TemplateClient::ClientVersion() const { return PackageVersion::ToString(); } +TemplateClient::TemplateClient(TemplateClientOptions const& options) + : m_tracingFactory(options, "Template", PackageVersion::ToString()) -TemplateClient::TemplateClient(TemplateClientOptions) {} +{ +} -int TemplateClient::GetValue(int key) const +int TemplateClient::GetValue(int key, Azure::Core::Context const& context) const { - if (key < 0) + auto tracingContext = m_tracingFactory.CreateTracingContext("GetValue", context); + + try { - return 0; - } - return key + 1; + if (key < 0) + { + return 0; + } + + // Blackjack basic strategy vs dealer 10, 6+ decks, H17. + if (key <= 0) + { + return 0; + } // we were not dealt a hand + else if (key > 21) + { + return -100; + } // we busted + else if (key == 21) + { + return 150; + } // celebrate + else if (key == 11) + { + return 20; + } // double down + else if (key < 11) + { + return 10; + } // hit + else if (key > 11 && key < 17) + { + return 1; + } // hit, but be less happy about it + else + { + return 0; + } // >= 17 we always stay + } + catch (std::exception const& e) + { + tracingContext.Span.AddEvent(e); + throw; + } } diff --git a/sdk/template/azure-template/test/ut/template_test.cpp b/sdk/template/azure-template/test/ut/template_test.cpp index 044d0fa48b..5e07d27bd0 100644 --- a/sdk/template/azure-template/test/ut/template_test.cpp +++ b/sdk/template/azure-template/test/ut/template_test.cpp @@ -7,18 +7,18 @@ using namespace Azure::Template; -TEST(Template, Basic) -{ - TemplateClient templateClient; - - EXPECT_FALSE(templateClient.ClientVersion().empty()); -} +TEST(Template, Basic) { TemplateClient templateClient; } TEST(Template, GetValue) { TemplateClient templateClient; EXPECT_EQ(templateClient.GetValue(-1), 0); - EXPECT_EQ(templateClient.GetValue(0), 1); - EXPECT_EQ(templateClient.GetValue(1), 2); + EXPECT_EQ(templateClient.GetValue(0), 0); + EXPECT_EQ(templateClient.GetValue(1), 10); + EXPECT_EQ(templateClient.GetValue(22), -100); + EXPECT_EQ(templateClient.GetValue(21), 150); + EXPECT_EQ(templateClient.GetValue(11), 20); + EXPECT_EQ(templateClient.GetValue(14), 1); + EXPECT_EQ(templateClient.GetValue(18), 0); }